""" model.py — MutationPredictorCNN (fc1=304, region_importance_head.out=2) """ import os, torch, torch.nn as nn, torch.nn.functional as F import numpy as np from huggingface_hub import hf_hub_download HF_MODEL_REPO = os.environ.get("MODEL_REPO", "nileshhanotia/mutation-predictor-models") MODEL_FILENAME = "mutation_predictor_exon_intron.pth" def get_mutation_position_from_input(x): return x[:, 990:1089].argmax(dim=1) class MutationPredictorCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv1d(11, 64, 7, padding=3) self.bn1 = nn.BatchNorm1d(64) self.conv2 = nn.Conv1d(64, 128, 5, padding=2) self.bn2 = nn.BatchNorm1d(128) self.conv3 = nn.Conv1d(128,256, 3, padding=1) self.bn3 = nn.BatchNorm1d(256) self.global_pool = nn.AdaptiveAvgPool1d(1) self.mut_fc = nn.Linear(12, 32) self.importance_head = nn.Linear(256, 1) self.region_importance_head = nn.Linear(256, 2) self.region_fc = nn.Linear(2, 16) self.fc1 = nn.Linear(304, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 1) self.relu = nn.ReLU(); self.dropout = nn.Dropout(0.4) def forward(self, x, mutation_positions=None): B = x.size(0) seq_2d = x[:,:1089].view(B, 11, 99) mut_oh = x[:,1089:1101] reg_f = x[:,1101:1103] h = self.relu(self.bn1(self.conv1(seq_2d))) h = self.relu(self.bn2(self.conv2(h))) c = self.relu(self.bn3(self.conv3(h))) if mutation_positions is None: mutation_positions = get_mutation_position_from_input(x) pi = mutation_positions.clamp(0,98).long() pe = pi.view(B,1,1).expand(B,256,1) mf = c.gather(2,pe).squeeze(2) imp = torch.sigmoid(self.importance_head(mf)) p = self.global_pool(c).squeeze(-1) ri = torch.sigmoid(self.region_importance_head(p)) rf = F.relu(self.region_fc(reg_f)) m = self.relu(self.mut_fc(mut_oh)) fused = torch.cat([p, m, rf], dim=1) out = self.dropout(self.relu(self.fc1(fused))) out = self.dropout(self.relu(self.fc2(out))) logit = self.fc3(out) if self.training: return logit, imp, c, ri return logit, imp, ri def load_model(repo_id=HF_MODEL_REPO, filename=MODEL_FILENAME, device=None): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" path = hf_hub_download(repo_id=repo_id, filename=filename) m = MutationPredictorCNN() ck = torch.load(path, map_location=device, weights_only=False) m.load_state_dict(ck["model_state_dict"]) assert m.region_importance_head.out_features == 2 assert m.fc1.in_features == 304 return m.to(device).eval() def predict_with_importance(model, encoder, ref_seq, mut_seq, chrom=None, pos=None, exon_flag=None, intron_flag=None, device=None): if device is None: device = next(model.parameters()).device model.eval() enc = encoder.encode_mutation(ref_seq, mut_seq, exon_flag=exon_flag, intron_flag=intron_flag) if exon_flag is not None: enc[1101] = float(exon_flag) if intron_flag is not None: enc[1102] = float(intron_flag) with torch.no_grad(): x = enc.unsqueeze(0).to(device) logit, imp, ri = model(x) prob = torch.sigmoid(logit).item() ri_np = ri.cpu().numpy().flatten() mp = encoder.find_mutation_position(ref_seq, mut_seq) rb = ref_seq[mp].upper() if mp < len(ref_seq) else "?" mb = mut_seq[mp].upper() if mp < len(mut_seq) else "?" ef = int(enc[1101].item()); iff = int(enc[1102].item()) region = "EXON" if ef else ("INTRON" if iff else "UNKNOWN") act = {} def _h(m,i,o): act["c"]=o.detach() hook = model.conv3.register_forward_hook(_h) with torch.no_grad(): model(x) hook.remove() pi = act["c"].squeeze(0).norm(dim=0).cpu().numpy() if pi.max()>0: pi /= pi.max() return { "prediction": "Pathogenic" if prob>=0.5 else "Benign", "pathogenic_probability": round(prob,4), "mutation_importance": round(imp.item(),4), "region_importance_exon": round(float(ri_np[0]),4), "region_importance_intron": round(float(ri_np[1]),4), "mutation_position": mp, "substitution": f"{rb}>{mb}", "region": region, "top5_active_positions": np.argsort(pi)[::-1][:5].tolist(), "pos_importance": pi, }