File size: 473 Bytes
87b3ac1
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import torchvision
from torch import nn

def create_effnetb2(classes:int=3,seed:int=42):
    weights=torchvision.models.EfficientNet_B2_Weights.DEFAULT
    transforms=weights.transforms()
    model=torchvision.models.efficientnet_b2(weights=weights)
    for p in model.parameters():
        p.requires_grad=False
    torch.manual_seed(seed)
    model.classifier=nn.Sequential(nn.Dropout(p=0.3,inplace=True),nn.Linear(1408,classes))
    return model,transforms