File size: 1,181 Bytes
ce714e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

import torch
import torchvision
from torch import nn

def create_effnet_b2_model(num_classes: int = 3,
                           seed: int = 42):
  """
  Creates an EfficientNetB2 feature extractor model and transforms.

    Args:
        num_classes (int, optional): number of classes in the classifier head.
            Defaults to 3.
        seed (int, optional): random seed value. Defaults to 42.

    Returns:
        model (torch.nn.Module): EffNetB2 feature extractor model.
        transforms (torchvision.transforms): EffNetB2 image transforms.

  """

  # 1. Setup pretrained weights
  weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT

  # 2.Get transforms
  transforms = weights.transforms()

  # 3. Cretate the pretrained model
  model = torchvision.models.efficientnet_b2(weights=weights)

  # 4. Freeze the base layer
  for param in model.parameters():
    param.requires_grad = False

  # 5. Update the classifier head to suit our data with reproducibility
  torch.manual_seed(seed)
  model.classifier = nn.Sequential(
      nn.Dropout(p=0.3, inplace=True),
      nn.Linear(in_features=1408, out_features= num_classes)
  )

  return model, transforms