1NEYRON1 commited on
Commit
70cb727
·
verified ·
1 Parent(s): 9d626a4

Update modeling_whisper.py

Browse files
Files changed (1) hide show
  1. modeling_whisper.py +31 -34
modeling_whisper.py CHANGED
@@ -887,52 +887,48 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
887
 
888
  # В классе WhisperSSLEnsemble
889
  def _load_ssl_ensemble_if_needed(self):
890
- # Если модель уже загружена или не нужна, выходим
891
  if self.ssl_ensemble_model is not None or not self.predict_mode:
892
  return
893
 
894
  logger.info("Lazily loading SSL Ensemble model...")
895
  ssl_ensemble_config = self.config.ssl_ensemble_config
896
 
897
- # Получаем директорию, где лежат все файлы модели
898
- model_dir = os.path.dirname(__file__)
899
-
900
- # Строим абсолютные пути к файлам, которые мы ищем
901
- weak_learners_filename = ssl_ensemble_config["weak_learners_path"]
902
- weak_learners_path = os.path.join(model_dir, weak_learners_filename)
903
-
904
- meta_learner_filename = ssl_ensemble_config["meta_learner_path"]
905
- meta_learner_path = os.path.join(model_dir, meta_learner_filename)
906
-
907
  try:
908
- # Улучшенная проверка на наличие файлов с отладкой
909
- if not os.path.exists(weak_learners_path):
910
- # Получаем список всех файлов в директории для отладки
911
- files_in_dir = os.listdir(model_dir)
912
- error_message = (
913
- f"Weak learners file not found at the expected path: {weak_learners_path}\n"
914
- f"Expected filename from config: '{weak_learners_filename}'\n"
915
- f"Files found in the model directory ('{model_dir}'):\n{files_in_dir}"
916
- )
917
- raise FileNotFoundError(error_message)
918
-
919
- if not os.path.exists(meta_learner_path):
920
- files_in_dir = os.listdir(model_dir)
921
- error_message = (
922
- f"Meta learner file not found at the expected path: {meta_learner_path}\n"
923
- f"Expected filename from config: '{meta_learner_filename}'\n"
924
- f"Files found in the model directory ('{model_dir}'):\n{files_in_dir}"
925
- )
926
- raise FileNotFoundError(error_message)
 
 
 
 
 
 
927
 
928
- # Если все проверки пройдены, загружаем модели
929
  weak_learners = WeakLearners(
930
  audio_dim=ssl_ensemble_config["audio_dim"],
931
  text_dim=ssl_ensemble_config["text_dim"],
932
  device=self._target_device.type
933
  )
934
- # load_fitted уже внутри себя печатает ошибку, но мы можем быть уверены, что файл есть
935
- weak_learners.load_fitted(weak_learners_path)
 
936
 
937
  meta_learner = StackingMetaLearner(
938
  weak_output_dim=len(weak_learners.models),
@@ -950,10 +946,11 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
950
  logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.")
951
 
952
  except Exception as e:
953
- # Этот блок теперь будет ловить наши детальные ошибки FileNotFoundError
954
  logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True)
955
  self.predict_mode = False
956
  logger.warning(" Prediction with SSL Ensemble will be disabled.")
 
 
957
 
958
  def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor:
959
  self._load_whisper_if_needed()
 
887
 
888
  # В классе WhisperSSLEnsemble
889
  def _load_ssl_ensemble_if_needed(self):
 
890
  if self.ssl_ensemble_model is not None or not self.predict_mode:
891
  return
892
 
893
  logger.info("Lazily loading SSL Ensemble model...")
894
  ssl_ensemble_config = self.config.ssl_ensemble_config
895
 
 
 
 
 
 
 
 
 
 
 
896
  try:
897
+ # НОВЫЙ ПОДХОД: Скачиваем файлы напрямую из репозитория
898
+ from huggingface_hub import hf_hub_download
899
+
900
+ # Получаем имя репозитория из конфига
901
+ repo_id = getattr(self.config, '_name_or_path', '1NEYRON1/whisper')
902
+
903
+ # Скачиваем файлы весов напрямую из репозитория
904
+ weak_learners_filename = ssl_ensemble_config["weak_learners_path"]
905
+ meta_learner_filename = ssl_ensemble_config["meta_learner_path"]
906
+
907
+ logger.info(f"Downloading {weak_learners_filename} from {repo_id}...")
908
+ weak_learners_path = hf_hub_download(
909
+ repo_id=repo_id,
910
+ filename=weak_learners_filename
911
+ )
912
+
913
+ logger.info(f"Downloading {meta_learner_filename} from {repo_id}...")
914
+ meta_learner_path = hf_hub_download(
915
+ repo_id=repo_id,
916
+ filename=meta_learner_filename
917
+ )
918
+
919
+ logger.info(f"Files downloaded successfully:")
920
+ logger.info(f" Weak learners: {weak_learners_path}")
921
+ logger.info(f" Meta learner: {meta_learner_path}")
922
 
923
+ # Теперь загружаем модели из скачанных файлов
924
  weak_learners = WeakLearners(
925
  audio_dim=ssl_ensemble_config["audio_dim"],
926
  text_dim=ssl_ensemble_config["text_dim"],
927
  device=self._target_device.type
928
  )
929
+
930
+ if not weak_learners.load_fitted(weak_learners_path):
931
+ raise RuntimeError(f"Failed to load weak learners from {weak_learners_path}")
932
 
933
  meta_learner = StackingMetaLearner(
934
  weak_output_dim=len(weak_learners.models),
 
946
  logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.")
947
 
948
  except Exception as e:
 
949
  logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True)
950
  self.predict_mode = False
951
  logger.warning(" Prediction with SSL Ensemble will be disabled.")
952
+
953
+
954
 
955
  def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor:
956
  self._load_whisper_if_needed()