File size: 5,520 Bytes
ac024f3
84842ba
ac024f3
84842ba
 
 
 
 
ac024f3
 
 
 
 
 
84842ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac024f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84842ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
MedSigLIP Classifier and ensemble utilities for skin lesion triage.

Contains:
    - MedSigLIPClassifier: 7-class model (encoder + head, used by notebooks)
    - BinaryGateHead: Binary gate scaffold (NB13, not deployed)
    - load_medsig_encoder: Load just the frozen vision encoder
    - build_classifier_head / load_classifier_head: NB09 7-class heads
"""

import torch
import torch.nn as nn


# ---------------------------------------------------------------------------
# Encoder + head loaders (used by ensemble pipeline)
# ---------------------------------------------------------------------------

def load_medsig_encoder(device="cpu"):
    """Load the frozen MedSigLIP-448 vision encoder only.

    Lighter than MedSigLIPClassifier — skips creating a classifier head.
    Used by the ensemble path where NB09 heads are loaded separately.

    Returns:
        (vision_model, embed_dim)
    """
    from transformers import AutoModel

    full_model = AutoModel.from_pretrained(
        "google/medsiglip-448", torch_dtype=torch.float32
    )
    vision_model = full_model.vision_model
    embed_dim = full_model.config.vision_config.hidden_size
    del full_model

    vision_model = vision_model.to(device).eval()
    for param in vision_model.parameters():
        param.requires_grad = False
    return vision_model, embed_dim


def build_classifier_head(embed_dim, hidden_dim, num_classes=7, dropout_rate=0.3):
    """Build a 7-class classifier head (same architecture as NB09 heads).

    This is the same Sequential structure used in MedSigLIPClassifier.classifier
    and in the NB09 training notebook for both MedSigLIP-only and DermLIP-only heads.
    """
    return nn.Sequential(
        nn.LayerNorm(embed_dim),
        nn.Dropout(dropout_rate),
        nn.Linear(embed_dim, hidden_dim),
        nn.GELU(),
        nn.Dropout(dropout_rate),
        nn.Linear(hidden_dim, num_classes),
    )


def load_classifier_head(checkpoint_path, device="cpu"):
    """Load an NB09 7-class head from checkpoint.

    Checkpoint dict expected keys:
        head_state_dict: OrderedDict of weights
        temperature: float (calibration temperature from NB09)
        embed_dim (optional): int — inferred from weights if absent
        hidden_dim (optional): int — inferred from weights if absent

    Returns:
        (head: nn.Sequential, temperature: float)
    """
    ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)

    sd = ckpt["head_state_dict"]
    # Strip 'head.' prefix if NB09 saved with module wrapper
    if any(k.startswith("head.") for k in sd):
        sd = {k.removeprefix("head."): v for k, v in sd.items()}

    # Read dims from checkpoint or infer from weight shapes
    embed_dim = ckpt.get("embed_dim") or sd["0.weight"].shape[0]
    hidden_dim = ckpt.get("hidden_dim") or sd["2.weight"].shape[0]

    head = build_classifier_head(embed_dim, hidden_dim)
    head.load_state_dict(sd)
    head = head.to(device).eval()

    return head, ckpt["temperature"]


# ---------------------------------------------------------------------------
# Full classifier (used by notebooks and MedSigLIP-only fallback)
# ---------------------------------------------------------------------------

class MedSigLIPClassifier(nn.Module):
    def __init__(self, num_classes=7, dropout_rate=0.3, freeze_encoder=True):
        super().__init__()
        from transformers import AutoModel

        full_model = AutoModel.from_pretrained(
            "google/medsiglip-448",
            torch_dtype=torch.float32,
        )
        self.vision_model = full_model.vision_model
        self.embed_dim = full_model.config.vision_config.hidden_size
        del full_model

        if freeze_encoder:
            for param in self.vision_model.parameters():
                param.requires_grad = False

        self.classifier = nn.Sequential(
            nn.LayerNorm(self.embed_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(self.embed_dim, 512),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes),
        )

    def forward(self, pixel_values):
        with torch.no_grad():
            outputs = self.vision_model(pixel_values=pixel_values)
            if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
                features = outputs.pooler_output
            else:
                features = outputs.last_hidden_state.mean(dim=1)
        return self.classifier(features)


class BinaryGateHead(nn.Module):
    """Binary malignancy gate for skin lesion triage.

    Takes the pooler_output (1152-d) from MedSigLIP's vision encoder
    and outputs a single logit: positive = malignant, negative = benign.

    Scaffolded for NB13 LoRA fine-tuning. Not used by the bridge path
    (which sums 7-class malignant probabilities instead).
    """

    def __init__(self, embed_dim=1152, hidden_dim=256, dropout_rate=0.3):
        super().__init__()
        self.gate = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, features):
        """Forward pass.

        Args:
            features: (B, embed_dim) pooler_output from vision encoder.

        Returns:
            (B,) logits — one per sample.
        """
        return self.gate(features).squeeze(-1)