1NEYRON1 commited on
Commit
e43f6b0
·
verified ·
1 Parent(s): be002f7

Update modeling_whisper.py

Browse files
Files changed (1) hide show
  1. modeling_whisper.py +50 -14
modeling_whisper.py CHANGED
@@ -769,22 +769,36 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
769
  ssl_ensemble_config = self.config.ssl_ensemble_config
770
 
771
  # Determine if 'weights' is a path or a variant name for whisper.load_model
772
- whisper_path_or_variant = whisper_weights_path if whisper_weights_path else whisper_variant
773
 
774
- logger.info(f"Loading Whisper model: '{whisper_path_or_variant}'...")
775
- try:
776
- # Pass the _target_device directly to whisper.load_model
777
- wm = whisper.load_model(whisper_path_or_variant, device=self._target_device)
778
- # with torch.device("cpu"):
779
- # wm = whisper.load_model(whisper_path_or_variant, device='cpu')
780
 
781
- self.whisper_model = wm # Assign to self.whisper_model AFTER loading
782
- self.whisper_model.eval()
783
- self._audio_embedding_dim = self.whisper_model.encoder.ln_post.normalized_shape[0]
784
- logger.info(f" Whisper loaded. Audio embedding dimension: {self._audio_embedding_dim}. Actual Whisper device: {self.whisper_model.device}")
785
- except Exception as e:
786
- logger.error(f"Error loading Whisper model: {e}")
787
- raise RuntimeError(f"Failed to load Whisper model '{whisper_path_or_variant}'") from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
788
 
789
  self.use_text = text_model_type is not None and text_model_type.lower() != "none"
790
  if self.use_text:
@@ -852,7 +866,28 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
852
  self.to(self._target_device)
853
  logger.info(f"WhisperSSLEnsemble initialization complete. Final model device: {self.device}")
854
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855
  def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor:
 
856
  processed_mels = []
857
  # Use self.whisper_model.device as the definitive device for mel spectrograms
858
  # as whisper.load_model puts its tensors on that device.
@@ -893,6 +928,7 @@ class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
893
  return inputs
894
 
895
  def get_embeddings(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> tuple:
 
896
  if self.use_text and texts is None:
897
  pass
898
 
 
769
  ssl_ensemble_config = self.config.ssl_ensemble_config
770
 
771
  # Determine if 'weights' is a path or a variant name for whisper.load_model
772
+ # whisper_path_or_variant = whisper_weights_path if whisper_weights_path else whisper_variant
773
 
774
+ # logger.info(f"Loading Whisper model: '{whisper_path_or_variant}'...")
775
+ # try:
776
+ # # Pass the _target_device directly to whisper.load_model
777
+ # wm = whisper.load_model(whisper_path_or_variant, device=self._target_device)
778
+ # # with torch.device("cpu"):
779
+ # # wm = whisper.load_model(whisper_path_or_variant, device='cpu')
780
 
781
+ # self.whisper_model = wm # Assign to self.whisper_model AFTER loading
782
+ # self.whisper_model.eval()
783
+ # self._audio_embedding_dim = self.whisper_model.encoder.ln_post.normalized_shape[0]
784
+ # logger.info(f" Whisper loaded. Audio embedding dimension: {self._audio_embedding_dim}. Actual Whisper device: {self.whisper_model.device}")
785
+ # except Exception as e:
786
+ # logger.error(f"Error loading Whisper model: {e}")
787
+ # raise RuntimeError(f"Failed to load Whisper model '{whisper_path_or_variant}'") from e
788
+ self.whisper_model = None
789
+
790
+ # Размеры эмбеддингов теперь должны быть явно указаны в конфиге,
791
+ # так как мы не можем узнать их из еще не загруженной модели.
792
+ # Убедитесь, что в вашем WhisperSSLEnsembleConfig есть эти поля.
793
+ if not hasattr(self.config, 'whisper_embedding_dim'):
794
+ raise ValueError("config.json must contain 'whisper_embedding_dim'")
795
+ if not hasattr(self.config, 'text_embedding_dim'):
796
+ raise ValueError("config.json must contain 'text_embedding_dim'")
797
+
798
+ self._audio_embedding_dim = self.config.whisper_embedding_dim
799
+ self._text_embedding_dim = 0 # Будет обновлено ниже, если есть текстовая модель
800
+
801
+ text_model_type = self.config.text_model_type
802
 
803
  self.use_text = text_model_type is not None and text_model_type.lower() != "none"
804
  if self.use_text:
 
866
  self.to(self._target_device)
867
  logger.info(f"WhisperSSLEnsemble initialization complete. Final model device: {self.device}")
868
 
869
+ def _load_whisper_if_needed(self):
870
+ """
871
+ Lazily loads the whisper model on the first call to a method that needs it.
872
+ This avoids the 'meta' device conflict during __init__.
873
+ """
874
+ if self.whisper_model is not None:
875
+ return
876
+
877
+ logger.info(f"Lazily loading Whisper model '{self.config.whisper_variant}' onto device '{self.device}'...")
878
+ try:
879
+ # К моменту вызова этого метода основная модель уже на своем финальном устройстве (self.device).
880
+ # Мы можем безопасно загрузить модель Whisper прямо на это устройство.
881
+ whisper_path_or_variant = self.config.whisper_weights_path or self.config.whisper_variant
882
+ self.whisper_model = whisper.load_model(whisper_path_or_variant, device=self.device)
883
+ self.whisper_model.eval()
884
+ logger.info("Whisper model loaded successfully.")
885
+ except Exception as e:
886
+ logger.error(f"Failed to lazily load Whisper model: {e}")
887
+ raise RuntimeError("Could not initialize the Whisper sub-component.") from e
888
+
889
  def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor:
890
+ self._load_whisper_if_needed()
891
  processed_mels = []
892
  # Use self.whisper_model.device as the definitive device for mel spectrograms
893
  # as whisper.load_model puts its tensors on that device.
 
928
  return inputs
929
 
930
  def get_embeddings(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> tuple:
931
+ self._load_whisper_if_needed()
932
  if self.use_text and texts is None:
933
  pass
934