mazesmazes commited on
Commit
ae41cb4
·
verified ·
1 Parent(s): 527cc10

Model save

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - generated_from_trainer
5
+ model-index:
6
+ - name: test_s2s_output
7
+ results: []
8
+ ---
9
+
10
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
11
+ should probably proofread and complete it, then remove this comment. -->
12
+
13
+ # test_s2s_output
14
+
15
+ This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
16
+
17
+ ## Model description
18
+
19
+ More information needed
20
+
21
+ ## Intended uses & limitations
22
+
23
+ More information needed
24
+
25
+ ## Training and evaluation data
26
+
27
+ More information needed
28
+
29
+ ## Training procedure
30
+
31
+ ### Training hyperparameters
32
+
33
+ The following hyperparameters were used during training:
34
+ - learning_rate: 0.0001
35
+ - train_batch_size: 16
36
+ - eval_batch_size: 16
37
+ - seed: 42
38
+ - gradient_accumulation_steps: 2
39
+ - total_train_batch_size: 32
40
+ - optimizer: Use OptimizerNames.ADAMW_TORCH_FUSED with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
41
+ - lr_scheduler_type: polynomial
42
+ - lr_scheduler_warmup_steps: 500
43
+ - training_steps: 5
44
+
45
+ ### Training results
46
+
47
+
48
+
49
+ ### Framework versions
50
+
51
+ - Transformers 5.0.0
52
+ - Pytorch 2.8.0
53
+ - Datasets 3.6.0
54
+ - Tokenizers 0.22.2
alignment.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Forced alignment for word-level timestamps using Wav2Vec2."""
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
7
+ # Calibrated on librispeech-alignments dataset (n=25, MAE=48ms)
8
+ START_OFFSET = 0.04 # Subtract from start times (shift earlier)
9
+ END_OFFSET = -0.04 # Subtract from end times (shift later)
10
+
11
+
12
+ def _get_device() -> str:
13
+ """Get best available device for non-transformers models."""
14
+ if torch.cuda.is_available():
15
+ return "cuda"
16
+ if torch.backends.mps.is_available():
17
+ return "mps"
18
+ return "cpu"
19
+
20
+
21
+ class ForcedAligner:
22
+ """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
23
+
24
+ Uses Viterbi trellis algorithm for optimal alignment path finding.
25
+ """
26
+
27
+ _bundle = None
28
+ _model = None
29
+ _labels = None
30
+ _dictionary = None
31
+
32
+ @classmethod
33
+ def get_instance(cls, device: str = "cuda"):
34
+ """Get or create the forced alignment model (singleton).
35
+
36
+ Args:
37
+ device: Device to run model on ("cuda" or "cpu")
38
+
39
+ Returns:
40
+ Tuple of (model, labels, dictionary)
41
+ """
42
+ if cls._model is None:
43
+ import torchaudio
44
+
45
+ cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
46
+ cls._model = cls._bundle.get_model().to(device)
47
+ cls._model.eval()
48
+ cls._labels = cls._bundle.get_labels()
49
+ cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
50
+ return cls._model, cls._labels, cls._dictionary
51
+
52
+ @staticmethod
53
+ def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
54
+ """Build trellis for forced alignment using forward algorithm.
55
+
56
+ The trellis[t, j] represents the log probability of the best path that
57
+ aligns the first j tokens to the first t frames.
58
+
59
+ Args:
60
+ emission: Log-softmax emission matrix of shape (num_frames, num_classes)
61
+ tokens: List of target token indices
62
+ blank_id: Index of the blank/CTC token (default 0)
63
+
64
+ Returns:
65
+ Trellis matrix of shape (num_frames + 1, num_tokens + 1)
66
+ """
67
+ num_frames = emission.size(0)
68
+ num_tokens = len(tokens)
69
+
70
+ trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
71
+ trellis[0, 0] = 0
72
+
73
+ # Force alignment to use all tokens by preventing staying in blank
74
+ # at the end when there are still tokens to emit
75
+ if num_tokens > 1:
76
+ trellis[-num_tokens + 1 :, 0] = float("inf")
77
+
78
+ for t in range(num_frames):
79
+ for j in range(num_tokens + 1):
80
+ # Stay: emit blank and stay at j tokens
81
+ stay = trellis[t, j] + emission[t, blank_id]
82
+
83
+ # Move: emit token j and advance to j+1 tokens
84
+ move = trellis[t, j - 1] + emission[t, tokens[j - 1]] if j > 0 else -float("inf")
85
+
86
+ trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
87
+
88
+ return trellis
89
+
90
+ @staticmethod
91
+ def _backtrack(
92
+ trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
93
+ ) -> list[tuple[int, float, float, float]]:
94
+ """Backtrack through trellis to find optimal forced monotonic alignment.
95
+
96
+ Guarantees:
97
+ - All tokens are emitted exactly once
98
+ - Strictly monotonic: each token's frames come after previous token's
99
+ - No frame skipping or token teleporting
100
+
101
+ Returns list of (token_id, start_frame, end_frame, peak_frame) for each token.
102
+ The peak_frame is the frame with highest emission probability for that token.
103
+ """
104
+ num_frames = emission.size(0)
105
+ num_tokens = len(tokens)
106
+
107
+ if num_tokens == 0:
108
+ return []
109
+
110
+ # Find the best ending point (should be at num_tokens)
111
+ # But verify trellis reached a valid state
112
+ if trellis[num_frames, num_tokens] == -float("inf"):
113
+ # Alignment failed - fall back to uniform distribution
114
+ frames_per_token = num_frames / num_tokens
115
+ return [
116
+ (
117
+ tokens[i],
118
+ i * frames_per_token,
119
+ (i + 1) * frames_per_token,
120
+ (i + 0.5) * frames_per_token,
121
+ )
122
+ for i in range(num_tokens)
123
+ ]
124
+
125
+ # Backtrack: find where each token transition occurred
126
+ # Store (frame, emission_score) for each token
127
+ token_frames: list[list[tuple[int, float]]] = [[] for _ in range(num_tokens)]
128
+
129
+ t = num_frames
130
+ j = num_tokens
131
+
132
+ while t > 0 and j > 0:
133
+ # Check: did we transition from j-1 to j at frame t-1?
134
+ stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
135
+ move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
136
+
137
+ if move_score >= stay_score:
138
+ # Token j-1 was emitted at frame t-1
139
+ # Store frame and its emission probability
140
+ emit_prob = emission[t - 1, tokens[j - 1]].exp().item()
141
+ token_frames[j - 1].insert(0, (t - 1, emit_prob))
142
+ j -= 1
143
+ # Always decrement time (monotonic)
144
+ t -= 1
145
+
146
+ # Handle any remaining tokens at the start (edge case)
147
+ while j > 0:
148
+ token_frames[j - 1].insert(0, (0, 0.0))
149
+ j -= 1
150
+
151
+ # Convert to spans with peak frame
152
+ token_spans: list[tuple[int, float, float, float]] = []
153
+ for token_idx, frames_with_scores in enumerate(token_frames):
154
+ if not frames_with_scores:
155
+ # Token never emitted - assign minimal span after previous
156
+ if token_spans:
157
+ prev_end = token_spans[-1][2]
158
+ frames_with_scores = [(int(prev_end), 0.0)]
159
+ else:
160
+ frames_with_scores = [(0, 0.0)]
161
+
162
+ token_id = tokens[token_idx]
163
+ frames = [f for f, _ in frames_with_scores]
164
+ start_frame = float(min(frames))
165
+ end_frame = float(max(frames)) + 1.0
166
+
167
+ # Find peak frame (highest emission probability)
168
+ peak_frame, _ = max(frames_with_scores, key=lambda x: x[1])
169
+
170
+ token_spans.append((token_id, start_frame, end_frame, float(peak_frame)))
171
+
172
+ return token_spans
173
+
174
+ @classmethod
175
+ def align(
176
+ cls,
177
+ audio: np.ndarray,
178
+ text: str,
179
+ sample_rate: int = 16000,
180
+ _language: str = "eng",
181
+ _batch_size: int = 16,
182
+ ) -> list[dict]:
183
+ """Align transcript to audio and return word-level timestamps.
184
+
185
+ Uses Viterbi trellis algorithm for optimal forced alignment.
186
+
187
+ Args:
188
+ audio: Audio waveform as numpy array
189
+ text: Transcript text to align
190
+ sample_rate: Audio sample rate (default 16000)
191
+ _language: ISO-639-3 language code (default "eng" for English, unused)
192
+ _batch_size: Batch size for alignment model (unused)
193
+
194
+ Returns:
195
+ List of dicts with 'word', 'start', 'end' keys
196
+ """
197
+ import torchaudio
198
+
199
+ device = _get_device()
200
+ model, _labels, dictionary = cls.get_instance(device)
201
+ assert cls._bundle is not None and dictionary is not None # Initialized by get_instance
202
+
203
+ # Convert audio to tensor (copy to ensure array is writable)
204
+ if isinstance(audio, np.ndarray):
205
+ waveform = torch.from_numpy(audio.copy()).float()
206
+ else:
207
+ waveform = audio.clone().float()
208
+
209
+ # Ensure 2D (channels, time)
210
+ if waveform.dim() == 1:
211
+ waveform = waveform.unsqueeze(0)
212
+
213
+ # Resample if needed (wav2vec2 expects 16kHz)
214
+ if sample_rate != cls._bundle.sample_rate:
215
+ waveform = torchaudio.functional.resample(
216
+ waveform, sample_rate, cls._bundle.sample_rate
217
+ )
218
+
219
+ waveform = waveform.to(device)
220
+
221
+ # Get emissions from model
222
+ with torch.inference_mode():
223
+ emissions, _ = model(waveform)
224
+ emissions = torch.log_softmax(emissions, dim=-1)
225
+
226
+ emission = emissions[0].cpu()
227
+
228
+ # Normalize text: uppercase, keep only valid characters
229
+ transcript = text.upper()
230
+
231
+ # Build tokens from transcript (including word separators)
232
+ tokens = []
233
+ for char in transcript:
234
+ if char in dictionary:
235
+ tokens.append(dictionary[char])
236
+ elif char == " ":
237
+ tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
238
+
239
+ if not tokens:
240
+ return []
241
+
242
+ # Build Viterbi trellis and backtrack for optimal path
243
+ trellis = cls._get_trellis(emission, tokens, blank_id=0)
244
+ alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
245
+
246
+ # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
247
+ frame_duration = 320 / cls._bundle.sample_rate
248
+
249
+ # Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
250
+ start_offset = START_OFFSET
251
+ end_offset = END_OFFSET
252
+
253
+ # Group aligned tokens into words based on pipe separator
254
+ # Use peak emission frame for more accurate word boundaries
255
+ words = text.split()
256
+ word_timestamps = []
257
+ first_char_peak = None
258
+ last_char_peak = None
259
+ word_idx = 0
260
+ separator_id = dictionary.get("|", dictionary.get(" ", 0))
261
+
262
+ for token_id, _start_frame, _end_frame, peak_frame in alignment_path:
263
+ if token_id == separator_id: # Word separator
264
+ if (
265
+ first_char_peak is not None
266
+ and last_char_peak is not None
267
+ and word_idx < len(words)
268
+ ):
269
+ # Use peak frames for word boundaries
270
+ start_time = max(0.0, first_char_peak * frame_duration - start_offset)
271
+ end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
272
+ word_timestamps.append(
273
+ {
274
+ "word": words[word_idx],
275
+ "start": start_time,
276
+ "end": end_time,
277
+ }
278
+ )
279
+ word_idx += 1
280
+ first_char_peak = None
281
+ last_char_peak = None
282
+ else:
283
+ if first_char_peak is None:
284
+ first_char_peak = peak_frame
285
+ last_char_peak = peak_frame
286
+
287
+ # Don't forget the last word
288
+ if first_char_peak is not None and last_char_peak is not None and word_idx < len(words):
289
+ start_time = max(0.0, first_char_peak * frame_duration - start_offset)
290
+ end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
291
+ word_timestamps.append(
292
+ {
293
+ "word": words[word_idx],
294
+ "start": start_time,
295
+ "end": end_time,
296
+ }
297
+ )
298
+
299
+ return word_timestamps
asr_config.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import transformers
4
+
5
+
6
+ class ASRConfig(transformers.PretrainedConfig):
7
+ """Configuration class for the ASR model."""
8
+
9
+ model_type = "asr_model"
10
+ is_composition = True
11
+
12
+ # Generation defaults
13
+ GENERATION_DEFAULTS = {
14
+ "num_beams": 1,
15
+ "max_new_tokens": 128,
16
+ "min_new_tokens": 0,
17
+ "repetition_penalty": 1.0,
18
+ "length_penalty": 1.0,
19
+ "no_repeat_ngram_size": 0,
20
+ "use_cache": True,
21
+ "do_sample": False,
22
+ "temperature": None,
23
+ "top_p": None,
24
+ "top_k": None,
25
+ }
26
+
27
+ def __init__(
28
+ self,
29
+ # Model IDs
30
+ audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
31
+ text_model_id: str = "Qwen/Qwen3-0.6B",
32
+ # Model settings
33
+ attn_implementation: str = "sdpa",
34
+ model_dtype: str = "bfloat16",
35
+ system_prompt: str = "You are a helpful assistant.",
36
+ enable_thinking: bool = False,
37
+ # Encoder settings (auto-detected if None)
38
+ encoder_dim: Optional[int] = None,
39
+ llm_dim: Optional[int] = None,
40
+ encoder_conv_layers: Optional[list] = None,
41
+ audio_sample_rate: int = 16000,
42
+ # Projector settings
43
+ projector_type: str = "mlp",
44
+ projector_pool_stride: int = 4,
45
+ projector_hidden_dim: Optional[int] = None,
46
+ projector_num_layers: int = 2,
47
+ projector_init_std: float = 0.02,
48
+ projector_dropout: float = 0.0,
49
+ # MoE projector settings
50
+ num_experts: int = 4,
51
+ num_experts_per_tok: int = 2,
52
+ router_aux_loss_coef: float = 0.01,
53
+ # QFormer projector settings
54
+ qformer_window_size: int = 15,
55
+ qformer_hidden_size: Optional[int] = None,
56
+ qformer_num_layers: int = 2,
57
+ qformer_num_heads: int = 16,
58
+ qformer_intermediate_size: Optional[int] = None,
59
+ downsample_rate: int = 5,
60
+ # Training settings (not saved to config.json for inference)
61
+ use_specaugment: bool = False,
62
+ num_time_masks: int = 2,
63
+ time_mask_length: int = 10,
64
+ num_freq_masks: int = 0,
65
+ freq_mask_length: int = 10,
66
+ use_lora: bool = False,
67
+ lora_rank: int = 8,
68
+ lora_alpha: int = 32,
69
+ lora_dropout: float = 0.0,
70
+ lora_target_modules: Optional[list] = None,
71
+ freeze_projector: bool = False,
72
+ label_smoothing: float = 0.0,
73
+ # Audio Head settings (flow matching with pocket-tts)
74
+ use_audio_head: bool = False,
75
+ freeze_audio_head: bool = False, # Freeze entire audio head
76
+ lsd_decode_steps: int = 1, # LSD decoding integration steps
77
+ flow_temperature: float = 1.0, # Sampling temperature for flow generation
78
+ pocket_tts_weights: Optional[str] = None, # Path to pretrained pocket-tts weights
79
+ freeze_flow_net: bool = True, # Freeze flow_net, only train llm_proj
80
+ **kwargs,
81
+ ):
82
+ # Merge generation defaults with kwargs (kwargs takes precedence)
83
+ for key, default in self.GENERATION_DEFAULTS.items():
84
+ if key not in kwargs:
85
+ kwargs[key] = default
86
+
87
+ # Core model settings
88
+ self.audio_model_id = audio_model_id
89
+ self.text_model_id = text_model_id
90
+ self.attn_implementation = attn_implementation
91
+ self.model_dtype = model_dtype
92
+ self.system_prompt = system_prompt
93
+ self.enable_thinking = enable_thinking
94
+
95
+ # Encoder settings
96
+ self.encoder_dim = encoder_dim
97
+ self.llm_dim = llm_dim
98
+ self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
99
+ self.audio_sample_rate = audio_sample_rate
100
+
101
+ # Projector settings
102
+ self.projector_type = projector_type
103
+ self.projector_pool_stride = projector_pool_stride
104
+ self.projector_hidden_dim = projector_hidden_dim
105
+ self.projector_num_layers = projector_num_layers
106
+ self.projector_init_std = projector_init_std
107
+ self.projector_dropout = projector_dropout
108
+
109
+ # MoE settings
110
+ self.num_experts = num_experts
111
+ self.num_experts_per_tok = num_experts_per_tok
112
+ self.router_aux_loss_coef = router_aux_loss_coef
113
+
114
+ # QFormer settings
115
+ self.qformer_window_size = qformer_window_size
116
+ self.qformer_hidden_size = qformer_hidden_size
117
+ self.qformer_num_layers = qformer_num_layers
118
+ self.qformer_num_heads = qformer_num_heads
119
+ self.qformer_intermediate_size = qformer_intermediate_size
120
+ self.downsample_rate = downsample_rate
121
+
122
+ # Training settings
123
+ self.use_specaugment = use_specaugment
124
+ self.num_time_masks = num_time_masks
125
+ self.time_mask_length = time_mask_length
126
+ self.num_freq_masks = num_freq_masks
127
+ self.freq_mask_length = freq_mask_length
128
+ self.use_lora = use_lora
129
+ self.lora_rank = lora_rank
130
+ self.lora_alpha = lora_alpha
131
+ self.lora_dropout = lora_dropout
132
+ self.lora_target_modules = lora_target_modules or [
133
+ "q_proj",
134
+ "k_proj",
135
+ "v_proj",
136
+ "o_proj",
137
+ "gate_proj",
138
+ "up_proj",
139
+ "down_proj",
140
+ ]
141
+ self.freeze_projector = freeze_projector
142
+ self.label_smoothing = label_smoothing
143
+
144
+ # Audio Head settings (flow matching with pocket-tts)
145
+ self.use_audio_head = use_audio_head
146
+ self.freeze_audio_head = freeze_audio_head
147
+ self.lsd_decode_steps = lsd_decode_steps
148
+ self.flow_temperature = flow_temperature
149
+ self.pocket_tts_weights = pocket_tts_weights
150
+ self.freeze_flow_net = freeze_flow_net
151
+
152
+ # Generation parameters (from kwargs after merge with defaults)
153
+ self.num_beams = kwargs.pop("num_beams")
154
+ self.max_new_tokens = kwargs.pop("max_new_tokens")
155
+ self.min_new_tokens = kwargs.pop("min_new_tokens")
156
+ self.repetition_penalty = kwargs.pop("repetition_penalty")
157
+ self.length_penalty = kwargs.pop("length_penalty")
158
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size")
159
+ self.use_cache = kwargs.pop("use_cache")
160
+ self.do_sample = kwargs.pop("do_sample")
161
+ self.temperature = kwargs.pop("temperature")
162
+ self.top_p = kwargs.pop("top_p")
163
+ self.top_k = kwargs.pop("top_k")
164
+
165
+ # Load sub-configs
166
+ self.audio_config = kwargs.pop("audio_config", None)
167
+ if self.audio_config is None:
168
+ self.audio_config = transformers.AutoConfig.from_pretrained(
169
+ audio_model_id, trust_remote_code=True
170
+ )
171
+ self.audio_config.dtype = model_dtype
172
+ elif isinstance(self.audio_config, dict) and self.audio_config.get("model_type"):
173
+ config_class = transformers.AutoConfig.for_model(
174
+ self.audio_config["model_type"]
175
+ ).__class__
176
+ self.audio_config = config_class(**self.audio_config)
177
+
178
+ self.text_config = kwargs.pop("text_config", None)
179
+ if self.text_config is None:
180
+ self.text_config = transformers.AutoConfig.from_pretrained(
181
+ text_model_id, trust_remote_code=True
182
+ )
183
+ self.text_config.dtype = model_dtype
184
+ elif isinstance(self.text_config, dict):
185
+ config_class = transformers.AutoConfig.for_model(
186
+ self.text_config["model_type"]
187
+ ).__class__
188
+ self.text_config = config_class(**self.text_config)
189
+
190
+ super().__init__(**kwargs)
191
+
192
+ # Pipeline configuration
193
+ self.encoder = self.audio_config
194
+ self.auto_map = {
195
+ "AutoConfig": "asr_config.ASRConfig",
196
+ "AutoModel": "asr_modeling.ASRModel",
197
+ "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
198
+ "AutoProcessor": "asr_processing.ASRProcessor",
199
+ }
200
+ self.custom_pipelines = {
201
+ "automatic-speech-recognition": {
202
+ "impl": "asr_pipeline.ASRPipeline",
203
+ "pt": ["AutoModelForSpeechSeq2Seq"],
204
+ "tf": [],
205
+ "type": "audio",
206
+ }
207
+ }
208
+ self.architectures = ["ASRModel"]
209
+ self.pipeline_tag = "automatic-speech-recognition"
210
+
211
+
212
+ transformers.AutoConfig.register("asr_model", ASRConfig)
asr_modeling.py ADDED
@@ -0,0 +1,1110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from threading import Thread
4
+ from typing import Iterator, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModel,
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ PreTrainedModel,
14
+ TextIteratorStreamer,
15
+ )
16
+ from transformers.generation import GenerationMixin
17
+ from transformers.modeling_outputs import CausalLMOutputWithPast
18
+
19
+ try:
20
+ from .asr_config import ASRConfig
21
+ from .projectors import PROJECTOR_CLASSES
22
+ except ImportError:
23
+ from asr_config import ASRConfig # type: ignore[no-redef]
24
+ from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
25
+
26
+
27
+ from torchaudio.transforms import SpecAugment
28
+
29
+
30
+ class ASRModel(PreTrainedModel, GenerationMixin):
31
+ """Audio-to-text model combining an audio encoder, projector, and language model."""
32
+
33
+ config_class = ASRConfig
34
+ base_model_prefix = "model"
35
+ main_input_name = "input_features"
36
+ _supports_flash_attn_2 = True
37
+ supports_gradient_checkpointing = True
38
+ _is_loading_from_pretrained: bool = False
39
+ _pretrained_model_path: Optional[str] = None
40
+
41
+ TRANSCRIBE_PROMPT = ""
42
+
43
+ @classmethod
44
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
45
+ """Load model from pretrained, handling device placement correctly."""
46
+ from safetensors.torch import load_file
47
+ from transformers.utils.hub import cached_file
48
+
49
+ config = kwargs.pop("config", None)
50
+ if config is None:
51
+ config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
52
+
53
+ # Set flag to avoid device_map="auto" in sub-model loaders
54
+ cls._is_loading_from_pretrained = True
55
+ cls._pretrained_model_path = pretrained_model_name_or_path
56
+
57
+ try:
58
+ model = cls(config, **kwargs)
59
+
60
+ # Load projector weights from safetensors
61
+ subfolder = kwargs.get("subfolder")
62
+ revision = kwargs.get("revision")
63
+ cache_kwargs = {}
64
+ if subfolder:
65
+ cache_kwargs["subfolder"] = subfolder
66
+ if revision:
67
+ cache_kwargs["revision"] = revision
68
+
69
+ model_file = cached_file(
70
+ pretrained_model_name_or_path,
71
+ "model.safetensors",
72
+ _raise_exceptions_for_missing_entries=False,
73
+ **cache_kwargs,
74
+ )
75
+
76
+ if model_file is not None:
77
+ state_dict = load_file(model_file)
78
+ model.load_state_dict(state_dict, strict=False)
79
+
80
+ # Load LoRA adapters if use_lora is enabled
81
+ if getattr(config, "use_lora", False):
82
+ # Check for adapter_config.json (required by PEFT to load adapters)
83
+ adapter_config_file = cached_file(
84
+ pretrained_model_name_or_path,
85
+ "adapter_config.json",
86
+ _raise_exceptions_for_missing_entries=False,
87
+ **cache_kwargs,
88
+ )
89
+ if adapter_config_file is not None:
90
+ # Load saved adapter weights using the original repo_id/path
91
+ # PEFT handles Hub downloads and caching internally
92
+ from peft import PeftModel
93
+
94
+ model.language_model = PeftModel.from_pretrained(
95
+ model.language_model,
96
+ pretrained_model_name_or_path,
97
+ is_trainable=True,
98
+ **cache_kwargs,
99
+ )
100
+ else:
101
+ # No saved adapters - initialize fresh LLM LoRA for training
102
+ from peft import LoraConfig, get_peft_model
103
+
104
+ lora_config = LoraConfig(
105
+ r=config.lora_rank,
106
+ lora_alpha=config.lora_alpha,
107
+ target_modules=config.lora_target_modules,
108
+ lora_dropout=config.lora_dropout,
109
+ bias="none",
110
+ task_type="CAUSAL_LM",
111
+ )
112
+ model.language_model = get_peft_model(model.language_model, lora_config)
113
+
114
+ return model
115
+ finally:
116
+ cls._is_loading_from_pretrained = False
117
+ cls._pretrained_model_path = None
118
+
119
+ def __init__(self, config: ASRConfig, **kwargs) -> None:
120
+ super().__init__(config)
121
+
122
+ self.system_prompt = config.system_prompt
123
+ target_dtype = getattr(torch, config.model_dtype)
124
+
125
+ # Audio encoder (frozen)
126
+ self.audio_tower = self._load_audio_encoder(config, target_dtype)
127
+
128
+ # Language model (frozen)
129
+ self.language_model = self._load_language_model(config, target_dtype)
130
+
131
+ # Initialize tokenizer and special tokens
132
+ self._init_tokenizer(config)
133
+
134
+ # Set up generation config with greedy decoding defaults
135
+ self.generation_config = self.language_model.generation_config
136
+ self.generation_config.max_new_tokens = config.max_new_tokens
137
+ self.generation_config.min_new_tokens = config.min_new_tokens
138
+ self.generation_config.num_beams = config.num_beams
139
+ self.generation_config.do_sample = config.do_sample
140
+ # Set sampling params from config (None means use model defaults)
141
+ self.generation_config.temperature = config.temperature
142
+ self.generation_config.top_p = config.top_p
143
+ self.generation_config.top_k = config.top_k
144
+ self.generation_config.use_cache = config.use_cache
145
+ self.generation_config.length_penalty = config.length_penalty
146
+ self.generation_config.repetition_penalty = config.repetition_penalty
147
+ self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
148
+ # Set EOS tokens, filtering out any that don't exist in the tokenizer
149
+ eos_candidates = [
150
+ self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
151
+ self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
152
+ ]
153
+ self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None]
154
+ self.generation_config.pad_token_id = self.tokenizer.pad_token_id
155
+
156
+ # Feature extractor for audio preprocessing
157
+ self.feature_extractor = self._create_feature_extractor(config)
158
+
159
+ # Audio projector (trainable unless freeze_projector is set)
160
+ self.projector = self._create_projector(config, target_dtype)
161
+
162
+ # Setup LoRA if enabled (Stage 2 fine-tuning)
163
+ # Skip if loading from pretrained - from_pretrained will handle adapter loading
164
+ if getattr(config, "use_lora", False) and not getattr(
165
+ self.__class__, "_is_loading_from_pretrained", False
166
+ ):
167
+ self._setup_lora(config)
168
+
169
+ # Freeze projector if specified (for Stage 2 LoRA-only training)
170
+ if getattr(config, "freeze_projector", False):
171
+ self.projector.requires_grad_(False)
172
+
173
+ # SpecAugment for data augmentation during training
174
+ if getattr(config, "use_specaugment", False):
175
+ self.spec_augment = SpecAugment(
176
+ n_time_masks=config.num_time_masks,
177
+ time_mask_param=config.time_mask_length,
178
+ n_freq_masks=config.num_freq_masks,
179
+ freq_mask_param=config.freq_mask_length,
180
+ )
181
+ else:
182
+ self.spec_augment = None
183
+
184
+ # Audio head for S2S (flow matching)
185
+ if getattr(config, "use_audio_head", False):
186
+ from .audio_head import AudioHead
187
+
188
+ device = next(self.language_model.parameters()).device
189
+ llm_dim = self.language_model.config.hidden_size
190
+
191
+ self.audio_head = AudioHead(config, llm_dim=llm_dim).to(
192
+ device=device, dtype=target_dtype
193
+ )
194
+
195
+ # Load pretrained pocket-tts flow_net if configured
196
+ pocket_tts_weights = getattr(config, "pocket_tts_weights", None)
197
+ freeze_flow_net = getattr(config, "freeze_flow_net", True)
198
+ if pocket_tts_weights is not None or freeze_flow_net:
199
+ # If freeze_flow_net is True but no weights specified, download from HF
200
+ self.audio_head.load_pretrained_flow_net(
201
+ weights_path=pocket_tts_weights,
202
+ freeze=freeze_flow_net,
203
+ )
204
+
205
+ if getattr(config, "freeze_audio_head", False):
206
+ self.audio_head.requires_grad_(False)
207
+ else:
208
+ self.audio_head = None
209
+
210
+ # For model parallelism
211
+ self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
212
+
213
+ def _create_feature_extractor(self, config: ASRConfig):
214
+ """Create the appropriate feature extractor for the audio encoder."""
215
+ from transformers import AutoFeatureExtractor
216
+
217
+ feature_extractor = AutoFeatureExtractor.from_pretrained(config.audio_model_id)
218
+ # Disable padding by default - use actual audio length
219
+ feature_extractor.padding = False
220
+ return feature_extractor
221
+
222
+ @classmethod
223
+ def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
224
+ """Load and freeze the audio encoder."""
225
+ encoder_kwargs = {
226
+ "attn_implementation": config.attn_implementation,
227
+ "low_cpu_mem_usage": True,
228
+ "torch_dtype": dtype,
229
+ }
230
+
231
+ if "whisper" in config.audio_model_id.lower():
232
+ from transformers import WhisperModel
233
+
234
+ full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
235
+ encoder = full_model.encoder
236
+ del full_model
237
+ elif "glm" in config.audio_model_id.lower():
238
+ # GLM-ASR models use audio_tower as the encoder
239
+ # Requires transformers >= 5.x or installed from source
240
+ from transformers import AutoModelForSeq2SeqLM
241
+
242
+ full_model = AutoModelForSeq2SeqLM.from_pretrained(
243
+ config.audio_model_id, trust_remote_code=True, **encoder_kwargs
244
+ )
245
+ # GLM stores encoder at audio_tower (GlmAsrEncoder)
246
+ encoder = full_model.audio_tower
247
+ # Clear references to free VRAM from the LLM decoder
248
+ full_model.language_model = None
249
+ full_model.multi_modal_projector = None
250
+ del full_model
251
+ else:
252
+ encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
253
+
254
+ encoder.requires_grad_(False)
255
+ encoder.eval()
256
+ return encoder
257
+
258
+ @classmethod
259
+ def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel:
260
+ """Load and freeze the language model."""
261
+ decoder_kwargs = {
262
+ "attn_implementation": config.attn_implementation,
263
+ "trust_remote_code": True,
264
+ "low_cpu_mem_usage": True,
265
+ "dtype": dtype,
266
+ }
267
+
268
+ decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
269
+ decoder.config.use_cache = getattr(config, "use_cache", True)
270
+ decoder.requires_grad_(False)
271
+ decoder.eval()
272
+ return decoder
273
+
274
+ def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
275
+ """Create the trainable audio projector."""
276
+ # Auto-detect dimensions if not specified
277
+ if config.encoder_dim is None:
278
+ enc_cfg = self.audio_tower.config
279
+ config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr(
280
+ enc_cfg, "d_model", None
281
+ )
282
+ if config.encoder_dim is None:
283
+ raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
284
+
285
+ if config.llm_dim is None:
286
+ dec_cfg = self.language_model.config
287
+ config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr(
288
+ dec_cfg, "d_model", None
289
+ )
290
+ if config.llm_dim is None:
291
+ raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
292
+
293
+ # Select projector type based on config
294
+ projector_type = getattr(config, "projector_type", "mlp")
295
+ projector_class = PROJECTOR_CLASSES.get(projector_type)
296
+ if projector_class is None:
297
+ raise ValueError(
298
+ f"Unknown projector_type: {projector_type}. "
299
+ f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
300
+ )
301
+ projector = projector_class(config)
302
+
303
+ # Move projector to same device as language model (important when using quantization)
304
+ device = next(self.language_model.parameters()).device
305
+ return projector.to(device=device, dtype=dtype)
306
+
307
+ def _setup_lora(self, config: ASRConfig):
308
+ """Apply LoRA adapters to the language model for Stage 2 fine-tuning."""
309
+ from peft import LoraConfig, get_peft_model
310
+
311
+ lora_config = LoraConfig(
312
+ r=config.lora_rank,
313
+ lora_alpha=config.lora_alpha,
314
+ target_modules=config.lora_target_modules,
315
+ lora_dropout=config.lora_dropout,
316
+ bias="none",
317
+ task_type="CAUSAL_LM",
318
+ )
319
+ self.language_model = get_peft_model(self.language_model, lora_config)
320
+
321
+ def _init_tokenizer(self, config: ASRConfig):
322
+ """Initialize tokenizer with audio token."""
323
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
324
+
325
+ # Set pad token
326
+ if (
327
+ self.tokenizer.pad_token is None
328
+ or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
329
+ ) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
330
+ self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
331
+
332
+ # Add audio token
333
+ existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or []
334
+ if "<audio>" not in existing_special:
335
+ self.tokenizer.add_special_tokens(
336
+ {"additional_special_tokens": existing_special + ["<audio>"]}
337
+ )
338
+ self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
339
+
340
+ self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
341
+ self.tokenizer.padding_side = "right"
342
+
343
+ # Sync token IDs to configs
344
+ for cfg in [self.config.text_config, self.language_model.config, self.generation_config]:
345
+ if cfg is not None:
346
+ cfg.pad_token_id = self.tokenizer.pad_token_id
347
+ cfg.eos_token_id = self.tokenizer.eos_token_id
348
+ cfg.bos_token_id = self.tokenizer.bos_token_id
349
+
350
+ def _init_weights(self, _module):
351
+ """Weight initialization (projector weights are initialized in MoEAudioProjector)."""
352
+ pass
353
+
354
+ def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
355
+ """Enable/disable gradient checkpointing for the language model."""
356
+ # The LLM still stores activations during forward for backprop to projector
357
+ # Gradient checkpointing trades compute for memory by recomputing activations
358
+ if hasattr(self.language_model, "_set_gradient_checkpointing"):
359
+ self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
360
+ elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
361
+ self.language_model.gradient_checkpointing_enable(
362
+ gradient_checkpointing_kwargs={"use_reentrant": False}
363
+ )
364
+ elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
365
+ self.language_model.gradient_checkpointing_disable()
366
+
367
+ def get_input_embeddings(self) -> nn.Module:
368
+ return self.language_model.get_input_embeddings()
369
+
370
+ def set_input_embeddings(self, value: nn.Module) -> None:
371
+ self.language_model.set_input_embeddings(value)
372
+
373
+ def get_output_embeddings(self) -> nn.Module:
374
+ return self.language_model.get_output_embeddings()
375
+
376
+ def set_output_embeddings(self, value: nn.Module) -> None:
377
+ self.language_model.set_output_embeddings(value)
378
+
379
+ def get_processor(self):
380
+ """Get the processor for this model."""
381
+ try:
382
+ from .asr_processing import ASRProcessor
383
+ except ImportError:
384
+ from asr_processing import ASRProcessor # type: ignore[no-redef]
385
+
386
+ return ASRProcessor(
387
+ feature_extractor=self.feature_extractor,
388
+ tokenizer=self.tokenizer,
389
+ projector=self.projector,
390
+ encoder_conv_layers=self.config.encoder_conv_layers,
391
+ )
392
+
393
+ def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
394
+ """Save trainable weights (projector + audio_head if present)."""
395
+ state = {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
396
+ if self.audio_head is not None:
397
+ state.update({f"audio_head.{k}": v for k, v in self.audio_head.state_dict().items()})
398
+ return state
399
+
400
+ def _compute_encoder_output_lengths(
401
+ self,
402
+ audio_attention_mask: torch.Tensor,
403
+ ) -> torch.Tensor:
404
+ """Compute per-sample encoder output lengths using conv layer formulas.
405
+
406
+ Args:
407
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
408
+
409
+ Returns:
410
+ Tensor of encoder output lengths per sample (batch,)
411
+ """
412
+ # Get mel frame lengths from attention mask
413
+ lengths = audio_attention_mask.sum(dim=-1)
414
+
415
+ # Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
416
+ for padding, kernel_size, stride in self.config.encoder_conv_layers:
417
+ lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
418
+
419
+ return lengths
420
+
421
+ def _encode_audio(
422
+ self,
423
+ audio_features: torch.Tensor,
424
+ audio_attention_mask: torch.Tensor,
425
+ expected_token_counts: torch.Tensor | None = None,
426
+ ) -> torch.Tensor:
427
+ """Encode audio and project to LLM embedding space.
428
+
429
+ Args:
430
+ audio_features: Mel spectrogram features (batch, n_mels, mel_len)
431
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
432
+ expected_token_counts: Expected number of audio tokens per sample from input_ids.
433
+ If provided, output will match these counts exactly (padding/truncating as needed).
434
+
435
+ Returns:
436
+ Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
437
+ """
438
+ with torch.no_grad():
439
+ encoder_out = self.audio_tower(input_features=audio_features)
440
+ hidden_states = encoder_out.last_hidden_state
441
+
442
+ # Project to LLM space
443
+ audio_embeds = self.projector(hidden_states)
444
+
445
+ # Use expected token counts if provided (from input_ids), otherwise compute from audio
446
+ if expected_token_counts is not None:
447
+ token_counts = expected_token_counts
448
+ else:
449
+ # Compute per-sample encoder output lengths using conv formulas
450
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
451
+ token_counts = torch.tensor(
452
+ [
453
+ self.projector.get_output_length(int(length.item()))
454
+ for length in encoder_lengths
455
+ ],
456
+ device=audio_embeds.device,
457
+ )
458
+
459
+ # Extract embeddings matching expected token counts per sample
460
+ batch_size = audio_embeds.shape[0]
461
+ hidden_dim = audio_embeds.shape[2]
462
+
463
+ result_embeds = []
464
+ for i in range(batch_size):
465
+ count = int(token_counts[i].item())
466
+ sample_embeds = audio_embeds[i, :count, :] # Take first 'count' embeddings
467
+ # Pad with zeros if we don't have enough embeddings
468
+ if sample_embeds.shape[0] < count:
469
+ padding = torch.zeros(
470
+ count - sample_embeds.shape[0],
471
+ hidden_dim,
472
+ device=audio_embeds.device,
473
+ dtype=audio_embeds.dtype,
474
+ )
475
+ sample_embeds = torch.cat([sample_embeds, padding], dim=0)
476
+ result_embeds.append(sample_embeds)
477
+
478
+ return torch.cat(result_embeds, dim=0)
479
+
480
+ def forward(
481
+ self,
482
+ input_ids: Optional[torch.Tensor] = None,
483
+ input_features: Optional[torch.Tensor] = None,
484
+ audio_attention_mask: Optional[torch.Tensor] = None,
485
+ attention_mask: Optional[torch.Tensor] = None,
486
+ position_ids: Optional[torch.Tensor] = None,
487
+ past_key_values: Optional[torch.Tensor] = None,
488
+ inputs_embeds: Optional[torch.Tensor] = None,
489
+ labels: Optional[torch.Tensor] = None,
490
+ use_cache: Optional[bool] = None,
491
+ cache_position: Optional[torch.Tensor] = None,
492
+ latent_targets: Optional[torch.Tensor] = None,
493
+ latent_lengths: Optional[torch.Tensor] = None,
494
+ **kwargs,
495
+ ) -> CausalLMOutputWithPast:
496
+ """Forward pass for training and inference."""
497
+ # Get text embeddings if not provided
498
+ if inputs_embeds is None:
499
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
500
+
501
+ if input_features is not None and input_ids is not None:
502
+ # Apply SpecAugment during training if enabled
503
+ if self.training and self.spec_augment is not None:
504
+ input_features = self.spec_augment(input_features)
505
+
506
+ # Count expected audio tokens from input_ids (ground truth from collator)
507
+ audio_token_counts = (input_ids == self.audio_token_id).sum(dim=-1)
508
+
509
+ # Encode audio -> flattened (total_audio_tokens, hidden_dim)
510
+ audio_embeds = self._encode_audio(
511
+ input_features, audio_attention_mask, audio_token_counts
512
+ )
513
+
514
+ # Replace <audio> token placeholders with audio embeddings using masked_scatter
515
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
516
+
517
+ inputs_embeds = inputs_embeds.masked_scatter(
518
+ audio_token_mask.to(inputs_embeds.device),
519
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
520
+ )
521
+
522
+ # Request hidden states if training audio head with latent targets
523
+ if self.audio_head is not None and latent_targets is not None:
524
+ kwargs["output_hidden_states"] = True
525
+
526
+ # Remove TRL-specific keys that shouldn't go to the LLM
527
+ kwargs.pop("prompts", None)
528
+ kwargs.pop("prompt_attention_mask", None)
529
+
530
+ # Run through language model (let it compute loss if labels provided)
531
+ outputs = self.language_model(
532
+ attention_mask=attention_mask,
533
+ position_ids=position_ids,
534
+ past_key_values=past_key_values,
535
+ inputs_embeds=inputs_embeds,
536
+ labels=labels,
537
+ use_cache=use_cache,
538
+ cache_position=cache_position,
539
+ **kwargs,
540
+ )
541
+
542
+ # Add auxiliary loss from MoE projectors if available
543
+ if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
544
+ aux_loss = self.projector.get_aux_loss()
545
+ if aux_loss is not None and aux_loss.numel() > 0:
546
+ outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device)
547
+
548
+ # Compute audio head loss if training S2S with latent targets
549
+ if self.audio_head is not None and latent_targets is not None:
550
+ if outputs.hidden_states is None:
551
+ raise ValueError(
552
+ "LLM did not return hidden_states for audio head. "
553
+ "Ensure output_hidden_states=True is passed to the LLM."
554
+ )
555
+ hidden_states = outputs.hidden_states[-1] # Last layer hidden states
556
+
557
+ # Extract only assistant-position hidden states using assistant_mask
558
+ # This mask identifies text output positions (where LLM generates response)
559
+ assistant_mask = kwargs.get("assistant_mask")
560
+ if assistant_mask is not None:
561
+ batch_size = hidden_states.shape[0]
562
+
563
+ # Extract assistant hidden states for each sample
564
+ assistant_hidden_list = []
565
+ assistant_lengths = []
566
+ for i in range(batch_size):
567
+ mask_i = assistant_mask[i] # [seq_len]
568
+ hidden_i = hidden_states[i][mask_i] # [num_assistant_tokens, hidden_dim]
569
+ assistant_hidden_list.append(hidden_i)
570
+ assistant_lengths.append(hidden_i.shape[0])
571
+
572
+ # Pad sequences while preserving gradients
573
+ # Use pad_sequence which maintains gradient flow
574
+ hidden_states = torch.nn.utils.rnn.pad_sequence(
575
+ assistant_hidden_list, batch_first=True, padding_value=0.0
576
+ )
577
+ # Note: latent_lengths stays as original Mimi latent lengths for masking
578
+ # audio_head._compute_loss handles interpolation between different seq lengths
579
+
580
+ # No detach needed: LLM is frozen (requires_grad=False), so gradients
581
+ # naturally stop there. Hidden states keep their grad_fn for proper backprop.
582
+ audio_head_loss = self.audio_head(
583
+ hidden_states,
584
+ latent_targets=latent_targets,
585
+ latent_lengths=latent_lengths,
586
+ )
587
+
588
+ # Combine with LLM loss if present (e.g., joint ASR+S2S training)
589
+ if outputs.loss is not None:
590
+ total_loss = outputs.loss + audio_head_loss
591
+ else:
592
+ total_loss = audio_head_loss
593
+
594
+ # Return new output object (direct assignment doesn't work with Accelerator/DDP)
595
+ from transformers.modeling_outputs import CausalLMOutputWithPast
596
+
597
+ return CausalLMOutputWithPast(
598
+ loss=total_loss,
599
+ logits=outputs.logits,
600
+ past_key_values=outputs.past_key_values,
601
+ hidden_states=outputs.hidden_states,
602
+ attentions=outputs.attentions,
603
+ )
604
+
605
+ return outputs
606
+
607
+ def prepare_inputs_for_generation(self, *args, **kwargs):
608
+ """Prepare inputs for generation, handling audio features for cached decoding."""
609
+ input_features = kwargs.pop("input_features", None)
610
+ cache_position = kwargs.get("cache_position")
611
+
612
+ model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs)
613
+
614
+ # Only pass audio features on the first generation step (cache_position[0] == 0)
615
+ if cache_position is not None and cache_position[0] == 0 and input_features is not None:
616
+ model_inputs["input_features"] = input_features
617
+
618
+ return model_inputs
619
+
620
+ def _get_num_audio_tokens(
621
+ self,
622
+ audio_attention_mask: torch.Tensor,
623
+ ) -> int:
624
+ """Calculate number of audio tokens based on actual audio length.
625
+
626
+ Uses attention mask to get real audio length, then computes:
627
+ mel_frames -> encoder_frames (via conv formulas) -> projector output tokens
628
+ """
629
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
630
+ # Use max length for batch (all samples should have same token count for generation)
631
+ encoder_output_len = int(encoder_lengths.max().item())
632
+ return int(self.projector.get_output_length(encoder_output_len))
633
+
634
+ @torch.no_grad()
635
+ def generate(
636
+ self,
637
+ input_ids: Optional[torch.Tensor] = None,
638
+ input_features: Optional[torch.Tensor] = None,
639
+ audio_attention_mask: Optional[torch.Tensor] = None,
640
+ attention_mask: Optional[torch.Tensor] = None,
641
+ system_prompt: Optional[str] = None,
642
+ **generate_kwargs,
643
+ ) -> torch.Tensor:
644
+ """Generate transcription from audio input.
645
+
646
+ Can be called in two ways:
647
+ 1. With input_ids containing <audio> tokens (from processor)
648
+ 2. With just audio, and we build the prompt internally
649
+ """
650
+ if input_features is None:
651
+ raise ValueError("input_features required for generation")
652
+ if audio_attention_mask is None:
653
+ raise ValueError("audio_attention_mask required for generation")
654
+
655
+ device = input_features.device
656
+ batch_size = input_features.shape[0]
657
+
658
+ # Encode audio -> flattened embeddings
659
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
660
+
661
+ # If input_ids not provided, build prompt with correct number of audio tokens
662
+ if input_ids is None:
663
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
664
+ audio_placeholder = "<audio>" * num_audio_tokens
665
+
666
+ system_prompt = system_prompt or self.system_prompt
667
+
668
+ messages: list[dict[str, str]] = []
669
+ if system_prompt:
670
+ messages.append({"role": "system", "content": system_prompt})
671
+ # Audio tokens only (instruction-free)
672
+ user_content = audio_placeholder
673
+ if self.TRANSCRIBE_PROMPT:
674
+ user_content += " " + self.TRANSCRIBE_PROMPT
675
+ messages.append({"role": "user", "content": user_content})
676
+
677
+ chat_result = self.tokenizer.apply_chat_template(
678
+ messages,
679
+ tokenize=True,
680
+ add_generation_prompt=True,
681
+ return_tensors="pt",
682
+ enable_thinking=getattr(self.config, "enable_thinking", False),
683
+ )
684
+ input_ids = chat_result.input_ids.to(device)
685
+
686
+ if input_ids.dim() == 1:
687
+ input_ids = input_ids.unsqueeze(0)
688
+ if input_ids.shape[0] == 1 and batch_size > 1:
689
+ input_ids = input_ids.expand(batch_size, -1)
690
+
691
+ attention_mask = torch.ones_like(input_ids)
692
+
693
+ # Get text embeddings and replace audio tokens with audio embeddings
694
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
695
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
696
+ inputs_embeds = inputs_embeds.masked_scatter(
697
+ audio_token_mask.to(inputs_embeds.device),
698
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
699
+ )
700
+
701
+ # Generate using language model
702
+ # Pass both input_ids and inputs_embeds so repetition_penalty works correctly
703
+ # (it needs input_ids to track which tokens have been used)
704
+ output = self.language_model.generate(
705
+ input_ids=input_ids,
706
+ inputs_embeds=inputs_embeds,
707
+ attention_mask=attention_mask,
708
+ generation_config=self.generation_config,
709
+ **generate_kwargs,
710
+ )
711
+
712
+ # When using inputs_embeds with input_ids, generate returns full sequence
713
+ # Strip the input tokens to return only generated tokens
714
+ sequences = output if isinstance(output, torch.Tensor) else output.sequences
715
+ input_len = input_ids.shape[1]
716
+ return sequences[:, input_len:]
717
+
718
+ def generate_streaming(
719
+ self,
720
+ input_features: torch.Tensor,
721
+ audio_attention_mask: torch.Tensor,
722
+ system_prompt: Optional[str] = None,
723
+ **generate_kwargs,
724
+ ) -> Iterator[str]:
725
+ """Generate transcription with streaming token output.
726
+
727
+ Yields partial transcript strings as tokens are generated.
728
+ Reduces time-to-first-word by streaming tokens as they're decoded.
729
+
730
+ Args:
731
+ input_features: Mel spectrogram features (batch, n_mels, mel_len)
732
+ audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len)
733
+ system_prompt: Optional system prompt override
734
+ **generate_kwargs: Additional generation arguments
735
+
736
+ Yields:
737
+ Partial transcript text as each token is generated
738
+ """
739
+ device = input_features.device
740
+ batch_size = input_features.shape[0]
741
+
742
+ # Encode audio -> flattened embeddings
743
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
744
+
745
+ # Build prompt with correct number of audio tokens
746
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
747
+ audio_placeholder = "<audio>" * num_audio_tokens
748
+
749
+ system_prompt = system_prompt or self.system_prompt
750
+
751
+ messages: list[dict[str, str]] = []
752
+ if system_prompt:
753
+ messages.append({"role": "system", "content": system_prompt})
754
+ # Audio tokens only (instruction-free)
755
+ user_content = audio_placeholder
756
+ if self.TRANSCRIBE_PROMPT:
757
+ user_content += " " + self.TRANSCRIBE_PROMPT
758
+ messages.append({"role": "user", "content": user_content})
759
+
760
+ chat_result = self.tokenizer.apply_chat_template(
761
+ messages,
762
+ tokenize=True,
763
+ add_generation_prompt=True,
764
+ return_tensors="pt",
765
+ enable_thinking=getattr(self.config, "enable_thinking", False),
766
+ )
767
+ input_ids = chat_result.input_ids.to(device)
768
+
769
+ if input_ids.dim() == 1:
770
+ input_ids = input_ids.unsqueeze(0)
771
+ if input_ids.shape[0] == 1 and batch_size > 1:
772
+ input_ids = input_ids.expand(batch_size, -1)
773
+
774
+ attention_mask = torch.ones_like(input_ids)
775
+
776
+ # Get text embeddings and replace audio tokens with audio embeddings
777
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
778
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
779
+ inputs_embeds = inputs_embeds.masked_scatter(
780
+ audio_token_mask.to(inputs_embeds.device),
781
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
782
+ )
783
+
784
+ # Setup streamer for token-by-token output
785
+ streamer = TextIteratorStreamer(
786
+ self.tokenizer,
787
+ skip_prompt=True,
788
+ skip_special_tokens=True,
789
+ )
790
+
791
+ # Prepare generation kwargs
792
+ gen_kwargs = {
793
+ "inputs_embeds": inputs_embeds,
794
+ "attention_mask": attention_mask,
795
+ "generation_config": self.generation_config,
796
+ "streamer": streamer,
797
+ **generate_kwargs,
798
+ }
799
+
800
+ # Run generation in background thread
801
+ thread = Thread(target=self.language_model.generate, kwargs=gen_kwargs)
802
+ thread.start()
803
+
804
+ # Yield tokens as they're generated, filtering out <think>...</think> blocks
805
+ # Start assuming no think block - only filter when we see <think>
806
+ in_think_block = False
807
+ buffer = ""
808
+
809
+ for text in streamer:
810
+ buffer += text
811
+
812
+ # Check for think block start (in case model outputs think blocks)
813
+ while "<think>" in buffer:
814
+ in_think_block = True
815
+ # Yield any text before <think>
816
+ before_think = buffer.split("<think>")[0]
817
+ if before_think:
818
+ yield before_think
819
+ buffer = buffer.split("<think>", 1)[-1]
820
+
821
+ # Check for think block end
822
+ while in_think_block and "</think>" in buffer:
823
+ in_think_block = False
824
+ buffer = buffer.split("</think>", 1)[-1]
825
+
826
+ # Yield text if not in think block
827
+ if not in_think_block and buffer:
828
+ yield buffer
829
+ buffer = ""
830
+
831
+ # Yield any remaining buffer
832
+ if buffer and not in_think_block:
833
+ yield buffer
834
+
835
+ thread.join()
836
+
837
+ @torch.no_grad()
838
+ def generate_text_only(
839
+ self,
840
+ messages: list[dict[str, str]],
841
+ max_new_tokens: int = 256,
842
+ **generate_kwargs,
843
+ ) -> str:
844
+ """Generate text using only the LLM (no audio encoding).
845
+
846
+ Used for SIFT-style response generation from metadata prompts.
847
+
848
+ Args:
849
+ messages: List of chat messages [{"role": "user", "content": "..."}]
850
+ max_new_tokens: Maximum tokens to generate
851
+ **generate_kwargs: Additional generation arguments
852
+
853
+ Returns:
854
+ Generated text response
855
+ """
856
+ device = next(self.language_model.parameters()).device
857
+
858
+ # Apply chat template
859
+ input_ids = self.tokenizer.apply_chat_template(
860
+ messages,
861
+ tokenize=True,
862
+ add_generation_prompt=True,
863
+ return_tensors="pt",
864
+ enable_thinking=getattr(self.config, "enable_thinking", False),
865
+ ).to(device)
866
+
867
+ if input_ids.dim() == 1:
868
+ input_ids = input_ids.unsqueeze(0)
869
+
870
+ attention_mask = torch.ones_like(input_ids)
871
+
872
+ # Generate using language model directly
873
+ output = self.language_model.generate(
874
+ input_ids=input_ids,
875
+ attention_mask=attention_mask,
876
+ max_new_tokens=max_new_tokens,
877
+ do_sample=False,
878
+ pad_token_id=self.tokenizer.pad_token_id,
879
+ eos_token_id=self.tokenizer.eos_token_id,
880
+ **generate_kwargs,
881
+ )
882
+
883
+ # Decode only the new tokens
884
+ new_tokens = output[0, input_ids.shape[1] :]
885
+ response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
886
+ return response.strip()
887
+
888
+ @torch.no_grad()
889
+ def generate_with_audio(
890
+ self,
891
+ input_features: torch.Tensor,
892
+ audio_attention_mask: torch.Tensor,
893
+ **generate_kwargs,
894
+ ) -> dict[str, torch.Tensor | list[str]]:
895
+ """Generate text and audio for Speech-to-Speech.
896
+
897
+ Args:
898
+ input_features: Mel spectrogram features (batch, n_mels, mel_len)
899
+ audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len)
900
+ **generate_kwargs: Additional generation arguments
901
+
902
+ Returns:
903
+ Dict with:
904
+ - text: Decoded text strings (list of str)
905
+ - audio: Audio waveform at 24kHz (batch, samples)
906
+ """
907
+ if self.audio_head is None:
908
+ raise ValueError("Audio head not configured. Set use_audio_head=True in config.")
909
+
910
+ device = input_features.device
911
+ batch_size = input_features.shape[0]
912
+
913
+ # Encode audio -> flattened embeddings
914
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
915
+
916
+ # Build prompt with correct number of audio tokens
917
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
918
+ audio_placeholder = "<audio>" * num_audio_tokens
919
+
920
+ messages: list[dict[str, str]] = []
921
+ if self.system_prompt:
922
+ messages.append({"role": "system", "content": self.system_prompt})
923
+ user_content = audio_placeholder
924
+ if self.TRANSCRIBE_PROMPT:
925
+ user_content += " " + self.TRANSCRIBE_PROMPT
926
+ messages.append({"role": "user", "content": user_content})
927
+
928
+ chat_result = self.tokenizer.apply_chat_template(
929
+ messages,
930
+ tokenize=True,
931
+ add_generation_prompt=True,
932
+ return_tensors="pt",
933
+ enable_thinking=getattr(self.config, "enable_thinking", False),
934
+ )
935
+ input_ids = chat_result.input_ids.to(device)
936
+
937
+ if input_ids.dim() == 1:
938
+ input_ids = input_ids.unsqueeze(0)
939
+ if input_ids.shape[0] == 1 and batch_size > 1:
940
+ input_ids = input_ids.expand(batch_size, -1)
941
+
942
+ attention_mask = torch.ones_like(input_ids)
943
+
944
+ # Get text embeddings and replace audio tokens with audio embeddings
945
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
946
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
947
+ inputs_embeds = inputs_embeds.masked_scatter(
948
+ audio_token_mask.to(inputs_embeds.device),
949
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
950
+ )
951
+
952
+ # Generate with hidden states
953
+ output = self.language_model.generate(
954
+ input_ids=input_ids,
955
+ inputs_embeds=inputs_embeds,
956
+ attention_mask=attention_mask,
957
+ generation_config=self.generation_config,
958
+ output_hidden_states=True,
959
+ return_dict_in_generate=True,
960
+ **generate_kwargs,
961
+ )
962
+
963
+ # Extract generated text
964
+ text_ids = output.sequences[:, input_ids.shape[1] :]
965
+ text = self.tokenizer.batch_decode(text_ids, skip_special_tokens=True)
966
+
967
+ # Extract hidden states from generation steps and concatenate
968
+ # output.hidden_states is tuple of (step,) where each step is tuple of (layer,)
969
+ # Each layer tensor is (batch, 1, hidden_dim) for generated tokens
970
+ last_layer_states = []
971
+ for step_hidden in output.hidden_states:
972
+ # step_hidden is tuple of (num_layers,) tensors
973
+ # Get last layer: shape (batch, 1, hidden_dim)
974
+ last_layer_states.append(step_hidden[-1])
975
+
976
+ # Concatenate across generation steps: (batch, gen_seq_len, hidden_dim)
977
+ hidden_states = torch.cat(last_layer_states, dim=1)
978
+
979
+ # Generate Mimi latents from LLM hidden states via flow matching
980
+ latents = self.audio_head(hidden_states)
981
+
982
+ # Load Mimi decoder if not already loaded
983
+ if self.audio_head.mimi is None:
984
+ self.audio_head.load_mimi_decoder(device=device)
985
+
986
+ # Decode latents to audio waveform
987
+ audio = self.audio_head.decode_to_audio(latents)
988
+
989
+ return {
990
+ "text": text,
991
+ "audio": audio,
992
+ }
993
+
994
+ def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
995
+ """Save model, tokenizer, and processor."""
996
+ import shutil
997
+ from pathlib import Path as PathlibPath
998
+
999
+ save_dir = PathlibPath(save_directory)
1000
+ save_dir.mkdir(parents=True, exist_ok=True)
1001
+
1002
+ # Update config with actual vocab size
1003
+ self.config.vocab_size = self.language_model.config.vocab_size
1004
+ self.config.text_config.vocab_size = self.language_model.config.vocab_size
1005
+
1006
+ if hasattr(self.audio_tower.config, "num_mel_bins"):
1007
+ self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins
1008
+
1009
+ # Save model (temporarily remove non-serializable attributes)
1010
+ tokenizer = self.tokenizer
1011
+ del self.tokenizer
1012
+
1013
+ try:
1014
+ super().save_pretrained(save_dir, **kwargs)
1015
+ finally:
1016
+ self.tokenizer = tokenizer
1017
+
1018
+ # Save tokenizer and feature extractor
1019
+ self.tokenizer.save_pretrained(save_dir)
1020
+ self.feature_extractor.save_pretrained(save_dir)
1021
+
1022
+ # Save LoRA adapters if present (creates adapter_model.safetensors and adapter_config.json)
1023
+ # Don't save embedding layers - the <audio> token embedding is never used
1024
+ # (it's replaced with projected audio embeddings before the LLM sees it)
1025
+ if hasattr(self.language_model, "peft_config"):
1026
+ self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
1027
+
1028
+ # Clear base_model_name_or_path in adapter_config.json to prevent HF pipeline
1029
+ # from redirecting to the base LLM repo (like Qwen) which breaks feature
1030
+ # extractor loading for multimodal models. If a repo_id is provided, use that
1031
+ # so the model can be loaded directly from the Hub.
1032
+ adapter_config_path = save_dir / "adapter_config.json"
1033
+ if adapter_config_path.exists():
1034
+ with adapter_config_path.open() as f:
1035
+ adapter_config = json.load(f)
1036
+
1037
+ # Use repo_id if available, otherwise clear to prevent redirect.
1038
+ # Use empty string instead of None to avoid str(None) -> "None" bug
1039
+ # in some transformers/PEFT versions.
1040
+ repo_id = (
1041
+ kwargs.get("repo_id")
1042
+ or kwargs.get("push_to_hub_model_id")
1043
+ or getattr(self.config, "pretrained_model_path", None)
1044
+ or "" # Use empty string instead of None
1045
+ )
1046
+ adapter_config["base_model_name_or_path"] = repo_id
1047
+
1048
+ with adapter_config_path.open("w") as f:
1049
+ json.dump(adapter_config, f, indent=2)
1050
+
1051
+ # Add processor auto_map to preprocessor_config.json
1052
+ config_path = save_dir / "preprocessor_config.json"
1053
+ if config_path.exists():
1054
+ with config_path.open() as f:
1055
+ processor_config = json.load(f)
1056
+ else:
1057
+ processor_config = {}
1058
+
1059
+ processor_config.update(
1060
+ {
1061
+ "processor_class": "ASRProcessor",
1062
+ "auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
1063
+ }
1064
+ )
1065
+
1066
+ with config_path.open("w") as f:
1067
+ json.dump(processor_config, f, indent=2)
1068
+
1069
+ # Copy source files for auto-loading
1070
+ src_dir = PathlibPath(__file__).parent
1071
+ for asr_file in src_dir.glob("asr_*.py"):
1072
+ shutil.copy(asr_file, save_dir / asr_file.name)
1073
+ # Copy projectors module
1074
+ shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
1075
+ # Copy alignment module
1076
+ shutil.copy(src_dir / "alignment.py", save_dir / "alignment.py")
1077
+ # Copy diarization module
1078
+ shutil.copy(src_dir / "diarization.py", save_dir / "diarization.py")
1079
+ # Copy audio head for S2S
1080
+ audio_head_path = src_dir / "audio_head.py"
1081
+ if audio_head_path.exists():
1082
+ shutil.copy(audio_head_path, save_dir / "audio_head.py")
1083
+ # Copy modules directory (for audio head dependencies)
1084
+ modules_dir = src_dir / "modules"
1085
+ if modules_dir.exists():
1086
+ save_modules_dir = save_dir / "modules"
1087
+ save_modules_dir.mkdir(exist_ok=True)
1088
+ for module_file in modules_dir.glob("*.py"):
1089
+ shutil.copy(module_file, save_modules_dir / module_file.name)
1090
+
1091
+ def push_to_hub(self, repo_id: str, **kwargs) -> str:
1092
+ """Push model to HuggingFace Hub, ensuring adapter_config points to repo.
1093
+
1094
+ IMPORTANT: Sets base_model_name_or_path in adapter_config.json to repo_id
1095
+ so that transformers pipeline() can load the model correctly. Without this,
1096
+ the pipeline tries to load from "None" which fails.
1097
+ """
1098
+ # Store repo_id in config so save_pretrained can access it
1099
+ self.config.pretrained_model_path = repo_id
1100
+ # Call parent's push_to_hub
1101
+ return super().push_to_hub(repo_id, **kwargs)
1102
+
1103
+ def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
1104
+ """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
1105
+ pass
1106
+
1107
+
1108
+ # Register with transformers Auto classes
1109
+ AutoConfig.register("asr_model", ASRConfig)
1110
+ AutoModel.register(ASRConfig, ASRModel)
asr_pipeline.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
2
+
3
+ import re
4
+ from pathlib import Path
5
+ from typing import Any, Iterator, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import transformers
10
+
11
+ try:
12
+ from .alignment import ForcedAligner
13
+ from .asr_modeling import ASRModel
14
+ from .diarization import SpeakerDiarizer
15
+ except ImportError:
16
+ from alignment import ForcedAligner # type: ignore[no-redef]
17
+ from asr_modeling import ASRModel # type: ignore[no-redef]
18
+ from diarization import SpeakerDiarizer # type: ignore[no-redef]
19
+
20
+ # Re-export for backwards compatibility
21
+ __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline", "strip_thinking"]
22
+
23
+ # Default TTS voice for Kokoro
24
+ DEFAULT_TTS_VOICE = "af_heart"
25
+ TTS_SAMPLE_RATE = 24000
26
+
27
+
28
+ def strip_thinking(text: str) -> str:
29
+ """Remove <think>...</think> tags from model output.
30
+
31
+ Args:
32
+ text: Model output text that may contain thinking tags
33
+
34
+ Returns:
35
+ Text with thinking content removed
36
+ """
37
+ if not text:
38
+ return text
39
+ text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)
40
+ return text.strip()
41
+
42
+
43
+ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
44
+ """ASR Pipeline for audio-to-text transcription."""
45
+
46
+ model: ASRModel
47
+
48
+ def __init__(self, model: ASRModel, **kwargs):
49
+ """Initialize ASR pipeline.
50
+
51
+ Args:
52
+ model: ASRModel instance for transcription
53
+ **kwargs: Additional arguments (feature_extractor, tokenizer, device)
54
+ """
55
+ feature_extractor = kwargs.pop("feature_extractor", None)
56
+ tokenizer = kwargs.pop("tokenizer", model.tokenizer)
57
+
58
+ if feature_extractor is None:
59
+ feature_extractor = model.get_processor().feature_extractor
60
+
61
+ super().__init__(
62
+ model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
63
+ )
64
+ self._current_audio = None
65
+ self._tts_pipeline = None
66
+
67
+ @property
68
+ def tts_pipeline(self):
69
+ """Lazy-load Kokoro TTS pipeline on first use."""
70
+ if self._tts_pipeline is None:
71
+ try:
72
+ from kokoro import KPipeline
73
+
74
+ self._tts_pipeline = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M")
75
+ except ImportError as e:
76
+ raise ImportError(
77
+ "Kokoro TTS is required for audio output. "
78
+ "Install with: pip install kokoro>=0.9.2\n"
79
+ "Also requires espeak-ng: apt-get install espeak-ng"
80
+ ) from e
81
+ return self._tts_pipeline
82
+
83
+ def text_to_speech(self, text: str, voice: str = DEFAULT_TTS_VOICE) -> dict[str, Any]:
84
+ """Convert text to speech using Kokoro TTS.
85
+
86
+ Args:
87
+ text: Text to synthesize
88
+ voice: Kokoro voice ID (default: "af_heart")
89
+
90
+ Returns:
91
+ Dict with 'audio' (numpy array) and 'sample_rate' keys
92
+ """
93
+ if not text or not text.strip():
94
+ return {"audio": np.array([], dtype=np.float32), "sample_rate": TTS_SAMPLE_RATE}
95
+
96
+ # Generate audio chunks and concatenate
97
+ audio_chunks = []
98
+ for _, _, audio in self.tts_pipeline(text, voice=voice):
99
+ audio_chunks.append(audio)
100
+
101
+ audio = np.concatenate(audio_chunks) if audio_chunks else np.array([], dtype=np.float32)
102
+ return {"audio": audio, "sample_rate": TTS_SAMPLE_RATE}
103
+
104
+ def transcribe_streaming(
105
+ self,
106
+ inputs: Union[str, bytes, np.ndarray, dict],
107
+ system_prompt: str | None = None,
108
+ ) -> Iterator[str]:
109
+ """Transcribe audio with streaming token output for low-latency applications.
110
+
111
+ Yields partial transcript strings as tokens are generated, reducing
112
+ time-to-first-word compared to waiting for full transcription.
113
+
114
+ Args:
115
+ inputs: Audio input in any supported format:
116
+ - str: File path to audio file
117
+ - bytes: Raw audio bytes
118
+ - np.ndarray: Audio samples as numpy array
119
+ - dict: {"array": np.ndarray, "sampling_rate": int}
120
+ system_prompt: Optional system prompt override (uses model's default if not provided)
121
+
122
+ Yields:
123
+ Partial transcript text as each token is generated
124
+
125
+ Example:
126
+ >>> for partial in pipeline.transcribe_streaming("audio.wav"):
127
+ ... print(partial, end="", flush=True)
128
+ """
129
+ # Extract audio array from various input formats
130
+ audio_data = self._extract_audio(inputs)
131
+ if audio_data is None:
132
+ return
133
+
134
+ audio_array = audio_data["array"]
135
+ sample_rate = audio_data.get("sampling_rate", 16000)
136
+
137
+ # Preprocess audio through feature extractor
138
+ model_inputs = self.feature_extractor(
139
+ audio_array,
140
+ sampling_rate=sample_rate,
141
+ return_tensors="pt",
142
+ return_attention_mask=True,
143
+ )
144
+
145
+ # Get model dtype and device, cast inputs to match
146
+ device = self.model.device
147
+ model_dtype = next(self.model.parameters()).dtype
148
+ input_features = model_inputs.input_features.to(device, dtype=model_dtype)
149
+ attention_mask = model_inputs.attention_mask.to(device)
150
+
151
+ # Stream tokens from model
152
+ yield from self.model.generate_streaming(
153
+ input_features=input_features,
154
+ audio_attention_mask=attention_mask,
155
+ system_prompt=system_prompt,
156
+ )
157
+
158
+ def transcribe_streaming_with_audio(
159
+ self,
160
+ inputs: Union[str, bytes, np.ndarray, dict],
161
+ voice: str = DEFAULT_TTS_VOICE,
162
+ system_prompt: str | None = None,
163
+ ) -> Iterator[dict[str, Any]]:
164
+ """Transcribe audio with streaming text AND audio output.
165
+
166
+ Yields partial text as tokens are generated, and audio chunks
167
+ as complete sentences are detected. This enables low-latency
168
+ voice agents that can start speaking before transcription completes.
169
+
170
+ Args:
171
+ inputs: Audio input (same formats as transcribe_streaming)
172
+ voice: Kokoro TTS voice ID
173
+ system_prompt: Optional system prompt override (uses model's default if not provided)
174
+
175
+ Yields:
176
+ Dicts with either:
177
+ - {"type": "text", "text": str, "interim": bool} for text updates
178
+ - {"type": "audio", "audio": np.ndarray, "sample_rate": int} for audio chunks
179
+
180
+ Example:
181
+ >>> for chunk in pipeline.transcribe_streaming_with_audio(audio):
182
+ ... if chunk["type"] == "text":
183
+ ... print(chunk["text"], end="", flush=True)
184
+ ... elif chunk["type"] == "audio":
185
+ ... play_audio(chunk["audio"], chunk["sample_rate"])
186
+ """
187
+ import re
188
+
189
+ sentence_buffer = ""
190
+ full_text = ""
191
+
192
+ # Sentence-ending patterns (handles ., !, ?, and common abbreviations)
193
+ sentence_end_pattern = re.compile(r"[.!?](?:\s|$)")
194
+
195
+ for token_text in self.transcribe_streaming(inputs, system_prompt=system_prompt):
196
+ full_text += token_text
197
+ sentence_buffer += token_text
198
+
199
+ # Yield text update
200
+ yield {"type": "text", "text": full_text, "interim": True}
201
+
202
+ # Check for complete sentence
203
+ match = sentence_end_pattern.search(sentence_buffer)
204
+ if match:
205
+ # Extract complete sentence(s)
206
+ end_pos = match.end()
207
+ complete_text = sentence_buffer[:end_pos].strip()
208
+ sentence_buffer = sentence_buffer[end_pos:]
209
+
210
+ # Generate audio for the complete sentence
211
+ if complete_text:
212
+ try:
213
+ tts_result = self.text_to_speech(complete_text, voice=voice)
214
+ if tts_result["audio"] is not None and len(tts_result["audio"]) > 0:
215
+ yield {
216
+ "type": "audio",
217
+ "audio": tts_result["audio"],
218
+ "sample_rate": tts_result["sample_rate"],
219
+ }
220
+ except Exception:
221
+ pass # Skip audio on TTS errors
222
+
223
+ # Final text update (not interim)
224
+ yield {"type": "text", "text": full_text, "interim": False}
225
+
226
+ # Generate audio for any remaining text
227
+ remaining = sentence_buffer.strip()
228
+ if remaining:
229
+ try:
230
+ tts_result = self.text_to_speech(remaining, voice=voice)
231
+ if tts_result["audio"] is not None and len(tts_result["audio"]) > 0:
232
+ yield {
233
+ "type": "audio",
234
+ "audio": tts_result["audio"],
235
+ "sample_rate": tts_result["sample_rate"],
236
+ }
237
+ except Exception:
238
+ pass
239
+
240
+ def _sanitize_parameters(self, **kwargs):
241
+ """Intercept our custom parameters before parent class validates them."""
242
+ # Remove our custom parameters so parent doesn't see them
243
+ kwargs.pop("return_timestamps", None)
244
+ kwargs.pop("return_speakers", None)
245
+ kwargs.pop("num_speakers", None)
246
+ kwargs.pop("min_speakers", None)
247
+ kwargs.pop("max_speakers", None)
248
+ kwargs.pop("hf_token", None)
249
+ kwargs.pop("user_prompt", None)
250
+ kwargs.pop("system_prompt", None)
251
+ kwargs.pop("diarization_backend", None)
252
+ # TTS parameters
253
+ kwargs.pop("return_audio", None)
254
+ kwargs.pop("tts_voice", None)
255
+
256
+ return super()._sanitize_parameters(**kwargs)
257
+
258
+ def __call__(
259
+ self,
260
+ inputs,
261
+ **kwargs,
262
+ ):
263
+ """Transcribe audio with optional word-level timestamps and speaker diarization.
264
+
265
+ Args:
266
+ inputs: Audio input (file path, dict with array/sampling_rate, etc.)
267
+ return_timestamps: If True, return word-level timestamps using forced alignment
268
+ return_speakers: If True, return speaker labels for each word
269
+ return_audio: If True, synthesize transcription as speech using Kokoro TTS
270
+ tts_voice: Kokoro voice ID for TTS output (default: "af_heart")
271
+ user_prompt: Custom transcription prompt (default: "Transcribe: ")
272
+ system_prompt: Custom system prompt override (uses model's default if not provided)
273
+ num_speakers: Exact number of speakers (if known, for diarization)
274
+ min_speakers: Minimum number of speakers (for diarization)
275
+ max_speakers: Maximum number of speakers (for diarization)
276
+ **kwargs: Additional arguments passed to the pipeline
277
+
278
+ Returns:
279
+ Dict with 'text' key, 'words' key if return_timestamps=True,
280
+ speaker labels on words if return_speakers=True,
281
+ and 'audio'/'sample_rate' keys if return_audio=True
282
+ """
283
+ # Extract our params before super().__call__ (which will also call _sanitize_parameters)
284
+ return_timestamps = kwargs.pop("return_timestamps", False)
285
+ return_speakers = kwargs.pop("return_speakers", False)
286
+ return_audio = kwargs.pop("return_audio", False)
287
+ tts_voice = kwargs.pop("tts_voice", DEFAULT_TTS_VOICE)
288
+ user_prompt = kwargs.pop("user_prompt", None)
289
+ system_prompt = kwargs.pop("system_prompt", None)
290
+ diarization_params = {
291
+ "num_speakers": kwargs.pop("num_speakers", None),
292
+ "min_speakers": kwargs.pop("min_speakers", None),
293
+ "max_speakers": kwargs.pop("max_speakers", None),
294
+ }
295
+
296
+ if return_speakers:
297
+ return_timestamps = True
298
+
299
+ # Set custom user prompt if provided
300
+ original_prompt = None
301
+ if user_prompt:
302
+ original_prompt = self.model.TRANSCRIBE_PROMPT
303
+ self.model.TRANSCRIBE_PROMPT = user_prompt
304
+
305
+ # Set custom system prompt if provided
306
+ original_system_prompt = None
307
+ if system_prompt:
308
+ original_system_prompt = self.model.system_prompt
309
+ self.model.system_prompt = system_prompt
310
+
311
+ # Store audio for timestamp alignment and diarization
312
+ if return_timestamps or return_speakers:
313
+ self._current_audio = self._extract_audio(inputs)
314
+
315
+ # Run standard transcription
316
+ result = super().__call__(inputs, **kwargs)
317
+
318
+ # Add timestamps if requested
319
+ if return_timestamps and self._current_audio is not None:
320
+ text = result.get("text", "")
321
+ if text:
322
+ try:
323
+ words = ForcedAligner.align(
324
+ self._current_audio["array"],
325
+ text,
326
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
327
+ )
328
+ result["words"] = words
329
+ except Exception as e:
330
+ result["words"] = []
331
+ result["timestamp_error"] = str(e)
332
+ else:
333
+ result["words"] = []
334
+
335
+ # Add speaker diarization if requested
336
+ if return_speakers and self._current_audio is not None:
337
+ try:
338
+ # Run diarization
339
+ speaker_segments = SpeakerDiarizer.diarize(
340
+ self._current_audio["array"],
341
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
342
+ **{k: v for k, v in diarization_params.items() if v is not None},
343
+ )
344
+ result["speaker_segments"] = speaker_segments
345
+
346
+ # Assign speakers to words
347
+ if result.get("words"):
348
+ result["words"] = SpeakerDiarizer.assign_speakers_to_words(
349
+ result["words"],
350
+ speaker_segments,
351
+ )
352
+ except Exception as e:
353
+ result["speaker_segments"] = []
354
+ result["diarization_error"] = str(e)
355
+
356
+ # Synthesize transcription as speech if requested
357
+ if return_audio:
358
+ text = result.get("text", "")
359
+ try:
360
+ tts_result = self.text_to_speech(text, voice=tts_voice)
361
+ result["audio"] = tts_result["audio"]
362
+ result["sample_rate"] = tts_result["sample_rate"]
363
+ except Exception as e:
364
+ result["audio"] = np.array([], dtype=np.float32)
365
+ result["sample_rate"] = TTS_SAMPLE_RATE
366
+ result["tts_error"] = str(e)
367
+
368
+ # Clean up
369
+ self._current_audio = None
370
+ if original_prompt is not None:
371
+ self.model.TRANSCRIBE_PROMPT = original_prompt
372
+ if original_system_prompt is not None:
373
+ self.model.system_prompt = original_system_prompt
374
+
375
+ return result
376
+
377
+ def _extract_audio(self, inputs) -> dict | None:
378
+ """Extract audio array from various input formats.
379
+
380
+ Supported input formats:
381
+ - str: File path to audio file
382
+ - bytes: Encoded audio (mp3, wav, etc.) - decoded via ffmpeg
383
+ - np.ndarray: Audio samples as float32 array
384
+ - dict with "array": Audio samples as numpy array
385
+ - dict with "raw": Alias for "array" (HF pipeline compat)
386
+ - dict with "raw_bytes": Raw PCM bytes (requires "dtype", optional "sampling_rate")
387
+
388
+ For raw PCM bytes (e.g., from pipecat), use:
389
+ {"raw_bytes": pcm_bytes, "dtype": "int16", "sampling_rate": 16000}
390
+ """
391
+ from transformers.pipelines.audio_utils import ffmpeg_read
392
+
393
+ if isinstance(inputs, dict):
394
+ if "array" in inputs:
395
+ return {
396
+ "array": inputs["array"],
397
+ "sampling_rate": inputs.get("sampling_rate", 16000),
398
+ }
399
+ if "raw" in inputs:
400
+ return {
401
+ "array": inputs["raw"],
402
+ "sampling_rate": inputs.get("sampling_rate", 16000),
403
+ }
404
+ if "raw_bytes" in inputs:
405
+ # Raw PCM bytes - convert to float32 array
406
+ dtype = inputs.get("dtype", "int16")
407
+ sample_rate = inputs.get("sampling_rate", 16000)
408
+ audio = np.frombuffer(inputs["raw_bytes"], dtype=dtype).astype(np.float32)
409
+ # Normalize based on dtype
410
+ if dtype == "int16":
411
+ audio = audio / 32768.0
412
+ elif dtype == "int32":
413
+ audio = audio / 2147483648.0
414
+ return {"array": audio, "sampling_rate": sample_rate}
415
+ elif isinstance(inputs, str):
416
+ # File path - load audio using ffmpeg (same as HF pipeline)
417
+ with Path(inputs).open("rb") as f:
418
+ audio = ffmpeg_read(f.read(), sampling_rate=16000)
419
+ return {"array": audio, "sampling_rate": 16000}
420
+ elif isinstance(inputs, bytes):
421
+ audio = ffmpeg_read(inputs, sampling_rate=16000)
422
+ return {"array": audio, "sampling_rate": 16000}
423
+ elif isinstance(inputs, np.ndarray):
424
+ return {"array": inputs, "sampling_rate": 16000}
425
+
426
+ return None
427
+
428
+ def preprocess(self, inputs, **preprocess_params):
429
+ """Preprocess audio inputs for the model.
430
+
431
+ Args:
432
+ inputs: Audio input (dict with array, file path, etc.)
433
+ **preprocess_params: Additional preprocessing parameters
434
+
435
+ Yields:
436
+ Model input dicts with input_features and attention_mask
437
+ """
438
+ # Handle dict with "array" key (from datasets)
439
+ if isinstance(inputs, dict) and "array" in inputs:
440
+ inputs = {
441
+ "raw": inputs["array"],
442
+ "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
443
+ }
444
+
445
+ for item in super().preprocess(inputs, **preprocess_params):
446
+ if "is_last" not in item:
447
+ item["is_last"] = True
448
+ yield item
449
+
450
+ def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
451
+ """Run model forward pass to generate transcription.
452
+
453
+ Args:
454
+ model_inputs: Dict with input_features and attention_mask
455
+ **generate_kwargs: Generation parameters
456
+
457
+ Returns:
458
+ Dict with generated token IDs
459
+ """
460
+ # Extract audio features and is_last flag
461
+ is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
462
+
463
+ input_features = model_inputs["input_features"].to(self.model.device)
464
+ audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
465
+
466
+ generated_ids = self.model.generate(
467
+ input_features=input_features,
468
+ audio_attention_mask=audio_attention_mask,
469
+ **generate_kwargs,
470
+ )
471
+
472
+ return {"tokens": generated_ids, "is_last": is_last}
473
+
474
+ def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
475
+ """Convert model output tokens to text.
476
+
477
+ Args:
478
+ model_outputs: Dict with 'tokens' key containing generated IDs
479
+ **kwargs: Additional postprocessing parameters
480
+
481
+ Returns:
482
+ Dict with 'text' key containing transcription
483
+ """
484
+ # Handle list of outputs (from chunking)
485
+ if isinstance(model_outputs, list):
486
+ model_outputs = model_outputs[0] if model_outputs else {}
487
+
488
+ tokens = model_outputs.get("tokens")
489
+ if tokens is None:
490
+ return super().postprocess(model_outputs, **kwargs)
491
+
492
+ if torch.is_tensor(tokens):
493
+ tokens = tokens.cpu()
494
+ if tokens.dim() > 1:
495
+ tokens = tokens[0]
496
+
497
+ # Filter out eos tokens that the tokenizer doesn't recognize as special
498
+ # (generation_config.eos_token_id may differ from tokenizer.eos_token_id)
499
+ if hasattr(self, "model") and hasattr(self.model, "generation_config"):
500
+ eos_ids = self.model.generation_config.eos_token_id
501
+ if eos_ids is not None:
502
+ eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids}
503
+ tokens = [t for t in tokens.tolist() if t not in eos_set]
504
+
505
+ text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
506
+ # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
507
+ text = strip_thinking(text)
508
+ # Truncate repetitions at end of text
509
+ text = _truncate_repetitions(text)
510
+ return {"text": text}
511
+
512
+
513
+ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
514
+ """Truncate repeated words/phrases/characters at end of text.
515
+
516
+ Detects patterns like:
517
+ - Repeated words: "the the the the" -> "the"
518
+ - Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry"
519
+ - Repeated characters: "444444" -> "4"
520
+
521
+ Args:
522
+ text: Input text to process
523
+ min_repeats: Minimum repetitions to trigger truncation (default 3)
524
+
525
+ Returns:
526
+ Text with trailing repetitions removed
527
+ """
528
+ if not text:
529
+ return text
530
+
531
+ # 1. Truncate repeated characters at end (e.g., "444444" -> "4")
532
+ char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$")
533
+ text = char_pattern.sub(r"\1", text)
534
+
535
+ # 2. Truncate repeated words at end (e.g., "the the the" -> "the")
536
+ word_pattern = re.compile(
537
+ r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
538
+ )
539
+ while word_pattern.search(text):
540
+ text = word_pattern.sub(r"\1", text)
541
+
542
+ # 3. Truncate repeated phrases (2-20 words) at end
543
+ # e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
544
+ words = text.split()
545
+ if len(words) >= min_repeats * 2:
546
+ # Try phrase lengths from 2 to 20 words
547
+ for phrase_len in range(2, min(21, len(words) // min_repeats + 1)):
548
+ # Check if the last phrase_len words repeat
549
+ phrase = " ".join(words[-phrase_len:])
550
+ # Build pattern to match repeated phrases at end
551
+ phrase_escaped = re.escape(phrase)
552
+ phrase_pattern = re.compile(
553
+ r"(^|.*?\s)("
554
+ + phrase_escaped
555
+ + r")(?:\s+"
556
+ + phrase_escaped
557
+ + r"){"
558
+ + str(min_repeats - 1)
559
+ + r",}\s*$",
560
+ re.IGNORECASE,
561
+ )
562
+ match = phrase_pattern.match(text)
563
+ if match:
564
+ # Keep prefix + one instance of the phrase
565
+ text = (match.group(1) + match.group(2)).strip()
566
+ words = text.split()
567
+ break
568
+
569
+ return text
asr_processing.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import transformers
5
+ from transformers import ProcessorMixin
6
+
7
+ try:
8
+ from .asr_config import ASRConfig
9
+ except ImportError:
10
+ from asr_config import ASRConfig # type: ignore[no-redef]
11
+
12
+
13
+ class ASRProcessor(ProcessorMixin):
14
+ """Processor for Whisper-based ASR models."""
15
+
16
+ attributes = ["feature_extractor", "tokenizer"]
17
+ feature_extractor_class = "AutoFeatureExtractor"
18
+ tokenizer_class = "AutoTokenizer"
19
+ AUDIO_TOKEN = "<audio>"
20
+ TRANSCRIBE_PROMPT = ""
21
+ # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
+ DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
+
24
+ def __init__(
25
+ self,
26
+ feature_extractor,
27
+ tokenizer,
28
+ projector=None,
29
+ encoder_conv_layers: Optional[list] = None,
30
+ ):
31
+ """Initialize the ASR processor.
32
+
33
+ Args:
34
+ feature_extractor: Audio feature extractor (WhisperFeatureExtractor)
35
+ tokenizer: Text tokenizer for the language model
36
+ projector: Audio projector module (for computing output lengths)
37
+ encoder_conv_layers: Conv layer specs [(pad, kernel, stride), ...]
38
+ """
39
+ self.feature_extractor = feature_extractor
40
+ self.tokenizer = tokenizer
41
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
42
+ self.projector = projector
43
+ self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
44
+
45
+ def _compute_encoder_output_length(self, mel_length: int) -> int:
46
+ """Compute encoder output length using conv layer formulas."""
47
+ length = mel_length
48
+ for padding, kernel_size, stride in self.encoder_conv_layers:
49
+ length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
50
+ return length
51
+
52
+ def __call__(
53
+ self,
54
+ audio: Optional[Union[list, "torch.Tensor"]] = None,
55
+ text: Optional[str] = None,
56
+ system_prompt: Optional[str] = None,
57
+ return_tensors: str = "pt",
58
+ **kwargs,
59
+ ) -> dict:
60
+ """Process audio and text inputs for inference.
61
+
62
+ Args:
63
+ audio: Raw audio waveform(s)
64
+ text: Target transcription (optional, for training - but use DataCollator instead)
65
+ system_prompt: Optional system prompt
66
+ return_tensors: Return format ("pt" for PyTorch)
67
+
68
+ Returns:
69
+ Dict with input_features, input_ids, attention_mask
70
+ """
71
+ result = {}
72
+
73
+ # Process audio
74
+ if audio is not None:
75
+ audio_inputs = self.feature_extractor(
76
+ audio,
77
+ sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
78
+ return_attention_mask=True,
79
+ return_tensors=return_tensors,
80
+ **kwargs,
81
+ )
82
+ result["input_features"] = audio_inputs["input_features"]
83
+ result["audio_attention_mask"] = audio_inputs["attention_mask"]
84
+
85
+ # Use actual audio length (from attention mask) for token count
86
+ real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
87
+ encoder_output_len = self._compute_encoder_output_length(real_mel_len)
88
+ num_audio_tokens = self.projector.get_output_length(encoder_output_len)
89
+ else:
90
+ num_audio_tokens = 0
91
+
92
+ # Build prompt with audio token placeholders (instruction-free)
93
+ if num_audio_tokens > 0:
94
+ user_content = self.AUDIO_TOKEN * num_audio_tokens
95
+ if self.TRANSCRIBE_PROMPT:
96
+ user_content += " " + self.TRANSCRIBE_PROMPT
97
+ else:
98
+ user_content = self.TRANSCRIBE_PROMPT or ""
99
+
100
+ messages = []
101
+ if system_prompt:
102
+ messages.append({"role": "system", "content": system_prompt})
103
+ messages.append({"role": "user", "content": user_content})
104
+ if text is not None:
105
+ messages.append({"role": "assistant", "content": text})
106
+
107
+ # Tokenize
108
+ tokenized = self.tokenizer.apply_chat_template(
109
+ messages,
110
+ tokenize=True,
111
+ add_generation_prompt=(text is None),
112
+ return_tensors=return_tensors,
113
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
114
+ )
115
+
116
+ # Handle both tensor and BatchEncoding returns
117
+ if isinstance(tokenized, torch.Tensor):
118
+ input_ids = tokenized
119
+ else:
120
+ # BatchEncoding or dict-like object
121
+ input_ids = tokenized.get("input_ids", tokenized.input_ids)
122
+
123
+ if input_ids.dim() == 1:
124
+ input_ids = input_ids.unsqueeze(0)
125
+
126
+ result["input_ids"] = input_ids
127
+ result["attention_mask"] = torch.ones_like(input_ids)
128
+
129
+ return result
130
+
131
+
132
+ ASRProcessor.register_for_auto_class()
133
+ transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
audio_head.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flow matching audio head for speech-to-speech.
2
+
3
+ Generates audio from LLM hidden states via flow matching:
4
+ LLM hidden -> llm_proj -> flow_net (LSD decode) -> Mimi latents -> Mimi decoder -> audio
5
+
6
+ Supports two modes:
7
+ 1. Training from scratch with 512-dim Mimi embeddings (latent_proj_in/out)
8
+ 2. Using pretrained pocket-tts flow_net with 32-dim normalized latents
9
+ """
10
+
11
+ import logging
12
+ from functools import partial
13
+ from typing import Optional
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from .modules.mlp import SimpleMLPAdaLN
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def lsd_decode(
24
+ v_t,
25
+ x_0: torch.Tensor,
26
+ num_steps: int = 1,
27
+ ) -> torch.Tensor:
28
+ """Lagrangian Self-Distillation decoding.
29
+
30
+ Iteratively refines noise into latents using the flow velocity network.
31
+
32
+ Args:
33
+ v_t: Velocity function v(s, t, x) -> velocity
34
+ x_0: Initial noise, shape [N, latent_dim]
35
+ num_steps: Number of integration steps
36
+
37
+ Returns:
38
+ Decoded latents, shape [N, latent_dim]
39
+ """
40
+ current = x_0
41
+ for i in range(num_steps):
42
+ s = i / num_steps
43
+ t = (i + 1) / num_steps
44
+ s_tensor = torch.full_like(x_0[..., :1], s)
45
+ t_tensor = torch.full_like(x_0[..., :1], t)
46
+ flow_dir = v_t(s_tensor, t_tensor, current)
47
+ current = current + flow_dir / num_steps
48
+ return current
49
+
50
+
51
+ class AudioHead(nn.Module):
52
+ """Flow matching head: LLM hidden -> Mimi latents -> audio.
53
+
54
+ Architecture:
55
+ - llm_proj: Linear projection from LLM hidden dim to flow conditioning
56
+ - latent_proj_in/out: Project between Mimi 512-dim and flow 32-dim
57
+ - flow_net: SimpleMLPAdaLN that predicts flow velocity
58
+ - Mimi decoder for latent -> audio
59
+
60
+ Args:
61
+ config: ASRConfig with:
62
+ - llm_dim: LLM hidden dimension (default: 2048)
63
+ - lsd_decode_steps: Number of LSD integration steps (default: 1)
64
+ - flow_temperature: Sampling temperature for noise (default: 1.0)
65
+ """
66
+
67
+ # Architecture dimensions
68
+ COND_DIM = 1024 # Conditioning dimension
69
+ LATENT_DIM = 32 # Flow latent dimension (matches Mimi's 32 codebooks)
70
+ MIMI_DIM = 512 # Mimi encoder output dimension
71
+ FLOW_DIM = 512 # Flow network hidden dimension
72
+ FLOW_DEPTH = 6 # Number of residual blocks
73
+
74
+ def __init__(self, config, llm_dim: int = None):
75
+ super().__init__()
76
+ # llm_dim can be passed directly or from config
77
+ self.llm_dim = llm_dim or getattr(config, "llm_dim", None) or 2048
78
+ self.cond_dim = self.COND_DIM
79
+ self.latent_dim = self.LATENT_DIM
80
+ self.mimi_dim = self.MIMI_DIM
81
+ self.lsd_steps = getattr(config, "lsd_decode_steps", 1)
82
+ self.temp = getattr(config, "flow_temperature", 1.0)
83
+
84
+ # LLM -> conditioning projection
85
+ self.llm_proj = nn.Linear(self.llm_dim, self.cond_dim, bias=False)
86
+
87
+ # Mimi embedding projections
88
+ # Projects 512-dim Mimi embeddings to 32-dim flow latents and back
89
+ self.latent_proj_in = nn.Linear(self.mimi_dim, self.latent_dim, bias=False)
90
+ self.latent_proj_out = nn.Linear(self.latent_dim, self.mimi_dim, bias=False)
91
+
92
+ # Flow network
93
+ self.flow_net = SimpleMLPAdaLN(
94
+ in_channels=self.latent_dim,
95
+ model_channels=self.FLOW_DIM,
96
+ out_channels=self.latent_dim,
97
+ cond_channels=self.cond_dim,
98
+ num_res_blocks=self.FLOW_DEPTH,
99
+ num_time_conds=2,
100
+ )
101
+
102
+ # Normalization buffers for pretrained pocket-tts flow_net
103
+ # When using pretrained weights, the flow operates in normalized 32-dim space
104
+ self.register_buffer("emb_mean", torch.zeros(self.latent_dim))
105
+ self.register_buffer("emb_std", torch.ones(self.latent_dim))
106
+ self._use_pretrained_normalization = False
107
+
108
+ # Mimi decoder components (loaded separately via load_mimi_decoder)
109
+ self.mimi = None
110
+
111
+ def load_mimi_decoder(self, device: torch.device = None, dtype: torch.dtype = None):
112
+ """Load Mimi model for decoding latents to audio."""
113
+ from transformers import MimiModel
114
+
115
+ self.mimi = MimiModel.from_pretrained("kyutai/mimi")
116
+ self.mimi.requires_grad_(False)
117
+ self.mimi.eval()
118
+
119
+ if device is not None:
120
+ self.mimi = self.mimi.to(device)
121
+ if dtype is not None:
122
+ self.mimi = self.mimi.to(dtype)
123
+
124
+ logger.info("Loaded Mimi decoder from kyutai/mimi")
125
+
126
+ def load_pretrained_flow_net(
127
+ self,
128
+ weights_path: Optional[str] = None,
129
+ freeze: bool = True,
130
+ ):
131
+ """Load pretrained pocket-tts flow_net weights.
132
+
133
+ This enables using the pretrained flow matching network from pocket-tts,
134
+ which operates in normalized 32-dim latent space.
135
+
136
+ Args:
137
+ weights_path: Path to safetensors file. If None, downloads from HuggingFace.
138
+ freeze: Whether to freeze flow_net weights (default: True, only train llm_proj)
139
+ """
140
+ import safetensors.torch
141
+
142
+ if weights_path is None:
143
+ from huggingface_hub import hf_hub_download
144
+
145
+ weights_path = hf_hub_download(
146
+ repo_id="kyutai/pocket-tts", filename="tts_b6369a24.safetensors"
147
+ )
148
+
149
+ state = safetensors.torch.load_file(weights_path)
150
+
151
+ # Extract flow_net weights
152
+ flow_state = {}
153
+ for k, v in state.items():
154
+ if k.startswith("flow_lm.flow_net."):
155
+ new_key = k.replace("flow_lm.flow_net.", "")
156
+ flow_state[new_key] = v
157
+
158
+ self.flow_net.load_state_dict(flow_state)
159
+ logger.info(f"Loaded pretrained flow_net from {weights_path}")
160
+
161
+ # Load normalization buffers
162
+ if "flow_lm.emb_mean" in state:
163
+ self.emb_mean.copy_(state["flow_lm.emb_mean"])
164
+ if "flow_lm.emb_std" in state:
165
+ self.emb_std.copy_(state["flow_lm.emb_std"])
166
+ # Enable normalization for generate
167
+ self._use_pretrained_normalization = True
168
+ logger.info("Loaded emb_mean and emb_std for normalization")
169
+
170
+ if freeze:
171
+ self.flow_net.requires_grad_(False)
172
+ logger.info("Froze flow_net weights (only llm_proj will train)")
173
+
174
+ def forward(
175
+ self,
176
+ hidden_states: torch.Tensor,
177
+ latent_targets: Optional[torch.Tensor] = None,
178
+ latent_lengths: Optional[torch.Tensor] = None,
179
+ ) -> torch.Tensor:
180
+ """Forward pass for training or inference.
181
+
182
+ Args:
183
+ hidden_states: LLM hidden states, shape [batch, seq_len, llm_dim]
184
+ latent_targets: Target Mimi latents for training, shape [batch, seq_len, 512]
185
+ latent_lengths: Actual lengths per sample, shape [batch]
186
+
187
+ Returns:
188
+ Training: scalar flow matching loss
189
+ Inference: generated Mimi latents, shape [batch, seq_len, 512]
190
+ """
191
+ # Project LLM hidden states to conditioning
192
+ cond = self.llm_proj(hidden_states)
193
+
194
+ if latent_targets is not None:
195
+ return self._compute_loss(cond, latent_targets, latent_lengths)
196
+ return self._generate(cond)
197
+
198
+ def _compute_loss(
199
+ self,
200
+ cond: torch.Tensor,
201
+ targets: torch.Tensor,
202
+ lengths: Optional[torch.Tensor],
203
+ ) -> torch.Tensor:
204
+ """Compute flow matching loss with reconstruction term.
205
+
206
+ The loss has two components:
207
+ 1. Flow matching loss: MSE between predicted and target velocities in 32-dim space
208
+ 2. Reconstruction loss: MSE between reconstructed and original 512-dim embeddings
209
+ (this ensures latent_proj_out is trained)
210
+
211
+ Args:
212
+ cond: Conditioning from LLM, shape [batch, cond_seq_len, cond_dim]
213
+ targets: Mimi embeddings, shape [batch, target_seq_len, 512]
214
+ lengths: Optional lengths for masking
215
+ """
216
+ # Debug: check inputs for NaN/Inf
217
+ if torch.isnan(cond).any() or torch.isinf(cond).any():
218
+ logger.warning(
219
+ f"NaN/Inf in cond! shape={cond.shape}, nan={torch.isnan(cond).sum()}, inf={torch.isinf(cond).sum()}"
220
+ )
221
+ if torch.isnan(targets).any() or torch.isinf(targets).any():
222
+ logger.warning(f"NaN/Inf in targets! shape={targets.shape}")
223
+
224
+ batch, cond_seq_len, _ = cond.shape
225
+ target_seq_len = targets.shape[1]
226
+ device = cond.device
227
+ dtype = cond.dtype
228
+
229
+ # Handle empty sequences
230
+ if cond_seq_len == 0 or target_seq_len == 0:
231
+ return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True)
232
+
233
+ # Project 512-dim Mimi embeddings to 32-dim flow latents
234
+ targets_proj = self.latent_proj_in(targets)
235
+
236
+ # Compute reconstruction loss to train latent_proj_out
237
+ # This ensures the projection learns a good inverse mapping
238
+ targets_reconstructed = self.latent_proj_out(targets_proj)
239
+
240
+ # Interpolate targets to match conditioning sequence length
241
+ targets_for_interp = targets
242
+ if target_seq_len != cond_seq_len:
243
+ targets_proj = targets_proj.transpose(1, 2)
244
+ targets_proj = torch.nn.functional.interpolate(
245
+ targets_proj, size=cond_seq_len, mode="linear", align_corners=False
246
+ )
247
+ targets_proj = targets_proj.transpose(1, 2).contiguous()
248
+
249
+ # Also interpolate original targets for reconstruction loss
250
+ targets_for_interp = targets.transpose(1, 2)
251
+ targets_for_interp = torch.nn.functional.interpolate(
252
+ targets_for_interp, size=cond_seq_len, mode="linear", align_corners=False
253
+ )
254
+ targets_for_interp = targets_for_interp.transpose(1, 2).contiguous()
255
+
256
+ # Interpolate reconstructed targets to match
257
+ targets_reconstructed = targets_reconstructed.transpose(1, 2)
258
+ targets_reconstructed = torch.nn.functional.interpolate(
259
+ targets_reconstructed, size=cond_seq_len, mode="linear", align_corners=False
260
+ )
261
+ targets_reconstructed = targets_reconstructed.transpose(1, 2).contiguous()
262
+
263
+ if lengths is not None:
264
+ scale = cond_seq_len / target_seq_len
265
+ lengths = (lengths.float() * scale).long()
266
+
267
+ seq_len = cond_seq_len
268
+ x_1 = targets_proj
269
+
270
+ # Random timesteps for each sample/position (match input dtype)
271
+ t = torch.rand(batch, seq_len, 1, device=device, dtype=dtype)
272
+
273
+ # Sample noise
274
+ x_0 = torch.randn_like(x_1)
275
+
276
+ # Linear interpolation: x_t = (1-t) * x_0 + t * x_1
277
+ x_t = (1 - t) * x_0 + t * x_1
278
+
279
+ # Target velocity: dx/dt = x_1 - x_0
280
+ v_target = x_1 - x_0
281
+
282
+ # Flatten for flow_net: [batch * seq_len, dim]
283
+ cond_flat = cond.view(-1, self.cond_dim)
284
+ t_flat = t.view(-1, 1)
285
+ x_t_flat = x_t.view(-1, self.latent_dim)
286
+
287
+ # Predict velocity
288
+ v_pred = self.flow_net(cond_flat, t_flat, t_flat, x_t_flat)
289
+ v_pred = v_pred.view(batch, seq_len, self.latent_dim)
290
+
291
+ # Compute masked losses
292
+ if lengths is not None:
293
+ positions = torch.arange(seq_len, device=device).unsqueeze(0)
294
+ mask = positions < lengths.unsqueeze(1)
295
+
296
+ # Check if mask is all False (no valid positions)
297
+ if not mask.any():
298
+ return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True)
299
+
300
+ flow_mask = mask.unsqueeze(-1).expand_as(v_pred)
301
+ recon_mask = mask.unsqueeze(-1).expand_as(targets_reconstructed)
302
+
303
+ flow_loss = ((v_pred - v_target) ** 2)[flow_mask].mean()
304
+ recon_loss = ((targets_reconstructed - targets_for_interp) ** 2)[recon_mask].mean()
305
+ else:
306
+ flow_loss = ((v_pred - v_target) ** 2).mean()
307
+ recon_loss = ((targets_reconstructed - targets_for_interp) ** 2).mean()
308
+
309
+ # Combined loss (reconstruction loss weighted at 0.1 to not dominate)
310
+ return flow_loss + 0.1 * recon_loss
311
+
312
+ def _generate(self, cond: torch.Tensor) -> torch.Tensor:
313
+ """Generate Mimi embeddings via LSD decoding.
314
+
315
+ Args:
316
+ cond: Conditioning from LLM, shape [batch, seq_len, cond_dim]
317
+
318
+ Returns:
319
+ Generated Mimi embeddings, shape [batch, seq_len, 512]
320
+ """
321
+ batch, seq_len, _ = cond.shape
322
+ device = cond.device
323
+ dtype = cond.dtype
324
+
325
+ # Handle empty sequences
326
+ if seq_len == 0:
327
+ return torch.empty(batch, 0, self.mimi_dim, device=device, dtype=dtype)
328
+
329
+ # Clamp temperature to non-negative to avoid complex numbers from sqrt
330
+ temp = max(0.0, self.temp)
331
+
332
+ latents = []
333
+ for t in range(seq_len):
334
+ cond_t = cond[:, t]
335
+
336
+ # Sample initial noise in 32-dim flow space
337
+ noise = torch.randn(batch, self.latent_dim, device=device, dtype=dtype)
338
+ noise = noise * (temp**0.5)
339
+
340
+ def velocity_fn(cond_fixed, s, t, x):
341
+ return self.flow_net(cond_fixed, s, t, x)
342
+
343
+ conditioned_flow = partial(velocity_fn, cond_t)
344
+ latent = lsd_decode(conditioned_flow, noise, self.lsd_steps)
345
+ latents.append(latent)
346
+
347
+ latents = torch.stack(latents, dim=1)
348
+
349
+ # Denormalize if using pretrained pocket-tts normalization
350
+ if self._use_pretrained_normalization:
351
+ latents = latents * self.emb_std + self.emb_mean
352
+
353
+ # Project back to 512-dim Mimi embedding space
354
+ return self.latent_proj_out(latents)
355
+
356
+ def decode_to_audio(self, latents: torch.Tensor) -> torch.Tensor:
357
+ """Decode Mimi latents to audio waveform.
358
+
359
+ Note: HuggingFace MimiModel.decode() expects discrete codes, not continuous
360
+ embeddings. We bypass the quantizer and call upsample → decoder_transformer
361
+ → decoder directly to decode from continuous latents.
362
+
363
+ Args:
364
+ latents: Mimi latents, shape [batch, seq_len, 512]
365
+
366
+ Returns:
367
+ Audio waveform, shape [batch, samples]
368
+ """
369
+ if self.mimi is None:
370
+ raise RuntimeError("Mimi decoder not loaded. Call load_mimi_decoder() first.")
371
+
372
+ # [batch, seq, 512] → [batch, 512, seq]
373
+ latents = latents.transpose(1, 2)
374
+
375
+ with torch.no_grad():
376
+ # Upsample latents (2x temporal upsampling)
377
+ emb = self.mimi.upsample(latents)
378
+
379
+ # Decoder transformer expects [batch, seq, dim]
380
+ emb = emb.transpose(1, 2)
381
+ decoder_out = self.mimi.decoder_transformer(emb)
382
+ emb = getattr(decoder_out, "last_hidden_state", decoder_out[0])
383
+
384
+ # Final decoder expects [batch, dim, seq]
385
+ emb = emb.transpose(1, 2)
386
+ audio = self.mimi.decoder(emb)
387
+
388
+ return audio.squeeze(1)
389
+
390
+ def get_output_length(self, input_length: int) -> int:
391
+ """Estimate output audio frames from input hidden state length.
392
+
393
+ For Mimi at 12.5 Hz frame rate with 24kHz audio:
394
+ Each latent frame = 24000 / 12.5 = 1920 audio samples
395
+ """
396
+ return input_length * 1920
chat_template.jinja ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {# ───── defaults ───── #}
2
+ {%- if enable_thinking is not defined -%}
3
+ {%- set enable_thinking = true -%}
4
+ {%- endif -%}
5
+
6
+ {# ───── reasoning mode ───── #}
7
+ {%- if enable_thinking -%}
8
+ {%- set reasoning_mode = "/think" -%}
9
+ {%- else -%}
10
+ {%- set reasoning_mode = "/no_think" -%}
11
+ {%- endif -%}
12
+
13
+ {# ───── header (system message) ───── #}
14
+ {{- "<|im_start|>system\n" -}}
15
+
16
+ {%- if messages[0].role == "system" -%}
17
+ {%- set system_message = messages[0].content -%}
18
+ {%- if "/no_think" in system_message -%}
19
+ {%- set reasoning_mode = "/no_think" -%}
20
+ {%- elif "/think" in system_message -%}
21
+ {%- set reasoning_mode = "/think" -%}
22
+ {%- endif -%}
23
+ {%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
24
+ {%- endif -%}
25
+
26
+ {%- if "/system_override" in system_message -%}
27
+ {{- custom_instructions.replace("/system_override", "").rstrip() -}}
28
+ {{- "<|im_end|>\n" -}}
29
+ {%- else -%}
30
+ {{- "## Metadata\n\n" -}}
31
+ {{- "Knowledge Cutoff Date: June 2025\n" -}}
32
+ {%- set today = strftime_now("%d %B %Y") -%}
33
+ {{- "Today Date: " ~ today ~ "\n" -}}
34
+ {{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
35
+
36
+ {{- "## Custom Instructions\n\n" -}}
37
+ {%- if custom_instructions -%}
38
+ {{- custom_instructions + "\n\n" -}}
39
+ {%- elif reasoning_mode == "/think" -%}
40
+ {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
41
+ {%- else -%}
42
+ {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
43
+ {%- endif -%}
44
+
45
+ {%- if xml_tools or python_tools or tools -%}
46
+ {{- "### Tools\n\n" -}}
47
+ {%- if xml_tools or tools -%}
48
+ {%- if tools -%}
49
+ {%- set xml_tools = tools -%}
50
+ {%- endif -%}
51
+ {%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
52
+ {%- for tool in xml_tools[:] -%} {# The slicing makes sure that xml_tools is a list #}
53
+ {%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | string) ~ "\n" -%}
54
+ {%- endfor -%}
55
+ {%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
56
+ {{- xml_tool_string -}}
57
+ {%- endif -%}
58
+ {%- if python_tools -%}
59
+ {%- set ns = namespace(python_tool_string="When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.\n\nYou can use the following tools in your python code like regular functions:\n<tools>\n") -%}
60
+ {%- for tool in python_tools[:] -%} {# The slicing makes sure that python_tools is a list #}
61
+ {%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
62
+ {%- endfor -%}
63
+ {%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions: so variables that you define in one step are still available thereafter." -%}
64
+ {{- python_tool_string -}}
65
+ {%- endif -%}
66
+ {{- "\n\n" -}}
67
+ {{- "<|im_end|>\n" -}}
68
+ {%- endif -%}
69
+ {%- endif -%}
70
+ {# ───── main loop ───── #}
71
+ {%- for message in messages -%}
72
+ {%- set content = message.content if message.content is string else "" -%}
73
+ {%- if message.role == "user" -%}
74
+ {{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
75
+ {%- elif message.role == "assistant" -%}
76
+ {% generation %}
77
+ {%- if reasoning_mode == "/think" -%}
78
+ {{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
79
+ {%- else -%}
80
+ {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
81
+ {%- endif -%}
82
+ {% endgeneration %}
83
+ {%- elif message.role == "tool" -%}
84
+ {{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
85
+ {%- endif -%}
86
+ {%- endfor -%}
87
+ {# ───── generation prompt ───── #}
88
+ {%- if add_generation_prompt -%}
89
+ {%- if reasoning_mode == "/think" -%}
90
+ {{ "<|im_start|>assistant\n" }}
91
+ {%- else -%}
92
+ {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
93
+ {%- endif -%}
94
+ {%- endif -%}
diarization.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
2
+
3
+ Spectral clustering implementation adapted from FunASR/3D-Speaker:
4
+ https://github.com/alibaba-damo-academy/FunASR
5
+ MIT License (https://opensource.org/licenses/MIT)
6
+ """
7
+
8
+ import warnings
9
+
10
+ import numpy as np
11
+ import scipy
12
+ import sklearn.metrics.pairwise
13
+ import torch
14
+ from sklearn.cluster._kmeans import k_means
15
+ from sklearn.preprocessing import normalize
16
+
17
+
18
+ def _get_device() -> torch.device:
19
+ """Get best available device for inference."""
20
+ if torch.cuda.is_available():
21
+ return torch.device("cuda")
22
+ if torch.backends.mps.is_available():
23
+ return torch.device("mps")
24
+ return torch.device("cpu")
25
+
26
+
27
+ class SpectralCluster:
28
+ """Spectral clustering using unnormalized Laplacian of affinity matrix.
29
+
30
+ Adapted from FunASR/3D-Speaker and SpeechBrain implementations.
31
+ Uses eigenvalue gap to automatically determine number of speakers.
32
+ """
33
+
34
+ def __init__(self, min_num_spks: int = 1, max_num_spks: int = 15, pval: float = 0.06):
35
+ self.min_num_spks = min_num_spks
36
+ self.max_num_spks = max_num_spks
37
+ self.pval = pval
38
+
39
+ def __call__(self, embeddings: np.ndarray, oracle_num: int | None = None) -> np.ndarray:
40
+ """Run spectral clustering on embeddings.
41
+
42
+ Args:
43
+ embeddings: Speaker embeddings of shape [N, D]
44
+ oracle_num: Optional known number of speakers
45
+
46
+ Returns:
47
+ Cluster labels of shape [N]
48
+ """
49
+ # Similarity matrix computation
50
+ sim_mat = self.get_sim_mat(embeddings)
51
+
52
+ # Refining similarity matrix with pval
53
+ prunned_sim_mat = self.p_pruning(sim_mat)
54
+
55
+ # Symmetrization
56
+ sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
57
+
58
+ # Laplacian calculation
59
+ laplacian = self.get_laplacian(sym_prund_sim_mat)
60
+
61
+ # Get Spectral Embeddings
62
+ emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
63
+
64
+ # Perform clustering
65
+ return self.cluster_embs(emb, num_of_spk)
66
+
67
+ def get_sim_mat(self, embeddings: np.ndarray) -> np.ndarray:
68
+ """Compute cosine similarity matrix."""
69
+ return sklearn.metrics.pairwise.cosine_similarity(embeddings, embeddings)
70
+
71
+ def p_pruning(self, affinity: np.ndarray) -> np.ndarray:
72
+ """Prune low similarity values in affinity matrix (keep top pval fraction)."""
73
+ n = affinity.shape[0]
74
+ pval = max(self.pval, 6.0 / n)
75
+ k_keep = max(1, int(pval * n))
76
+
77
+ # Vectorized: find top-k indices per row and zero out the rest
78
+ top_k_idx = np.argpartition(affinity, -k_keep, axis=1)[:, -k_keep:]
79
+ mask = np.zeros_like(affinity, dtype=bool)
80
+ np.put_along_axis(mask, top_k_idx, True, axis=1)
81
+ affinity[~mask] = 0
82
+ return affinity
83
+
84
+ def get_laplacian(self, sim_mat: np.ndarray) -> np.ndarray:
85
+ """Compute unnormalized Laplacian matrix."""
86
+ from scipy.sparse.csgraph import laplacian
87
+
88
+ np.fill_diagonal(sim_mat, 0)
89
+ return laplacian(sim_mat, normed=False)
90
+
91
+ def get_spec_embs(
92
+ self, laplacian: np.ndarray, k_oracle: int | None = None
93
+ ) -> tuple[np.ndarray, int]:
94
+ """Extract spectral embeddings from Laplacian.
95
+
96
+ Uses the eigengap heuristic to estimate the number of clusters:
97
+ The number of clusters k is chosen where the gap between consecutive
98
+ eigenvalues is largest, indicating a transition from "cluster" eigenvalues
99
+ (near 0) to "noise" eigenvalues.
100
+ """
101
+ lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
102
+
103
+ num_of_spk = k_oracle if k_oracle is not None else self._estimate_num_speakers(lambdas)
104
+
105
+ emb = eig_vecs[:, :num_of_spk]
106
+ return emb, num_of_spk
107
+
108
+ def _estimate_num_speakers(self, lambdas: np.ndarray) -> int:
109
+ """Estimate number of speakers using refined eigengap heuristic.
110
+
111
+ For spectral clustering, we look for the largest gap in eigenvalues.
112
+ The eigenvalues corresponding to clusters are close to 0, and there
113
+ should be a significant jump to the remaining eigenvalues.
114
+ """
115
+ # Consider eigenvalues from index 1 to max_num_spks (skip first, it's always ~0)
116
+ # We need gaps between positions, so look at indices 1 to max_num_spks+1
117
+ max_idx = min(self.max_num_spks + 1, len(lambdas))
118
+ relevant_lambdas = lambdas[1:max_idx] # Skip first eigenvalue
119
+
120
+ if len(relevant_lambdas) < 2:
121
+ return self.min_num_spks
122
+
123
+ # Compute absolute gaps (not ratios - ratios are unstable near 0)
124
+ gaps = np.diff(relevant_lambdas)
125
+
126
+ # Find the largest gap - the index gives us (k-1) since we skipped first
127
+ # Add 1 to convert from gap index to number of speakers
128
+ # Add 1 again because we skipped the first eigenvalue
129
+ max_gap_idx = int(np.argmax(gaps))
130
+ num_of_spk = max_gap_idx + 2 # +1 for gap->count, +1 for skipped eigenvalue
131
+
132
+ # Clamp between min and max
133
+ return max(self.min_num_spks, min(num_of_spk, self.max_num_spks))
134
+
135
+ def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
136
+ """Cluster spectral embeddings using k-means."""
137
+ _, labels, _ = k_means(emb, k, n_init=10)
138
+ return labels
139
+
140
+ def get_eigen_gaps(self, eig_vals: np.ndarray) -> np.ndarray:
141
+ """Compute gaps between consecutive eigenvalues."""
142
+ return np.diff(eig_vals)
143
+
144
+
145
+ class SpeakerClusterer:
146
+ """Speaker clustering backend using spectral clustering with speaker merging.
147
+
148
+ Features:
149
+ - Spectral clustering with eigenvalue gap for auto speaker count detection
150
+ - P-pruning for affinity matrix refinement
151
+ - Post-clustering speaker merging by cosine similarity
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ min_num_spks: int = 2,
157
+ max_num_spks: int = 10,
158
+ merge_thr: float = 0.90, # Moderate merging
159
+ ):
160
+ self.min_num_spks = min_num_spks
161
+ self.max_num_spks = max_num_spks
162
+ self.merge_thr = merge_thr
163
+ self._spectral_cluster: SpectralCluster | None = None
164
+
165
+ def _get_spectral_cluster(self) -> SpectralCluster:
166
+ """Lazy-load spectral clusterer."""
167
+ if self._spectral_cluster is None:
168
+ self._spectral_cluster = SpectralCluster(
169
+ min_num_spks=self.min_num_spks,
170
+ max_num_spks=self.max_num_spks,
171
+ )
172
+ return self._spectral_cluster
173
+
174
+ def __call__(self, embeddings: np.ndarray, num_speakers: int | None = None) -> np.ndarray:
175
+ """Cluster speaker embeddings and return labels.
176
+
177
+ Args:
178
+ embeddings: Speaker embeddings of shape [N, D]
179
+ num_speakers: Optional oracle number of speakers
180
+
181
+ Returns:
182
+ Cluster labels of shape [N]
183
+ """
184
+ import warnings
185
+
186
+ if len(embeddings.shape) != 2:
187
+ raise ValueError(f"Expected 2D array, got shape {embeddings.shape}")
188
+
189
+ # Handle edge cases
190
+ if embeddings.shape[0] == 0:
191
+ return np.array([], dtype=int)
192
+ if embeddings.shape[0] == 1:
193
+ return np.array([0], dtype=int)
194
+ if embeddings.shape[0] < 6:
195
+ return np.zeros(embeddings.shape[0], dtype=int)
196
+
197
+ # Normalize embeddings and replace NaN/inf
198
+ embeddings = np.nan_to_num(embeddings, nan=0.0, posinf=0.0, neginf=0.0)
199
+ embeddings = normalize(embeddings)
200
+
201
+ # Run spectral clustering (suppress numerical warnings)
202
+ spectral = self._get_spectral_cluster()
203
+
204
+ # Update min/max for oracle case
205
+ if num_speakers is not None:
206
+ spectral.min_num_spks = num_speakers
207
+ spectral.max_num_spks = num_speakers
208
+
209
+ with warnings.catch_warnings():
210
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
211
+ labels = spectral(embeddings, oracle_num=num_speakers)
212
+
213
+ # Reset min/max
214
+ if num_speakers is not None:
215
+ spectral.min_num_spks = self.min_num_spks
216
+ spectral.max_num_spks = self.max_num_spks
217
+
218
+ # Merge similar speakers if no oracle
219
+ if num_speakers is None:
220
+ labels = self._merge_by_cos(labels, embeddings, self.merge_thr)
221
+
222
+ # Re-index labels sequentially
223
+ _, labels = np.unique(labels, return_inverse=True)
224
+
225
+ return labels
226
+
227
+ def _merge_by_cos(self, labels: np.ndarray, embs: np.ndarray, cos_thr: float) -> np.ndarray:
228
+ """Merge similar speakers by cosine similarity of centroids."""
229
+ from scipy.cluster.hierarchy import fcluster, linkage
230
+ from scipy.spatial.distance import pdist
231
+
232
+ unique_labels = np.unique(labels)
233
+ if len(unique_labels) <= 1:
234
+ return labels
235
+
236
+ # Compute normalized speaker centroids
237
+ centroids = np.array([embs[labels == lbl].mean(0) for lbl in unique_labels])
238
+ centroids = normalize(centroids)
239
+
240
+ # Hierarchical clustering with cosine distance
241
+ distances = pdist(centroids, metric="cosine")
242
+ linkage_matrix = linkage(distances, method="average")
243
+ merged_labels = fcluster(linkage_matrix, t=1.0 - cos_thr, criterion="distance") - 1
244
+
245
+ # Map original labels to merged labels
246
+ label_map = dict(zip(unique_labels, merged_labels))
247
+ return np.array([label_map[lbl] for lbl in labels])
248
+
249
+
250
+ class LocalSpeakerDiarizer:
251
+ """Local speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
252
+
253
+ Pipeline:
254
+ 1. TEN-VAD detects speech segments
255
+ 2. Sliding window (1.0s, 75% overlap) for uniform embedding extraction
256
+ 3. ECAPA-TDNN extracts speaker embeddings per window
257
+ 4. Spectral clustering with eigenvalue gap for auto speaker detection
258
+ 5. Frame-level consensus voting for segment reconstruction
259
+ 6. Post-processing merges short segments to reduce flicker
260
+
261
+ Tunable Parameters (class attributes):
262
+ - WINDOW_SIZE: Embedding extraction window size in seconds
263
+ - STEP_SIZE: Sliding window step size (overlap = WINDOW_SIZE - STEP_SIZE)
264
+ - VAD_THRESHOLD: Speech detection threshold (lower = more sensitive)
265
+ - VAD_MIN_DURATION: Minimum speech segment duration
266
+ - VAD_MAX_GAP: Maximum gap to bridge between segments
267
+ - VAD_PAD_ONSET/OFFSET: Padding added to speech segments
268
+ - VOTING_RATE: Frame resolution for consensus voting
269
+ - MIN_SEGMENT_DURATION: Minimum final segment duration
270
+ - SAME_SPEAKER_GAP: Maximum gap to merge same-speaker segments
271
+ - TAIL_COVERAGE_RATIO: Minimum tail coverage to add extra window
272
+ """
273
+
274
+ _ten_vad_model = None
275
+ _ecapa_model = None
276
+ _device = None
277
+
278
+ # ==================== TUNABLE PARAMETERS ====================
279
+
280
+ # Sliding window for embedding extraction
281
+ WINDOW_SIZE = 0.75 # seconds - shorter window for finer resolution
282
+ STEP_SIZE = 0.15 # seconds (80% overlap for more votes)
283
+ TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
284
+
285
+ # VAD hysteresis parameters
286
+ VAD_THRESHOLD = 0.25 # Balanced threshold
287
+ VAD_MIN_DURATION = 0.05 # Minimum speech segment duration (seconds)
288
+ VAD_MAX_GAP = 0.50 # Bridge gaps shorter than this (seconds)
289
+ VAD_PAD_ONSET = 0.05 # Padding at segment start (seconds)
290
+ VAD_PAD_OFFSET = 0.05 # Padding at segment end (seconds)
291
+
292
+ # Frame-level voting
293
+ VOTING_RATE = 0.01 # 10ms resolution for consensus voting
294
+
295
+ # Post-processing
296
+ MIN_SEGMENT_DURATION = 0.15 # Minimum final segment duration (seconds)
297
+ SHORT_SEGMENT_GAP = 0.1 # Gap threshold for merging short segments
298
+ SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
299
+
300
+ # ===========================================================
301
+
302
+ @classmethod
303
+ def _get_ten_vad_model(cls):
304
+ """Lazy-load TEN-VAD model (singleton)."""
305
+ if cls._ten_vad_model is None:
306
+ from ten_vad import TenVad
307
+
308
+ cls._ten_vad_model = TenVad(hop_size=256, threshold=cls.VAD_THRESHOLD)
309
+ return cls._ten_vad_model
310
+
311
+ @classmethod
312
+ def _get_device(cls) -> torch.device:
313
+ """Get the best available device."""
314
+ if cls._device is None:
315
+ cls._device = _get_device()
316
+ return cls._device
317
+
318
+ @classmethod
319
+ def _get_ecapa_model(cls):
320
+ """Lazy-load ECAPA-TDNN speaker embedding model (singleton)."""
321
+ if cls._ecapa_model is None:
322
+ # Suppress torchaudio deprecation warning from SpeechBrain
323
+ with warnings.catch_warnings():
324
+ warnings.filterwarnings("ignore", message="torchaudio._backend")
325
+ from speechbrain.inference.speaker import EncoderClassifier
326
+
327
+ device = cls._get_device()
328
+ cls._ecapa_model = EncoderClassifier.from_hparams(
329
+ source="speechbrain/spkrec-ecapa-voxceleb",
330
+ run_opts={"device": str(device)},
331
+ )
332
+
333
+ return cls._ecapa_model
334
+
335
+ @classmethod
336
+ def diarize(
337
+ cls,
338
+ audio: np.ndarray | str,
339
+ sample_rate: int = 16000,
340
+ num_speakers: int | None = None,
341
+ min_speakers: int = 2,
342
+ max_speakers: int = 10,
343
+ **_kwargs,
344
+ ) -> list[dict]:
345
+ """Run speaker diarization on audio.
346
+
347
+ Args:
348
+ audio: Audio waveform as numpy array or path to audio file
349
+ sample_rate: Audio sample rate (default 16000)
350
+ num_speakers: Exact number of speakers (if known)
351
+ min_speakers: Minimum number of speakers
352
+ max_speakers: Maximum number of speakers
353
+
354
+ Returns:
355
+ List of dicts with 'speaker', 'start', 'end' keys
356
+ """
357
+ # Handle file path input
358
+ if isinstance(audio, str):
359
+ import librosa
360
+
361
+ audio, sample_rate = librosa.load(audio, sr=16000)
362
+
363
+ # Ensure correct sample rate
364
+ if sample_rate != 16000:
365
+ import librosa
366
+
367
+ audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
368
+ sample_rate = 16000
369
+
370
+ audio = audio.astype(np.float32)
371
+ total_duration = len(audio) / sample_rate
372
+
373
+ # Step 1: VAD (returns segments and raw frame-level decisions)
374
+ segments, vad_frames = cls._get_speech_segments(audio, sample_rate)
375
+ if not segments:
376
+ return []
377
+
378
+ # Step 2: Extract embeddings
379
+ embeddings, window_segments = cls._extract_embeddings(audio, segments, sample_rate)
380
+ if len(embeddings) == 0:
381
+ return []
382
+
383
+ # Step 3: Cluster
384
+ clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
385
+ labels = clusterer(embeddings, num_speakers)
386
+
387
+ # Step 4: Post-process with consensus voting (VAD-aware)
388
+ return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
389
+
390
+ @classmethod
391
+ def _get_speech_segments(
392
+ cls, audio_array: np.ndarray, sample_rate: int = 16000
393
+ ) -> tuple[list[dict], list[bool]]:
394
+ """Get speech segments using TEN-VAD.
395
+
396
+ Returns:
397
+ Tuple of (segments list, vad_frames list of per-frame speech decisions)
398
+ """
399
+ vad_model = cls._get_ten_vad_model()
400
+
401
+ # Convert to int16 as required by TEN-VAD
402
+ # Clip to prevent integer overflow
403
+ if audio_array.dtype != np.int16:
404
+ audio_int16 = (np.clip(audio_array, -1.0, 1.0) * 32767).astype(np.int16)
405
+ else:
406
+ audio_int16 = audio_array
407
+
408
+ # Process frame by frame
409
+ hop_size = 256
410
+ frame_duration = hop_size / sample_rate
411
+ speech_frames: list[bool] = []
412
+
413
+ for i in range(0, len(audio_int16) - hop_size, hop_size):
414
+ frame = audio_int16[i : i + hop_size]
415
+ _, is_speech = vad_model.process(frame)
416
+ speech_frames.append(is_speech)
417
+
418
+ # Convert frame-level decisions to segments
419
+ segments = []
420
+ in_speech = False
421
+ start_idx = 0
422
+
423
+ for i, is_speech in enumerate(speech_frames):
424
+ if is_speech and not in_speech:
425
+ start_idx = i
426
+ in_speech = True
427
+ elif not is_speech and in_speech:
428
+ start_time = start_idx * frame_duration
429
+ end_time = i * frame_duration
430
+ segments.append(
431
+ {
432
+ "start": start_time,
433
+ "end": end_time,
434
+ "start_sample": int(start_time * sample_rate),
435
+ "end_sample": int(end_time * sample_rate),
436
+ }
437
+ )
438
+ in_speech = False
439
+
440
+ # Handle trailing speech
441
+ if in_speech:
442
+ start_time = start_idx * frame_duration
443
+ end_time = len(speech_frames) * frame_duration
444
+ segments.append(
445
+ {
446
+ "start": start_time,
447
+ "end": end_time,
448
+ "start_sample": int(start_time * sample_rate),
449
+ "end_sample": int(end_time * sample_rate),
450
+ }
451
+ )
452
+
453
+ return cls._apply_vad_hysteresis(segments, sample_rate), speech_frames
454
+
455
+ @classmethod
456
+ def _apply_vad_hysteresis(cls, segments: list[dict], sample_rate: int = 16000) -> list[dict]:
457
+ """Apply hysteresis-like post-processing to VAD segments."""
458
+ if not segments:
459
+ return segments
460
+
461
+ segments = sorted(segments, key=lambda x: x["start"])
462
+
463
+ # Fill short gaps
464
+ merged = [segments[0].copy()]
465
+ for seg in segments[1:]:
466
+ gap = seg["start"] - merged[-1]["end"]
467
+ if gap <= cls.VAD_MAX_GAP:
468
+ merged[-1]["end"] = seg["end"]
469
+ merged[-1]["end_sample"] = seg["end_sample"]
470
+ else:
471
+ merged.append(seg.copy())
472
+
473
+ # Remove short segments
474
+ filtered = [seg for seg in merged if (seg["end"] - seg["start"]) >= cls.VAD_MIN_DURATION]
475
+
476
+ # Dilate segments (add padding)
477
+ for seg in filtered:
478
+ seg["start"] = max(0.0, seg["start"] - cls.VAD_PAD_ONSET)
479
+ seg["end"] = seg["end"] + cls.VAD_PAD_OFFSET
480
+ seg["start_sample"] = int(seg["start"] * sample_rate)
481
+ seg["end_sample"] = int(seg["end"] * sample_rate)
482
+
483
+ return filtered
484
+
485
+ @classmethod
486
+ def _extract_embeddings(
487
+ cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
488
+ ) -> tuple[np.ndarray, list[dict]]:
489
+ """Extract speaker embeddings using sliding windows."""
490
+ speaker_model = cls._get_ecapa_model()
491
+
492
+ window_samples = int(cls.WINDOW_SIZE * sample_rate)
493
+ step_samples = int(cls.STEP_SIZE * sample_rate)
494
+
495
+ embeddings = []
496
+ window_segments = []
497
+
498
+ with torch.no_grad():
499
+ for seg in segments:
500
+ seg_start = seg["start_sample"]
501
+ seg_end = seg["end_sample"]
502
+ seg_len = seg_end - seg_start
503
+
504
+ # Generate window positions
505
+ if seg_len <= window_samples:
506
+ starts = [seg_start]
507
+ ends = [seg_end]
508
+ else:
509
+ starts = list(range(seg_start, seg_end - window_samples + 1, step_samples))
510
+ ends = [s + window_samples for s in starts]
511
+
512
+ # Cover tail if > TAIL_COVERAGE_RATIO of window remains
513
+ if ends and ends[-1] < seg_end:
514
+ remainder = seg_end - ends[-1]
515
+ if remainder > (window_samples * cls.TAIL_COVERAGE_RATIO):
516
+ starts.append(seg_end - window_samples)
517
+ ends.append(seg_end)
518
+
519
+ for c_start, c_end in zip(starts, ends):
520
+ chunk = audio_array[c_start:c_end]
521
+
522
+ # Pad short chunks with reflection
523
+ if len(chunk) < window_samples:
524
+ pad_width = window_samples - len(chunk)
525
+ chunk = np.pad(chunk, (0, pad_width), mode="reflect")
526
+
527
+ # Extract embedding using SpeechBrain's encode_batch
528
+ chunk_tensor = torch.from_numpy(chunk).float().unsqueeze(0)
529
+ embedding = (
530
+ speaker_model.encode_batch(chunk_tensor).squeeze(0).squeeze(0).cpu().numpy()
531
+ )
532
+
533
+ # Validate embedding
534
+ if np.isfinite(embedding).all() and np.linalg.norm(embedding) > 1e-8:
535
+ embeddings.append(embedding)
536
+ window_segments.append(
537
+ {
538
+ "start": c_start / sample_rate,
539
+ "end": c_end / sample_rate,
540
+ }
541
+ )
542
+
543
+ # Normalize all embeddings at once
544
+ if embeddings:
545
+ return normalize(np.array(embeddings)), window_segments
546
+ return np.array([]), []
547
+
548
+ @classmethod
549
+ def _resample_vad(cls, vad_frames: list[bool], num_frames: int) -> np.ndarray:
550
+ """Resample VAD frame decisions to match voting grid resolution.
551
+
552
+ VAD operates at 256 samples / 16000 Hz = 16ms per frame.
553
+ Voting operates at VOTING_RATE (default 10ms) per frame.
554
+ This maps VAD decisions to the finer voting grid.
555
+ """
556
+ if not vad_frames:
557
+ return np.zeros(num_frames, dtype=bool)
558
+
559
+ vad_rate = 256 / 16000 # 16ms per VAD frame
560
+ vad_arr = np.array(vad_frames)
561
+
562
+ # Vectorized: compute VAD frame indices for each voting frame
563
+ voting_times = np.arange(num_frames) * cls.VOTING_RATE
564
+ vad_indices = np.clip((voting_times / vad_rate).astype(int), 0, len(vad_arr) - 1)
565
+ return vad_arr[vad_indices]
566
+
567
+ @classmethod
568
+ def _postprocess_segments(
569
+ cls,
570
+ window_segments: list[dict],
571
+ labels: np.ndarray,
572
+ total_duration: float,
573
+ vad_frames: list[bool],
574
+ ) -> list[dict]:
575
+ """Post-process using frame-level consensus voting with VAD-aware silence."""
576
+ if not window_segments or len(labels) == 0:
577
+ return []
578
+
579
+ # Correct labels to be contiguous
580
+ unique_labels = np.unique(labels)
581
+ label_map = {old: new for new, old in enumerate(unique_labels)}
582
+ clean_labels = np.array([label_map[lbl] for lbl in labels])
583
+ num_speakers = len(unique_labels)
584
+
585
+ if num_speakers == 0:
586
+ return []
587
+
588
+ # Create voting grid
589
+ num_frames = int(np.ceil(total_duration / cls.VOTING_RATE)) + 1
590
+ votes = np.zeros((num_frames, num_speakers), dtype=np.float32)
591
+
592
+ # Accumulate votes
593
+ for win, label in zip(window_segments, clean_labels):
594
+ start_frame = int(win["start"] / cls.VOTING_RATE)
595
+ end_frame = int(win["end"] / cls.VOTING_RATE)
596
+ end_frame = min(end_frame, num_frames)
597
+ if start_frame < end_frame:
598
+ votes[start_frame:end_frame, label] += 1.0
599
+
600
+ # Determine winner per frame
601
+ frame_speakers = np.argmax(votes, axis=1)
602
+ max_votes = np.max(votes, axis=1)
603
+
604
+ # Resample VAD to voting grid resolution for silence-aware voting
605
+ vad_resampled = cls._resample_vad(vad_frames, num_frames)
606
+
607
+ # Convert frames to segments
608
+ final_segments = []
609
+ current_speaker = -1
610
+ seg_start = 0.0
611
+
612
+ for f in range(num_frames):
613
+ speaker = int(frame_speakers[f])
614
+ score = max_votes[f]
615
+
616
+ # Force silence if VAD says no speech OR no votes
617
+ if score == 0 or not vad_resampled[f]:
618
+ speaker = -1
619
+
620
+ if speaker != current_speaker:
621
+ if current_speaker != -1:
622
+ final_segments.append(
623
+ {
624
+ "speaker": f"SPEAKER_{current_speaker}",
625
+ "start": seg_start,
626
+ "end": f * cls.VOTING_RATE,
627
+ }
628
+ )
629
+ current_speaker = speaker
630
+ seg_start = f * cls.VOTING_RATE
631
+
632
+ # Close last segment
633
+ if current_speaker != -1:
634
+ final_segments.append(
635
+ {
636
+ "speaker": f"SPEAKER_{current_speaker}",
637
+ "start": seg_start,
638
+ "end": num_frames * cls.VOTING_RATE,
639
+ }
640
+ )
641
+
642
+ return cls._merge_short_segments(final_segments)
643
+
644
+ @classmethod
645
+ def _merge_short_segments(cls, segments: list[dict]) -> list[dict]:
646
+ """Merge short segments to reduce flicker."""
647
+ if not segments:
648
+ return []
649
+
650
+ clean: list[dict] = []
651
+ for seg in segments:
652
+ dur = seg["end"] - seg["start"]
653
+ if dur < cls.MIN_SEGMENT_DURATION:
654
+ if (
655
+ clean
656
+ and clean[-1]["speaker"] == seg["speaker"]
657
+ and seg["start"] - clean[-1]["end"] < cls.SHORT_SEGMENT_GAP
658
+ ):
659
+ clean[-1]["end"] = seg["end"]
660
+ continue
661
+
662
+ if (
663
+ clean
664
+ and clean[-1]["speaker"] == seg["speaker"]
665
+ and seg["start"] - clean[-1]["end"] < cls.SAME_SPEAKER_GAP
666
+ ):
667
+ clean[-1]["end"] = seg["end"]
668
+ else:
669
+ clean.append(seg)
670
+
671
+ return clean
672
+
673
+ @classmethod
674
+ def assign_speakers_to_words(
675
+ cls,
676
+ words: list[dict],
677
+ speaker_segments: list[dict],
678
+ ) -> list[dict]:
679
+ """Assign speaker labels to words based on timestamp overlap.
680
+
681
+ Args:
682
+ words: List of word dicts with 'word', 'start', 'end' keys
683
+ speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
684
+
685
+ Returns:
686
+ Words list with 'speaker' key added to each word
687
+ """
688
+ for word in words:
689
+ word_mid = (word["start"] + word["end"]) / 2
690
+
691
+ # Find the speaker segment that contains this word's midpoint
692
+ best_speaker = None
693
+ for seg in speaker_segments:
694
+ if seg["start"] <= word_mid <= seg["end"]:
695
+ best_speaker = seg["speaker"]
696
+ break
697
+
698
+ # If no exact match, find closest segment
699
+ if best_speaker is None and speaker_segments:
700
+ min_dist = float("inf")
701
+ for seg in speaker_segments:
702
+ seg_mid = (seg["start"] + seg["end"]) / 2
703
+ dist = abs(word_mid - seg_mid)
704
+ if dist < min_dist:
705
+ min_dist = dist
706
+ best_speaker = seg["speaker"]
707
+
708
+ word["speaker"] = best_speaker
709
+
710
+ return words
711
+
712
+
713
+ class SpeakerDiarizer:
714
+ """Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
715
+
716
+ Example:
717
+ >>> segments = SpeakerDiarizer.diarize(audio_array)
718
+ >>> for seg in segments:
719
+ ... print(f"{seg['speaker']}: {seg['start']:.2f} - {seg['end']:.2f}")
720
+ """
721
+
722
+ @classmethod
723
+ def diarize(
724
+ cls,
725
+ audio: np.ndarray | str,
726
+ sample_rate: int = 16000,
727
+ num_speakers: int | None = None,
728
+ min_speakers: int | None = None,
729
+ max_speakers: int | None = None,
730
+ **_kwargs,
731
+ ) -> list[dict]:
732
+ """Run speaker diarization on audio.
733
+
734
+ Args:
735
+ audio: Audio waveform as numpy array or path to audio file
736
+ sample_rate: Audio sample rate (default 16000)
737
+ num_speakers: Exact number of speakers (if known)
738
+ min_speakers: Minimum number of speakers
739
+ max_speakers: Maximum number of speakers
740
+
741
+ Returns:
742
+ List of dicts with 'speaker', 'start', 'end' keys
743
+ """
744
+ return LocalSpeakerDiarizer.diarize(
745
+ audio,
746
+ sample_rate=sample_rate,
747
+ num_speakers=num_speakers,
748
+ min_speakers=min_speakers or 2,
749
+ max_speakers=max_speakers or 10,
750
+ )
751
+
752
+ @classmethod
753
+ def assign_speakers_to_words(
754
+ cls,
755
+ words: list[dict],
756
+ speaker_segments: list[dict],
757
+ ) -> list[dict]:
758
+ """Assign speaker labels to words based on timestamp overlap."""
759
+ return LocalSpeakerDiarizer.assign_speakers_to_words(words, speaker_segments)
modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Modules for flow matching audio synthesis."""
2
+
3
+ from .mlp import SimpleMLPAdaLN
4
+
5
+ __all__ = ["SimpleMLPAdaLN"]
modules/mlp.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flow matching MLP with adaptive layer normalization.
2
+
3
+ Adapted from pocket-tts, originally from:
4
+ https://github.com/LTH14/mar/blob/fe470ac24afbee924668d8c5c83e9fec60af3a73/models/diffloss.py
5
+
6
+ Reference: https://arxiv.org/abs/2406.11838
7
+ """
8
+
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
16
+ """Apply adaptive normalization modulation."""
17
+ return x * (1 + scale) + shift
18
+
19
+
20
+ class RMSNorm(nn.Module):
21
+ """Root Mean Square Layer Normalization."""
22
+
23
+ def __init__(self, dim: int, eps: float = 1e-5):
24
+ super().__init__()
25
+ self.eps = eps
26
+ self.alpha = nn.Parameter(torch.ones(dim))
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ x_dtype = x.dtype
30
+ var = self.eps + x.var(dim=-1, keepdim=True)
31
+ return (x * (self.alpha.to(var) * torch.rsqrt(var))).to(x_dtype)
32
+
33
+
34
+ class LayerNorm(nn.Module):
35
+ """LayerNorm that supports JVP (for flow matching gradients)."""
36
+
37
+ def __init__(self, channels: int, eps: float = 1e-6, elementwise_affine: bool = True):
38
+ super().__init__()
39
+ self.eps = eps
40
+ if elementwise_affine:
41
+ self.weight = nn.Parameter(torch.ones(channels))
42
+ self.bias = nn.Parameter(torch.zeros(channels))
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ mean = x.mean(dim=-1, keepdim=True)
46
+ var = x.var(dim=-1, unbiased=False, keepdim=True)
47
+ x = (x - mean) / torch.sqrt(var + self.eps)
48
+ if hasattr(self, "weight"):
49
+ x = x * self.weight + self.bias
50
+ return x
51
+
52
+
53
+ class TimestepEmbedder(nn.Module):
54
+ """Embeds scalar timesteps into vector representations."""
55
+
56
+ def __init__(
57
+ self,
58
+ hidden_size: int,
59
+ frequency_embedding_size: int = 256,
60
+ max_period: int = 10000,
61
+ ):
62
+ super().__init__()
63
+ self.mlp = nn.Sequential(
64
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
65
+ nn.SiLU(),
66
+ nn.Linear(hidden_size, hidden_size, bias=True),
67
+ RMSNorm(hidden_size),
68
+ )
69
+ self.frequency_embedding_size = frequency_embedding_size
70
+ half = frequency_embedding_size // 2
71
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half) / half)
72
+ self.register_buffer("freqs", freqs)
73
+
74
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
75
+ args = t * self.freqs.to(t.dtype)
76
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
77
+ return self.mlp(embedding)
78
+
79
+
80
+ class ResBlock(nn.Module):
81
+ """Residual block with adaptive layer normalization."""
82
+
83
+ def __init__(self, channels: int):
84
+ super().__init__()
85
+ self.channels = channels
86
+ self.in_ln = LayerNorm(channels, eps=1e-6)
87
+ self.mlp = nn.Sequential(
88
+ nn.Linear(channels, channels, bias=True),
89
+ nn.SiLU(),
90
+ nn.Linear(channels, channels, bias=True),
91
+ )
92
+ self.adaLN_modulation = nn.Sequential(
93
+ nn.SiLU(),
94
+ nn.Linear(channels, 3 * channels, bias=True),
95
+ )
96
+
97
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
98
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
99
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
100
+ h = self.mlp(h)
101
+ return x + gate_mlp * h
102
+
103
+
104
+ class FinalLayer(nn.Module):
105
+ """Final layer with adaptive normalization (DiT-style)."""
106
+
107
+ def __init__(self, model_channels: int, out_channels: int):
108
+ super().__init__()
109
+ self.norm_final = LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
110
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
111
+ self.adaLN_modulation = nn.Sequential(
112
+ nn.SiLU(),
113
+ nn.Linear(model_channels, 2 * model_channels, bias=True),
114
+ )
115
+
116
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
117
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
118
+ x = modulate(self.norm_final(x), shift, scale)
119
+ return self.linear(x)
120
+
121
+
122
+ class SimpleMLPAdaLN(nn.Module):
123
+ """MLP for flow matching with adaptive layer normalization.
124
+
125
+ Takes conditioning from an AR transformer and predicts flow velocity.
126
+
127
+ Args:
128
+ in_channels: Input/output latent dimension (e.g., 256 for Mimi)
129
+ model_channels: Hidden dimension of the MLP
130
+ out_channels: Output dimension (same as in_channels for flow matching)
131
+ cond_channels: Conditioning dimension from LLM
132
+ num_res_blocks: Number of residual blocks
133
+ num_time_conds: Number of time conditions (2 for start/end time in LSD)
134
+ """
135
+
136
+ def __init__(
137
+ self,
138
+ in_channels: int,
139
+ model_channels: int,
140
+ out_channels: int,
141
+ cond_channels: int,
142
+ num_res_blocks: int,
143
+ num_time_conds: int = 2,
144
+ ):
145
+ super().__init__()
146
+
147
+ self.in_channels = in_channels
148
+ self.model_channels = model_channels
149
+ self.out_channels = out_channels
150
+ self.num_res_blocks = num_res_blocks
151
+ self.num_time_conds = num_time_conds
152
+
153
+ assert num_time_conds == 2, "LSD requires exactly 2 time conditions (start, end)"
154
+
155
+ self.time_embed = nn.ModuleList(
156
+ [TimestepEmbedder(model_channels) for _ in range(num_time_conds)]
157
+ )
158
+ self.cond_embed = nn.Linear(cond_channels, model_channels)
159
+ self.input_proj = nn.Linear(in_channels, model_channels)
160
+
161
+ self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)])
162
+ self.final_layer = FinalLayer(model_channels, out_channels)
163
+
164
+ def forward(
165
+ self,
166
+ c: torch.Tensor,
167
+ s: torch.Tensor,
168
+ t: torch.Tensor,
169
+ x: torch.Tensor,
170
+ ) -> torch.Tensor:
171
+ """Predict flow velocity.
172
+
173
+ Args:
174
+ c: Conditioning from LLM, shape [N, cond_channels]
175
+ s: Start time, shape [N, 1]
176
+ t: Target time, shape [N, 1]
177
+ x: Noisy latent, shape [N, in_channels]
178
+
179
+ Returns:
180
+ Predicted velocity, shape [N, out_channels]
181
+ """
182
+ x = self.input_proj(x)
183
+
184
+ # Combine time embeddings (average of start and end time embeddings)
185
+ ts = [s, t]
186
+ t_combined = sum(self.time_embed[i](ts[i]) for i in range(self.num_time_conds))
187
+ t_combined = t_combined / self.num_time_conds
188
+
189
+ # Add conditioning
190
+ c = self.cond_embed(c)
191
+ y = t_combined + c
192
+
193
+ # Residual blocks
194
+ for block in self.res_blocks:
195
+ x = block(x, y)
196
+
197
+ return self.final_layer(x, y)
preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "dither": 0.0,
4
+ "feature_extractor_type": "WhisperFeatureExtractor",
5
+ "feature_size": 128,
6
+ "hop_length": 160,
7
+ "n_fft": 400,
8
+ "n_samples": 480000,
9
+ "nb_max_frames": 3000,
10
+ "padding": false,
11
+ "padding_side": "right",
12
+ "padding_value": 0.0,
13
+ "return_attention_mask": false,
14
+ "sampling_rate": 16000,
15
+ "processor_class": "ASRProcessor",
16
+ "auto_map": {
17
+ "AutoProcessor": "asr_processing.ASRProcessor"
18
+ }
19
+ }
projectors.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio projector modules for bridging encoder and decoder embeddings.
2
+
3
+ This module contains all projector architectures:
4
+ - MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling
5
+ - MOSAProjector: MOSA-style dense mixture of experts
6
+ - SharedMoEAudioProjector: Shared expert + sparse routed experts
7
+ - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
8
+ """
9
+
10
+ import math
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F # noqa: N812
15
+ from transformers import AutoModel, Blip2QFormerConfig
16
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
17
+
18
+ # =============================================================================
19
+ # MLP Projector
20
+ # =============================================================================
21
+
22
+
23
+ class MLPAudioProjector(nn.Module):
24
+ """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
25
+
26
+ def __init__(self, config):
27
+ """Initialize MLP projector.
28
+
29
+ Args:
30
+ config: ASRConfig with encoder_dim, llm_dim, projector_pool_stride
31
+ """
32
+ super().__init__()
33
+
34
+ encoder_dim = getattr(config, "encoder_dim", 768)
35
+ llm_dim = getattr(config, "llm_dim", 2048)
36
+ self.k = getattr(config, "projector_pool_stride", 4)
37
+
38
+ # Frame stacking: concat k adjacent frames then project
39
+ in_dim = encoder_dim * self.k
40
+ # Hidden dim defaults to llm_dim, can be overridden via config
41
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim
42
+ self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False)
43
+ self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
44
+ self.act = nn.GELU()
45
+ self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
46
+
47
+ def get_output_length(self, input_length: int) -> int:
48
+ """Calculate output sequence length given input length (matches GLM-ASR)."""
49
+ # GLM-ASR formula: (L - merge_factor) // merge_factor + 1
50
+ return (input_length - self.k) // self.k + 1
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ """Project audio features to LLM embedding space.
54
+
55
+ Args:
56
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
57
+
58
+ Returns:
59
+ Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim]
60
+ """
61
+ batch, seq, dim = x.shape
62
+ # Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
63
+ # This drops trailing frames that don't fill a complete k-frame window
64
+ out_len = (seq - self.k) // self.k + 1
65
+ x = x[:, : out_len * self.k, :] # Truncate to exact multiple
66
+ x = x.reshape(batch, out_len, dim * self.k)
67
+
68
+ x = self.linear_1(x)
69
+ x = self.norm(x)
70
+ x = self.act(x)
71
+ return self.linear_2(x)
72
+
73
+
74
+ # =============================================================================
75
+ # MoE Projector (MOSA-style)
76
+ # =============================================================================
77
+
78
+
79
+ class SimpleAdapter(nn.Module):
80
+ """Simple 2-layer GELU adapter (from MOSA paper)."""
81
+
82
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
83
+ super().__init__()
84
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
85
+ self.act = nn.GELU()
86
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
87
+
88
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
89
+ return self.fc2(self.act(self.fc1(x)))
90
+
91
+
92
+ class SwiGLU(nn.Module):
93
+ """SwiGLU activation with gated linear units (used in LLaMA, Mistral, etc.)."""
94
+
95
+ def __init__(self, dim: int, hidden_dim: int, bias: bool = False):
96
+ super().__init__()
97
+ self.w1 = nn.Linear(dim, hidden_dim, bias=bias) # Gate
98
+ self.w2 = nn.Linear(dim, hidden_dim, bias=bias) # Value
99
+ self.w3 = nn.Linear(hidden_dim, dim, bias=bias) # Output
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
103
+
104
+
105
+ class AsymmetricSwiGLU(nn.Module):
106
+ """SwiGLU that handles different input and output dimensions."""
107
+
108
+ def __init__(
109
+ self, in_features: int, hidden_features: int, out_features: int, bias: bool = False
110
+ ):
111
+ super().__init__()
112
+ self.w1 = nn.Linear(in_features, hidden_features, bias=bias) # Gate
113
+ self.w2 = nn.Linear(in_features, hidden_features, bias=bias) # Value
114
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias) # Output
115
+
116
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
117
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
118
+
119
+
120
+ class MOSAProjector(nn.Module):
121
+ """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
122
+
123
+ Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
124
+ Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
125
+ Uses Conv1d for downsampling (2 layers, stride 2 each = 4x total).
126
+ """
127
+
128
+ def __init__(self, config):
129
+ """Initialize MOSA projector.
130
+
131
+ Args:
132
+ config: ASRConfig with encoder_dim, llm_dim, num_experts
133
+ """
134
+ super().__init__()
135
+ self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
136
+ self.llm_dim = getattr(config, "llm_dim", None) or 2048
137
+ self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
138
+ adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
139
+ router_hidden = getattr(config, "router_hidden_dim", None) or 512
140
+
141
+ # --- 1. Conv1d Downsampler (4x reduction) ---
142
+ # 2 layers of stride-2 convolution
143
+ self.downsampler = nn.Sequential(
144
+ nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=3, stride=2, padding=1),
145
+ nn.GELU(),
146
+ nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
147
+ nn.GELU(),
148
+ )
149
+
150
+ # --- 2. Simple Router (MOSA-Base: 2 layers with ReLU) ---
151
+ # Takes downsampled features (llm_dim) -> 512 -> num_experts
152
+ self.router = nn.Sequential(
153
+ nn.Linear(self.llm_dim, router_hidden),
154
+ nn.ReLU(),
155
+ nn.Linear(router_hidden, self.num_experts),
156
+ )
157
+
158
+ # --- 3. Experts (Simple 2-layer GELU adapters) ---
159
+ # Each expert: llm_dim -> hidden -> llm_dim (much smaller than frame-stacking)
160
+ self.experts = nn.ModuleList(
161
+ [
162
+ SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
163
+ for _ in range(self.num_experts)
164
+ ]
165
+ )
166
+
167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
168
+ """Project audio features using mixture of experts.
169
+
170
+ Args:
171
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
172
+
173
+ Returns:
174
+ Projected features of shape [batch, out_len, llm_dim]
175
+ """
176
+ # --- 1. Conv1d Downsampling ---
177
+ # Permute for Conv1d: [B, S, D] -> [B, D, S]
178
+ x = x.transpose(1, 2)
179
+ x = self.downsampler(x)
180
+ # Permute back: [B, D, S] -> [B, S, D]
181
+ x = x.transpose(1, 2)
182
+
183
+ # --- 2. Routing ---
184
+ routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts)
185
+
186
+ # --- 3. Expert Mixture (Dense Execution) ---
187
+ expert_outputs = torch.stack([expert(x) for expert in self.experts]) # (E, B, out_len, D)
188
+ return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
189
+
190
+ def get_output_length(self, input_length: int) -> int:
191
+ """Calculate output sequence length after Conv1d downsampling (4x reduction)."""
192
+ # Conv1d with stride 2, kernel 3, padding 1: out = (in + 2*1 - 3) // 2 + 1 = (in - 1) // 2 + 1
193
+ # Applied twice for 4x total reduction
194
+ after_conv1 = (input_length + 2 * 1 - 3) // 2 + 1
195
+ return (after_conv1 + 2 * 1 - 3) // 2 + 1
196
+
197
+
198
+ # =============================================================================
199
+ # MoE Projector (Pure PyTorch with Shared Expert)
200
+ # =============================================================================
201
+
202
+
203
+ class MoEAudioProjector(nn.Module):
204
+ """MoE projector with shared expert (DeepSeek-style), pure PyTorch implementation.
205
+
206
+ Uses 4 sparse experts with top-2 routing plus a shared expert that processes all tokens.
207
+ No external dependencies (megablocks removed).
208
+
209
+ Architecture matches main branch: norm → experts(in_dim → hidden → out_dim)
210
+ """
211
+
212
+ def __init__(self, config):
213
+ """Initialize MoE projector.
214
+
215
+ Args:
216
+ config: ASRConfig with encoder_dim, llm_dim, num_experts, num_experts_per_tok
217
+ """
218
+ super().__init__()
219
+
220
+ self.k = getattr(config, "projector_pool_stride", 4)
221
+ self.aux_coef = getattr(config, "router_aux_loss_coef", 0.01)
222
+
223
+ # Stability coefficients
224
+ self.router_z_loss_coef = getattr(
225
+ config, "router_z_loss_coef", 1e-4
226
+ ) # Prevents logit explosion
227
+ self.router_jitter_noise = getattr(
228
+ config, "router_jitter_noise", 0.01
229
+ ) # Prevents expert collapse
230
+
231
+ in_dim = config.encoder_dim * self.k
232
+ out_dim = config.llm_dim
233
+
234
+ # Expert hidden dim (default = output dim)
235
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim
236
+
237
+ # Number of experts and top-k selection
238
+ self.num_experts = getattr(config, "num_experts", 4)
239
+ self.top_k = getattr(config, "num_experts_per_tok", 2)
240
+
241
+ # A. Normalize stacked input (like main branch SharedMoEBlock)
242
+ self.norm = LlamaRMSNorm(in_dim, eps=1e-6)
243
+
244
+ # B. Router (operates on stacked input)
245
+ self.router = nn.Linear(in_dim, self.num_experts, bias=False)
246
+
247
+ # C. Experts: simple 2-layer MLP (same as MLPAudioProjector)
248
+ self.experts = nn.ModuleList(
249
+ [SimpleAdapter(in_dim, hidden_dim, out_dim) for _ in range(self.num_experts)]
250
+ )
251
+
252
+ # D. Shared Expert (same architecture)
253
+ self.shared_expert = SimpleAdapter(in_dim, hidden_dim, out_dim)
254
+
255
+ # E. Initialize weights for stable training
256
+ self._init_weights()
257
+
258
+ self.last_aux_loss = torch.tensor(0.0)
259
+
260
+ def _init_weights(self):
261
+ """Initialize weights for stable training start."""
262
+ with torch.no_grad():
263
+ # Router: small weights -> uniform probability
264
+ nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
265
+
266
+ # Experts: xavier for fc1, small for fc2 (output)
267
+ for expert in [self.shared_expert, *self.experts]:
268
+ nn.init.xavier_uniform_(expert.fc1.weight)
269
+ nn.init.normal_(expert.fc2.weight, mean=0.0, std=0.01) # Small init
270
+
271
+ def get_output_length(self, input_length: int) -> int:
272
+ """Calculate output sequence length given input length (matches MLP projector)."""
273
+ return (input_length - self.k) // self.k + 1
274
+
275
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
276
+ """Project audio features using shared + sparse MoE.
277
+
278
+ Args:
279
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
280
+
281
+ Returns:
282
+ Projected features of shape [batch, out_len, llm_dim]
283
+ """
284
+ # 1. Frame Stacking
285
+ batch, seq, dim = x.shape
286
+ out_len = (seq - self.k) // self.k + 1
287
+ x = x[:, : out_len * self.k, :]
288
+ x = x.reshape(batch, out_len, dim * self.k)
289
+
290
+ # 2. Normalize stacked input (like main branch SharedMoEBlock)
291
+ x = self.norm(x)
292
+ flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim]
293
+
294
+ # 3. Shared Expert (compute first, creates output tensor)
295
+ output = self.shared_expert(flat_x)
296
+
297
+ # 4. Sparse Experts (in-place add to shared output)
298
+ self.last_aux_loss = self._forward_sparse(flat_x, output)
299
+
300
+ return output.view(batch, out_len, -1)
301
+
302
+ def _forward_sparse(self, x: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
303
+ """Stability-hardened sparse expert dispatch (in-place add to output).
304
+
305
+ Args:
306
+ x: Flattened input of shape [tokens, dim]
307
+ output: Output tensor to add sparse expert results into (in-place)
308
+
309
+ Returns:
310
+ Auxiliary loss tensor
311
+ """
312
+ # A. Router Logic with Jitter
313
+ logits = self.router(x)
314
+
315
+ if self.training and self.router_jitter_noise > 0:
316
+ # Jitter: multiply by uniform noise (1-eps, 1+eps) to shake decision boundary
317
+ # Prevents router from getting stuck on one expert early in training
318
+ noise = torch.empty_like(logits).uniform_(
319
+ 1.0 - self.router_jitter_noise, 1.0 + self.router_jitter_noise
320
+ )
321
+ logits = logits * noise
322
+
323
+ # Force float32 for softmax (bf16/fp16 exponentials can overflow)
324
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(x)
325
+
326
+ # B. Top-K Selection
327
+ top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
328
+
329
+ # Normalize weights so they sum to 1.0
330
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6)
331
+
332
+ # C. Aux Loss + Z-Loss
333
+ aux_loss = torch.tensor(0.0, device=x.device)
334
+
335
+ if self.training:
336
+ # Load balancing loss (batch-size invariant)
337
+ prob_per_expert = probs.mean(0) # [num_experts]
338
+ target = 1.0 / self.num_experts
339
+ balance_loss = (
340
+ self.aux_coef * ((prob_per_expert - target) ** 2).mean() * self.num_experts
341
+ )
342
+
343
+ # Z-loss: penalty on large logits to prevent softmax saturation
344
+ z_loss = self.router_z_loss_coef * torch.logsumexp(logits, dim=-1).pow(2).mean()
345
+
346
+ aux_loss = balance_loss + z_loss
347
+
348
+ # D. Dispatch Loop (in-place add to output)
349
+ for i, expert in enumerate(self.experts):
350
+ # Create boolean mask for tokens that selected Expert 'i'
351
+ mask = top_k_indices == i
352
+
353
+ if mask.any():
354
+ # token_idx = which tokens, k_idx = 1st or 2nd choice
355
+ token_idx, k_idx = torch.where(mask)
356
+
357
+ # Gather inputs and compute
358
+ expert_input = x[token_idx]
359
+ expert_output = expert(expert_input)
360
+
361
+ # Apply routing weight
362
+ weight = top_k_weights[token_idx, k_idx].unsqueeze(-1)
363
+ weighted_output = (expert_output * weight).type_as(output)
364
+
365
+ # Scatter back in-place (index_add_ is atomic and deterministic)
366
+ output.index_add_(0, token_idx, weighted_output)
367
+
368
+ return aux_loss
369
+
370
+ def get_aux_loss(self) -> torch.Tensor:
371
+ """Return auxiliary load balancing loss."""
372
+ return self.last_aux_loss
373
+
374
+
375
+ # =============================================================================
376
+ # QFormer Projector (Granite-style)
377
+ # =============================================================================
378
+
379
+
380
+ class QFormerAudioProjector(nn.Module):
381
+ """
382
+ BLIP-2 QFormer projector with learnable queries.
383
+
384
+ Based on GraniteSpeechEncoderProjector - uses a QFormer model with learnable
385
+ query embeddings to compress and project audio encoder outputs. The audio
386
+ sequence is processed in windows and downsampled via cross-attention.
387
+ """
388
+
389
+ def __init__(self, config):
390
+ """Initialize QFormer projector.
391
+
392
+ Args:
393
+ config: ASRConfig with encoder_dim, llm_dim, qformer_* settings
394
+ """
395
+ super().__init__()
396
+
397
+ encoder_dim = config.encoder_dim
398
+ llm_dim = config.llm_dim
399
+
400
+ # Window and downsampling parameters (Granite defaults: window=15, downsample=5)
401
+ self.window_size = getattr(config, "qformer_window_size", 15)
402
+ self.downsample_rate = getattr(config, "downsample_rate", 5)
403
+ self.num_queries = self.window_size // self.downsample_rate
404
+
405
+ # QFormer hidden size (matches encoder for cross-attention)
406
+ qformer_hidden = getattr(config, "qformer_hidden_size", None) or encoder_dim
407
+ qformer_num_layers = getattr(config, "qformer_num_layers", 2)
408
+ qformer_num_heads = getattr(config, "qformer_num_heads", 16)
409
+ qformer_intermediate = getattr(config, "qformer_intermediate_size", None) or (
410
+ qformer_hidden * 4
411
+ )
412
+
413
+ # Learnable query embeddings (Granite uses std=1.0)
414
+ self.query = nn.Parameter(torch.zeros(1, self.num_queries, qformer_hidden))
415
+ self.query.data.normal_(mean=0.0, std=1.0)
416
+
417
+ # Optional projection if encoder dim != qformer hidden
418
+ if encoder_dim != qformer_hidden:
419
+ self.encoder_proj = nn.Linear(encoder_dim, qformer_hidden, bias=False)
420
+ else:
421
+ self.encoder_proj = None
422
+
423
+ # Configure QFormer to match Granite's exact config
424
+ qformer_config = Blip2QFormerConfig(
425
+ hidden_size=qformer_hidden,
426
+ num_hidden_layers=qformer_num_layers,
427
+ num_attention_heads=qformer_num_heads,
428
+ intermediate_size=qformer_intermediate,
429
+ encoder_hidden_size=qformer_hidden,
430
+ cross_attention_frequency=1,
431
+ # Granite-specific settings
432
+ hidden_act="gelu",
433
+ attention_probs_dropout_prob=0.1,
434
+ hidden_dropout_prob=0.1,
435
+ layer_norm_eps=1e-12,
436
+ initializer_range=0.02,
437
+ )
438
+ self.qformer = AutoModel.from_config(qformer_config)
439
+
440
+ # Final projection to LLM dimension (Granite uses bias=True)
441
+ self.linear = nn.Linear(qformer_hidden, llm_dim)
442
+
443
+ def get_output_length(self, input_length: int) -> int:
444
+ """Calculate output sequence length given input length."""
445
+ # QFormer uses window-based processing with num_queries per window
446
+ nblocks = math.ceil(input_length / self.window_size)
447
+ return nblocks * self.num_queries
448
+
449
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
450
+ """
451
+ Args:
452
+ hidden_states: [batch_size, seq_len, encoder_dim]
453
+
454
+ Returns:
455
+ projected: [batch_size, num_output_tokens, llm_dim]
456
+ """
457
+ batch_size, seq_len, dim = hidden_states.size()
458
+
459
+ # Ensure float dtype for QFormer
460
+ target_dtype = self.query.dtype
461
+ if hidden_states.dtype != target_dtype:
462
+ hidden_states = hidden_states.to(target_dtype)
463
+
464
+ # Optional encoder projection
465
+ if self.encoder_proj is not None:
466
+ hidden_states = self.encoder_proj(hidden_states)
467
+
468
+ # Compute number of windows and pad to fit
469
+ nblocks = math.ceil(seq_len / self.window_size)
470
+ pad = nblocks * self.window_size - seq_len
471
+ if pad > 0:
472
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
473
+
474
+ # Reshape to process each window: [batch*nblocks, window_size, dim]
475
+ effective_batch = batch_size * nblocks
476
+ hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
477
+
478
+ # Expand queries to match batch size
479
+ query_embeds = self.query.expand(effective_batch, -1, -1)
480
+
481
+ # QFormer cross-attention
482
+ query_output = self.qformer(
483
+ query_embeds=query_embeds,
484
+ encoder_hidden_states=hidden_states,
485
+ return_dict=True,
486
+ )
487
+
488
+ # Reshape back: [batch, nblocks * num_queries, hidden]
489
+ output_tokens = nblocks * self.num_queries
490
+ query_proj = query_output.last_hidden_state.view(batch_size, output_tokens, -1)
491
+
492
+ # Project to LLM dimension
493
+ return self.linear(query_proj)
494
+
495
+
496
+ # =============================================================================
497
+ # Projector Registry
498
+ # =============================================================================
499
+
500
+ PROJECTOR_CLASSES = {
501
+ "mlp": MLPAudioProjector,
502
+ "mosa": MOSAProjector,
503
+ "moe": MoEAudioProjector,
504
+ "qformer": QFormerAudioProjector,
505
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
3
+ size 17209003
tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": null,
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|im_end|>",
6
+ "extra_special_tokens": [
7
+ "<audio>"
8
+ ],
9
+ "fast": false,
10
+ "is_local": false,
11
+ "model_input_names": [
12
+ "input_ids",
13
+ "attention_mask"
14
+ ],
15
+ "model_max_length": 131072,
16
+ "model_specific_special_tokens": {},
17
+ "pad_token": "<|finetune_right_pad_id|>",
18
+ "tokenizer_class": "TokenizersBackend"
19
+ }