saneshashank commited on
Commit
6da2fc9
·
verified ·
1 Parent(s): 006971c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -13,7 +13,9 @@ with open("imagenet_classes.json", "r") as f:
13
  # Load model
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  model = resnet50(num_classes=1000, drop_path_rate=0.0, use_blurpool=True)
16
- model.load_state_dict(torch.load("best_resnet50_imagenet_1k.pt", map_location=device))
 
 
17
  model.to(device)
18
  model.eval()
19
 
 
13
  # Load model
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  model = resnet50(num_classes=1000, drop_path_rate=0.0, use_blurpool=True)
16
+ # model.load_state_dict(torch.load("best_resnet50_imagenet_1k.pt", map_location=device))
17
+ checkpoint = torch.load('best_resnet50_imagenet_1k.pt')
18
+ model.load_state_dict(checkpoint['model_state_dict'])
19
  model.to(device)
20
  model.eval()
21