whisper / modeling_whisper.py
1NEYRON1's picture
Update modeling_whisper.py
70cb727 verified
raw
history blame
57.7 kB
# import os
# from typing import Dict, List, Optional, Union
# import logging
# import joblib
# import numpy as np
# import torch
# import whisper
# from sklearn.linear_model import Ridge
# from sklearn.svm import SVR
# from sklearn.tree import DecisionTreeRegressor
# from torch import nn
# from tqdm import tqdm
# from transformers import AutoModel, AutoTokenizer, PreTrainedModel, AutoConfig
# from .configuration_whisper import WhisperSSLEnsembleConfig
# # Setup logger
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# logger = logging.getLogger(__name__)
# # --- Weak Learners ---
# class WeakLearners(nn.Module):
# """
# Wrapper for scikit-learn regression models to act as weak learners.
# Requires loading fitted models from a file before use in inference.
# Note: Sklearn models run on CPU. Data is moved accordingly during forward.
# """
# # Added default dimensions matching common Whisper base/BERT base
# def __init__(self, audio_dim: int = 512, text_dim: int = 768, device: str = "cpu") -> None:
# super().__init__()
# self.audio_dim = audio_dim
# self.text_dim = text_dim
# # Device primarily for moving predictions back, sklearn runs on CPU
# self.device = torch.device(device)
# # Initialize sklearn model structures (placeholders until loaded)
# self.ridge_regressor = Ridge(alpha=1.0)
# self.svr = SVR()
# self.dtr = DecisionTreeRegressor()
# self.models = [self.ridge_regressor, self.svr, self.dtr]
# self.model_names = ["Ridge", "SVR", "DTR"]
# self.fitted = False # Will be set to True after loading
# # Keep fit method for potential separate training script, but not used by WhisperModel/MultiModalWhisper directly
# def fit(self, train_loader: torch.utils.data.DataLoader) -> None:
# logger.info("Fitting weak learners...")
# all_audio_emb, all_text_emb, all_labels = [], [], []
# pbar = tqdm(train_loader, desc="Extracting Embeddings for Weak Learners", leave=False)
# for batch_data in pbar:
# # Flexible batch data handling (adjust if needed based on actual loader)
# if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 3:
# audio_emb, text_emb, labels = (batch_data[0], batch_data[1], batch_data[2])
# elif isinstance(batch_data, dict):
# audio_emb = batch_data.get("audio_embedding")
# text_emb = batch_data.get("text_embedding") # Might be None
# labels = batch_data.get("label")
# if audio_emb is None or labels is None:
# raise ValueError("Expected dict with 'audio_embedding' and 'label' keys")
# else:
# raise ValueError("Unsupported batch data format from train_loader")
# all_audio_emb.append(audio_emb.detach().cpu().numpy())
# if text_emb is not None:
# all_text_emb.append(text_emb.detach().cpu().numpy())
# elif self.text_dim > 0: # Pad if text expected but not provided in batch
# zeros_text = np.zeros((audio_emb.shape[0], self.text_dim))
# all_text_emb.append(zeros_text)
# all_labels.append(labels.detach().cpu().numpy())
# if not all_audio_emb or not all_labels:
# raise RuntimeError("No embeddings or labels collected. Check train_loader.")
# all_audio_emb = np.vstack(all_audio_emb)
# all_labels = np.concatenate(all_labels)
# if all_text_emb:
# all_text_emb = np.vstack(all_text_emb)
# if all_audio_emb.shape[0] != all_text_emb.shape[0]:
# raise ValueError(f"Sample count mismatch: Audio {all_audio_emb.shape[0]}, Text {all_text_emb.shape[0]}")
# combined_embeddings = np.hstack((all_audio_emb, all_text_emb))
# logger.info(f"Combined embedding shape for fitting: {combined_embeddings.shape}")
# else:
# combined_embeddings = all_audio_emb
# logger.info(f"Using only audio embeddings for fitting. Shape: {combined_embeddings.shape}")
# if combined_embeddings.shape[0] != len(all_labels):
# raise ValueError(f"Sample count mismatch: Embeddings {combined_embeddings.shape[0]}, Labels {len(all_labels)}")
# logger.info("Training sklearn models...")
# for model, name in zip(self.models, self.model_names):
# logger.info(f" Fitting {name}...")
# model.fit(combined_embeddings, all_labels)
# self.fitted = True
# logger.info("Weak learners fitting completed.")
# # --- Add save functionality ---
# # joblib.dump(self.models, 'fitted_weak_learners.joblib')
# # logger.info("Fitted weak learners saved to fitted_weak_learners.joblib")
# def load_fitted(self, filepath: str) -> bool:
# """Loads fitted sklearn models from a joblib file."""
# if not os.path.exists(filepath):
# logger.error(f"Weak learners file not found at {filepath}")
# return False
# try:
# loaded_sklearn_models = joblib.load(filepath)
# if isinstance(loaded_sklearn_models, list) and len(loaded_sklearn_models) == len(self.models):
# self.models = loaded_sklearn_models
# self.fitted = True
# logger.info(f" Weak learners loaded successfully from {filepath}.")
# return True
# logger.error(f"Loaded file '{filepath}' does not contain the expected list of {len(self.models)} sklearn models.")
# self.fitted = False
# return False
# except ImportError:
# logger.error("joblib is required to load weak learners but not installed. Run: pip install joblib")
# self.fitted = False
# return False
# except Exception as e:
# logger.error(f"Error loading weak learners from {filepath}: {e}")
# self.fitted = False
# return False
# def forward(self, audio_emb: torch.Tensor, text_emb: Optional[torch.Tensor]) -> torch.Tensor:
# """
# Generate predictions from loaded weak learners.
# Input tensors (on any device) will be moved to CPU for sklearn.
# Output tensor (predictions) will be on self.device.
# """
# if not self.fitted:
# # Changed error message slightly
# raise RuntimeError("Weak learners must be loaded using load_fitted() before calling forward.")
# # Prepare input for sklearn: move to CPU, detach, convert to numpy
# audio_np = audio_emb.detach().cpu().numpy()
# # Handle text embedding based on text_dim
# if self.text_dim > 0:
# if text_emb is None:
# # If text is expected (text_dim > 0) but not provided, pad with zeros
# logger.warning("Text embedding is None in WeakLearners forward pass, but text_dim > 0. Padding with zeros.")
# zeros_text = np.zeros((audio_np.shape[0], self.text_dim))
# combined_embeddings = np.hstack((audio_np, zeros_text))
# else:
# # Text is expected and provided
# text_np = text_emb.detach().cpu().numpy()
# if audio_np.shape[0] != text_np.shape[0]:
# raise ValueError("Batch size mismatch between audio and text embeddings in WeakLearners forward.")
# combined_embeddings = np.hstack((audio_np, text_np))
# else:
# # No text is expected (text_dim == 0)
# if text_emb is not None:
# logger.warning("Text embedding provided to WeakLearners forward pass, but text_dim is 0. Ignoring text.")
# combined_embeddings = audio_np
# # Get predictions from each loaded sklearn model
# all_preds = []
# # Use self.models which now contains the loaded models
# for model in self.models:
# preds = model.predict(combined_embeddings)
# # Convert numpy array predictions to float tensor
# all_preds.append(torch.from_numpy(preds).float())
# # Stack predictions along a new dimension (dim=1) -> (batch_size, num_weak_learners)
# # Move the final stacked tensor to the designated device (e.g., GPU if specified)
# stacked_preds = torch.stack(all_preds, dim=1).to(self.device)
# return stacked_preds
# # --- Stacking Model (Meta-Learner) ---
# class StackingMetaLearner(nn.Module):
# """
# A simple feed-forward network that learns to combine predictions
# from weak learners. Structure needs to match the saved weights.
# """
# def __init__(self, weak_output_dim: int = 3, hidden_dim: int = 256) -> None:
# super().__init__()
# # Check for valid dimensions
# if weak_output_dim <= 0 or hidden_dim <= 0:
# raise ValueError("weak_output_dim and hidden_dim must be positive integers.")
# self.fc1 = nn.Linear(weak_output_dim, hidden_dim)
# self.relu = nn.ReLU()
# self.fc2 = nn.Linear(hidden_dim, 1) # Predict a single score
# def load_state_dict_from_file(self, filepath: str, device: torch.device) -> None:
# if not os.path.exists(filepath):
# raise FileNotFoundError(f"Meta-learner state file not found at {filepath}")
# try:
# # Load state dict onto the specified device directly
# state_dict = torch.load(filepath, map_location=device)
# self.load_state_dict(state_dict)
# # Move model to the device *after* loading state_dict
# self.to(device)
# self.eval() # Set to evaluation mode after loading
# logger.info(f" StackingMetaLearner state loaded successfully from {filepath} onto device {device}.")
# except Exception as e:
# logger.error(f"Error loading StackingMetaLearner state_dict from {filepath}: {e}")
# raise RuntimeError(f"Failed to load StackingMetaLearner state: {e}") from e
# def forward(self, weak_outputs: torch.Tensor) -> torch.Tensor:
# """
# Args:
# weak_outputs (torch.Tensor): Tensor of shape (batch_size, weak_output_dim)
# containing predictions from weak learners.
# Returns:
# torch.Tensor: Final prediction of shape (batch_size, 1).
# """
# x = self.relu(self.fc1(weak_outputs))
# x = self.fc2(x)
# return x
# # --- Main Ensemble Model (Wrapper used internally now) ---
# class SSLEnsembleModel(nn.Module):
# """
# Combines WeakLearners and a StackingMetaLearner.
# Assumes WeakLearners are loaded externally and MetaLearner state_dict is loaded externally.
# This class mainly defines the forward pass logic using the components.
# """
# def __init__(self, weak_learners: WeakLearners, stacking_meta_learner: StackingMetaLearner) -> None:
# super().__init__()
# if not isinstance(weak_learners, WeakLearners) or not weak_learners.fitted:
# raise ValueError("A pre-loaded WeakLearners instance must be provided to SSLEnsembleModel.")
# if not isinstance(stacking_meta_learner, StackingMetaLearner):
# raise ValueError("A StackingMetaLearner instance must be provided.")
# self.weak_learners = weak_learners
# self.stacking_meta_learner = stacking_meta_learner
# def forward(self, audio_emb: torch.Tensor, text_emb: Optional[torch.Tensor]) -> torch.Tensor:
# """
# Forward pass through the ensemble.
# 1. Get predictions from weak learners.
# 2. Pass weak learner predictions to the stacking meta-learner.
# """
# # Get predictions from weak learners (shape: batch_size, num_weak_outputs)
# # WeakLearners forward handles device movement internally
# weak_outputs = self.weak_learners(audio_emb, text_emb) # Output is on weak_learners.device
# # Ensure weak_outputs are on the same device as the meta-learner before passing
# weak_outputs = weak_outputs.to(next(self.stacking_meta_learner.parameters()).device)
# # Get final prediction from the meta-learner
# final_output = self.stacking_meta_learner(weak_outputs) # Output shape (batch_size, 1)
# return final_output
# # --- Unified Main Model Class ---
# class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
# """
# Unified model using Whisper for audio, an optional transformer for text,
# and an optional SSL Ensemble (Weak + Meta learners) for final prediction.
# Inherits from BaseMultimodalModel.
# """
# config_class = WhisperSSLEnsembleConfig
# def __init__(self, config: WhisperSSLEnsembleConfig) -> None:
# # self,
# # whisper_variant: str = "base.en",
# # text_model_type: str = "bert-base-uncased",
# # weights: str = None, # Path to Whisper weights if not using standard variant download
# # device: str = None,
# # ssl_ensemble_config: Optional[Dict] = None,
# # super().__init__(weights) # Not calling super for BaseModel due to different init logic
# super().__init__(config)
# self.config = config
# whisper_variant = self.config.whisper_variant
# text_model_type = self.config.text_model_type
# weights = self.config.weights
# ssl_ensemble_config = self.config.ssl_ensemble_config
# # device = 'cpu'
# # if self.config.device is None:
# # device = "cuda" if torch.cuda.is_available() else "cpu"
# # else:
# # device = self.config.device
# # self.device = torch.device(device)
# self.predict_mode = False
# self.ssl_ensemble_model = None
# self.tokenizer = None
# self.text_model = None
# self._audio_embedding_dim = None
# self._text_embedding_dim = 0
# # logger.info(f"Initializing WhisperSSLEnsemble on device: {self.device}")
# try:
# # Determine if 'weights' is a path or a variant name for whisper.load_model
# # whisper.load_model can take a name like "base.en" or a path to a .pt file
# logger.info(f"Loading Whisper model: '{weights if weights else whisper_variant}'...")
# wm = whisper.load_model(weights if weights else whisper_variant, device=self.device)
# self.whisper_model = wm
# self.whisper_model.eval()
# self._audio_embedding_dim = self.whisper_model.encoder.ln_post.normalized_shape[0]
# logger.info(f" Whisper loaded. Audio embedding dimension: {self._audio_embedding_dim}")
# except Exception as e:
# logger.error(f"Error loading Whisper model: {e}")
# raise RuntimeError(f"Failed to load Whisper model '{weights or whisper_variant}'") from e
# self.use_text = text_model_type is not None and text_model_type.lower() != "none"
# if self.use_text:
# logger.info(f"Loading Text model '{text_model_type}'...")
# try:
# self.tokenizer = AutoTokenizer.from_pretrained(text_model_type)
# self.text_model = AutoModel.from_pretrained(text_model_type).to(self.device)
# self.text_model.eval()
# self._text_embedding_dim = self.text_model.config.hidden_size
# logger.info(f" Text model loaded. Text embedding dimension: {self._text_embedding_dim}")
# except Exception as e:
# logger.warning(f"Failed to load text model '{text_model_type}'. Error: {e}")
# logger.warning(" Text processing will be disabled.")
# self.use_text = False
# self._text_embedding_dim = 0
# else:
# logger.info("Text model type is 'none'. Text processing disabled.")
# self._text_embedding_dim = 0
# if ssl_ensemble_config is not None:
# logger.info("SSL Ensemble config provided. Initializing for prediction mode...")
# self.predict_mode = True
# required_keys = ["weak_learners_path", "meta_learner_path", "audio_dim", "text_dim", "hidden_dim"]
# if not all(key in ssl_ensemble_config for key in required_keys):
# raise ValueError(f"ssl_ensemble_config missing required keys: {required_keys}")
# cfg_audio_dim = ssl_ensemble_config["audio_dim"]
# cfg_text_dim = ssl_ensemble_config["text_dim"]
# if cfg_audio_dim != self._audio_embedding_dim:
# logger.warning(f"Ensemble config audio_dim ({cfg_audio_dim}) mismatches Whisper model dim ({self._audio_embedding_dim}).")
# # Validate text_dim from config against loaded text model
# if self.use_text and cfg_text_dim != self._text_embedding_dim:
# logger.warning(f"Ensemble config text_dim ({cfg_text_dim}) mismatches loaded text model dim ({self._text_embedding_dim}).")
# elif not self.use_text and cfg_text_dim != 0:
# logger.warning(f"Ensemble config expects text_dim={cfg_text_dim}, but no text model is loaded. WeakLearners must handle padding.")
# elif self.use_text and cfg_text_dim == 0: # Text model loaded, but ensemble config says text_dim=0
# logger.warning("Ensemble config expects text_dim=0, but a text model IS loaded. "
# "Text embeddings will be ignored by weak learners if their text_dim is 0.")
# weak_learners = WeakLearners(audio_dim=cfg_audio_dim, text_dim=cfg_text_dim, device=self.device)
# if not weak_learners.load_fitted(ssl_ensemble_config["weak_learners_path"]):
# raise RuntimeError("Failed to load weak learners for WhisperSSLEnsemble.")
# meta_learner = StackingMetaLearner(weak_output_dim=len(weak_learners.models), hidden_dim=ssl_ensemble_config["hidden_dim"])
# meta_learner.load_state_dict_from_file(ssl_ensemble_config["meta_learner_path"], device=self.device)
# self.ssl_ensemble_model = SSLEnsembleModel(weak_learners=weak_learners, stacking_meta_learner=meta_learner)
# self.ssl_ensemble_model.eval()
# logger.info("SSLEnsembleModel loaded successfully.")
# else:
# logger.info("No SSL Ensemble config provided. Model initialized in embedding mode.")
# logger.info("Call '.get_embeddings(audios, texts)' to get embeddings.")
# logger.info("Calling '.predict()' or '.forward()' will raise an error in this mode.")
# def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor:
# processed_mels = []
# for audio in audios:
# if isinstance(audio, torch.Tensor):
# audio = audio.cpu().numpy()
# if not isinstance(audio, np.ndarray):
# raise TypeError(f"Expected audio input to be numpy array or torch tensor, got {type(audio)}")
# if audio.dtype != np.float32:
# audio = audio.astype(np.float32)
# audio_proc = whisper.pad_or_trim(audio)
# mel = whisper.log_mel_spectrogram(audio_proc, device=self.whisper_model.device)
# processed_mels.append(mel)
# if not processed_mels:
# # Using Whisper's typical dimensions: 80 mel bins, 3000 frames (30s)
# return torch.empty((0, self.whisper_model.dims.n_mels, self.whisper_model.dims.n_audio_ctx), device=self.whisper_model.device)
# # Simpler: use known default values if whisper model object doesn't expose them easily
# # return torch.empty((0, 80, 3000), device=self.whisper_model.device)
# mels_batch = torch.stack(processed_mels)
# return mels_batch
# def preprocess_text(self, texts: List[str]) -> Optional[Dict[str, torch.Tensor]]:
# if not self.use_text or not self.tokenizer:
# return None
# if not texts:
# return {
# "input_ids": torch.empty((0, 0), dtype=torch.long, device=self.device),
# "attention_mask": torch.empty((0, 0), dtype=torch.long, device=self.device)
# }
# pad_token = self.tokenizer.pad_token if self.tokenizer.pad_token else "[PAD]"
# processed_texts = [t if isinstance(t, str) and t else pad_token for t in texts]
# inputs = self.tokenizer(processed_texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
# inputs = {k: v.to(self.device) for k, v in inputs.items()}
# return inputs
# def get_embeddings(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> tuple:
# if self.use_text and texts is None:
# pass # Allow texts to be None, WeakLearners will handle if it expects text.
# if not self.use_text and texts is not None:
# logger.warning("Text input provided to get_embeddings but no text model loaded. Text will be ignored.")
# texts = None # Ensure texts list is None if not used by this WhisperSSLEnsemble instance
# mels = self.preprocess_audio(audios)
# # Handle empty batch input for embeddings
# if mels.numel() == 0: # No audio provided
# audio_emb = torch.empty((0, self._audio_embedding_dim), device=self.device) # type: ignore
# text_emb = None
# if self.use_text:
# text_emb = torch.empty((0, self._text_embedding_dim), device=self.device)
# return audio_emb, text_emb
# audio_emb = torch.empty((mels.shape[0], self._audio_embedding_dim), device=self.device) # type: ignore
# if mels.numel() > 0:
# with torch.no_grad():
# encoder_output = self.whisper_model.encoder(mels)
# audio_emb = encoder_output.mean(dim=1)
# text_emb = None
# if self.use_text:
# # Ensure text batch matches audio if texts are provided or if texts are None but audio exists (pad with empty strings)
# effective_texts = texts
# if texts is None: # If no texts provided, but we use_text, create dummy empty strings for tokenization matching audio batch size
# effective_texts = [""] * mels.shape[0]
# elif len(texts) != mels.shape[0] and mels.shape[0] > 0: # If texts length mismatch audio batch, this is an issue or needs defined behavior
# # For now, if texts are provided but mismatch, we might just process what's given or raise error.
# # Let's assume if texts are given, they match the audio batch. Or pad/truncate.
# # This part of logic may need refinement based on expected use case for mismatched batch sizes.
# # For simplicity, we proceed with `effective_texts` as is, assuming `preprocess_text` handles it.
# # If texts is None and use_text is true, effective_texts will be list of empty strings matching audio batch.
# pass
# tokenized_inputs = self.preprocess_text(effective_texts) # type: ignore
# if tokenized_inputs and tokenized_inputs["input_ids"].numel() > 0:
# with torch.no_grad():
# outputs = self.text_model(**tokenized_inputs) # type: ignore
# text_emb = outputs.last_hidden_state[:, 0, :]
# elif self._text_embedding_dim > 0: # Texts were None or empty list resulting in no tokenized_inputs, but text is expected
# text_emb = torch.empty((mels.shape[0], self._text_embedding_dim), device=self.device)
# audio_emb = audio_emb.to(self.device)
# if text_emb is not None:
# text_emb = text_emb.to(self.device)
# if text_emb.shape[0] != audio_emb.shape[0] and audio_emb.shape[0] > 0:
# pass
# return audio_emb, text_emb
# def forward(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> torch.Tensor:
# """
# Overrides BaseMultimodalModel.forward.
# Processes raw audio and text inputs through the full pipeline (embeddings + SSL ensemble)
# to get final scores as a Tensor. This method is called by self.predict().
# If not in predict_mode, this method will raise a RuntimeError.
# """
# if not self.predict_mode or self.ssl_ensemble_model is None:
# raise RuntimeError("Cannot call 'forward' for final predictions. Model was not initialized with 'ssl_ensemble_config' for prediction mode. "
# "Use 'get_embeddings()' if you only need embeddings.")
# # Determine text input strategy for get_embeddings based on ensemble and model config
# weak_learners_cfg_text_dim = self.ssl_ensemble_model.weak_learners.text_dim
# texts_for_embeddings = texts # Default to provided texts
# if weak_learners_cfg_text_dim > 0: # Ensemble's weak learners expect text features
# if not self.use_text: # This WhisperSSLEnsemble instance has no text model
# logger.warning("WhisperSSLEnsemble.forward: Ensemble's weak learners expect text (config text_dim > 0), "
# "but this model instance has no text processor (e.g., text_model_type='none' or load failed). "
# "Embeddings passed to weak learners will have text_emb=None. WeakLearners must handle this (e.g. by padding).")
# texts_for_embeddings = None # Ensure get_embeddings receives None for text
# elif texts is None: # Model has text processor, ensemble expects text, but no 'texts' arg provided
# raise ValueError("WhisperSSLEnsemble.forward: Ensemble's weak learners require text input (config text_dim > 0), but 'texts' argument is None.")
# elif weak_learners_cfg_text_dim == 0: # Ensemble's weak learners DO NOT expect text features
# if texts is not None:
# logger.warning("WhisperSSLEnsemble.forward: Ensemble's weak learners do not expect text (config text_dim == 0),"
# " but text input was provided. If a text model is loaded in this instance, text embeddings might be computed by "
# "get_embeddings but will then be ignored by the weak learners.")
# # 1. Get Embeddings
# audio_emb, text_emb = self.get_embeddings(audios, texts_for_embeddings)
# if audio_emb.shape[0] == 0: # Handle empty batch case after getting embeddings
# return torch.empty((0, 1), device=self.device) # Return empty scores tensor, shape (batch, 1)
# # 2. Pass embeddings through the SSLEnsembleModel
# with torch.no_grad():
# final_predictions_tensor = self.ssl_ensemble_model(audio_emb, text_emb) # Shape: (batch_size, 1)
# return final_predictions_tensor
# def predict(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> List[float]:
# """
# Overrides BaseMultimodalModel.predict.
# Generates final predictions using the full pipeline (Embeddings + SSL Ensemble).
# Requires the model to be initialized in prediction mode.
# """
# final_predictions_tensor = self.forward(audios, texts)
# if final_predictions_tensor.numel() == 0:
# return []
# scores = final_predictions_tensor.detach().cpu().numpy().flatten().tolist()
# return scores
# AutoConfig.register("whisper-bert", WhisperSSLEnsembleConfig)
# AutoModel.register(WhisperSSLEnsembleConfig, WhisperSSLEnsemble)
import os
from typing import Dict, List, Optional, Union
import logging
import joblib
import numpy as np
import torch
import whisper
from sklearn.linear_model import Ridge
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from torch import nn
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, AutoConfig
# Assuming your configuration is in a file like 'configuration_whisper.py'
from .configuration_whisper import WhisperSSLEnsembleConfig
# Setup logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Weak Learners ---
class WeakLearners(nn.Module):
"""
Wrapper for scikit-learn regression models to act as weak learners.
Requires loading fitted models from a file before use in inference.
Note: Sklearn models run on CPU. Data is moved accordingly during forward.
"""
def __init__(self, audio_dim: int = 512, text_dim: int = 768, device: str = "cpu") -> None:
super().__init__()
self.audio_dim = audio_dim
self.text_dim = text_dim
# Store the target device (as torch.device) for moving predictions back
self.device = torch.device(device) # Keep as torch.device for consistency
# Initialize sklearn model structures (placeholders until loaded)
# These are not PyTorch modules/parameters, so they don't affect self.device directly
self.ridge_regressor = Ridge(alpha=1.0)
self.svr = SVR()
self.dtr = DecisionTreeRegressor()
self.models = [self.ridge_regressor, self.svr, self.dtr]
self.model_names = ["Ridge", "SVR", "DTR"]
self.fitted = False # Will be set to True after loading
def fit(self, train_loader: torch.utils.data.DataLoader) -> None:
logger.info("Fitting weak learners...")
all_audio_emb, all_text_emb, all_labels = [], [], []
pbar = tqdm(train_loader, desc="Extracting Embeddings for Weak Learners", leave=False)
for batch_data in pbar:
if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 3:
audio_emb, text_emb, labels = (batch_data[0], batch_data[1], batch_data[2])
elif isinstance(batch_data, dict):
audio_emb = batch_data.get("audio_embedding")
text_emb = batch_data.get("text_embedding")
labels = batch_data.get("label")
if audio_emb is None or labels is None:
raise ValueError("Expected dict with 'audio_embedding' and 'label' keys")
else:
raise ValueError("Unsupported batch data format from train_loader")
all_audio_emb.append(audio_emb.detach().cpu().numpy())
if text_emb is not None:
all_text_emb.append(text_emb.detach().cpu().numpy())
elif self.text_dim > 0:
zeros_text = np.zeros((audio_emb.shape[0], self.text_dim))
all_text_emb.append(zeros_text)
all_labels.append(labels.detach().cpu().numpy())
if not all_audio_emb or not all_labels:
raise RuntimeError("No embeddings or labels collected. Check train_loader.")
all_audio_emb = np.vstack(all_audio_emb)
all_labels = np.concatenate(all_labels)
if all_text_emb:
all_text_emb = np.vstack(all_text_emb)
if all_audio_emb.shape[0] != all_text_emb.shape[0]:
raise ValueError(f"Sample count mismatch: Audio {all_audio_emb.shape[0]}, Text {all_text_emb.shape[0]}")
combined_embeddings = np.hstack((all_audio_emb, all_text_emb))
logger.info(f"Combined embedding shape for fitting: {combined_embeddings.shape}")
else:
combined_embeddings = all_audio_emb
logger.info(f"Using only audio embeddings for fitting. Shape: {combined_embeddings.shape}")
if combined_embeddings.shape[0] != len(all_labels):
raise ValueError(f"Sample count mismatch: Embeddings {combined_embeddings.shape[0]}, Labels {len(all_labels)}")
logger.info("Training sklearn models...")
for model, name in zip(self.models, self.model_names):
logger.info(f" Fitting {name}...")
model.fit(combined_embeddings, all_labels)
self.fitted = True
logger.info("Weak learners fitting completed.")
def load_fitted(self, filepath: str) -> bool:
"""Loads fitted sklearn models from a joblib file."""
if not os.path.exists(filepath):
logger.error(f"Weak learners file not found at {filepath}")
return False
try:
loaded_sklearn_models = joblib.load(filepath)
if isinstance(loaded_sklearn_models, list) and len(loaded_sklearn_models) == len(self.models):
self.models = loaded_sklearn_models
self.fitted = True
logger.info(f" Weak learners loaded successfully from {filepath}.")
return True
logger.error(f"Loaded file '{filepath}' does not contain the expected list of {len(self.models)} sklearn models.")
self.fitted = False
return False
except ImportError:
logger.error("joblib is required to load weak learners but not installed. Run: pip install joblib")
self.fitted = False
return False
except Exception as e:
logger.error(f"Error loading weak learners from {filepath}: {e}")
self.fitted = False
return False
def forward(self, audio_emb: torch.Tensor, text_emb: Optional[torch.Tensor]) -> torch.Tensor:
"""
Generate predictions from loaded weak learners.
Input tensors (on any device) will be moved to CPU for sklearn.
Output tensor (predictions) will be on self.device.
"""
if not self.fitted:
raise RuntimeError("Weak learners must be loaded using load_fitted() before calling forward.")
audio_np = audio_emb.detach().cpu().numpy()
if self.text_dim > 0:
if text_emb is None:
logger.warning("Text embedding is None in WeakLearners forward pass, but text_dim > 0. Padding with zeros.")
zeros_text = np.zeros((audio_np.shape[0], self.text_dim))
combined_embeddings = np.hstack((audio_np, zeros_text))
else:
text_np = text_emb.detach().cpu().numpy()
if audio_np.shape[0] != text_np.shape[0]:
raise ValueError("Batch size mismatch between audio and text embeddings in WeakLearners forward.")
combined_embeddings = np.hstack((audio_np, text_np))
else:
if text_emb is not None:
logger.warning("Text embedding provided to WeakLearners forward pass, but text_dim is 0. Ignoring text.")
combined_embeddings = audio_np
all_preds = []
for model in self.models:
preds = model.predict(combined_embeddings)
all_preds.append(torch.from_numpy(preds).float())
stacked_preds = torch.stack(all_preds, dim=1).to(self.device)
return stacked_preds
# --- Stacking Model (Meta-Learner) ---
class StackingMetaLearner(nn.Module):
"""
A simple feed-forward network that learns to combine predictions
from weak learners. Structure needs to match the saved weights.
"""
def __init__(self, weak_output_dim: int = 3, hidden_dim: int = 256) -> None:
super().__init__()
if weak_output_dim <= 0 or hidden_dim <= 0:
raise ValueError("weak_output_dim and hidden_dim must be positive integers.")
self.fc1 = nn.Linear(weak_output_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, 1)
def load_state_dict_from_file(self, filepath: str, device: torch.device) -> None:
if not os.path.exists(filepath):
raise FileNotFoundError(f"Meta-learner state file not found at {filepath}")
try:
state_dict = torch.load(filepath, map_location=device)
self.load_state_dict(state_dict)
self.to(device) # Ensure the meta learner itself is on the device
self.eval()
logger.info(f" StackingMetaLearner state loaded successfully from {filepath} onto device {device}.")
except Exception as e:
logger.error(f"Error loading StackingMetaLearner state_dict from {filepath}: {e}")
raise RuntimeError(f"Failed to load StackingMetaLearner state: {e}") from e
def forward(self, weak_outputs: torch.Tensor) -> torch.Tensor:
x = self.relu(self.fc1(weak_outputs))
x = self.fc2(x)
return x
# --- Main Ensemble Model (Wrapper used internally now) ---
class SSLEnsembleModel(nn.Module):
"""
Combines WeakLearners and a StackingMetaLearner.
Assumes WeakLearners are loaded externally and MetaLearner state_dict is loaded externally.
This class mainly defines the forward pass logic using the components.
"""
def __init__(self, weak_learners: WeakLearners, stacking_meta_learner: StackingMetaLearner) -> None:
super().__init__()
if not isinstance(weak_learners, WeakLearners) or not weak_learners.fitted:
raise ValueError("A pre-loaded WeakLearners instance must be provided to SSLEnsembleModel.")
if not isinstance(stacking_meta_learner, StackingMetaLearner):
raise ValueError("A StackingMetaLearner instance must be provided.")
self.weak_learners = weak_learners
self.stacking_meta_learner = stacking_meta_learner
def forward(self, audio_emb: torch.Tensor, text_emb: Optional[torch.Tensor]) -> torch.Tensor:
weak_outputs = self.weak_learners(audio_emb, text_emb)
weak_outputs = weak_outputs.to(next(self.stacking_meta_learner.parameters()).device) # Ensure device match
final_output = self.stacking_meta_learner(weak_outputs)
return final_output
# --- Unified Main Model Class ---
class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
"""
Unified model using Whisper for audio, an optional transformer for text,
and an optional SSL Ensemble (Weak + Meta learners) for final prediction.
Inherits from BaseMultimodalModel.
"""
config_class = WhisperSSLEnsembleConfig
# В классе WhisperSSLEnsemble
def __init__(self, config: WhisperSSLEnsembleConfig) -> None:
super().__init__(config)
self.config = config
# 1. ОПРЕДЕЛЯЕМ ЦЕЛЕВОЕ УСТРОЙСТВО (БЕЗОПАСНО)
target_device_str = config.device if config.device else ("cuda" if torch.cuda.is_available() else "cpu")
self._target_device = torch.device(target_device_str)
logger.info(f"Initializing WhisperSSLEnsemble, targeting device: {self._target_device}")
# 2. ИНИЦИАЛИЗИРУЕМ ВСЕ КОМПОНЕНТЫ КАК None (БЕЗОПАСНО)
# Никакой загрузки здесь!
self.whisper_model = None
self.text_model = None
self.tokenizer = None
self.ssl_ensemble_model = None
# 3. ЧИТАЕМ ПРОСТЫЕ ПАРАМЕТРЫ ИЗ КОНФИГА (БЕЗОПАСНО)
# Эти размеры нужны для других частей кода
self._audio_embedding_dim = config.whisper_embedding_dim
self._text_embedding_dim = config.text_embedding_dim
# Определяем, будем ли мы вообще использовать под-модели
self.use_text = config.text_model_type is not None and config.text_model_type.lower() != "none"
self.predict_mode = config.ssl_ensemble_config is not None
if self.use_text:
logger.info(f"Text model '{config.text_model_type}' is configured and will be loaded on demand.")
if self.predict_mode:
logger.info("SSL Ensemble is configured and will be loaded on demand for prediction.")
# 4. ФИНАЛЬНЫЙ ВЫЗОВ .to() (БЕЗОПАСНО)
# Так как у модели еще нет параметров, этот вызов просто установит self.device
self.to(self._target_device)
logger.info(f"WhisperSSLEnsemble initialization complete. Final model device: {self._target_device}")
# Поместите этот метод внутри класса WhisperSSLEnsemble
# В классе WhisperSSLEnsemble
def _load_whisper_if_needed(self):
if self.whisper_model is not None:
return
whisper_path_or_variant = self.config.whisper_weights_path or self.config.whisper_variant
logger.info(f"Lazily loading Whisper model '{whisper_path_or_variant}'...")
try:
# ШАГ 1: Гарантированно создаем на CPU, чтобы избежать meta-контекста
with torch.device("cpu"):
# Загружаем модель, не указывая device, чтобы она осталась на CPU
whisper_model_cpu = whisper.load_model(whisper_path_or_variant)
# ШАГ 2: Перемещаем полностью "живую" модель на целевое устройство
self.whisper_model = whisper_model_cpu.to(self._target_device)
self.whisper_model.eval()
logger.info(f"Whisper model loaded successfully onto device '{self._target_device}'.")
except Exception as e:
logger.error(f"Failed to lazily load Whisper model: {e}", exc_info=True)
raise RuntimeError("Could not initialize the Whisper sub-component.") from e
def _load_text_model_if_needed(self):
if not self.use_text or self.text_model is not None:
return
text_model_type = self.config.text_model_type
logger.info(f"Lazily loading Text model and tokenizer '{text_model_type}'...")
try:
# Токенизатор можно загружать сразу
self.tokenizer = AutoTokenizer.from_pretrained(text_model_type)
# ШАГ 1: Гарантированно создаем модель на CPU
# Явно отключаем low_cpu_mem_usage, чтобы избежать meta-устройства
text_model_cpu = AutoModel.from_pretrained(
text_model_type,
low_cpu_mem_usage=False
)
# ШАГ 2: Перемещаем "живую" CPU-модель на целевое устройство (self.device)
self.text_model = text_model_cpu.to(self._target_device)
self.text_model.eval()
logger.info(f"Text model and tokenizer loaded successfully onto device '{self._target_device}'.")
# Проверка консистентности размеров (не меняется)
loaded_dim = self.text_model.config.hidden_size
if self.use_text and loaded_dim != self._text_embedding_dim:
logger.warning(f"Configured text_embedding_dim ({self._text_embedding_dim}) mismatches loaded model's dim ({loaded_dim}).")
self._text_embedding_dim = loaded_dim
except Exception as e:
logger.error(f"Failed to lazily load text model '{text_model_type}': {e}", exc_info=True)
self.use_text = False
logger.warning(" Text processing will be disabled due to load failure.")
# В классе WhisperSSLEnsemble
# def _load_ssl_ensemble_if_needed(self):
# # Если модель уже загружена или не нужна, выходим
# if self.ssl_ensemble_model is not None or not self.predict_mode:
# return
# logger.info("Lazily loading SSL Ensemble model...")
# ssl_ensemble_config = self.config.ssl_ensemble_config
# try:
# # ШАГ 1: НАЙТИ ДИРЕКТОРИЮ, ГДЕ ЛЕЖАТ ВСЕ ФАЙЛЫ МОДЕЛИ
# # Это самый важный и правильный шаг.
# # __file__ - это путь к текущему файлу (modeling_whisper.py)
# # os.path.dirname() получает директорию из этого пути.
# model_dir = os.path.dirname(__file__)
# # ШАГ 2: ПОСТРОИТЬ АБСОЛЮТНЫЕ ПУТИ К ФАЙЛАМ ВЕСОВ
# weak_learners_filename = ssl_ensemble_config["weak_learners_path"]
# weak_learners_path = os.path.join(model_dir, weak_learners_filename)
# meta_learner_filename = ssl_ensemble_config["meta_learner_path"]
# meta_learner_path = os.path.join(model_dir, meta_learner_filename)
# logger.info(f"Attempting to load weak learners from: {weak_learners_path}")
# logger.info(f"Attempting to load meta learner from: {meta_learner_path}")
# # ШАГ 3: ЗАГРУЗИТЬ МОДЕЛИ ПО АБСОЛЮТНЫМ ПУТЯМ
# weak_learners = WeakLearners(
# audio_dim=ssl_ensemble_config["audio_dim"],
# text_dim=ssl_ensemble_config["text_dim"],
# device=self._target_device.type
# )
# if not weak_learners.load_fitted(weak_learners_path):
# raise RuntimeError(f"Failed to load weak learners from {weak_learners_path}")
# meta_learner = StackingMetaLearner(
# weak_output_dim=len(weak_learners.models),
# hidden_dim=ssl_ensemble_config["hidden_dim"]
# )
# meta_learner.load_state_dict_from_file(meta_learner_path, device=self._target_device)
# # СОЗДАНИЕ ИТОГОВОЙ МОДЕЛИ АНСАМБЛЯ
# self.ssl_ensemble_model = SSLEnsembleModel(
# weak_learners=weak_learners,
# stacking_meta_learner=meta_learner
# )
# self.ssl_ensemble_model.eval()
# logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.")
# except Exception as e:
# logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True)
# self.predict_mode = False
# logger.warning(" Prediction with SSL Ensemble will be disabled.")
# В классе WhisperSSLEnsemble
def _load_ssl_ensemble_if_needed(self):
if self.ssl_ensemble_model is not None or not self.predict_mode:
return
logger.info("Lazily loading SSL Ensemble model...")
ssl_ensemble_config = self.config.ssl_ensemble_config
try:
# НОВЫЙ ПОДХОД: Скачиваем файлы напрямую из репозитория
from huggingface_hub import hf_hub_download
# Получаем имя репозитория из конфига
repo_id = getattr(self.config, '_name_or_path', '1NEYRON1/whisper')
# Скачиваем файлы весов напрямую из репозитория
weak_learners_filename = ssl_ensemble_config["weak_learners_path"]
meta_learner_filename = ssl_ensemble_config["meta_learner_path"]
logger.info(f"Downloading {weak_learners_filename} from {repo_id}...")
weak_learners_path = hf_hub_download(
repo_id=repo_id,
filename=weak_learners_filename
)
logger.info(f"Downloading {meta_learner_filename} from {repo_id}...")
meta_learner_path = hf_hub_download(
repo_id=repo_id,
filename=meta_learner_filename
)
logger.info(f"Files downloaded successfully:")
logger.info(f" Weak learners: {weak_learners_path}")
logger.info(f" Meta learner: {meta_learner_path}")
# Теперь загружаем модели из скачанных файлов
weak_learners = WeakLearners(
audio_dim=ssl_ensemble_config["audio_dim"],
text_dim=ssl_ensemble_config["text_dim"],
device=self._target_device.type
)
if not weak_learners.load_fitted(weak_learners_path):
raise RuntimeError(f"Failed to load weak learners from {weak_learners_path}")
meta_learner = StackingMetaLearner(
weak_output_dim=len(weak_learners.models),
hidden_dim=ssl_ensemble_config["hidden_dim"]
)
meta_learner.load_state_dict_from_file(meta_learner_path, device=self._target_device)
# Создание итоговой модели ансамбля
self.ssl_ensemble_model = SSLEnsembleModel(
weak_learners=weak_learners,
stacking_meta_learner=meta_learner
)
self.ssl_ensemble_model.eval()
logger.info(f"SSL Ensemble loaded successfully onto device {self._target_device}.")
except Exception as e:
logger.error(f"Failed to lazily load SSL Ensemble model: {e}", exc_info=True)
self.predict_mode = False
logger.warning(" Prediction with SSL Ensemble will be disabled.")
def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor:
self._load_whisper_if_needed()
processed_mels = []
# Use self.whisper_model.device as the definitive device for mel spectrograms
# as whisper.load_model puts its tensors on that device.
target_device_for_mels = self.whisper_model.device
for audio in audios:
if isinstance(audio, torch.Tensor):
audio = audio.cpu().numpy()
if not isinstance(audio, np.ndarray):
raise TypeError(f"Expected audio input to be numpy array or torch tensor, got {type(audio)}")
if audio.dtype != np.float32:
audio = audio.astype(np.float32)
audio_proc = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio_proc, device=target_device_for_mels)
processed_mels.append(mel)
if not processed_mels:
return torch.empty((0, self.whisper_model.dims.n_mels, self.whisper_model.dims.n_audio_ctx), device=target_device_for_mels)
mels_batch = torch.stack(processed_mels)
return mels_batch
def preprocess_text(self, texts: List[str]) -> Optional[Dict[str, torch.Tensor]]:
if not self.use_text or not self.tokenizer:
return None
if not texts:
# Return empty tensors with correct device and shape for batching
return {
"input_ids": torch.empty((0, 0), dtype=torch.long, device=self._target_device),
"attention_mask": torch.empty((0, 0), dtype=torch.long, device=self._target_device)
}
pad_token = self.tokenizer.pad_token if self.tokenizer.pad_token else "[PAD]"
processed_texts = [t if isinstance(t, str) and t else pad_token for t in texts]
inputs = self.tokenizer(processed_texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
inputs = {k: v.to(self._target_device) for k, v in inputs.items()}
return inputs
def get_embeddings(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> tuple:
self._load_whisper_if_needed()
if self.use_text and texts is None:
pass
if not self.use_text and texts is not None:
logger.warning("Text input provided to get_embeddings but no text model loaded. Text will be ignored.")
texts = None
mels = self.preprocess_audio(audios)
# Handle empty batch input for embeddings
if mels.numel() == 0:
audio_emb = torch.empty((0, self._audio_embedding_dim), device=self._target_device)
text_emb = None
if self.use_text:
text_emb = torch.empty((0, self._text_embedding_dim), device=self._target_device)
return audio_emb, text_emb
# Audio embeddings
# Ensure audio_emb tensor is initialized on the correct device for subsequent operations
audio_emb = torch.empty((mels.shape[0], self._audio_embedding_dim), device=self._target_device)
with torch.no_grad():
encoder_output = self.whisper_model.encoder(mels)
audio_emb = encoder_output.mean(dim=1)
# Text embeddings
text_emb = None
if self.use_text:
effective_texts = texts
if texts is None:
effective_texts = [""] * mels.shape[0] # Match audio batch size with empty strings if no text given
tokenized_inputs = self.preprocess_text(effective_texts)
if tokenized_inputs and tokenized_inputs["input_ids"].numel() > 0:
with torch.no_grad():
outputs = self.text_model(**tokenized_inputs)
text_emb = outputs.last_hidden_state[:, 0, :]
elif self._text_embedding_dim > 0: # If text expected but tokenization resulted in empty output
text_emb = torch.empty((mels.shape[0], self._text_embedding_dim), device=self._target_device)
# Ensure embeddings are on the model's primary device (self.device)
audio_emb = audio_emb.to(self._target_device)
if text_emb is not None:
text_emb = text_emb.to(self._target_device)
# Add a check for batch size consistency between audio and text embeddings
if text_emb.shape[0] != audio_emb.shape[0]:
logger.warning(f"Batch size mismatch after text embedding computation: Audio {audio_emb.shape[0]}, Text {text_emb.shape[0]}. This might lead to issues in ensemble.")
return audio_emb, text_emb
def forward(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> torch.Tensor:
if not self.predict_mode:
raise RuntimeError("Cannot call 'forward' for final predictions. Model was not initialized with 'ssl_ensemble_config' for prediction mode. "
"Use 'get_embeddings()' if you only need embeddings.")
self._load_ssl_ensemble_if_needed() # <-- ДОБАВЬТЕ ЭТУ СТРОКУ
self._load_whisper_if_needed() # (это уже должно быть там)
self._load_text_model_if_needed()
if self.ssl_ensemble_model is None:
raise RuntimeError("Cannot call 'forward' for final predictions. Model was not initialized with 'ssl_ensemble_config' for prediction mode. "
"Use 'get_embeddings()' if you only need embeddings.")
weak_learners_cfg_text_dim = self.ssl_ensemble_model.weak_learners.text_dim
texts_for_embeddings = texts
if weak_learners_cfg_text_dim > 0:
if not self.use_text:
logger.warning("WhisperSSLEnsemble.forward: Ensemble's weak learners expect text (config text_dim > 0), "
"but this model instance has no text processor (e.g., text_model_type='none' or load failed). "
"Embeddings passed to weak learners will have text_emb=None. WeakLearners must handle this (e.g. by padding).")
texts_for_embeddings = None
elif texts is None:
# If text_dim > 0 and no text provided, raise error. If this is acceptable, you'd pad here.
raise ValueError("WhisperSSLEnsemble.forward: Ensemble's weak learners require text input (config text_dim > 0), but 'texts' argument is None.")
elif weak_learners_cfg_text_dim == 0:
if texts is not None:
logger.warning("WhisperSSLEnsemble.forward: Ensemble's weak learners do not expect text (config text_dim == 0),"
" but text input was provided. If a text model is loaded in this instance, text embeddings might be computed by "
"get_embeddings but will then be ignored by the weak learners.")
audio_emb, text_emb = self.get_embeddings(audios, texts_for_embeddings)
if audio_emb.shape[0] == 0:
return torch.empty((0, 1), device=self._target_device)
with torch.no_grad():
final_predictions_tensor = self.ssl_ensemble_model(audio_emb, text_emb)
return final_predictions_tensor
def predict(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> List[float]:
final_predictions_tensor = self.forward(audios, texts)
if final_predictions_tensor.numel() == 0:
return []
scores = final_predictions_tensor.detach().cpu().numpy().flatten().tolist()
return scores
# Register your model with AutoConfig and AutoModel
AutoConfig.register("whisper-bert", WhisperSSLEnsembleConfig)
AutoModel.register(WhisperSSLEnsembleConfig, WhisperSSLEnsemble)