import torch from efficientnet_pytorch import EfficientNet class_names = ['normal', 'fracture'] # update if more classes def load_model(): model = EfficientNet.from_name("efficientnet-b0", num_classes=len(class_names)) model.load_state_dict(torch.load("efnet_fracture_classifier.pth", map_location="cpu")) model.eval() return model