File size: 6,703 Bytes
d0469c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
Router Model Architecture for Smart ASR Routing.

Regression-based approach: predicts WER for each backend model.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Dict

from transformers import PreTrainedModel, PretrainedConfig, WhisperModel, WhisperFeatureExtractor
from transformers.modeling_outputs import ModelOutput


class AttentionPooling(nn.Module):
    """Learnable attention pooling for variable-length sequences."""

    def __init__(self, input_dim: int):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, 1),
            nn.Tanh()
        )

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Args:
            x: [Batch, Time, Dim]
            mask: [Batch, Time] (1 for valid, 0 for pad)
        Returns:
            pooled: [Batch, Dim]
        """
        scores = self.attention(x)  # [Batch, Time, 1]

        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(-1) == 0, -1e9)

        weights = F.softmax(scores, dim=1)  # [Batch, Time, 1]
        return torch.sum(x * weights, dim=1)  # [Batch, Dim]


class ASRRouterConfig(PretrainedConfig):
    """Configuration for ASRRouter model."""
    model_type = "asr_router"

    def __init__(
        self,
        input_dim: int = 384,  # whisper-tiny encoder dim
        hidden_dim: int = 128,
        intermediate_dim: int = 64,
        dropout: float = 0.1,  # Lower dropout for regression
        num_models: int = 3,   # Number of backends to predict scores for
        **kwargs
    ):
        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.intermediate_dim = intermediate_dim
        self.dropout = dropout
        self.num_models = num_models


@dataclass
class RouterOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    pred_wers: torch.FloatTensor = None  # Predicted WER for each model


class ASRRouterModel(PreTrainedModel):
    """
    Regression Router.
    Input: 384-dimensional Whisper encoder embeddings
    Output: Estimated WER (0.0+, unbounded) for each backend model.
    Uses Softplus activation to ensure non-negative outputs while allowing WER > 1.0.
    """
    config_class = ASRRouterConfig

    MODEL_ID_MAP = {0: "kyutai", 1: "granite", 2: "tiny_audio"}

    def __init__(self, config: ASRRouterConfig):
        super().__init__(config)

        self.network = nn.Sequential(
            nn.Linear(config.input_dim, config.hidden_dim),
            nn.GELU(),
            nn.LayerNorm(config.hidden_dim),  # Better for batch_size=1
            nn.Dropout(config.dropout),

            nn.Linear(config.hidden_dim, config.intermediate_dim),
            nn.GELU(),
            nn.LayerNorm(config.intermediate_dim),

            nn.Linear(config.intermediate_dim, config.num_models)
        )

        self.post_init()

    def forward(
        self,
        embeddings: torch.Tensor,
        labels: Optional[torch.Tensor] = None,  # Actual WERs from ground truth
    ) -> RouterOutput:

        # Softplus for unbounded positive WER (WER can exceed 1.0)
        pred_wers = F.softplus(self.network(embeddings))

        loss = None
        if labels is not None:
            loss = F.mse_loss(pred_wers, labels)

        return RouterOutput(loss=loss, pred_wers=pred_wers)

    def predict_proba(self, embeddings: torch.Tensor) -> torch.Tensor:
        """Get predicted WERs for each model."""
        with torch.no_grad():
            return F.softplus(self.network(embeddings))


class RouterWithFeatureExtractor:
    """
    Production-ready router with attention pooling and memory optimizations.
    """
    def __init__(self, router: ASRRouterModel, device: str = "cpu"):
        self.device = device
        self.router = router.to(device)
        self.router.eval()

        # Attention pooling for variable-length sequences
        self.attention_pooling = AttentionPooling(input_dim=384).to(device)
        self.attention_pooling.eval()

        # Memory Optimization: Load full model, extract encoder, delete rest
        print("Loading Whisper Encoder...")
        full_whisper = WhisperModel.from_pretrained("openai/whisper-tiny")
        self.whisper_encoder = full_whisper.encoder.to(device)
        self.whisper_encoder.eval()

        del full_whisper.decoder
        del full_whisper
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

        self.feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")

    def extract_features(self, waveform: torch.Tensor) -> torch.Tensor:
        """Extract embeddings using Attention Pooling for variable lengths."""
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)

        # Convert batch tensor to list of 1D numpy arrays (required by WhisperFeatureExtractor)
        audio_np = [w.cpu().numpy() for w in waveform]

        inputs = self.feature_extractor(
            audio_np,
            sampling_rate=16000,
            return_tensors="pt",
            return_attention_mask=True
        )

        input_features = inputs.input_features.to(self.device)
        attention_mask = inputs.attention_mask.to(self.device)

        with torch.no_grad():
            last_hidden_state = self.whisper_encoder(input_features).last_hidden_state

            # Resize mask to match encoder output temporal dimension
            mask_resized = F.interpolate(
                attention_mask.unsqueeze(1).float(),
                size=last_hidden_state.shape[1],
                mode='nearest'
            ).squeeze(1)

            # Attention Pooling
            return self.attention_pooling(last_hidden_state, mask_resized)

    def predict(self, waveform: torch.Tensor) -> Dict:
        """Select the model with the lowest predicted WER."""
        embeddings = self.extract_features(waveform)

        with torch.no_grad():
            output = self.router(embeddings)
            pred_wers = output.pred_wers[0].cpu().numpy()

        scores = {
            "kyutai": float(pred_wers[0]),
            "granite": float(pred_wers[1]),
            "tiny_audio": float(pred_wers[2])
        }

        best_model = min(scores.items(), key=lambda x: x[1])

        return {
            "selected_model": best_model[0],
            "predicted_wers": scores,
            "confidence": max(0.0, 1.0 - best_model[1])  # Clamp since WER can exceed 1.0
        }


# Register for auto classes
ASRRouterConfig.register_for_auto_class()
ASRRouterModel.register_for_auto_class("AutoModel")