Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -69,8 +69,8 @@ def load_checkpoint_for_inference(filepath, model_class):
|
|
| 69 |
# (You need the Model Class definition handy)
|
| 70 |
model = model_class()
|
| 71 |
|
| 72 |
-
# Load the checkpoint file
|
| 73 |
-
checkpoint = torch.load(filepath)
|
| 74 |
|
| 75 |
# Load the state dictionary into the model instance
|
| 76 |
model.load_state_dict(checkpoint)
|
|
@@ -78,8 +78,7 @@ def load_checkpoint_for_inference(filepath, model_class):
|
|
| 78 |
# Set the model to evaluation mode for inference
|
| 79 |
model.eval()
|
| 80 |
|
| 81 |
-
#
|
| 82 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 83 |
model.to(device)
|
| 84 |
|
| 85 |
print(f"Checkpoint loaded.")
|
|
|
|
| 69 |
# (You need the Model Class definition handy)
|
| 70 |
model = model_class()
|
| 71 |
|
| 72 |
+
# Load the checkpoint file with map_location to handle device mismatch
|
| 73 |
+
checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
|
| 74 |
|
| 75 |
# Load the state dictionary into the model instance
|
| 76 |
model.load_state_dict(checkpoint)
|
|
|
|
| 78 |
# Set the model to evaluation mode for inference
|
| 79 |
model.eval()
|
| 80 |
|
| 81 |
+
# Move the model to the device (GPU/CPU)
|
|
|
|
| 82 |
model.to(device)
|
| 83 |
|
| 84 |
print(f"Checkpoint loaded.")
|