Spaces:
Runtime error
Runtime error
File size: 1,463 Bytes
8478d83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch
import torch.nn as nn
import torchvision
# Create an EffNetB2 feature extractor
def create_effnet_b2(num_of_class: str=3,
transform: torchvision.transforms=None,
seed=42
):
"""Creates an EfficientNetB2 feature extractor model and transforms.
Args:
num_classes (int, optional): number of classes in the classifier head.
Defaults to 3.
seed (int, optional): random seed value. Defaults to 42.
Returns:
model (torch.nn.Module): EffNetB2 feature extractor model.
transforms (torchvision.transforms): EffNetB2 image transforms.
"""
# 1. Get the base mdoel with pretrained weights and send to target device
model = torchvision.models.efficientnet_b2(pretrained=True)
# 2. Freeze the base model layers
for param in model.parameters():
param.requires_grad = False
# 3. Set the seeds
torch.manual_seed(seed)
# 4. Change the classifier head
model.classifier = nn.Sequential(nn.Dropout(p=0.3, inplace=True),
nn.Linear(1408, num_of_class, bias=True)
)
return model, transform
# mymodel = create_effnet_b2(num_of_class=3,
# transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),
# seed=42)
# print(mymodel)
|