MEDNA-DFM: A Dual-View FiLM-MoE Model for Explainable DNA Methylation Prediction
Paper
β’ 2602.22850 β’ Published
This repository contains the official model weights and inference code for the paper: MEDNA-DFM: A Dual-View FiLM-MOE Model for Explainable DNA Methylation Prediction (He et al., 2026).
This specific repository hosts the fine-tuned model for predicting 4mC_S.cerevisiae.
π Paper Link: https://arxiv.org/abs/2602.22850
Due to specific architecture dependencies, please ensure you have transformers==4.18.0 installed. Run the following python script to dynamically download and run the model:
"""
Quick Start: MEDNA-DFM for DNA Methylation Prediction
This script demonstrates how to dynamically load the MEDNA-DFM model
from the Hugging Face Hub and perform inference on raw DNA sequences.
"""
import sys
import torch
import requests
from huggingface_hub import snapshot_download
# Environment Patch for Legacy Dependencies
def patch_legacy_requests():
"""
Patches the requests session to ensure compatibility between legacy
transformers (v4.18.0) and modern Hugging Face API endpoints.
"""
original_request = requests.Session.request
def patched_request(self, method, url, *args, **kwargs):
if url.startswith('/api/'):
# Route relative API paths through a stable mirror or hub URL
url = 'https://hf-mirror.com' + url
return original_request(self, method, url, *args, **kwargs)
requests.Session.request = patched_request
patch_legacy_requests()
print("Fetching MEDNA-DFM from Hugging Face Hub...")
REPO_ID = "hy-0003/MEDNA-DFM_4mC_S.cerevisiae"
cloud_model_path = snapshot_download(repo_id=REPO_ID)
sys.path.insert(0, cloud_model_path)
from configuration_medna import MednaConfig
from modeling_medna import MEDNADFMForSequenceClassification
print("Initializing model weights...")
model = MEDNADFMForSequenceClassification.from_pretrained(cloud_model_path)
model.eval()
# Inference on Sample DNA Sequences
# Typically, sequences centered around the target modification site (e.g., C or A)
test_sequences = [
"CGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGCGC", # Sample 1: GC-rich motif
"AAAAATTTTTAAAAATTTTTGAGGAAAAATTTTTAAAAATT" # Sample 2: A-tract motif
]
with torch.no_grad():
outputs = model(test_sequences)
probabilities = torch.softmax(outputs.logits, dim=-1)
print(f"{'MEDNA-DFM Prediction Results':^50}")
for i, seq in enumerate(test_sequences):
prob_0 = probabilities[i][0].item()
prob_1 = probabilities[i][1].item()
# Thresholding for binary classification
prediction = "Methylated (1)" if prob_1 > 0.5 else "Unmethylated (0)"
print(f"Sequence {i+1}:")
print(f" Snippet : {seq[:10]} ... {seq[-10:]}")
print(f" P(Unmethylated) : {prob_0:.4f}")
print(f" P(Methylated) : {prob_1:.4f}")
print(f" Prediction : {prediction}")
print("-" * 50)