File size: 1,620 Bytes
1007aeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torchvision
import torch.nn as nn

def create_resnet50_model(num_classes: int = 2, seed: int = 42):
    """Creates a ResNet50 feature extractor model and transforms.

    Args:
        num_classes (int, optional): Number of classes in the classifier head.
            Defaults to 2.
        seed (int, optional): Random seed value. Defaults to 42.

    Returns:
        model (torch.nn.Module): ResNet50 feature extractor model.
        transforms (torchvision.transforms): ResNet50 image transforms.
    """
    # 1. Create ResNet50 pretrained weights and transforms
    weights = torchvision.models.resnet50(pretrained=True)
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    # 2. Create ResNet50 model with pretrained weights
    model = torchvision.models.resnet50(pretrained=False)

    # 3. Load the pretrained weights into the model
    model.load_state_dict(weights.state_dict())

    # 4. Freeze all layers in the base model
    for param in model.parameters():
        param.requires_grad = False

    # 5. Change classifier head with random seed for reproducibility
    torch.manual_seed(seed)
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(p=0.3, inplace=True),
        nn.Linear(in_features=num_features, out_features=num_classes),
    )

    return model, transforms