Update modeling_whisper.py
Browse files- modeling_whisper.py +31 -34
modeling_whisper.py
CHANGED
|
@@ -887,52 +887,48 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
|
|
| 887 |
|
| 888 |
# В классе WhisperSSLEnsemble
|
| 889 |
def _load_ssl_ensemble_if_needed(self):
|
| 890 |
-
# Если модель уже загружена или не нужна, выходим
|
| 891 |
if self.ssl_ensemble_model is not None or not self.predict_mode:
|
| 892 |
return
|
| 893 |
|
| 894 |
logger.info("Lazily loading SSL Ensemble model...")
|
| 895 |
ssl_ensemble_config = self.config.ssl_ensemble_config
|
| 896 |
|
| 897 |
-
# Получаем директорию, где лежат все файлы модели
|
| 898 |
-
model_dir = os.path.dirname(__file__)
|
| 899 |
-
|
| 900 |
-
# Строим абсолютные пути к файлам, которые мы ищем
|
| 901 |
-
weak_learners_filename = ssl_ensemble_config["weak_learners_path"]
|
| 902 |
-
weak_learners_path = os.path.join(model_dir, weak_learners_filename)
|
| 903 |
-
|
| 904 |
-
meta_learner_filename = ssl_ensemble_config["meta_learner_path"]
|
| 905 |
-
meta_learner_path = os.path.join(model_dir, meta_learner_filename)
|
| 906 |
-
|
| 907 |
try:
|
| 908 |
-
#
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 927 |
|
| 928 |
-
#
|
| 929 |
weak_learners = WeakLearners(
|
| 930 |
audio_dim=ssl_ensemble_config["audio_dim"],
|
| 931 |
text_dim=ssl_ensemble_config["text_dim"],
|
| 932 |
device=self._target_device.type
|
| 933 |
)
|
| 934 |
-
|
| 935 |
-
weak_learners.load_fitted(weak_learners_path)
|
|
|
|
| 936 |
|
| 937 |
meta_learner = StackingMetaLearner(
|
| 938 |
weak_output_dim=len(weak_learners.models),
|
|
@@ -950,10 +946,11 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
|
|
| 950 |
logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.")
|
| 951 |
|
| 952 |
except Exception as e:
|
| 953 |
-
# Этот блок теперь будет ловить наши детальные ошибки FileNotFoundError
|
| 954 |
logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True)
|
| 955 |
self.predict_mode = False
|
| 956 |
logger.warning(" Prediction with SSL Ensemble will be disabled.")
|
|
|
|
|
|
|
| 957 |
|
| 958 |
def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor:
|
| 959 |
self._load_whisper_if_needed()
|
|
|
|
| 887 |
|
| 888 |
# В классе WhisperSSLEnsemble
|
| 889 |
def _load_ssl_ensemble_if_needed(self):
|
|
|
|
| 890 |
if self.ssl_ensemble_model is not None or not self.predict_mode:
|
| 891 |
return
|
| 892 |
|
| 893 |
logger.info("Lazily loading SSL Ensemble model...")
|
| 894 |
ssl_ensemble_config = self.config.ssl_ensemble_config
|
| 895 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 896 |
try:
|
| 897 |
+
# НОВЫЙ ПОДХОД: Скачиваем файлы напрямую из репозитория
|
| 898 |
+
from huggingface_hub import hf_hub_download
|
| 899 |
+
|
| 900 |
+
# Получаем имя репозитория из конфига
|
| 901 |
+
repo_id = getattr(self.config, '_name_or_path', '1NEYRON1/whisper')
|
| 902 |
+
|
| 903 |
+
# Скачиваем файлы весов напрямую из репозитория
|
| 904 |
+
weak_learners_filename = ssl_ensemble_config["weak_learners_path"]
|
| 905 |
+
meta_learner_filename = ssl_ensemble_config["meta_learner_path"]
|
| 906 |
+
|
| 907 |
+
logger.info(f"Downloading {weak_learners_filename} from {repo_id}...")
|
| 908 |
+
weak_learners_path = hf_hub_download(
|
| 909 |
+
repo_id=repo_id,
|
| 910 |
+
filename=weak_learners_filename
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
logger.info(f"Downloading {meta_learner_filename} from {repo_id}...")
|
| 914 |
+
meta_learner_path = hf_hub_download(
|
| 915 |
+
repo_id=repo_id,
|
| 916 |
+
filename=meta_learner_filename
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
logger.info(f"Files downloaded successfully:")
|
| 920 |
+
logger.info(f" Weak learners: {weak_learners_path}")
|
| 921 |
+
logger.info(f" Meta learner: {meta_learner_path}")
|
| 922 |
|
| 923 |
+
# Теперь загружаем модели из скачанных файлов
|
| 924 |
weak_learners = WeakLearners(
|
| 925 |
audio_dim=ssl_ensemble_config["audio_dim"],
|
| 926 |
text_dim=ssl_ensemble_config["text_dim"],
|
| 927 |
device=self._target_device.type
|
| 928 |
)
|
| 929 |
+
|
| 930 |
+
if not weak_learners.load_fitted(weak_learners_path):
|
| 931 |
+
raise RuntimeError(f"Failed to load weak learners from {weak_learners_path}")
|
| 932 |
|
| 933 |
meta_learner = StackingMetaLearner(
|
| 934 |
weak_output_dim=len(weak_learners.models),
|
|
|
|
| 946 |
logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.")
|
| 947 |
|
| 948 |
except Exception as e:
|
|
|
|
| 949 |
logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True)
|
| 950 |
self.predict_mode = False
|
| 951 |
logger.warning(" Prediction with SSL Ensemble will be disabled.")
|
| 952 |
+
|
| 953 |
+
|
| 954 |
|
| 955 |
def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor:
|
| 956 |
self._load_whisper_if_needed()
|