|
|
""" |
|
|
Router Model Architecture for Smart ASR Routing. |
|
|
|
|
|
Regression-based approach: predicts WER for each backend model. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Dict |
|
|
|
|
|
from transformers import PreTrainedModel, PretrainedConfig, WhisperModel, WhisperFeatureExtractor |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
|
|
|
|
|
|
class AttentionPooling(nn.Module): |
|
|
"""Learnable attention pooling for variable-length sequences.""" |
|
|
|
|
|
def __init__(self, input_dim: int): |
|
|
super().__init__() |
|
|
self.attention = nn.Sequential( |
|
|
nn.Linear(input_dim, 1), |
|
|
nn.Tanh() |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: [Batch, Time, Dim] |
|
|
mask: [Batch, Time] (1 for valid, 0 for pad) |
|
|
Returns: |
|
|
pooled: [Batch, Dim] |
|
|
""" |
|
|
scores = self.attention(x) |
|
|
|
|
|
if mask is not None: |
|
|
scores = scores.masked_fill(mask.unsqueeze(-1) == 0, -1e9) |
|
|
|
|
|
weights = F.softmax(scores, dim=1) |
|
|
return torch.sum(x * weights, dim=1) |
|
|
|
|
|
|
|
|
class ASRRouterConfig(PretrainedConfig): |
|
|
"""Configuration for ASRRouter model.""" |
|
|
model_type = "asr_router" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int = 384, |
|
|
hidden_dim: int = 128, |
|
|
intermediate_dim: int = 64, |
|
|
dropout: float = 0.1, |
|
|
num_models: int = 3, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.input_dim = input_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
self.intermediate_dim = intermediate_dim |
|
|
self.dropout = dropout |
|
|
self.num_models = num_models |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RouterOutput(ModelOutput): |
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
pred_wers: torch.FloatTensor = None |
|
|
|
|
|
|
|
|
class ASRRouterModel(PreTrainedModel): |
|
|
""" |
|
|
Regression Router. |
|
|
Input: 384-dimensional Whisper encoder embeddings |
|
|
Output: Estimated WER (0.0+, unbounded) for each backend model. |
|
|
Uses Softplus activation to ensure non-negative outputs while allowing WER > 1.0. |
|
|
""" |
|
|
config_class = ASRRouterConfig |
|
|
|
|
|
MODEL_ID_MAP = {0: "kyutai", 1: "granite", 2: "tiny_audio"} |
|
|
|
|
|
def __init__(self, config: ASRRouterConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.network = nn.Sequential( |
|
|
nn.Linear(config.input_dim, config.hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.LayerNorm(config.hidden_dim), |
|
|
nn.Dropout(config.dropout), |
|
|
|
|
|
nn.Linear(config.hidden_dim, config.intermediate_dim), |
|
|
nn.GELU(), |
|
|
nn.LayerNorm(config.intermediate_dim), |
|
|
|
|
|
nn.Linear(config.intermediate_dim, config.num_models) |
|
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
embeddings: torch.Tensor, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
) -> RouterOutput: |
|
|
|
|
|
|
|
|
pred_wers = F.softplus(self.network(embeddings)) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = F.mse_loss(pred_wers, labels) |
|
|
|
|
|
return RouterOutput(loss=loss, pred_wers=pred_wers) |
|
|
|
|
|
def predict_proba(self, embeddings: torch.Tensor) -> torch.Tensor: |
|
|
"""Get predicted WERs for each model.""" |
|
|
with torch.no_grad(): |
|
|
return F.softplus(self.network(embeddings)) |
|
|
|
|
|
|
|
|
class RouterWithFeatureExtractor: |
|
|
""" |
|
|
Production-ready router with attention pooling and memory optimizations. |
|
|
""" |
|
|
def __init__(self, router: ASRRouterModel, device: str = "cpu"): |
|
|
self.device = device |
|
|
self.router = router.to(device) |
|
|
self.router.eval() |
|
|
|
|
|
|
|
|
self.attention_pooling = AttentionPooling(input_dim=384).to(device) |
|
|
self.attention_pooling.eval() |
|
|
|
|
|
|
|
|
print("Loading Whisper Encoder...") |
|
|
full_whisper = WhisperModel.from_pretrained("openai/whisper-tiny") |
|
|
self.whisper_encoder = full_whisper.encoder.to(device) |
|
|
self.whisper_encoder.eval() |
|
|
|
|
|
del full_whisper.decoder |
|
|
del full_whisper |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
self.feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny") |
|
|
|
|
|
def extract_features(self, waveform: torch.Tensor) -> torch.Tensor: |
|
|
"""Extract embeddings using Attention Pooling for variable lengths.""" |
|
|
if waveform.dim() == 1: |
|
|
waveform = waveform.unsqueeze(0) |
|
|
|
|
|
|
|
|
audio_np = [w.cpu().numpy() for w in waveform] |
|
|
|
|
|
inputs = self.feature_extractor( |
|
|
audio_np, |
|
|
sampling_rate=16000, |
|
|
return_tensors="pt", |
|
|
return_attention_mask=True |
|
|
) |
|
|
|
|
|
input_features = inputs.input_features.to(self.device) |
|
|
attention_mask = inputs.attention_mask.to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
last_hidden_state = self.whisper_encoder(input_features).last_hidden_state |
|
|
|
|
|
|
|
|
mask_resized = F.interpolate( |
|
|
attention_mask.unsqueeze(1).float(), |
|
|
size=last_hidden_state.shape[1], |
|
|
mode='nearest' |
|
|
).squeeze(1) |
|
|
|
|
|
|
|
|
return self.attention_pooling(last_hidden_state, mask_resized) |
|
|
|
|
|
def predict(self, waveform: torch.Tensor) -> Dict: |
|
|
"""Select the model with the lowest predicted WER.""" |
|
|
embeddings = self.extract_features(waveform) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.router(embeddings) |
|
|
pred_wers = output.pred_wers[0].cpu().numpy() |
|
|
|
|
|
scores = { |
|
|
"kyutai": float(pred_wers[0]), |
|
|
"granite": float(pred_wers[1]), |
|
|
"tiny_audio": float(pred_wers[2]) |
|
|
} |
|
|
|
|
|
best_model = min(scores.items(), key=lambda x: x[1]) |
|
|
|
|
|
return { |
|
|
"selected_model": best_model[0], |
|
|
"predicted_wers": scores, |
|
|
"confidence": max(0.0, 1.0 - best_model[1]) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ASRRouterConfig.register_for_auto_class() |
|
|
ASRRouterModel.register_for_auto_class("AutoModel") |
|
|
|