File size: 1,167 Bytes
dbe7aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torchvision

from torch import nn
from helper_functions import set_seeds

def create_effnetb2_model(output_classes:int=3,
                           seed=42):
  """
  Creates a pretrained EfficientNet B2 model feature extractor, with the base layers frozen and the output classifier adjusted to the target setup

  returns:
    (model, transforms)
    model: The Feature extractor model instance of EfficientNetB2
  """
  # 1. Setup poretrained EffNetB2 weights
  effnetb2_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT

  #2. Get the transforms
  transforms = effnetb2_weights.transforms()

  #3. Setup pretrines model instance
  model = torchvision.models.efficientnet_b2(weights=effnetb2_weights)

  #4. Freeze the base layers in the model - this will stop all base layers from training
  for param in model.parameters():
    param.requires_grad=False

  #5. Change the classification head
  #Set seed
  set_seeds(42)
  model.classifier = nn.Sequential(
      nn.Dropout(p=0.3, inplace=True),
      nn.Linear(in_features=1408,
                out_features=output_classes,
                bias=True)
  )
  return model, transforms