1NEYRON1 commited on
Commit
f496fb7
·
verified ·
1 Parent(s): 8d11de1

Update modeling_whisper.py

Browse files
Files changed (1) hide show
  1. modeling_whisper.py +89 -18
modeling_whisper.py CHANGED
@@ -831,6 +831,61 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
831
  logger.warning(" Text processing will be disabled due to load failure.")
832
 
833
  # В классе WhisperSSLEnsemble
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
834
  def _load_ssl_ensemble_if_needed(self):
835
  # Если модель уже загружена или не нужна, выходим
836
  if self.ssl_ensemble_model is not None or not self.predict_mode:
@@ -838,31 +893,46 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
838
 
839
  logger.info("Lazily loading SSL Ensemble model...")
840
  ssl_ensemble_config = self.config.ssl_ensemble_config
841
- try:
842
- # ШАГ 1: НАЙТИ ДИРЕКТОРИЮ, ГДЕ ЛЕЖАТ ВСЕ ФАЙЛЫ МОДЕЛИ
843
- # Это самый важный и правильный шаг.
844
- # __file__ - это путь к текущему файлу (modeling_whisper.py)
845
- # os.path.dirname() получает директорию из этого пути.
846
- model_dir = os.path.dirname(__file__)
 
847
 
848
- # ШАГ 2: ПОСТРОИТЬ АБСОЛЮТНЫЕ ПУТИ К ФАЙЛАМ ВЕСОВ
849
- weak_learners_filename = ssl_ensemble_config["weak_learners_path"]
850
- weak_learners_path = os.path.join(model_dir, weak_learners_filename)
 
 
 
 
 
 
 
 
 
 
 
851
 
852
- meta_learner_filename = ssl_ensemble_config["meta_learner_path"]
853
- meta_learner_path = os.path.join(model_dir, meta_learner_filename)
854
-
855
- logger.info(f"Attempting to load weak learners from: {weak_learners_path}")
856
- logger.info(f"Attempting to load meta learner from: {meta_learner_path}")
 
 
 
857
 
858
- # ШАГ 3: ЗАГРУЗИТЬ МОДЕЛИ ПО АБСОЛЮТНЫМ ПУТЯМ
859
  weak_learners = WeakLearners(
860
  audio_dim=ssl_ensemble_config["audio_dim"],
861
  text_dim=ssl_ensemble_config["text_dim"],
862
  device=self._target_device.type
863
  )
864
- if not weak_learners.load_fitted(weak_learners_path):
865
- raise RuntimeError(f"Failed to load weak learners from {weak_learners_path}")
866
 
867
  meta_learner = StackingMetaLearner(
868
  weak_output_dim=len(weak_learners.models),
@@ -870,7 +940,7 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
870
  )
871
  meta_learner.load_state_dict_from_file(meta_learner_path, device=self._target_device)
872
 
873
- # СОЗДАНИЕ ИТОГОВОЙ МОДЕЛИ АНСАМБЛЯ
874
  self.ssl_ensemble_model = SSLEnsembleModel(
875
  weak_learners=weak_learners,
876
  stacking_meta_learner=meta_learner
@@ -880,6 +950,7 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
880
  logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.")
881
 
882
  except Exception as e:
 
883
  logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True)
884
  self.predict_mode = False
885
  logger.warning(" Prediction with SSL Ensemble will be disabled.")
 
831
  logger.warning(" Text processing will be disabled due to load failure.")
832
 
833
  # В классе WhisperSSLEnsemble
