File size: 1,463 Bytes
8478d83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

import torch
import torch.nn as nn
import torchvision


# Create an EffNetB2 feature extractor
def create_effnet_b2(num_of_class: str=3,
                     transform: torchvision.transforms=None,
                     seed=42
                     ):
    """Creates an EfficientNetB2 feature extractor model and transforms.

    Args:
        num_classes (int, optional): number of classes in the classifier head. 
            Defaults to 3.
        seed (int, optional): random seed value. Defaults to 42.

    Returns:
        model (torch.nn.Module): EffNetB2 feature extractor model. 
        transforms (torchvision.transforms): EffNetB2 image transforms.
    """
    
    # 1. Get the base mdoel with pretrained weights and send to target device
    model = torchvision.models.efficientnet_b2(pretrained=True)
    
    # 2. Freeze the base model layers
    for param in model.parameters():
        param.requires_grad = False
    
    # 3. Set the seeds    
    torch.manual_seed(seed)
    
    # 4. Change the classifier head
    model.classifier = nn.Sequential(nn.Dropout(p=0.3, inplace=True),
                                     nn.Linear(1408, num_of_class, bias=True)
                                    )
    
    return model, transform

# mymodel = create_effnet_b2(num_of_class=3, 
#                            transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),
#                            seed=42)
# print(mymodel)