Update modeling_whisper.py
Browse files- modeling_whisper.py +50 -14
modeling_whisper.py
CHANGED
|
@@ -769,22 +769,36 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
|
|
| 769 |
ssl_ensemble_config = self.config.ssl_ensemble_config
|
| 770 |
|
| 771 |
# Determine if 'weights' is a path or a variant name for whisper.load_model
|
| 772 |
-
whisper_path_or_variant = whisper_weights_path if whisper_weights_path else whisper_variant
|
| 773 |
|
| 774 |
-
logger.info(f"Loading Whisper model: '{whisper_path_or_variant}'...")
|
| 775 |
-
try:
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
except Exception as e:
|
| 786 |
-
|
| 787 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 788 |
|
| 789 |
self.use_text = text_model_type is not None and text_model_type.lower() != "none"
|
| 790 |
if self.use_text:
|
|
@@ -852,7 +866,28 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
|
|
| 852 |
self.to(self._target_device)
|
| 853 |
logger.info(f"WhisperSSLEnsemble initialization complete. Final model device: {self.device}")
|
| 854 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor:
|
|
|
|
| 856 |
processed_mels = []
|
| 857 |
# Use self.whisper_model.device as the definitive device for mel spectrograms
|
| 858 |
# as whisper.load_model puts its tensors on that device.
|
|
@@ -893,6 +928,7 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
|
|
| 893 |
return inputs
|
| 894 |
|
| 895 |
def get_embeddings(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> tuple:
|
|
|
|
| 896 |
if self.use_text and texts is None:
|
| 897 |
pass
|
| 898 |
|
|
|
|
| 769 |
ssl_ensemble_config = self.config.ssl_ensemble_config
|
| 770 |
|
| 771 |
# Determine if 'weights' is a path or a variant name for whisper.load_model
|
| 772 |
+
# whisper_path_or_variant = whisper_weights_path if whisper_weights_path else whisper_variant
|
| 773 |
|
| 774 |
+
# logger.info(f"Loading Whisper model: '{whisper_path_or_variant}'...")
|
| 775 |
+
# try:
|
| 776 |
+
# # Pass the _target_device directly to whisper.load_model
|
| 777 |
+
# wm = whisper.load_model(whisper_path_or_variant, device=self._target_device)
|
| 778 |
+
# # with torch.device("cpu"):
|
| 779 |
+
# # wm = whisper.load_model(whisper_path_or_variant, device='cpu')
|
| 780 |
|
| 781 |
+
# self.whisper_model = wm # Assign to self.whisper_model AFTER loading
|
| 782 |
+
# self.whisper_model.eval()
|
| 783 |
+
# self._audio_embedding_dim = self.whisper_model.encoder.ln_post.normalized_shape[0]
|
| 784 |
+
# logger.info(f" Whisper loaded. Audio embedding dimension: {self._audio_embedding_dim}. Actual Whisper device: {self.whisper_model.device}")
|
| 785 |
+
# except Exception as e:
|
| 786 |
+
# logger.error(f"Error loading Whisper model: {e}")
|
| 787 |
+
# raise RuntimeError(f"Failed to load Whisper model '{whisper_path_or_variant}'") from e
|
| 788 |
+
self.whisper_model = None
|
| 789 |
+
|
| 790 |
+
# Размеры эмбеддингов теперь должны быть явно указаны в конфиге,
|
| 791 |
+
# так как мы не можем узнать их из еще не загруженной модели.
|
| 792 |
+
# Убедитесь, что в вашем WhisperSSLEnsembleConfig есть эти поля.
|
| 793 |
+
if not hasattr(self.config, 'whisper_embedding_dim'):
|
| 794 |
+
raise ValueError("config.json must contain 'whisper_embedding_dim'")
|
| 795 |
+
if not hasattr(self.config, 'text_embedding_dim'):
|
| 796 |
+
raise ValueError("config.json must contain 'text_embedding_dim'")
|
| 797 |
+
|
| 798 |
+
self._audio_embedding_dim = self.config.whisper_embedding_dim
|
| 799 |
+
self._text_embedding_dim = 0 # Будет обновлено ниже, если есть текстовая модель
|
| 800 |
+
|
| 801 |
+
text_model_type = self.config.text_model_type
|
| 802 |
|
| 803 |
self.use_text = text_model_type is not None and text_model_type.lower() != "none"
|
| 804 |
if self.use_text:
|
|
|
|
| 866 |
self.to(self._target_device)
|
| 867 |
logger.info(f"WhisperSSLEnsemble initialization complete. Final model device: {self.device}")
|
| 868 |
|
| 869 |
+
def _load_whisper_if_needed(self):
|
| 870 |
+
"""
|
| 871 |
+
Lazily loads the whisper model on the first call to a method that needs it.
|
| 872 |
+
This avoids the 'meta' device conflict during __init__.
|
| 873 |
+
"""
|
| 874 |
+
if self.whisper_model is not None:
|
| 875 |
+
return
|
| 876 |
+
|
| 877 |
+
logger.info(f"Lazily loading Whisper model '{self.config.whisper_variant}' onto device '{self.device}'...")
|
| 878 |
+
try:
|
| 879 |
+
# К моменту вызова этого метода основная модель уже на своем финальном устройстве (self.device).
|
| 880 |
+
# Мы можем безопасно загрузить модель Whisper прямо на это устройство.
|
| 881 |
+
whisper_path_or_variant = self.config.whisper_weights_path or self.config.whisper_variant
|
| 882 |
+
self.whisper_model = whisper.load_model(whisper_path_or_variant, device=self.device)
|
| 883 |
+
self.whisper_model.eval()
|
| 884 |
+
logger.info("Whisper model loaded successfully.")
|
| 885 |
+
except Exception as e:
|
| 886 |
+
logger.error(f"Failed to lazily load Whisper model: {e}")
|
| 887 |
+
raise RuntimeError("Could not initialize the Whisper sub-component.") from e
|
| 888 |
+
|
| 889 |
def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor:
|
| 890 |
+
self._load_whisper_if_needed()
|
| 891 |
processed_mels = []
|
| 892 |
# Use self.whisper_model.device as the definitive device for mel spectrograms
|
| 893 |
# as whisper.load_model puts its tensors on that device.
|
|
|
|
| 928 |
return inputs
|
| 929 |
|
| 930 |
def get_embeddings(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> tuple:
|
| 931 |
+
self._load_whisper_if_needed()
|
| 932 |
if self.use_text and texts is None:
|
| 933 |
pass
|
| 934 |
|