Shilpaj commited on
Commit
db0bc4b
·
verified ·
1 Parent(s): 7e6224d

Fix: Use GPU

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -21,7 +21,7 @@ def load_model(model_path: str):
21
  model = models.resnet50(weights=None)
22
 
23
  # Load custom weights from a .pth file with CPU mapping
24
- state_dict = torch.load(model_path, map_location=torch.device('cpu'))
25
 
26
  # Filter out unexpected keys
27
  filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()}
 
21
  model = models.resnet50(weights=None)
22
 
23
  # Load custom weights from a .pth file with CPU mapping
24
+ state_dict = torch.load(model_path)
25
 
26
  # Filter out unexpected keys
27
  filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()}