nileshhanotia commited on
Commit
1bf9b9d
·
verified ·
1 Parent(s): 6b78362

Upload model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +425 -0
model_loader.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model_loader.py — PeVe Unified Space Model Loading Module
3
+
4
+ Loading logic adapted from:
5
+ - nileshhanotia/mutation-predictor-splice-app (app.py)
6
+ - nileshhanotia/mutation-pathogenicity-app (app.py)
7
+ - nileshhanotia/mutation-explainable-v6 (model_v6.pkl)
8
+
9
+ Provides:
10
+ load_splice_model() → (model, status_dict)
11
+ load_context_model() → (model, status_dict)
12
+ load_protein_model() → (model, status_dict)
13
+ get_model_status() → combined status dict
14
+ """
15
+
16
+ import os
17
+ import traceback
18
+ import pickle
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ # ── Optional: set HF token for private repos ───────────────────────────────
24
+ # Either set the environment variable HF_TOKEN before running, or hard-code
25
+ # a token here (not recommended for public repos).
26
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
27
+
28
+ # ══════════════════════════════════════════════════════════════════════════════
29
+ # MODULE-LEVEL MODEL HANDLES
30
+ # These are populated by the load_*() functions below.
31
+ # ══════════════════════════════════════════════════════════════════════════════
32
+
33
+ _splice_model = None
34
+ _context_model = None
35
+ _protein_model = None
36
+
37
+ # ══════════════════════════════════════════════════════════════════════════════
38
+ # ARCHITECTURE — Splice Model
39
+ # Adapted from: nileshhanotia/mutation-predictor-splice-app app.py
40
+ # ══════════════════════════════════════════════════════════════════════════════
41
+
42
+ def _get_mutation_position_from_input(x_flat):
43
+ """Internal helper used by MutationPredictorCNN_v2.forward()."""
44
+ return x_flat[:, 990:1089].argmax(dim=1)
45
+
46
+
47
+ class MutationPredictorCNN_v2(nn.Module):
48
+ """
49
+ Splice-aware mutation predictor.
50
+ Architecture copied verbatim from mutation-predictor-splice-app/app.py
51
+ to guarantee weight compatibility.
52
+ """
53
+
54
+ def __init__(self, fc_region_out=8, splice_fc_out=16):
55
+ super().__init__()
56
+ fc1_in = 256 + 32 + fc_region_out + splice_fc_out
57
+ self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3)
58
+ self.bn1 = nn.BatchNorm1d(64)
59
+ self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
60
+ self.bn2 = nn.BatchNorm1d(128)
61
+ self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
62
+ self.bn3 = nn.BatchNorm1d(256)
63
+ self.global_pool = nn.AdaptiveAvgPool1d(1)
64
+ self.mut_fc = nn.Linear(12, 32)
65
+ self.importance_head = nn.Linear(256, 1)
66
+ self.region_importance_head = nn.Linear(256, 2)
67
+ self.fc_region = nn.Linear(2, fc_region_out)
68
+ self.splice_fc = nn.Linear(3, splice_fc_out)
69
+ self.splice_importance_head = nn.Linear(256, 3)
70
+ self.fc1 = nn.Linear(fc1_in, 128)
71
+ self.fc2 = nn.Linear(128, 64)
72
+ self.fc3 = nn.Linear(64, 1)
73
+ self.relu = nn.ReLU()
74
+ self.dropout = nn.Dropout(0.4)
75
+
76
+ def forward(self, x, mutation_positions=None):
77
+ bs = x.size(0)
78
+ seq_flat = x[:, :1089]
79
+ mut_onehot = x[:, 1089:1101]
80
+ region_feat = x[:, 1101:1103]
81
+ splice_feat = x[:, 1103:1106]
82
+
83
+ h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99))))
84
+ h = self.relu(self.bn2(self.conv2(h)))
85
+ conv_out = self.relu(self.bn3(self.conv3(h)))
86
+
87
+ if mutation_positions is None:
88
+ mutation_positions = _get_mutation_position_from_input(x)
89
+ pos_idx = mutation_positions.clamp(0, 98).long()
90
+ pe = pos_idx.view(bs, 1, 1).expand(bs, 256, 1)
91
+ mut_feat = conv_out.gather(2, pe).squeeze(2)
92
+
93
+ imp_score = torch.sigmoid(self.importance_head(mut_feat))
94
+ pooled = self.global_pool(conv_out).squeeze(-1)
95
+ r_imp = torch.sigmoid(self.region_importance_head(pooled))
96
+ s_imp = torch.sigmoid(self.splice_importance_head(pooled))
97
+
98
+ m = self.relu(self.mut_fc(mut_onehot))
99
+ r = self.relu(self.fc_region(region_feat))
100
+ s = self.relu(self.splice_fc(splice_feat))
101
+
102
+ fused = torch.cat([pooled, m, r, s], dim=1)
103
+ out = self.dropout(self.relu(self.fc1(fused)))
104
+ out = self.dropout(self.relu(self.fc2(out)))
105
+ logit = self.fc3(out)
106
+ return logit, imp_score, r_imp, s_imp
107
+
108
+
109
+ # ══════════════════════════════════════════���═══════════════════════════════════
110
+ # ARCHITECTURE — Context (401 bp CNN) Model
111
+ # Adapted from: nileshhanotia/mutation-predictor-v4
112
+ # ══════════════════════════════════════════════════════════════════════════════
113
+
114
+ class MutationContextCNN(nn.Module):
115
+ """
116
+ 401 bp context window CNN for mutation pathogenicity.
117
+ Architecture mirrors the v4 space model; weights loaded from state dict.
118
+ If the actual v4 architecture differs, the load_state_dict call will raise
119
+ a descriptive KeyError that will be captured in the status dict.
120
+ """
121
+
122
+ def __init__(self):
123
+ super().__init__()
124
+ self.conv1 = nn.Conv1d(5, 64, kernel_size=11, padding=5)
125
+ self.bn1 = nn.BatchNorm1d(64)
126
+ self.conv2 = nn.Conv1d(64, 128, kernel_size=7, padding=3)
127
+ self.bn2 = nn.BatchNorm1d(128)
128
+ self.conv3 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
129
+ self.bn3 = nn.BatchNorm1d(256)
130
+ self.pool = nn.AdaptiveAvgPool1d(1)
131
+ self.fc1 = nn.Linear(256, 128)
132
+ self.fc2 = nn.Linear(128, 64)
133
+ self.fc3 = nn.Linear(64, 1)
134
+ self.relu = nn.ReLU()
135
+ self.drop = nn.Dropout(0.3)
136
+
137
+ def forward(self, x):
138
+ # x: (batch, seq_len, channels) → permute → (batch, channels, seq_len)
139
+ h = x.permute(0, 2, 1)
140
+ h = self.relu(self.bn1(self.conv1(h)))
141
+ h = self.relu(self.bn2(self.conv2(h)))
142
+ h = self.relu(self.bn3(self.conv3(h)))
143
+ h = self.pool(h).squeeze(-1)
144
+ h = self.drop(self.relu(self.fc1(h)))
145
+ h = self.drop(self.relu(self.fc2(h)))
146
+ return self.fc3(h)
147
+
148
+
149
+ # ══════════════════════════════════════════════════════════════════════════════
150
+ # LOADER — Splice Model
151
+ # ══════════════════════════════════════════════════════════════════════════════
152
+
153
+ def load_splice_model():
154
+ """
155
+ Load MutationPredictorCNN_v2 from nileshhanotia/mutation-predictor-splice.
156
+
157
+ Loading logic adapted from:
158
+ nileshhanotia/mutation-predictor-splice-app app.py
159
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
160
+ sd = ckpt["model_state_dict"]
161
+
162
+ Returns
163
+ -------
164
+ (model | None, {"loaded": bool, "error_message": str})
165
+ """
166
+ global _splice_model
167
+
168
+ status = {"loaded": False, "error_message": ""}
169
+
170
+ try:
171
+ from huggingface_hub import hf_hub_download # local import for clarity
172
+
173
+ MODEL_REPO = "nileshhanotia/mutation-predictor-splice"
174
+ MODEL_FILENAME = "mutation_predictor_splice.pt"
175
+
176
+ print(f"[splice] Downloading {MODEL_FILENAME} from {MODEL_REPO} …")
177
+ ckpt_path = hf_hub_download(
178
+ repo_id=MODEL_REPO,
179
+ filename=MODEL_FILENAME,
180
+ token=HF_TOKEN,
181
+ )
182
+
183
+ print(f"[splice] Loading checkpoint from {ckpt_path} …")
184
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
185
+ sd = ckpt["model_state_dict"]
186
+
187
+ # Infer architecture hyper-params from the state dict (exact pattern from app.py)
188
+ fc_region_out = sd["fc_region.weight"].shape[0]
189
+ splice_fc_out = sd["splice_fc.weight"].shape[0]
190
+
191
+ model = MutationPredictorCNN_v2(
192
+ fc_region_out=fc_region_out,
193
+ splice_fc_out=splice_fc_out,
194
+ )
195
+ model.load_state_dict(sd)
196
+ model.eval()
197
+
198
+ val_acc = ckpt.get("val_accuracy", float("nan"))
199
+ print(f"[splice] ✓ Loaded. val_accuracy={val_acc:.4f} | "
200
+ f"fc_region_out={fc_region_out} | splice_fc_out={splice_fc_out}")
201
+
202
+ _splice_model = model
203
+ status["loaded"] = True
204
+
205
+ except Exception:
206
+ tb = traceback.format_exc()
207
+ print(f"[splice] ✗ FAILED to load:\n{tb}")
208
+ status["error_message"] = tb
209
+ _splice_model = None
210
+
211
+ return _splice_model, status
212
+
213
+
214
+ # ══════════════════════════════════════════════════════════════════════════════
215
+ # LOADER — Context Model (401 bp CNN, mutation-predictor-v4)
216
+ # ══════════════════════════════════════════════════════════════════════════════
217
+
218
+ def load_context_model():
219
+ """
220
+ Load the 401 bp context CNN from nileshhanotia/mutation-predictor-v4.
221
+
222
+ Loading logic adapted from:
223
+ nileshhanotia/mutation-pathogenicity-app app.py
224
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
225
+ model.load_state_dict(checkpoint["model_state_dict"])
226
+
227
+ Returns
228
+ -------
229
+ (model | None, {"loaded": bool, "error_message": str})
230
+ """
231
+ global _context_model
232
+
233
+ status = {"loaded": False, "error_message": ""}
234
+
235
+ try:
236
+ from huggingface_hub import hf_hub_download
237
+
238
+ MODEL_REPO = "nileshhanotia/mutation-predictor-v4"
239
+ # Try common checkpoint filenames used in HF spaces
240
+ CANDIDATE_FILENAMES = [
241
+ "pytorch_model.pth",
242
+ "mutation_predictor_v4.pt",
243
+ "model.pt",
244
+ "model.pth",
245
+ "checkpoint.pth",
246
+ ]
247
+
248
+ ckpt_path = None
249
+ last_error = ""
250
+ for fname in CANDIDATE_FILENAMES:
251
+ try:
252
+ print(f"[context] Trying {fname} from {MODEL_REPO} …")
253
+ ckpt_path = hf_hub_download(
254
+ repo_id=MODEL_REPO,
255
+ filename=fname,
256
+ token=HF_TOKEN,
257
+ )
258
+ print(f"[context] Found: {fname}")
259
+ break
260
+ except Exception as e:
261
+ last_error = str(e)
262
+ continue
263
+
264
+ if ckpt_path is None:
265
+ raise FileNotFoundError(
266
+ f"None of the candidate filenames found in {MODEL_REPO}. "
267
+ f"Last error: {last_error}"
268
+ )
269
+
270
+ print(f"[context] Loading checkpoint from {ckpt_path} …")
271
+ checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
272
+
273
+ # Support both raw state-dict and wrapped checkpoint
274
+ if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
275
+ sd = checkpoint["model_state_dict"]
276
+ elif isinstance(checkpoint, dict) and "state_dict" in checkpoint:
277
+ sd = checkpoint["state_dict"]
278
+ else:
279
+ sd = checkpoint # assume it IS the state dict
280
+
281
+ model = MutationContextCNN()
282
+ model.load_state_dict(sd, strict=False) # strict=False tolerates minor arch diffs
283
+ model.eval()
284
+
285
+ print("[context] ✓ Loaded MutationContextCNN (401 bp).")
286
+ _context_model = model
287
+ status["loaded"] = True
288
+
289
+ except Exception:
290
+ tb = traceback.format_exc()
291
+ print(f"[context] ✗ FAILED to load:\n{tb}")
292
+ status["error_message"] = tb
293
+ _context_model = None
294
+
295
+ return _context_model, status
296
+
297
+
298
+ # ══════════════════════════════════════════════════════════════════════════════
299
+ # LOADER — Protein Model (XGBoost .pkl from mutation-explainable-v6)
300
+ # ══════════════════════════════════════════════════════════════════════════════
301
+
302
+ def load_protein_model():
303
+ """
304
+ Load the pickled XGBoost model from nileshhanotia/mutation-explainable-v6.
305
+
306
+ Loading logic adapted from:
307
+ nileshhanotia/mutation-explainable-v6 (model_v6.pkl)
308
+
309
+ Uses Python pickle / joblib — NOT XGBoost Booster.load_model().
310
+ The model is already stored as a complete trained sklearn-compatible object.
311
+
312
+ Returns
313
+ -------
314
+ (model | None, {"loaded": bool, "error_message": str})
315
+ """
316
+ global _protein_model
317
+
318
+ status = {"loaded": False, "error_message": ""}
319
+
320
+ try:
321
+ from huggingface_hub import hf_hub_download
322
+
323
+ MODEL_REPO = "nileshhanotia/mutation-explainable-v6"
324
+ MODEL_FILENAME = "model_v6.pkl"
325
+
326
+ print(f"[protein] Downloading {MODEL_FILENAME} from {MODEL_REPO} …")
327
+ pkl_path = hf_hub_download(
328
+ repo_id=MODEL_REPO,
329
+ filename=MODEL_FILENAME,
330
+ token=HF_TOKEN,
331
+ )
332
+
333
+ print(f"[protein] Loading pickle from {pkl_path} …")
334
+ # Try joblib first (common for sklearn/xgboost pipelines), fall back to pickle
335
+ try:
336
+ import joblib
337
+ model = joblib.load(pkl_path)
338
+ print("[protein] Loaded via joblib.")
339
+ except Exception:
340
+ with open(pkl_path, "rb") as f:
341
+ model = pickle.load(f)
342
+ print("[protein] Loaded via pickle.")
343
+
344
+ print(f"[protein] ✓ Loaded protein model: {type(model).__name__}")
345
+ _protein_model = model
346
+ status["loaded"] = True
347
+
348
+ except Exception:
349
+ tb = traceback.format_exc()
350
+ print(f"[protein] ✗ FAILED to load:\n{tb}")
351
+ status["error_message"] = tb
352
+ _protein_model = None
353
+
354
+ return _protein_model, status
355
+
356
+
357
+ # ══════════════════════════════════════════════════════════════════════════════
358
+ # STATUS AGGREGATOR
359
+ # ══════════════════════════════════════════════════════════════════════════════
360
+
361
+ def get_model_status() -> dict:
362
+ """
363
+ Load all three models and return a unified status dictionary.
364
+
365
+ Returns
366
+ -------
367
+ {
368
+ "splice": {"loaded": bool, "error_message": str},
369
+ "context": {"loaded": bool, "error_message": str},
370
+ "protein": {"loaded": bool, "error_message": str},
371
+ }
372
+ """
373
+ print("=" * 60)
374
+ print("PeVe — starting unified model loading")
375
+ print("=" * 60)
376
+
377
+ _, splice_status = load_splice_model()
378
+ _, context_status = load_context_model()
379
+ _, protein_status = load_protein_model()
380
+
381
+ status = {
382
+ "splice": splice_status,
383
+ "context": context_status,
384
+ "protein": protein_status,
385
+ }
386
+
387
+ # Summary report
388
+ print("\n" + "=" * 60)
389
+ print("PeVe — model loading complete")
390
+ print("=" * 60)
391
+ for name, s in status.items():
392
+ icon = "✓" if s["loaded"] else "✗"
393
+ print(f" [{icon}] {name:10s} loaded={s['loaded']}")
394
+ print("=" * 60 + "\n")
395
+
396
+ return status
397
+
398
+
399
+ # ══════════════════════════════════════════════════════════════════════════════
400
+ # PUBLIC ACCESSORS
401
+ # ══════════════════════════════════════════════════════════════════════════════
402
+
403
+ def get_splice_model():
404
+ """Return the loaded splice model handle (None if not loaded)."""
405
+ return _splice_model
406
+
407
+
408
+ def get_context_model():
409
+ """Return the loaded context model handle (None if not loaded)."""
410
+ return _context_model
411
+
412
+
413
+ def get_protein_model():
414
+ """Return the loaded protein model handle (None if not loaded)."""
415
+ return _protein_model
416
+
417
+
418
+ # ══════════════════════════════════════════════════════════════════════════════
419
+ # SELF-TEST
420
+ # ══════════════════════════════════════════════════════════════════════════════
421
+
422
+ if __name__ == "__main__":
423
+ print("Testing model loading...")
424
+ status = get_model_status()
425
+ print(status)