File size: 2,141 Bytes
604f286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Model factory for binary classification: ResNet-50, DenseNet-121, ViT-B/16."""

import timm
import torch.nn as nn


SUPPORTED_MODELS = {
    "resnet50": "resnet50",
    "densenet121": "densenet121",
    "vit_base_patch16_224": "vit_base_patch16_224",
}


def _replace_relu_with_gelu(module):
    for name, child in module.named_children():
        if isinstance(child, nn.ReLU):
            setattr(module, name, nn.GELU())
        else:
            _replace_relu_with_gelu(child)


def create_model(model_name="resnet50", pretrained=True, dropout=0.3, modified=False):
    """Create a binary classification model.



    Args:

        model_name: One of 'resnet50', 'densenet121', 'vit_base_patch16_224'.

        pretrained: Use ImageNet-pretrained weights.

        dropout: Dropout rate before the final classifier.

        modified: If True, replace ReLU with GELU in ResNet-50.



    Returns:

        model: nn.Module with a single-output (sigmoid) head.

    """
    if model_name not in SUPPORTED_MODELS:
        raise ValueError(f"Unknown model: {model_name}. Choose from {list(SUPPORTED_MODELS)}")

    model = timm.create_model(
        SUPPORTED_MODELS[model_name],
        pretrained=pretrained,
        num_classes=1,
        drop_rate=dropout,
    )

    if modified and model_name == "resnet50":
        _replace_relu_with_gelu(model)
        in_features = model.get_classifier().in_features
        hidden = in_features // 2
        model.fc = nn.Sequential(
            nn.Linear(in_features, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1),
        )

    return model


def freeze_backbone(model):
    """Freeze all parameters except the classification head."""
    classifier_params = set(id(p) for p in model.get_classifier().parameters())
    for param in model.parameters():
        if id(param) not in classifier_params:
            param.requires_grad = False


def unfreeze_backbone(model):
    """Unfreeze all parameters."""
    for param in model.parameters():
        param.requires_grad = True