urameez commited on
Commit
00d2fe8
·
verified ·
1 Parent(s): 4ffc6a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -13,12 +13,18 @@ class_names=["pizza","steak","sushi"]
13
  effnetb2,effnetb2_transforms=create_effnetb2_model(num_classes=len(class_names))
14
 
15
  # Load save weights
16
- effnetb2.load_state_dict(
17
- torch.load(
18
- f="effnetb2_feature_extractor_food101_mini.pth",
19
- map_location=torch.device("cpu") # load the model to the CPU
20
- )
21
- )
 
 
 
 
 
 
22
 
23
  ### 3. Predict function ###
24
  def predict(img)->Tuple[Dict,float]:
 
13
  effnetb2,effnetb2_transforms=create_effnetb2_model(num_classes=len(class_names))
14
 
15
  # Load save weights
16
+ # effnetb2.load_state_dict(
17
+ # torch.load(
18
+ # f="effnetb2_feature_extractor_food101_mini.pth",
19
+ # map_location=torch.device("cpu") # load the model to the CPU
20
+ # )
21
+ # )
22
+
23
+ # Load the state_dict (no DataParallel prefix)
24
+ effnetb2.load_state_dict(torch.load("effnetb2_feature_extractor_food101_mini.pth", map_location=torch.device("cpu")))
25
+
26
+ # Wrap with DataParallel if needed
27
+ # effnetb2 = DataParallel(effnetb2).to("device")
28
 
29
  ### 3. Predict function ###
30
  def predict(img)->Tuple[Dict,float]: