Upload folder using huggingface_hub
Browse files- MLBaseModelDriver.py +4 -1
MLBaseModelDriver.py
CHANGED
|
@@ -64,7 +64,10 @@ class MLBaseModelDriver:
|
|
| 64 |
|
| 65 |
model = model_class(input_dim=4)
|
| 66 |
state_dict = torch.load(model_file, weights_only=False)
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
| 68 |
model.eval()
|
| 69 |
|
| 70 |
with open(scaler_file, 'rb') as f:
|
|
|
|
| 64 |
|
| 65 |
model = model_class(input_dim=4)
|
| 66 |
state_dict = torch.load(model_file, weights_only=False)
|
| 67 |
+
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
|
| 68 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 69 |
+
self.preprocessor = checkpoint.get('preprocessor', None)
|
| 70 |
+
self.input_dim = checkpoint.get('input_dim', None)
|
| 71 |
model.eval()
|
| 72 |
|
| 73 |
with open(scaler_file, 'rb') as f:
|