""" Router Model Architecture for Smart ASR Routing. Regression-based approach: predicts WER for each backend model. """ import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from typing import Optional, Dict from transformers import PreTrainedModel, PretrainedConfig, WhisperModel, WhisperFeatureExtractor from transformers.modeling_outputs import ModelOutput class AttentionPooling(nn.Module): """Learnable attention pooling for variable-length sequences.""" def __init__(self, input_dim: int): super().__init__() self.attention = nn.Sequential( nn.Linear(input_dim, 1), nn.Tanh() ) def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: """ Args: x: [Batch, Time, Dim] mask: [Batch, Time] (1 for valid, 0 for pad) Returns: pooled: [Batch, Dim] """ scores = self.attention(x) # [Batch, Time, 1] if mask is not None: scores = scores.masked_fill(mask.unsqueeze(-1) == 0, -1e9) weights = F.softmax(scores, dim=1) # [Batch, Time, 1] return torch.sum(x * weights, dim=1) # [Batch, Dim] class ASRRouterConfig(PretrainedConfig): """Configuration for ASRRouter model.""" model_type = "asr_router" def __init__( self, input_dim: int = 384, # whisper-tiny encoder dim hidden_dim: int = 128, intermediate_dim: int = 64, dropout: float = 0.1, # Lower dropout for regression num_models: int = 3, # Number of backends to predict scores for **kwargs ): super().__init__(**kwargs) self.input_dim = input_dim self.hidden_dim = hidden_dim self.intermediate_dim = intermediate_dim self.dropout = dropout self.num_models = num_models @dataclass class RouterOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None pred_wers: torch.FloatTensor = None # Predicted WER for each model class ASRRouterModel(PreTrainedModel): """ Regression Router. Input: 384-dimensional Whisper encoder embeddings Output: Estimated WER (0.0+, unbounded) for each backend model. Uses Softplus activation to ensure non-negative outputs while allowing WER > 1.0. """ config_class = ASRRouterConfig MODEL_ID_MAP = {0: "kyutai", 1: "granite", 2: "tiny_audio"} def __init__(self, config: ASRRouterConfig): super().__init__(config) self.network = nn.Sequential( nn.Linear(config.input_dim, config.hidden_dim), nn.GELU(), nn.LayerNorm(config.hidden_dim), # Better for batch_size=1 nn.Dropout(config.dropout), nn.Linear(config.hidden_dim, config.intermediate_dim), nn.GELU(), nn.LayerNorm(config.intermediate_dim), nn.Linear(config.intermediate_dim, config.num_models) ) self.post_init() def forward( self, embeddings: torch.Tensor, labels: Optional[torch.Tensor] = None, # Actual WERs from ground truth ) -> RouterOutput: # Softplus for unbounded positive WER (WER can exceed 1.0) pred_wers = F.softplus(self.network(embeddings)) loss = None if labels is not None: loss = F.mse_loss(pred_wers, labels) return RouterOutput(loss=loss, pred_wers=pred_wers) def predict_proba(self, embeddings: torch.Tensor) -> torch.Tensor: """Get predicted WERs for each model.""" with torch.no_grad(): return F.softplus(self.network(embeddings)) class RouterWithFeatureExtractor: """ Production-ready router with attention pooling and memory optimizations. """ def __init__(self, router: ASRRouterModel, device: str = "cpu"): self.device = device self.router = router.to(device) self.router.eval() # Attention pooling for variable-length sequences self.attention_pooling = AttentionPooling(input_dim=384).to(device) self.attention_pooling.eval() # Memory Optimization: Load full model, extract encoder, delete rest print("Loading Whisper Encoder...") full_whisper = WhisperModel.from_pretrained("openai/whisper-tiny") self.whisper_encoder = full_whisper.encoder.to(device) self.whisper_encoder.eval() del full_whisper.decoder del full_whisper torch.cuda.empty_cache() if torch.cuda.is_available() else None self.feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny") def extract_features(self, waveform: torch.Tensor) -> torch.Tensor: """Extract embeddings using Attention Pooling for variable lengths.""" if waveform.dim() == 1: waveform = waveform.unsqueeze(0) # Convert batch tensor to list of 1D numpy arrays (required by WhisperFeatureExtractor) audio_np = [w.cpu().numpy() for w in waveform] inputs = self.feature_extractor( audio_np, sampling_rate=16000, return_tensors="pt", return_attention_mask=True ) input_features = inputs.input_features.to(self.device) attention_mask = inputs.attention_mask.to(self.device) with torch.no_grad(): last_hidden_state = self.whisper_encoder(input_features).last_hidden_state # Resize mask to match encoder output temporal dimension mask_resized = F.interpolate( attention_mask.unsqueeze(1).float(), size=last_hidden_state.shape[1], mode='nearest' ).squeeze(1) # Attention Pooling return self.attention_pooling(last_hidden_state, mask_resized) def predict(self, waveform: torch.Tensor) -> Dict: """Select the model with the lowest predicted WER.""" embeddings = self.extract_features(waveform) with torch.no_grad(): output = self.router(embeddings) pred_wers = output.pred_wers[0].cpu().numpy() scores = { "kyutai": float(pred_wers[0]), "granite": float(pred_wers[1]), "tiny_audio": float(pred_wers[2]) } best_model = min(scores.items(), key=lambda x: x[1]) return { "selected_model": best_model[0], "predicted_wers": scores, "confidence": max(0.0, 1.0 - best_model[1]) # Clamp since WER can exceed 1.0 } # Register for auto classes ASRRouterConfig.register_for_auto_class() ASRRouterModel.register_for_auto_class("AutoModel")