Spaces:
Sleeping
Sleeping
File size: 5,520 Bytes
ac024f3 84842ba ac024f3 84842ba ac024f3 84842ba ac024f3 84842ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """
MedSigLIP Classifier and ensemble utilities for skin lesion triage.
Contains:
- MedSigLIPClassifier: 7-class model (encoder + head, used by notebooks)
- BinaryGateHead: Binary gate scaffold (NB13, not deployed)
- load_medsig_encoder: Load just the frozen vision encoder
- build_classifier_head / load_classifier_head: NB09 7-class heads
"""
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Encoder + head loaders (used by ensemble pipeline)
# ---------------------------------------------------------------------------
def load_medsig_encoder(device="cpu"):
"""Load the frozen MedSigLIP-448 vision encoder only.
Lighter than MedSigLIPClassifier — skips creating a classifier head.
Used by the ensemble path where NB09 heads are loaded separately.
Returns:
(vision_model, embed_dim)
"""
from transformers import AutoModel
full_model = AutoModel.from_pretrained(
"google/medsiglip-448", torch_dtype=torch.float32
)
vision_model = full_model.vision_model
embed_dim = full_model.config.vision_config.hidden_size
del full_model
vision_model = vision_model.to(device).eval()
for param in vision_model.parameters():
param.requires_grad = False
return vision_model, embed_dim
def build_classifier_head(embed_dim, hidden_dim, num_classes=7, dropout_rate=0.3):
"""Build a 7-class classifier head (same architecture as NB09 heads).
This is the same Sequential structure used in MedSigLIPClassifier.classifier
and in the NB09 training notebook for both MedSigLIP-only and DermLIP-only heads.
"""
return nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Dropout(dropout_rate),
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, num_classes),
)
def load_classifier_head(checkpoint_path, device="cpu"):
"""Load an NB09 7-class head from checkpoint.
Checkpoint dict expected keys:
head_state_dict: OrderedDict of weights
temperature: float (calibration temperature from NB09)
embed_dim (optional): int — inferred from weights if absent
hidden_dim (optional): int — inferred from weights if absent
Returns:
(head: nn.Sequential, temperature: float)
"""
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
sd = ckpt["head_state_dict"]
# Strip 'head.' prefix if NB09 saved with module wrapper
if any(k.startswith("head.") for k in sd):
sd = {k.removeprefix("head."): v for k, v in sd.items()}
# Read dims from checkpoint or infer from weight shapes
embed_dim = ckpt.get("embed_dim") or sd["0.weight"].shape[0]
hidden_dim = ckpt.get("hidden_dim") or sd["2.weight"].shape[0]
head = build_classifier_head(embed_dim, hidden_dim)
head.load_state_dict(sd)
head = head.to(device).eval()
return head, ckpt["temperature"]
# ---------------------------------------------------------------------------
# Full classifier (used by notebooks and MedSigLIP-only fallback)
# ---------------------------------------------------------------------------
class MedSigLIPClassifier(nn.Module):
def __init__(self, num_classes=7, dropout_rate=0.3, freeze_encoder=True):
super().__init__()
from transformers import AutoModel
full_model = AutoModel.from_pretrained(
"google/medsiglip-448",
torch_dtype=torch.float32,
)
self.vision_model = full_model.vision_model
self.embed_dim = full_model.config.vision_config.hidden_size
del full_model
if freeze_encoder:
for param in self.vision_model.parameters():
param.requires_grad = False
self.classifier = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Dropout(dropout_rate),
nn.Linear(self.embed_dim, 512),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(512, num_classes),
)
def forward(self, pixel_values):
with torch.no_grad():
outputs = self.vision_model(pixel_values=pixel_values)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
features = outputs.pooler_output
else:
features = outputs.last_hidden_state.mean(dim=1)
return self.classifier(features)
class BinaryGateHead(nn.Module):
"""Binary malignancy gate for skin lesion triage.
Takes the pooler_output (1152-d) from MedSigLIP's vision encoder
and outputs a single logit: positive = malignant, negative = benign.
Scaffolded for NB13 LoRA fine-tuning. Not used by the bridge path
(which sums 7-class malignant probabilities instead).
"""
def __init__(self, embed_dim=1152, hidden_dim=256, dropout_rate=0.3):
super().__init__()
self.gate = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Dropout(dropout_rate),
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, 1),
)
def forward(self, features):
"""Forward pass.
Args:
features: (B, embed_dim) pooler_output from vision encoder.
Returns:
(B,) logits — one per sample.
"""
return self.gate(features).squeeze(-1)
|