nileshhanotia's picture
Add model.py
947772a verified
"""
model.py — MutationPredictorCNN (fc1=304, region_importance_head.out=2)
"""
import os, torch, torch.nn as nn, torch.nn.functional as F
import numpy as np
from huggingface_hub import hf_hub_download
HF_MODEL_REPO = os.environ.get("MODEL_REPO", "nileshhanotia/mutation-predictor-models")
MODEL_FILENAME = "mutation_predictor_exon_intron.pth"
def get_mutation_position_from_input(x):
return x[:, 990:1089].argmax(dim=1)
class MutationPredictorCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv1d(11, 64, 7, padding=3)
self.bn1 = nn.BatchNorm1d(64)
self.conv2 = nn.Conv1d(64, 128, 5, padding=2)
self.bn2 = nn.BatchNorm1d(128)
self.conv3 = nn.Conv1d(128,256, 3, padding=1)
self.bn3 = nn.BatchNorm1d(256)
self.global_pool = nn.AdaptiveAvgPool1d(1)
self.mut_fc = nn.Linear(12, 32)
self.importance_head = nn.Linear(256, 1)
self.region_importance_head = nn.Linear(256, 2)
self.region_fc = nn.Linear(2, 16)
self.fc1 = nn.Linear(304, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
self.relu = nn.ReLU(); self.dropout = nn.Dropout(0.4)
def forward(self, x, mutation_positions=None):
B = x.size(0)
seq_2d = x[:,:1089].view(B, 11, 99)
mut_oh = x[:,1089:1101]
reg_f = x[:,1101:1103]
h = self.relu(self.bn1(self.conv1(seq_2d)))
h = self.relu(self.bn2(self.conv2(h)))
c = self.relu(self.bn3(self.conv3(h)))
if mutation_positions is None:
mutation_positions = get_mutation_position_from_input(x)
pi = mutation_positions.clamp(0,98).long()
pe = pi.view(B,1,1).expand(B,256,1)
mf = c.gather(2,pe).squeeze(2)
imp = torch.sigmoid(self.importance_head(mf))
p = self.global_pool(c).squeeze(-1)
ri = torch.sigmoid(self.region_importance_head(p))
rf = F.relu(self.region_fc(reg_f))
m = self.relu(self.mut_fc(mut_oh))
fused = torch.cat([p, m, rf], dim=1)
out = self.dropout(self.relu(self.fc1(fused)))
out = self.dropout(self.relu(self.fc2(out)))
logit = self.fc3(out)
if self.training:
return logit, imp, c, ri
return logit, imp, ri
def load_model(repo_id=HF_MODEL_REPO, filename=MODEL_FILENAME, device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
path = hf_hub_download(repo_id=repo_id, filename=filename)
m = MutationPredictorCNN()
ck = torch.load(path, map_location=device, weights_only=False)
m.load_state_dict(ck["model_state_dict"])
assert m.region_importance_head.out_features == 2
assert m.fc1.in_features == 304
return m.to(device).eval()
def predict_with_importance(model, encoder, ref_seq, mut_seq,
chrom=None, pos=None,
exon_flag=None, intron_flag=None, device=None):
if device is None:
device = next(model.parameters()).device
model.eval()
enc = encoder.encode_mutation(ref_seq, mut_seq,
exon_flag=exon_flag, intron_flag=intron_flag)
if exon_flag is not None: enc[1101] = float(exon_flag)
if intron_flag is not None: enc[1102] = float(intron_flag)
with torch.no_grad():
x = enc.unsqueeze(0).to(device)
logit, imp, ri = model(x)
prob = torch.sigmoid(logit).item()
ri_np = ri.cpu().numpy().flatten()
mp = encoder.find_mutation_position(ref_seq, mut_seq)
rb = ref_seq[mp].upper() if mp < len(ref_seq) else "?"
mb = mut_seq[mp].upper() if mp < len(mut_seq) else "?"
ef = int(enc[1101].item()); iff = int(enc[1102].item())
region = "EXON" if ef else ("INTRON" if iff else "UNKNOWN")
act = {}
def _h(m,i,o): act["c"]=o.detach()
hook = model.conv3.register_forward_hook(_h)
with torch.no_grad(): model(x)
hook.remove()
pi = act["c"].squeeze(0).norm(dim=0).cpu().numpy()
if pi.max()>0: pi /= pi.max()
return {
"prediction": "Pathogenic" if prob>=0.5 else "Benign",
"pathogenic_probability": round(prob,4),
"mutation_importance": round(imp.item(),4),
"region_importance_exon": round(float(ri_np[0]),4),
"region_importance_intron": round(float(ri_np[1]),4),
"mutation_position": mp,
"substitution": f"{rb}>{mb}",
"region": region,
"top5_active_positions": np.argsort(pi)[::-1][:5].tolist(),
"pos_importance": pi,
}