Spaces:
Build error
Build error
| """ | |
| model_loader.py | |
| =============== | |
| Authoritative architecture definitions for all three models, | |
| matched exactly to the trained checkpoint shapes. | |
| splice β MutationPredictorCNN_v2 (input: 1106-dim flat vector) | |
| v4 β MutationPredictorCNN_v4 (input: seq/mut/region/splice tensors) | |
| classic β MutationPredictorClassic (input: 1103-dim flat vector, from classic repo) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| logger = logging.getLogger("mutation_xai.loader") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Shared constants | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| NUCL = {"A": 0, "T": 1, "G": 2, "C": 3, "N": 4} | |
| MUT_TYPES = { | |
| ("A","T"):0, ("A","C"):1, ("A","G"):2, | |
| ("T","A"):3, ("T","C"):4, ("T","G"):5, | |
| ("C","A"):6, ("C","T"):7, ("C","G"):8, | |
| ("G","A"):9, ("G","T"):10,("G","C"):11, | |
| } | |
| ALL_BASES = ["A", "T", "C", "G"] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # β SPLICE MODEL β MutationPredictorCNN_v2 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _get_mutation_position_from_input(x_flat: torch.Tensor) -> torch.Tensor: | |
| """Infer mutation position from input tensor (sequence difference mask).""" | |
| return x_flat[:, 990:1089].argmax(dim=1) | |
| class MutationPredictorCNN_v2(nn.Module): | |
| """Splice-aware CNN β exact architecture from mutation-predictor-splice.""" | |
| def __init__(self, fc_region_out: int = 8, splice_fc_out: int = 16): | |
| super().__init__() | |
| fc1_in = 256 + 32 + fc_region_out + splice_fc_out | |
| self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3) | |
| self.bn1 = nn.BatchNorm1d(64) | |
| self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) | |
| self.bn2 = nn.BatchNorm1d(128) | |
| self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) | |
| self.bn3 = nn.BatchNorm1d(256) | |
| self.global_pool = nn.AdaptiveAvgPool1d(1) | |
| self.mut_fc = nn.Linear(12, 32) | |
| self.importance_head = nn.Linear(256, 1) | |
| self.region_importance_head = nn.Linear(256, 2) | |
| self.fc_region = nn.Linear(2, fc_region_out) | |
| self.splice_fc = nn.Linear(3, splice_fc_out) | |
| self.splice_importance_head = nn.Linear(256, 3) | |
| self.fc1 = nn.Linear(fc1_in, 128) | |
| self.fc2 = nn.Linear(128, 64) | |
| self.fc3 = nn.Linear(64, 1) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(0.4) | |
| def forward(self, x: torch.Tensor, | |
| mutation_positions: Optional[torch.Tensor] = None): | |
| bs = x.size(0) | |
| seq_flat = x[:, :1089] | |
| mut_onehot = x[:, 1089:1101] | |
| region_feat= x[:, 1101:1103] | |
| splice_feat= x[:, 1103:1106] | |
| h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99)))) | |
| h = self.relu(self.bn2(self.conv2(h))) | |
| conv_out = self.relu(self.bn3(self.conv3(h))) | |
| if mutation_positions is None: | |
| mutation_positions = _get_mutation_position_from_input(x) | |
| pos_idx = mutation_positions.clamp(0, 98).long() | |
| pe = pos_idx.view(bs, 1, 1).expand(bs, 256, 1) | |
| mut_feat = conv_out.gather(2, pe).squeeze(2) | |
| imp_score = torch.sigmoid(self.importance_head(mut_feat)) | |
| pooled = self.global_pool(conv_out).squeeze(-1) | |
| r_imp = torch.sigmoid(self.region_importance_head(pooled)) | |
| s_imp = torch.sigmoid(self.splice_importance_head(pooled)) | |
| m = self.relu(self.mut_fc(mut_onehot)) | |
| r = self.relu(self.fc_region(region_feat)) | |
| s = self.relu(self.splice_fc(splice_feat)) | |
| fused = torch.cat([pooled, m, r, s], dim=1) | |
| out = self.dropout(self.relu(self.fc1(fused))) | |
| out = self.dropout(self.relu(self.fc2(out))) | |
| logit = self.fc3(out) | |
| return logit, imp_score, r_imp, s_imp | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # β‘ V4 MODEL β MutationPredictorCNN_v4 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class MutationPredictorCNN_v4(nn.Module): | |
| """V4 model β takes separate (seq, mut, region, splice) tensor inputs.""" | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Conv1d(11, 64, 7, padding=3) | |
| self.conv2 = nn.Conv1d(64, 128, 5, padding=2) | |
| self.conv3 = nn.Conv1d(128, 256, 3, padding=1) | |
| self.pool = nn.AdaptiveAvgPool1d(1) | |
| self.mut_fc = nn.Linear(12, 32) | |
| self.region_fc= nn.Linear(2, 8) | |
| self.splice_fc= nn.Linear(3, 16) | |
| self.fc1 = nn.Linear(312, 128) | |
| self.fc2 = nn.Linear(128, 64) | |
| self.fc3 = nn.Linear(64, 1) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(0.3) | |
| def forward(self, seq: torch.Tensor, mut: torch.Tensor, | |
| region: torch.Tensor, splice: torch.Tensor): | |
| x = self.relu(self.conv1(seq)) | |
| x = self.relu(self.conv2(x)) | |
| x = self.relu(self.conv3(x)) | |
| x = self.pool(x).squeeze(-1) | |
| m = self.relu(self.mut_fc(mut)) | |
| r = self.relu(self.region_fc(region)) | |
| s = self.relu(self.splice_fc(splice)) | |
| x = torch.cat([x, m, r, s], dim=1) | |
| x = self.dropout(self.relu(self.fc1(x))) | |
| x = self.relu(self.fc2(x)) | |
| return self.fc3(x) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # β’ CLASSIC MODEL β MutationPredictorClassic | |
| # Mirrors the architecture in the explainable-space repo's model.py | |
| # Input: 1103-dim flat vector (99 ref enc + 99 mut enc + 99 diff + 12 mut_oh + 2 region + 3 splice = 1103) | |
| # Outputs: logit, importance_head_output (per-position), region_imp (2,) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class MutationPredictorClassic(nn.Module): | |
| """Classic explainable model from mutation-pathogenicity-predictor.""" | |
| def __init__(self, input_dim: int = 1103): | |
| super().__init__() | |
| # Sequence portion: 99 Γ 11 channels | |
| self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3) | |
| self.bn1 = nn.BatchNorm1d(64) | |
| self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) | |
| self.bn2 = nn.BatchNorm1d(128) | |
| self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) | |
| self.bn3 = nn.BatchNorm1d(256) | |
| self.pool = nn.AdaptiveAvgPool1d(1) | |
| # Importance head β from Linear(256,1) in explainable repo | |
| self.importance_head = nn.Linear(256, 1) | |
| self.region_importance_head = nn.Linear(256, 2) | |
| self.mut_fc = nn.Linear(12, 32) | |
| self.region_fc = nn.Linear(2, 8) | |
| self.splice_fc = nn.Linear(3, 16) | |
| # 256 + 32 + 8 + 16 = 312 | |
| self.fc1 = nn.Linear(312, 128) | |
| self.fc2 = nn.Linear(128, 64) | |
| self.fc3 = nn.Linear(64, 1) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(0.3) | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| x: (batch, 1103) | |
| [0:1089] = ref(99Γ5) + mut(99Γ5) + diff(99Γ1) flattened into 99Γ11 β 1089 | |
| [1089:1101] = mutation onehot (12) | |
| [1101:1103] = region flags (2) | |
| [1103:1106] = splice flags (3) β may be absent in 1103-dim variant | |
| """ | |
| bs = x.size(0) | |
| seq_flat = x[:, :1089] | |
| # If input is 1103, splice indices are [1100:1103]; handle both 1103 and 1106 shapes | |
| if x.size(1) >= 1106: | |
| mut_onehot = x[:, 1089:1101] | |
| region_feat = x[:, 1101:1103] | |
| splice_feat = x[:, 1103:1106] | |
| else: | |
| mut_onehot = x[:, 1089:1101] | |
| region_feat = x[:, 1101:1103] | |
| splice_feat = torch.zeros(bs, 3, device=x.device) | |
| h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99)))) | |
| h = self.relu(self.bn2(self.conv2(h))) | |
| conv_out = self.relu(self.bn3(self.conv3(h))) | |
| pooled = self.pool(conv_out).squeeze(-1) | |
| imp = torch.sigmoid(self.importance_head(pooled)) | |
| r_imp = torch.sigmoid(self.region_importance_head(pooled)) | |
| m = self.relu(self.mut_fc(mut_onehot)) | |
| r = self.relu(self.region_fc(region_feat)) | |
| s = self.relu(self.splice_fc(splice_feat)) | |
| fused = torch.cat([pooled, m, r, s], dim=1) | |
| out = self.dropout(self.relu(self.fc1(fused))) | |
| out = self.relu(self.fc2(out)) | |
| logit = self.fc3(out) | |
| return logit, imp, r_imp | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ENCODERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _encode_seq_11ch(seq: str, n: int = 99) -> torch.Tensor: | |
| """Encode sequence as (n, 5) one-hot. Channel layout: A/T/G/C/N.""" | |
| seq = (seq.upper() + "N" * n)[:n] | |
| enc = torch.zeros(n, 5) | |
| for i, c in enumerate(seq): | |
| enc[i, NUCL.get(c, 4)] = 1.0 | |
| return enc | |
| def encode_for_v2(ref_seq: str, mut_seq: str, | |
| exon_flag: int = 0, intron_flag: int = 0, | |
| donor_flag: int = 0, acceptor_flag: int = 0, | |
| region_flag: int = 0) -> torch.Tensor: | |
| """ | |
| Build the 1106-dim input vector used by both splice and classic models. | |
| Layout: | |
| [0:495] ref (99 Γ 5) | |
| [495:990] mut (99 Γ 5) | |
| [990:1089] diff (99 Γ 1) | |
| [1089:1101] mutation onehot (12) | |
| [1101:1103] region flags [exon, intron] | |
| [1103:1106] splice flags [donor, acceptor, region] | |
| """ | |
| n = 99 | |
| re = _encode_seq_11ch(ref_seq, n) # (99,5) | |
| me = _encode_seq_11ch(mut_seq, n) # (99,5) | |
| dm = torch.zeros(n, 1) | |
| rb = mb = None | |
| for i in range(min(len(ref_seq), len(mut_seq), n)): | |
| if ref_seq[i].upper() != mut_seq[i].upper(): | |
| dm[i, 0] = 1.0 | |
| if rb is None: | |
| rb = ref_seq[i].upper() | |
| mb = mut_seq[i].upper() | |
| moh = torch.zeros(12) | |
| if rb and mb: | |
| idx = MUT_TYPES.get((rb, mb)) | |
| if idx is not None: | |
| moh[idx] = 1.0 | |
| sf = torch.cat([re, me, dm], dim=1).flatten() # 99 Γ 11 = 1089 | |
| rt = torch.tensor([float(exon_flag), float(intron_flag)]) | |
| st = torch.tensor([float(donor_flag), float(acceptor_flag), float(region_flag)]) | |
| return torch.cat([sf, moh, rt, st]) | |
| def encode_for_v4(ref_seq: str, mut_seq: str, | |
| exon_flag: int = 0, intron_flag: int = 0, | |
| donor_flag: int = 0, acceptor_flag: int = 0, | |
| region_flag: int = 0): | |
| """ | |
| Returns separate tensors (seq, mut_oh, region, splice) for MutationPredictorCNN_v4. | |
| seq: (1, 11, 99) β stacked ref/mut/diff channels | |
| """ | |
| flat = encode_for_v2(ref_seq, mut_seq, exon_flag, intron_flag, | |
| donor_flag, acceptor_flag, region_flag) | |
| seq_flat = flat[:1089].view(11, 99).unsqueeze(0) # (1,11,99) | |
| mut_oh = flat[1089:1101].unsqueeze(0) # (1,12) | |
| region = flat[1101:1103].unsqueeze(0) # (1,2) | |
| splice = flat[1103:1106].unsqueeze(0) # (1,3) | |
| return seq_flat, mut_oh, region, splice | |
| def find_mutation_pos(ref_seq: str, mut_seq: str) -> int: | |
| """Return 0-indexed position of first differing character, or -1.""" | |
| for i in range(min(len(ref_seq), len(mut_seq), 99)): | |
| if ref_seq[i].upper() != mut_seq[i].upper(): | |
| return i | |
| return -1 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODEL REGISTRY β loads all three models once at startup | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SPLICE_REPO = "nileshhanotia/mutation-predictor-splice" | |
| V4_REPO = "nileshhanotia/mutation-predictor-v4" | |
| CLASSIC_REPO = "nileshhanotia/mutation-pathogenicity-predictor" | |
| SPLICE_FILENAME = "mutation_predictor_splice.pt" | |
| V4_FILENAME = "mutation_predictor_splice_v4.pt" | |
| CLASSIC_FILENAME = "mutation_predictor.pt" # common name; fallback tried | |
| def _load_ckpt(repo: str, filename: str, | |
| token: Optional[str] = None) -> dict: | |
| """Download checkpoint, return state dict or full ckpt dict.""" | |
| path = hf_hub_download(repo_id=repo, filename=filename, token=token) | |
| ckpt = torch.load(path, map_location="cpu", weights_only=False) | |
| return ckpt | |
| def _try_filenames(repo: str, candidates: list[str], | |
| token: Optional[str] = None) -> dict: | |
| for fn in candidates: | |
| try: | |
| return _load_ckpt(repo, fn, token) | |
| except Exception: | |
| continue | |
| raise FileNotFoundError( | |
| f"None of {candidates} found in repo {repo}") | |
| class ModelRegistry: | |
| """Lazy singleton that loads each model exactly once.""" | |
| def __init__(self, hf_token: Optional[str] = None): | |
| self._token = hf_token | |
| self._splice = None | |
| self._v4 = None | |
| self._classic = None | |
| self._splice_val_acc = 0.0 | |
| self._v4_val_acc = 0.0 | |
| self._classic_val_acc = 0.0 | |
| # ββ individual loaders ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_splice(self): | |
| logger.info("Loading splice model β¦") | |
| ckpt = _try_filenames(SPLICE_REPO, | |
| [SPLICE_FILENAME, "model.pt", "pytorch_model.pt"], | |
| self._token) | |
| sd = ckpt.get("model_state_dict", ckpt) | |
| fc_region_out = sd["fc_region.weight"].shape[0] | |
| splice_fc_out = sd["splice_fc.weight"].shape[0] | |
| m = MutationPredictorCNN_v2(fc_region_out=fc_region_out, | |
| splice_fc_out=splice_fc_out) | |
| m.load_state_dict(sd) | |
| m.eval() | |
| self._splice_val_acc = float(ckpt.get("val_accuracy", 0)) | |
| logger.info(f"Splice model ready (val_acc={self._splice_val_acc:.4f})") | |
| return m | |
| def _load_v4(self): | |
| logger.info("Loading v4 model β¦") | |
| ckpt = _try_filenames(V4_REPO, | |
| [V4_FILENAME, "mutation_predictor_splice_v4.pt", | |
| "model.pt", "pytorch_model.pt"], | |
| self._token) | |
| sd = ckpt.get("model_state_dict", ckpt) | |
| m = MutationPredictorCNN_v4() | |
| # Strict=False so we survive minor shape drift between checkpoints | |
| missing, unexpected = m.load_state_dict(sd, strict=False) | |
| if missing: | |
| logger.warning(f"V4 missing keys: {missing[:6]}") | |
| m.eval() | |
| self._v4_val_acc = float(ckpt.get("val_accuracy", 0)) | |
| logger.info(f"V4 model ready (val_acc={self._v4_val_acc:.4f})") | |
| return m | |
| def _load_classic(self): | |
| logger.info("Loading classic model β¦") | |
| ckpt = _try_filenames(CLASSIC_REPO, | |
| [CLASSIC_FILENAME, "model.pt", | |
| "mutation_predictor_classic.pt", | |
| "pytorch_model.pt"], | |
| self._token) | |
| sd = ckpt.get("model_state_dict", ckpt) | |
| # Detect input_dim from first conv weight (channels Γ kernel = 11 Γ kernel) | |
| m = MutationPredictorClassic() | |
| missing, unexpected = m.load_state_dict(sd, strict=False) | |
| if missing: | |
| logger.warning(f"Classic missing keys: {missing[:6]}") | |
| m.eval() | |
| self._classic_val_acc = float(ckpt.get("val_accuracy", 0)) | |
| logger.info(f"Classic model ready (val_acc={self._classic_val_acc:.4f})") | |
| return m | |
| # ββ properties ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def splice(self) -> MutationPredictorCNN_v2: | |
| if self._splice is None: | |
| self._splice = self._load_splice() | |
| return self._splice | |
| def v4(self) -> MutationPredictorCNN_v4: | |
| if self._v4 is None: | |
| self._v4 = self._load_v4() | |
| return self._v4 | |
| def classic(self) -> MutationPredictorClassic: | |
| if self._classic is None: | |
| self._classic = self._load_classic() | |
| return self._classic | |
| def val_accs(self) -> dict: | |
| return { | |
| "splice": self._splice_val_acc, | |
| "v4": self._v4_val_acc, | |
| "classic": self._classic_val_acc, | |
| } | |