File size: 379 Bytes
74b427f |
1 2 3 4 5 6 7 8 9 10 11 12 |
# model.py
import torch.nn as nn
from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights
def load_model(device='cpu'):
weights = EfficientNet_B7_Weights.DEFAULT
model = efficientnet_b7(weights=weights)
model.classifier[1] = nn.Linear(2560, 15)
return model.to(device)
def get_preprocess():
return EfficientNet_B7_Weights.DEFAULT.transforms() |