zarqankhn's picture
Update model.py
38a7b9d verified
raw
history blame contribute delete
350 Bytes
import torch
from efficientnet_pytorch import EfficientNet
class_names = ['normal', 'fracture'] # update if more classes
def load_model():
model = EfficientNet.from_name("efficientnet-b0", num_classes=len(class_names))
model.load_state_dict(torch.load("efnet_fracture_classifier.pth", map_location="cpu"))
model.eval()
return model