File size: 2,563 Bytes
4d48c23
 
6498090
4d48c23
 
6498090
4d48c23
3899ece
4e34900
6498090
3899ece
 
 
 
 
6498090
3899ece
 
 
 
 
6498090
3899ece
 
 
 
 
 
4e34900
 
 
3899ece
 
 
4d48c23
 
3899ece
 
 
 
 
4d48c23
3899ece
 
4d48c23
 
 
6498090
 
4e34900
 
6498090
4d48c23
 
 
 
 
 
4e34900
3899ece
a0f945e
4d48c23
3899ece
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download

class SimpleCNN(nn.Module):
    def __init__(self, model_type='f', num_classes=6):
        super(SimpleCNN, self).__init__()
        self.num_classes = num_classes
        self.model_type = model_type
        if model_type == 'f':
            self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
            self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
            self.fc1 = nn.Linear(64 * 28 * 28, 256)
            self.dropout = nn.Dropout(0.5)
        elif model_type == 'c':
            self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
            self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
            self.fc1 = nn.Linear(128 * 28 * 28, 512)
            self.dropout = nn.Dropout(0.5)
        elif model_type == 'q':
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
            self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
            self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
            self.fc1 = nn.Linear(512 * 14 * 14, 1024)
            self.dropout = nn.Dropout(0.3)
        else:
            raise ValueError(f"Unknown model type: {model_type}")

        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc2 = nn.Linear(self.fc1.out_features, num_classes)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        if self.model_type == 'q':
            x = self.pool(self.relu(self.conv4(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


def load_model(version='c', device='cpu'):
    model_type = version.lower()
    filename = f"Vbai-2.1{model_type}.pt"

    weights_path = hf_hub_download(
        repo_id="Neurazum/Vbai-DPA-2.1",
        filename=filename,
        repo_type="model"
    )

    model = SimpleCNN(model_type=model_type, num_classes=6).to(device)
    state_dict = torch.load(weights_path, map_location=device)
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    return model