ckoozzzu commited on
Commit
912ac89
·
verified ·
1 Parent(s): 10c0cd2

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. MLBaseModelDriver.py +1 -1
MLBaseModelDriver.py CHANGED
@@ -61,7 +61,7 @@ class MLBaseModelDriver:
61
  model_class = self._import_model_class(model_class_file)
62
 
63
  model = model_class(input_dim=4)
64
- state_dict = torch.load(model_file, weights_only=False)
65
  model.load_state_dict(state_dict)
66
  model.eval()
67
 
 
61
  model_class = self._import_model_class(model_class_file)
62
 
63
  model = model_class(input_dim=4)
64
+ state_dict = torch.load(model_file, map_location=torch.device("cpu"), weights_only=False)
65
  model.load_state_dict(state_dict)
66
  model.eval()
67