File size: 942 Bytes
c62c87b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torchvision
from torch import nn

def create_model(num_classes = 6, seed = 1):
    """Create an instance of the effnet_b2 model, freezes all layers and changes the classifier head.

        

        Returns: The model and its data transform

    """
    #get pretrained model and its transform
    weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
    model = torchvision.models.efficientnet_b2(weights = weights)
    transform = weights.transforms()

    #freeze all layers
    for param in model.parameters():
        param.requires_grad = False

    #create a new classifier head with 6 output classes
    classifier = nn.Sequential(nn.Dropout(p = 0.2, inplace = True),
                               nn.Linear(in_features = 1408, out_features = num_classes))

    #replace old classifier head with newly created one
    model.classifier = classifier
    
    return model, transform