dermtriage / src /model.py
Kabirgrover's picture
updated for new HF space
84842ba verified
"""
MedSigLIP Classifier and ensemble utilities for skin lesion triage.
Contains:
- MedSigLIPClassifier: 7-class model (encoder + head, used by notebooks)
- BinaryGateHead: Binary gate scaffold (NB13, not deployed)
- load_medsig_encoder: Load just the frozen vision encoder
- build_classifier_head / load_classifier_head: NB09 7-class heads
"""
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Encoder + head loaders (used by ensemble pipeline)
# ---------------------------------------------------------------------------
def load_medsig_encoder(device="cpu"):
"""Load the frozen MedSigLIP-448 vision encoder only.
Lighter than MedSigLIPClassifier — skips creating a classifier head.
Used by the ensemble path where NB09 heads are loaded separately.
Returns:
(vision_model, embed_dim)
"""
from transformers import AutoModel
full_model = AutoModel.from_pretrained(
"google/medsiglip-448", torch_dtype=torch.float32
)
vision_model = full_model.vision_model
embed_dim = full_model.config.vision_config.hidden_size
del full_model
vision_model = vision_model.to(device).eval()
for param in vision_model.parameters():
param.requires_grad = False
return vision_model, embed_dim
def build_classifier_head(embed_dim, hidden_dim, num_classes=7, dropout_rate=0.3):
"""Build a 7-class classifier head (same architecture as NB09 heads).
This is the same Sequential structure used in MedSigLIPClassifier.classifier
and in the NB09 training notebook for both MedSigLIP-only and DermLIP-only heads.
"""
return nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Dropout(dropout_rate),
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, num_classes),
)
def load_classifier_head(checkpoint_path, device="cpu"):
"""Load an NB09 7-class head from checkpoint.
Checkpoint dict expected keys:
head_state_dict: OrderedDict of weights
temperature: float (calibration temperature from NB09)
embed_dim (optional): int — inferred from weights if absent
hidden_dim (optional): int — inferred from weights if absent
Returns:
(head: nn.Sequential, temperature: float)
"""
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
sd = ckpt["head_state_dict"]
# Strip 'head.' prefix if NB09 saved with module wrapper
if any(k.startswith("head.") for k in sd):
sd = {k.removeprefix("head."): v for k, v in sd.items()}
# Read dims from checkpoint or infer from weight shapes
embed_dim = ckpt.get("embed_dim") or sd["0.weight"].shape[0]
hidden_dim = ckpt.get("hidden_dim") or sd["2.weight"].shape[0]
head = build_classifier_head(embed_dim, hidden_dim)
head.load_state_dict(sd)
head = head.to(device).eval()
return head, ckpt["temperature"]
# ---------------------------------------------------------------------------
# Full classifier (used by notebooks and MedSigLIP-only fallback)
# ---------------------------------------------------------------------------
class MedSigLIPClassifier(nn.Module):
def __init__(self, num_classes=7, dropout_rate=0.3, freeze_encoder=True):
super().__init__()
from transformers import AutoModel
full_model = AutoModel.from_pretrained(
"google/medsiglip-448",
torch_dtype=torch.float32,
)
self.vision_model = full_model.vision_model
self.embed_dim = full_model.config.vision_config.hidden_size
del full_model
if freeze_encoder:
for param in self.vision_model.parameters():
param.requires_grad = False
self.classifier = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Dropout(dropout_rate),
nn.Linear(self.embed_dim, 512),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(512, num_classes),
)
def forward(self, pixel_values):
with torch.no_grad():
outputs = self.vision_model(pixel_values=pixel_values)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
features = outputs.pooler_output
else:
features = outputs.last_hidden_state.mean(dim=1)
return self.classifier(features)
class BinaryGateHead(nn.Module):
"""Binary malignancy gate for skin lesion triage.
Takes the pooler_output (1152-d) from MedSigLIP's vision encoder
and outputs a single logit: positive = malignant, negative = benign.
Scaffolded for NB13 LoRA fine-tuning. Not used by the bridge path
(which sums 7-class malignant probabilities instead).
"""
def __init__(self, embed_dim=1152, hidden_dim=256, dropout_rate=0.3):
super().__init__()
self.gate = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Dropout(dropout_rate),
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, 1),
)
def forward(self, features):
"""Forward pass.
Args:
features: (B, embed_dim) pooler_output from vision encoder.
Returns:
(B,) logits — one per sample.
"""
return self.gate(features).squeeze(-1)