ckoozzzu commited on
Commit
caf00b5
·
verified ·
1 Parent(s): cb8c13e

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. MLBaseModelDriver.py +4 -1
MLBaseModelDriver.py CHANGED
@@ -64,7 +64,10 @@ class MLBaseModelDriver:
64
 
65
  model = model_class(input_dim=4)
66
  state_dict = torch.load(model_file, weights_only=False)
67
- model.load_state_dict(state_dict)
 
 
 
68
  model.eval()
69
 
70
  with open(scaler_file, 'rb') as f:
 
64
 
65
  model = model_class(input_dim=4)
66
  state_dict = torch.load(model_file, weights_only=False)
67
+ checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
68
+ model.load_state_dict(checkpoint['model_state_dict'])
69
+ self.preprocessor = checkpoint.get('preprocessor', None)
70
+ self.input_dim = checkpoint.get('input_dim', None)
71
  model.eval()
72
 
73
  with open(scaler_file, 'rb') as f: