cameron-d commited on
Commit
d98797f
·
verified ·
1 Parent(s): 7869d69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -69,8 +69,8 @@ def load_checkpoint_for_inference(filepath, model_class):
69
  # (You need the Model Class definition handy)
70
  model = model_class()
71
 
72
- # Load the checkpoint file
73
- checkpoint = torch.load(filepath)
74
 
75
  # Load the state dictionary into the model instance
76
  model.load_state_dict(checkpoint)
@@ -78,8 +78,7 @@ def load_checkpoint_for_inference(filepath, model_class):
78
  # Set the model to evaluation mode for inference
79
  model.eval()
80
 
81
- # Optional: Move the model to the appropriate device (GPU/CPU)
82
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
  model.to(device)
84
 
85
  print(f"Checkpoint loaded.")
 
69
  # (You need the Model Class definition handy)
70
  model = model_class()
71
 
72
+ # Load the checkpoint file with map_location to handle device mismatch
73
+ checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
74
 
75
  # Load the state dictionary into the model instance
76
  model.load_state_dict(checkpoint)
 
78
  # Set the model to evaluation mode for inference
79
  model.eval()
80
 
81
+ # Move the model to the device (GPU/CPU)
 
82
  model.to(device)
83
 
84
  print(f"Checkpoint loaded.")