import torch import torch.nn as nn DROPOUT = 0.6 class DermFoundationMLPHead(nn.Sequential): """ Exact MLP head used after Derm Foundation embeddings. Architecture: Linear(input_dim, 512) -> ReLU -> Dropout(0.6) Linear(512, 256) -> ReLU -> Dropout(0.6) Linear(256, 128) -> ReLU -> Dropout(0.6) Linear(128, num_classes) """ def __init__(self, input_dim: int, num_classes: int): super().__init__( nn.Linear(input_dim, 512), nn.ReLU(), nn.Dropout(DROPOUT), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(DROPOUT), nn.Linear(256, 128), nn.ReLU(), nn.Dropout(DROPOUT), nn.Linear(128, num_classes), ) def build_mlp_head_from_checkpoint( checkpoint_path: str, device: torch.device, ) -> tuple[nn.Module, dict]: """ Load derm_foundation_mlp_head.pt. Expected checkpoint format: { "model_state_dict": model.state_dict(), ... } """ checkpoint = torch.load( checkpoint_path, map_location=device, ) state_dict = checkpoint["model_state_dict"] input_dim = int(state_dict["0.weight"].shape[1]) num_classes = int(state_dict["9.weight"].shape[0]) head = DermFoundationMLPHead( input_dim=input_dim, num_classes=num_classes, ).to(device) head.load_state_dict(state_dict, strict=True) head.eval() return head, checkpoint