| # import os | |
| # from typing import Dict, List, Optional, Union | |
| # import logging | |
| # import joblib | |
| # import numpy as np | |
| # import torch | |
| # import whisper | |
| # from sklearn.linear_model import Ridge | |
| # from sklearn.svm import SVR | |
| # from sklearn.tree import DecisionTreeRegressor | |
| # from torch import nn | |
| # from tqdm import tqdm | |
| # from transformers import AutoModel, AutoTokenizer, PreTrainedModel, AutoConfig | |
| # from .configuration_whisper import WhisperSSLEnsembleConfig | |
| # # Setup logger | |
| # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| # logger = logging.getLogger(__name__) | |
| # # --- Weak Learners --- | |
| # class WeakLearners(nn.Module): | |
| # """ | |
| # Wrapper for scikit-learn regression models to act as weak learners. | |
| # Requires loading fitted models from a file before use in inference. | |
| # Note: Sklearn models run on CPU. Data is moved accordingly during forward. | |
| # """ | |
| # # Added default dimensions matching common Whisper base/BERT base | |
| # def __init__(self, audio_dim: int = 512, text_dim: int = 768, device: str = "cpu") -> None: | |
| # super().__init__() | |
| # self.audio_dim = audio_dim | |
| # self.text_dim = text_dim | |
| # # Device primarily for moving predictions back, sklearn runs on CPU | |
| # self.device = torch.device(device) | |
| # # Initialize sklearn model structures (placeholders until loaded) | |
| # self.ridge_regressor = Ridge(alpha=1.0) | |
| # self.svr = SVR() | |
| # self.dtr = DecisionTreeRegressor() | |
| # self.models = [self.ridge_regressor, self.svr, self.dtr] | |
| # self.model_names = ["Ridge", "SVR", "DTR"] | |
| # self.fitted = False # Will be set to True after loading | |
| # # Keep fit method for potential separate training script, but not used by WhisperModel/MultiModalWhisper directly | |
| # def fit(self, train_loader: torch.utils.data.DataLoader) -> None: | |
| # logger.info("Fitting weak learners...") | |
| # all_audio_emb, all_text_emb, all_labels = [], [], [] | |
| # pbar = tqdm(train_loader, desc="Extracting Embeddings for Weak Learners", leave=False) | |
| # for batch_data in pbar: | |
| # # Flexible batch data handling (adjust if needed based on actual loader) | |
| # if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 3: | |
| # audio_emb, text_emb, labels = (batch_data[0], batch_data[1], batch_data[2]) | |
| # elif isinstance(batch_data, dict): | |
| # audio_emb = batch_data.get("audio_embedding") | |
| # text_emb = batch_data.get("text_embedding") # Might be None | |
| # labels = batch_data.get("label") | |
| # if audio_emb is None or labels is None: | |
| # raise ValueError("Expected dict with 'audio_embedding' and 'label' keys") | |
| # else: | |
| # raise ValueError("Unsupported batch data format from train_loader") | |
| # all_audio_emb.append(audio_emb.detach().cpu().numpy()) | |
| # if text_emb is not None: | |
| # all_text_emb.append(text_emb.detach().cpu().numpy()) | |
| # elif self.text_dim > 0: # Pad if text expected but not provided in batch | |
| # zeros_text = np.zeros((audio_emb.shape[0], self.text_dim)) | |
| # all_text_emb.append(zeros_text) | |
| # all_labels.append(labels.detach().cpu().numpy()) | |
| # if not all_audio_emb or not all_labels: | |
| # raise RuntimeError("No embeddings or labels collected. Check train_loader.") | |
| # all_audio_emb = np.vstack(all_audio_emb) | |
| # all_labels = np.concatenate(all_labels) | |
| # if all_text_emb: | |
| # all_text_emb = np.vstack(all_text_emb) | |
| # if all_audio_emb.shape[0] != all_text_emb.shape[0]: | |
| # raise ValueError(f"Sample count mismatch: Audio {all_audio_emb.shape[0]}, Text {all_text_emb.shape[0]}") | |
| # combined_embeddings = np.hstack((all_audio_emb, all_text_emb)) | |
| # logger.info(f"Combined embedding shape for fitting: {combined_embeddings.shape}") | |
| # else: | |
| # combined_embeddings = all_audio_emb | |
| # logger.info(f"Using only audio embeddings for fitting. Shape: {combined_embeddings.shape}") | |
| # if combined_embeddings.shape[0] != len(all_labels): | |
| # raise ValueError(f"Sample count mismatch: Embeddings {combined_embeddings.shape[0]}, Labels {len(all_labels)}") | |
| # logger.info("Training sklearn models...") | |
| # for model, name in zip(self.models, self.model_names): | |
| # logger.info(f" Fitting {name}...") | |
| # model.fit(combined_embeddings, all_labels) | |
| # self.fitted = True | |
| # logger.info("Weak learners fitting completed.") | |
| # # --- Add save functionality --- | |
| # # joblib.dump(self.models, 'fitted_weak_learners.joblib') | |
| # # logger.info("Fitted weak learners saved to fitted_weak_learners.joblib") | |
| # def load_fitted(self, filepath: str) -> bool: | |
| # """Loads fitted sklearn models from a joblib file.""" | |
| # if not os.path.exists(filepath): | |
| # logger.error(f"Weak learners file not found at {filepath}") | |
| # return False | |
| # try: | |
| # loaded_sklearn_models = joblib.load(filepath) | |
| # if isinstance(loaded_sklearn_models, list) and len(loaded_sklearn_models) == len(self.models): | |
| # self.models = loaded_sklearn_models | |
| # self.fitted = True | |
| # logger.info(f" Weak learners loaded successfully from {filepath}.") | |
| # return True | |
| # logger.error(f"Loaded file '{filepath}' does not contain the expected list of {len(self.models)} sklearn models.") | |
| # self.fitted = False | |
| # return False | |
| # except ImportError: | |
| # logger.error("joblib is required to load weak learners but not installed. Run: pip install joblib") | |
| # self.fitted = False | |
| # return False | |
| # except Exception as e: | |
| # logger.error(f"Error loading weak learners from {filepath}: {e}") | |
| # self.fitted = False | |
| # return False | |
| # def forward(self, audio_emb: torch.Tensor, text_emb: Optional[torch.Tensor]) -> torch.Tensor: | |
| # """ | |
| # Generate predictions from loaded weak learners. | |
| # Input tensors (on any device) will be moved to CPU for sklearn. | |
| # Output tensor (predictions) will be on self.device. | |
| # """ | |
| # if not self.fitted: | |
| # # Changed error message slightly | |
| # raise RuntimeError("Weak learners must be loaded using load_fitted() before calling forward.") | |
| # # Prepare input for sklearn: move to CPU, detach, convert to numpy | |
| # audio_np = audio_emb.detach().cpu().numpy() | |
| # # Handle text embedding based on text_dim | |
| # if self.text_dim > 0: | |
| # if text_emb is None: | |
| # # If text is expected (text_dim > 0) but not provided, pad with zeros | |
| # logger.warning("Text embedding is None in WeakLearners forward pass, but text_dim > 0. Padding with zeros.") | |
| # zeros_text = np.zeros((audio_np.shape[0], self.text_dim)) | |
| # combined_embeddings = np.hstack((audio_np, zeros_text)) | |
| # else: | |
| # # Text is expected and provided | |
| # text_np = text_emb.detach().cpu().numpy() | |
| # if audio_np.shape[0] != text_np.shape[0]: | |
| # raise ValueError("Batch size mismatch between audio and text embeddings in WeakLearners forward.") | |
| # combined_embeddings = np.hstack((audio_np, text_np)) | |
| # else: | |
| # # No text is expected (text_dim == 0) | |
| # if text_emb is not None: | |
| # logger.warning("Text embedding provided to WeakLearners forward pass, but text_dim is 0. Ignoring text.") | |
| # combined_embeddings = audio_np | |
| # # Get predictions from each loaded sklearn model | |
| # all_preds = [] | |
| # # Use self.models which now contains the loaded models | |
| # for model in self.models: | |
| # preds = model.predict(combined_embeddings) | |
| # # Convert numpy array predictions to float tensor | |
| # all_preds.append(torch.from_numpy(preds).float()) | |
| # # Stack predictions along a new dimension (dim=1) -> (batch_size, num_weak_learners) | |
| # # Move the final stacked tensor to the designated device (e.g., GPU if specified) | |
| # stacked_preds = torch.stack(all_preds, dim=1).to(self.device) | |
| # return stacked_preds | |
| # # --- Stacking Model (Meta-Learner) --- | |
| # class StackingMetaLearner(nn.Module): | |
| # """ | |
| # A simple feed-forward network that learns to combine predictions | |
| # from weak learners. Structure needs to match the saved weights. | |
| # """ | |
| # def __init__(self, weak_output_dim: int = 3, hidden_dim: int = 256) -> None: | |
| # super().__init__() | |
| # # Check for valid dimensions | |
| # if weak_output_dim <= 0 or hidden_dim <= 0: | |
| # raise ValueError("weak_output_dim and hidden_dim must be positive integers.") | |
| # self.fc1 = nn.Linear(weak_output_dim, hidden_dim) | |
| # self.relu = nn.ReLU() | |
| # self.fc2 = nn.Linear(hidden_dim, 1) # Predict a single score | |
| # def load_state_dict_from_file(self, filepath: str, device: torch.device) -> None: | |
| # if not os.path.exists(filepath): | |
| # raise FileNotFoundError(f"Meta-learner state file not found at {filepath}") | |
| # try: | |
| # # Load state dict onto the specified device directly | |
| # state_dict = torch.load(filepath, map_location=device) | |
| # self.load_state_dict(state_dict) | |
| # # Move model to the device *after* loading state_dict | |
| # self.to(device) | |
| # self.eval() # Set to evaluation mode after loading | |
| # logger.info(f" StackingMetaLearner state loaded successfully from {filepath} onto device {device}.") | |
| # except Exception as e: | |
| # logger.error(f"Error loading StackingMetaLearner state_dict from {filepath}: {e}") | |
| # raise RuntimeError(f"Failed to load StackingMetaLearner state: {e}") from e | |
| # def forward(self, weak_outputs: torch.Tensor) -> torch.Tensor: | |
| # """ | |
| # Args: | |
| # weak_outputs (torch.Tensor): Tensor of shape (batch_size, weak_output_dim) | |
| # containing predictions from weak learners. | |
| # Returns: | |
| # torch.Tensor: Final prediction of shape (batch_size, 1). | |
| # """ | |
| # x = self.relu(self.fc1(weak_outputs)) | |
| # x = self.fc2(x) | |
| # return x | |
| # # --- Main Ensemble Model (Wrapper used internally now) --- | |
| # class SSLEnsembleModel(nn.Module): | |
| # """ | |
| # Combines WeakLearners and a StackingMetaLearner. | |
| # Assumes WeakLearners are loaded externally and MetaLearner state_dict is loaded externally. | |
| # This class mainly defines the forward pass logic using the components. | |
| # """ | |
| # def __init__(self, weak_learners: WeakLearners, stacking_meta_learner: StackingMetaLearner) -> None: | |
| # super().__init__() | |
| # if not isinstance(weak_learners, WeakLearners) or not weak_learners.fitted: | |
| # raise ValueError("A pre-loaded WeakLearners instance must be provided to SSLEnsembleModel.") | |
| # if not isinstance(stacking_meta_learner, StackingMetaLearner): | |
| # raise ValueError("A StackingMetaLearner instance must be provided.") | |
| # self.weak_learners = weak_learners | |
| # self.stacking_meta_learner = stacking_meta_learner | |
| # def forward(self, audio_emb: torch.Tensor, text_emb: Optional[torch.Tensor]) -> torch.Tensor: | |
| # """ | |
| # Forward pass through the ensemble. | |
| # 1. Get predictions from weak learners. | |
| # 2. Pass weak learner predictions to the stacking meta-learner. | |
| # """ | |
| # # Get predictions from weak learners (shape: batch_size, num_weak_outputs) | |
| # # WeakLearners forward handles device movement internally | |
| # weak_outputs = self.weak_learners(audio_emb, text_emb) # Output is on weak_learners.device | |
| # # Ensure weak_outputs are on the same device as the meta-learner before passing | |
| # weak_outputs = weak_outputs.to(next(self.stacking_meta_learner.parameters()).device) | |
| # # Get final prediction from the meta-learner | |
| # final_output = self.stacking_meta_learner(weak_outputs) # Output shape (batch_size, 1) | |
| # return final_output | |
| # # --- Unified Main Model Class --- | |
| # class WhisperSSLEnsemble(PreTrainedModel): # type: ignore | |
| # """ | |
| # Unified model using Whisper for audio, an optional transformer for text, | |
| # and an optional SSL Ensemble (Weak + Meta learners) for final prediction. | |
| # Inherits from BaseMultimodalModel. | |
| # """ | |
| # config_class = WhisperSSLEnsembleConfig | |
| # def __init__(self, config: WhisperSSLEnsembleConfig) -> None: | |
| # # self, | |
| # # whisper_variant: str = "base.en", | |
| # # text_model_type: str = "bert-base-uncased", | |
| # # weights: str = None, # Path to Whisper weights if not using standard variant download | |
| # # device: str = None, | |
| # # ssl_ensemble_config: Optional[Dict] = None, | |
| # # super().__init__(weights) # Not calling super for BaseModel due to different init logic | |
| # super().__init__(config) | |
| # self.config = config | |
| # whisper_variant = self.config.whisper_variant | |
| # text_model_type = self.config.text_model_type | |
| # weights = self.config.weights | |
| # ssl_ensemble_config = self.config.ssl_ensemble_config | |
| # # device = 'cpu' | |
| # # if self.config.device is None: | |
| # # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # # else: | |
| # # device = self.config.device | |
| # # self.device = torch.device(device) | |
| # self.predict_mode = False | |
| # self.ssl_ensemble_model = None | |
| # self.tokenizer = None | |
| # self.text_model = None | |
| # self._audio_embedding_dim = None | |
| # self._text_embedding_dim = 0 | |
| # # logger.info(f"Initializing WhisperSSLEnsemble on device: {self.device}") | |
| # try: | |
| # # Determine if 'weights' is a path or a variant name for whisper.load_model | |
| # # whisper.load_model can take a name like "base.en" or a path to a .pt file | |
| # logger.info(f"Loading Whisper model: '{weights if weights else whisper_variant}'...") | |
| # wm = whisper.load_model(weights if weights else whisper_variant, device=self.device) | |
| # self.whisper_model = wm | |
| # self.whisper_model.eval() | |
| # self._audio_embedding_dim = self.whisper_model.encoder.ln_post.normalized_shape[0] | |
| # logger.info(f" Whisper loaded. Audio embedding dimension: {self._audio_embedding_dim}") | |
| # except Exception as e: | |
| # logger.error(f"Error loading Whisper model: {e}") | |
| # raise RuntimeError(f"Failed to load Whisper model '{weights or whisper_variant}'") from e | |
| # self.use_text = text_model_type is not None and text_model_type.lower() != "none" | |
| # if self.use_text: | |
| # logger.info(f"Loading Text model '{text_model_type}'...") | |
| # try: | |
| # self.tokenizer = AutoTokenizer.from_pretrained(text_model_type) | |
| # self.text_model = AutoModel.from_pretrained(text_model_type).to(self.device) | |
| # self.text_model.eval() | |
| # self._text_embedding_dim = self.text_model.config.hidden_size | |
| # logger.info(f" Text model loaded. Text embedding dimension: {self._text_embedding_dim}") | |
| # except Exception as e: | |
| # logger.warning(f"Failed to load text model '{text_model_type}'. Error: {e}") | |
| # logger.warning(" Text processing will be disabled.") | |
| # self.use_text = False | |
| # self._text_embedding_dim = 0 | |
| # else: | |
| # logger.info("Text model type is 'none'. Text processing disabled.") | |
| # self._text_embedding_dim = 0 | |
| # if ssl_ensemble_config is not None: | |
| # logger.info("SSL Ensemble config provided. Initializing for prediction mode...") | |
| # self.predict_mode = True | |
| # required_keys = ["weak_learners_path", "meta_learner_path", "audio_dim", "text_dim", "hidden_dim"] | |
| # if not all(key in ssl_ensemble_config for key in required_keys): | |
| # raise ValueError(f"ssl_ensemble_config missing required keys: {required_keys}") | |
| # cfg_audio_dim = ssl_ensemble_config["audio_dim"] | |
| # cfg_text_dim = ssl_ensemble_config["text_dim"] | |
| # if cfg_audio_dim != self._audio_embedding_dim: | |
| # logger.warning(f"Ensemble config audio_dim ({cfg_audio_dim}) mismatches Whisper model dim ({self._audio_embedding_dim}).") | |
| # # Validate text_dim from config against loaded text model | |
| # if self.use_text and cfg_text_dim != self._text_embedding_dim: | |
| # logger.warning(f"Ensemble config text_dim ({cfg_text_dim}) mismatches loaded text model dim ({self._text_embedding_dim}).") | |
| # elif not self.use_text and cfg_text_dim != 0: | |
| # logger.warning(f"Ensemble config expects text_dim={cfg_text_dim}, but no text model is loaded. WeakLearners must handle padding.") | |
| # elif self.use_text and cfg_text_dim == 0: # Text model loaded, but ensemble config says text_dim=0 | |
| # logger.warning("Ensemble config expects text_dim=0, but a text model IS loaded. " | |
| # "Text embeddings will be ignored by weak learners if their text_dim is 0.") | |
| # weak_learners = WeakLearners(audio_dim=cfg_audio_dim, text_dim=cfg_text_dim, device=self.device) | |
| # if not weak_learners.load_fitted(ssl_ensemble_config["weak_learners_path"]): | |
| # raise RuntimeError("Failed to load weak learners for WhisperSSLEnsemble.") | |
| # meta_learner = StackingMetaLearner(weak_output_dim=len(weak_learners.models), hidden_dim=ssl_ensemble_config["hidden_dim"]) | |
| # meta_learner.load_state_dict_from_file(ssl_ensemble_config["meta_learner_path"], device=self.device) | |
| # self.ssl_ensemble_model = SSLEnsembleModel(weak_learners=weak_learners, stacking_meta_learner=meta_learner) | |
| # self.ssl_ensemble_model.eval() | |
| # logger.info("SSLEnsembleModel loaded successfully.") | |
| # else: | |
| # logger.info("No SSL Ensemble config provided. Model initialized in embedding mode.") | |
| # logger.info("Call '.get_embeddings(audios, texts)' to get embeddings.") | |
| # logger.info("Calling '.predict()' or '.forward()' will raise an error in this mode.") | |
| # def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor: | |
| # processed_mels = [] | |
| # for audio in audios: | |
| # if isinstance(audio, torch.Tensor): | |
| # audio = audio.cpu().numpy() | |
| # if not isinstance(audio, np.ndarray): | |
| # raise TypeError(f"Expected audio input to be numpy array or torch tensor, got {type(audio)}") | |
| # if audio.dtype != np.float32: | |
| # audio = audio.astype(np.float32) | |
| # audio_proc = whisper.pad_or_trim(audio) | |
| # mel = whisper.log_mel_spectrogram(audio_proc, device=self.whisper_model.device) | |
| # processed_mels.append(mel) | |
| # if not processed_mels: | |
| # # Using Whisper's typical dimensions: 80 mel bins, 3000 frames (30s) | |
| # return torch.empty((0, self.whisper_model.dims.n_mels, self.whisper_model.dims.n_audio_ctx), device=self.whisper_model.device) | |
| # # Simpler: use known default values if whisper model object doesn't expose them easily | |
| # # return torch.empty((0, 80, 3000), device=self.whisper_model.device) | |
| # mels_batch = torch.stack(processed_mels) | |
| # return mels_batch | |
| # def preprocess_text(self, texts: List[str]) -> Optional[Dict[str, torch.Tensor]]: | |
| # if not self.use_text or not self.tokenizer: | |
| # return None | |
| # if not texts: | |
| # return { | |
| # "input_ids": torch.empty((0, 0), dtype=torch.long, device=self.device), | |
| # "attention_mask": torch.empty((0, 0), dtype=torch.long, device=self.device) | |
| # } | |
| # pad_token = self.tokenizer.pad_token if self.tokenizer.pad_token else "[PAD]" | |
| # processed_texts = [t if isinstance(t, str) and t else pad_token for t in texts] | |
| # inputs = self.tokenizer(processed_texts, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
| # inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # return inputs | |
| # def get_embeddings(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> tuple: | |
| # if self.use_text and texts is None: | |
| # pass # Allow texts to be None, WeakLearners will handle if it expects text. | |
| # if not self.use_text and texts is not None: | |
| # logger.warning("Text input provided to get_embeddings but no text model loaded. Text will be ignored.") | |
| # texts = None # Ensure texts list is None if not used by this WhisperSSLEnsemble instance | |
| # mels = self.preprocess_audio(audios) | |
| # # Handle empty batch input for embeddings | |
| # if mels.numel() == 0: # No audio provided | |
| # audio_emb = torch.empty((0, self._audio_embedding_dim), device=self.device) # type: ignore | |
| # text_emb = None | |
| # if self.use_text: | |
| # text_emb = torch.empty((0, self._text_embedding_dim), device=self.device) | |
| # return audio_emb, text_emb | |
| # audio_emb = torch.empty((mels.shape[0], self._audio_embedding_dim), device=self.device) # type: ignore | |
| # if mels.numel() > 0: | |
| # with torch.no_grad(): | |
| # encoder_output = self.whisper_model.encoder(mels) | |
| # audio_emb = encoder_output.mean(dim=1) | |
| # text_emb = None | |
| # if self.use_text: | |
| # # Ensure text batch matches audio if texts are provided or if texts are None but audio exists (pad with empty strings) | |
| # effective_texts = texts | |
| # if texts is None: # If no texts provided, but we use_text, create dummy empty strings for tokenization matching audio batch size | |
| # effective_texts = [""] * mels.shape[0] | |
| # elif len(texts) != mels.shape[0] and mels.shape[0] > 0: # If texts length mismatch audio batch, this is an issue or needs defined behavior | |
| # # For now, if texts are provided but mismatch, we might just process what's given or raise error. | |
| # # Let's assume if texts are given, they match the audio batch. Or pad/truncate. | |
| # # This part of logic may need refinement based on expected use case for mismatched batch sizes. | |
| # # For simplicity, we proceed with `effective_texts` as is, assuming `preprocess_text` handles it. | |
| # # If texts is None and use_text is true, effective_texts will be list of empty strings matching audio batch. | |
| # pass | |
| # tokenized_inputs = self.preprocess_text(effective_texts) # type: ignore | |
| # if tokenized_inputs and tokenized_inputs["input_ids"].numel() > 0: | |
| # with torch.no_grad(): | |
| # outputs = self.text_model(**tokenized_inputs) # type: ignore | |
| # text_emb = outputs.last_hidden_state[:, 0, :] | |
| # elif self._text_embedding_dim > 0: # Texts were None or empty list resulting in no tokenized_inputs, but text is expected | |
| # text_emb = torch.empty((mels.shape[0], self._text_embedding_dim), device=self.device) | |
| # audio_emb = audio_emb.to(self.device) | |
| # if text_emb is not None: | |
| # text_emb = text_emb.to(self.device) | |
| # if text_emb.shape[0] != audio_emb.shape[0] and audio_emb.shape[0] > 0: | |
| # pass | |
| # return audio_emb, text_emb | |
| # def forward(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> torch.Tensor: | |
| # """ | |
| # Overrides BaseMultimodalModel.forward. | |
| # Processes raw audio and text inputs through the full pipeline (embeddings + SSL ensemble) | |
| # to get final scores as a Tensor. This method is called by self.predict(). | |
| # If not in predict_mode, this method will raise a RuntimeError. | |
| # """ | |
| # if not self.predict_mode or self.ssl_ensemble_model is None: | |
| # raise RuntimeError("Cannot call 'forward' for final predictions. Model was not initialized with 'ssl_ensemble_config' for prediction mode. " | |
| # "Use 'get_embeddings()' if you only need embeddings.") | |
| # # Determine text input strategy for get_embeddings based on ensemble and model config | |
| # weak_learners_cfg_text_dim = self.ssl_ensemble_model.weak_learners.text_dim | |
| # texts_for_embeddings = texts # Default to provided texts | |
| # if weak_learners_cfg_text_dim > 0: # Ensemble's weak learners expect text features | |
| # if not self.use_text: # This WhisperSSLEnsemble instance has no text model | |
| # logger.warning("WhisperSSLEnsemble.forward: Ensemble's weak learners expect text (config text_dim > 0), " | |
| # "but this model instance has no text processor (e.g., text_model_type='none' or load failed). " | |
| # "Embeddings passed to weak learners will have text_emb=None. WeakLearners must handle this (e.g. by padding).") | |
| # texts_for_embeddings = None # Ensure get_embeddings receives None for text | |
| # elif texts is None: # Model has text processor, ensemble expects text, but no 'texts' arg provided | |
| # raise ValueError("WhisperSSLEnsemble.forward: Ensemble's weak learners require text input (config text_dim > 0), but 'texts' argument is None.") | |
| # elif weak_learners_cfg_text_dim == 0: # Ensemble's weak learners DO NOT expect text features | |
| # if texts is not None: | |
| # logger.warning("WhisperSSLEnsemble.forward: Ensemble's weak learners do not expect text (config text_dim == 0)," | |
| # " but text input was provided. If a text model is loaded in this instance, text embeddings might be computed by " | |
| # "get_embeddings but will then be ignored by the weak learners.") | |
| # # 1. Get Embeddings | |
| # audio_emb, text_emb = self.get_embeddings(audios, texts_for_embeddings) | |
| # if audio_emb.shape[0] == 0: # Handle empty batch case after getting embeddings | |
| # return torch.empty((0, 1), device=self.device) # Return empty scores tensor, shape (batch, 1) | |
| # # 2. Pass embeddings through the SSLEnsembleModel | |
| # with torch.no_grad(): | |
| # final_predictions_tensor = self.ssl_ensemble_model(audio_emb, text_emb) # Shape: (batch_size, 1) | |
| # return final_predictions_tensor | |
| # def predict(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> List[float]: | |
| # """ | |
| # Overrides BaseMultimodalModel.predict. | |
| # Generates final predictions using the full pipeline (Embeddings + SSL Ensemble). | |
| # Requires the model to be initialized in prediction mode. | |
| # """ | |
| # final_predictions_tensor = self.forward(audios, texts) | |
| # if final_predictions_tensor.numel() == 0: | |
| # return [] | |
| # scores = final_predictions_tensor.detach().cpu().numpy().flatten().tolist() | |
| # return scores | |
| # AutoConfig.register("whisper-bert", WhisperSSLEnsembleConfig) | |
| # AutoModel.register(WhisperSSLEnsembleConfig, WhisperSSLEnsemble) | |
| import os | |
| from typing import Dict, List, Optional, Union | |
| import logging | |
| import joblib | |
| import numpy as np | |
| import torch | |
| import whisper | |
| from sklearn.linear_model import Ridge | |
| from sklearn.svm import SVR | |
| from sklearn.tree import DecisionTreeRegressor | |
| from torch import nn | |
| from tqdm import tqdm | |
| from transformers import AutoModel, AutoTokenizer, PreTrainedModel, AutoConfig | |
| # Assuming your configuration is in a file like 'configuration_whisper.py' | |
| from .configuration_whisper import WhisperSSLEnsembleConfig | |
| # Setup logger | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # --- Weak Learners --- | |
| class WeakLearners(nn.Module): | |
| """ | |
| Wrapper for scikit-learn regression models to act as weak learners. | |
| Requires loading fitted models from a file before use in inference. | |
| Note: Sklearn models run on CPU. Data is moved accordingly during forward. | |
| """ | |
| def __init__(self, audio_dim: int = 512, text_dim: int = 768, device: str = "cpu") -> None: | |
| super().__init__() | |
| self.audio_dim = audio_dim | |
| self.text_dim = text_dim | |
| # Store the target device (as torch.device) for moving predictions back | |
| self.device = torch.device(device) # Keep as torch.device for consistency | |
| # Initialize sklearn model structures (placeholders until loaded) | |
| # These are not PyTorch modules/parameters, so they don't affect self.device directly | |
| self.ridge_regressor = Ridge(alpha=1.0) | |
| self.svr = SVR() | |
| self.dtr = DecisionTreeRegressor() | |
| self.models = [self.ridge_regressor, self.svr, self.dtr] | |
| self.model_names = ["Ridge", "SVR", "DTR"] | |
| self.fitted = False # Will be set to True after loading | |
| def fit(self, train_loader: torch.utils.data.DataLoader) -> None: | |
| logger.info("Fitting weak learners...") | |
| all_audio_emb, all_text_emb, all_labels = [], [], [] | |
| pbar = tqdm(train_loader, desc="Extracting Embeddings for Weak Learners", leave=False) | |
| for batch_data in pbar: | |
| if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 3: | |
| audio_emb, text_emb, labels = (batch_data[0], batch_data[1], batch_data[2]) | |
| elif isinstance(batch_data, dict): | |
| audio_emb = batch_data.get("audio_embedding") | |
| text_emb = batch_data.get("text_embedding") | |
| labels = batch_data.get("label") | |
| if audio_emb is None or labels is None: | |
| raise ValueError("Expected dict with 'audio_embedding' and 'label' keys") | |
| else: | |
| raise ValueError("Unsupported batch data format from train_loader") | |
| all_audio_emb.append(audio_emb.detach().cpu().numpy()) | |
| if text_emb is not None: | |
| all_text_emb.append(text_emb.detach().cpu().numpy()) | |
| elif self.text_dim > 0: | |
| zeros_text = np.zeros((audio_emb.shape[0], self.text_dim)) | |
| all_text_emb.append(zeros_text) | |
| all_labels.append(labels.detach().cpu().numpy()) | |
| if not all_audio_emb or not all_labels: | |
| raise RuntimeError("No embeddings or labels collected. Check train_loader.") | |
| all_audio_emb = np.vstack(all_audio_emb) | |
| all_labels = np.concatenate(all_labels) | |
| if all_text_emb: | |
| all_text_emb = np.vstack(all_text_emb) | |
| if all_audio_emb.shape[0] != all_text_emb.shape[0]: | |
| raise ValueError(f"Sample count mismatch: Audio {all_audio_emb.shape[0]}, Text {all_text_emb.shape[0]}") | |
| combined_embeddings = np.hstack((all_audio_emb, all_text_emb)) | |
| logger.info(f"Combined embedding shape for fitting: {combined_embeddings.shape}") | |
| else: | |
| combined_embeddings = all_audio_emb | |
| logger.info(f"Using only audio embeddings for fitting. Shape: {combined_embeddings.shape}") | |
| if combined_embeddings.shape[0] != len(all_labels): | |
| raise ValueError(f"Sample count mismatch: Embeddings {combined_embeddings.shape[0]}, Labels {len(all_labels)}") | |
| logger.info("Training sklearn models...") | |
| for model, name in zip(self.models, self.model_names): | |
| logger.info(f" Fitting {name}...") | |
| model.fit(combined_embeddings, all_labels) | |
| self.fitted = True | |
| logger.info("Weak learners fitting completed.") | |
| def load_fitted(self, filepath: str) -> bool: | |
| """Loads fitted sklearn models from a joblib file.""" | |
| if not os.path.exists(filepath): | |
| logger.error(f"Weak learners file not found at {filepath}") | |
| return False | |
| try: | |
| loaded_sklearn_models = joblib.load(filepath) | |
| if isinstance(loaded_sklearn_models, list) and len(loaded_sklearn_models) == len(self.models): | |
| self.models = loaded_sklearn_models | |
| self.fitted = True | |
| logger.info(f" Weak learners loaded successfully from {filepath}.") | |
| return True | |
| logger.error(f"Loaded file '{filepath}' does not contain the expected list of {len(self.models)} sklearn models.") | |
| self.fitted = False | |
| return False | |
| except ImportError: | |
| logger.error("joblib is required to load weak learners but not installed. Run: pip install joblib") | |
| self.fitted = False | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error loading weak learners from {filepath}: {e}") | |
| self.fitted = False | |
| return False | |
| def forward(self, audio_emb: torch.Tensor, text_emb: Optional[torch.Tensor]) -> torch.Tensor: | |
| """ | |
| Generate predictions from loaded weak learners. | |
| Input tensors (on any device) will be moved to CPU for sklearn. | |
| Output tensor (predictions) will be on self.device. | |
| """ | |
| if not self.fitted: | |
| raise RuntimeError("Weak learners must be loaded using load_fitted() before calling forward.") | |
| audio_np = audio_emb.detach().cpu().numpy() | |
| if self.text_dim > 0: | |
| if text_emb is None: | |
| logger.warning("Text embedding is None in WeakLearners forward pass, but text_dim > 0. Padding with zeros.") | |
| zeros_text = np.zeros((audio_np.shape[0], self.text_dim)) | |
| combined_embeddings = np.hstack((audio_np, zeros_text)) | |
| else: | |
| text_np = text_emb.detach().cpu().numpy() | |
| if audio_np.shape[0] != text_np.shape[0]: | |
| raise ValueError("Batch size mismatch between audio and text embeddings in WeakLearners forward.") | |
| combined_embeddings = np.hstack((audio_np, text_np)) | |
| else: | |
| if text_emb is not None: | |
| logger.warning("Text embedding provided to WeakLearners forward pass, but text_dim is 0. Ignoring text.") | |
| combined_embeddings = audio_np | |
| all_preds = [] | |
| for model in self.models: | |
| preds = model.predict(combined_embeddings) | |
| all_preds.append(torch.from_numpy(preds).float()) | |
| stacked_preds = torch.stack(all_preds, dim=1).to(self.device) | |
| return stacked_preds | |
| # --- Stacking Model (Meta-Learner) --- | |
| class StackingMetaLearner(nn.Module): | |
| """ | |
| A simple feed-forward network that learns to combine predictions | |
| from weak learners. Structure needs to match the saved weights. | |
| """ | |
| def __init__(self, weak_output_dim: int = 3, hidden_dim: int = 256) -> None: | |
| super().__init__() | |
| if weak_output_dim <= 0 or hidden_dim <= 0: | |
| raise ValueError("weak_output_dim and hidden_dim must be positive integers.") | |
| self.fc1 = nn.Linear(weak_output_dim, hidden_dim) | |
| self.relu = nn.ReLU() | |
| self.fc2 = nn.Linear(hidden_dim, 1) | |
| def load_state_dict_from_file(self, filepath: str, device: torch.device) -> None: | |
| if not os.path.exists(filepath): | |
| raise FileNotFoundError(f"Meta-learner state file not found at {filepath}") | |
| try: | |
| state_dict = torch.load(filepath, map_location=device) | |
| self.load_state_dict(state_dict) | |
| self.to(device) # Ensure the meta learner itself is on the device | |
| self.eval() | |
| logger.info(f" StackingMetaLearner state loaded successfully from {filepath} onto device {device}.") | |
| except Exception as e: | |
| logger.error(f"Error loading StackingMetaLearner state_dict from {filepath}: {e}") | |
| raise RuntimeError(f"Failed to load StackingMetaLearner state: {e}") from e | |
| def forward(self, weak_outputs: torch.Tensor) -> torch.Tensor: | |
| x = self.relu(self.fc1(weak_outputs)) | |
| x = self.fc2(x) | |
| return x | |
| # --- Main Ensemble Model (Wrapper used internally now) --- | |
| class SSLEnsembleModel(nn.Module): | |
| """ | |
| Combines WeakLearners and a StackingMetaLearner. | |
| Assumes WeakLearners are loaded externally and MetaLearner state_dict is loaded externally. | |
| This class mainly defines the forward pass logic using the components. | |
| """ | |
| def __init__(self, weak_learners: WeakLearners, stacking_meta_learner: StackingMetaLearner) -> None: | |
| super().__init__() | |
| if not isinstance(weak_learners, WeakLearners) or not weak_learners.fitted: | |
| raise ValueError("A pre-loaded WeakLearners instance must be provided to SSLEnsembleModel.") | |
| if not isinstance(stacking_meta_learner, StackingMetaLearner): | |
| raise ValueError("A StackingMetaLearner instance must be provided.") | |
| self.weak_learners = weak_learners | |
| self.stacking_meta_learner = stacking_meta_learner | |
| def forward(self, audio_emb: torch.Tensor, text_emb: Optional[torch.Tensor]) -> torch.Tensor: | |
| weak_outputs = self.weak_learners(audio_emb, text_emb) | |
| weak_outputs = weak_outputs.to(next(self.stacking_meta_learner.parameters()).device) # Ensure device match | |
| final_output = self.stacking_meta_learner(weak_outputs) | |
| return final_output | |
| # --- Unified Main Model Class --- | |
| class WhisperSSLEnsemble(PreTrainedModel): # type: ignore | |
| """ | |
| Unified model using Whisper for audio, an optional transformer for text, | |
| and an optional SSL Ensemble (Weak + Meta learners) for final prediction. | |
| Inherits from BaseMultimodalModel. | |
| """ | |
| config_class = WhisperSSLEnsembleConfig | |
| # В классе WhisperSSLEnsemble | |
| def __init__(self, config: WhisperSSLEnsembleConfig) -> None: | |
| super().__init__(config) | |
| self.config = config | |
| # 1. ОПРЕДЕЛЯЕМ ЦЕЛЕВОЕ УСТРОЙСТВО (БЕЗОПАСНО) | |
| target_device_str = config.device if config.device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| self._target_device = torch.device(target_device_str) | |
| logger.info(f"Initializing WhisperSSLEnsemble, targeting device: {self._target_device}") | |
| # 2. ИНИЦИАЛИЗИРУЕМ ВСЕ КОМПОНЕНТЫ КАК None (БЕЗОПАСНО) | |
| # Никакой загрузки здесь! | |
| self.whisper_model = None | |
| self.text_model = None | |
| self.tokenizer = None | |
| self.ssl_ensemble_model = None | |
| # 3. ЧИТАЕМ ПРОСТЫЕ ПАРАМЕТРЫ ИЗ КОНФИГА (БЕЗОПАСНО) | |
| # Эти размеры нужны для других частей кода | |
| self._audio_embedding_dim = config.whisper_embedding_dim | |
| self._text_embedding_dim = config.text_embedding_dim | |
| # Определяем, будем ли мы вообще использовать под-модели | |
| self.use_text = config.text_model_type is not None and config.text_model_type.lower() != "none" | |
| self.predict_mode = config.ssl_ensemble_config is not None | |
| if self.use_text: | |
| logger.info(f"Text model '{config.text_model_type}' is configured and will be loaded on demand.") | |
| if self.predict_mode: | |
| logger.info("SSL Ensemble is configured and will be loaded on demand for prediction.") | |
| # 4. ФИНАЛЬНЫЙ ВЫЗОВ .to() (БЕЗОПАСНО) | |
| # Так как у модели еще нет параметров, этот вызов просто установит self.device | |
| self.to(self._target_device) | |
| logger.info(f"WhisperSSLEnsemble initialization complete. Final model device: {self._target_device}") | |
| # Поместите этот метод внутри класса WhisperSSLEnsemble | |
| # В классе WhisperSSLEnsemble | |
| def _load_whisper_if_needed(self): | |
| if self.whisper_model is not None: | |
| return | |
| whisper_path_or_variant = self.config.whisper_weights_path or self.config.whisper_variant | |
| logger.info(f"Lazily loading Whisper model '{whisper_path_or_variant}'...") | |
| try: | |
| # ШАГ 1: Гарантированно создаем на CPU, чтобы избежать meta-контекста | |
| with torch.device("cpu"): | |
| # Загружаем модель, не указывая device, чтобы она осталась на CPU | |
| whisper_model_cpu = whisper.load_model(whisper_path_or_variant) | |
| # ШАГ 2: Перемещаем полностью "живую" модель на целевое устройство | |
| self.whisper_model = whisper_model_cpu.to(self._target_device) | |
| self.whisper_model.eval() | |
| logger.info(f"Whisper model loaded successfully onto device '{self._target_device}'.") | |
| except Exception as e: | |
| logger.error(f"Failed to lazily load Whisper model: {e}", exc_info=True) | |
| raise RuntimeError("Could not initialize the Whisper sub-component.") from e | |
| def _load_text_model_if_needed(self): | |
| if not self.use_text or self.text_model is not None: | |
| return | |
| text_model_type = self.config.text_model_type | |
| logger.info(f"Lazily loading Text model and tokenizer '{text_model_type}'...") | |
| try: | |
| # Токенизатор можно загружать сразу | |
| self.tokenizer = AutoTokenizer.from_pretrained(text_model_type) | |
| # ШАГ 1: Гарантированно создаем модель на CPU | |
| # Явно отключаем low_cpu_mem_usage, чтобы избежать meta-устройства | |
| text_model_cpu = AutoModel.from_pretrained( | |
| text_model_type, | |
| low_cpu_mem_usage=False | |
| ) | |
| # ШАГ 2: Перемещаем "живую" CPU-модель на целевое устройство (self.device) | |
| self.text_model = text_model_cpu.to(self._target_device) | |
| self.text_model.eval() | |
| logger.info(f"Text model and tokenizer loaded successfully onto device '{self._target_device}'.") | |
| # Проверка консистентности размеров (не меняется) | |
| loaded_dim = self.text_model.config.hidden_size | |
| if self.use_text and loaded_dim != self._text_embedding_dim: | |
| logger.warning(f"Configured text_embedding_dim ({self._text_embedding_dim}) mismatches loaded model's dim ({loaded_dim}).") | |
| self._text_embedding_dim = loaded_dim | |
| except Exception as e: | |
| logger.error(f"Failed to lazily load text model '{text_model_type}': {e}", exc_info=True) | |
| self.use_text = False | |
| logger.warning(" Text processing will be disabled due to load failure.") | |
| # В классе WhisperSSLEnsemble | |
| # def _load_ssl_ensemble_if_needed(self): | |
| # # Если модель уже загружена или не нужна, выходим | |
| # if self.ssl_ensemble_model is not None or not self.predict_mode: | |
| # return | |
| # logger.info("Lazily loading SSL Ensemble model...") | |
| # ssl_ensemble_config = self.config.ssl_ensemble_config | |
| # try: | |
| # # ШАГ 1: НАЙТИ ДИРЕКТОРИЮ, ГДЕ ЛЕЖАТ ВСЕ ФАЙЛЫ МОДЕЛИ | |
| # # Это самый важный и правильный шаг. | |
| # # __file__ - это путь к текущему файлу (modeling_whisper.py) | |
| # # os.path.dirname() получает директорию из этого пути. | |
| # model_dir = os.path.dirname(__file__) | |
| # # ШАГ 2: ПОСТРОИТЬ АБСОЛЮТНЫЕ ПУТИ К ФАЙЛАМ ВЕСОВ | |
| # weak_learners_filename = ssl_ensemble_config["weak_learners_path"] | |
| # weak_learners_path = os.path.join(model_dir, weak_learners_filename) | |
| # meta_learner_filename = ssl_ensemble_config["meta_learner_path"] | |
| # meta_learner_path = os.path.join(model_dir, meta_learner_filename) | |
| # logger.info(f"Attempting to load weak learners from: {weak_learners_path}") | |
| # logger.info(f"Attempting to load meta learner from: {meta_learner_path}") | |
| # # ШАГ 3: ЗАГРУЗИТЬ МОДЕЛИ ПО АБСОЛЮТНЫМ ПУТЯМ | |
| # weak_learners = WeakLearners( | |
| # audio_dim=ssl_ensemble_config["audio_dim"], | |
| # text_dim=ssl_ensemble_config["text_dim"], | |
| # device=self._target_device.type | |
| # ) | |
| # if not weak_learners.load_fitted(weak_learners_path): | |
| # raise RuntimeError(f"Failed to load weak learners from {weak_learners_path}") | |
| # meta_learner = StackingMetaLearner( | |
| # weak_output_dim=len(weak_learners.models), | |
| # hidden_dim=ssl_ensemble_config["hidden_dim"] | |
| # ) | |
| # meta_learner.load_state_dict_from_file(meta_learner_path, device=self._target_device) | |
| # # СОЗДАНИЕ ИТОГОВОЙ МОДЕЛИ АНСАМБЛЯ | |
| # self.ssl_ensemble_model = SSLEnsembleModel( | |
| # weak_learners=weak_learners, | |
| # stacking_meta_learner=meta_learner | |
| # ) | |
| # self.ssl_ensemble_model.eval() | |
| # logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.") | |
| # except Exception as e: | |
| # logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True) | |
| # self.predict_mode = False | |
| # logger.warning(" Prediction with SSL Ensemble will be disabled.") | |
| # В классе WhisperSSLEnsemble | |
| def _load_ssl_ensemble_if_needed(self): | |
| if self.ssl_ensemble_model is not None or not self.predict_mode: | |
| return | |
| logger.info("Lazily loading SSL Ensemble model...") | |
| ssl_ensemble_config = self.config.ssl_ensemble_config | |
| try: | |
| # НОВЫЙ ПОДХОД: Скачиваем файлы напрямую из репозитория | |
| from huggingface_hub import hf_hub_download | |
| # Получаем имя репозитория из конфига | |
| repo_id = getattr(self.config, '_name_or_path', '1NEYRON1/whisper') | |
| # Скачиваем файлы весов напрямую из репозитория | |
| weak_learners_filename = ssl_ensemble_config["weak_learners_path"] | |
| meta_learner_filename = ssl_ensemble_config["meta_learner_path"] | |
| logger.info(f"Downloading {weak_learners_filename} from {repo_id}...") | |
| weak_learners_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=weak_learners_filename | |
| ) | |
| logger.info(f"Downloading {meta_learner_filename} from {repo_id}...") | |
| meta_learner_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=meta_learner_filename | |
| ) | |
| logger.info(f"Files downloaded successfully:") | |
| logger.info(f" Weak learners: {weak_learners_path}") | |
| logger.info(f" Meta learner: {meta_learner_path}") | |
| # Теперь загружаем модели из скачанных файлов | |
| weak_learners = WeakLearners( | |
| audio_dim=ssl_ensemble_config["audio_dim"], | |
| text_dim=ssl_ensemble_config["text_dim"], | |
| device=self._target_device.type | |
| ) | |
| if not weak_learners.load_fitted(weak_learners_path): | |
| raise RuntimeError(f"Failed to load weak learners from {weak_learners_path}") | |
| meta_learner = StackingMetaLearner( | |
| weak_output_dim=len(weak_learners.models), | |
| hidden_dim=ssl_ensemble_config["hidden_dim"] | |
| ) | |
| meta_learner.load_state_dict_from_file(meta_learner_path, device=self._target_device) | |
| # Создание итоговой модели ансамбля | |
| self.ssl_ensemble_model = SSLEnsembleModel( | |
| weak_learners=weak_learners, | |
| stacking_meta_learner=meta_learner | |
| ) | |
| self.ssl_ensemble_model.eval() | |
| logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.") | |
| except Exception as e: | |
| logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True) | |
| self.predict_mode = False | |
| logger.warning(" Prediction with SSL Ensemble will be disabled.") | |
| def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor: | |
| self._load_whisper_if_needed() | |
| processed_mels = [] | |
| # Use self.whisper_model.device as the definitive device for mel spectrograms | |
| # as whisper.load_model puts its tensors on that device. | |
| target_device_for_mels = self.whisper_model.device | |
| for audio in audios: | |
| if isinstance(audio, torch.Tensor): | |
| audio = audio.cpu().numpy() | |
| if not isinstance(audio, np.ndarray): | |
| raise TypeError(f"Expected audio input to be numpy array or torch tensor, got {type(audio)}") | |
| if audio.dtype != np.float32: | |
| audio = audio.astype(np.float32) | |
| audio_proc = whisper.pad_or_trim(audio) | |
| mel = whisper.log_mel_spectrogram(audio_proc, device=target_device_for_mels) | |
| processed_mels.append(mel) | |
| if not processed_mels: | |
| return torch.empty((0, self.whisper_model.dims.n_mels, self.whisper_model.dims.n_audio_ctx), device=target_device_for_mels) | |
| mels_batch = torch.stack(processed_mels) | |
| return mels_batch | |
| def preprocess_text(self, texts: List[str]) -> Optional[Dict[str, torch.Tensor]]: | |
| if not self.use_text or not self.tokenizer: | |
| return None | |
| if not texts: | |
| # Return empty tensors with correct device and shape for batching | |
| return { | |
| "input_ids": torch.empty((0, 0), dtype=torch.long, device=self._target_device), | |
| "attention_mask": torch.empty((0, 0), dtype=torch.long, device=self._target_device) | |
| } | |
| pad_token = self.tokenizer.pad_token if self.tokenizer.pad_token else "[PAD]" | |
| processed_texts = [t if isinstance(t, str) and t else pad_token for t in texts] | |
| inputs = self.tokenizer(processed_texts, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
| inputs = {k: v.to(self._target_device) for k, v in inputs.items()} | |
| return inputs | |
| def get_embeddings(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> tuple: | |
| self._load_whisper_if_needed() | |
| if self.use_text and texts is None: | |
| pass | |
| if not self.use_text and texts is not None: | |
| logger.warning("Text input provided to get_embeddings but no text model loaded. Text will be ignored.") | |
| texts = None | |
| mels = self.preprocess_audio(audios) | |
| # Handle empty batch input for embeddings | |
| if mels.numel() == 0: | |
| audio_emb = torch.empty((0, self._audio_embedding_dim), device=self._target_device) | |
| text_emb = None | |
| if self.use_text: | |
| text_emb = torch.empty((0, self._text_embedding_dim), device=self._target_device) | |
| return audio_emb, text_emb | |
| # Audio embeddings | |
| # Ensure audio_emb tensor is initialized on the correct device for subsequent operations | |
| audio_emb = torch.empty((mels.shape[0], self._audio_embedding_dim), device=self._target_device) | |
| with torch.no_grad(): | |
| encoder_output = self.whisper_model.encoder(mels) | |
| audio_emb = encoder_output.mean(dim=1) | |
| # Text embeddings | |
| text_emb = None | |
| if self.use_text: | |
| effective_texts = texts | |
| if texts is None: | |
| effective_texts = [""] * mels.shape[0] # Match audio batch size with empty strings if no text given | |
| tokenized_inputs = self.preprocess_text(effective_texts) | |
| if tokenized_inputs and tokenized_inputs["input_ids"].numel() > 0: | |
| with torch.no_grad(): | |
| outputs = self.text_model(**tokenized_inputs) | |
| text_emb = outputs.last_hidden_state[:, 0, :] | |
| elif self._text_embedding_dim > 0: # If text expected but tokenization resulted in empty output | |
| text_emb = torch.empty((mels.shape[0], self._text_embedding_dim), device=self._target_device) | |
| # Ensure embeddings are on the model's primary device (self.device) | |
| audio_emb = audio_emb.to(self._target_device) | |
| if text_emb is not None: | |
| text_emb = text_emb.to(self._target_device) | |
| # Add a check for batch size consistency between audio and text embeddings | |
| if text_emb.shape[0] != audio_emb.shape[0]: | |
| logger.warning(f"Batch size mismatch after text embedding computation: Audio {audio_emb.shape[0]}, Text {text_emb.shape[0]}. This might lead to issues in ensemble.") | |
| return audio_emb, text_emb | |
| def forward(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> torch.Tensor: | |
| if not self.predict_mode: | |
| raise RuntimeError("Cannot call 'forward' for final predictions. Model was not initialized with 'ssl_ensemble_config' for prediction mode. " | |
| "Use 'get_embeddings()' if you only need embeddings.") | |
| self._load_ssl_ensemble_if_needed() # <-- ДОБАВЬТЕ ЭТУ СТРОКУ | |
| self._load_whisper_if_needed() # (это уже должно быть там) | |
| self._load_text_model_if_needed() | |
| if self.ssl_ensemble_model is None: | |
| raise RuntimeError("Cannot call 'forward' for final predictions. Model was not initialized with 'ssl_ensemble_config' for prediction mode. " | |
| "Use 'get_embeddings()' if you only need embeddings.") | |
| weak_learners_cfg_text_dim = self.ssl_ensemble_model.weak_learners.text_dim | |
| texts_for_embeddings = texts | |
| if weak_learners_cfg_text_dim > 0: | |
| if not self.use_text: | |
| logger.warning("WhisperSSLEnsemble.forward: Ensemble's weak learners expect text (config text_dim > 0), " | |
| "but this model instance has no text processor (e.g., text_model_type='none' or load failed). " | |
| "Embeddings passed to weak learners will have text_emb=None. WeakLearners must handle this (e.g. by padding).") | |
| texts_for_embeddings = None | |
| elif texts is None: | |
| # If text_dim > 0 and no text provided, raise error. If this is acceptable, you'd pad here. | |
| raise ValueError("WhisperSSLEnsemble.forward: Ensemble's weak learners require text input (config text_dim > 0), but 'texts' argument is None.") | |
| elif weak_learners_cfg_text_dim == 0: | |
| if texts is not None: | |
| logger.warning("WhisperSSLEnsemble.forward: Ensemble's weak learners do not expect text (config text_dim == 0)," | |
| " but text input was provided. If a text model is loaded in this instance, text embeddings might be computed by " | |
| "get_embeddings but will then be ignored by the weak learners.") | |
| audio_emb, text_emb = self.get_embeddings(audios, texts_for_embeddings) | |
| if audio_emb.shape[0] == 0: | |
| return torch.empty((0, 1), device=self._target_device) | |
| with torch.no_grad(): | |
| final_predictions_tensor = self.ssl_ensemble_model(audio_emb, text_emb) | |
| return final_predictions_tensor | |
| def predict(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> List[float]: | |
| final_predictions_tensor = self.forward(audios, texts) | |
| if final_predictions_tensor.numel() == 0: | |
| return [] | |
| scores = final_predictions_tensor.detach().cpu().numpy().flatten().tolist() | |
| return scores | |
| # Register your model with AutoConfig and AutoModel | |
| AutoConfig.register("whisper-bert", WhisperSSLEnsembleConfig) | |
| AutoModel.register(WhisperSSLEnsembleConfig, WhisperSSLEnsemble) | |