Spaces:
Running
Running
File size: 1,548 Bytes
cb92718 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | import torch
import torch.nn as nn
DROPOUT = 0.6
class DermFoundationMLPHead(nn.Sequential):
"""
Exact MLP head used after Derm Foundation embeddings.
Architecture:
Linear(input_dim, 512) -> ReLU -> Dropout(0.6)
Linear(512, 256) -> ReLU -> Dropout(0.6)
Linear(256, 128) -> ReLU -> Dropout(0.6)
Linear(128, num_classes)
"""
def __init__(self, input_dim: int, num_classes: int):
super().__init__(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Dropout(DROPOUT),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(DROPOUT),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(DROPOUT),
nn.Linear(128, num_classes),
)
def build_mlp_head_from_checkpoint(
checkpoint_path: str,
device: torch.device,
) -> tuple[nn.Module, dict]:
"""
Load derm_foundation_mlp_head.pt.
Expected checkpoint format:
{
"model_state_dict": model.state_dict(),
...
}
"""
checkpoint = torch.load(
checkpoint_path,
map_location=device,
)
state_dict = checkpoint["model_state_dict"]
input_dim = int(state_dict["0.weight"].shape[1])
num_classes = int(state_dict["9.weight"].shape[0])
head = DermFoundationMLPHead(
input_dim=input_dim,
num_classes=num_classes,
).to(device)
head.load_state_dict(state_dict, strict=True)
head.eval()
return head, checkpoint |