File size: 774 Bytes
69beb24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torchvision

from torch import nn

def create_effnet_b2_instance(num_classes = 3):
  # fetch the model's pretrained weights
  effnetb2_pretrained_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
  # fetch the preprocessing transforms
  effnetb2_transforms = effnetb2_pretrained_weights.transforms()
  # get the model and load the pretrained weighits
  effnetb2 = torchvision.models.efficientnet_b2(weights=effnetb2_pretrained_weights)
  # freeze the feature extractor
  for param in effnetb2.parameters():
    param.requires_grad = False
  # fix the output 
  effnetb2.classifier = nn.Sequential(
      nn.Dropout(p = 0.3,inplace=True),
      nn.Linear(in_features = 1408,out_features = num_classes)
  )
  return effnetb2_transforms,effnetb2