zarqankhn commited on
Commit
f6d80b1
·
verified ·
1 Parent(s): 75253d8

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +10 -0
model.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from efficientnet_pytorch import EfficientNet
3
+
4
+ class_names = ['normal', 'fracture'] # update if more classes
5
+
6
+ def load_model():
7
+ model = EfficientNet.from_name("efficientnet-b0", num_classes=len(class_names))
8
+ model.load_state_dict(torch.load("efnet_fracture_classifier.pth", map_location="cpu"))
9
+ model.eval()
10
+ return model