nileshhanotia's picture
Rename model_loader (1).py to model_loader.py
cdaf2d3 verified
"""
model_loader.py
===============
Authoritative architecture definitions for all three models,
matched exactly to the trained checkpoint shapes.
splice β†’ MutationPredictorCNN_v2 (input: 1106-dim flat vector)
v4 β†’ MutationPredictorCNN_v4 (input: seq/mut/region/splice tensors)
classic β†’ MutationPredictorClassic (input: 1103-dim flat vector, from classic repo)
"""
from __future__ import annotations
import logging
import os
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
logger = logging.getLogger("mutation_xai.loader")
# ═══════════════════════════════════════════════════════════════════════════════
# Shared constants
# ═══════════════════════════════════════════════════════════════════════════════
NUCL = {"A": 0, "T": 1, "G": 2, "C": 3, "N": 4}
MUT_TYPES = {
("A","T"):0, ("A","C"):1, ("A","G"):2,
("T","A"):3, ("T","C"):4, ("T","G"):5,
("C","A"):6, ("C","T"):7, ("C","G"):8,
("G","A"):9, ("G","T"):10,("G","C"):11,
}
ALL_BASES = ["A", "T", "C", "G"]
# ═══════════════════════════════════════════════════════════════════════════════
# β‘  SPLICE MODEL β€” MutationPredictorCNN_v2
# ═══════════════════════════════════════════════════════════════════════════════
def _get_mutation_position_from_input(x_flat: torch.Tensor) -> torch.Tensor:
"""Infer mutation position from input tensor (sequence difference mask)."""
return x_flat[:, 990:1089].argmax(dim=1)
class MutationPredictorCNN_v2(nn.Module):
"""Splice-aware CNN β€” exact architecture from mutation-predictor-splice."""
def __init__(self, fc_region_out: int = 8, splice_fc_out: int = 16):
super().__init__()
fc1_in = 256 + 32 + fc_region_out + splice_fc_out
self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3)
self.bn1 = nn.BatchNorm1d(64)
self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
self.bn2 = nn.BatchNorm1d(128)
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm1d(256)
self.global_pool = nn.AdaptiveAvgPool1d(1)
self.mut_fc = nn.Linear(12, 32)
self.importance_head = nn.Linear(256, 1)
self.region_importance_head = nn.Linear(256, 2)
self.fc_region = nn.Linear(2, fc_region_out)
self.splice_fc = nn.Linear(3, splice_fc_out)
self.splice_importance_head = nn.Linear(256, 3)
self.fc1 = nn.Linear(fc1_in, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.4)
def forward(self, x: torch.Tensor,
mutation_positions: Optional[torch.Tensor] = None):
bs = x.size(0)
seq_flat = x[:, :1089]
mut_onehot = x[:, 1089:1101]
region_feat= x[:, 1101:1103]
splice_feat= x[:, 1103:1106]
h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99))))
h = self.relu(self.bn2(self.conv2(h)))
conv_out = self.relu(self.bn3(self.conv3(h)))
if mutation_positions is None:
mutation_positions = _get_mutation_position_from_input(x)
pos_idx = mutation_positions.clamp(0, 98).long()
pe = pos_idx.view(bs, 1, 1).expand(bs, 256, 1)
mut_feat = conv_out.gather(2, pe).squeeze(2)
imp_score = torch.sigmoid(self.importance_head(mut_feat))
pooled = self.global_pool(conv_out).squeeze(-1)
r_imp = torch.sigmoid(self.region_importance_head(pooled))
s_imp = torch.sigmoid(self.splice_importance_head(pooled))
m = self.relu(self.mut_fc(mut_onehot))
r = self.relu(self.fc_region(region_feat))
s = self.relu(self.splice_fc(splice_feat))
fused = torch.cat([pooled, m, r, s], dim=1)
out = self.dropout(self.relu(self.fc1(fused)))
out = self.dropout(self.relu(self.fc2(out)))
logit = self.fc3(out)
return logit, imp_score, r_imp, s_imp
# ═══════════════════════════════════════════════════════════════════════════════
# β‘‘ V4 MODEL β€” MutationPredictorCNN_v4
# ═══════════════════════════════════════════════════════════════════════════════
class MutationPredictorCNN_v4(nn.Module):
"""V4 model β€” takes separate (seq, mut, region, splice) tensor inputs."""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv1d(11, 64, 7, padding=3)
self.conv2 = nn.Conv1d(64, 128, 5, padding=2)
self.conv3 = nn.Conv1d(128, 256, 3, padding=1)
self.pool = nn.AdaptiveAvgPool1d(1)
self.mut_fc = nn.Linear(12, 32)
self.region_fc= nn.Linear(2, 8)
self.splice_fc= nn.Linear(3, 16)
self.fc1 = nn.Linear(312, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.3)
def forward(self, seq: torch.Tensor, mut: torch.Tensor,
region: torch.Tensor, splice: torch.Tensor):
x = self.relu(self.conv1(seq))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pool(x).squeeze(-1)
m = self.relu(self.mut_fc(mut))
r = self.relu(self.region_fc(region))
s = self.relu(self.splice_fc(splice))
x = torch.cat([x, m, r, s], dim=1)
x = self.dropout(self.relu(self.fc1(x)))
x = self.relu(self.fc2(x))
return self.fc3(x)
# ═══════════════════════════════════════════════════════════════════════════════
# β‘’ CLASSIC MODEL β€” MutationPredictorClassic
# Mirrors the architecture in the explainable-space repo's model.py
# Input: 1103-dim flat vector (99 ref enc + 99 mut enc + 99 diff + 12 mut_oh + 2 region + 3 splice = 1103)
# Outputs: logit, importance_head_output (per-position), region_imp (2,)
# ═══════════════════════════════════════════════════════════════════════════════
class MutationPredictorClassic(nn.Module):
"""Classic explainable model from mutation-pathogenicity-predictor."""
def __init__(self, input_dim: int = 1103):
super().__init__()
# Sequence portion: 99 Γ— 11 channels
self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3)
self.bn1 = nn.BatchNorm1d(64)
self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
self.bn2 = nn.BatchNorm1d(128)
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm1d(256)
self.pool = nn.AdaptiveAvgPool1d(1)
# Importance head β€” from Linear(256,1) in explainable repo
self.importance_head = nn.Linear(256, 1)
self.region_importance_head = nn.Linear(256, 2)
self.mut_fc = nn.Linear(12, 32)
self.region_fc = nn.Linear(2, 8)
self.splice_fc = nn.Linear(3, 16)
# 256 + 32 + 8 + 16 = 312
self.fc1 = nn.Linear(312, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.3)
def forward(self, x: torch.Tensor):
"""
x: (batch, 1103)
[0:1089] = ref(99Γ—5) + mut(99Γ—5) + diff(99Γ—1) flattened into 99Γ—11 β†’ 1089
[1089:1101] = mutation onehot (12)
[1101:1103] = region flags (2)
[1103:1106] = splice flags (3) β€” may be absent in 1103-dim variant
"""
bs = x.size(0)
seq_flat = x[:, :1089]
# If input is 1103, splice indices are [1100:1103]; handle both 1103 and 1106 shapes
if x.size(1) >= 1106:
mut_onehot = x[:, 1089:1101]
region_feat = x[:, 1101:1103]
splice_feat = x[:, 1103:1106]
else:
mut_onehot = x[:, 1089:1101]
region_feat = x[:, 1101:1103]
splice_feat = torch.zeros(bs, 3, device=x.device)
h = self.relu(self.bn1(self.conv1(seq_flat.view(bs, 11, 99))))
h = self.relu(self.bn2(self.conv2(h)))
conv_out = self.relu(self.bn3(self.conv3(h)))
pooled = self.pool(conv_out).squeeze(-1)
imp = torch.sigmoid(self.importance_head(pooled))
r_imp = torch.sigmoid(self.region_importance_head(pooled))
m = self.relu(self.mut_fc(mut_onehot))
r = self.relu(self.region_fc(region_feat))
s = self.relu(self.splice_fc(splice_feat))
fused = torch.cat([pooled, m, r, s], dim=1)
out = self.dropout(self.relu(self.fc1(fused)))
out = self.relu(self.fc2(out))
logit = self.fc3(out)
return logit, imp, r_imp
# ═══════════════════════════════════════════════════════════════════════════════
# ENCODERS
# ═══════════════════════════════════════════════════════════════════════════════
def _encode_seq_11ch(seq: str, n: int = 99) -> torch.Tensor:
"""Encode sequence as (n, 5) one-hot. Channel layout: A/T/G/C/N."""
seq = (seq.upper() + "N" * n)[:n]
enc = torch.zeros(n, 5)
for i, c in enumerate(seq):
enc[i, NUCL.get(c, 4)] = 1.0
return enc
def encode_for_v2(ref_seq: str, mut_seq: str,
exon_flag: int = 0, intron_flag: int = 0,
donor_flag: int = 0, acceptor_flag: int = 0,
region_flag: int = 0) -> torch.Tensor:
"""
Build the 1106-dim input vector used by both splice and classic models.
Layout:
[0:495] ref (99 Γ— 5)
[495:990] mut (99 Γ— 5)
[990:1089] diff (99 Γ— 1)
[1089:1101] mutation onehot (12)
[1101:1103] region flags [exon, intron]
[1103:1106] splice flags [donor, acceptor, region]
"""
n = 99
re = _encode_seq_11ch(ref_seq, n) # (99,5)
me = _encode_seq_11ch(mut_seq, n) # (99,5)
dm = torch.zeros(n, 1)
rb = mb = None
for i in range(min(len(ref_seq), len(mut_seq), n)):
if ref_seq[i].upper() != mut_seq[i].upper():
dm[i, 0] = 1.0
if rb is None:
rb = ref_seq[i].upper()
mb = mut_seq[i].upper()
moh = torch.zeros(12)
if rb and mb:
idx = MUT_TYPES.get((rb, mb))
if idx is not None:
moh[idx] = 1.0
sf = torch.cat([re, me, dm], dim=1).flatten() # 99 Γ— 11 = 1089
rt = torch.tensor([float(exon_flag), float(intron_flag)])
st = torch.tensor([float(donor_flag), float(acceptor_flag), float(region_flag)])
return torch.cat([sf, moh, rt, st])
def encode_for_v4(ref_seq: str, mut_seq: str,
exon_flag: int = 0, intron_flag: int = 0,
donor_flag: int = 0, acceptor_flag: int = 0,
region_flag: int = 0):
"""
Returns separate tensors (seq, mut_oh, region, splice) for MutationPredictorCNN_v4.
seq: (1, 11, 99) β€” stacked ref/mut/diff channels
"""
flat = encode_for_v2(ref_seq, mut_seq, exon_flag, intron_flag,
donor_flag, acceptor_flag, region_flag)
seq_flat = flat[:1089].view(11, 99).unsqueeze(0) # (1,11,99)
mut_oh = flat[1089:1101].unsqueeze(0) # (1,12)
region = flat[1101:1103].unsqueeze(0) # (1,2)
splice = flat[1103:1106].unsqueeze(0) # (1,3)
return seq_flat, mut_oh, region, splice
def find_mutation_pos(ref_seq: str, mut_seq: str) -> int:
"""Return 0-indexed position of first differing character, or -1."""
for i in range(min(len(ref_seq), len(mut_seq), 99)):
if ref_seq[i].upper() != mut_seq[i].upper():
return i
return -1
# ═══════════════════════════════════════════════════════════════════════════════
# MODEL REGISTRY β€” loads all three models once at startup
# ═══════════════════════════════════════════════════════════════════════════════
SPLICE_REPO = "nileshhanotia/mutation-predictor-splice"
V4_REPO = "nileshhanotia/mutation-predictor-v4"
CLASSIC_REPO = "nileshhanotia/mutation-pathogenicity-predictor"
SPLICE_FILENAME = "mutation_predictor_splice.pt"
V4_FILENAME = "mutation_predictor_splice_v4.pt"
CLASSIC_FILENAME = "mutation_predictor.pt" # common name; fallback tried
def _load_ckpt(repo: str, filename: str,
token: Optional[str] = None) -> dict:
"""Download checkpoint, return state dict or full ckpt dict."""
path = hf_hub_download(repo_id=repo, filename=filename, token=token)
ckpt = torch.load(path, map_location="cpu", weights_only=False)
return ckpt
def _try_filenames(repo: str, candidates: list[str],
token: Optional[str] = None) -> dict:
for fn in candidates:
try:
return _load_ckpt(repo, fn, token)
except Exception:
continue
raise FileNotFoundError(
f"None of {candidates} found in repo {repo}")
class ModelRegistry:
"""Lazy singleton that loads each model exactly once."""
def __init__(self, hf_token: Optional[str] = None):
self._token = hf_token
self._splice = None
self._v4 = None
self._classic = None
self._splice_val_acc = 0.0
self._v4_val_acc = 0.0
self._classic_val_acc = 0.0
# ── individual loaders ────────────────────────────────────────────────────
def _load_splice(self):
logger.info("Loading splice model …")
ckpt = _try_filenames(SPLICE_REPO,
[SPLICE_FILENAME, "model.pt", "pytorch_model.pt"],
self._token)
sd = ckpt.get("model_state_dict", ckpt)
fc_region_out = sd["fc_region.weight"].shape[0]
splice_fc_out = sd["splice_fc.weight"].shape[0]
m = MutationPredictorCNN_v2(fc_region_out=fc_region_out,
splice_fc_out=splice_fc_out)
m.load_state_dict(sd)
m.eval()
self._splice_val_acc = float(ckpt.get("val_accuracy", 0))
logger.info(f"Splice model ready (val_acc={self._splice_val_acc:.4f})")
return m
def _load_v4(self):
logger.info("Loading v4 model …")
ckpt = _try_filenames(V4_REPO,
[V4_FILENAME, "mutation_predictor_splice_v4.pt",
"model.pt", "pytorch_model.pt"],
self._token)
sd = ckpt.get("model_state_dict", ckpt)
m = MutationPredictorCNN_v4()
# Strict=False so we survive minor shape drift between checkpoints
missing, unexpected = m.load_state_dict(sd, strict=False)
if missing:
logger.warning(f"V4 missing keys: {missing[:6]}")
m.eval()
self._v4_val_acc = float(ckpt.get("val_accuracy", 0))
logger.info(f"V4 model ready (val_acc={self._v4_val_acc:.4f})")
return m
def _load_classic(self):
logger.info("Loading classic model …")
ckpt = _try_filenames(CLASSIC_REPO,
[CLASSIC_FILENAME, "model.pt",
"mutation_predictor_classic.pt",
"pytorch_model.pt"],
self._token)
sd = ckpt.get("model_state_dict", ckpt)
# Detect input_dim from first conv weight (channels Γ— kernel = 11 Γ— kernel)
m = MutationPredictorClassic()
missing, unexpected = m.load_state_dict(sd, strict=False)
if missing:
logger.warning(f"Classic missing keys: {missing[:6]}")
m.eval()
self._classic_val_acc = float(ckpt.get("val_accuracy", 0))
logger.info(f"Classic model ready (val_acc={self._classic_val_acc:.4f})")
return m
# ── properties ────────────────────────────────────────────────────────────
@property
def splice(self) -> MutationPredictorCNN_v2:
if self._splice is None:
self._splice = self._load_splice()
return self._splice
@property
def v4(self) -> MutationPredictorCNN_v4:
if self._v4 is None:
self._v4 = self._load_v4()
return self._v4
@property
def classic(self) -> MutationPredictorClassic:
if self._classic is None:
self._classic = self._load_classic()
return self._classic
@property
def val_accs(self) -> dict:
return {
"splice": self._splice_val_acc,
"v4": self._v4_val_acc,
"classic": self._classic_val_acc,
}