| import torch | |
| from efficientnet_pytorch import EfficientNet | |
| class_names = ['normal', 'fracture'] # update if more classes | |
| def load_model(): | |
| model = EfficientNet.from_name("efficientnet-b0", num_classes=len(class_names)) | |
| model.load_state_dict(torch.load("efnet_fracture_classifier.pth", map_location="cpu")) | |
| model.eval() | |
| return model |