Youseff1987's picture
Upload model.py with huggingface_hub
74b427f verified
# model.py
import torch.nn as nn
from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights
def load_model(device='cpu'):
weights = EfficientNet_B7_Weights.DEFAULT
model = efficientnet_b7(weights=weights)
model.classifier[1] = nn.Linear(2560, 15)
return model.to(device)
def get_preprocess():
return EfficientNet_B7_Weights.DEFAULT.transforms()