File size: 350 Bytes
f6d80b1 38a7b9d | 1 2 3 4 5 6 7 8 9 10 | import torch
from efficientnet_pytorch import EfficientNet
class_names = ['normal', 'fracture'] # update if more classes
def load_model():
model = EfficientNet.from_name("efficientnet-b0", num_classes=len(class_names))
model.load_state_dict(torch.load("efnet_fracture_classifier.pth", map_location="cpu"))
model.eval()
return model |