Update modeling_whisper.py
Browse files- modeling_whisper.py +89 -18
modeling_whisper.py
CHANGED
|
@@ -831,6 +831,61 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
|
|
| 831 |
logger.warning(" Text processing will be disabled due to load failure.")
|
| 832 |
|
| 833 |
# В классе WhisperSSLEnsemble
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 834 |
def _load_ssl_ensemble_if_needed(self):
|
| 835 |
# Если модель уже загружена или не нужна, выходим
|
| 836 |
if self.ssl_ensemble_model is not None or not self.predict_mode:
|
|
@@ -838,31 +893,46 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
|
|
| 838 |
|
| 839 |
logger.info("Lazily loading SSL Ensemble model...")
|
| 840 |
ssl_ensemble_config = self.config.ssl_ensemble_config
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
|
|
|
| 847 |
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 851 |
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
|
|
|
|
|
|
|
|
|
| 857 |
|
| 858 |
-
#
|
| 859 |
weak_learners = WeakLearners(
|
| 860 |
audio_dim=ssl_ensemble_config["audio_dim"],
|
| 861 |
text_dim=ssl_ensemble_config["text_dim"],
|
| 862 |
device=self._target_device.type
|
| 863 |
)
|
| 864 |
-
|
| 865 |
-
|
| 866 |
|
| 867 |
meta_learner = StackingMetaLearner(
|
| 868 |
weak_output_dim=len(weak_learners.models),
|
|
@@ -870,7 +940,7 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
|
|
| 870 |
)
|
| 871 |
meta_learner.load_state_dict_from_file(meta_learner_path, device=self._target_device)
|
| 872 |
|
| 873 |
-
# С
|
| 874 |
self.ssl_ensemble_model = SSLEnsembleModel(
|
| 875 |
weak_learners=weak_learners,
|
| 876 |
stacking_meta_learner=meta_learner
|
|
@@ -880,6 +950,7 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
|
|
| 880 |
logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.")
|
| 881 |
|
| 882 |
except Exception as e:
|
|
|
|
| 883 |
logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True)
|
| 884 |
self.predict_mode = False
|
| 885 |
logger.warning(" Prediction with SSL Ensemble will be disabled.")
|
|
|
|
| 831 |
logger.warning(" Text processing will be disabled due to load failure.")
|
| 832 |
|
| 833 |
# В классе WhisperSSLEnsemble
|
| 834 |
+
# def _load_ssl_ensemble_if_needed(self):
|
| 835 |
+
# # Если модель уже загружена или не нужна, выходим
|
| 836 |
+
# if self.ssl_ensemble_model is not None or not self.predict_mode:
|
| 837 |
+
# return
|
| 838 |
+
|
| 839 |
+
# logger.info("Lazily loading SSL Ensemble model...")
|
| 840 |
+
# ssl_ensemble_config = self.config.ssl_ensemble_config
|
| 841 |
+
# try:
|
| 842 |
+
# # ШАГ 1: НАЙТИ ДИРЕКТОРИЮ, ГДЕ ЛЕЖАТ ВСЕ ФАЙЛЫ МОДЕЛИ
|
| 843 |
+
# # Это самый важный и правильный шаг.
|
| 844 |
+
# # __file__ - это путь к текущему файлу (modeling_whisper.py)
|
| 845 |
+
# # os.path.dirname() получает директорию из этого пути.
|
| 846 |
+
# model_dir = os.path.dirname(__file__)
|
| 847 |
+
|
| 848 |
+
# # ШАГ 2: ПОСТРОИТЬ АБСОЛЮТНЫЕ ПУТИ К ФАЙЛАМ ВЕСОВ
|
| 849 |
+
# weak_learners_filename = ssl_ensemble_config["weak_learners_path"]
|
| 850 |
+
# weak_learners_path = os.path.join(model_dir, weak_learners_filename)
|
| 851 |
+
|
| 852 |
+
# meta_learner_filename = ssl_ensemble_config["meta_learner_path"]
|
| 853 |
+
# meta_learner_path = os.path.join(model_dir, meta_learner_filename)
|
| 854 |
+
|
| 855 |
+
# logger.info(f"Attempting to load weak learners from: {weak_learners_path}")
|
| 856 |
+
# logger.info(f"Attempting to load meta learner from: {meta_learner_path}")
|
| 857 |
+
|
| 858 |
+
# # ШАГ 3: ЗАГРУЗИТЬ МОДЕЛИ ПО АБСОЛЮТНЫМ ПУТЯМ
|
| 859 |
+
# weak_learners = WeakLearners(
|
| 860 |
+
# audio_dim=ssl_ensemble_config["audio_dim"],
|
| 861 |
+
# text_dim=ssl_ensemble_config["text_dim"],
|
| 862 |
+
# device=self._target_device.type
|
| 863 |
+
# )
|
| 864 |
+
# if not weak_learners.load_fitted(weak_learners_path):
|
| 865 |
+
# raise RuntimeError(f"Failed to load weak learners from {weak_learners_path}")
|
| 866 |
+
|
| 867 |
+
# meta_learner = StackingMetaLearner(
|
| 868 |
+
# weak_output_dim=len(weak_learners.models),
|
| 869 |
+
# hidden_dim=ssl_ensemble_config["hidden_dim"]
|
| 870 |
+
# )
|
| 871 |
+
# meta_learner.load_state_dict_from_file(meta_learner_path, device=self._target_device)
|
| 872 |
+
|
| 873 |
+
# # СОЗДАНИЕ ИТОГОВОЙ МОДЕЛИ АНСАМБЛЯ
|
| 874 |
+
# self.ssl_ensemble_model = SSLEnsembleModel(
|
| 875 |
+
# weak_learners=weak_learners,
|
| 876 |
+
# stacking_meta_learner=meta_learner
|
| 877 |
+
# )
|
| 878 |
+
# self.ssl_ensemble_model.eval()
|
| 879 |
+
|
| 880 |
+
# logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.")
|
| 881 |
+
|
| 882 |
+
# except Exception as e:
|
| 883 |
+
# logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True)
|
| 884 |
+
# self.predict_mode = False
|
| 885 |
+
# logger.warning(" Prediction with SSL Ensemble will be disabled.")
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
# В классе WhisperSSLEnsemble
|
| 889 |
def _load_ssl_ensemble_if_needed(self):
|
| 890 |
# Если модель уже загружена или не нужна, выходим
|
| 891 |
if self.ssl_ensemble_model is not None or not self.predict_mode:
|
|
|
|
| 893 |
|
| 894 |
logger.info("Lazily loading SSL Ensemble model...")
|
| 895 |
ssl_ensemble_config = self.config.ssl_ensemble_config
|
| 896 |
+
|
| 897 |
+
# Получаем директорию, где лежат все файлы модели
|
| 898 |
+
model_dir = os.path.dirname(__file__)
|
| 899 |
+
|
| 900 |
+
# Строим абсолютные пути к файлам, которые мы ищем
|
| 901 |
+
weak_learners_filename = ssl_ensemble_config["weak_learners_path"]
|
| 902 |
+
weak_learners_path = os.path.join(model_dir, weak_learners_filename)
|
| 903 |
|
| 904 |
+
meta_learner_filename = ssl_ensemble_config["meta_learner_path"]
|
| 905 |
+
meta_learner_path = os.path.join(model_dir, meta_learner_filename)
|
| 906 |
+
|
| 907 |
+
try:
|
| 908 |
+
# Улучшенная проверка на наличие файлов с отладкой
|
| 909 |
+
if not os.path.exists(weak_learners_path):
|
| 910 |
+
# Получаем список всех файлов в директории для отладки
|
| 911 |
+
files_in_dir = os.listdir(model_dir)
|
| 912 |
+
error_message = (
|
| 913 |
+
f"Weak learners file not found at the expected path: {weak_learners_path}\n"
|
| 914 |
+
f"Expected filename from config: '{weak_learners_filename}'\n"
|
| 915 |
+
f"Files found in the model directory ('{model_dir}'):\n{files_in_dir}"
|
| 916 |
+
)
|
| 917 |
+
raise FileNotFoundError(error_message)
|
| 918 |
|
| 919 |
+
if not os.path.exists(meta_learner_path):
|
| 920 |
+
files_in_dir = os.listdir(model_dir)
|
| 921 |
+
error_message = (
|
| 922 |
+
f"Meta learner file not found at the expected path: {meta_learner_path}\n"
|
| 923 |
+
f"Expected filename from config: '{meta_learner_filename}'\n"
|
| 924 |
+
f"Files found in the model directory ('{model_dir}'):\n{files_in_dir}"
|
| 925 |
+
)
|
| 926 |
+
raise FileNotFoundError(error_message)
|
| 927 |
|
| 928 |
+
# Если все проверки пройдены, загружаем модели
|
| 929 |
weak_learners = WeakLearners(
|
| 930 |
audio_dim=ssl_ensemble_config["audio_dim"],
|
| 931 |
text_dim=ssl_ensemble_config["text_dim"],
|
| 932 |
device=self._target_device.type
|
| 933 |
)
|
| 934 |
+
# load_fitted уже внутри себя печатает ошибку, но мы можем быть уверены, что файл есть
|
| 935 |
+
weak_learners.load_fitted(weak_learners_path)
|
| 936 |
|
| 937 |
meta_learner = StackingMetaLearner(
|
| 938 |
weak_output_dim=len(weak_learners.models),
|
|
|
|
| 940 |
)
|
| 941 |
meta_learner.load_state_dict_from_file(meta_learner_path, device=self._target_device)
|
| 942 |
|
| 943 |
+
# Создание итоговой модели ансамбля
|
| 944 |
self.ssl_ensemble_model = SSLEnsembleModel(
|
| 945 |
weak_learners=weak_learners,
|
| 946 |
stacking_meta_learner=meta_learner
|
|
|
|
| 950 |
logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.")
|
| 951 |
|
| 952 |
except Exception as e:
|
| 953 |
+
# Этот блок теперь будет ловить наши детальные ошибки FileNotFoundError
|
| 954 |
logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True)
|
| 955 |
self.predict_mode = False
|
| 956 |
logger.warning(" Prediction with SSL Ensemble will be disabled.")
|