penguin218 commited on
Commit
5088d0f
·
verified ·
1 Parent(s): 1a35878

update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -12,7 +12,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
12
  # 加载模型
13
  weights_path = "best_model.pth"
14
  model, _, _ = get_mobilenet_model(num_classes=16)
15
- model.load_state_dict(torch.load(weights_path, map_location=device, weights_only=False))
 
16
  model.to(device)
17
  model.eval()
18
 
 
12
  # 加载模型
13
  weights_path = "best_model.pth"
14
  model, _, _ = get_mobilenet_model(num_classes=16)
15
+ checkpoint = torch.load(weights_path, map_location=device, weights_only=False)
16
+ model.load_state_dict(checkpoint['state_dict']) # 注意这里是 checkpoint['state_dict']
17
  model.to(device)
18
  model.eval()
19