| import os |
| from typing import Dict, List, Optional, Union |
| import logging |
|
|
| import joblib |
| import numpy as np |
| import torch |
| import whisper |
| from sklearn.linear_model import Ridge |
| from sklearn.svm import SVR |
| from sklearn.tree import DecisionTreeRegressor |
| from torch import nn |
| from tqdm import tqdm |
| from transformers import AutoModel, AutoTokenizer, PreTrainedModel, AutoConfig |
| from .configuration_whisper import WhisperSSLEnsembleConfig |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| class WeakLearners(nn.Module): |
| """ |
| Wrapper for scikit-learn regression models to act as weak learners. |
| Requires loading fitted models from a file before use in inference. |
| Note: Sklearn models run on CPU. Data is moved accordingly during forward. |
| """ |
|
|
| |
| def __init__(self, audio_dim: int = 512, text_dim: int = 768, device: str = "cpu") -> None: |
| super().__init__() |
|
|
| self.audio_dim = audio_dim |
| self.text_dim = text_dim |
| |
| self.device = torch.device(device) |
|
|
| |
| self.ridge_regressor = Ridge(alpha=1.0) |
| self.svr = SVR() |
| self.dtr = DecisionTreeRegressor() |
|
|
| self.models = [self.ridge_regressor, self.svr, self.dtr] |
| self.model_names = ["Ridge", "SVR", "DTR"] |
| self.fitted = False |
|
|
| |
| def fit(self, train_loader: torch.utils.data.DataLoader) -> None: |
| logger.info("Fitting weak learners...") |
| all_audio_emb, all_text_emb, all_labels = [], [], [] |
| pbar = tqdm(train_loader, desc="Extracting Embeddings for Weak Learners", leave=False) |
| for batch_data in pbar: |
| |
| if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 3: |
| audio_emb, text_emb, labels = (batch_data[0], batch_data[1], batch_data[2]) |
| elif isinstance(batch_data, dict): |
| audio_emb = batch_data.get("audio_embedding") |
| text_emb = batch_data.get("text_embedding") |
| labels = batch_data.get("label") |
| if audio_emb is None or labels is None: |
| raise ValueError("Expected dict with 'audio_embedding' and 'label' keys") |
| else: |
| raise ValueError("Unsupported batch data format from train_loader") |
|
|
| all_audio_emb.append(audio_emb.detach().cpu().numpy()) |
| if text_emb is not None: |
| all_text_emb.append(text_emb.detach().cpu().numpy()) |
| elif self.text_dim > 0: |
| zeros_text = np.zeros((audio_emb.shape[0], self.text_dim)) |
| all_text_emb.append(zeros_text) |
| all_labels.append(labels.detach().cpu().numpy()) |
|
|
| if not all_audio_emb or not all_labels: |
| raise RuntimeError("No embeddings or labels collected. Check train_loader.") |
|
|
| all_audio_emb = np.vstack(all_audio_emb) |
| all_labels = np.concatenate(all_labels) |
|
|
| if all_text_emb: |
| all_text_emb = np.vstack(all_text_emb) |
| if all_audio_emb.shape[0] != all_text_emb.shape[0]: |
| raise ValueError(f"Sample count mismatch: Audio {all_audio_emb.shape[0]}, Text {all_text_emb.shape[0]}") |
| combined_embeddings = np.hstack((all_audio_emb, all_text_emb)) |
| logger.info(f"Combined embedding shape for fitting: {combined_embeddings.shape}") |
| else: |
| combined_embeddings = all_audio_emb |
| logger.info(f"Using only audio embeddings for fitting. Shape: {combined_embeddings.shape}") |
|
|
| if combined_embeddings.shape[0] != len(all_labels): |
| raise ValueError(f"Sample count mismatch: Embeddings {combined_embeddings.shape[0]}, Labels {len(all_labels)}") |
|
|
| logger.info("Training sklearn models...") |
| for model, name in zip(self.models, self.model_names): |
| logger.info(f" Fitting {name}...") |
| model.fit(combined_embeddings, all_labels) |
| self.fitted = True |
| logger.info("Weak learners fitting completed.") |
| |
| |
| |
|
|
| def load_fitted(self, filepath: str) -> bool: |
| """Loads fitted sklearn models from a joblib file.""" |
| if not os.path.exists(filepath): |
| logger.error(f"Weak learners file not found at {filepath}") |
| return False |
| try: |
| loaded_sklearn_models = joblib.load(filepath) |
| if isinstance(loaded_sklearn_models, list) and len(loaded_sklearn_models) == len(self.models): |
| self.models = loaded_sklearn_models |
| self.fitted = True |
| logger.info(f" Weak learners loaded successfully from {filepath}.") |
| return True |
|
|
| logger.error(f"Loaded file '{filepath}' does not contain the expected list of {len(self.models)} sklearn models.") |
| self.fitted = False |
| return False |
| except ImportError: |
| logger.error("joblib is required to load weak learners but not installed. Run: pip install joblib") |
| self.fitted = False |
| return False |
| except Exception as e: |
| logger.error(f"Error loading weak learners from {filepath}: {e}") |
| self.fitted = False |
| return False |
|
|
| def forward(self, audio_emb: torch.Tensor, text_emb: Optional[torch.Tensor]) -> torch.Tensor: |
| """ |
| Generate predictions from loaded weak learners. |
| Input tensors (on any device) will be moved to CPU for sklearn. |
| Output tensor (predictions) will be on self.device. |
| """ |
| if not self.fitted: |
| |
| raise RuntimeError("Weak learners must be loaded using load_fitted() before calling forward.") |
|
|
| |
| audio_np = audio_emb.detach().cpu().numpy() |
| |
| if self.text_dim > 0: |
| if text_emb is None: |
| |
| logger.warning("Text embedding is None in WeakLearners forward pass, but text_dim > 0. Padding with zeros.") |
| zeros_text = np.zeros((audio_np.shape[0], self.text_dim)) |
| combined_embeddings = np.hstack((audio_np, zeros_text)) |
| else: |
| |
| text_np = text_emb.detach().cpu().numpy() |
| if audio_np.shape[0] != text_np.shape[0]: |
| raise ValueError("Batch size mismatch between audio and text embeddings in WeakLearners forward.") |
| combined_embeddings = np.hstack((audio_np, text_np)) |
| else: |
| |
| if text_emb is not None: |
| logger.warning("Text embedding provided to WeakLearners forward pass, but text_dim is 0. Ignoring text.") |
| combined_embeddings = audio_np |
|
|
| |
| all_preds = [] |
| |
| for model in self.models: |
| preds = model.predict(combined_embeddings) |
| |
| all_preds.append(torch.from_numpy(preds).float()) |
|
|
| |
| |
| stacked_preds = torch.stack(all_preds, dim=1).to(self.device) |
| return stacked_preds |
|
|
|
|
| |
| class StackingMetaLearner(nn.Module): |
| """ |
| A simple feed-forward network that learns to combine predictions |
| from weak learners. Structure needs to match the saved weights. |
| """ |
|
|
| def __init__(self, weak_output_dim: int = 3, hidden_dim: int = 256) -> None: |
| super().__init__() |
| |
| if weak_output_dim <= 0 or hidden_dim <= 0: |
| raise ValueError("weak_output_dim and hidden_dim must be positive integers.") |
| self.fc1 = nn.Linear(weak_output_dim, hidden_dim) |
| self.relu = nn.ReLU() |
| self.fc2 = nn.Linear(hidden_dim, 1) |
|
|
| def load_state_dict_from_file(self, filepath: str, device: torch.device) -> None: |
| if not os.path.exists(filepath): |
| raise FileNotFoundError(f"Meta-learner state file not found at {filepath}") |
| try: |
| |
| state_dict = torch.load(filepath, map_location=device) |
| self.load_state_dict(state_dict) |
| |
| self.to(device) |
| self.eval() |
| logger.info(f" StackingMetaLearner state loaded successfully from {filepath} onto device {device}.") |
| except Exception as e: |
| logger.error(f"Error loading StackingMetaLearner state_dict from {filepath}: {e}") |
| raise RuntimeError(f"Failed to load StackingMetaLearner state: {e}") from e |
|
|
| def forward(self, weak_outputs: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| weak_outputs (torch.Tensor): Tensor of shape (batch_size, weak_output_dim) |
| containing predictions from weak learners. |
| Returns: |
| torch.Tensor: Final prediction of shape (batch_size, 1). |
| """ |
| x = self.relu(self.fc1(weak_outputs)) |
| x = self.fc2(x) |
| return x |
|
|
|
|
| |
| class SSLEnsembleModel(nn.Module): |
| """ |
| Combines WeakLearners and a StackingMetaLearner. |
| Assumes WeakLearners are loaded externally and MetaLearner state_dict is loaded externally. |
| This class mainly defines the forward pass logic using the components. |
| """ |
|
|
| def __init__(self, weak_learners: WeakLearners, stacking_meta_learner: StackingMetaLearner) -> None: |
| super().__init__() |
| if not isinstance(weak_learners, WeakLearners) or not weak_learners.fitted: |
| raise ValueError("A pre-loaded WeakLearners instance must be provided to SSLEnsembleModel.") |
| if not isinstance(stacking_meta_learner, StackingMetaLearner): |
| raise ValueError("A StackingMetaLearner instance must be provided.") |
|
|
| self.weak_learners = weak_learners |
| self.stacking_meta_learner = stacking_meta_learner |
|
|
| def forward(self, audio_emb: torch.Tensor, text_emb: Optional[torch.Tensor]) -> torch.Tensor: |
| """ |
| Forward pass through the ensemble. |
| 1. Get predictions from weak learners. |
| 2. Pass weak learner predictions to the stacking meta-learner. |
| """ |
| |
| |
| weak_outputs = self.weak_learners(audio_emb, text_emb) |
|
|
| |
| weak_outputs = weak_outputs.to(next(self.stacking_meta_learner.parameters()).device) |
|
|
| |
| final_output = self.stacking_meta_learner(weak_outputs) |
| return final_output |
|
|
|
|
| |
| class WhisperSSLEnsemble(PreTrainedModel): |
| """ |
| Unified model using Whisper for audio, an optional transformer for text, |
| and an optional SSL Ensemble (Weak + Meta learners) for final prediction. |
| Inherits from BaseMultimodalModel. |
| """ |
| config_class = WhisperSSLEnsembleConfig |
| |
| def __init__(self, config: WhisperSSLEnsembleConfig) -> None: |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| super().__init__(config) |
| self.config = config |
| whisper_variant = self.config.whisper_variant |
| text_model_type = self.config.text_model_type |
| weights = self.config.weights |
| ssl_ensemble_config = self.config.ssl_ensemble_config |
| |
| |
| |
| |
| |
| |
| |
| self.predict_mode = False |
| self.ssl_ensemble_model = None |
| self.tokenizer = None |
| self.text_model = None |
| self._audio_embedding_dim = None |
| self._text_embedding_dim = 0 |
|
|
| |
| try: |
| |
| |
| logger.info(f"Loading Whisper model: '{weights if weights else whisper_variant}'...") |
| wm = whisper.load_model(weights if weights else whisper_variant, device=self.device) |
| self.whisper_model = wm |
| self.whisper_model.eval() |
| self._audio_embedding_dim = self.whisper_model.encoder.ln_post.normalized_shape[0] |
| logger.info(f" Whisper loaded. Audio embedding dimension: {self._audio_embedding_dim}") |
| except Exception as e: |
| logger.error(f"Error loading Whisper model: {e}") |
| raise RuntimeError(f"Failed to load Whisper model '{weights or whisper_variant}'") from e |
|
|
| self.use_text = text_model_type is not None and text_model_type.lower() != "none" |
| if self.use_text: |
| logger.info(f"Loading Text model '{text_model_type}'...") |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained(text_model_type) |
| self.text_model = AutoModel.from_pretrained(text_model_type).to(self.device) |
| self.text_model.eval() |
| self._text_embedding_dim = self.text_model.config.hidden_size |
| logger.info(f" Text model loaded. Text embedding dimension: {self._text_embedding_dim}") |
| except Exception as e: |
| logger.warning(f"Failed to load text model '{text_model_type}'. Error: {e}") |
| logger.warning(" Text processing will be disabled.") |
| self.use_text = False |
| self._text_embedding_dim = 0 |
| else: |
| logger.info("Text model type is 'none'. Text processing disabled.") |
| self._text_embedding_dim = 0 |
|
|
| if ssl_ensemble_config is not None: |
| logger.info("SSL Ensemble config provided. Initializing for prediction mode...") |
| self.predict_mode = True |
| required_keys = ["weak_learners_path", "meta_learner_path", "audio_dim", "text_dim", "hidden_dim"] |
| if not all(key in ssl_ensemble_config for key in required_keys): |
| raise ValueError(f"ssl_ensemble_config missing required keys: {required_keys}") |
|
|
| cfg_audio_dim = ssl_ensemble_config["audio_dim"] |
| cfg_text_dim = ssl_ensemble_config["text_dim"] |
|
|
| if cfg_audio_dim != self._audio_embedding_dim: |
| logger.warning(f"Ensemble config audio_dim ({cfg_audio_dim}) mismatches Whisper model dim ({self._audio_embedding_dim}).") |
|
|
| |
| if self.use_text and cfg_text_dim != self._text_embedding_dim: |
| logger.warning(f"Ensemble config text_dim ({cfg_text_dim}) mismatches loaded text model dim ({self._text_embedding_dim}).") |
| elif not self.use_text and cfg_text_dim != 0: |
| logger.warning(f"Ensemble config expects text_dim={cfg_text_dim}, but no text model is loaded. WeakLearners must handle padding.") |
| elif self.use_text and cfg_text_dim == 0: |
| logger.warning("Ensemble config expects text_dim=0, but a text model IS loaded. " |
| "Text embeddings will be ignored by weak learners if their text_dim is 0.") |
|
|
| weak_learners = WeakLearners(audio_dim=cfg_audio_dim, text_dim=cfg_text_dim, device=self.device) |
| if not weak_learners.load_fitted(ssl_ensemble_config["weak_learners_path"]): |
| raise RuntimeError("Failed to load weak learners for WhisperSSLEnsemble.") |
|
|
| meta_learner = StackingMetaLearner(weak_output_dim=len(weak_learners.models), hidden_dim=ssl_ensemble_config["hidden_dim"]) |
| meta_learner.load_state_dict_from_file(ssl_ensemble_config["meta_learner_path"], device=self.device) |
|
|
| self.ssl_ensemble_model = SSLEnsembleModel(weak_learners=weak_learners, stacking_meta_learner=meta_learner) |
| self.ssl_ensemble_model.eval() |
| logger.info("SSLEnsembleModel loaded successfully.") |
| else: |
| logger.info("No SSL Ensemble config provided. Model initialized in embedding mode.") |
| logger.info("Call '.get_embeddings(audios, texts)' to get embeddings.") |
| logger.info("Calling '.predict()' or '.forward()' will raise an error in this mode.") |
|
|
| def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor: |
| processed_mels = [] |
| for audio in audios: |
| if isinstance(audio, torch.Tensor): |
| audio = audio.cpu().numpy() |
| if not isinstance(audio, np.ndarray): |
| raise TypeError(f"Expected audio input to be numpy array or torch tensor, got {type(audio)}") |
| if audio.dtype != np.float32: |
| audio = audio.astype(np.float32) |
|
|
| audio_proc = whisper.pad_or_trim(audio) |
| mel = whisper.log_mel_spectrogram(audio_proc, device=self.whisper_model.device) |
| processed_mels.append(mel) |
|
|
| if not processed_mels: |
| |
| return torch.empty((0, self.whisper_model.dims.n_mels, self.whisper_model.dims.n_audio_ctx), device=self.whisper_model.device) |
| |
| |
|
|
| mels_batch = torch.stack(processed_mels) |
| return mels_batch |
|
|
| def preprocess_text(self, texts: List[str]) -> Optional[Dict[str, torch.Tensor]]: |
| if not self.use_text or not self.tokenizer: |
| return None |
| if not texts: |
| return { |
| "input_ids": torch.empty((0, 0), dtype=torch.long, device=self.device), |
| "attention_mask": torch.empty((0, 0), dtype=torch.long, device=self.device) |
| } |
|
|
| pad_token = self.tokenizer.pad_token if self.tokenizer.pad_token else "[PAD]" |
| processed_texts = [t if isinstance(t, str) and t else pad_token for t in texts] |
|
|
| inputs = self.tokenizer(processed_texts, return_tensors="pt", padding=True, truncation=True, max_length=128) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| return inputs |
|
|
| def get_embeddings(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> tuple: |
| if self.use_text and texts is None: |
| pass |
|
|
| if not self.use_text and texts is not None: |
| logger.warning("Text input provided to get_embeddings but no text model loaded. Text will be ignored.") |
| texts = None |
|
|
| mels = self.preprocess_audio(audios) |
|
|
| |
| if mels.numel() == 0: |
| audio_emb = torch.empty((0, self._audio_embedding_dim), device=self.device) |
| text_emb = None |
| if self.use_text: |
| text_emb = torch.empty((0, self._text_embedding_dim), device=self.device) |
| return audio_emb, text_emb |
|
|
| audio_emb = torch.empty((mels.shape[0], self._audio_embedding_dim), device=self.device) |
| if mels.numel() > 0: |
| with torch.no_grad(): |
| encoder_output = self.whisper_model.encoder(mels) |
| audio_emb = encoder_output.mean(dim=1) |
|
|
| text_emb = None |
| if self.use_text: |
| |
| effective_texts = texts |
| if texts is None: |
| effective_texts = [""] * mels.shape[0] |
| elif len(texts) != mels.shape[0] and mels.shape[0] > 0: |
| |
| |
| |
| |
| |
| pass |
|
|
| tokenized_inputs = self.preprocess_text(effective_texts) |
| if tokenized_inputs and tokenized_inputs["input_ids"].numel() > 0: |
| with torch.no_grad(): |
| outputs = self.text_model(**tokenized_inputs) |
| text_emb = outputs.last_hidden_state[:, 0, :] |
| elif self._text_embedding_dim > 0: |
| text_emb = torch.empty((mels.shape[0], self._text_embedding_dim), device=self.device) |
|
|
| audio_emb = audio_emb.to(self.device) |
| if text_emb is not None: |
| text_emb = text_emb.to(self.device) |
| if text_emb.shape[0] != audio_emb.shape[0] and audio_emb.shape[0] > 0: |
| pass |
|
|
| return audio_emb, text_emb |
|
|
| def forward(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> torch.Tensor: |
| """ |
| Overrides BaseMultimodalModel.forward. |
| Processes raw audio and text inputs through the full pipeline (embeddings + SSL ensemble) |
| to get final scores as a Tensor. This method is called by self.predict(). |
| If not in predict_mode, this method will raise a RuntimeError. |
| """ |
| if not self.predict_mode or self.ssl_ensemble_model is None: |
| raise RuntimeError("Cannot call 'forward' for final predictions. Model was not initialized with 'ssl_ensemble_config' for prediction mode. " |
| "Use 'get_embeddings()' if you only need embeddings.") |
|
|
| |
| weak_learners_cfg_text_dim = self.ssl_ensemble_model.weak_learners.text_dim |
| texts_for_embeddings = texts |
|
|
| if weak_learners_cfg_text_dim > 0: |
| if not self.use_text: |
| logger.warning("WhisperSSLEnsemble.forward: Ensemble's weak learners expect text (config text_dim > 0), " |
| "but this model instance has no text processor (e.g., text_model_type='none' or load failed). " |
| "Embeddings passed to weak learners will have text_emb=None. WeakLearners must handle this (e.g. by padding).") |
| texts_for_embeddings = None |
| elif texts is None: |
| raise ValueError("WhisperSSLEnsemble.forward: Ensemble's weak learners require text input (config text_dim > 0), but 'texts' argument is None.") |
| elif weak_learners_cfg_text_dim == 0: |
| if texts is not None: |
| logger.warning("WhisperSSLEnsemble.forward: Ensemble's weak learners do not expect text (config text_dim == 0)," |
| " but text input was provided. If a text model is loaded in this instance, text embeddings might be computed by " |
| "get_embeddings but will then be ignored by the weak learners.") |
|
|
| |
| audio_emb, text_emb = self.get_embeddings(audios, texts_for_embeddings) |
|
|
| if audio_emb.shape[0] == 0: |
| return torch.empty((0, 1), device=self.device) |
|
|
| |
| with torch.no_grad(): |
| final_predictions_tensor = self.ssl_ensemble_model(audio_emb, text_emb) |
|
|
| return final_predictions_tensor |
|
|
| def predict(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> List[float]: |
| """ |
| Overrides BaseMultimodalModel.predict. |
| Generates final predictions using the full pipeline (Embeddings + SSL Ensemble). |
| Requires the model to be initialized in prediction mode. |
| """ |
| final_predictions_tensor = self.forward(audios, texts) |
|
|
| if final_predictions_tensor.numel() == 0: |
| return [] |
|
|
| scores = final_predictions_tensor.detach().cpu().numpy().flatten().tolist() |
| return scores |
|
|
|
|
| AutoConfig.register("whisper-bert", WhisperSSLEnsembleConfig) |
| AutoModel.register(WhisperSSLEnsembleConfig, WhisperSSLEnsemble) |