| """ | |
| model.py — MutationPredictorCNN (fc1=304, region_importance_head.out=2) | |
| """ | |
| import os, torch, torch.nn as nn, torch.nn.functional as F | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| HF_MODEL_REPO = os.environ.get("MODEL_REPO", "nileshhanotia/mutation-predictor-models") | |
| MODEL_FILENAME = "mutation_predictor_exon_intron.pth" | |
| def get_mutation_position_from_input(x): | |
| return x[:, 990:1089].argmax(dim=1) | |
| class MutationPredictorCNN(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Conv1d(11, 64, 7, padding=3) | |
| self.bn1 = nn.BatchNorm1d(64) | |
| self.conv2 = nn.Conv1d(64, 128, 5, padding=2) | |
| self.bn2 = nn.BatchNorm1d(128) | |
| self.conv3 = nn.Conv1d(128,256, 3, padding=1) | |
| self.bn3 = nn.BatchNorm1d(256) | |
| self.global_pool = nn.AdaptiveAvgPool1d(1) | |
| self.mut_fc = nn.Linear(12, 32) | |
| self.importance_head = nn.Linear(256, 1) | |
| self.region_importance_head = nn.Linear(256, 2) | |
| self.region_fc = nn.Linear(2, 16) | |
| self.fc1 = nn.Linear(304, 128) | |
| self.fc2 = nn.Linear(128, 64) | |
| self.fc3 = nn.Linear(64, 1) | |
| self.relu = nn.ReLU(); self.dropout = nn.Dropout(0.4) | |
| def forward(self, x, mutation_positions=None): | |
| B = x.size(0) | |
| seq_2d = x[:,:1089].view(B, 11, 99) | |
| mut_oh = x[:,1089:1101] | |
| reg_f = x[:,1101:1103] | |
| h = self.relu(self.bn1(self.conv1(seq_2d))) | |
| h = self.relu(self.bn2(self.conv2(h))) | |
| c = self.relu(self.bn3(self.conv3(h))) | |
| if mutation_positions is None: | |
| mutation_positions = get_mutation_position_from_input(x) | |
| pi = mutation_positions.clamp(0,98).long() | |
| pe = pi.view(B,1,1).expand(B,256,1) | |
| mf = c.gather(2,pe).squeeze(2) | |
| imp = torch.sigmoid(self.importance_head(mf)) | |
| p = self.global_pool(c).squeeze(-1) | |
| ri = torch.sigmoid(self.region_importance_head(p)) | |
| rf = F.relu(self.region_fc(reg_f)) | |
| m = self.relu(self.mut_fc(mut_oh)) | |
| fused = torch.cat([p, m, rf], dim=1) | |
| out = self.dropout(self.relu(self.fc1(fused))) | |
| out = self.dropout(self.relu(self.fc2(out))) | |
| logit = self.fc3(out) | |
| if self.training: | |
| return logit, imp, c, ri | |
| return logit, imp, ri | |
| def load_model(repo_id=HF_MODEL_REPO, filename=MODEL_FILENAME, device=None): | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| m = MutationPredictorCNN() | |
| ck = torch.load(path, map_location=device, weights_only=False) | |
| m.load_state_dict(ck["model_state_dict"]) | |
| assert m.region_importance_head.out_features == 2 | |
| assert m.fc1.in_features == 304 | |
| return m.to(device).eval() | |
| def predict_with_importance(model, encoder, ref_seq, mut_seq, | |
| chrom=None, pos=None, | |
| exon_flag=None, intron_flag=None, device=None): | |
| if device is None: | |
| device = next(model.parameters()).device | |
| model.eval() | |
| enc = encoder.encode_mutation(ref_seq, mut_seq, | |
| exon_flag=exon_flag, intron_flag=intron_flag) | |
| if exon_flag is not None: enc[1101] = float(exon_flag) | |
| if intron_flag is not None: enc[1102] = float(intron_flag) | |
| with torch.no_grad(): | |
| x = enc.unsqueeze(0).to(device) | |
| logit, imp, ri = model(x) | |
| prob = torch.sigmoid(logit).item() | |
| ri_np = ri.cpu().numpy().flatten() | |
| mp = encoder.find_mutation_position(ref_seq, mut_seq) | |
| rb = ref_seq[mp].upper() if mp < len(ref_seq) else "?" | |
| mb = mut_seq[mp].upper() if mp < len(mut_seq) else "?" | |
| ef = int(enc[1101].item()); iff = int(enc[1102].item()) | |
| region = "EXON" if ef else ("INTRON" if iff else "UNKNOWN") | |
| act = {} | |
| def _h(m,i,o): act["c"]=o.detach() | |
| hook = model.conv3.register_forward_hook(_h) | |
| with torch.no_grad(): model(x) | |
| hook.remove() | |
| pi = act["c"].squeeze(0).norm(dim=0).cpu().numpy() | |
| if pi.max()>0: pi /= pi.max() | |
| return { | |
| "prediction": "Pathogenic" if prob>=0.5 else "Benign", | |
| "pathogenic_probability": round(prob,4), | |
| "mutation_importance": round(imp.item(),4), | |
| "region_importance_exon": round(float(ri_np[0]),4), | |
| "region_importance_intron": round(float(ri_np[1]),4), | |
| "mutation_position": mp, | |
| "substitution": f"{rb}>{mb}", | |
| "region": region, | |
| "top5_active_positions": np.argsort(pi)[::-1][:5].tolist(), | |
| "pos_importance": pi, | |
| } | |