834
+ # def _load_ssl_ensemble_if_needed(self):
835
+ # # Если модель уже загружена или не нужна, выходим
836
+ # if self.ssl_ensemble_model is not None or not self.predict_mode:
837
+ # return
838
+
839
+ # logger.info("Lazily loading SSL Ensemble model...")
840
+ # ssl_ensemble_config = self.config.ssl_ensemble_config
841
+ # try:
842
+ # # ШАГ 1: НАЙТИ ДИРЕКТОРИЮ, ГДЕ ЛЕЖАТ ВСЕ ФАЙЛЫ МОДЕЛИ
843
+ # # Это самый важный и правильный шаг.
844
+ # # __file__ - это путь к текущему файлу (modeling_whisper.py)
845
+ # # os.path.dirname() получает директорию из этого пути.
846
+ # model_dir = os.path.dirname(__file__)
847
+
848
+ # # ШАГ 2: ПОСТРОИТЬ АБСОЛЮТНЫЕ ПУТИ К ФАЙЛАМ ВЕСОВ
849
+ # weak_learners_filename = ssl_ensemble_config["weak_learners_path"]
850
+ # weak_learners_path = os.path.join(model_dir, weak_learners_filename)
851
+
852
+ # meta_learner_filename = ssl_ensemble_config["meta_learner_path"]
853
+ # meta_learner_path = os.path.join(model_dir, meta_learner_filename)
854
+
855
+ # logger.info(f"Attempting to load weak learners from: {weak_learners_path}")
856
+ # logger.info(f"Attempting to load meta learner from: {meta_learner_path}")
857
+
858
+ # # ШАГ 3: ЗАГРУЗИТЬ МОДЕЛИ ПО АБСОЛЮТНЫМ ПУТЯМ
859
+ # weak_learners = WeakLearners(
860
+ # audio_dim=ssl_ensemble_config["audio_dim"],
861
+ # text_dim=ssl_ensemble_config["text_dim"],
862
+ # device=self._target_device.type
863
+ # )
864
+ # if not weak_learners.load_fitted(weak_learners_path):
865
+ # raise RuntimeError(f"Failed to load weak learners from {weak_learners_path}")
866
+
867
+ # meta_learner = StackingMetaLearner(
868
+ # weak_output_dim=len(weak_learners.models),
869
+ # hidden_dim=ssl_ensemble_config["hidden_dim"]
870
+ # )
871
+ # meta_learner.load_state_dict_from_file(meta_learner_path, device=self._target_device)
872
+
873
+ # # СОЗДАНИЕ ИТОГОВОЙ МОДЕЛИ АНСАМБЛЯ
874
+ # self.ssl_ensemble_model = SSLEnsembleModel(
875
+ # weak_learners=weak_learners,
876
+ # stacking_meta_learner=meta_learner
877
+ # )
878
+ # self.ssl_ensemble_model.eval()
879
+
880
+ # logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.")
881
+
882
+ # except Exception as e:
883
+ # logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True)
884
+ # self.predict_mode = False
885
+ # logger.warning(" Prediction with SSL Ensemble will be disabled.")
886
+
887
+
888
+ # В классе WhisperSSLEnsemble
889
  def _load_ssl_ensemble_if_needed(self):
890
  # Если модель уже загружена или не нужна, выходим
891
  if self.ssl_ensemble_model is not None or not self.predict_mode:
 
893
 
894
  logger.info("Lazily loading SSL Ensemble model...")
895
  ssl_ensemble_config = self.config.ssl_ensemble_config
896
+
897
+ # Получаем директорию, где лежат все файлы модели
898
+ model_dir = os.path.dirname(__file__)
899
+
900
+ # Строим абсолютные пути к файлам, которые мы ищем
901
+ weak_learners_filename = ssl_ensemble_config["weak_learners_path"]
902
+ weak_learners_path = os.path.join(model_dir, weak_learners_filename)
903
 
904
+ meta_learner_filename = ssl_ensemble_config["meta_learner_path"]
905
+ meta_learner_path = os.path.join(model_dir, meta_learner_filename)
906
+
907
+ try:
908
+ # Улучшенная проверка на наличие файлов с отладкой
909
+ if not os.path.exists(weak_learners_path):
910
+ # Получаем список всех файлов в директории для отладки
911
+ files_in_dir = os.listdir(model_dir)
912
+ error_message = (
913
+ f"Weak learners file not found at the expected path: {weak_learners_path}\n"
914
+ f"Expected filename from config: '{weak_learners_filename}'\n"
915
+ f"Files found in the model directory ('{model_dir}'):\n{files_in_dir}"
916
+ )
917
+ raise FileNotFoundError(error_message)
918
 
919
+ if not os.path.exists(meta_learner_path):
920
+ files_in_dir = os.listdir(model_dir)
921
+ error_message = (
922
+ f"Meta learner file not found at the expected path: {meta_learner_path}\n"
923
+ f"Expected filename from config: '{meta_learner_filename}'\n"
924
+ f"Files found in the model directory ('{model_dir}'):\n{files_in_dir}"
925
+ )
926
+ raise FileNotFoundError(error_message)
927
 
928
+ # Если все проверки пройдены, загружаем модели
929
  weak_learners = WeakLearners(
930
  audio_dim=ssl_ensemble_config["audio_dim"],
931
  text_dim=ssl_ensemble_config["text_dim"],
932
  device=self._target_device.type
933
  )
934
+ # load_fitted уже внутри себя печатает ошибку, но мы можем быть уверены, что файл есть
935
+ weak_learners.load_fitted(weak_learners_path)
936
 
937
  meta_learner = StackingMetaLearner(
938
  weak_output_dim=len(weak_learners.models),
 
940
  )
941
  meta_learner.load_state_dict_from_file(meta_learner_path, device=self._target_device)
942
 
943
+ # Создание итоговой модели ансамбля
944
  self.ssl_ensemble_model = SSLEnsembleModel(
945
  weak_learners=weak_learners,
946
  stacking_meta_learner=meta_learner
 
950
  logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.")
951
 
952
  except Exception as e:
953
+ # Этот блок теперь будет ловить наши детальные ошибки FileNotFoundError
954
  logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True)
955
  self.predict_mode = False
956
  logger.warning(" Prediction with SSL Ensemble will be disabled.")