File size: 2,172 Bytes
bf07f10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch.nn as nn
import timm
import torchvision.models as tvmodels

def get_model(name: str, num_classes: int, pretrained: bool = True):
    """Loads and adapts model architecture."""
    name = name.lower()
    
    if name.startswith('swin'):
        model = timm.create_model('swin_small_patch4_window7_224', pretrained=pretrained)
        if hasattr(model, 'reset_classifier'):
            model.reset_classifier(num_classes=num_classes)
        else:
            model.head = nn.Linear(model.head.in_features, num_classes)
        return model
    
    if name.startswith('convnext'):
        model = timm.create_model('convnext_tiny', pretrained=pretrained)
        if hasattr(model, 'reset_classifier'):
            model.reset_classifier(num_classes=num_classes)
        else:
            model.head.fc = nn.Linear(model.head.fc.in_features, num_classes)
        return model
    
    if name.startswith('densenet'):
        model = tvmodels.densenet169(pretrained=pretrained)
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
        return model
    
    if name.startswith('mobilenet'):
        model = timm.create_model('mobilenetv2_100', pretrained=pretrained)
        if hasattr(model, 'reset_classifier'):
            model.reset_classifier(num_classes=num_classes)
        else:
            model.classifier = nn.Linear(model.classifier.in_features, num_classes)
        return model
    
    if name.startswith('efficientnet'):
        model = timm.create_model('efficientnet_b0', pretrained=pretrained)
        if hasattr(model, 'reset_classifier'):
            model.reset_classifier(num_classes=num_classes)
        else:
            model.classifier = nn.Linear(model.classifier.in_features, num_classes)
        return model
    
    if name.startswith('maxvit'):
        model = timm.create_model('maxvit_tiny_tf_224', pretrained=pretrained)
        if hasattr(model, 'reset_classifier'):
            model.reset_classifier(num_classes=num_classes)
        else:
            model.head.fc = nn.Linear(model.head.fc.in_features, num_classes)
        return model
    
    raise ValueError(f'Unknown model: {name}')