whisper / modeling_whisper.py
1NEYRON1's picture
Update modeling_whisper.py
c72fbfe verified
raw
history blame
27.2 kB
import os
from typing import Dict, List, Optional, Union
import logging
import joblib
import numpy as np
import torch
import whisper
from sklearn.linear_model import Ridge
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from torch import nn
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, AutoConfig
from .configuration_whisper import WhisperSSLEnsembleConfig
# Setup logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Weak Learners ---
class WeakLearners(nn.Module):
"""
Wrapper for scikit-learn regression models to act as weak learners.
Requires loading fitted models from a file before use in inference.
Note: Sklearn models run on CPU. Data is moved accordingly during forward.
"""
# Added default dimensions matching common Whisper base/BERT base
def __init__(self, audio_dim: int = 512, text_dim: int = 768, device: str = "cpu") -> None:
super().__init__()
self.audio_dim = audio_dim
self.text_dim = text_dim
# Device primarily for moving predictions back, sklearn runs on CPU
self.device = torch.device(device)
# Initialize sklearn model structures (placeholders until loaded)
self.ridge_regressor = Ridge(alpha=1.0)
self.svr = SVR()
self.dtr = DecisionTreeRegressor()
self.models = [self.ridge_regressor, self.svr, self.dtr]
self.model_names = ["Ridge", "SVR", "DTR"]
self.fitted = False # Will be set to True after loading
# Keep fit method for potential separate training script, but not used by WhisperModel/MultiModalWhisper directly
def fit(self, train_loader: torch.utils.data.DataLoader) -> None:
logger.info("Fitting weak learners...")
all_audio_emb, all_text_emb, all_labels = [], [], []
pbar = tqdm(train_loader, desc="Extracting Embeddings for Weak Learners", leave=False)
for batch_data in pbar:
# Flexible batch data handling (adjust if needed based on actual loader)
if isinstance(batch_data, (list, tuple)) and len(batch_data) >= 3:
audio_emb, text_emb, labels = (batch_data[0], batch_data[1], batch_data[2])
elif isinstance(batch_data, dict):
audio_emb = batch_data.get("audio_embedding")
text_emb = batch_data.get("text_embedding") # Might be None
labels = batch_data.get("label")
if audio_emb is None or labels is None:
raise ValueError("Expected dict with 'audio_embedding' and 'label' keys")
else:
raise ValueError("Unsupported batch data format from train_loader")
all_audio_emb.append(audio_emb.detach().cpu().numpy())
if text_emb is not None:
all_text_emb.append(text_emb.detach().cpu().numpy())
elif self.text_dim > 0: # Pad if text expected but not provided in batch
zeros_text = np.zeros((audio_emb.shape[0], self.text_dim))
all_text_emb.append(zeros_text)
all_labels.append(labels.detach().cpu().numpy())
if not all_audio_emb or not all_labels:
raise RuntimeError("No embeddings or labels collected. Check train_loader.")
all_audio_emb = np.vstack(all_audio_emb)
all_labels = np.concatenate(all_labels)
if all_text_emb:
all_text_emb = np.vstack(all_text_emb)
if all_audio_emb.shape[0] != all_text_emb.shape[0]:
raise ValueError(f"Sample count mismatch: Audio {all_audio_emb.shape[0]}, Text {all_text_emb.shape[0]}")
combined_embeddings = np.hstack((all_audio_emb, all_text_emb))
logger.info(f"Combined embedding shape for fitting: {combined_embeddings.shape}")
else:
combined_embeddings = all_audio_emb
logger.info(f"Using only audio embeddings for fitting. Shape: {combined_embeddings.shape}")
if combined_embeddings.shape[0] != len(all_labels):
raise ValueError(f"Sample count mismatch: Embeddings {combined_embeddings.shape[0]}, Labels {len(all_labels)}")
logger.info("Training sklearn models...")
for model, name in zip(self.models, self.model_names):
logger.info(f" Fitting {name}...")
model.fit(combined_embeddings, all_labels)
self.fitted = True
logger.info("Weak learners fitting completed.")
# --- Add save functionality ---
# joblib.dump(self.models, 'fitted_weak_learners.joblib')
# logger.info("Fitted weak learners saved to fitted_weak_learners.joblib")
def load_fitted(self, filepath: str) -> bool:
"""Loads fitted sklearn models from a joblib file."""
if not os.path.exists(filepath):
logger.error(f"Weak learners file not found at {filepath}")
return False
try:
loaded_sklearn_models = joblib.load(filepath)
if isinstance(loaded_sklearn_models, list) and len(loaded_sklearn_models) == len(self.models):
self.models = loaded_sklearn_models
self.fitted = True
logger.info(f" Weak learners loaded successfully from {filepath}.")
return True
logger.error(f"Loaded file '{filepath}' does not contain the expected list of {len(self.models)} sklearn models.")
self.fitted = False
return False
except ImportError:
logger.error("joblib is required to load weak learners but not installed. Run: pip install joblib")
self.fitted = False
return False
except Exception as e:
logger.error(f"Error loading weak learners from {filepath}: {e}")
self.fitted = False
return False
def forward(self, audio_emb: torch.Tensor, text_emb: Optional[torch.Tensor]) -> torch.Tensor:
"""
Generate predictions from loaded weak learners.
Input tensors (on any device) will be moved to CPU for sklearn.
Output tensor (predictions) will be on self.device.
"""
if not self.fitted:
# Changed error message slightly
raise RuntimeError("Weak learners must be loaded using load_fitted() before calling forward.")
# Prepare input for sklearn: move to CPU, detach, convert to numpy
audio_np = audio_emb.detach().cpu().numpy()
# Handle text embedding based on text_dim
if self.text_dim > 0:
if text_emb is None:
# If text is expected (text_dim > 0) but not provided, pad with zeros
logger.warning("Text embedding is None in WeakLearners forward pass, but text_dim > 0. Padding with zeros.")
zeros_text = np.zeros((audio_np.shape[0], self.text_dim))
combined_embeddings = np.hstack((audio_np, zeros_text))
else:
# Text is expected and provided
text_np = text_emb.detach().cpu().numpy()
if audio_np.shape[0] != text_np.shape[0]:
raise ValueError("Batch size mismatch between audio and text embeddings in WeakLearners forward.")
combined_embeddings = np.hstack((audio_np, text_np))
else:
# No text is expected (text_dim == 0)
if text_emb is not None:
logger.warning("Text embedding provided to WeakLearners forward pass, but text_dim is 0. Ignoring text.")
combined_embeddings = audio_np
# Get predictions from each loaded sklearn model
all_preds = []
# Use self.models which now contains the loaded models
for model in self.models:
preds = model.predict(combined_embeddings)
# Convert numpy array predictions to float tensor
all_preds.append(torch.from_numpy(preds).float())
# Stack predictions along a new dimension (dim=1) -> (batch_size, num_weak_learners)
# Move the final stacked tensor to the designated device (e.g., GPU if specified)
stacked_preds = torch.stack(all_preds, dim=1).to(self.device)
return stacked_preds
# --- Stacking Model (Meta-Learner) ---
class StackingMetaLearner(nn.Module):
"""
A simple feed-forward network that learns to combine predictions
from weak learners. Structure needs to match the saved weights.
"""
def __init__(self, weak_output_dim: int = 3, hidden_dim: int = 256) -> None:
super().__init__()
# Check for valid dimensions
if weak_output_dim <= 0 or hidden_dim <= 0:
raise ValueError("weak_output_dim and hidden_dim must be positive integers.")
self.fc1 = nn.Linear(weak_output_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, 1) # Predict a single score
def load_state_dict_from_file(self, filepath: str, device: torch.device) -> None:
if not os.path.exists(filepath):
raise FileNotFoundError(f"Meta-learner state file not found at {filepath}")
try:
# Load state dict onto the specified device directly
state_dict = torch.load(filepath, map_location=device)
self.load_state_dict(state_dict)
# Move model to the device *after* loading state_dict
self.to(device)
self.eval() # Set to evaluation mode after loading
logger.info(f" StackingMetaLearner state loaded successfully from {filepath} onto device {device}.")
except Exception as e:
logger.error(f"Error loading StackingMetaLearner state_dict from {filepath}: {e}")
raise RuntimeError(f"Failed to load StackingMetaLearner state: {e}") from e
def forward(self, weak_outputs: torch.Tensor) -> torch.Tensor:
"""
Args:
weak_outputs (torch.Tensor): Tensor of shape (batch_size, weak_output_dim)
containing predictions from weak learners.
Returns:
torch.Tensor: Final prediction of shape (batch_size, 1).
"""
x = self.relu(self.fc1(weak_outputs))
x = self.fc2(x)
return x
# --- Main Ensemble Model (Wrapper used internally now) ---
class SSLEnsembleModel(nn.Module):
"""
Combines WeakLearners and a StackingMetaLearner.
Assumes WeakLearners are loaded externally and MetaLearner state_dict is loaded externally.
This class mainly defines the forward pass logic using the components.
"""
def __init__(self, weak_learners: WeakLearners, stacking_meta_learner: StackingMetaLearner) -> None:
super().__init__()
if not isinstance(weak_learners, WeakLearners) or not weak_learners.fitted:
raise ValueError("A pre-loaded WeakLearners instance must be provided to SSLEnsembleModel.")
if not isinstance(stacking_meta_learner, StackingMetaLearner):
raise ValueError("A StackingMetaLearner instance must be provided.")
self.weak_learners = weak_learners
self.stacking_meta_learner = stacking_meta_learner
def forward(self, audio_emb: torch.Tensor, text_emb: Optional[torch.Tensor]) -> torch.Tensor:
"""
Forward pass through the ensemble.
1. Get predictions from weak learners.
2. Pass weak learner predictions to the stacking meta-learner.
"""
# Get predictions from weak learners (shape: batch_size, num_weak_outputs)
# WeakLearners forward handles device movement internally
weak_outputs = self.weak_learners(audio_emb, text_emb) # Output is on weak_learners.device
# Ensure weak_outputs are on the same device as the meta-learner before passing
weak_outputs = weak_outputs.to(next(self.stacking_meta_learner.parameters()).device)
# Get final prediction from the meta-learner
final_output = self.stacking_meta_learner(weak_outputs) # Output shape (batch_size, 1)
return final_output
# --- Unified Main Model Class ---
class WhisperSSLEnsemble(PreTrainedModel): # type: ignore
"""
Unified model using Whisper for audio, an optional transformer for text,
and an optional SSL Ensemble (Weak + Meta learners) for final prediction.
Inherits from BaseMultimodalModel.
"""
config_class = WhisperSSLEnsembleConfig
def __init__(self, config: WhisperSSLEnsembleConfig) -> None:
# self,
# whisper_variant: str = "base.en",
# text_model_type: str = "bert-base-uncased",
# weights: str = None, # Path to Whisper weights if not using standard variant download
# device: str = None,
# ssl_ensemble_config: Optional[Dict] = None,
# super().__init__(weights) # Not calling super for BaseModel due to different init logic
super().__init__(config)
self.config = config
whisper_variant = self.config.whisper_variant
text_model_type = self.config.text_model_type
weights = self.config.weights
ssl_ensemble_config = self.config.ssl_ensemble_config
# device = 'cpu'
# if self.config.device is None:
# device = "cuda" if torch.cuda.is_available() else "cpu"
# else:
# device = self.config.device
# self.device = torch.device(device)
self.predict_mode = False
self.ssl_ensemble_model = None
self.tokenizer = None
self.text_model = None
self._audio_embedding_dim = None
self._text_embedding_dim = 0
# logger.info(f"Initializing WhisperSSLEnsemble on device: {self.device}")
try:
# Determine if 'weights' is a path or a variant name for whisper.load_model
# whisper.load_model can take a name like "base.en" or a path to a .pt file
logger.info(f"Loading Whisper model: '{weights if weights else whisper_variant}'...")
wm = whisper.load_model(weights if weights else whisper_variant, device=self.device)
self.whisper_model = wm
self.whisper_model.eval()
self._audio_embedding_dim = self.whisper_model.encoder.ln_post.normalized_shape[0]
logger.info(f" Whisper loaded. Audio embedding dimension: {self._audio_embedding_dim}")
except Exception as e:
logger.error(f"Error loading Whisper model: {e}")
raise RuntimeError(f"Failed to load Whisper model '{weights or whisper_variant}'") from e
self.use_text = text_model_type is not None and text_model_type.lower() != "none"
if self.use_text:
logger.info(f"Loading Text model '{text_model_type}'...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(text_model_type)
self.text_model = AutoModel.from_pretrained(text_model_type).to(self.device)
self.text_model.eval()
self._text_embedding_dim = self.text_model.config.hidden_size
logger.info(f" Text model loaded. Text embedding dimension: {self._text_embedding_dim}")
except Exception as e:
logger.warning(f"Failed to load text model '{text_model_type}'. Error: {e}")
logger.warning(" Text processing will be disabled.")
self.use_text = False
self._text_embedding_dim = 0
else:
logger.info("Text model type is 'none'. Text processing disabled.")
self._text_embedding_dim = 0
if ssl_ensemble_config is not None:
logger.info("SSL Ensemble config provided. Initializing for prediction mode...")
self.predict_mode = True
required_keys = ["weak_learners_path", "meta_learner_path", "audio_dim", "text_dim", "hidden_dim"]
if not all(key in ssl_ensemble_config for key in required_keys):
raise ValueError(f"ssl_ensemble_config missing required keys: {required_keys}")
cfg_audio_dim = ssl_ensemble_config["audio_dim"]
cfg_text_dim = ssl_ensemble_config["text_dim"]
if cfg_audio_dim != self._audio_embedding_dim:
logger.warning(f"Ensemble config audio_dim ({cfg_audio_dim}) mismatches Whisper model dim ({self._audio_embedding_dim}).")
# Validate text_dim from config against loaded text model
if self.use_text and cfg_text_dim != self._text_embedding_dim:
logger.warning(f"Ensemble config text_dim ({cfg_text_dim}) mismatches loaded text model dim ({self._text_embedding_dim}).")
elif not self.use_text and cfg_text_dim != 0:
logger.warning(f"Ensemble config expects text_dim={cfg_text_dim}, but no text model is loaded. WeakLearners must handle padding.")
elif self.use_text and cfg_text_dim == 0: # Text model loaded, but ensemble config says text_dim=0
logger.warning("Ensemble config expects text_dim=0, but a text model IS loaded. "
"Text embeddings will be ignored by weak learners if their text_dim is 0.")
weak_learners = WeakLearners(audio_dim=cfg_audio_dim, text_dim=cfg_text_dim, device=self.device)
if not weak_learners.load_fitted(ssl_ensemble_config["weak_learners_path"]):
raise RuntimeError("Failed to load weak learners for WhisperSSLEnsemble.")
meta_learner = StackingMetaLearner(weak_output_dim=len(weak_learners.models), hidden_dim=ssl_ensemble_config["hidden_dim"])
meta_learner.load_state_dict_from_file(ssl_ensemble_config["meta_learner_path"], device=self.device)
self.ssl_ensemble_model = SSLEnsembleModel(weak_learners=weak_learners, stacking_meta_learner=meta_learner)
self.ssl_ensemble_model.eval()
logger.info("SSLEnsembleModel loaded successfully.")
else:
logger.info("No SSL Ensemble config provided. Model initialized in embedding mode.")
logger.info("Call '.get_embeddings(audios, texts)' to get embeddings.")
logger.info("Calling '.predict()' or '.forward()' will raise an error in this mode.")
def preprocess_audio(self, audios: List[Union[np.ndarray, torch.Tensor]]) -> torch.Tensor:
processed_mels = []
for audio in audios:
if isinstance(audio, torch.Tensor):
audio = audio.cpu().numpy()
if not isinstance(audio, np.ndarray):
raise TypeError(f"Expected audio input to be numpy array or torch tensor, got {type(audio)}")
if audio.dtype != np.float32:
audio = audio.astype(np.float32)
audio_proc = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio_proc, device=self.whisper_model.device)
processed_mels.append(mel)
if not processed_mels:
# Using Whisper's typical dimensions: 80 mel bins, 3000 frames (30s)
return torch.empty((0, self.whisper_model.dims.n_mels, self.whisper_model.dims.n_audio_ctx), device=self.whisper_model.device)
# Simpler: use known default values if whisper model object doesn't expose them easily
# return torch.empty((0, 80, 3000), device=self.whisper_model.device)
mels_batch = torch.stack(processed_mels)
return mels_batch
def preprocess_text(self, texts: List[str]) -> Optional[Dict[str, torch.Tensor]]:
if not self.use_text or not self.tokenizer:
return None
if not texts:
return {
"input_ids": torch.empty((0, 0), dtype=torch.long, device=self.device),
"attention_mask": torch.empty((0, 0), dtype=torch.long, device=self.device)
}
pad_token = self.tokenizer.pad_token if self.tokenizer.pad_token else "[PAD]"
processed_texts = [t if isinstance(t, str) and t else pad_token for t in texts]
inputs = self.tokenizer(processed_texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
return inputs
def get_embeddings(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> tuple:
if self.use_text and texts is None:
pass # Allow texts to be None, WeakLearners will handle if it expects text.
if not self.use_text and texts is not None:
logger.warning("Text input provided to get_embeddings but no text model loaded. Text will be ignored.")
texts = None # Ensure texts list is None if not used by this WhisperSSLEnsemble instance
mels = self.preprocess_audio(audios)
# Handle empty batch input for embeddings
if mels.numel() == 0: # No audio provided
audio_emb = torch.empty((0, self._audio_embedding_dim), device=self.device) # type: ignore
text_emb = None
if self.use_text:
text_emb = torch.empty((0, self._text_embedding_dim), device=self.device)
return audio_emb, text_emb
audio_emb = torch.empty((mels.shape[0], self._audio_embedding_dim), device=self.device) # type: ignore
if mels.numel() > 0:
with torch.no_grad():
encoder_output = self.whisper_model.encoder(mels)
audio_emb = encoder_output.mean(dim=1)
text_emb = None
if self.use_text:
# Ensure text batch matches audio if texts are provided or if texts are None but audio exists (pad with empty strings)
effective_texts = texts
if texts is None: # If no texts provided, but we use_text, create dummy empty strings for tokenization matching audio batch size
effective_texts = [""] * mels.shape[0]
elif len(texts) != mels.shape[0] and mels.shape[0] > 0: # If texts length mismatch audio batch, this is an issue or needs defined behavior
# For now, if texts are provided but mismatch, we might just process what's given or raise error.
# Let's assume if texts are given, they match the audio batch. Or pad/truncate.
# This part of logic may need refinement based on expected use case for mismatched batch sizes.
# For simplicity, we proceed with `effective_texts` as is, assuming `preprocess_text` handles it.
# If texts is None and use_text is true, effective_texts will be list of empty strings matching audio batch.
pass
tokenized_inputs = self.preprocess_text(effective_texts) # type: ignore
if tokenized_inputs and tokenized_inputs["input_ids"].numel() > 0:
with torch.no_grad():
outputs = self.text_model(**tokenized_inputs) # type: ignore
text_emb = outputs.last_hidden_state[:, 0, :]
elif self._text_embedding_dim > 0: # Texts were None or empty list resulting in no tokenized_inputs, but text is expected
text_emb = torch.empty((mels.shape[0], self._text_embedding_dim), device=self.device)
audio_emb = audio_emb.to(self.device)
if text_emb is not None:
text_emb = text_emb.to(self.device)
if text_emb.shape[0] != audio_emb.shape[0] and audio_emb.shape[0] > 0:
pass
return audio_emb, text_emb
def forward(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> torch.Tensor:
"""
Overrides BaseMultimodalModel.forward.
Processes raw audio and text inputs through the full pipeline (embeddings + SSL ensemble)
to get final scores as a Tensor. This method is called by self.predict().
If not in predict_mode, this method will raise a RuntimeError.
"""
if not self.predict_mode or self.ssl_ensemble_model is None:
raise RuntimeError("Cannot call 'forward' for final predictions. Model was not initialized with 'ssl_ensemble_config' for prediction mode. "
"Use 'get_embeddings()' if you only need embeddings.")
# Determine text input strategy for get_embeddings based on ensemble and model config
weak_learners_cfg_text_dim = self.ssl_ensemble_model.weak_learners.text_dim
texts_for_embeddings = texts # Default to provided texts
if weak_learners_cfg_text_dim > 0: # Ensemble's weak learners expect text features
if not self.use_text: # This WhisperSSLEnsemble instance has no text model
logger.warning("WhisperSSLEnsemble.forward: Ensemble's weak learners expect text (config text_dim > 0), "
"but this model instance has no text processor (e.g., text_model_type='none' or load failed). "
"Embeddings passed to weak learners will have text_emb=None. WeakLearners must handle this (e.g. by padding).")
texts_for_embeddings = None # Ensure get_embeddings receives None for text
elif texts is None: # Model has text processor, ensemble expects text, but no 'texts' arg provided
raise ValueError("WhisperSSLEnsemble.forward: Ensemble's weak learners require text input (config text_dim > 0), but 'texts' argument is None.")
elif weak_learners_cfg_text_dim == 0: # Ensemble's weak learners DO NOT expect text features
if texts is not None:
logger.warning("WhisperSSLEnsemble.forward: Ensemble's weak learners do not expect text (config text_dim == 0),"
" but text input was provided. If a text model is loaded in this instance, text embeddings might be computed by "
"get_embeddings but will then be ignored by the weak learners.")
# 1. Get Embeddings
audio_emb, text_emb = self.get_embeddings(audios, texts_for_embeddings)
if audio_emb.shape[0] == 0: # Handle empty batch case after getting embeddings
return torch.empty((0, 1), device=self.device) # Return empty scores tensor, shape (batch, 1)
# 2. Pass embeddings through the SSLEnsembleModel
with torch.no_grad():
final_predictions_tensor = self.ssl_ensemble_model(audio_emb, text_emb) # Shape: (batch_size, 1)
return final_predictions_tensor
def predict(self, audios: List[Union[np.ndarray, torch.Tensor]], texts: Optional[List[str]] = None) -> List[float]:
"""
Overrides BaseMultimodalModel.predict.
Generates final predictions using the full pipeline (Embeddings + SSL Ensemble).
Requires the model to be initialized in prediction mode.
"""
final_predictions_tensor = self.forward(audios, texts)
if final_predictions_tensor.numel() == 0:
return []
scores = final_predictions_tensor.detach().cpu().numpy().flatten().tolist()
return scores
AutoConfig.register("whisper-bert", WhisperSSLEnsembleConfig)
AutoModel.register(WhisperSSLEnsembleConfig, WhisperSSLEnsemble)