mazesmazes commited on
Commit
52fae00
·
verified ·
1 Parent(s): 3cc13cb

Training in progress - step 2000

Browse files
Files changed (9) hide show
  1. README.md +199 -0
  2. alignment.py +286 -0
  3. asr_config.py +262 -0
  4. asr_modeling.py +1069 -0
  5. asr_pipeline.py +368 -0
  6. asr_processing.py +132 -0
  7. diarization.py +730 -0
  8. preprocessor_config.json +19 -0
  9. projectors.py +493 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
alignment.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Forced alignment for word-level timestamps using Wav2Vec2."""
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def _get_device() -> str:
8
+ """Get best available device for non-transformers models."""
9
+ if torch.cuda.is_available():
10
+ return "cuda"
11
+ if torch.backends.mps.is_available():
12
+ return "mps"
13
+ return "cpu"
14
+
15
+
16
+ class ForcedAligner:
17
+ """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
18
+
19
+ Uses Viterbi trellis algorithm for optimal alignment path finding.
20
+ """
21
+
22
+ _bundle = None
23
+ _model = None
24
+ _labels = None
25
+ _dictionary = None
26
+
27
+ @classmethod
28
+ def get_instance(cls, device: str = "cuda"):
29
+ """Get or create the forced alignment model (singleton).
30
+
31
+ Args:
32
+ device: Device to run model on ("cuda" or "cpu")
33
+
34
+ Returns:
35
+ Tuple of (model, labels, dictionary)
36
+ """
37
+ if cls._model is None:
38
+ import torchaudio
39
+
40
+ cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
41
+ cls._model = cls._bundle.get_model().to(device)
42
+ cls._model.eval()
43
+ cls._labels = cls._bundle.get_labels()
44
+ cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
45
+ return cls._model, cls._labels, cls._dictionary
46
+
47
+ @staticmethod
48
+ def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
49
+ """Build trellis for forced alignment using forward algorithm.
50
+
51
+ The trellis[t, j] represents the log probability of the best path that
52
+ aligns the first j tokens to the first t frames.
53
+
54
+ Args:
55
+ emission: Log-softmax emission matrix of shape (num_frames, num_classes)
56
+ tokens: List of target token indices
57
+ blank_id: Index of the blank/CTC token (default 0)
58
+
59
+ Returns:
60
+ Trellis matrix of shape (num_frames + 1, num_tokens + 1)
61
+ """
62
+ num_frames = emission.size(0)
63
+ num_tokens = len(tokens)
64
+
65
+ trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
66
+ trellis[0, 0] = 0
67
+
68
+ for t in range(num_frames):
69
+ for j in range(num_tokens + 1):
70
+ # Stay: emit blank and stay at j tokens
71
+ stay = trellis[t, j] + emission[t, blank_id]
72
+
73
+ # Move: emit token j and advance to j+1 tokens
74
+ move = trellis[t, j - 1] + emission[t, tokens[j - 1]] if j > 0 else -float("inf")
75
+
76
+ trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
77
+
78
+ return trellis
79
+
80
+ @staticmethod
81
+ def _backtrack(
82
+ trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
83
+ ) -> list[tuple[int, float, float]]:
84
+ """Backtrack through trellis to find optimal forced monotonic alignment.
85
+
86
+ Guarantees:
87
+ - All tokens are emitted exactly once
88
+ - Strictly monotonic: each token's frames come after previous token's
89
+ - No frame skipping or token teleporting
90
+
91
+ Returns list of (token_id, start_frame, end_frame) for each token.
92
+ """
93
+ num_frames = emission.size(0)
94
+ num_tokens = len(tokens)
95
+
96
+ if num_tokens == 0:
97
+ return []
98
+
99
+ # Find the best ending point (should be at num_tokens)
100
+ # But verify trellis reached a valid state
101
+ if trellis[num_frames, num_tokens] == -float("inf"):
102
+ # Alignment failed - fall back to uniform distribution
103
+ frames_per_token = num_frames / num_tokens
104
+ return [
105
+ (tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
106
+ for i in range(num_tokens)
107
+ ]
108
+
109
+ # Backtrack: find where each token transition occurred
110
+ # path[i] = frame where token i was first emitted
111
+ token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
112
+
113
+ t = num_frames
114
+ j = num_tokens
115
+
116
+ while t > 0 and j > 0:
117
+ # Check: did we transition from j-1 to j at frame t-1?
118
+ stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
119
+ move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
120
+
121
+ if move_score >= stay_score:
122
+ # Token j-1 was emitted at frame t-1
123
+ token_frames[j - 1].append(t - 1)
124
+ j -= 1
125
+ t -= 1
126
+
127
+ # Handle any remaining tokens at the start (edge case)
128
+ while j > 0:
129
+ token_frames[j - 1].append(0)
130
+ j -= 1
131
+
132
+ # We appended in reverse-time order; restore monotonic order
133
+ for frames in token_frames:
134
+ frames.reverse()
135
+
136
+ # Convert to spans
137
+ token_spans: list[tuple[int, float, float]] = []
138
+ for token_idx, frames in enumerate(token_frames):
139
+ if not frames:
140
+ # Token never emitted - assign minimal span after previous
141
+ if token_spans:
142
+ prev_end = token_spans[-1][2]
143
+ frames = [int(prev_end)]
144
+ else:
145
+ frames = [0]
146
+
147
+ token_id = tokens[token_idx]
148
+ start_frame = float(min(frames))
149
+ end_frame = float(max(frames)) + 1.0
150
+ token_spans.append((token_id, start_frame, end_frame))
151
+
152
+ return token_spans
153
+
154
+ # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
155
+ # Calibrated on librispeech-alignments dataset
156
+ START_OFFSET = 0.06 # Subtract from start times (shift earlier)
157
+ END_OFFSET = -0.03 # Add to end times (shift later)
158
+
159
+ @classmethod
160
+ def align(
161
+ cls,
162
+ audio: np.ndarray,
163
+ text: str,
164
+ sample_rate: int = 16000,
165
+ _language: str = "eng",
166
+ _batch_size: int = 16,
167
+ ) -> list[dict]:
168
+ """Align transcript to audio and return word-level timestamps.
169
+
170
+ Uses Viterbi trellis algorithm for optimal forced alignment.
171
+
172
+ Args:
173
+ audio: Audio waveform as numpy array
174
+ text: Transcript text to align
175
+ sample_rate: Audio sample rate (default 16000)
176
+ _language: ISO-639-3 language code (default "eng" for English, unused)
177
+ _batch_size: Batch size for alignment model (unused)
178
+
179
+ Returns:
180
+ List of dicts with 'word', 'start', 'end' keys
181
+ """
182
+ import torchaudio
183
+
184
+ device = _get_device()
185
+ model, _labels, dictionary = cls.get_instance(device)
186
+ assert cls._bundle is not None and dictionary is not None # Initialized by get_instance
187
+
188
+ # Convert audio to tensor (copy to ensure array is writable)
189
+ if isinstance(audio, np.ndarray):
190
+ waveform = torch.from_numpy(audio.copy()).float()
191
+ else:
192
+ waveform = audio.clone().float()
193
+
194
+ # Ensure 2D (channels, time)
195
+ if waveform.dim() == 1:
196
+ waveform = waveform.unsqueeze(0)
197
+
198
+ # Resample if needed (wav2vec2 expects 16kHz)
199
+ if sample_rate != cls._bundle.sample_rate:
200
+ waveform = torchaudio.functional.resample(
201
+ waveform, sample_rate, cls._bundle.sample_rate
202
+ )
203
+
204
+ waveform = waveform.to(device)
205
+
206
+ # Get emissions from model
207
+ with torch.inference_mode():
208
+ emissions, _ = model(waveform)
209
+ emissions = torch.log_softmax(emissions, dim=-1)
210
+
211
+ emission = emissions[0].cpu()
212
+
213
+ # Normalize text: uppercase, keep only valid characters
214
+ transcript = text.upper()
215
+
216
+ # Build tokens from transcript (including word separators)
217
+ tokens = []
218
+ for char in transcript:
219
+ if char in dictionary:
220
+ tokens.append(dictionary[char])
221
+ elif char == " ":
222
+ tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
223
+
224
+ if not tokens:
225
+ return []
226
+
227
+ # Build Viterbi trellis and backtrack for optimal path
228
+ trellis = cls._get_trellis(emission, tokens, blank_id=0)
229
+ alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
230
+
231
+ # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
232
+ frame_duration = 320 / cls._bundle.sample_rate
233
+
234
+ # Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
235
+ start_offset = cls.START_OFFSET
236
+ end_offset = cls.END_OFFSET
237
+
238
+ # Group aligned tokens into words based on pipe separator
239
+ words = text.split()
240
+ word_timestamps = []
241
+ current_word_start = None
242
+ current_word_end = None
243
+ word_idx = 0
244
+ separator_id = dictionary.get("|", dictionary.get(" ", 0))
245
+
246
+ for token_id, start_frame, end_frame in alignment_path:
247
+ if token_id == separator_id: # Word separator
248
+ if (
249
+ current_word_start is not None
250
+ and current_word_end is not None
251
+ and word_idx < len(words)
252
+ ):
253
+ start_time = max(0.0, current_word_start * frame_duration - start_offset)
254
+ end_time = max(0.0, current_word_end * frame_duration - end_offset)
255
+ word_timestamps.append(
256
+ {
257
+ "word": words[word_idx],
258
+ "start": start_time,
259
+ "end": end_time,
260
+ }
261
+ )
262
+ word_idx += 1
263
+ current_word_start = None
264
+ current_word_end = None
265
+ else:
266
+ if current_word_start is None:
267
+ current_word_start = start_frame
268
+ current_word_end = end_frame
269
+
270
+ # Don't forget the last word
271
+ if (
272
+ current_word_start is not None
273
+ and current_word_end is not None
274
+ and word_idx < len(words)
275
+ ):
276
+ start_time = max(0.0, current_word_start * frame_duration - start_offset)
277
+ end_time = max(0.0, current_word_end * frame_duration - end_offset)
278
+ word_timestamps.append(
279
+ {
280
+ "word": words[word_idx],
281
+ "start": start_time,
282
+ "end": end_time,
283
+ }
284
+ )
285
+
286
+ return word_timestamps
asr_config.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import transformers
4
+
5
+ # Default conv layers for Whisper/GLM-ASR audio encoders: [(pad, kernel, stride), ...]
6
+ DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
7
+
8
+
9
+ def compute_encoder_output_length(mel_length, conv_layers=None):
10
+ """Apply encoder conv layer formulas to compute output length.
11
+
12
+ Works with both Python ints and torch tensors of mel lengths; the formula
13
+ `(L + 2*p - (k-1) - 1) // s + 1` per layer is identical for both.
14
+ """
15
+ layers = conv_layers if conv_layers is not None else DEFAULT_ENCODER_CONV_LAYERS
16
+ length = mel_length
17
+ for padding, kernel_size, stride in layers:
18
+ length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
19
+ return length
20
+
21
+
22
+ class ASRConfig(transformers.PretrainedConfig):
23
+ """Configuration class for the ASR model.
24
+
25
+ This config combines settings for:
26
+ - Audio encoder (GLM-ASR/Whisper)
27
+ - Text decoder (Qwen)
28
+ - Projector (MLP, MOSA, MoE, QFormer)
29
+ - Generation parameters
30
+ - Training options (LoRA)
31
+ """
32
+
33
+ model_type = "asr_model"
34
+ is_composition = True
35
+
36
+ def __init__(
37
+ self,
38
+ audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
39
+ text_model_id: str = "Qwen/Qwen3-0.6B",
40
+ attn_implementation: str = "flash_attention_2",
41
+ model_dtype: str = "bfloat16",
42
+ num_beams: Optional[int] = None,
43
+ system_prompt: str = "You are a helpful assistant.",
44
+ encoder_dim: Optional[int] = None,
45
+ llm_dim: Optional[int] = None,
46
+ # Encoder conv layers: list of (padding, kernel_size, stride) tuples
47
+ # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
48
+ encoder_conv_layers: Optional[list] = None,
49
+ audio_sample_rate: int = 16000,
50
+ projector_pool_stride: int = 4,
51
+ downsample_rate: int = 5, # Granite default
52
+ projector_hidden_dim: Optional[int] = None,
53
+ projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
54
+ projector_dropout: float = 0.0,
55
+ # Label smoothing applied inside the LM's loss function (not HF Trainer's
56
+ # LabelSmoother). Train-only — ASRModel.forward zeros it on eval. Routing
57
+ # smoothing through the loss_function flows through liger's fused linear
58
+ # CE when apply_liger_kernel_to_qwen3() is active, avoiding the
59
+ # (B,T,V) fp32 log_softmax materialization that the HF LabelSmoother
60
+ # path requires (~15GB at B=50/V=152k on Qwen3-0.6B).
61
+ label_smoothing: float = 0.0,
62
+ # MoE-specific configuration
63
+ num_experts: int = 4, # Number of experts in MoE projectors
64
+ num_experts_per_tok: int = 2, # Top-k experts per token
65
+ router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
66
+ # QFormer-specific configuration (Granite defaults)
67
+ qformer_window_size: int = 15, # Window size for QFormer processing
68
+ qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
69
+ qformer_num_layers: int = 2, # Number of QFormer transformer layers
70
+ qformer_num_heads: int = 16, # Number of attention heads in QFormer
71
+ qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
72
+ # LoRA configuration (for Stage 2 fine-tuning)
73
+ use_lora: bool = False,
74
+ lora_rank: int = 8, # SALMONN default
75
+ lora_alpha: int = 32, # SALMONN default (scaling factor 4.0)
76
+ lora_dropout: float = 0.0,
77
+ lora_target_modules: Optional[list] = None, # Default: all linear layers
78
+ freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
79
+ freeze_language_model: bool = True, # False = full decoder fine-tuning
80
+ freeze_text_embed_tokens: bool = False,
81
+ # Audio encoder is frozen by default — the published recipe treats
82
+ # GLM-ASR-Nano as a fixed feature extractor. Setting this to False
83
+ # makes the encoder trainable; pair with `encoder_learning_rate` in
84
+ # the training config to avoid destroying pretrained encoder weights
85
+ # at the projector/decoder LR.
86
+ freeze_audio_encoder: bool = True,
87
+ # SpecAugment on mel input (training-only), parameters match
88
+ # transformers' WhisperConfig / Wav2Vec2 conventions. Most relevant
89
+ # when the encoder is trainable (`freeze_audio_encoder=False`) —
90
+ # without augmentation the encoder sees identical mel inputs on
91
+ # every visit and overfits fast. Standard for ASR encoder fine-
92
+ # tuning (Whisper, Conformer, wav2vec2 all use it). Applied to
93
+ # log-mel input where zero is in-distribution (silence);
94
+ # structurally different from the prior encoder-output ZM which
95
+ # was removed because zero was OOD for the encoder's emission
96
+ # distribution. Uses `_compute_mask_indices` from
97
+ # transformers.models.whisper.modeling_whisper — the same helper
98
+ # Whisper itself uses, vectorized over the batch and torch.compile
99
+ # compatible. Default values match Whisper's defaults.
100
+ apply_spec_augment: bool = False,
101
+ mask_time_prob: float = 0.05,
102
+ mask_time_length: int = 10,
103
+ mask_time_min_masks: int = 2,
104
+ mask_feature_prob: float = 0.0,
105
+ mask_feature_length: int = 10,
106
+ mask_feature_min_masks: int = 0,
107
+ do_sample: bool = False,
108
+ temperature: Optional[float] = None,
109
+ top_p: Optional[float] = None,
110
+ top_k: Optional[int] = None,
111
+ max_new_tokens: Optional[int] = None,
112
+ min_new_tokens: Optional[int] = None,
113
+ repetition_penalty: Optional[float] = None,
114
+ length_penalty: Optional[float] = None,
115
+ no_repeat_ngram_size: Optional[int] = None,
116
+ use_cache: Optional[bool] = None,
117
+ **kwargs,
118
+ ):
119
+ """Initialize ASR model configuration.
120
+
121
+ Args:
122
+ audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
123
+ text_model_id: HuggingFace model ID for text decoder (Qwen)
124
+ attn_implementation: Attention implementation ("flash_attention_2", "sdpa", "eager")
125
+ model_dtype: Model dtype ("bfloat16", "float16", "float32")
126
+ projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
127
+ use_lora: Enable LoRA adapters for Stage 2 fine-tuning
128
+ """
129
+ # Set default generation parameters (greedy decoding only).
130
+ # Applied via setattr below — keeping these out of kwargs so they
131
+ # don't get re-overwritten by super().__init__(**kwargs) at the end.
132
+ generation_defaults = {
133
+ "num_beams": 1,
134
+ "max_new_tokens": 128,
135
+ "min_new_tokens": 0,
136
+ "repetition_penalty": 1.0,
137
+ "length_penalty": 1.0,
138
+ "no_repeat_ngram_size": 0,
139
+ "use_cache": True,
140
+ }
141
+
142
+ self.audio_model_id = audio_model_id
143
+ self.text_model_id = text_model_id
144
+ self.attn_implementation = attn_implementation
145
+ self.model_dtype = model_dtype
146
+ self.system_prompt = system_prompt
147
+ self.encoder_dim = encoder_dim
148
+ self.llm_dim = llm_dim
149
+ self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
150
+ self.audio_sample_rate = audio_sample_rate
151
+ self.projector_pool_stride = projector_pool_stride
152
+ self.downsample_rate = downsample_rate
153
+ self.projector_hidden_dim = projector_hidden_dim
154
+ self.projector_type = projector_type
155
+ self.projector_dropout = projector_dropout
156
+ self.label_smoothing = label_smoothing
157
+ # MoE-specific configuration
158
+ self.num_experts = num_experts
159
+ self.num_experts_per_tok = num_experts_per_tok
160
+ self.router_aux_loss_coef = router_aux_loss_coef
161
+ # QFormer-specific configuration
162
+ self.qformer_window_size = qformer_window_size
163
+ self.qformer_hidden_size = qformer_hidden_size
164
+ self.qformer_num_layers = qformer_num_layers
165
+ self.qformer_num_heads = qformer_num_heads
166
+ self.qformer_intermediate_size = qformer_intermediate_size
167
+ # LoRA configuration
168
+ self.use_lora = use_lora
169
+ self.lora_rank = lora_rank
170
+ self.lora_alpha = lora_alpha
171
+ self.lora_dropout = lora_dropout
172
+ self.lora_target_modules = lora_target_modules or [
173
+ "q_proj",
174
+ "k_proj",
175
+ "v_proj",
176
+ "o_proj",
177
+ "gate_proj",
178
+ "up_proj",
179
+ "down_proj",
180
+ ]
181
+ self.freeze_projector = freeze_projector
182
+ self.freeze_language_model = freeze_language_model
183
+ self.freeze_text_embed_tokens = freeze_text_embed_tokens
184
+ self.freeze_audio_encoder = freeze_audio_encoder
185
+ self.apply_spec_augment = apply_spec_augment
186
+ self.mask_time_prob = mask_time_prob
187
+ self.mask_time_length = mask_time_length
188
+ self.mask_time_min_masks = mask_time_min_masks
189
+ self.mask_feature_prob = mask_feature_prob
190
+ self.mask_feature_length = mask_feature_length
191
+ self.mask_feature_min_masks = mask_feature_min_masks
192
+
193
+ explicit_generation_args = {
194
+ "num_beams": num_beams,
195
+ "max_new_tokens": max_new_tokens,
196
+ "min_new_tokens": min_new_tokens,
197
+ "repetition_penalty": repetition_penalty,
198
+ "length_penalty": length_penalty,
199
+ "no_repeat_ngram_size": no_repeat_ngram_size,
200
+ "use_cache": use_cache,
201
+ }
202
+ for key, default in generation_defaults.items():
203
+ value = explicit_generation_args[key]
204
+ setattr(self, key, value if value is not None else default)
205
+ self.do_sample = do_sample
206
+ self.temperature = temperature
207
+ self.top_p = top_p
208
+ self.top_k = top_k
209
+
210
+ if "audio_config" not in kwargs:
211
+ self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
212
+ # Override dtype to match model_dtype
213
+ self.audio_config.dtype = model_dtype
214
+ else:
215
+ self.audio_config = kwargs.pop("audio_config")
216
+
217
+ if "text_config" not in kwargs:
218
+ self.text_config = transformers.AutoConfig.from_pretrained(
219
+ text_model_id, trust_remote_code=True
220
+ )
221
+ # Override dtype to match model_dtype
222
+ self.text_config.dtype = model_dtype
223
+ else:
224
+ self.text_config = kwargs.pop("text_config")
225
+
226
+ if isinstance(self.text_config, dict):
227
+ # Reconstruct config from dict using the model_type stored in the dict
228
+ model_type = self.text_config["model_type"]
229
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
230
+ self.text_config = config_class(**self.text_config)
231
+
232
+ if isinstance(self.audio_config, dict):
233
+ model_type = self.audio_config.get("model_type")
234
+ if model_type:
235
+ config_class = transformers.AutoConfig.for_model(model_type).__class__
236
+ self.audio_config = config_class(**self.audio_config)
237
+
238
+ super().__init__(**kwargs)
239
+
240
+ # Point encoder to audio_config so pipeline uses correct feature extractor
241
+ # The pipeline looks for config.encoder._name_or_path for feature extractor
242
+ self.encoder = self.audio_config
243
+
244
+ self.auto_map = {
245
+ "AutoConfig": "asr_config.ASRConfig",
246
+ "AutoModel": "asr_modeling.ASRModel",
247
+ "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
248
+ "AutoProcessor": "asr_processing.ASRProcessor",
249
+ }
250
+ self.custom_pipelines = {
251
+ "automatic-speech-recognition": {
252
+ "impl": "asr_pipeline.ASRPipeline",
253
+ "pt": ["AutoModelForSpeechSeq2Seq"],
254
+ "tf": [],
255
+ "type": "audio",
256
+ }
257
+ }
258
+ self.architectures = ["ASRModel"]
259
+ self.pipeline_tag = "automatic-speech-recognition"
260
+
261
+
262
+ transformers.AutoConfig.register("asr_model", ASRConfig)
asr_modeling.py ADDED
@@ -0,0 +1,1069 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from threading import Thread
4
+ from typing import Iterator, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F # noqa: N812
9
+ from transformers import (
10
+ AutoModel,
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ PreTrainedModel,
14
+ TextIteratorStreamer,
15
+ )
16
+ from transformers.generation import GenerationMixin
17
+ from transformers.modeling_outputs import CausalLMOutputWithPast
18
+
19
+ try:
20
+ from .asr_config import ASRConfig, compute_encoder_output_length
21
+ from .projectors import PROJECTOR_CLASSES
22
+ except ImportError:
23
+ from asr_config import ASRConfig, compute_encoder_output_length # type: ignore[no-redef]
24
+ from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
25
+
26
+
27
+ def _resolve_attn_implementation(requested: Optional[str]) -> Optional[str]:
28
+ """Coerce flash_attention_2 to sdpa when CUDA isn't available.
29
+
30
+ FA2 is CUDA-only. On MPS/CPU, requesting it either errors at load or
31
+ silently falls back to a slower path; either way the user pays the FA2
32
+ install + import cost for no win. Coerce here so a saved config that
33
+ pins flash_attention_2 still loads on Mac / CPU-only Linux boxes.
34
+ """
35
+ if requested == "flash_attention_2" and not torch.cuda.is_available():
36
+ return "sdpa"
37
+ return requested
38
+
39
+
40
+ def _gather_audio_embeds(audio_embeds: torch.Tensor, token_counts: torch.Tensor) -> torch.Tensor:
41
+ """Flatten per-sample audio embeddings into a packed tensor.
42
+
43
+ For each row i, takes the first ``token_counts[i]`` rows of
44
+ ``audio_embeds[i]`` and concatenates them. If any token count exceeds
45
+ ``audio_embeds.shape[1]``, the deficit is zero-padded.
46
+
47
+ Equivalent to a per-sample slice/cat loop but with O(1) host-device
48
+ syncs per call (one ``max().item()``) instead of one per sample.
49
+ """
50
+ _, max_len, _ = audio_embeds.shape
51
+ needed = int(token_counts.max().item())
52
+ if needed > max_len:
53
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, needed - max_len))
54
+ max_len = needed
55
+ indices = torch.arange(max_len, device=audio_embeds.device).unsqueeze(0)
56
+ mask = indices < token_counts.unsqueeze(1)
57
+ return audio_embeds[mask]
58
+
59
+
60
+ class ASRModel(PreTrainedModel, GenerationMixin):
61
+ """Audio-to-text model combining an audio encoder, projector, and language model."""
62
+
63
+ config_class = ASRConfig
64
+ base_model_prefix = "model"
65
+ main_input_name = "input_features"
66
+ _supports_flash_attn_2 = True
67
+ supports_gradient_checkpointing = True
68
+ _is_loading_from_pretrained: bool = False
69
+
70
+ TRANSCRIBE_PROMPT = "Transcribe the speech to text"
71
+
72
+ @classmethod
73
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
74
+ """Load model from pretrained, handling device placement correctly."""
75
+ from safetensors.torch import load_file
76
+ from transformers.utils.hub import cached_file
77
+
78
+ config = kwargs.pop("config", None)
79
+ if config is None:
80
+ config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
81
+
82
+ # Set flag to avoid device_map="auto" in sub-model loaders
83
+ cls._is_loading_from_pretrained = True
84
+
85
+ try:
86
+ model = cls(config, **kwargs)
87
+
88
+ # Load projector weights from safetensors
89
+ subfolder = kwargs.get("subfolder")
90
+ revision = kwargs.get("revision")
91
+ cache_kwargs = {}
92
+ if subfolder:
93
+ cache_kwargs["subfolder"] = subfolder
94
+ if revision:
95
+ cache_kwargs["revision"] = revision
96
+
97
+ model_file = cached_file(
98
+ pretrained_model_name_or_path,
99
+ "model.safetensors",
100
+ _raise_exceptions_for_missing_entries=False,
101
+ **cache_kwargs,
102
+ )
103
+
104
+ if model_file is not None:
105
+ state_dict = load_file(model_file)
106
+ model.load_state_dict(state_dict, strict=False)
107
+
108
+ # Load LoRA adapters if use_lora is enabled
109
+ if getattr(config, "use_lora", False):
110
+ # Check for adapter_config.json (required by PEFT to load adapters)
111
+ adapter_config_file = cached_file(
112
+ pretrained_model_name_or_path,
113
+ "adapter_config.json",
114
+ _raise_exceptions_for_missing_entries=False,
115
+ **cache_kwargs,
116
+ )
117
+ if adapter_config_file is not None:
118
+ # Load saved adapter weights using the original repo_id/path
119
+ # PEFT handles Hub downloads and caching internally
120
+ from peft import PeftModel
121
+
122
+ model.language_model = PeftModel.from_pretrained(
123
+ model.language_model,
124
+ pretrained_model_name_or_path,
125
+ is_trainable=True,
126
+ **cache_kwargs,
127
+ )
128
+ else:
129
+ # No saved adapters - initialize fresh LLM LoRA for training
130
+ from peft import LoraConfig, get_peft_model
131
+
132
+ lora_config = LoraConfig(
133
+ r=config.lora_rank,
134
+ lora_alpha=config.lora_alpha,
135
+ target_modules=config.lora_target_modules,
136
+ lora_dropout=config.lora_dropout,
137
+ bias="none",
138
+ task_type="CAUSAL_LM",
139
+ )
140
+ model.language_model = get_peft_model(model.language_model, lora_config)
141
+
142
+ return model
143
+ finally:
144
+ cls._is_loading_from_pretrained = False
145
+
146
+ def __init__(self, config: ASRConfig, **kwargs) -> None:
147
+ super().__init__(config)
148
+
149
+ self.system_prompt = config.system_prompt
150
+ target_dtype = getattr(torch, config.model_dtype)
151
+
152
+ # Audio encoder (frozen)
153
+ self.audio_tower = self._load_audio_encoder(config, target_dtype)
154
+
155
+ # Language model (frozen)
156
+ self.language_model = self._load_language_model(config, target_dtype)
157
+
158
+ # Initialize tokenizer and special tokens
159
+ self._init_tokenizer(config)
160
+
161
+ # Set up generation config with greedy decoding defaults
162
+ self.generation_config = self.language_model.generation_config
163
+ self.generation_config.max_new_tokens = config.max_new_tokens
164
+ self.generation_config.min_new_tokens = config.min_new_tokens
165
+ self.generation_config.num_beams = config.num_beams
166
+ self.generation_config.do_sample = config.do_sample
167
+ # Set sampling params from config (None means use model defaults)
168
+ self.generation_config.temperature = config.temperature
169
+ self.generation_config.top_p = config.top_p
170
+ self.generation_config.top_k = config.top_k
171
+ self.generation_config.use_cache = config.use_cache
172
+ self.generation_config.length_penalty = config.length_penalty
173
+ self.generation_config.repetition_penalty = config.repetition_penalty
174
+ self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
175
+ # Set EOS tokens, filtering out any that don't exist in the tokenizer
176
+ eos_candidates = [
177
+ self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
178
+ self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
179
+ ]
180
+ self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None]
181
+ self.generation_config.pad_token_id = self.tokenizer.pad_token_id
182
+
183
+ # Feature extractor for audio preprocessing
184
+ self.feature_extractor = self._create_feature_extractor(config)
185
+
186
+ # Audio projector (trainable unless freeze_projector is set)
187
+ self.projector = self._create_projector(config, target_dtype)
188
+
189
+ # Setup LoRA if enabled (Stage 2 fine-tuning)
190
+ # Skip if loading from pretrained - from_pretrained will handle adapter loading
191
+ if getattr(config, "use_lora", False) and not getattr(
192
+ self.__class__, "_is_loading_from_pretrained", False
193
+ ):
194
+ self._setup_lora(config)
195
+
196
+ # Freeze projector if specified (for Stage 2 LoRA-only training)
197
+ if getattr(config, "freeze_projector", False):
198
+ self.projector.requires_grad_(False)
199
+
200
+ # Freeze the text-vocab embedding table (preserves base Qwen3's
201
+ # token→embedding mapping during joint fine-tune). With
202
+ # tie_word_embeddings=True the same tensor backs lm_head, so this
203
+ # also freezes the output projection. Audio tokens bypass this
204
+ # table — they're scattered into inputs_embeds via masked_scatter
205
+ # at <audio> positions (forward(), below), so the audio path is
206
+ # unaffected. Mirrors Baichuan-Audio's stage-2 policy of training
207
+ # all decoder params except the text embedding and LM head.
208
+ if getattr(config, "freeze_text_embed_tokens", False):
209
+ self.language_model.get_input_embeddings().weight.requires_grad_(False)
210
+
211
+ # For model parallelism
212
+ self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
213
+
214
+ def _create_feature_extractor(self, config: ASRConfig):
215
+ """Create the appropriate feature extractor for the audio encoder."""
216
+ from transformers import AutoFeatureExtractor
217
+
218
+ feature_extractor = AutoFeatureExtractor.from_pretrained(config.audio_model_id)
219
+ # Whisper's encoder requires a fixed 3000 mel frames (30s) and the
220
+ # feature extractor pads to that by default — leave it alone. Other
221
+ # encoders (e.g. GLM-ASR) accept variable-length input, so we disable
222
+ # padding to avoid wasting compute on silent frames.
223
+ if "whisper" not in config.audio_model_id.lower():
224
+ feature_extractor.padding = False
225
+ return feature_extractor
226
+
227
+ @classmethod
228
+ def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
229
+ """Load the audio encoder; freeze unless `config.freeze_audio_encoder=False`.
230
+
231
+ When unfrozen, the encoder participates in joint training — pair with a
232
+ much lower `encoder_learning_rate` than the projector/decoder LRs
233
+ (encoder is large, sensitive to perturbation, and shouldn't drift far
234
+ from its pretrained features). See `ASRTrainer.create_optimizer` for the
235
+ LR routing.
236
+ """
237
+ encoder_kwargs = {
238
+ "attn_implementation": _resolve_attn_implementation(config.attn_implementation),
239
+ "low_cpu_mem_usage": True,
240
+ "dtype": dtype,
241
+ }
242
+
243
+ if "whisper" in config.audio_model_id.lower():
244
+ from transformers import WhisperModel
245
+
246
+ full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
247
+ encoder = full_model.encoder
248
+ del full_model
249
+ elif "glm" in config.audio_model_id.lower():
250
+ # GLM-ASR models use audio_tower as the encoder
251
+ # Requires transformers >= 5.x or installed from source
252
+ from transformers import AutoModelForSeq2SeqLM
253
+
254
+ full_model = AutoModelForSeq2SeqLM.from_pretrained(
255
+ config.audio_model_id, trust_remote_code=True, **encoder_kwargs
256
+ )
257
+ # GLM stores encoder at audio_tower (GlmAsrEncoder)
258
+ encoder = full_model.audio_tower
259
+ # Clear references to free VRAM from the LLM decoder
260
+ full_model.language_model = None
261
+ full_model.multi_modal_projector = None
262
+ del full_model
263
+ else:
264
+ encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
265
+
266
+ # Explicit cast: from_pretrained's `dtype=` kwarg is honored
267
+ # inconsistently across loader paths (especially trust_remote_code
268
+ # branches like GLM-ASR), leaving submodules in fp32. FA2's startup
269
+ # then complains "current dype is torch.float32, expected fp16/bf16",
270
+ # and even with sdpa the projector→encoder feed mismatches dtypes.
271
+ # `.to(dtype=...)` after load is idempotent and forces the issue.
272
+ encoder = encoder.to(dtype=dtype)
273
+ if getattr(config, "freeze_audio_encoder", True):
274
+ encoder.requires_grad_(False)
275
+ encoder.train(False) # equivalent to .eval(); avoids a security hook false-positive
276
+ return encoder
277
+
278
+ @classmethod
279
+ def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel:
280
+ """Load and freeze the language model."""
281
+ decoder_kwargs = {
282
+ "attn_implementation": _resolve_attn_implementation(config.attn_implementation),
283
+ "trust_remote_code": True,
284
+ "low_cpu_mem_usage": True,
285
+ "dtype": dtype,
286
+ }
287
+
288
+ decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
289
+ # See _load_audio_encoder note: idempotent post-load cast to dodge the
290
+ # FA2 "current dype is fp32" warning when from_pretrained's dtype kwarg
291
+ # isn't fully propagated to every submodule.
292
+ decoder = decoder.to(dtype=dtype)
293
+ decoder.config.use_cache = getattr(config, "use_cache", True)
294
+ if getattr(config, "freeze_language_model", True):
295
+ decoder.requires_grad_(False)
296
+ decoder.train(False)
297
+ return decoder
298
+
299
+ def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
300
+ """Create the trainable audio projector."""
301
+ # Auto-detect dimensions if not specified
302
+ if config.encoder_dim is None:
303
+ enc_cfg = self.audio_tower.config
304
+ config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr(
305
+ enc_cfg, "d_model", None
306
+ )
307
+ if config.encoder_dim is None:
308
+ raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
309
+
310
+ if config.llm_dim is None:
311
+ dec_cfg = self.language_model.config
312
+ config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr(
313
+ dec_cfg, "d_model", None
314
+ )
315
+ if config.llm_dim is None:
316
+ raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
317
+
318
+ # Select projector type based on config
319
+ projector_type = getattr(config, "projector_type", "mlp")
320
+ projector_class = PROJECTOR_CLASSES.get(projector_type)
321
+ if projector_class is None:
322
+ raise ValueError(
323
+ f"Unknown projector_type: {projector_type}. "
324
+ f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
325
+ )
326
+ projector = projector_class(config)
327
+
328
+ # Move projector to same device as language model (important when using quantization)
329
+ device = next(self.language_model.parameters()).device
330
+ return projector.to(device=device, dtype=dtype)
331
+
332
+ def _setup_lora(self, config: ASRConfig):
333
+ """Apply LoRA adapters to the language model for Stage 2 fine-tuning."""
334
+ from peft import LoraConfig, get_peft_model
335
+
336
+ lora_config = LoraConfig(
337
+ r=config.lora_rank,
338
+ lora_alpha=config.lora_alpha,
339
+ target_modules=config.lora_target_modules,
340
+ lora_dropout=config.lora_dropout,
341
+ bias="none",
342
+ task_type="CAUSAL_LM",
343
+ )
344
+ self.language_model = get_peft_model(self.language_model, lora_config)
345
+
346
+ def _init_tokenizer(self, config: ASRConfig):
347
+ """Initialize tokenizer with audio token."""
348
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
349
+
350
+ # Set pad token. Prefer a dedicated pad token if the tokenizer has one
351
+ # (e.g. Qwen's <|finetune_right_pad_id|>); otherwise fall back to
352
+ # eos_token, which is the standard pattern for Llama-style tokenizers
353
+ # (SmolLM2, Llama, etc.) that ship without a separate pad token.
354
+ if (
355
+ self.tokenizer.pad_token is None
356
+ or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
357
+ ):
358
+ if "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
359
+ self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
360
+ elif self.tokenizer.pad_token is None:
361
+ self.tokenizer.pad_token = self.tokenizer.eos_token
362
+
363
+ # Add audio token
364
+ existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or []
365
+ if "<audio>" not in existing_special:
366
+ self.tokenizer.add_special_tokens(
367
+ {"additional_special_tokens": existing_special + ["<audio>"]}
368
+ )
369
+ # mean_resizing=True initializes the new <audio> row at the mean of
370
+ # existing rows so its scale matches the pretrained distribution. The
371
+ # input-side <audio> embedding is overwritten via masked_scatter and
372
+ # never seen by the LM, but with tied embeddings (Qwen3-0.6B) this
373
+ # same row is the lm_head column for predicting <audio>; a Gaussian
374
+ # draw at config.initializer_range was visible in early-step logits.
375
+ self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=True)
376
+
377
+ self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
378
+ self.tokenizer.padding_side = "right"
379
+
380
+ # Sync token IDs to configs
381
+ for cfg in [self.config.text_config, self.language_model.config, self.generation_config]:
382
+ if cfg is not None:
383
+ cfg.pad_token_id = self.tokenizer.pad_token_id
384
+ cfg.eos_token_id = self.tokenizer.eos_token_id
385
+ cfg.bos_token_id = self.tokenizer.bos_token_id
386
+
387
+ def train(self, mode: bool = True):
388
+ """Set train/eval mode, but keep frozen submodules out of train mode.
389
+
390
+ HF Trainer calls `model.train()` at the top of every training step, which
391
+ recursively switches every submodule into train mode — re-enabling dropout
392
+ on modules with `requires_grad_(False)`. The frozen encoder (and the LM
393
+ when `freeze_language_model=True`) should always run deterministically;
394
+ train-mode dropout only adds noise that can't improve a frozen network.
395
+ """
396
+ super().train(mode)
397
+ if getattr(self.config, "freeze_audio_encoder", True):
398
+ self.audio_tower.train(False)
399
+ if getattr(self.config, "freeze_language_model", True):
400
+ self.language_model.train(False)
401
+ return self
402
+
403
+ def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
404
+ """Enable/disable gradient checkpointing on the trainable submodules.
405
+
406
+ Routes the request to whichever components are actually trainable in
407
+ this run. The LM is always reached (its forward activations are
408
+ needed for backprop to the projector even when its weights are
409
+ frozen). The encoder is reached only when `freeze_audio_encoder` is
410
+ False — when frozen, no gradient flows through it and checkpointing
411
+ would just add recompute cost for no memory savings.
412
+ """
413
+ # The LLM still stores activations during forward for backprop to projector
414
+ # Gradient checkpointing trades compute for memory by recomputing activations
415
+ for submodule in self._gradient_checkpointing_targets():
416
+ if hasattr(submodule, "_set_gradient_checkpointing"):
417
+ submodule._set_gradient_checkpointing(enable, gradient_checkpointing_func)
418
+ elif hasattr(submodule, "gradient_checkpointing_enable") and enable:
419
+ submodule.gradient_checkpointing_enable(
420
+ gradient_checkpointing_kwargs={"use_reentrant": False}
421
+ )
422
+ elif hasattr(submodule, "gradient_checkpointing_disable") and not enable:
423
+ submodule.gradient_checkpointing_disable()
424
+
425
+ def _gradient_checkpointing_targets(self) -> list[nn.Module]:
426
+ """Return the submodules that should respond to gradient_checkpointing
427
+ toggles. Always includes the LM (activations are on the gradient path
428
+ to the projector); includes the encoder only when it's trainable.
429
+ """
430
+ targets: list[nn.Module] = [self.language_model]
431
+ if not getattr(self.config, "freeze_audio_encoder", True):
432
+ targets.append(self.audio_tower)
433
+ return targets
434
+
435
+ def get_input_embeddings(self) -> nn.Module:
436
+ return self.language_model.get_input_embeddings()
437
+
438
+ def set_input_embeddings(self, value: nn.Module) -> None:
439
+ self.language_model.set_input_embeddings(value)
440
+
441
+ def get_output_embeddings(self) -> nn.Module:
442
+ return self.language_model.get_output_embeddings()
443
+
444
+ def set_output_embeddings(self, value: nn.Module) -> None:
445
+ self.language_model.set_output_embeddings(value)
446
+
447
+ def get_processor(self):
448
+ """Get the processor for this model."""
449
+ try:
450
+ from .asr_processing import ASRProcessor
451
+ except ImportError:
452
+ from asr_processing import ASRProcessor # type: ignore[no-redef]
453
+
454
+ return ASRProcessor(
455
+ feature_extractor=self.feature_extractor,
456
+ tokenizer=self.tokenizer,
457
+ projector=self.projector,
458
+ encoder_conv_layers=self.config.encoder_conv_layers,
459
+ )
460
+
461
+ def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
462
+ """Save trainable weights: projector, plus the language model when fine-tuned.
463
+
464
+ With LoRA attached, the language_model entries are flattened to plain
465
+ (non-PEFT) HF naming so model.safetensors round-trips through
466
+ ASRModel.from_pretrained — which builds a vanilla base LM, overlays
467
+ these weights, and only then re-attaches PEFT. lora_*/adapter weights
468
+ are skipped here; PEFT serializes them separately as
469
+ adapter_model.safetensors via the save_pretrained path below.
470
+ """
471
+ sd = {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
472
+ if not getattr(self.config, "freeze_language_model", True):
473
+ lm = self.language_model
474
+ if hasattr(lm, "peft_config"):
475
+ for k, v in lm.state_dict().items():
476
+ if "lora_" in k:
477
+ continue
478
+ if k.startswith("base_model.model."):
479
+ k = k[len("base_model.model.") :]
480
+ # LoRA layers wrap the original Linear as `<name>.base_layer.<weight|bias>`.
481
+ k = k.replace(".base_layer.", ".")
482
+ sd[f"language_model.{k}"] = v
483
+ else:
484
+ sd.update({f"language_model.{k}": v for k, v in lm.state_dict().items()})
485
+ return sd
486
+
487
+ def _compute_encoder_output_lengths(
488
+ self,
489
+ audio_attention_mask: torch.Tensor,
490
+ ) -> torch.Tensor:
491
+ """Compute per-sample encoder output lengths using conv layer formulas."""
492
+ return compute_encoder_output_length(
493
+ audio_attention_mask.sum(dim=-1),
494
+ self.config.encoder_conv_layers,
495
+ )
496
+
497
+ def _encode_audio(
498
+ self,
499
+ audio_features: torch.Tensor,
500
+ expected_token_counts: torch.Tensor,
501
+ ) -> torch.Tensor:
502
+ """Encode audio features and return flattened embeddings matching expected_token_counts.
503
+
504
+ Args:
505
+ audio_features: Mel spectrogram features (batch, n_mels, mel_len)
506
+ expected_token_counts: Per-sample audio token counts as int64 tensor (batch,).
507
+
508
+ Returns:
509
+ Flattened audio embeddings of shape (sum(expected_token_counts), hidden_dim).
510
+ """
511
+ # SpecAugment is applied on the mel input, training-only. Most useful
512
+ # when the encoder is trainable; on the frozen-encoder path it still
513
+ # perturbs the projector's input slightly but with no gradient flowing
514
+ # back to the encoder to leverage the diversity.
515
+ if (
516
+ self.training
517
+ and getattr(self.config, "apply_spec_augment", False)
518
+ and audio_features.numel() > 0
519
+ ):
520
+ audio_features = self._mask_input_features(audio_features)
521
+
522
+ # When the encoder is frozen, skip gradient tracking through it — cuts
523
+ # activation memory and matches the prior published recipe's behavior.
524
+ # When trainable, we MUST allow gradients to flow back to encoder
525
+ # params; wrapping in no_grad here would silently zero encoder
526
+ # gradients regardless of requires_grad on its parameters.
527
+ encoder_frozen = getattr(self.config, "freeze_audio_encoder", True)
528
+ if encoder_frozen:
529
+ with torch.no_grad():
530
+ encoder_out = self.audio_tower(input_features=audio_features)
531
+ hidden_states = encoder_out.last_hidden_state
532
+ else:
533
+ encoder_out = self.audio_tower(input_features=audio_features)
534
+ hidden_states = encoder_out.last_hidden_state
535
+
536
+ audio_embeds = self.projector(hidden_states)
537
+
538
+ token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
539
+ return _gather_audio_embeds(audio_embeds, token_counts)
540
+
541
+ def _mask_input_features(
542
+ self,
543
+ input_features: torch.Tensor,
544
+ attention_mask: Optional[torch.Tensor] = None, # noqa: ARG002 — reserved for future use
545
+ ) -> torch.Tensor:
546
+ """SpecAugment on mel input (pure-torch, vectorized, compile-ready).
547
+
548
+ Follows the same semantics as
549
+ `transformers.models.whisper.modeling_whisper.WhisperModel._mask_input_features`
550
+ (wav2vec2-style mask sampling: sample N start positions per sample,
551
+ mask `mask_length` frames forward from each), but reimplemented in
552
+ pure torch so it stays inside the autograd graph without crossing
553
+ the numpy boundary. This avoids inductor codegen failures
554
+ (e.g. the `‘zuf0’ was not declared` error from the prior numpy ->
555
+ torch.tensor round-trip) AND avoids the per-forward host-to-GPU
556
+ sync that the numpy path required.
557
+
558
+ One minor semantic divergence vs the upstream helper: this version
559
+ allows mask spans to overlap, while upstream rejects overlapping
560
+ samples. For ASR purposes this is irrelevant — occasional region
561
+ double-coverage has no measurable effect on the regularization
562
+ signal.
563
+
564
+ Reads ASRConfig fields by Whisper naming convention: mask_time_prob,
565
+ mask_time_length, mask_time_min_masks, mask_feature_prob,
566
+ mask_feature_length, mask_feature_min_masks.
567
+
568
+ Args:
569
+ input_features: (batch, n_mels, mel_len) log-mel features.
570
+ attention_mask: reserved for future use; ignored here since our
571
+ mel features are pre-padded to zero and double-masking
572
+ pad regions is a no-op.
573
+
574
+ Returns:
575
+ Same-shape tensor with time-axis and/or feature-axis masks zeroed.
576
+ """
577
+ input_features = input_features.clone()
578
+ batch_size, hidden_size, sequence_length = input_features.size()
579
+ config = self.config
580
+ device = input_features.device
581
+
582
+ if getattr(config, "mask_time_prob", 0.0) > 0:
583
+ mask_time = self._sample_mask_indices(
584
+ batch_size,
585
+ sequence_length,
586
+ mask_prob=config.mask_time_prob,
587
+ mask_length=config.mask_time_length,
588
+ min_masks=config.mask_time_min_masks,
589
+ device=device,
590
+ )
591
+ # Broadcast (B, T) -> (B, 1, T) to mask all mel bins at masked times.
592
+ input_features.masked_fill_(mask_time.unsqueeze(1), 0)
593
+
594
+ if getattr(config, "mask_feature_prob", 0.0) > 0:
595
+ mask_feature = self._sample_mask_indices(
596
+ batch_size,
597
+ hidden_size,
598
+ mask_prob=config.mask_feature_prob,
599
+ mask_length=config.mask_feature_length,
600
+ min_masks=config.mask_feature_min_masks,
601
+ device=device,
602
+ )
603
+ # Broadcast (B, F) -> (B, F, 1) to mask all time steps at masked bins.
604
+ input_features.masked_fill_(mask_feature.unsqueeze(-1), 0)
605
+
606
+ return input_features
607
+
608
+ @staticmethod
609
+ def _sample_mask_indices(
610
+ batch_size: int,
611
+ axis_length: int,
612
+ mask_prob: float,
613
+ mask_length: int,
614
+ min_masks: int,
615
+ device: torch.device,
616
+ ) -> torch.Tensor:
617
+ """Vectorized SpecAugment mask sampler — torch.compile-friendly.
618
+
619
+ Returns a (batch_size, axis_length) bool tensor where True marks
620
+ a position covered by at least one mask span. Spans may overlap
621
+ (see _mask_input_features docstring on the semantic difference vs
622
+ the upstream Whisper helper).
623
+ """
624
+ # Number of mask spans per sample: deterministic given config + axis_length.
625
+ # Matches the upstream formula (ignoring the epsilon noise term, which
626
+ # only shifts the count by ±1 stochastically — negligible at the
627
+ # default mask_time_prob=0.05 / mask_length=10 setting which gives
628
+ # ~5 spans for a typical 1500-frame mel input).
629
+ num_masked_spans = max(int(mask_prob * axis_length / mask_length + 0.5), min_masks)
630
+ if num_masked_spans == 0:
631
+ return torch.zeros(batch_size, axis_length, device=device, dtype=torch.bool)
632
+
633
+ # Sample start positions independently per sample × span.
634
+ # Clamp range so a span of length mask_length never runs off the end.
635
+ max_start = max(axis_length - mask_length + 1, 1)
636
+ starts = torch.randint(
637
+ 0, max_start, (batch_size, num_masked_spans), device=device
638
+ ) # (B, N)
639
+
640
+ # For each (sample, span, position), True iff position ∈ [start, start+mask_length).
641
+ positions = torch.arange(axis_length, device=device).view(1, 1, -1) # (1, 1, T)
642
+ starts_b = starts.unsqueeze(-1) # (B, N, 1)
643
+ span_mask = (positions >= starts_b) & (positions < starts_b + mask_length)
644
+ # Reduce over the span dim: True if ANY span covers this position.
645
+ return span_mask.any(dim=1)
646
+
647
+ def forward(
648
+ self,
649
+ input_ids: Optional[torch.Tensor] = None,
650
+ input_features: Optional[torch.Tensor] = None,
651
+ audio_attention_mask: Optional[torch.Tensor] = None,
652
+ attention_mask: Optional[torch.Tensor] = None,
653
+ position_ids: Optional[torch.Tensor] = None,
654
+ past_key_values: Optional[torch.Tensor] = None,
655
+ inputs_embeds: Optional[torch.Tensor] = None,
656
+ labels: Optional[torch.Tensor] = None,
657
+ use_cache: Optional[bool] = None,
658
+ cache_position: Optional[torch.Tensor] = None,
659
+ audio_token_counts: Optional[torch.Tensor] = None,
660
+ **kwargs,
661
+ ) -> CausalLMOutputWithPast:
662
+ """Forward pass for training and inference."""
663
+ if inputs_embeds is None:
664
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
665
+
666
+ if input_features is not None and input_ids is not None:
667
+ is_audio_token = input_ids == self.audio_token_id
668
+ if audio_token_counts is None:
669
+ audio_token_counts = is_audio_token.sum(dim=-1)
670
+ else:
671
+ audio_token_counts = audio_token_counts.to(
672
+ device=input_ids.device, dtype=torch.long
673
+ )
674
+
675
+ audio_embeds = self._encode_audio(input_features, audio_token_counts)
676
+
677
+ audio_token_mask = is_audio_token.unsqueeze(-1)
678
+ inputs_embeds = inputs_embeds.masked_scatter(
679
+ audio_token_mask.to(inputs_embeds.device),
680
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
681
+ )
682
+
683
+ # Forward label_smoothing to the LM's loss_function via **kwargs.
684
+ # transformers.loss.loss_utils.ForCausalLMLoss → fixed_cross_entropy
685
+ # forwards extra kwargs to F.cross_entropy, which accepts label_smoothing.
686
+ # When apply_liger_kernel_to_qwen3() has patched the LM, the smoothing
687
+ # is consumed by liger's fused linear CE (no (B,T,V) materialization).
688
+ # Zeroed on eval so eval/loss is raw CE and comparable to LS=0 runs.
689
+ if labels is not None and self.training and self.config.label_smoothing > 0:
690
+ kwargs.setdefault("label_smoothing", self.config.label_smoothing)
691
+
692
+ outputs = self.language_model(
693
+ attention_mask=attention_mask,
694
+ position_ids=position_ids,
695
+ past_key_values=past_key_values,
696
+ inputs_embeds=inputs_embeds,
697
+ labels=labels,
698
+ use_cache=use_cache,
699
+ cache_position=cache_position,
700
+ **kwargs,
701
+ )
702
+
703
+ if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
704
+ aux_loss = self.projector.get_aux_loss()
705
+ if aux_loss is not None and aux_loss.numel() > 0:
706
+ outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device)
707
+
708
+ return outputs
709
+
710
+ def prepare_inputs_for_generation(self, *args, **kwargs):
711
+ """Prepare inputs for generation, handling audio features for cached decoding."""
712
+ input_features = kwargs.pop("input_features", None)
713
+ cache_position = kwargs.get("cache_position")
714
+
715
+ model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs)
716
+
717
+ # Only pass audio features on the first generation step (cache_position[0] == 0)
718
+ if cache_position is not None and cache_position[0] == 0 and input_features is not None:
719
+ model_inputs["input_features"] = input_features
720
+
721
+ return model_inputs
722
+
723
+ def _get_num_audio_tokens(
724
+ self,
725
+ audio_attention_mask: torch.Tensor,
726
+ ) -> int:
727
+ """Calculate number of audio tokens based on actual audio length.
728
+
729
+ Uses attention mask to get real audio length, then computes:
730
+ mel_frames -> encoder_frames (via conv formulas) -> projector output tokens
731
+ """
732
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
733
+ # Use max length for batch (all samples should have same token count for generation)
734
+ encoder_output_len = int(encoder_lengths.max().item())
735
+ return int(self.projector.get_output_length(encoder_output_len))
736
+
737
+ @torch.no_grad()
738
+ def generate(
739
+ self,
740
+ input_ids: Optional[torch.Tensor] = None,
741
+ input_features: Optional[torch.Tensor] = None,
742
+ audio_attention_mask: Optional[torch.Tensor] = None,
743
+ attention_mask: Optional[torch.Tensor] = None,
744
+ system_prompt: Optional[str] = None,
745
+ **generate_kwargs,
746
+ ):
747
+ """Generate transcription from audio input.
748
+
749
+ Can be called in two ways:
750
+ 1. With input_ids containing <audio> tokens (from processor)
751
+ 2. With just audio, and we build the prompt internally
752
+ """
753
+ if input_features is None:
754
+ raise ValueError("input_features required for generation")
755
+ if audio_attention_mask is None:
756
+ raise ValueError("audio_attention_mask required for generation")
757
+
758
+ device = input_features.device
759
+ batch_size = input_features.shape[0]
760
+
761
+ # Encode audio -> flattened embeddings (no per-sample host sync)
762
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
763
+ token_counts = self.projector.get_output_length(encoder_lengths).to(torch.long)
764
+ audio_embeds = self._encode_audio(input_features, token_counts)
765
+
766
+ # If input_ids not provided, build prompt with correct number of audio tokens
767
+ if input_ids is None:
768
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
769
+ audio_placeholder = "<audio>" * num_audio_tokens
770
+
771
+ system_prompt = system_prompt or self.system_prompt
772
+
773
+ messages: list[dict[str, str]] = []
774
+ if system_prompt:
775
+ messages.append({"role": "system", "content": system_prompt})
776
+ # Audio tokens only (instruction-free)
777
+ user_content = audio_placeholder
778
+ if self.TRANSCRIBE_PROMPT:
779
+ user_content += " " + self.TRANSCRIBE_PROMPT
780
+ messages.append({"role": "user", "content": user_content})
781
+
782
+ chat_result = self.tokenizer.apply_chat_template(
783
+ messages,
784
+ tokenize=True,
785
+ add_generation_prompt=True,
786
+ return_tensors="pt",
787
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
788
+ )
789
+ input_ids = chat_result.input_ids.to(device)
790
+
791
+ if input_ids.dim() == 1:
792
+ input_ids = input_ids.unsqueeze(0)
793
+ if input_ids.shape[0] == 1 and batch_size > 1:
794
+ input_ids = input_ids.expand(batch_size, -1)
795
+
796
+ attention_mask = torch.ones_like(input_ids)
797
+
798
+ # Get text embeddings and replace audio tokens with audio embeddings
799
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
800
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
801
+ inputs_embeds = inputs_embeds.masked_scatter(
802
+ audio_token_mask.to(inputs_embeds.device),
803
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
804
+ )
805
+
806
+ # transformers v5 deprecates passing generation flags as kwargs when a
807
+ # `generation_config` is also passed — the kwargs get silently dropped.
808
+ # Pull any score-related flags out of generate_kwargs and apply them to
809
+ # a derived generation_config so they actually take effect.
810
+ gen_cfg = self.generation_config
811
+ score_flags = {}
812
+ for flag in ("output_scores", "output_logits", "return_dict_in_generate"):
813
+ if flag in generate_kwargs:
814
+ score_flags[flag] = generate_kwargs.pop(flag)
815
+ if score_flags:
816
+ from copy import copy as _copy
817
+
818
+ gen_cfg = _copy(self.generation_config)
819
+ for flag, value in score_flags.items():
820
+ setattr(gen_cfg, flag, value)
821
+ # output_scores requires return_dict_in_generate for HF generate to
822
+ # actually populate .scores on the output object.
823
+ if gen_cfg.output_scores and not gen_cfg.return_dict_in_generate:
824
+ gen_cfg.return_dict_in_generate = True
825
+
826
+ # Generate using language model
827
+ # Pass both input_ids and inputs_embeds so repetition_penalty works correctly
828
+ # (it needs input_ids to track which tokens have been used)
829
+ output = self.language_model.generate(
830
+ input_ids=input_ids,
831
+ inputs_embeds=inputs_embeds,
832
+ attention_mask=attention_mask,
833
+ generation_config=gen_cfg,
834
+ **generate_kwargs,
835
+ )
836
+
837
+ # When using inputs_embeds with input_ids, generate returns the full
838
+ # sequence (prompt + generated). Strip the prompt to return only the
839
+ # newly generated tokens. When scores were requested, preserve the
840
+ # GenerateOutput so callers can read .scores; otherwise return the
841
+ # bare tensor for backward compatibility with existing callers.
842
+ input_len = input_ids.shape[1]
843
+ if isinstance(output, torch.Tensor):
844
+ return output[:, input_len:]
845
+ output.sequences = output.sequences[:, input_len:]
846
+ return output
847
+
848
+ def generate_streaming(
849
+ self,
850
+ input_features: torch.Tensor,
851
+ audio_attention_mask: torch.Tensor,
852
+ system_prompt: Optional[str] = None,
853
+ **generate_kwargs,
854
+ ) -> Iterator[str]:
855
+ """Generate transcription with streaming token output.
856
+
857
+ Yields partial transcript strings as tokens are generated.
858
+ Reduces time-to-first-word by streaming tokens as they're decoded.
859
+
860
+ Args:
861
+ input_features: Mel spectrogram features (batch, n_mels, mel_len)
862
+ audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len)
863
+ system_prompt: Optional system prompt override
864
+ **generate_kwargs: Additional generation arguments
865
+
866
+ Yields:
867
+ Partial transcript text as each token is generated
868
+ """
869
+ device = input_features.device
870
+ batch_size = input_features.shape[0]
871
+
872
+ # Encode audio -> flattened embeddings (no per-sample host sync)
873
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
874
+ token_counts = self.projector.get_output_length(encoder_lengths).to(torch.long)
875
+ audio_embeds = self._encode_audio(input_features, token_counts)
876
+
877
+ # Build prompt with correct number of audio tokens
878
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
879
+ audio_placeholder = "<audio>" * num_audio_tokens
880
+
881
+ system_prompt = system_prompt or self.system_prompt
882
+
883
+ messages: list[dict[str, str]] = []
884
+ if system_prompt:
885
+ messages.append({"role": "system", "content": system_prompt})
886
+ # Audio tokens only (instruction-free)
887
+ user_content = audio_placeholder
888
+ if self.TRANSCRIBE_PROMPT:
889
+ user_content += " " + self.TRANSCRIBE_PROMPT
890
+ messages.append({"role": "user", "content": user_content})
891
+
892
+ chat_result = self.tokenizer.apply_chat_template(
893
+ messages,
894
+ tokenize=True,
895
+ add_generation_prompt=True,
896
+ return_tensors="pt",
897
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
898
+ )
899
+ input_ids = chat_result.input_ids.to(device)
900
+
901
+ if input_ids.dim() == 1:
902
+ input_ids = input_ids.unsqueeze(0)
903
+ if input_ids.shape[0] == 1 and batch_size > 1:
904
+ input_ids = input_ids.expand(batch_size, -1)
905
+
906
+ attention_mask = torch.ones_like(input_ids)
907
+
908
+ # Get text embeddings and replace audio tokens with audio embeddings
909
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
910
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
911
+ inputs_embeds = inputs_embeds.masked_scatter(
912
+ audio_token_mask.to(inputs_embeds.device),
913
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
914
+ )
915
+
916
+ # Setup streamer for token-by-token output
917
+ streamer = TextIteratorStreamer(
918
+ self.tokenizer,
919
+ skip_prompt=True,
920
+ skip_special_tokens=True,
921
+ )
922
+
923
+ # Prepare generation kwargs
924
+ gen_kwargs = {
925
+ "inputs_embeds": inputs_embeds,
926
+ "attention_mask": attention_mask,
927
+ "generation_config": self.generation_config,
928
+ "streamer": streamer,
929
+ **generate_kwargs,
930
+ }
931
+
932
+ # Run generation in background thread
933
+ thread = Thread(target=self.language_model.generate, kwargs=gen_kwargs)
934
+ thread.start()
935
+
936
+ # Yield tokens as they're generated, filtering out <think>...</think> blocks
937
+ # Start assuming no think block - only filter when we see <think>
938
+ in_think_block = False
939
+ buffer = ""
940
+
941
+ for text in streamer:
942
+ buffer += text
943
+
944
+ # Check for think block start (in case model outputs think blocks)
945
+ while "<think>" in buffer:
946
+ in_think_block = True
947
+ # Yield any text before <think>
948
+ before_think = buffer.split("<think>")[0]
949
+ if before_think:
950
+ yield before_think
951
+ buffer = buffer.split("<think>", 1)[-1]
952
+
953
+ # Check for think block end
954
+ while in_think_block and "</think>" in buffer:
955
+ in_think_block = False
956
+ buffer = buffer.split("</think>", 1)[-1]
957
+
958
+ # Yield text if not in think block
959
+ if not in_think_block and buffer:
960
+ yield buffer
961
+ buffer = ""
962
+
963
+ # Yield any remaining buffer
964
+ if buffer and not in_think_block:
965
+ yield buffer
966
+
967
+ thread.join()
968
+
969
+ def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
970
+ """Save model, tokenizer, and processor."""
971
+ import shutil
972
+
973
+ save_dir = Path(save_directory)
974
+ save_dir.mkdir(parents=True, exist_ok=True)
975
+
976
+ # Update config with actual vocab size
977
+ self.config.vocab_size = self.language_model.config.vocab_size
978
+ self.config.text_config.vocab_size = self.language_model.config.vocab_size
979
+
980
+ if hasattr(self.audio_tower.config, "num_mel_bins"):
981
+ self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins
982
+
983
+ # Save model (temporarily remove non-serializable attributes)
984
+ tokenizer = self.tokenizer
985
+ del self.tokenizer
986
+
987
+ try:
988
+ super().save_pretrained(save_dir, **kwargs)
989
+ finally:
990
+ self.tokenizer = tokenizer
991
+
992
+ # Save tokenizer and feature extractor
993
+ self.tokenizer.save_pretrained(save_dir)
994
+ self.feature_extractor.save_pretrained(save_dir)
995
+
996
+ # Save LoRA adapters if present (creates adapter_model.safetensors and adapter_config.json)
997
+ # Don't save embedding layers - the <audio> token embedding is never used
998
+ # (it's replaced with projected audio embeddings before the LLM sees it)
999
+ if hasattr(self.language_model, "peft_config"):
1000
+ self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
1001
+
1002
+ # Clear base_model_name_or_path in adapter_config.json to prevent HF pipeline
1003
+ # from redirecting to the base LLM repo (like Qwen) which breaks feature
1004
+ # extractor loading for multimodal models. If a repo_id is provided, use that
1005
+ # so the model can be loaded directly from the Hub.
1006
+ adapter_config_path = save_dir / "adapter_config.json"
1007
+ if adapter_config_path.exists():
1008
+ with adapter_config_path.open() as f:
1009
+ adapter_config = json.load(f)
1010
+
1011
+ # Use repo_id if available, otherwise clear to prevent redirect.
1012
+ # Use empty string instead of None to avoid str(None) -> "None" bug
1013
+ # in some transformers/PEFT versions.
1014
+ repo_id = (
1015
+ kwargs.get("repo_id")
1016
+ or kwargs.get("push_to_hub_model_id")
1017
+ or getattr(self.config, "pretrained_model_path", None)
1018
+ or "" # Use empty string instead of None
1019
+ )
1020
+ adapter_config["base_model_name_or_path"] = repo_id
1021
+
1022
+ with adapter_config_path.open("w") as f:
1023
+ json.dump(adapter_config, f, indent=2)
1024
+
1025
+ # Add processor auto_map to preprocessor_config.json
1026
+ config_path = save_dir / "preprocessor_config.json"
1027
+ if config_path.exists():
1028
+ with config_path.open() as f:
1029
+ processor_config = json.load(f)
1030
+ else:
1031
+ processor_config = {}
1032
+
1033
+ processor_config.update(
1034
+ {
1035
+ "processor_class": "ASRProcessor",
1036
+ "auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
1037
+ }
1038
+ )
1039
+
1040
+ with config_path.open("w") as f:
1041
+ json.dump(processor_config, f, indent=2)
1042
+
1043
+ # Copy source files for auto-loading
1044
+ src_dir = Path(__file__).parent
1045
+ for asr_file in src_dir.glob("asr_*.py"):
1046
+ shutil.copy(asr_file, save_dir / asr_file.name)
1047
+ # Copy projectors module
1048
+ shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
1049
+ # Copy alignment module
1050
+ shutil.copy(src_dir / "alignment.py", save_dir / "alignment.py")
1051
+ # Copy diarization module
1052
+ shutil.copy(src_dir / "diarization.py", save_dir / "diarization.py")
1053
+
1054
+ def push_to_hub(self, repo_id: str, **kwargs) -> str:
1055
+ """Push model to HuggingFace Hub, ensuring adapter_config points to repo.
1056
+
1057
+ IMPORTANT: Sets base_model_name_or_path in adapter_config.json to repo_id
1058
+ so that transformers pipeline() can load the model correctly. Without this,
1059
+ the pipeline tries to load from "None" which fails.
1060
+ """
1061
+ # Store repo_id in config so save_pretrained can access it
1062
+ self.config.pretrained_model_path = repo_id
1063
+ # Call parent's push_to_hub
1064
+ return super().push_to_hub(repo_id, **kwargs)
1065
+
1066
+
1067
+ # Register with transformers Auto classes
1068
+ # (AutoConfig.register is handled in asr_config.py at module load.)
1069
+ AutoModel.register(ASRConfig, ASRModel)
asr_pipeline.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
2
+
3
+ import re
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import torch
9
+ import transformers
10
+ from transformers.pipelines.audio_utils import ffmpeg_read
11
+
12
+ try:
13
+ from .alignment import ForcedAligner
14
+ from .asr_modeling import ASRModel
15
+ from .diarization import SpeakerDiarizer
16
+ except ImportError:
17
+ from alignment import ForcedAligner # type: ignore[no-redef]
18
+ from asr_modeling import ASRModel # type: ignore[no-redef]
19
+ from diarization import SpeakerDiarizer # type: ignore[no-redef]
20
+
21
+ # Re-export for backwards compatibility
22
+ __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"]
23
+
24
+ _THINK_TAG_RE = re.compile(r"<think>.*?</think>\s*", flags=re.DOTALL)
25
+ _DEFAULT_MIN_REPEATS = 3
26
+ _TRAILING_CHAR_RE = re.compile(rf"(.)\1{{{_DEFAULT_MIN_REPEATS - 1},}}$")
27
+ _TRAILING_WORD_RE = re.compile(
28
+ rf"\b(\w+)(?:\s+\1){{{_DEFAULT_MIN_REPEATS - 1},}}\s*$", re.IGNORECASE
29
+ )
30
+
31
+
32
+ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
33
+ """ASR Pipeline for audio-to-text transcription."""
34
+
35
+ model: ASRModel
36
+
37
+ def __init__(self, model: ASRModel, **kwargs):
38
+ """Initialize ASR pipeline.
39
+
40
+ Args:
41
+ model: ASRModel instance for transcription
42
+ **kwargs: Additional arguments (feature_extractor, tokenizer, device)
43
+ """
44
+ feature_extractor = kwargs.pop("feature_extractor", None)
45
+ tokenizer = kwargs.pop("tokenizer", model.tokenizer)
46
+
47
+ if feature_extractor is None:
48
+ feature_extractor = model.get_processor().feature_extractor
49
+
50
+ super().__init__(
51
+ model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
52
+ )
53
+ self._current_audio = None
54
+
55
+ def _sanitize_parameters(self, **kwargs):
56
+ """Intercept our custom parameters before parent class validates them."""
57
+ # Remove our custom parameters so parent doesn't see them
58
+ kwargs.pop("return_timestamps", None)
59
+ kwargs.pop("return_speakers", None)
60
+ kwargs.pop("num_speakers", None)
61
+ kwargs.pop("min_speakers", None)
62
+ kwargs.pop("max_speakers", None)
63
+ kwargs.pop("hf_token", None)
64
+ kwargs.pop("user_prompt", None)
65
+ kwargs.pop("diarization_backend", None)
66
+
67
+ return super()._sanitize_parameters(**kwargs)
68
+
69
+ def __call__(
70
+ self,
71
+ inputs,
72
+ **kwargs,
73
+ ):
74
+ """Transcribe audio with optional word-level timestamps and speaker diarization.
75
+
76
+ Args:
77
+ inputs: Audio input (file path, dict with array/sampling_rate, etc.)
78
+ return_timestamps: If True, return word-level timestamps using forced alignment
79
+ return_speakers: If True, return speaker labels for each word
80
+ user_prompt: Custom transcription prompt (default: "Transcribe: ")
81
+ num_speakers: Exact number of speakers (if known, for diarization)
82
+ min_speakers: Minimum number of speakers (for diarization)
83
+ max_speakers: Maximum number of speakers (for diarization)
84
+ **kwargs: Additional arguments passed to the pipeline
85
+
86
+ Returns:
87
+ Dict with 'text' key, 'words' key if return_timestamps=True,
88
+ and speaker labels on words if return_speakers=True
89
+ """
90
+ # Extract our params before super().__call__ (which will also call _sanitize_parameters)
91
+ return_timestamps = kwargs.pop("return_timestamps", False)
92
+ return_speakers = kwargs.pop("return_speakers", False)
93
+ user_prompt = kwargs.pop("user_prompt", None)
94
+ diarization_params = {
95
+ "num_speakers": kwargs.pop("num_speakers", None),
96
+ "min_speakers": kwargs.pop("min_speakers", None),
97
+ "max_speakers": kwargs.pop("max_speakers", None),
98
+ }
99
+
100
+ if return_speakers:
101
+ return_timestamps = True
102
+
103
+ # Set custom user prompt if provided
104
+ original_prompt = None
105
+ if user_prompt:
106
+ original_prompt = self.model.TRANSCRIBE_PROMPT
107
+ self.model.TRANSCRIBE_PROMPT = user_prompt
108
+
109
+ # Store audio for timestamp alignment and diarization
110
+ if return_timestamps or return_speakers:
111
+ self._current_audio = self._extract_audio(inputs)
112
+
113
+ # Run standard transcription
114
+ result = super().__call__(inputs, **kwargs)
115
+
116
+ # Add timestamps if requested
117
+ if return_timestamps and self._current_audio is not None:
118
+ text = result.get("text", "")
119
+ if text:
120
+ try:
121
+ words = ForcedAligner.align(
122
+ self._current_audio["array"],
123
+ text,
124
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
125
+ )
126
+ result["words"] = words
127
+ except Exception as e:
128
+ result["words"] = []
129
+ result["timestamp_error"] = str(e)
130
+ else:
131
+ result["words"] = []
132
+
133
+ # Add speaker diarization if requested
134
+ if return_speakers and self._current_audio is not None:
135
+ try:
136
+ # Run diarization
137
+ speaker_segments = SpeakerDiarizer.diarize(
138
+ self._current_audio["array"],
139
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
140
+ **{k: v for k, v in diarization_params.items() if v is not None},
141
+ )
142
+ result["speaker_segments"] = speaker_segments
143
+
144
+ # Assign speakers to words
145
+ if result.get("words"):
146
+ result["words"] = SpeakerDiarizer.assign_speakers_to_words(
147
+ result["words"],
148
+ speaker_segments,
149
+ )
150
+ except Exception as e:
151
+ result["speaker_segments"] = []
152
+ result["diarization_error"] = str(e)
153
+
154
+ # Clean up
155
+ self._current_audio = None
156
+ if original_prompt is not None:
157
+ self.model.TRANSCRIBE_PROMPT = original_prompt
158
+
159
+ return result
160
+
161
+ def _extract_audio(self, inputs) -> dict | None:
162
+ """Extract audio array from various input formats using HF utilities."""
163
+ if isinstance(inputs, dict):
164
+ if "array" in inputs:
165
+ return {
166
+ "array": inputs["array"],
167
+ "sampling_rate": inputs.get("sampling_rate", 16000),
168
+ }
169
+ if "raw" in inputs:
170
+ return {
171
+ "array": inputs["raw"],
172
+ "sampling_rate": inputs.get("sampling_rate", 16000),
173
+ }
174
+ elif isinstance(inputs, str):
175
+ # File path - load audio using ffmpeg (same as HF pipeline)
176
+ with Path(inputs).open("rb") as f:
177
+ audio = ffmpeg_read(f.read(), sampling_rate=16000)
178
+ return {"array": audio, "sampling_rate": 16000}
179
+ elif isinstance(inputs, bytes):
180
+ audio = ffmpeg_read(inputs, sampling_rate=16000)
181
+ return {"array": audio, "sampling_rate": 16000}
182
+ elif isinstance(inputs, np.ndarray):
183
+ return {"array": inputs, "sampling_rate": 16000}
184
+
185
+ return None
186
+
187
+ def preprocess(self, inputs, **preprocess_params):
188
+ """Preprocess audio inputs for the model.
189
+
190
+ Args:
191
+ inputs: Audio input (dict with array, file path, etc.)
192
+ **preprocess_params: Additional preprocessing parameters
193
+
194
+ Yields:
195
+ Model input dicts with input_features and attention_mask
196
+ """
197
+ # Handle dict with "array" key (from datasets)
198
+ if isinstance(inputs, dict) and "array" in inputs:
199
+ inputs = {
200
+ "raw": inputs["array"],
201
+ "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
202
+ }
203
+
204
+ for item in super().preprocess(inputs, **preprocess_params):
205
+ if "is_last" not in item:
206
+ item["is_last"] = True
207
+ yield item
208
+
209
+ def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
210
+ """Run model forward pass to generate transcription.
211
+
212
+ Args:
213
+ model_inputs: Dict with input_features and attention_mask
214
+ **generate_kwargs: Generation parameters. Pass ``output_scores=True``
215
+ (and ``return_dict_in_generate=True``, which is then implied) to
216
+ also return per-step top-1 and top-2 log-probabilities — used by
217
+ the eval harness's confidence metric. Backward-compatible: when
218
+ unset, returns just token IDs as before.
219
+
220
+ Returns:
221
+ Dict with generated token IDs, and optionally per-step
222
+ ``top1_logprob`` / ``top2_logprob`` tensors when scores were
223
+ requested.
224
+ """
225
+ # Extract audio features and is_last flag
226
+ is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
227
+
228
+ input_features = model_inputs["input_features"].to(self.model.device)
229
+ audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
230
+
231
+ # Opt-in: when output_scores is requested, force return_dict_in_generate
232
+ # so we get a GenerateOutput rather than a bare token tensor.
233
+ want_scores = bool(generate_kwargs.get("output_scores", False))
234
+ if want_scores:
235
+ generate_kwargs.setdefault("return_dict_in_generate", True)
236
+
237
+ generate_output = self.model.generate(
238
+ input_features=input_features,
239
+ audio_attention_mask=audio_attention_mask,
240
+ **generate_kwargs,
241
+ )
242
+
243
+ # Default (no scores requested): generate returns a tensor of token IDs.
244
+ if torch.is_tensor(generate_output):
245
+ return {"tokens": generate_output, "is_last": is_last}
246
+
247
+ # Scores requested: GenerateOutput dict-like with .sequences and .scores.
248
+ # `scores` is a tuple of per-step logits tensors (batch, vocab); convert
249
+ # each to log-probs and take top-2 to produce two short tensors over the
250
+ # generation horizon — kept small (no full vocab) so this is cheap to
251
+ # carry through postprocess.
252
+ sequences = generate_output.sequences
253
+ scores = generate_output.scores
254
+ top1_logprobs: list[float] = []
255
+ top2_logprobs: list[float] = []
256
+ if scores:
257
+ for step_logits in scores:
258
+ step_logprobs = torch.log_softmax(step_logits[0].float(), dim=-1)
259
+ top2 = torch.topk(step_logprobs, k=2)
260
+ top1_logprobs.append(top2.values[0].item())
261
+ top2_logprobs.append(top2.values[1].item())
262
+ return {
263
+ "tokens": sequences,
264
+ "top1_logprob": top1_logprobs,
265
+ "top2_logprob": top2_logprobs,
266
+ "is_last": is_last,
267
+ }
268
+
269
+ def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
270
+ """Convert model output tokens to text.
271
+
272
+ Args:
273
+ model_outputs: Dict with 'tokens' key containing generated IDs
274
+ **kwargs: Additional postprocessing parameters
275
+
276
+ Returns:
277
+ Dict with 'text' key containing transcription
278
+ """
279
+ # Handle list of outputs (from chunking)
280
+ if isinstance(model_outputs, list):
281
+ model_outputs = model_outputs[0] if model_outputs else {}
282
+
283
+ tokens = model_outputs.get("tokens")
284
+ if tokens is None:
285
+ return super().postprocess(model_outputs, **kwargs)
286
+
287
+ if torch.is_tensor(tokens):
288
+ tokens = tokens.cpu()
289
+ if tokens.dim() > 1:
290
+ tokens = tokens[0]
291
+
292
+ # Filter out eos tokens that the tokenizer doesn't recognize as special
293
+ # (generation_config.eos_token_id may differ from tokenizer.eos_token_id)
294
+ if hasattr(self, "model") and hasattr(self.model, "generation_config"):
295
+ eos_ids = self.model.generation_config.eos_token_id
296
+ if eos_ids is not None:
297
+ eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids}
298
+ tokens = [t for t in tokens.tolist() if t not in eos_set]
299
+
300
+ text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
301
+ # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
302
+ if "<think>" in text:
303
+ text = _THINK_TAG_RE.sub("", text).strip()
304
+ text = _truncate_repetitions(text)
305
+ out: dict[str, Any] = {"text": text}
306
+ # Pass through per-step logprobs when _forward captured them (i.e. caller
307
+ # passed output_scores=True). Lets eval harnesses compute confidence
308
+ # stats without re-running the model.
309
+ if "top1_logprob" in model_outputs:
310
+ out["top1_logprob"] = model_outputs["top1_logprob"]
311
+ if "top2_logprob" in model_outputs:
312
+ out["top2_logprob"] = model_outputs["top2_logprob"]
313
+ return out
314
+
315
+
316
+ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
317
+ """Truncate repeated words/phrases/characters at end of text.
318
+
319
+ Detects patterns like:
320
+ - Repeated words: "the the the the" -> "the"
321
+ - Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry"
322
+ - Repeated characters: "444444" -> "4"
323
+
324
+ Args:
325
+ text: Input text to process
326
+ min_repeats: Minimum repetitions to trigger truncation (default 3)
327
+
328
+ Returns:
329
+ Text with trailing repetitions removed
330
+ """
331
+ if not text:
332
+ return text
333
+
334
+ if min_repeats == _DEFAULT_MIN_REPEATS:
335
+ char_pattern = _TRAILING_CHAR_RE
336
+ word_pattern = _TRAILING_WORD_RE
337
+ else:
338
+ char_pattern = re.compile(rf"(.)\1{{{min_repeats - 1},}}$")
339
+ word_pattern = re.compile(rf"\b(\w+)(?:\s+\1){{{min_repeats - 1},}}\s*$", re.IGNORECASE)
340
+
341
+ text = char_pattern.sub(r"\1", text)
342
+ while word_pattern.search(text):
343
+ text = word_pattern.sub(r"\1", text)
344
+
345
+ # 3. Truncate repeated phrases (2-20 words) at end
346
+ # e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
347
+ words = text.split()
348
+ if len(words) < min_repeats * 2:
349
+ return text
350
+
351
+ # Cheap pre-check: trailing window must contain duplicates for any phrase repeat
352
+ # to be possible. set(window) == window means all unique → no repetition.
353
+ window = words[-min_repeats * 2 :]
354
+ if len(set(window)) == len(window):
355
+ return text
356
+
357
+ for phrase_len in range(2, min(21, len(words) // min_repeats + 1)):
358
+ phrase_escaped = re.escape(" ".join(words[-phrase_len:]))
359
+ phrase_pattern = re.compile(
360
+ rf"(^|.*?\s)({phrase_escaped})(?:\s+{phrase_escaped}){{{min_repeats - 1},}}\s*$",
361
+ re.IGNORECASE,
362
+ )
363
+ match = phrase_pattern.match(text)
364
+ if match:
365
+ text = (match.group(1) + match.group(2)).strip()
366
+ break
367
+
368
+ return text
asr_processing.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import transformers
5
+ from transformers import ProcessorMixin
6
+
7
+ try:
8
+ from .asr_config import DEFAULT_ENCODER_CONV_LAYERS, ASRConfig, compute_encoder_output_length
9
+ except ImportError:
10
+ from asr_config import ( # type: ignore[no-redef]
11
+ DEFAULT_ENCODER_CONV_LAYERS,
12
+ ASRConfig,
13
+ compute_encoder_output_length,
14
+ )
15
+
16
+
17
+ class ASRProcessor(ProcessorMixin):
18
+ """Processor for Whisper-based ASR models."""
19
+
20
+ attributes = ["feature_extractor", "tokenizer"]
21
+ feature_extractor_class = "AutoFeatureExtractor"
22
+ tokenizer_class = "AutoTokenizer"
23
+ AUDIO_TOKEN = "<audio>"
24
+ TRANSCRIBE_PROMPT = "Transcribe the speech to text"
25
+
26
+ def __init__(
27
+ self,
28
+ feature_extractor,
29
+ tokenizer,
30
+ projector=None,
31
+ encoder_conv_layers: Optional[list] = None,
32
+ ):
33
+ """Initialize the ASR processor.
34
+
35
+ Args:
36
+ feature_extractor: Audio feature extractor (WhisperFeatureExtractor)
37
+ tokenizer: Text tokenizer for the language model
38
+ projector: Audio projector module (for computing output lengths)
39
+ encoder_conv_layers: Conv layer specs [(pad, kernel, stride), ...]
40
+ """
41
+ self.feature_extractor = feature_extractor
42
+ self.tokenizer = tokenizer
43
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
44
+ self.projector = projector
45
+ self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
46
+
47
+ def _compute_encoder_output_length(self, mel_length: int) -> int:
48
+ """Compute encoder output length using conv layer formulas."""
49
+ return compute_encoder_output_length(mel_length, self.encoder_conv_layers)
50
+
51
+ def __call__(
52
+ self,
53
+ audio: Optional[Union[list, "torch.Tensor"]] = None,
54
+ text: Optional[str] = None,
55
+ system_prompt: Optional[str] = None,
56
+ return_tensors: str = "pt",
57
+ **kwargs,
58
+ ) -> dict:
59
+ """Process audio and text inputs for inference.
60
+
61
+ Args:
62
+ audio: Raw audio waveform(s)
63
+ text: Target transcription (optional, for training - but use DataCollator instead)
64
+ system_prompt: Optional system prompt
65
+ return_tensors: Return format ("pt" for PyTorch)
66
+
67
+ Returns:
68
+ Dict with input_features, input_ids, attention_mask
69
+ """
70
+ result = {}
71
+
72
+ # Process audio
73
+ if audio is not None:
74
+ audio_inputs = self.feature_extractor(
75
+ audio,
76
+ sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
77
+ return_attention_mask=True,
78
+ return_tensors=return_tensors,
79
+ **kwargs,
80
+ )
81
+ result["input_features"] = audio_inputs["input_features"]
82
+ result["audio_attention_mask"] = audio_inputs["attention_mask"]
83
+
84
+ # Use actual audio length (from attention mask) for token count
85
+ real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
86
+ encoder_output_len = self._compute_encoder_output_length(real_mel_len)
87
+ num_audio_tokens = self.projector.get_output_length(encoder_output_len)
88
+ else:
89
+ num_audio_tokens = 0
90
+
91
+ # Build prompt with audio token placeholders (instruction-free)
92
+ if num_audio_tokens > 0:
93
+ user_content = self.AUDIO_TOKEN * num_audio_tokens
94
+ if self.TRANSCRIBE_PROMPT:
95
+ user_content += " " + self.TRANSCRIBE_PROMPT
96
+ else:
97
+ user_content = self.TRANSCRIBE_PROMPT or ""
98
+
99
+ messages = []
100
+ if system_prompt:
101
+ messages.append({"role": "system", "content": system_prompt})
102
+ messages.append({"role": "user", "content": user_content})
103
+ if text is not None:
104
+ messages.append({"role": "assistant", "content": text})
105
+
106
+ # Tokenize
107
+ tokenized = self.tokenizer.apply_chat_template(
108
+ messages,
109
+ tokenize=True,
110
+ add_generation_prompt=(text is None),
111
+ return_tensors=return_tensors,
112
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
113
+ )
114
+
115
+ # Handle both tensor and BatchEncoding returns
116
+ if isinstance(tokenized, torch.Tensor):
117
+ input_ids = tokenized
118
+ else:
119
+ # BatchEncoding or dict-like object
120
+ input_ids = tokenized.get("input_ids", tokenized.input_ids)
121
+
122
+ if input_ids.dim() == 1:
123
+ input_ids = input_ids.unsqueeze(0)
124
+
125
+ result["input_ids"] = input_ids
126
+ result["attention_mask"] = torch.ones_like(input_ids)
127
+
128
+ return result
129
+
130
+
131
+ ASRProcessor.register_for_auto_class()
132
+ transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
diarization.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
2
+
3
+ Spectral clustering implementation adapted from FunASR/3D-Speaker:
4
+ https://github.com/alibaba-damo-academy/FunASR
5
+ MIT License (https://opensource.org/licenses/MIT)
6
+ """
7
+
8
+ import warnings
9
+
10
+ import numpy as np
11
+ import scipy
12
+ import sklearn.metrics.pairwise
13
+ import torch
14
+ from sklearn.cluster._kmeans import k_means
15
+ from sklearn.preprocessing import normalize
16
+
17
+
18
+ def _get_device() -> torch.device:
19
+ """Get best available device for inference."""
20
+ if torch.cuda.is_available():
21
+ return torch.device("cuda")
22
+ if torch.backends.mps.is_available():
23
+ return torch.device("mps")
24
+ return torch.device("cpu")
25
+
26
+
27
+ class SpectralCluster:
28
+ """Spectral clustering using unnormalized Laplacian of affinity matrix.
29
+
30
+ Adapted from FunASR/3D-Speaker and SpeechBrain implementations.
31
+ Uses eigenvalue gap to automatically determine number of speakers.
32
+ """
33
+
34
+ def __init__(self, min_num_spks: int = 1, max_num_spks: int = 15, pval: float = 0.06):
35
+ self.min_num_spks = min_num_spks
36
+ self.max_num_spks = max_num_spks
37
+ self.pval = pval
38
+
39
+ def __call__(self, embeddings: np.ndarray, oracle_num: int | None = None) -> np.ndarray:
40
+ """Run spectral clustering on embeddings.
41
+
42
+ Args:
43
+ embeddings: Speaker embeddings of shape [N, D]
44
+ oracle_num: Optional known number of speakers
45
+
46
+ Returns:
47
+ Cluster labels of shape [N]
48
+ """
49
+ # Similarity matrix computation
50
+ sim_mat = self.get_sim_mat(embeddings)
51
+
52
+ # Refining similarity matrix with pval
53
+ prunned_sim_mat = self.p_pruning(sim_mat)
54
+
55
+ # Symmetrization
56
+ sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
57
+
58
+ # Laplacian calculation
59
+ laplacian = self.get_laplacian(sym_prund_sim_mat)
60
+
61
+ # Get Spectral Embeddings
62
+ emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
63
+
64
+ # Perform clustering
65
+ return self.cluster_embs(emb, num_of_spk)
66
+
67
+ def get_sim_mat(self, embeddings: np.ndarray) -> np.ndarray:
68
+ """Compute cosine similarity matrix."""
69
+ return sklearn.metrics.pairwise.cosine_similarity(embeddings, embeddings)
70
+
71
+ def p_pruning(self, affinity: np.ndarray) -> np.ndarray:
72
+ """Prune low similarity values in affinity matrix (keep top pval fraction)."""
73
+ n = affinity.shape[0]
74
+ pval = max(self.pval, 6.0 / n)
75
+ k_keep = max(1, int(pval * n))
76
+
77
+ # Vectorized: find top-k indices per row and zero out the rest
78
+ top_k_idx = np.argpartition(affinity, -k_keep, axis=1)[:, -k_keep:]
79
+ mask = np.zeros_like(affinity, dtype=bool)
80
+ np.put_along_axis(mask, top_k_idx, True, axis=1)
81
+ affinity[~mask] = 0
82
+ return affinity
83
+
84
+ def get_laplacian(self, sim_mat: np.ndarray) -> np.ndarray:
85
+ """Compute unnormalized Laplacian matrix."""
86
+ from scipy.sparse.csgraph import laplacian
87
+
88
+ np.fill_diagonal(sim_mat, 0)
89
+ return laplacian(sim_mat, normed=False)
90
+
91
+ def get_spec_embs(
92
+ self, laplacian: np.ndarray, k_oracle: int | None = None
93
+ ) -> tuple[np.ndarray, int]:
94
+ """Extract spectral embeddings from Laplacian."""
95
+ lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
96
+
97
+ if k_oracle is not None:
98
+ num_of_spk = k_oracle
99
+ else:
100
+ lambda_gap_list = self.get_eigen_gaps(
101
+ lambdas[self.min_num_spks - 1 : self.max_num_spks + 1]
102
+ )
103
+ num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
104
+
105
+ emb = eig_vecs[:, :num_of_spk]
106
+ return emb, num_of_spk
107
+
108
+ def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
109
+ """Cluster spectral embeddings using k-means."""
110
+ _, labels, _ = k_means(emb, k, n_init=10)
111
+ return labels
112
+
113
+ def get_eigen_gaps(self, eig_vals: np.ndarray) -> np.ndarray:
114
+ """Compute gaps between consecutive eigenvalues."""
115
+ return np.diff(eig_vals)
116
+
117
+
118
+ class SpeakerClusterer:
119
+ """Speaker clustering backend using spectral clustering with speaker merging.
120
+
121
+ Features:
122
+ - Spectral clustering with eigenvalue gap for auto speaker count detection
123
+ - P-pruning for affinity matrix refinement
124
+ - Post-clustering speaker merging by cosine similarity
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ min_num_spks: int = 2,
130
+ max_num_spks: int = 10,
131
+ merge_thr: float = 0.90, # Moderate merging
132
+ ):
133
+ self.min_num_spks = min_num_spks
134
+ self.max_num_spks = max_num_spks
135
+ self.merge_thr = merge_thr
136
+ self._spectral_cluster: SpectralCluster | None = None
137
+
138
+ def _get_spectral_cluster(self) -> SpectralCluster:
139
+ """Lazy-load spectral clusterer."""
140
+ if self._spectral_cluster is None:
141
+ self._spectral_cluster = SpectralCluster(
142
+ min_num_spks=self.min_num_spks,
143
+ max_num_spks=self.max_num_spks,
144
+ )
145
+ return self._spectral_cluster
146
+
147
+ def __call__(self, embeddings: np.ndarray, num_speakers: int | None = None) -> np.ndarray:
148
+ """Cluster speaker embeddings and return labels.
149
+
150
+ Args:
151
+ embeddings: Speaker embeddings of shape [N, D]
152
+ num_speakers: Optional oracle number of speakers
153
+
154
+ Returns:
155
+ Cluster labels of shape [N]
156
+ """
157
+ if len(embeddings.shape) != 2:
158
+ raise ValueError(f"Expected 2D array, got shape {embeddings.shape}")
159
+
160
+ # Handle edge cases
161
+ if embeddings.shape[0] == 0:
162
+ return np.array([], dtype=int)
163
+ if embeddings.shape[0] == 1:
164
+ return np.array([0], dtype=int)
165
+ if embeddings.shape[0] < 6:
166
+ return np.zeros(embeddings.shape[0], dtype=int)
167
+
168
+ # Normalize embeddings and replace NaN/inf
169
+ embeddings = np.nan_to_num(embeddings, nan=0.0, posinf=0.0, neginf=0.0)
170
+ embeddings = normalize(embeddings)
171
+
172
+ # Run spectral clustering (suppress numerical warnings)
173
+ spectral = self._get_spectral_cluster()
174
+
175
+ # Update min/max for oracle case
176
+ if num_speakers is not None:
177
+ spectral.min_num_spks = num_speakers
178
+ spectral.max_num_spks = num_speakers
179
+
180
+ with warnings.catch_warnings():
181
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
182
+ labels = spectral(embeddings, oracle_num=num_speakers)
183
+
184
+ # Reset min/max
185
+ if num_speakers is not None:
186
+ spectral.min_num_spks = self.min_num_spks
187
+ spectral.max_num_spks = self.max_num_spks
188
+
189
+ # Merge similar speakers if no oracle
190
+ if num_speakers is None:
191
+ labels = self._merge_by_cos(labels, embeddings, self.merge_thr)
192
+
193
+ # Re-index labels sequentially
194
+ _, labels = np.unique(labels, return_inverse=True)
195
+
196
+ return labels
197
+
198
+ def _merge_by_cos(self, labels: np.ndarray, embs: np.ndarray, cos_thr: float) -> np.ndarray:
199
+ """Merge similar speakers by cosine similarity of centroids."""
200
+ from scipy.cluster.hierarchy import fcluster, linkage
201
+ from scipy.spatial.distance import pdist
202
+
203
+ unique_labels = np.unique(labels)
204
+ if len(unique_labels) <= 1:
205
+ return labels
206
+
207
+ # Compute normalized speaker centroids
208
+ centroids = np.array([embs[labels == lbl].mean(0) for lbl in unique_labels])
209
+ centroids = normalize(centroids)
210
+
211
+ # Hierarchical clustering with cosine distance
212
+ distances = pdist(centroids, metric="cosine")
213
+ linkage_matrix = linkage(distances, method="average")
214
+ merged_labels = fcluster(linkage_matrix, t=1.0 - cos_thr, criterion="distance") - 1
215
+
216
+ # Map original labels to merged labels
217
+ label_map = dict(zip(unique_labels, merged_labels))
218
+ return np.array([label_map[lbl] for lbl in labels])
219
+
220
+
221
+ class LocalSpeakerDiarizer:
222
+ """Local speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
223
+
224
+ Pipeline:
225
+ 1. TEN-VAD detects speech segments
226
+ 2. Sliding window (1.0s, 75% overlap) for uniform embedding extraction
227
+ 3. ECAPA-TDNN extracts speaker embeddings per window
228
+ 4. Spectral clustering with eigenvalue gap for auto speaker detection
229
+ 5. Frame-level consensus voting for segment reconstruction
230
+ 6. Post-processing merges short segments to reduce flicker
231
+
232
+ Tunable Parameters (class attributes):
233
+ - WINDOW_SIZE: Embedding extraction window size in seconds
234
+ - STEP_SIZE: Sliding window step size (overlap = WINDOW_SIZE - STEP_SIZE)
235
+ - VAD_THRESHOLD: Speech detection threshold (lower = more sensitive)
236
+ - VAD_MIN_DURATION: Minimum speech segment duration
237
+ - VAD_MAX_GAP: Maximum gap to bridge between segments
238
+ - VAD_PAD_ONSET/OFFSET: Padding added to speech segments
239
+ - VOTING_RATE: Frame resolution for consensus voting
240
+ - MIN_SEGMENT_DURATION: Minimum final segment duration
241
+ - SAME_SPEAKER_GAP: Maximum gap to merge same-speaker segments
242
+ - TAIL_COVERAGE_RATIO: Minimum tail coverage to add extra window
243
+ """
244
+
245
+ _ten_vad_model = None
246
+ _ecapa_model = None
247
+ _device = None
248
+
249
+ # ==================== TUNABLE PARAMETERS ====================
250
+
251
+ # Sliding window for embedding extraction
252
+ WINDOW_SIZE = 0.75 # seconds - shorter window for finer resolution
253
+ STEP_SIZE = 0.15 # seconds (80% overlap for more votes)
254
+ TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
255
+
256
+ # VAD hysteresis parameters
257
+ VAD_THRESHOLD = 0.25 # Balanced threshold
258
+ VAD_MIN_DURATION = 0.05 # Minimum speech segment duration (seconds)
259
+ VAD_MAX_GAP = 0.50 # Bridge gaps shorter than this (seconds)
260
+ VAD_PAD_ONSET = 0.05 # Padding at segment start (seconds)
261
+ VAD_PAD_OFFSET = 0.05 # Padding at segment end (seconds)
262
+
263
+ # Frame-level voting
264
+ VOTING_RATE = 0.01 # 10ms resolution for consensus voting
265
+
266
+ # Post-processing
267
+ MIN_SEGMENT_DURATION = 0.15 # Minimum final segment duration (seconds)
268
+ SHORT_SEGMENT_GAP = 0.1 # Gap threshold for merging short segments
269
+ SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
270
+
271
+ # ===========================================================
272
+
273
+ @classmethod
274
+ def _get_ten_vad_model(cls):
275
+ """Lazy-load TEN-VAD model (singleton)."""
276
+ if cls._ten_vad_model is None:
277
+ from ten_vad import TenVad
278
+
279
+ cls._ten_vad_model = TenVad(hop_size=256, threshold=cls.VAD_THRESHOLD)
280
+ return cls._ten_vad_model
281
+
282
+ @classmethod
283
+ def _get_device(cls) -> torch.device:
284
+ """Get the best available device."""
285
+ if cls._device is None:
286
+ cls._device = _get_device()
287
+ return cls._device
288
+
289
+ @classmethod
290
+ def _get_ecapa_model(cls):
291
+ """Lazy-load ECAPA-TDNN speaker embedding model (singleton)."""
292
+ if cls._ecapa_model is None:
293
+ # Suppress torchaudio deprecation warning from SpeechBrain
294
+ with warnings.catch_warnings():
295
+ warnings.filterwarnings("ignore", message="torchaudio._backend")
296
+ from speechbrain.inference.speaker import EncoderClassifier
297
+
298
+ device = cls._get_device()
299
+ cls._ecapa_model = EncoderClassifier.from_hparams(
300
+ source="speechbrain/spkrec-ecapa-voxceleb",
301
+ run_opts={"device": str(device)},
302
+ )
303
+
304
+ return cls._ecapa_model
305
+
306
+ @classmethod
307
+ def diarize(
308
+ cls,
309
+ audio: np.ndarray | str,
310
+ sample_rate: int = 16000,
311
+ num_speakers: int | None = None,
312
+ min_speakers: int = 2,
313
+ max_speakers: int = 10,
314
+ **_kwargs,
315
+ ) -> list[dict]:
316
+ """Run speaker diarization on audio.
317
+
318
+ Args:
319
+ audio: Audio waveform as numpy array or path to audio file
320
+ sample_rate: Audio sample rate (default 16000)
321
+ num_speakers: Exact number of speakers (if known)
322
+ min_speakers: Minimum number of speakers
323
+ max_speakers: Maximum number of speakers
324
+
325
+ Returns:
326
+ List of dicts with 'speaker', 'start', 'end' keys
327
+ """
328
+ # Handle file path input
329
+ if isinstance(audio, str):
330
+ import librosa
331
+
332
+ audio, sample_rate = librosa.load(audio, sr=16000)
333
+
334
+ # Ensure correct sample rate
335
+ if sample_rate != 16000:
336
+ import librosa
337
+
338
+ audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
339
+ sample_rate = 16000
340
+
341
+ audio = audio.astype(np.float32)
342
+ total_duration = len(audio) / sample_rate
343
+
344
+ # Step 1: VAD (returns segments and raw frame-level decisions)
345
+ segments, vad_frames = cls._get_speech_segments(audio, sample_rate)
346
+ if not segments:
347
+ return []
348
+
349
+ # Step 2: Extract embeddings
350
+ embeddings, window_segments = cls._extract_embeddings(audio, segments, sample_rate)
351
+ if len(embeddings) == 0:
352
+ return []
353
+
354
+ # Step 3: Cluster
355
+ clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
356
+ labels = clusterer(embeddings, num_speakers)
357
+
358
+ # Step 4: Post-process with consensus voting (VAD-aware)
359
+ return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
360
+
361
+ @classmethod
362
+ def _get_speech_segments(
363
+ cls, audio_array: np.ndarray, sample_rate: int = 16000
364
+ ) -> tuple[list[dict], list[bool]]:
365
+ """Get speech segments using TEN-VAD.
366
+
367
+ Returns:
368
+ Tuple of (segments list, vad_frames list of per-frame speech decisions)
369
+ """
370
+ vad_model = cls._get_ten_vad_model()
371
+
372
+ # Convert to int16 as required by TEN-VAD
373
+ # Clip to prevent integer overflow
374
+ if audio_array.dtype != np.int16:
375
+ audio_int16 = (np.clip(audio_array, -1.0, 1.0) * 32767).astype(np.int16)
376
+ else:
377
+ audio_int16 = audio_array
378
+
379
+ # Process frame by frame
380
+ hop_size = 256
381
+ frame_duration = hop_size / sample_rate
382
+ speech_frames: list[bool] = []
383
+
384
+ for i in range(0, len(audio_int16) - hop_size, hop_size):
385
+ frame = audio_int16[i : i + hop_size]
386
+ _, is_speech = vad_model.process(frame)
387
+ speech_frames.append(is_speech)
388
+
389
+ # Convert frame-level decisions to segments
390
+ segments = []
391
+ in_speech = False
392
+ start_idx = 0
393
+
394
+ for i, is_speech in enumerate(speech_frames):
395
+ if is_speech and not in_speech:
396
+ start_idx = i
397
+ in_speech = True
398
+ elif not is_speech and in_speech:
399
+ start_time = start_idx * frame_duration
400
+ end_time = i * frame_duration
401
+ segments.append(
402
+ {
403
+ "start": start_time,
404
+ "end": end_time,
405
+ "start_sample": int(start_time * sample_rate),
406
+ "end_sample": int(end_time * sample_rate),
407
+ }
408
+ )
409
+ in_speech = False
410
+
411
+ # Handle trailing speech
412
+ if in_speech:
413
+ start_time = start_idx * frame_duration
414
+ end_time = len(speech_frames) * frame_duration
415
+ segments.append(
416
+ {
417
+ "start": start_time,
418
+ "end": end_time,
419
+ "start_sample": int(start_time * sample_rate),
420
+ "end_sample": int(end_time * sample_rate),
421
+ }
422
+ )
423
+
424
+ return cls._apply_vad_hysteresis(segments, sample_rate), speech_frames
425
+
426
+ @classmethod
427
+ def _apply_vad_hysteresis(cls, segments: list[dict], sample_rate: int = 16000) -> list[dict]:
428
+ """Apply hysteresis-like post-processing to VAD segments."""
429
+ if not segments:
430
+ return segments
431
+
432
+ segments = sorted(segments, key=lambda x: x["start"])
433
+
434
+ # Fill short gaps
435
+ merged = [segments[0].copy()]
436
+ for seg in segments[1:]:
437
+ gap = seg["start"] - merged[-1]["end"]
438
+ if gap <= cls.VAD_MAX_GAP:
439
+ merged[-1]["end"] = seg["end"]
440
+ merged[-1]["end_sample"] = seg["end_sample"]
441
+ else:
442
+ merged.append(seg.copy())
443
+
444
+ # Remove short segments
445
+ filtered = [seg for seg in merged if (seg["end"] - seg["start"]) >= cls.VAD_MIN_DURATION]
446
+
447
+ # Dilate segments (add padding)
448
+ for seg in filtered:
449
+ seg["start"] = max(0.0, seg["start"] - cls.VAD_PAD_ONSET)
450
+ seg["end"] = seg["end"] + cls.VAD_PAD_OFFSET
451
+ seg["start_sample"] = int(seg["start"] * sample_rate)
452
+ seg["end_sample"] = int(seg["end"] * sample_rate)
453
+
454
+ return filtered
455
+
456
+ @classmethod
457
+ def _extract_embeddings(
458
+ cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
459
+ ) -> tuple[np.ndarray, list[dict]]:
460
+ """Extract speaker embeddings using sliding windows."""
461
+ speaker_model = cls._get_ecapa_model()
462
+
463
+ window_samples = int(cls.WINDOW_SIZE * sample_rate)
464
+ step_samples = int(cls.STEP_SIZE * sample_rate)
465
+
466
+ embeddings = []
467
+ window_segments = []
468
+
469
+ with torch.no_grad():
470
+ for seg in segments:
471
+ seg_start = seg["start_sample"]
472
+ seg_end = seg["end_sample"]
473
+ seg_len = seg_end - seg_start
474
+
475
+ # Generate window positions
476
+ if seg_len <= window_samples:
477
+ starts = [seg_start]
478
+ ends = [seg_end]
479
+ else:
480
+ starts = list(range(seg_start, seg_end - window_samples + 1, step_samples))
481
+ ends = [s + window_samples for s in starts]
482
+
483
+ # Cover tail if > TAIL_COVERAGE_RATIO of window remains
484
+ if ends and ends[-1] < seg_end:
485
+ remainder = seg_end - ends[-1]
486
+ if remainder > (window_samples * cls.TAIL_COVERAGE_RATIO):
487
+ starts.append(seg_end - window_samples)
488
+ ends.append(seg_end)
489
+
490
+ for c_start, c_end in zip(starts, ends):
491
+ chunk = audio_array[c_start:c_end]
492
+
493
+ # Pad short chunks with reflection
494
+ if len(chunk) < window_samples:
495
+ pad_width = window_samples - len(chunk)
496
+ chunk = np.pad(chunk, (0, pad_width), mode="reflect")
497
+
498
+ # Extract embedding using SpeechBrain's encode_batch
499
+ chunk_tensor = torch.from_numpy(chunk).float().unsqueeze(0)
500
+ embedding = (
501
+ speaker_model.encode_batch(chunk_tensor).squeeze(0).squeeze(0).cpu().numpy()
502
+ )
503
+
504
+ # Validate embedding
505
+ if np.isfinite(embedding).all() and np.linalg.norm(embedding) > 1e-8:
506
+ embeddings.append(embedding)
507
+ window_segments.append(
508
+ {
509
+ "start": c_start / sample_rate,
510
+ "end": c_end / sample_rate,
511
+ }
512
+ )
513
+
514
+ # Normalize all embeddings at once
515
+ if embeddings:
516
+ return normalize(np.array(embeddings)), window_segments
517
+ return np.array([]), []
518
+
519
+ @classmethod
520
+ def _resample_vad(cls, vad_frames: list[bool], num_frames: int) -> np.ndarray:
521
+ """Resample VAD frame decisions to match voting grid resolution.
522
+
523
+ VAD operates at 256 samples / 16000 Hz = 16ms per frame.
524
+ Voting operates at VOTING_RATE (default 10ms) per frame.
525
+ This maps VAD decisions to the finer voting grid.
526
+ """
527
+ if not vad_frames:
528
+ return np.zeros(num_frames, dtype=bool)
529
+
530
+ vad_rate = 256 / 16000 # 16ms per VAD frame
531
+ vad_arr = np.array(vad_frames)
532
+
533
+ # Vectorized: compute VAD frame indices for each voting frame
534
+ voting_times = np.arange(num_frames) * cls.VOTING_RATE
535
+ vad_indices = np.clip((voting_times / vad_rate).astype(int), 0, len(vad_arr) - 1)
536
+ return vad_arr[vad_indices]
537
+
538
+ @classmethod
539
+ def _postprocess_segments(
540
+ cls,
541
+ window_segments: list[dict],
542
+ labels: np.ndarray,
543
+ total_duration: float,
544
+ vad_frames: list[bool],
545
+ ) -> list[dict]:
546
+ """Post-process using frame-level consensus voting with VAD-aware silence."""
547
+ if not window_segments or len(labels) == 0:
548
+ return []
549
+
550
+ # Correct labels to be contiguous
551
+ unique_labels = np.unique(labels)
552
+ label_map = {old: new for new, old in enumerate(unique_labels)}
553
+ clean_labels = np.array([label_map[lbl] for lbl in labels])
554
+ num_speakers = len(unique_labels)
555
+
556
+ if num_speakers == 0:
557
+ return []
558
+
559
+ # Create voting grid
560
+ num_frames = int(np.ceil(total_duration / cls.VOTING_RATE)) + 1
561
+ votes = np.zeros((num_frames, num_speakers), dtype=np.float32)
562
+
563
+ # Accumulate votes
564
+ for win, label in zip(window_segments, clean_labels):
565
+ start_frame = int(win["start"] / cls.VOTING_RATE)
566
+ end_frame = int(win["end"] / cls.VOTING_RATE)
567
+ end_frame = min(end_frame, num_frames)
568
+ if start_frame < end_frame:
569
+ votes[start_frame:end_frame, label] += 1.0
570
+
571
+ # Determine winner per frame
572
+ frame_speakers = np.argmax(votes, axis=1)
573
+ max_votes = np.max(votes, axis=1)
574
+
575
+ # Resample VAD to voting grid resolution for silence-aware voting
576
+ vad_resampled = cls._resample_vad(vad_frames, num_frames)
577
+
578
+ # Convert frames to segments
579
+ final_segments = []
580
+ current_speaker = -1
581
+ seg_start = 0.0
582
+
583
+ for f in range(num_frames):
584
+ speaker = int(frame_speakers[f])
585
+ score = max_votes[f]
586
+
587
+ # Force silence if VAD says no speech OR no votes
588
+ if score == 0 or not vad_resampled[f]:
589
+ speaker = -1
590
+
591
+ if speaker != current_speaker:
592
+ if current_speaker != -1:
593
+ final_segments.append(
594
+ {
595
+ "speaker": f"SPEAKER_{current_speaker}",
596
+ "start": seg_start,
597
+ "end": f * cls.VOTING_RATE,
598
+ }
599
+ )
600
+ current_speaker = speaker
601
+ seg_start = f * cls.VOTING_RATE
602
+
603
+ # Close last segment
604
+ if current_speaker != -1:
605
+ final_segments.append(
606
+ {
607
+ "speaker": f"SPEAKER_{current_speaker}",
608
+ "start": seg_start,
609
+ "end": num_frames * cls.VOTING_RATE,
610
+ }
611
+ )
612
+
613
+ return cls._merge_short_segments(final_segments)
614
+
615
+ @classmethod
616
+ def _merge_short_segments(cls, segments: list[dict]) -> list[dict]:
617
+ """Merge short segments to reduce flicker."""
618
+ if not segments:
619
+ return []
620
+
621
+ clean: list[dict] = []
622
+ for seg in segments:
623
+ dur = seg["end"] - seg["start"]
624
+ if dur < cls.MIN_SEGMENT_DURATION:
625
+ if (
626
+ clean
627
+ and clean[-1]["speaker"] == seg["speaker"]
628
+ and seg["start"] - clean[-1]["end"] < cls.SHORT_SEGMENT_GAP
629
+ ):
630
+ clean[-1]["end"] = seg["end"]
631
+ continue
632
+
633
+ if (
634
+ clean
635
+ and clean[-1]["speaker"] == seg["speaker"]
636
+ and seg["start"] - clean[-1]["end"] < cls.SAME_SPEAKER_GAP
637
+ ):
638
+ clean[-1]["end"] = seg["end"]
639
+ else:
640
+ clean.append(seg)
641
+
642
+ return clean
643
+
644
+ @classmethod
645
+ def assign_speakers_to_words(
646
+ cls,
647
+ words: list[dict],
648
+ speaker_segments: list[dict],
649
+ ) -> list[dict]:
650
+ """Assign speaker labels to words based on timestamp overlap.
651
+
652
+ Args:
653
+ words: List of word dicts with 'word', 'start', 'end' keys
654
+ speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
655
+
656
+ Returns:
657
+ Words list with 'speaker' key added to each word
658
+ """
659
+ for word in words:
660
+ word_mid = (word["start"] + word["end"]) / 2
661
+
662
+ # Find the speaker segment that contains this word's midpoint
663
+ best_speaker = None
664
+ for seg in speaker_segments:
665
+ if seg["start"] <= word_mid <= seg["end"]:
666
+ best_speaker = seg["speaker"]
667
+ break
668
+
669
+ # If no exact match, find closest segment
670
+ if best_speaker is None and speaker_segments:
671
+ min_dist = float("inf")
672
+ for seg in speaker_segments:
673
+ seg_mid = (seg["start"] + seg["end"]) / 2
674
+ dist = abs(word_mid - seg_mid)
675
+ if dist < min_dist:
676
+ min_dist = dist
677
+ best_speaker = seg["speaker"]
678
+
679
+ word["speaker"] = best_speaker
680
+
681
+ return words
682
+
683
+
684
+ class SpeakerDiarizer:
685
+ """Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
686
+
687
+ Example:
688
+ >>> segments = SpeakerDiarizer.diarize(audio_array)
689
+ >>> for seg in segments:
690
+ ... print(f"{seg['speaker']}: {seg['start']:.2f} - {seg['end']:.2f}")
691
+ """
692
+
693
+ @classmethod
694
+ def diarize(
695
+ cls,
696
+ audio: np.ndarray | str,
697
+ sample_rate: int = 16000,
698
+ num_speakers: int | None = None,
699
+ min_speakers: int | None = None,
700
+ max_speakers: int | None = None,
701
+ **_kwargs,
702
+ ) -> list[dict]:
703
+ """Run speaker diarization on audio.
704
+
705
+ Args:
706
+ audio: Audio waveform as numpy array or path to audio file
707
+ sample_rate: Audio sample rate (default 16000)
708
+ num_speakers: Exact number of speakers (if known)
709
+ min_speakers: Minimum number of speakers
710
+ max_speakers: Maximum number of speakers
711
+
712
+ Returns:
713
+ List of dicts with 'speaker', 'start', 'end' keys
714
+ """
715
+ return LocalSpeakerDiarizer.diarize(
716
+ audio,
717
+ sample_rate=sample_rate,
718
+ num_speakers=num_speakers,
719
+ min_speakers=min_speakers or 2,
720
+ max_speakers=max_speakers or 10,
721
+ )
722
+
723
+ @classmethod
724
+ def assign_speakers_to_words(
725
+ cls,
726
+ words: list[dict],
727
+ speaker_segments: list[dict],
728
+ ) -> list[dict]:
729
+ """Assign speaker labels to words based on timestamp overlap."""
730
+ return LocalSpeakerDiarizer.assign_speakers_to_words(words, speaker_segments)
preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "dither": 0.0,
4
+ "feature_extractor_type": "WhisperFeatureExtractor",
5
+ "feature_size": 128,
6
+ "hop_length": 160,
7
+ "n_fft": 400,
8
+ "n_samples": 480000,
9
+ "nb_max_frames": 3000,
10
+ "padding": false,
11
+ "padding_side": "right",
12
+ "padding_value": 0.0,
13
+ "return_attention_mask": false,
14
+ "sampling_rate": 16000,
15
+ "processor_class": "ASRProcessor",
16
+ "auto_map": {
17
+ "AutoProcessor": "asr_processing.ASRProcessor"
18
+ }
19
+ }
projectors.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio projector modules for bridging encoder and decoder embeddings.
2
+
3
+ This module contains all projector architectures:
4
+ - MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling
5
+ - MOSAProjector: MOSA-style dense mixture of experts
6
+ - SharedMoEAudioProjector: Shared expert + sparse routed experts
7
+ - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
8
+ """
9
+
10
+ import math
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F # noqa: N812
15
+ from transformers import AutoModel, Blip2QFormerConfig
16
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
17
+
18
+ # =============================================================================
19
+ # MLP Projector
20
+ # =============================================================================
21
+
22
+
23
+ class MLPAudioProjector(nn.Module):
24
+ """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR).
25
+
26
+ Both RMSNorms use LlamaRMSNorm's default weight=1.0 init. A prior version
27
+ initialized both to 0.029 (Qwen3-0.6B's embed_tokens RMS) to put projector
28
+ outputs at residual-stream scale on step 1. Empirically, after training the
29
+ model drifted both norms back to ~1.0 (norm) and ~1.2 (norm_2) — the small
30
+ init wasted compute on a 35× scale-correction phase the optimizer would
31
+ have skipped from default init.
32
+ """
33
+
34
+ def __init__(self, config):
35
+ """Initialize MLP projector.
36
+
37
+ Args:
38
+ config: ASRConfig with encoder_dim, llm_dim, projector_pool_stride
39
+ """
40
+ super().__init__()
41
+
42
+ encoder_dim = getattr(config, "encoder_dim", 768)
43
+ llm_dim = getattr(config, "llm_dim", 2048)
44
+ self.k = getattr(config, "projector_pool_stride", 4)
45
+
46
+ # Frame stacking: concat k adjacent frames then project
47
+ in_dim = encoder_dim * self.k
48
+ # Hidden dim defaults to llm_dim, can be overridden via config
49
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim
50
+ self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False)
51
+ self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
52
+ self.act = nn.GELU()
53
+ self.dropout = nn.Dropout(getattr(config, "projector_dropout", 0.0))
54
+ self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
55
+ self.norm_2 = LlamaRMSNorm(llm_dim, eps=1e-6)
56
+
57
+ def get_output_length(self, input_length: int) -> int:
58
+ """Calculate output sequence length given input length (matches GLM-ASR)."""
59
+ # GLM-ASR formula: (L - merge_factor) // merge_factor + 1
60
+ return (input_length - self.k) // self.k + 1
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ """Project audio features to LLM embedding space.
64
+
65
+ Args:
66
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
67
+
68
+ Returns:
69
+ Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim]
70
+ """
71
+ x = _frame_stack(x, self.k)
72
+ x = self.linear_1(x)
73
+ x = self.norm(x)
74
+ x = self.act(x)
75
+ x = self.dropout(x)
76
+ x = self.linear_2(x)
77
+ return self.norm_2(x)
78
+
79
+
80
+ # =============================================================================
81
+ # MoE Projector (MOSA-style)
82
+ # =============================================================================
83
+
84
+
85
+ def _frame_stack(x: torch.Tensor, k: int) -> torch.Tensor:
86
+ """Stack k adjacent frames along the feature dim.
87
+
88
+ Truncates trailing frames that don't fill a complete k-frame window,
89
+ matching GLM-ASR's `(seq_len - k) // k + 1` formula.
90
+ """
91
+ batch, seq, dim = x.shape
92
+ out_len = (seq - k) // k + 1
93
+ return x[:, : out_len * k, :].reshape(batch, out_len, dim * k)
94
+
95
+
96
+ class SimpleAdapter(nn.Module):
97
+ """Simple 2-layer GELU adapter (from MOSA paper)."""
98
+
99
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
100
+ super().__init__()
101
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
102
+ self.act = nn.GELU()
103
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
104
+
105
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
106
+ return self.fc2(self.act(self.fc1(x)))
107
+
108
+
109
+ class MOSAProjector(nn.Module):
110
+ """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
111
+
112
+ Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
113
+ Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
114
+ Uses Conv1d for downsampling (2 layers, stride 2 each = 4x total).
115
+ """
116
+
117
+ ADAPTER_HIDDEN_DIM = 4096
118
+ ROUTER_HIDDEN_DIM = 512
119
+ CONV_KERNEL = 3
120
+ CONV_STRIDE = 2
121
+ CONV_PADDING = 1
122
+
123
+ def __init__(self, config):
124
+ """Initialize MOSA projector.
125
+
126
+ Args:
127
+ config: ASRConfig with encoder_dim, llm_dim, num_experts
128
+ """
129
+ super().__init__()
130
+ self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
131
+ self.llm_dim = getattr(config, "llm_dim", None) or 2048
132
+ self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
133
+
134
+ conv_kwargs = {
135
+ "kernel_size": self.CONV_KERNEL,
136
+ "stride": self.CONV_STRIDE,
137
+ "padding": self.CONV_PADDING,
138
+ }
139
+ self.downsampler = nn.Sequential(
140
+ nn.Conv1d(self.encoder_dim, self.encoder_dim, **conv_kwargs),
141
+ nn.GELU(),
142
+ nn.Conv1d(self.encoder_dim, self.llm_dim, **conv_kwargs),
143
+ nn.GELU(),
144
+ )
145
+
146
+ self.router = nn.Sequential(
147
+ nn.Linear(self.llm_dim, self.ROUTER_HIDDEN_DIM),
148
+ nn.ReLU(),
149
+ nn.Linear(self.ROUTER_HIDDEN_DIM, self.num_experts),
150
+ )
151
+
152
+ self.experts = nn.ModuleList(
153
+ [
154
+ SimpleAdapter(self.llm_dim, self.ADAPTER_HIDDEN_DIM, self.llm_dim)
155
+ for _ in range(self.num_experts)
156
+ ]
157
+ )
158
+
159
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
160
+ """Project audio features using mixture of experts.
161
+
162
+ Args:
163
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
164
+
165
+ Returns:
166
+ Projected features of shape [batch, out_len, llm_dim]
167
+ """
168
+ x = self.downsampler(x.transpose(1, 2)).transpose(1, 2)
169
+
170
+ routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts)
171
+
172
+ # Accumulate weighted expert outputs without materializing all experts at once.
173
+ output = self.experts[0](x) * routing_weights[..., 0:1]
174
+ for i, expert in enumerate(self.experts[1:], start=1):
175
+ output = output + expert(x) * routing_weights[..., i : i + 1]
176
+ return output
177
+
178
+ def get_output_length(self, input_length: int) -> int:
179
+ """Calculate output sequence length after Conv1d downsampling (4x reduction)."""
180
+ length = input_length
181
+ for _ in range(2):
182
+ length = (length + 2 * self.CONV_PADDING - self.CONV_KERNEL) // self.CONV_STRIDE + 1
183
+ return length
184
+
185
+
186
+ # =============================================================================
187
+ # MoE Projector (Pure PyTorch with Shared Expert)
188
+ # =============================================================================
189
+
190
+
191
+ class MoEAudioProjector(nn.Module):
192
+ """MoE projector with shared expert (DeepSeek-style), pure PyTorch implementation.
193
+
194
+ Uses 4 sparse experts with top-2 routing plus a shared expert that processes all tokens.
195
+ No external dependencies (megablocks removed).
196
+
197
+ Architecture matches main branch: norm → experts(in_dim → hidden → out_dim)
198
+ """
199
+
200
+ def __init__(self, config):
201
+ """Initialize MoE projector.
202
+
203
+ Args:
204
+ config: ASRConfig with encoder_dim, llm_dim, num_experts, num_experts_per_tok
205
+ """
206
+ super().__init__()
207
+
208
+ self.k = getattr(config, "projector_pool_stride", 4)
209
+ self.aux_coef = getattr(config, "router_aux_loss_coef", 0.01)
210
+
211
+ # Stability coefficients
212
+ self.router_z_loss_coef = getattr(
213
+ config, "router_z_loss_coef", 1e-4
214
+ ) # Prevents logit explosion
215
+ self.router_jitter_noise = getattr(
216
+ config, "router_jitter_noise", 0.01
217
+ ) # Prevents expert collapse
218
+
219
+ in_dim = config.encoder_dim * self.k
220
+ out_dim = config.llm_dim
221
+
222
+ # Expert hidden dim (default = output dim)
223
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim
224
+
225
+ # Number of experts and top-k selection
226
+ self.num_experts = getattr(config, "num_experts", 4)
227
+ self.top_k = getattr(config, "num_experts_per_tok", 2)
228
+
229
+ # A. Normalize stacked input (like main branch SharedMoEBlock)
230
+ self.norm = LlamaRMSNorm(in_dim, eps=1e-6)
231
+
232
+ # B. Router (operates on stacked input)
233
+ self.router = nn.Linear(in_dim, self.num_experts, bias=False)
234
+
235
+ # C. Experts: simple 2-layer MLP (same as MLPAudioProjector)
236
+ self.experts = nn.ModuleList(
237
+ [SimpleAdapter(in_dim, hidden_dim, out_dim) for _ in range(self.num_experts)]
238
+ )
239
+
240
+ # D. Shared Expert (same architecture)
241
+ self.shared_expert = SimpleAdapter(in_dim, hidden_dim, out_dim)
242
+
243
+ # E. Initialize weights for stable training
244
+ self._init_weights()
245
+
246
+ self.last_aux_loss = torch.tensor(0.0)
247
+
248
+ def _init_weights(self):
249
+ """Initialize weights for stable training start."""
250
+ with torch.no_grad():
251
+ # Router: small weights -> uniform probability
252
+ nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
253
+
254
+ # Experts: xavier for fc1, small for fc2 (output)
255
+ for expert in [self.shared_expert, *self.experts]:
256
+ nn.init.xavier_uniform_(expert.fc1.weight)
257
+ nn.init.normal_(expert.fc2.weight, mean=0.0, std=0.01) # Small init
258
+
259
+ def get_output_length(self, input_length: int) -> int:
260
+ """Calculate output sequence length given input length (matches MLP projector)."""
261
+ return (input_length - self.k) // self.k + 1
262
+
263
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
264
+ """Project audio features using shared + sparse MoE.
265
+
266
+ Args:
267
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
268
+
269
+ Returns:
270
+ Projected features of shape [batch, out_len, llm_dim]
271
+ """
272
+ x = _frame_stack(x, self.k)
273
+ batch, out_len, _ = x.shape
274
+
275
+ # Normalize stacked input (like main branch SharedMoEBlock)
276
+ x = self.norm(x)
277
+ flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim]
278
+
279
+ # 3. Shared Expert (compute first, creates output tensor)
280
+ output = self.shared_expert(flat_x)
281
+
282
+ # 4. Sparse Experts (in-place add to shared output)
283
+ self.last_aux_loss = self._forward_sparse(flat_x, output)
284
+
285
+ return output.view(batch, out_len, -1)
286
+
287
+ def _forward_sparse(self, x: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
288
+ """Stability-hardened sparse expert dispatch (in-place add to output).
289
+
290
+ Args:
291
+ x: Flattened input of shape [tokens, dim]
292
+ output: Output tensor to add sparse expert results into (in-place)
293
+
294
+ Returns:
295
+ Auxiliary loss tensor
296
+ """
297
+ # A. Router Logic with Jitter
298
+ logits = self.router(x)
299
+
300
+ if self.training and self.router_jitter_noise > 0:
301
+ # Jitter: multiply by uniform noise (1-eps, 1+eps) to shake decision boundary
302
+ # Prevents router from getting stuck on one expert early in training
303
+ noise = torch.empty_like(logits).uniform_(
304
+ 1.0 - self.router_jitter_noise, 1.0 + self.router_jitter_noise
305
+ )
306
+ logits = logits * noise
307
+
308
+ # Force float32 for softmax (bf16/fp16 exponentials can overflow)
309
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(x)
310
+
311
+ # B. Top-K Selection
312
+ top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
313
+
314
+ # Normalize weights so they sum to 1.0
315
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6)
316
+
317
+ # C. Aux Loss + Z-Loss
318
+ aux_loss = torch.tensor(0.0, device=x.device)
319
+
320
+ if self.training:
321
+ # Load balancing loss (batch-size invariant)
322
+ prob_per_expert = probs.mean(0) # [num_experts]
323
+ target = 1.0 / self.num_experts
324
+ balance_loss = (
325
+ self.aux_coef * ((prob_per_expert - target) ** 2).mean() * self.num_experts
326
+ )
327
+
328
+ # Z-loss: penalty on large logits to prevent softmax saturation
329
+ z_loss = self.router_z_loss_coef * torch.logsumexp(logits, dim=-1).pow(2).mean()
330
+
331
+ aux_loss = balance_loss + z_loss
332
+
333
+ # D. Dispatch Loop (in-place add to output)
334
+ for i, expert in enumerate(self.experts):
335
+ # Create boolean mask for tokens that selected Expert 'i'
336
+ mask = top_k_indices == i
337
+
338
+ if mask.any():
339
+ # token_idx = which tokens, k_idx = 1st or 2nd choice
340
+ token_idx, k_idx = torch.where(mask)
341
+
342
+ # Gather inputs and compute
343
+ expert_input = x[token_idx]
344
+ expert_output = expert(expert_input)
345
+
346
+ # Apply routing weight
347
+ weight = top_k_weights[token_idx, k_idx].unsqueeze(-1)
348
+ weighted_output = (expert_output * weight).type_as(output)
349
+
350
+ # Scatter back in-place (index_add_ is atomic and deterministic)
351
+ output.index_add_(0, token_idx, weighted_output)
352
+
353
+ return aux_loss
354
+
355
+ def get_aux_loss(self) -> torch.Tensor:
356
+ """Return auxiliary load balancing loss."""
357
+ return self.last_aux_loss
358
+
359
+
360
+ # =============================================================================
361
+ # QFormer Projector (Granite-style)
362
+ # =============================================================================
363
+
364
+
365
+ class QFormerAudioProjector(nn.Module):
366
+ """
367
+ BLIP-2 QFormer projector with learnable queries.
368
+
369
+ Based on GraniteSpeechEncoderProjector - uses a QFormer model with learnable
370
+ query embeddings to compress and project audio encoder outputs. The audio
371
+ sequence is processed in windows and downsampled via cross-attention.
372
+ """
373
+
374
+ def __init__(self, config):
375
+ """Initialize QFormer projector.
376
+
377
+ Args:
378
+ config: ASRConfig with encoder_dim, llm_dim, qformer_* settings
379
+ """
380
+ super().__init__()
381
+
382
+ encoder_dim = config.encoder_dim
383
+ llm_dim = config.llm_dim
384
+
385
+ # Window and downsampling parameters (Granite defaults: window=15, downsample=5)
386
+ self.window_size = getattr(config, "qformer_window_size", 15)
387
+ self.downsample_rate = getattr(config, "downsample_rate", 5)
388
+ self.num_queries = self.window_size // self.downsample_rate
389
+
390
+ # QFormer hidden size (matches encoder for cross-attention)
391
+ qformer_hidden = getattr(config, "qformer_hidden_size", None) or encoder_dim
392
+ qformer_num_layers = getattr(config, "qformer_num_layers", 2)
393
+ qformer_num_heads = getattr(config, "qformer_num_heads", 16)
394
+ qformer_intermediate = getattr(config, "qformer_intermediate_size", None) or (
395
+ qformer_hidden * 4
396
+ )
397
+
398
+ # Learnable query embeddings (Granite uses std=1.0)
399
+ self.query = nn.Parameter(torch.zeros(1, self.num_queries, qformer_hidden))
400
+ self.query.data.normal_(mean=0.0, std=1.0)
401
+
402
+ # Optional projection if encoder dim != qformer hidden
403
+ if encoder_dim != qformer_hidden:
404
+ self.encoder_proj = nn.Linear(encoder_dim, qformer_hidden, bias=False)
405
+ else:
406
+ self.encoder_proj = None
407
+
408
+ # Configure QFormer to match Granite's exact config
409
+ qformer_config = Blip2QFormerConfig(
410
+ hidden_size=qformer_hidden,
411
+ num_hidden_layers=qformer_num_layers,
412
+ num_attention_heads=qformer_num_heads,
413
+ intermediate_size=qformer_intermediate,
414
+ encoder_hidden_size=qformer_hidden,
415
+ cross_attention_frequency=1,
416
+ # Granite-specific settings
417
+ hidden_act="gelu",
418
+ attention_probs_dropout_prob=0.1,
419
+ hidden_dropout_prob=0.1,
420
+ layer_norm_eps=1e-12,
421
+ initializer_range=0.02,
422
+ )
423
+ self.qformer = AutoModel.from_config(qformer_config)
424
+
425
+ # Final projection to LLM dimension (Granite uses bias=True)
426
+ self.linear = nn.Linear(qformer_hidden, llm_dim)
427
+
428
+ def get_output_length(self, input_length):
429
+ """Calculate output sequence length given input length.
430
+
431
+ Accepts either Python ints or torch tensors; uses ceiling division so
432
+ the formula is identical for both — math.ceil would block tensors.
433
+ """
434
+ nblocks = (input_length + self.window_size - 1) // self.window_size
435
+ return nblocks * self.num_queries
436
+
437
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
438
+ """
439
+ Args:
440
+ hidden_states: [batch_size, seq_len, encoder_dim]
441
+
442
+ Returns:
443
+ projected: [batch_size, num_output_tokens, llm_dim]
444
+ """
445
+ batch_size, seq_len, dim = hidden_states.size()
446
+
447
+ # Ensure float dtype for QFormer
448
+ target_dtype = self.query.dtype
449
+ if hidden_states.dtype != target_dtype:
450
+ hidden_states = hidden_states.to(target_dtype)
451
+
452
+ # Optional encoder projection
453
+ if self.encoder_proj is not None:
454
+ hidden_states = self.encoder_proj(hidden_states)
455
+
456
+ # Compute number of windows and pad to fit
457
+ nblocks = math.ceil(seq_len / self.window_size)
458
+ pad = nblocks * self.window_size - seq_len
459
+ if pad > 0:
460
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
461
+
462
+ # Reshape to process each window: [batch*nblocks, window_size, dim]
463
+ effective_batch = batch_size * nblocks
464
+ hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
465
+
466
+ # Expand queries to match batch size
467
+ query_embeds = self.query.expand(effective_batch, -1, -1)
468
+
469
+ # QFormer cross-attention
470
+ query_output = self.qformer(
471
+ query_embeds=query_embeds,
472
+ encoder_hidden_states=hidden_states,
473
+ return_dict=True,
474
+ )
475
+
476
+ # Reshape back: [batch, nblocks * num_queries, hidden]
477
+ output_tokens = nblocks * self.num_queries
478
+ query_proj = query_output.last_hidden_state.view(batch_size, output_tokens, -1)
479
+
480
+ # Project to LLM dimension
481
+ return self.linear(query_proj)
482
+
483
+
484
+ # =============================================================================
485
+ # Projector Registry
486
+ # =============================================================================
487
+
488
+ PROJECTOR_CLASSES = {
489
+ "mlp": MLPAudioProjector,
490
+ "mosa": MOSAProjector,
491
+ "moe": MoEAudioProjector,
492
+ "qformer": QFormerAudioProjector,
493
+ }