File size: 1,548 Bytes
cb92718
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn


DROPOUT = 0.6


class DermFoundationMLPHead(nn.Sequential):
    """
    Exact MLP head used after Derm Foundation embeddings.

    Architecture:
        Linear(input_dim, 512) -> ReLU -> Dropout(0.6)
        Linear(512, 256)      -> ReLU -> Dropout(0.6)
        Linear(256, 128)      -> ReLU -> Dropout(0.6)
        Linear(128, num_classes)
    """

    def __init__(self, input_dim: int, num_classes: int):
        super().__init__(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(DROPOUT),

            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(DROPOUT),

            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(DROPOUT),

            nn.Linear(128, num_classes),
        )


def build_mlp_head_from_checkpoint(
    checkpoint_path: str,
    device: torch.device,
) -> tuple[nn.Module, dict]:
    """
    Load derm_foundation_mlp_head.pt.

    Expected checkpoint format:
        {
            "model_state_dict": model.state_dict(),
            ...
        }
    """
    checkpoint = torch.load(
        checkpoint_path,
        map_location=device,
    )

    state_dict = checkpoint["model_state_dict"]

    input_dim = int(state_dict["0.weight"].shape[1])
    num_classes = int(state_dict["9.weight"].shape[0])

    head = DermFoundationMLPHead(
        input_dim=input_dim,
        num_classes=num_classes,
    ).to(device)

    head.load_state_dict(state_dict, strict=True)
    head.eval()

    return head, checkpoint