foodvision_mini / model.py
yussaaa's picture
typo fix
8484686
raw
history blame contribute delete
790 Bytes
import gradio as gr
import os
from torch import nn
import torchvision
def create_EffNetB2_model(num_classes:int=3,
seed:int = 42):
"""
"""
weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
transform = weights.transforms()
model = torchvision.models.efficientnet_b2(weights=weights)
for parm in model.parameters():
parm.requires_grad = False
# set_seeds(seed)
model.classifier = nn.Sequential(nn.Dropout(p=0.2, inplace=True),
nn.Linear(in_features=1408,
out_features=num_classes,
bias=True
))
return model, transform