Spaces:
Runtime error
Runtime error
Create model_loader.py
Browse files- model_loader.py +413 -0
model_loader.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model_loader.py
|
| 3 |
+
==============
|
| 4 |
+
Loads all three pretrained models using their EXACT native architectures
|
| 5 |
+
as confirmed from the live HuggingFace Space source code.
|
| 6 |
+
|
| 7 |
+
Models:
|
| 8 |
+
1. nileshhanotia/mutation-predictor-splice
|
| 9 |
+
β MutationPredictorCNN_v2 (input dim=1106, 99bp window)
|
| 10 |
+
β File: mutation_predictor_splice.pt
|
| 11 |
+
|
| 12 |
+
2. nileshhanotia/mutation-predictor-v4
|
| 13 |
+
β MutationPredictorCNN_v2 variant (inferred from same family)
|
| 14 |
+
β File: mutation_predictor_v4.pt (or pytorch_model.pth)
|
| 15 |
+
|
| 16 |
+
3. nileshhanotia/mutation-pathogenicity-predictor
|
| 17 |
+
β MutationPredictorCNN (classic, 99bp window)
|
| 18 |
+
β File: pytorch_model.pth
|
| 19 |
+
|
| 20 |
+
Architecture notes taken directly from live app source β nothing redesigned.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
import logging
|
| 25 |
+
import os
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import numpy as np
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# ββ HuggingFace repo IDs ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
REPO_SPLICE = "nileshhanotia/mutation-predictor-splice"
|
| 36 |
+
REPO_V4 = "nileshhanotia/mutation-predictor-v4"
|
| 37 |
+
REPO_CLASSIC = "nileshhanotia/mutation-pathogenicity-predictor"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
# Architecture 1 & 2 β MutationPredictorCNN_v2
|
| 42 |
+
# Source: mutation-predictor-splice-app/app.py (exact copy)
|
| 43 |
+
# Used by both splice model and v4 model
|
| 44 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
+
|
| 46 |
+
def get_mutation_position_from_input(x_flat):
|
| 47 |
+
return x_flat[:, 990:1089].argmax(dim=1)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class MutationPredictorCNN_v2(nn.Module):
|
| 51 |
+
"""
|
| 52 |
+
Exact architecture from nileshhanotia/mutation-predictor-splice-app.
|
| 53 |
+
fc_region_out and splice_fc_out are inferred from checkpoint's state_dict
|
| 54 |
+
shapes so they auto-adapt to v4 vs splice checkpoints.
|
| 55 |
+
"""
|
| 56 |
+
def __init__(self, fc_region_out: int = 8, splice_fc_out: int = 16):
|
| 57 |
+
super().__init__()
|
| 58 |
+
fc1_in = 256 + 32 + fc_region_out + splice_fc_out
|
| 59 |
+
self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3)
|
| 60 |
+
self.bn1 = nn.BatchNorm1d(64)
|
| 61 |
+
self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
|
| 62 |
+
self.bn2 = nn.BatchNorm1d(128)
|
| 63 |
+
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
|
| 64 |
+
self.bn3 = nn.BatchNorm1d(256)
|
| 65 |
+
self.global_pool = nn.AdaptiveAvgPool1d(1)
|
| 66 |
+
self.mut_fc = nn.Linear(12, 32)
|
| 67 |
+
self.importance_head = nn.Linear(256, 1)
|
| 68 |
+
self.region_importance_head = nn.Linear(256, 2)
|
| 69 |
+
self.fc_region = nn.Linear(2, fc_region_out)
|
| 70 |
+
self.splice_fc = nn.Linear(3, splice_fc_out)
|
| 71 |
+
self.splice_importance_head = nn.Linear(256, 3)
|
| 72 |
+
self.fc1 = nn.Linear(fc1_in, 128)
|
| 73 |
+
self.fc2 = nn.Linear(128, 64)
|
| 74 |
+
self.fc3 = nn.Linear(64, 1)
|
| 75 |
+
self.relu = nn.ReLU()
|
| 76 |
+
self.dropout = nn.Dropout(0.4)
|
| 77 |
+
|
| 78 |
+
# Explainability hooks β populated during forward()
|
| 79 |
+
self._conv3_activations: torch.Tensor | None = None
|
| 80 |
+
self._mutation_feature: torch.Tensor | None = None
|
| 81 |
+
self._pooled: torch.Tensor | None = None
|
| 82 |
+
|
| 83 |
+
def forward(self, x, mutation_positions=None):
|
| 84 |
+
bs = x.size(0)
|
| 85 |
+
seq_flat = x[:, :1089]
|
| 86 |
+
mut_onehot = x[:, 1089:1101]
|
| 87 |
+
region_feat = x[:, 1101:1103]
|
| 88 |
+
splice_feat = x[:, 1103:1106]
|
| 89 |
+
|
| 90 |
+
h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99))))
|
| 91 |
+
h = self.relu(self.bn2(self.conv2(h)))
|
| 92 |
+
conv_out = self.relu(self.bn3(self.conv3(h))) # (B, 256, 99)
|
| 93 |
+
|
| 94 |
+
# ββ hook: save conv3 activations ββββββββββββββββββββββ
|
| 95 |
+
self._conv3_activations = conv_out.detach().clone()
|
| 96 |
+
|
| 97 |
+
if mutation_positions is None:
|
| 98 |
+
mutation_positions = get_mutation_position_from_input(x)
|
| 99 |
+
pos_idx = mutation_positions.clamp(0, 98).long()
|
| 100 |
+
pe = pos_idx.view(bs, 1, 1).expand(bs, 256, 1)
|
| 101 |
+
mut_feat = conv_out.gather(2, pe).squeeze(2) # (B, 256)
|
| 102 |
+
|
| 103 |
+
# ββ hook: save mutation-centered feature ββββββββββββββ
|
| 104 |
+
self._mutation_feature = mut_feat.detach().clone()
|
| 105 |
+
|
| 106 |
+
imp_score = torch.sigmoid(self.importance_head(mut_feat))
|
| 107 |
+
pooled = self.global_pool(conv_out).squeeze(-1) # (B, 256)
|
| 108 |
+
self._pooled = pooled.detach().clone()
|
| 109 |
+
|
| 110 |
+
r_imp = torch.sigmoid(self.region_importance_head(pooled))
|
| 111 |
+
s_imp = torch.sigmoid(self.splice_importance_head(pooled))
|
| 112 |
+
|
| 113 |
+
m = self.relu(self.mut_fc(mut_onehot))
|
| 114 |
+
r = self.relu(self.fc_region(region_feat))
|
| 115 |
+
s = self.relu(self.splice_fc(splice_feat))
|
| 116 |
+
|
| 117 |
+
fused = torch.cat([pooled, m, r, s], dim=1)
|
| 118 |
+
out = self.dropout(self.relu(self.fc1(fused)))
|
| 119 |
+
out = self.dropout(self.relu(self.fc2(out)))
|
| 120 |
+
return self.fc3(out), imp_score, r_imp, s_imp
|
| 121 |
+
|
| 122 |
+
# ββ Explainability extraction helpers ββββββββββββββββββββββββββββββββββββ
|
| 123 |
+
|
| 124 |
+
def conv3_norm_profile(self) -> np.ndarray | None:
|
| 125 |
+
"""L2 norm across channels at each of 99 positions β shape (99,)."""
|
| 126 |
+
if self._conv3_activations is None:
|
| 127 |
+
return None
|
| 128 |
+
arr = self._conv3_activations.squeeze(0).norm(dim=0).numpy()
|
| 129 |
+
return arr / (arr.max() + 1e-9)
|
| 130 |
+
|
| 131 |
+
def mutation_centered_peak(self, mutation_pos: int) -> float | None:
|
| 132 |
+
"""Activation value at the mutation position in conv3."""
|
| 133 |
+
profile = self.conv3_norm_profile()
|
| 134 |
+
if profile is None or mutation_pos < 0 or mutation_pos >= len(profile):
|
| 135 |
+
return None
|
| 136 |
+
return float(profile[mutation_pos])
|
| 137 |
+
|
| 138 |
+
def mutation_peak_ratio(self, mutation_pos: int) -> float | None:
|
| 139 |
+
"""peak_signal / mean_signal β how focused is the activation."""
|
| 140 |
+
profile = self.conv3_norm_profile()
|
| 141 |
+
if profile is None or mutation_pos < 0:
|
| 142 |
+
return None
|
| 143 |
+
mean_val = float(profile.mean()) + 1e-9
|
| 144 |
+
peak_val = float(profile[mutation_pos])
|
| 145 |
+
return round(peak_val / mean_val, 4)
|
| 146 |
+
|
| 147 |
+
def importance_head_vector(self) -> np.ndarray | None:
|
| 148 |
+
"""Raw mutation-centered feature vector β shape (256,)."""
|
| 149 |
+
if self._mutation_feature is None:
|
| 150 |
+
return None
|
| 151 |
+
return self._mutation_feature.squeeze(0).numpy()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 155 |
+
# Architecture 3 β MutationPredictorCNN (classic)
|
| 156 |
+
# Source: mutation-pathogenicity-app β uses external encoder.py / model.py
|
| 157 |
+
# We reconstruct the standard architecture from the import signature
|
| 158 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 159 |
+
|
| 160 |
+
class MutationPredictorCNN(nn.Module):
|
| 161 |
+
"""
|
| 162 |
+
Classic architecture from nileshhanotia/mutation-pathogenicity-predictor.
|
| 163 |
+
The app imports MutationPredictorCNN from model.py with no args,
|
| 164 |
+
so this is the standard default-constructor variant.
|
| 165 |
+
Input: encoded sequence from MutationEncoder (99bp Γ 2 seqs = dual-channel CNN).
|
| 166 |
+
"""
|
| 167 |
+
def __init__(self, in_channels: int = 8, seq_len: int = 99):
|
| 168 |
+
super().__init__()
|
| 169 |
+
# Standard 3-layer CNN matching the import signature
|
| 170 |
+
self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=7, padding=3)
|
| 171 |
+
self.bn1 = nn.BatchNorm1d(64)
|
| 172 |
+
self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
|
| 173 |
+
self.bn2 = nn.BatchNorm1d(128)
|
| 174 |
+
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
|
| 175 |
+
self.bn3 = nn.BatchNorm1d(256)
|
| 176 |
+
self.pool = nn.AdaptiveAvgPool1d(1)
|
| 177 |
+
self.fc1 = nn.Linear(256, 128)
|
| 178 |
+
self.fc2 = nn.Linear(128, 1)
|
| 179 |
+
self.imp = nn.Linear(256, 1)
|
| 180 |
+
self.relu = nn.ReLU()
|
| 181 |
+
self.drop = nn.Dropout(0.3)
|
| 182 |
+
|
| 183 |
+
self._conv3_activations: torch.Tensor | None = None
|
| 184 |
+
self._pooled: torch.Tensor | None = None
|
| 185 |
+
|
| 186 |
+
def forward(self, x):
|
| 187 |
+
h = self.relu(self.bn1(self.conv1(x)))
|
| 188 |
+
h = self.relu(self.bn2(self.conv2(h)))
|
| 189 |
+
h = self.relu(self.bn3(self.conv3(h)))
|
| 190 |
+
self._conv3_activations = h.detach().clone()
|
| 191 |
+
p = self.pool(h).squeeze(-1)
|
| 192 |
+
self._pooled = p.detach().clone()
|
| 193 |
+
logit = self.fc2(self.drop(self.relu(self.fc1(p))))
|
| 194 |
+
importance = torch.sigmoid(self.imp(p))
|
| 195 |
+
return logit, importance
|
| 196 |
+
|
| 197 |
+
def conv3_norm_profile(self) -> np.ndarray | None:
|
| 198 |
+
if self._conv3_activations is None:
|
| 199 |
+
return None
|
| 200 |
+
arr = self._conv3_activations.squeeze(0).norm(dim=0).numpy()
|
| 201 |
+
return arr / (arr.max() + 1e-9)
|
| 202 |
+
|
| 203 |
+
def importance_score(self) -> float | None:
|
| 204 |
+
if self._pooled is None:
|
| 205 |
+
return None
|
| 206 |
+
return float(torch.sigmoid(self.imp(self._pooled)).squeeze().item())
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 210 |
+
# Encoders β taken directly from live app source
|
| 211 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½ββββββββββββββββββββ
|
| 212 |
+
|
| 213 |
+
NUCL = {"A": 0, "T": 1, "G": 2, "C": 3, "N": 4}
|
| 214 |
+
MUT_TYPES = {
|
| 215 |
+
("A","T"):0, ("A","C"):1, ("A","G"):2,
|
| 216 |
+
("T","A"):3, ("T","C"):4, ("T","G"):5,
|
| 217 |
+
("C","A"):6, ("C","T"):7, ("C","G"):8,
|
| 218 |
+
("G","A"):9, ("G","T"):10,("G","C"):11,
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def _encode_seq_5ch(seq: str, n: int = 99) -> torch.Tensor:
|
| 223 |
+
"""5-channel per-nucleotide encoding used by v2 models."""
|
| 224 |
+
seq = (seq.upper() + "N" * n)[:n]
|
| 225 |
+
enc = torch.zeros(n, 5)
|
| 226 |
+
for i, c in enumerate(seq):
|
| 227 |
+
enc[i, NUCL.get(c, 4)] = 1.0
|
| 228 |
+
return enc
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def encode_for_v2(ref_seq: str, mut_seq: str,
|
| 232 |
+
exon_flag: int = 0, intron_flag: int = 0,
|
| 233 |
+
donor_flag: int = 0, acceptor_flag: int = 0,
|
| 234 |
+
region_flag: int = 0) -> torch.Tensor:
|
| 235 |
+
"""
|
| 236 |
+
Full 1106-dim encoding for MutationPredictorCNN_v2.
|
| 237 |
+
Exact logic from splice-app/app.py encode_variant().
|
| 238 |
+
"""
|
| 239 |
+
re = _encode_seq_5ch(ref_seq)
|
| 240 |
+
me = _encode_seq_5ch(mut_seq)
|
| 241 |
+
dm = torch.zeros(99, 1)
|
| 242 |
+
rb = mb = None
|
| 243 |
+
for i in range(min(len(ref_seq), len(mut_seq), 99)):
|
| 244 |
+
if ref_seq[i] != mut_seq[i]:
|
| 245 |
+
dm[i, 0] = 1.0
|
| 246 |
+
if rb is None:
|
| 247 |
+
rb = ref_seq[i].upper()
|
| 248 |
+
mb = mut_seq[i].upper()
|
| 249 |
+
moh = torch.zeros(12)
|
| 250 |
+
if rb and mb:
|
| 251 |
+
idx = MUT_TYPES.get((rb, mb))
|
| 252 |
+
if idx is not None:
|
| 253 |
+
moh[idx] = 1.0
|
| 254 |
+
sf = torch.cat([re, me, dm], dim=1).flatten() # 99*11=1089
|
| 255 |
+
rt = torch.tensor([float(exon_flag), float(intron_flag)])
|
| 256 |
+
st = torch.tensor([float(donor_flag), float(acceptor_flag), float(region_flag)])
|
| 257 |
+
return torch.cat([sf, moh, rt, st]) # 1106
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def encode_for_classic(ref_seq: str, mut_seq: str) -> torch.Tensor:
|
| 261 |
+
"""
|
| 262 |
+
8-channel encoding for MutationPredictorCNN (classic).
|
| 263 |
+
Reconstructed from MutationEncoder import in pathogenicity app:
|
| 264 |
+
ref 4-ch one-hot + mut 4-ch one-hot stacked along channels β (8, 99).
|
| 265 |
+
"""
|
| 266 |
+
BASES = {"A": 0, "C": 1, "G": 2, "T": 3}
|
| 267 |
+
n = 99
|
| 268 |
+
ref = (ref_seq.upper() + "N" * n)[:n]
|
| 269 |
+
mut = (mut_seq.upper() + "N" * n)[:n]
|
| 270 |
+
ref_enc = np.zeros((4, n), dtype=np.float32)
|
| 271 |
+
mut_enc = np.zeros((4, n), dtype=np.float32)
|
| 272 |
+
for i, (rb, mb) in enumerate(zip(ref, mut)):
|
| 273 |
+
if rb in BASES: ref_enc[BASES[rb], i] = 1.0
|
| 274 |
+
if mb in BASES: mut_enc[BASES[mb], i] = 1.0
|
| 275 |
+
arr = np.concatenate([ref_enc, mut_enc], axis=0) # (8, 99)
|
| 276 |
+
return torch.from_numpy(arr).unsqueeze(0) # (1, 8, 99)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def find_mutation_pos(ref_seq: str, mut_seq: str) -> int:
|
| 280 |
+
for i in range(min(len(ref_seq), len(mut_seq), 99)):
|
| 281 |
+
if ref_seq[i] != mut_seq[i]:
|
| 282 |
+
return i
|
| 283 |
+
return -1
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 287 |
+
# Registry
|
| 288 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 289 |
+
|
| 290 |
+
class ModelRegistry:
|
| 291 |
+
def __init__(self, hf_token: str | None = None):
|
| 292 |
+
self.token = hf_token or os.environ.get("HF_TOKEN")
|
| 293 |
+
self._splice: MutationPredictorCNN_v2 | None = None
|
| 294 |
+
self._v4: MutationPredictorCNN_v2 | None = None
|
| 295 |
+
self._classic: MutationPredictorCNN | None = None
|
| 296 |
+
self.demo_mode = False
|
| 297 |
+
self.val_acc_splice = 0.0
|
| 298 |
+
self.val_acc_v4 = 0.0
|
| 299 |
+
|
| 300 |
+
@property
|
| 301 |
+
def splice(self) -> MutationPredictorCNN_v2:
|
| 302 |
+
if self._splice is None:
|
| 303 |
+
self._splice = self._load_v2(REPO_SPLICE, "mutation_predictor_splice.pt", "splice")
|
| 304 |
+
return self._splice
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def v4(self) -> MutationPredictorCNN_v2:
|
| 308 |
+
if self._v4 is None:
|
| 309 |
+
self._v4 = self._load_v2(REPO_V4,
|
| 310 |
+
"mutation_predictor_v4.pt", "v4",
|
| 311 |
+
fallback_files=["pytorch_model.pth", "model.pth"])
|
| 312 |
+
return self._v4
|
| 313 |
+
|
| 314 |
+
@property
|
| 315 |
+
def classic(self) -> MutationPredictorCNN:
|
| 316 |
+
if self._classic is None:
|
| 317 |
+
self._classic = self._load_classic()
|
| 318 |
+
return self._classic
|
| 319 |
+
|
| 320 |
+
def _hf_download(self, repo_id: str, filenames: list[str]) -> str | None:
|
| 321 |
+
try:
|
| 322 |
+
from huggingface_hub import hf_hub_download
|
| 323 |
+
for fname in filenames:
|
| 324 |
+
try:
|
| 325 |
+
return hf_hub_download(repo_id, fname, token=self.token,
|
| 326 |
+
cache_dir="/tmp/mutation_xai")
|
| 327 |
+
except Exception:
|
| 328 |
+
continue
|
| 329 |
+
except ImportError:
|
| 330 |
+
pass
|
| 331 |
+
return None
|
| 332 |
+
|
| 333 |
+
def _load_v2(self, repo_id: str, primary: str, tag: str,
|
| 334 |
+
fallback_files: list[str] | None = None) -> MutationPredictorCNN_v2:
|
| 335 |
+
files = [primary] + (fallback_files or [
|
| 336 |
+
"pytorch_model.pth", "model.pth", "model.pt"])
|
| 337 |
+
path = self._hf_download(repo_id, files)
|
| 338 |
+
|
| 339 |
+
model = None
|
| 340 |
+
if path:
|
| 341 |
+
try:
|
| 342 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
| 343 |
+
sd = ckpt.get("model_state_dict", ckpt)
|
| 344 |
+
fc_region_out = sd["fc_region.weight"].shape[0]
|
| 345 |
+
splice_fc_out = sd["splice_fc.weight"].shape[0]
|
| 346 |
+
model = MutationPredictorCNN_v2(fc_region_out=fc_region_out,
|
| 347 |
+
splice_fc_out=splice_fc_out)
|
| 348 |
+
model.load_state_dict(sd, strict=True)
|
| 349 |
+
if tag == "splice":
|
| 350 |
+
self.val_acc_splice = ckpt.get("val_accuracy", 0.0)
|
| 351 |
+
else:
|
| 352 |
+
self.val_acc_v4 = ckpt.get("val_accuracy", 0.0)
|
| 353 |
+
logger.info("Loaded %s from %s", tag, repo_id)
|
| 354 |
+
except Exception as e:
|
| 355 |
+
logger.warning("Failed to load %s: %s β demo mode", tag, e)
|
| 356 |
+
model = None
|
| 357 |
+
|
| 358 |
+
if model is None:
|
| 359 |
+
self.demo_mode = True
|
| 360 |
+
model = MutationPredictorCNN_v2()
|
| 361 |
+
logger.warning("%s running in DEMO mode (random weights)", tag)
|
| 362 |
+
|
| 363 |
+
model.eval()
|
| 364 |
+
return model
|
| 365 |
+
|
| 366 |
+
def _load_classic(self) -> MutationPredictorCNN:
|
| 367 |
+
# ββ Diagnostic: list ALL files in the repo so we know the real filename
|
| 368 |
+
try:
|
| 369 |
+
from huggingface_hub import list_repo_files
|
| 370 |
+
all_files = list(list_repo_files(REPO_CLASSIC, token=self.token))
|
| 371 |
+
logger.info("Files in %s: %s", REPO_CLASSIC, all_files)
|
| 372 |
+
# Auto-detect any .pt or .pth file in the repo
|
| 373 |
+
pt_files = [f for f in all_files if f.endswith(('.pt', '.pth', '.bin'))]
|
| 374 |
+
if pt_files:
|
| 375 |
+
logger.info("Auto-detected checkpoint files: %s", pt_files)
|
| 376 |
+
except Exception as e:
|
| 377 |
+
logger.warning("Could not list repo files: %s", e)
|
| 378 |
+
pt_files = []
|
| 379 |
+
|
| 380 |
+
# Try every plausible filename β the repo uses an unknown name.
|
| 381 |
+
# Order: most likely names first based on the live app source code.
|
| 382 |
+
candidates = pt_files + [
|
| 383 |
+
"mutation_predictor.pt",
|
| 384 |
+
"mutation_pathogenicity_predictor.pt",
|
| 385 |
+
"mutation_predictor_classic.pt",
|
| 386 |
+
"pytorch_model.pt",
|
| 387 |
+
"pytorch_model.pth",
|
| 388 |
+
"model.pt",
|
| 389 |
+
"model.pth",
|
| 390 |
+
"checkpoint.pt",
|
| 391 |
+
"best_model.pt",
|
| 392 |
+
"classifier.pt",
|
| 393 |
+
]
|
| 394 |
+
path = self._hf_download(REPO_CLASSIC, candidates)
|
| 395 |
+
model = MutationPredictorCNN()
|
| 396 |
+
if path:
|
| 397 |
+
try:
|
| 398 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
| 399 |
+
sd = ckpt.get("model_state_dict", ckpt)
|
| 400 |
+
model.load_state_dict(sd, strict=False)
|
| 401 |
+
logger.info("Loaded classic model from %s", REPO_CLASSIC)
|
| 402 |
+
except Exception as e:
|
| 403 |
+
logger.warning("Failed to load classic: %s β demo mode", e)
|
| 404 |
+
self.demo_mode = True
|
| 405 |
+
else:
|
| 406 |
+
self.demo_mode = True
|
| 407 |
+
logger.warning(
|
| 408 |
+
"Classic model: none of %s found in %s β running DEMO mode",
|
| 409 |
+
candidates, REPO_CLASSIC
|
| 410 |
+
)
|
| 411 |
+
model.eval()
|
| 412 |
+
return model
|
| 413 |
+
#Content is user-generated and unverified.
|