tiny-router / router_model.py
mazesmazes's picture
Upload model
d0469c4 verified
"""
Router Model Architecture for Smart ASR Routing.
Regression-based approach: predicts WER for each backend model.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Dict
from transformers import PreTrainedModel, PretrainedConfig, WhisperModel, WhisperFeatureExtractor
from transformers.modeling_outputs import ModelOutput
class AttentionPooling(nn.Module):
"""Learnable attention pooling for variable-length sequences."""
def __init__(self, input_dim: int):
super().__init__()
self.attention = nn.Sequential(
nn.Linear(input_dim, 1),
nn.Tanh()
)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
Args:
x: [Batch, Time, Dim]
mask: [Batch, Time] (1 for valid, 0 for pad)
Returns:
pooled: [Batch, Dim]
"""
scores = self.attention(x) # [Batch, Time, 1]
if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(-1) == 0, -1e9)
weights = F.softmax(scores, dim=1) # [Batch, Time, 1]
return torch.sum(x * weights, dim=1) # [Batch, Dim]
class ASRRouterConfig(PretrainedConfig):
"""Configuration for ASRRouter model."""
model_type = "asr_router"
def __init__(
self,
input_dim: int = 384, # whisper-tiny encoder dim
hidden_dim: int = 128,
intermediate_dim: int = 64,
dropout: float = 0.1, # Lower dropout for regression
num_models: int = 3, # Number of backends to predict scores for
**kwargs
):
super().__init__(**kwargs)
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.dropout = dropout
self.num_models = num_models
@dataclass
class RouterOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
pred_wers: torch.FloatTensor = None # Predicted WER for each model
class ASRRouterModel(PreTrainedModel):
"""
Regression Router.
Input: 384-dimensional Whisper encoder embeddings
Output: Estimated WER (0.0+, unbounded) for each backend model.
Uses Softplus activation to ensure non-negative outputs while allowing WER > 1.0.
"""
config_class = ASRRouterConfig
MODEL_ID_MAP = {0: "kyutai", 1: "granite", 2: "tiny_audio"}
def __init__(self, config: ASRRouterConfig):
super().__init__(config)
self.network = nn.Sequential(
nn.Linear(config.input_dim, config.hidden_dim),
nn.GELU(),
nn.LayerNorm(config.hidden_dim), # Better for batch_size=1
nn.Dropout(config.dropout),
nn.Linear(config.hidden_dim, config.intermediate_dim),
nn.GELU(),
nn.LayerNorm(config.intermediate_dim),
nn.Linear(config.intermediate_dim, config.num_models)
)
self.post_init()
def forward(
self,
embeddings: torch.Tensor,
labels: Optional[torch.Tensor] = None, # Actual WERs from ground truth
) -> RouterOutput:
# Softplus for unbounded positive WER (WER can exceed 1.0)
pred_wers = F.softplus(self.network(embeddings))
loss = None
if labels is not None:
loss = F.mse_loss(pred_wers, labels)
return RouterOutput(loss=loss, pred_wers=pred_wers)
def predict_proba(self, embeddings: torch.Tensor) -> torch.Tensor:
"""Get predicted WERs for each model."""
with torch.no_grad():
return F.softplus(self.network(embeddings))
class RouterWithFeatureExtractor:
"""
Production-ready router with attention pooling and memory optimizations.
"""
def __init__(self, router: ASRRouterModel, device: str = "cpu"):
self.device = device
self.router = router.to(device)
self.router.eval()
# Attention pooling for variable-length sequences
self.attention_pooling = AttentionPooling(input_dim=384).to(device)
self.attention_pooling.eval()
# Memory Optimization: Load full model, extract encoder, delete rest
print("Loading Whisper Encoder...")
full_whisper = WhisperModel.from_pretrained("openai/whisper-tiny")
self.whisper_encoder = full_whisper.encoder.to(device)
self.whisper_encoder.eval()
del full_whisper.decoder
del full_whisper
torch.cuda.empty_cache() if torch.cuda.is_available() else None
self.feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
def extract_features(self, waveform: torch.Tensor) -> torch.Tensor:
"""Extract embeddings using Attention Pooling for variable lengths."""
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
# Convert batch tensor to list of 1D numpy arrays (required by WhisperFeatureExtractor)
audio_np = [w.cpu().numpy() for w in waveform]
inputs = self.feature_extractor(
audio_np,
sampling_rate=16000,
return_tensors="pt",
return_attention_mask=True
)
input_features = inputs.input_features.to(self.device)
attention_mask = inputs.attention_mask.to(self.device)
with torch.no_grad():
last_hidden_state = self.whisper_encoder(input_features).last_hidden_state
# Resize mask to match encoder output temporal dimension
mask_resized = F.interpolate(
attention_mask.unsqueeze(1).float(),
size=last_hidden_state.shape[1],
mode='nearest'
).squeeze(1)
# Attention Pooling
return self.attention_pooling(last_hidden_state, mask_resized)
def predict(self, waveform: torch.Tensor) -> Dict:
"""Select the model with the lowest predicted WER."""
embeddings = self.extract_features(waveform)
with torch.no_grad():
output = self.router(embeddings)
pred_wers = output.pred_wers[0].cpu().numpy()
scores = {
"kyutai": float(pred_wers[0]),
"granite": float(pred_wers[1]),
"tiny_audio": float(pred_wers[2])
}
best_model = min(scores.items(), key=lambda x: x[1])
return {
"selected_model": best_model[0],
"predicted_wers": scores,
"confidence": max(0.0, 1.0 - best_model[1]) # Clamp since WER can exceed 1.0
}
# Register for auto classes
ASRRouterConfig.register_for_auto_class()
ASRRouterModel.register_for_auto_class("AutoModel")