ckoozzzu commited on
Commit
a59766f
·
verified ·
1 Parent(s): ef2d7f4

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. MLBaseModelDriver.py +3 -3
MLBaseModelDriver.py CHANGED
@@ -53,8 +53,8 @@ class ProcessedSynapse(TypedDict):
53
  # --------- Класс MLBaseModelDriver ---------
54
  class MLBaseModelDriver:
55
 
56
- def __init__(self):
57
- self.model_path = model_path # сначала сохраняем путь
58
  self.model, self.label_encoders, self.scaler = self.load_model()
59
  self.preprocessor = DataPreprocessor(self.label_encoders, self.scaler)
60
 
@@ -64,7 +64,7 @@ class MLBaseModelDriver:
64
  model_class = self._import_model_class(model_class_file)
65
 
66
  model = model_class(input_dim=4)
67
- state_dict = torch.load(model_file, weights_only=False)
68
  checkpoint = torch.load(self.model_path, map_location=torch.device("cpu"))
69
  model.load_state_dict(checkpoint['model_state_dict'])
70
  self.preprocessor = checkpoint.get('preprocessor', None)
 
53
  # --------- Класс MLBaseModelDriver ---------
54
  class MLBaseModelDriver:
55
 
56
+ def __init__(self, model_path: str):
57
+ self.model_path = model_path # сохраняем путь модели
58
  self.model, self.label_encoders, self.scaler = self.load_model()
59
  self.preprocessor = DataPreprocessor(self.label_encoders, self.scaler)
60
 
 
64
  model_class = self._import_model_class(model_class_file)
65
 
66
  model = model_class(input_dim=4)
67
+ # Здесь загружаем сохранённые веса модели
68
  checkpoint = torch.load(self.model_path, map_location=torch.device("cpu"))
69
  model.load_state_dict(checkpoint['model_state_dict'])
70
  self.preprocessor = checkpoint.get('preprocessor', None)