Food-Vision-101 / model.py
Jamshid15's picture
Add changes
b912248
raw
history blame contribute delete
385 Bytes
import torch
import torchvision
from torch import nn
def create_effnetb4_model(num_classes:int = 101):
weights = torchvision.models.EfficientNet_B4_Weights.DEFAULT
transforms = weights.transforms()
model = torchvision.models.efficientnet_b4(weights=weights)
model.classifier[1] = nn.Linear(in_features=1792, out_features=num_classes)
return model, transforms