Update TractionModel.py
Browse files- app/TractionModel.py +2 -2
app/TractionModel.py
CHANGED
|
@@ -53,7 +53,7 @@ def create_model():
|
|
| 53 |
return model
|
| 54 |
|
| 55 |
|
| 56 |
-
def load_weights(model, path='model.pt'):
|
| 57 |
-
checkpoint = torch.load(path, map_location=torch.device(
|
| 58 |
model.load_state_dict(checkpoint)
|
| 59 |
return model
|
|
|
|
| 53 |
return model
|
| 54 |
|
| 55 |
|
| 56 |
+
def load_weights(model, path='model.pt', device_='cpu'):
|
| 57 |
+
checkpoint = torch.load(path, map_location=torch.device(device_))
|
| 58 |
model.load_state_dict(checkpoint)
|
| 59 |
return model
|