FoodVisionBig / model.py
Edesak's picture
init comit
423bf0c
raw
history blame contribute delete
537 Bytes
import torch
from torch.nn import Dropout, Linear
from torchvision.models import EfficientNet_B2_Weights, efficientnet_b2
def create_model():
weights = EfficientNet_B2_Weights.DEFAULT
model = efficientnet_b2(weights=weights)
transform = weights.transforms()
classifier = torch.nn.Sequential(
Dropout(p=0.3, inplace=True),
Linear(in_features=1408, out_features=101)
)
for layer in model.features:
layer.requires_grad_(False)
model.classifier = classifier
return model, transform