Spaces:
Build error
Build error
File size: 1,167 Bytes
dbe7aa3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | import torch
import torchvision
from torch import nn
from helper_functions import set_seeds
def create_effnetb2_model(output_classes:int=3,
seed=42):
"""
Creates a pretrained EfficientNet B2 model feature extractor, with the base layers frozen and the output classifier adjusted to the target setup
returns:
(model, transforms)
model: The Feature extractor model instance of EfficientNetB2
"""
# 1. Setup poretrained EffNetB2 weights
effnetb2_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
#2. Get the transforms
transforms = effnetb2_weights.transforms()
#3. Setup pretrines model instance
model = torchvision.models.efficientnet_b2(weights=effnetb2_weights)
#4. Freeze the base layers in the model - this will stop all base layers from training
for param in model.parameters():
param.requires_grad=False
#5. Change the classification head
#Set seed
set_seeds(42)
model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408,
out_features=output_classes,
bias=True)
)
return model, transforms
|