File size: 2,722 Bytes
397dad3
 
260781f
397dad3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260781f
 
 
 
 
397dad3
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download

class SkinCNN(nn.Module):
    def __init__(self, num_classes: int = 7):
        super().__init__()

        # Feature extractor
        self.features = nn.Sequential(
            # Block 1: 3 -> 32
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),      # 28x28 -> 14x14
            nn.BatchNorm2d(32),

            # Block 2: 32 -> 64 -> 64
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),      # 14x14 -> 7x7
            nn.BatchNorm2d(64),

            # Block 3: 64 -> 128 -> 128
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),      # 7x7 -> 3x3
            nn.BatchNorm2d(128),

            # Block 4: 128 -> 256 -> 256
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)       # 3x3 -> 1x1
        )

        # Classifier: 5 FC layers with BatchNorm + Dropout at input
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.2),

            # 256*1*1 -> 256
            nn.Linear(256 * 1 * 1, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),

            # 256 -> 128
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),

            # 128 -> 64
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),

            # 64 -> 32
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),

            # 32 -> num_classes
            nn.Linear(32, num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.classifier(x)
        return x


def load_model(
    weights_path: str = "model.pth",
    device: str | None = None
) -> tuple[nn.Module, torch.device]:
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    weights_path = hf_hub_download(
        repo_id="iamhmh/derm-cnn-ham10000",
        filename=weights_path
    )

    model = SkinCNN(num_classes=7)
    state_dict = torch.load(weights_path, map_location=device)
    model.load_state_dict(state_dict)

    model.to(device)
    model.eval()

    return model, device