Upload folder using huggingface_hub
Browse files- MLBaseModelDriver.py +3 -3
MLBaseModelDriver.py
CHANGED
|
@@ -53,8 +53,8 @@ class ProcessedSynapse(TypedDict):
|
|
| 53 |
# --------- Класс MLBaseModelDriver ---------
|
| 54 |
class MLBaseModelDriver:
|
| 55 |
|
| 56 |
-
def __init__(self):
|
| 57 |
-
self.model_path = model_path #
|
| 58 |
self.model, self.label_encoders, self.scaler = self.load_model()
|
| 59 |
self.preprocessor = DataPreprocessor(self.label_encoders, self.scaler)
|
| 60 |
|
|
@@ -64,7 +64,7 @@ class MLBaseModelDriver:
|
|
| 64 |
model_class = self._import_model_class(model_class_file)
|
| 65 |
|
| 66 |
model = model_class(input_dim=4)
|
| 67 |
-
|
| 68 |
checkpoint = torch.load(self.model_path, map_location=torch.device("cpu"))
|
| 69 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 70 |
self.preprocessor = checkpoint.get('preprocessor', None)
|
|
|
|
| 53 |
# --------- Класс MLBaseModelDriver ---------
|
| 54 |
class MLBaseModelDriver:
|
| 55 |
|
| 56 |
+
def __init__(self, model_path: str):
|
| 57 |
+
self.model_path = model_path # сохраняем путь модели
|
| 58 |
self.model, self.label_encoders, self.scaler = self.load_model()
|
| 59 |
self.preprocessor = DataPreprocessor(self.label_encoders, self.scaler)
|
| 60 |
|
|
|
|
| 64 |
model_class = self._import_model_class(model_class_file)
|
| 65 |
|
| 66 |
model = model_class(input_dim=4)
|
| 67 |
+
# Здесь загружаем сохранённые веса модели
|
| 68 |
checkpoint = torch.load(self.model_path, map_location=torch.device("cpu"))
|
| 69 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 70 |
self.preprocessor = checkpoint.get('preprocessor', None)
|