File size: 940 Bytes
4dbe437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36


import torch
import torchvision
from torch import nn

def create_effnetb2_model(num_classes:int=3,
                          seed:int=42):

  

   # 1. Setup  pretrained EffNetB2 weights
  effnetb2_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT

  # 2. Get EffNetB2 transforms
  effnetb2_transforms = effnetb2_weights.transforms()

  # 3. Setup pretrained model instance
  effnetb2 = torchvision.models.efficientnet_b2(weights=effnetb2_weights)

  # Set seeds
  torch.manual_seed(seed=seed)

  # 4. Freeze the base layer in the model (this will stop all layers form training)
  for params in effnetb2.parameters():
    params.requires_grad = False

  # 5. Chage the output layer (or header layer) classifier
  effnetb2.classifier = nn.Sequential(
    nn.Dropout(p=0.3, inplace=True),
    nn.Linear(in_features=1408,
              out_features=num_classes,
              bias=True)
  )
  return effnetb2, effnetb2_transforms