File size: 531 Bytes
c751fed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from torch.nn import Dropout, Linear
from torchvision.models import EfficientNet_B2_Weights, efficientnet_b2
import torch
def create_model():
    weights = EfficientNet_B2_Weights.DEFAULT
    model = efficientnet_b2(weights=weights)
    transform = weights.transforms()

    classifier = torch.nn.Sequential(
        Dropout(p=0.3, inplace=True),
        Linear(in_features=1408, out_features=3)
    )
    for layer in model.features:
        layer.requires_grad_(False)
    model.classifier = classifier
    return model,transform