File size: 379 Bytes
74b427f
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
# model.py
import torch.nn as nn
from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights

def load_model(device='cpu'):
    weights = EfficientNet_B7_Weights.DEFAULT
    model = efficientnet_b7(weights=weights)
    model.classifier[1] = nn.Linear(2560, 15)
    return model.to(device)

def get_preprocess():
    return EfficientNet_B7_Weights.DEFAULT.transforms()