""" MedSigLIP Classifier and ensemble utilities for skin lesion triage. Contains: - MedSigLIPClassifier: 7-class model (encoder + head, used by notebooks) - BinaryGateHead: Binary gate scaffold (NB13, not deployed) - load_medsig_encoder: Load just the frozen vision encoder - build_classifier_head / load_classifier_head: NB09 7-class heads """ import torch import torch.nn as nn # --------------------------------------------------------------------------- # Encoder + head loaders (used by ensemble pipeline) # --------------------------------------------------------------------------- def load_medsig_encoder(device="cpu"): """Load the frozen MedSigLIP-448 vision encoder only. Lighter than MedSigLIPClassifier — skips creating a classifier head. Used by the ensemble path where NB09 heads are loaded separately. Returns: (vision_model, embed_dim) """ from transformers import AutoModel full_model = AutoModel.from_pretrained( "google/medsiglip-448", torch_dtype=torch.float32 ) vision_model = full_model.vision_model embed_dim = full_model.config.vision_config.hidden_size del full_model vision_model = vision_model.to(device).eval() for param in vision_model.parameters(): param.requires_grad = False return vision_model, embed_dim def build_classifier_head(embed_dim, hidden_dim, num_classes=7, dropout_rate=0.3): """Build a 7-class classifier head (same architecture as NB09 heads). This is the same Sequential structure used in MedSigLIPClassifier.classifier and in the NB09 training notebook for both MedSigLIP-only and DermLIP-only heads. """ return nn.Sequential( nn.LayerNorm(embed_dim), nn.Dropout(dropout_rate), nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout_rate), nn.Linear(hidden_dim, num_classes), ) def load_classifier_head(checkpoint_path, device="cpu"): """Load an NB09 7-class head from checkpoint. Checkpoint dict expected keys: head_state_dict: OrderedDict of weights temperature: float (calibration temperature from NB09) embed_dim (optional): int — inferred from weights if absent hidden_dim (optional): int — inferred from weights if absent Returns: (head: nn.Sequential, temperature: float) """ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) sd = ckpt["head_state_dict"] # Strip 'head.' prefix if NB09 saved with module wrapper if any(k.startswith("head.") for k in sd): sd = {k.removeprefix("head."): v for k, v in sd.items()} # Read dims from checkpoint or infer from weight shapes embed_dim = ckpt.get("embed_dim") or sd["0.weight"].shape[0] hidden_dim = ckpt.get("hidden_dim") or sd["2.weight"].shape[0] head = build_classifier_head(embed_dim, hidden_dim) head.load_state_dict(sd) head = head.to(device).eval() return head, ckpt["temperature"] # --------------------------------------------------------------------------- # Full classifier (used by notebooks and MedSigLIP-only fallback) # --------------------------------------------------------------------------- class MedSigLIPClassifier(nn.Module): def __init__(self, num_classes=7, dropout_rate=0.3, freeze_encoder=True): super().__init__() from transformers import AutoModel full_model = AutoModel.from_pretrained( "google/medsiglip-448", torch_dtype=torch.float32, ) self.vision_model = full_model.vision_model self.embed_dim = full_model.config.vision_config.hidden_size del full_model if freeze_encoder: for param in self.vision_model.parameters(): param.requires_grad = False self.classifier = nn.Sequential( nn.LayerNorm(self.embed_dim), nn.Dropout(dropout_rate), nn.Linear(self.embed_dim, 512), nn.GELU(), nn.Dropout(dropout_rate), nn.Linear(512, num_classes), ) def forward(self, pixel_values): with torch.no_grad(): outputs = self.vision_model(pixel_values=pixel_values) if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: features = outputs.pooler_output else: features = outputs.last_hidden_state.mean(dim=1) return self.classifier(features) class BinaryGateHead(nn.Module): """Binary malignancy gate for skin lesion triage. Takes the pooler_output (1152-d) from MedSigLIP's vision encoder and outputs a single logit: positive = malignant, negative = benign. Scaffolded for NB13 LoRA fine-tuning. Not used by the bridge path (which sums 7-class malignant probabilities instead). """ def __init__(self, embed_dim=1152, hidden_dim=256, dropout_rate=0.3): super().__init__() self.gate = nn.Sequential( nn.LayerNorm(embed_dim), nn.Dropout(dropout_rate), nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout_rate), nn.Linear(hidden_dim, 1), ) def forward(self, features): """Forward pass. Args: features: (B, embed_dim) pooler_output from vision encoder. Returns: (B,) logits — one per sample. """ return self.gate(features).squeeze(-1)