File size: 350 Bytes
f6d80b1
 
 
 
 
 
 
 
 
38a7b9d
1
2
3
4
5
6
7
8
9
10
import torch
from efficientnet_pytorch import EfficientNet

class_names = ['normal', 'fracture']  # update if more classes

def load_model():
    model = EfficientNet.from_name("efficientnet-b0", num_classes=len(class_names))
    model.load_state_dict(torch.load("efnet_fracture_classifier.pth", map_location="cpu"))
    model.eval()
    return model