Spaces:
Sleeping
Sleeping
| """ | |
| MedSigLIP Classifier and ensemble utilities for skin lesion triage. | |
| Contains: | |
| - MedSigLIPClassifier: 7-class model (encoder + head, used by notebooks) | |
| - BinaryGateHead: Binary gate scaffold (NB13, not deployed) | |
| - load_medsig_encoder: Load just the frozen vision encoder | |
| - build_classifier_head / load_classifier_head: NB09 7-class heads | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| # --------------------------------------------------------------------------- | |
| # Encoder + head loaders (used by ensemble pipeline) | |
| # --------------------------------------------------------------------------- | |
| def load_medsig_encoder(device="cpu"): | |
| """Load the frozen MedSigLIP-448 vision encoder only. | |
| Lighter than MedSigLIPClassifier — skips creating a classifier head. | |
| Used by the ensemble path where NB09 heads are loaded separately. | |
| Returns: | |
| (vision_model, embed_dim) | |
| """ | |
| from transformers import AutoModel | |
| full_model = AutoModel.from_pretrained( | |
| "google/medsiglip-448", torch_dtype=torch.float32 | |
| ) | |
| vision_model = full_model.vision_model | |
| embed_dim = full_model.config.vision_config.hidden_size | |
| del full_model | |
| vision_model = vision_model.to(device).eval() | |
| for param in vision_model.parameters(): | |
| param.requires_grad = False | |
| return vision_model, embed_dim | |
| def build_classifier_head(embed_dim, hidden_dim, num_classes=7, dropout_rate=0.3): | |
| """Build a 7-class classifier head (same architecture as NB09 heads). | |
| This is the same Sequential structure used in MedSigLIPClassifier.classifier | |
| and in the NB09 training notebook for both MedSigLIP-only and DermLIP-only heads. | |
| """ | |
| return nn.Sequential( | |
| nn.LayerNorm(embed_dim), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(embed_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(hidden_dim, num_classes), | |
| ) | |
| def load_classifier_head(checkpoint_path, device="cpu"): | |
| """Load an NB09 7-class head from checkpoint. | |
| Checkpoint dict expected keys: | |
| head_state_dict: OrderedDict of weights | |
| temperature: float (calibration temperature from NB09) | |
| embed_dim (optional): int — inferred from weights if absent | |
| hidden_dim (optional): int — inferred from weights if absent | |
| Returns: | |
| (head: nn.Sequential, temperature: float) | |
| """ | |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| sd = ckpt["head_state_dict"] | |
| # Strip 'head.' prefix if NB09 saved with module wrapper | |
| if any(k.startswith("head.") for k in sd): | |
| sd = {k.removeprefix("head."): v for k, v in sd.items()} | |
| # Read dims from checkpoint or infer from weight shapes | |
| embed_dim = ckpt.get("embed_dim") or sd["0.weight"].shape[0] | |
| hidden_dim = ckpt.get("hidden_dim") or sd["2.weight"].shape[0] | |
| head = build_classifier_head(embed_dim, hidden_dim) | |
| head.load_state_dict(sd) | |
| head = head.to(device).eval() | |
| return head, ckpt["temperature"] | |
| # --------------------------------------------------------------------------- | |
| # Full classifier (used by notebooks and MedSigLIP-only fallback) | |
| # --------------------------------------------------------------------------- | |
| class MedSigLIPClassifier(nn.Module): | |
| def __init__(self, num_classes=7, dropout_rate=0.3, freeze_encoder=True): | |
| super().__init__() | |
| from transformers import AutoModel | |
| full_model = AutoModel.from_pretrained( | |
| "google/medsiglip-448", | |
| torch_dtype=torch.float32, | |
| ) | |
| self.vision_model = full_model.vision_model | |
| self.embed_dim = full_model.config.vision_config.hidden_size | |
| del full_model | |
| if freeze_encoder: | |
| for param in self.vision_model.parameters(): | |
| param.requires_grad = False | |
| self.classifier = nn.Sequential( | |
| nn.LayerNorm(self.embed_dim), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(self.embed_dim, 512), | |
| nn.GELU(), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(512, num_classes), | |
| ) | |
| def forward(self, pixel_values): | |
| with torch.no_grad(): | |
| outputs = self.vision_model(pixel_values=pixel_values) | |
| if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: | |
| features = outputs.pooler_output | |
| else: | |
| features = outputs.last_hidden_state.mean(dim=1) | |
| return self.classifier(features) | |
| class BinaryGateHead(nn.Module): | |
| """Binary malignancy gate for skin lesion triage. | |
| Takes the pooler_output (1152-d) from MedSigLIP's vision encoder | |
| and outputs a single logit: positive = malignant, negative = benign. | |
| Scaffolded for NB13 LoRA fine-tuning. Not used by the bridge path | |
| (which sums 7-class malignant probabilities instead). | |
| """ | |
| def __init__(self, embed_dim=1152, hidden_dim=256, dropout_rate=0.3): | |
| super().__init__() | |
| self.gate = nn.Sequential( | |
| nn.LayerNorm(embed_dim), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(embed_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(hidden_dim, 1), | |
| ) | |
| def forward(self, features): | |
| """Forward pass. | |
| Args: | |
| features: (B, embed_dim) pooler_output from vision encoder. | |
| Returns: | |
| (B,) logits — one per sample. | |
| """ | |
| return self.gate(features).squeeze(-1) | |