nileshhanotia commited on
Commit
ccc687d
Β·
verified Β·
1 Parent(s): 1c39b6b

Create model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +413 -0
model_loader.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model_loader.py
3
+ ==============
4
+ Loads all three pretrained models using their EXACT native architectures
5
+ as confirmed from the live HuggingFace Space source code.
6
+
7
+ Models:
8
+ 1. nileshhanotia/mutation-predictor-splice
9
+ β†’ MutationPredictorCNN_v2 (input dim=1106, 99bp window)
10
+ β†’ File: mutation_predictor_splice.pt
11
+
12
+ 2. nileshhanotia/mutation-predictor-v4
13
+ β†’ MutationPredictorCNN_v2 variant (inferred from same family)
14
+ β†’ File: mutation_predictor_v4.pt (or pytorch_model.pth)
15
+
16
+ 3. nileshhanotia/mutation-pathogenicity-predictor
17
+ β†’ MutationPredictorCNN (classic, 99bp window)
18
+ β†’ File: pytorch_model.pth
19
+
20
+ Architecture notes taken directly from live app source β€” nothing redesigned.
21
+ """
22
+
23
+ from __future__ import annotations
24
+ import logging
25
+ import os
26
+ from pathlib import Path
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import numpy as np
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # ── HuggingFace repo IDs ──────────────────────────────────────────────────────
35
+ REPO_SPLICE = "nileshhanotia/mutation-predictor-splice"
36
+ REPO_V4 = "nileshhanotia/mutation-predictor-v4"
37
+ REPO_CLASSIC = "nileshhanotia/mutation-pathogenicity-predictor"
38
+
39
+
40
+ # ═══════════════════════════════════════════════════════════════════════════════
41
+ # Architecture 1 & 2 β€” MutationPredictorCNN_v2
42
+ # Source: mutation-predictor-splice-app/app.py (exact copy)
43
+ # Used by both splice model and v4 model
44
+ # ═══════════════════════════════════════════════════════════════════════════════
45
+
46
+ def get_mutation_position_from_input(x_flat):
47
+ return x_flat[:, 990:1089].argmax(dim=1)
48
+
49
+
50
+ class MutationPredictorCNN_v2(nn.Module):
51
+ """
52
+ Exact architecture from nileshhanotia/mutation-predictor-splice-app.
53
+ fc_region_out and splice_fc_out are inferred from checkpoint's state_dict
54
+ shapes so they auto-adapt to v4 vs splice checkpoints.
55
+ """
56
+ def __init__(self, fc_region_out: int = 8, splice_fc_out: int = 16):
57
+ super().__init__()
58
+ fc1_in = 256 + 32 + fc_region_out + splice_fc_out
59
+ self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3)
60
+ self.bn1 = nn.BatchNorm1d(64)
61
+ self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
62
+ self.bn2 = nn.BatchNorm1d(128)
63
+ self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
64
+ self.bn3 = nn.BatchNorm1d(256)
65
+ self.global_pool = nn.AdaptiveAvgPool1d(1)
66
+ self.mut_fc = nn.Linear(12, 32)
67
+ self.importance_head = nn.Linear(256, 1)
68
+ self.region_importance_head = nn.Linear(256, 2)
69
+ self.fc_region = nn.Linear(2, fc_region_out)
70
+ self.splice_fc = nn.Linear(3, splice_fc_out)
71
+ self.splice_importance_head = nn.Linear(256, 3)
72
+ self.fc1 = nn.Linear(fc1_in, 128)
73
+ self.fc2 = nn.Linear(128, 64)
74
+ self.fc3 = nn.Linear(64, 1)
75
+ self.relu = nn.ReLU()
76
+ self.dropout = nn.Dropout(0.4)
77
+
78
+ # Explainability hooks β€” populated during forward()
79
+ self._conv3_activations: torch.Tensor | None = None
80
+ self._mutation_feature: torch.Tensor | None = None
81
+ self._pooled: torch.Tensor | None = None
82
+
83
+ def forward(self, x, mutation_positions=None):
84
+ bs = x.size(0)
85
+ seq_flat = x[:, :1089]
86
+ mut_onehot = x[:, 1089:1101]
87
+ region_feat = x[:, 1101:1103]
88
+ splice_feat = x[:, 1103:1106]
89
+
90
+ h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99))))
91
+ h = self.relu(self.bn2(self.conv2(h)))
92
+ conv_out = self.relu(self.bn3(self.conv3(h))) # (B, 256, 99)
93
+
94
+ # ── hook: save conv3 activations ──────────────────────
95
+ self._conv3_activations = conv_out.detach().clone()
96
+
97
+ if mutation_positions is None:
98
+ mutation_positions = get_mutation_position_from_input(x)
99
+ pos_idx = mutation_positions.clamp(0, 98).long()
100
+ pe = pos_idx.view(bs, 1, 1).expand(bs, 256, 1)
101
+ mut_feat = conv_out.gather(2, pe).squeeze(2) # (B, 256)
102
+
103
+ # ── hook: save mutation-centered feature ──────────────
104
+ self._mutation_feature = mut_feat.detach().clone()
105
+
106
+ imp_score = torch.sigmoid(self.importance_head(mut_feat))
107
+ pooled = self.global_pool(conv_out).squeeze(-1) # (B, 256)
108
+ self._pooled = pooled.detach().clone()
109
+
110
+ r_imp = torch.sigmoid(self.region_importance_head(pooled))
111
+ s_imp = torch.sigmoid(self.splice_importance_head(pooled))
112
+
113
+ m = self.relu(self.mut_fc(mut_onehot))
114
+ r = self.relu(self.fc_region(region_feat))
115
+ s = self.relu(self.splice_fc(splice_feat))
116
+
117
+ fused = torch.cat([pooled, m, r, s], dim=1)
118
+ out = self.dropout(self.relu(self.fc1(fused)))
119
+ out = self.dropout(self.relu(self.fc2(out)))
120
+ return self.fc3(out), imp_score, r_imp, s_imp
121
+
122
+ # ── Explainability extraction helpers ────────────────────────────────────
123
+
124
+ def conv3_norm_profile(self) -> np.ndarray | None:
125
+ """L2 norm across channels at each of 99 positions β€” shape (99,)."""
126
+ if self._conv3_activations is None:
127
+ return None
128
+ arr = self._conv3_activations.squeeze(0).norm(dim=0).numpy()
129
+ return arr / (arr.max() + 1e-9)
130
+
131
+ def mutation_centered_peak(self, mutation_pos: int) -> float | None:
132
+ """Activation value at the mutation position in conv3."""
133
+ profile = self.conv3_norm_profile()
134
+ if profile is None or mutation_pos < 0 or mutation_pos >= len(profile):
135
+ return None
136
+ return float(profile[mutation_pos])
137
+
138
+ def mutation_peak_ratio(self, mutation_pos: int) -> float | None:
139
+ """peak_signal / mean_signal β€” how focused is the activation."""
140
+ profile = self.conv3_norm_profile()
141
+ if profile is None or mutation_pos < 0:
142
+ return None
143
+ mean_val = float(profile.mean()) + 1e-9
144
+ peak_val = float(profile[mutation_pos])
145
+ return round(peak_val / mean_val, 4)
146
+
147
+ def importance_head_vector(self) -> np.ndarray | None:
148
+ """Raw mutation-centered feature vector β€” shape (256,)."""
149
+ if self._mutation_feature is None:
150
+ return None
151
+ return self._mutation_feature.squeeze(0).numpy()
152
+
153
+
154
+ # ═══════════════════════════════════════════════════════════════════════════════
155
+ # Architecture 3 β€” MutationPredictorCNN (classic)
156
+ # Source: mutation-pathogenicity-app β€” uses external encoder.py / model.py
157
+ # We reconstruct the standard architecture from the import signature
158
+ # ═══════════════════════════════════════════════════════════════════════════════
159
+
160
+ class MutationPredictorCNN(nn.Module):
161
+ """
162
+ Classic architecture from nileshhanotia/mutation-pathogenicity-predictor.
163
+ The app imports MutationPredictorCNN from model.py with no args,
164
+ so this is the standard default-constructor variant.
165
+ Input: encoded sequence from MutationEncoder (99bp Γ— 2 seqs = dual-channel CNN).
166
+ """
167
+ def __init__(self, in_channels: int = 8, seq_len: int = 99):
168
+ super().__init__()
169
+ # Standard 3-layer CNN matching the import signature
170
+ self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=7, padding=3)
171
+ self.bn1 = nn.BatchNorm1d(64)
172
+ self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
173
+ self.bn2 = nn.BatchNorm1d(128)
174
+ self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
175
+ self.bn3 = nn.BatchNorm1d(256)
176
+ self.pool = nn.AdaptiveAvgPool1d(1)
177
+ self.fc1 = nn.Linear(256, 128)
178
+ self.fc2 = nn.Linear(128, 1)
179
+ self.imp = nn.Linear(256, 1)
180
+ self.relu = nn.ReLU()
181
+ self.drop = nn.Dropout(0.3)
182
+
183
+ self._conv3_activations: torch.Tensor | None = None
184
+ self._pooled: torch.Tensor | None = None
185
+
186
+ def forward(self, x):
187
+ h = self.relu(self.bn1(self.conv1(x)))
188
+ h = self.relu(self.bn2(self.conv2(h)))
189
+ h = self.relu(self.bn3(self.conv3(h)))
190
+ self._conv3_activations = h.detach().clone()
191
+ p = self.pool(h).squeeze(-1)
192
+ self._pooled = p.detach().clone()
193
+ logit = self.fc2(self.drop(self.relu(self.fc1(p))))
194
+ importance = torch.sigmoid(self.imp(p))
195
+ return logit, importance
196
+
197
+ def conv3_norm_profile(self) -> np.ndarray | None:
198
+ if self._conv3_activations is None:
199
+ return None
200
+ arr = self._conv3_activations.squeeze(0).norm(dim=0).numpy()
201
+ return arr / (arr.max() + 1e-9)
202
+
203
+ def importance_score(self) -> float | None:
204
+ if self._pooled is None:
205
+ return None
206
+ return float(torch.sigmoid(self.imp(self._pooled)).squeeze().item())
207
+
208
+
209
+ # ═══════════════════════════════════════════════════════════════════════════════
210
+ # Encoders β€” taken directly from live app source
211
+ # ══════════════════════════════════════════════════════════���════════════════════
212
+
213
+ NUCL = {"A": 0, "T": 1, "G": 2, "C": 3, "N": 4}
214
+ MUT_TYPES = {
215
+ ("A","T"):0, ("A","C"):1, ("A","G"):2,
216
+ ("T","A"):3, ("T","C"):4, ("T","G"):5,
217
+ ("C","A"):6, ("C","T"):7, ("C","G"):8,
218
+ ("G","A"):9, ("G","T"):10,("G","C"):11,
219
+ }
220
+
221
+
222
+ def _encode_seq_5ch(seq: str, n: int = 99) -> torch.Tensor:
223
+ """5-channel per-nucleotide encoding used by v2 models."""
224
+ seq = (seq.upper() + "N" * n)[:n]
225
+ enc = torch.zeros(n, 5)
226
+ for i, c in enumerate(seq):
227
+ enc[i, NUCL.get(c, 4)] = 1.0
228
+ return enc
229
+
230
+
231
+ def encode_for_v2(ref_seq: str, mut_seq: str,
232
+ exon_flag: int = 0, intron_flag: int = 0,
233
+ donor_flag: int = 0, acceptor_flag: int = 0,
234
+ region_flag: int = 0) -> torch.Tensor:
235
+ """
236
+ Full 1106-dim encoding for MutationPredictorCNN_v2.
237
+ Exact logic from splice-app/app.py encode_variant().
238
+ """
239
+ re = _encode_seq_5ch(ref_seq)
240
+ me = _encode_seq_5ch(mut_seq)
241
+ dm = torch.zeros(99, 1)
242
+ rb = mb = None
243
+ for i in range(min(len(ref_seq), len(mut_seq), 99)):
244
+ if ref_seq[i] != mut_seq[i]:
245
+ dm[i, 0] = 1.0
246
+ if rb is None:
247
+ rb = ref_seq[i].upper()
248
+ mb = mut_seq[i].upper()
249
+ moh = torch.zeros(12)
250
+ if rb and mb:
251
+ idx = MUT_TYPES.get((rb, mb))
252
+ if idx is not None:
253
+ moh[idx] = 1.0
254
+ sf = torch.cat([re, me, dm], dim=1).flatten() # 99*11=1089
255
+ rt = torch.tensor([float(exon_flag), float(intron_flag)])
256
+ st = torch.tensor([float(donor_flag), float(acceptor_flag), float(region_flag)])
257
+ return torch.cat([sf, moh, rt, st]) # 1106
258
+
259
+
260
+ def encode_for_classic(ref_seq: str, mut_seq: str) -> torch.Tensor:
261
+ """
262
+ 8-channel encoding for MutationPredictorCNN (classic).
263
+ Reconstructed from MutationEncoder import in pathogenicity app:
264
+ ref 4-ch one-hot + mut 4-ch one-hot stacked along channels β†’ (8, 99).
265
+ """
266
+ BASES = {"A": 0, "C": 1, "G": 2, "T": 3}
267
+ n = 99
268
+ ref = (ref_seq.upper() + "N" * n)[:n]
269
+ mut = (mut_seq.upper() + "N" * n)[:n]
270
+ ref_enc = np.zeros((4, n), dtype=np.float32)
271
+ mut_enc = np.zeros((4, n), dtype=np.float32)
272
+ for i, (rb, mb) in enumerate(zip(ref, mut)):
273
+ if rb in BASES: ref_enc[BASES[rb], i] = 1.0
274
+ if mb in BASES: mut_enc[BASES[mb], i] = 1.0
275
+ arr = np.concatenate([ref_enc, mut_enc], axis=0) # (8, 99)
276
+ return torch.from_numpy(arr).unsqueeze(0) # (1, 8, 99)
277
+
278
+
279
+ def find_mutation_pos(ref_seq: str, mut_seq: str) -> int:
280
+ for i in range(min(len(ref_seq), len(mut_seq), 99)):
281
+ if ref_seq[i] != mut_seq[i]:
282
+ return i
283
+ return -1
284
+
285
+
286
+ # ═══════════════════════════════════════════════════════════════════════════════
287
+ # Registry
288
+ # ═══════════════════════════════════════════════════════════════════════════════
289
+
290
+ class ModelRegistry:
291
+ def __init__(self, hf_token: str | None = None):
292
+ self.token = hf_token or os.environ.get("HF_TOKEN")
293
+ self._splice: MutationPredictorCNN_v2 | None = None
294
+ self._v4: MutationPredictorCNN_v2 | None = None
295
+ self._classic: MutationPredictorCNN | None = None
296
+ self.demo_mode = False
297
+ self.val_acc_splice = 0.0
298
+ self.val_acc_v4 = 0.0
299
+
300
+ @property
301
+ def splice(self) -> MutationPredictorCNN_v2:
302
+ if self._splice is None:
303
+ self._splice = self._load_v2(REPO_SPLICE, "mutation_predictor_splice.pt", "splice")
304
+ return self._splice
305
+
306
+ @property
307
+ def v4(self) -> MutationPredictorCNN_v2:
308
+ if self._v4 is None:
309
+ self._v4 = self._load_v2(REPO_V4,
310
+ "mutation_predictor_v4.pt", "v4",
311
+ fallback_files=["pytorch_model.pth", "model.pth"])
312
+ return self._v4
313
+
314
+ @property
315
+ def classic(self) -> MutationPredictorCNN:
316
+ if self._classic is None:
317
+ self._classic = self._load_classic()
318
+ return self._classic
319
+
320
+ def _hf_download(self, repo_id: str, filenames: list[str]) -> str | None:
321
+ try:
322
+ from huggingface_hub import hf_hub_download
323
+ for fname in filenames:
324
+ try:
325
+ return hf_hub_download(repo_id, fname, token=self.token,
326
+ cache_dir="/tmp/mutation_xai")
327
+ except Exception:
328
+ continue
329
+ except ImportError:
330
+ pass
331
+ return None
332
+
333
+ def _load_v2(self, repo_id: str, primary: str, tag: str,
334
+ fallback_files: list[str] | None = None) -> MutationPredictorCNN_v2:
335
+ files = [primary] + (fallback_files or [
336
+ "pytorch_model.pth", "model.pth", "model.pt"])
337
+ path = self._hf_download(repo_id, files)
338
+
339
+ model = None
340
+ if path:
341
+ try:
342
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
343
+ sd = ckpt.get("model_state_dict", ckpt)
344
+ fc_region_out = sd["fc_region.weight"].shape[0]
345
+ splice_fc_out = sd["splice_fc.weight"].shape[0]
346
+ model = MutationPredictorCNN_v2(fc_region_out=fc_region_out,
347
+ splice_fc_out=splice_fc_out)
348
+ model.load_state_dict(sd, strict=True)
349
+ if tag == "splice":
350
+ self.val_acc_splice = ckpt.get("val_accuracy", 0.0)
351
+ else:
352
+ self.val_acc_v4 = ckpt.get("val_accuracy", 0.0)
353
+ logger.info("Loaded %s from %s", tag, repo_id)
354
+ except Exception as e:
355
+ logger.warning("Failed to load %s: %s β€” demo mode", tag, e)
356
+ model = None
357
+
358
+ if model is None:
359
+ self.demo_mode = True
360
+ model = MutationPredictorCNN_v2()
361
+ logger.warning("%s running in DEMO mode (random weights)", tag)
362
+
363
+ model.eval()
364
+ return model
365
+
366
+ def _load_classic(self) -> MutationPredictorCNN:
367
+ # ── Diagnostic: list ALL files in the repo so we know the real filename
368
+ try:
369
+ from huggingface_hub import list_repo_files
370
+ all_files = list(list_repo_files(REPO_CLASSIC, token=self.token))
371
+ logger.info("Files in %s: %s", REPO_CLASSIC, all_files)
372
+ # Auto-detect any .pt or .pth file in the repo
373
+ pt_files = [f for f in all_files if f.endswith(('.pt', '.pth', '.bin'))]
374
+ if pt_files:
375
+ logger.info("Auto-detected checkpoint files: %s", pt_files)
376
+ except Exception as e:
377
+ logger.warning("Could not list repo files: %s", e)
378
+ pt_files = []
379
+
380
+ # Try every plausible filename β€” the repo uses an unknown name.
381
+ # Order: most likely names first based on the live app source code.
382
+ candidates = pt_files + [
383
+ "mutation_predictor.pt",
384
+ "mutation_pathogenicity_predictor.pt",
385
+ "mutation_predictor_classic.pt",
386
+ "pytorch_model.pt",
387
+ "pytorch_model.pth",
388
+ "model.pt",
389
+ "model.pth",
390
+ "checkpoint.pt",
391
+ "best_model.pt",
392
+ "classifier.pt",
393
+ ]
394
+ path = self._hf_download(REPO_CLASSIC, candidates)
395
+ model = MutationPredictorCNN()
396
+ if path:
397
+ try:
398
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
399
+ sd = ckpt.get("model_state_dict", ckpt)
400
+ model.load_state_dict(sd, strict=False)
401
+ logger.info("Loaded classic model from %s", REPO_CLASSIC)
402
+ except Exception as e:
403
+ logger.warning("Failed to load classic: %s β€” demo mode", e)
404
+ self.demo_mode = True
405
+ else:
406
+ self.demo_mode = True
407
+ logger.warning(
408
+ "Classic model: none of %s found in %s β€” running DEMO mode",
409
+ candidates, REPO_CLASSIC
410
+ )
411
+ model.eval()
412
+ return model
413
+ #Content is user-generated and unverified.