mazesmazes commited on
Commit
64278ca
·
verified ·
1 Parent(s): c88c23a

Assembled S2S model (base + AudioHead)

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
alignment.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Forced alignment for word-level timestamps using Wav2Vec2."""
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ # Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
7
+ # Calibrated on librispeech-alignments dataset (n=25, MAE=48ms)
8
+ START_OFFSET = 0.04 # Subtract from start times (shift earlier)
9
+ END_OFFSET = -0.04 # Subtract from end times (shift later)
10
+
11
+
12
+ def _get_device() -> str:
13
+ """Get best available device for non-transformers models."""
14
+ if torch.cuda.is_available():
15
+ return "cuda"
16
+ if torch.backends.mps.is_available():
17
+ return "mps"
18
+ return "cpu"
19
+
20
+
21
+ class ForcedAligner:
22
+ """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
23
+
24
+ Uses Viterbi trellis algorithm for optimal alignment path finding.
25
+ """
26
+
27
+ _bundle = None
28
+ _model = None
29
+ _labels = None
30
+ _dictionary = None
31
+
32
+ @classmethod
33
+ def get_instance(cls, device: str = "cuda"):
34
+ """Get or create the forced alignment model (singleton).
35
+
36
+ Args:
37
+ device: Device to run model on ("cuda" or "cpu")
38
+
39
+ Returns:
40
+ Tuple of (model, labels, dictionary)
41
+ """
42
+ if cls._model is None:
43
+ import torchaudio
44
+
45
+ cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
46
+ cls._model = cls._bundle.get_model().to(device)
47
+ cls._model.eval()
48
+ cls._labels = cls._bundle.get_labels()
49
+ cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
50
+ return cls._model, cls._labels, cls._dictionary
51
+
52
+ @staticmethod
53
+ def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
54
+ """Build trellis for forced alignment using forward algorithm.
55
+
56
+ The trellis[t, j] represents the log probability of the best path that
57
+ aligns the first j tokens to the first t frames.
58
+
59
+ Args:
60
+ emission: Log-softmax emission matrix of shape (num_frames, num_classes)
61
+ tokens: List of target token indices
62
+ blank_id: Index of the blank/CTC token (default 0)
63
+
64
+ Returns:
65
+ Trellis matrix of shape (num_frames + 1, num_tokens + 1)
66
+ """
67
+ num_frames = emission.size(0)
68
+ num_tokens = len(tokens)
69
+
70
+ trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
71
+ trellis[0, 0] = 0
72
+
73
+ # Force alignment to use all tokens by preventing staying in blank
74
+ # at the end when there are still tokens to emit
75
+ if num_tokens > 1:
76
+ trellis[-num_tokens + 1 :, 0] = float("inf")
77
+
78
+ for t in range(num_frames):
79
+ for j in range(num_tokens + 1):
80
+ # Stay: emit blank and stay at j tokens
81
+ stay = trellis[t, j] + emission[t, blank_id]
82
+
83
+ # Move: emit token j and advance to j+1 tokens
84
+ move = trellis[t, j - 1] + emission[t, tokens[j - 1]] if j > 0 else -float("inf")
85
+
86
+ trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
87
+
88
+ return trellis
89
+
90
+ @staticmethod
91
+ def _backtrack(
92
+ trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
93
+ ) -> list[tuple[int, float, float, float]]:
94
+ """Backtrack through trellis to find optimal forced monotonic alignment.
95
+
96
+ Guarantees:
97
+ - All tokens are emitted exactly once
98
+ - Strictly monotonic: each token's frames come after previous token's
99
+ - No frame skipping or token teleporting
100
+
101
+ Returns list of (token_id, start_frame, end_frame, peak_frame) for each token.
102
+ The peak_frame is the frame with highest emission probability for that token.
103
+ """
104
+ num_frames = emission.size(0)
105
+ num_tokens = len(tokens)
106
+
107
+ if num_tokens == 0:
108
+ return []
109
+
110
+ # Find the best ending point (should be at num_tokens)
111
+ # But verify trellis reached a valid state
112
+ if trellis[num_frames, num_tokens] == -float("inf"):
113
+ # Alignment failed - fall back to uniform distribution
114
+ frames_per_token = num_frames / num_tokens
115
+ return [
116
+ (
117
+ tokens[i],
118
+ i * frames_per_token,
119
+ (i + 1) * frames_per_token,
120
+ (i + 0.5) * frames_per_token,
121
+ )
122
+ for i in range(num_tokens)
123
+ ]
124
+
125
+ # Backtrack: find where each token transition occurred
126
+ # Store (frame, emission_score) for each token
127
+ token_frames: list[list[tuple[int, float]]] = [[] for _ in range(num_tokens)]
128
+
129
+ t = num_frames
130
+ j = num_tokens
131
+
132
+ while t > 0 and j > 0:
133
+ # Check: did we transition from j-1 to j at frame t-1?
134
+ stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
135
+ move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
136
+
137
+ if move_score >= stay_score:
138
+ # Token j-1 was emitted at frame t-1
139
+ # Store frame and its emission probability
140
+ emit_prob = emission[t - 1, tokens[j - 1]].exp().item()
141
+ token_frames[j - 1].insert(0, (t - 1, emit_prob))
142
+ j -= 1
143
+ # Always decrement time (monotonic)
144
+ t -= 1
145
+
146
+ # Handle any remaining tokens at the start (edge case)
147
+ while j > 0:
148
+ token_frames[j - 1].insert(0, (0, 0.0))
149
+ j -= 1
150
+
151
+ # Convert to spans with peak frame
152
+ token_spans: list[tuple[int, float, float, float]] = []
153
+ for token_idx, frames_with_scores in enumerate(token_frames):
154
+ if not frames_with_scores:
155
+ # Token never emitted - assign minimal span after previous
156
+ if token_spans:
157
+ prev_end = token_spans[-1][2]
158
+ frames_with_scores = [(int(prev_end), 0.0)]
159
+ else:
160
+ frames_with_scores = [(0, 0.0)]
161
+
162
+ token_id = tokens[token_idx]
163
+ frames = [f for f, _ in frames_with_scores]
164
+ start_frame = float(min(frames))
165
+ end_frame = float(max(frames)) + 1.0
166
+
167
+ # Find peak frame (highest emission probability)
168
+ peak_frame, _ = max(frames_with_scores, key=lambda x: x[1])
169
+
170
+ token_spans.append((token_id, start_frame, end_frame, float(peak_frame)))
171
+
172
+ return token_spans
173
+
174
+ @classmethod
175
+ def align(
176
+ cls,
177
+ audio: np.ndarray,
178
+ text: str,
179
+ sample_rate: int = 16000,
180
+ ) -> list[dict]:
181
+ """Align transcript to audio and return word-level timestamps.
182
+
183
+ Uses Viterbi trellis algorithm for optimal forced alignment.
184
+
185
+ Args:
186
+ audio: Audio waveform as numpy array
187
+ text: Transcript text to align
188
+ sample_rate: Audio sample rate (default 16000)
189
+
190
+ Returns:
191
+ List of dicts with 'word', 'start', 'end' keys
192
+ """
193
+ import torchaudio
194
+
195
+ device = _get_device()
196
+ model, _labels, dictionary = cls.get_instance(device)
197
+ assert cls._bundle is not None and dictionary is not None # Initialized by get_instance
198
+
199
+ # Convert audio to tensor (copy to ensure array is writable)
200
+ if isinstance(audio, np.ndarray):
201
+ waveform = torch.from_numpy(audio.copy()).float()
202
+ else:
203
+ waveform = audio.clone().float()
204
+
205
+ # Ensure 2D (channels, time)
206
+ if waveform.dim() == 1:
207
+ waveform = waveform.unsqueeze(0)
208
+
209
+ # Resample if needed (wav2vec2 expects 16kHz)
210
+ if sample_rate != cls._bundle.sample_rate:
211
+ waveform = torchaudio.functional.resample(
212
+ waveform, sample_rate, cls._bundle.sample_rate
213
+ )
214
+
215
+ waveform = waveform.to(device)
216
+
217
+ # Get emissions from model
218
+ with torch.inference_mode():
219
+ emissions, _ = model(waveform)
220
+ emissions = torch.log_softmax(emissions, dim=-1)
221
+
222
+ emission = emissions[0].cpu()
223
+
224
+ # Normalize text: uppercase, keep only valid characters
225
+ transcript = text.upper()
226
+
227
+ # Build tokens from transcript (including word separators)
228
+ tokens = []
229
+ for char in transcript:
230
+ if char in dictionary:
231
+ tokens.append(dictionary[char])
232
+ elif char == " ":
233
+ tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
234
+
235
+ if not tokens:
236
+ return []
237
+
238
+ # Build Viterbi trellis and backtrack for optimal path
239
+ trellis = cls._get_trellis(emission, tokens, blank_id=0)
240
+ alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
241
+
242
+ # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
243
+ frame_duration = 320 / cls._bundle.sample_rate
244
+
245
+ # Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
246
+ start_offset = START_OFFSET
247
+ end_offset = END_OFFSET
248
+
249
+ # Group aligned tokens into words based on pipe separator
250
+ # Use peak emission frame for more accurate word boundaries
251
+ words = text.split()
252
+ word_timestamps = []
253
+ first_char_peak = None
254
+ last_char_peak = None
255
+ word_idx = 0
256
+ separator_id = dictionary.get("|", dictionary.get(" ", 0))
257
+
258
+ for token_id, _start_frame, _end_frame, peak_frame in alignment_path:
259
+ if token_id == separator_id: # Word separator
260
+ if (
261
+ first_char_peak is not None
262
+ and last_char_peak is not None
263
+ and word_idx < len(words)
264
+ ):
265
+ # Use peak frames for word boundaries
266
+ start_time = max(0.0, first_char_peak * frame_duration - start_offset)
267
+ end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
268
+ word_timestamps.append(
269
+ {
270
+ "word": words[word_idx],
271
+ "start": start_time,
272
+ "end": end_time,
273
+ }
274
+ )
275
+ word_idx += 1
276
+ first_char_peak = None
277
+ last_char_peak = None
278
+ else:
279
+ if first_char_peak is None:
280
+ first_char_peak = peak_frame
281
+ last_char_peak = peak_frame
282
+
283
+ # Don't forget the last word
284
+ if first_char_peak is not None and last_char_peak is not None and word_idx < len(words):
285
+ start_time = max(0.0, first_char_peak * frame_duration - start_offset)
286
+ end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
287
+ word_timestamps.append(
288
+ {
289
+ "word": words[word_idx],
290
+ "start": start_time,
291
+ "end": end_time,
292
+ }
293
+ )
294
+
295
+ return word_timestamps
asr_config.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import transformers
4
+
5
+
6
+ class ASRConfig(transformers.PretrainedConfig):
7
+ """Configuration class for the ASR model."""
8
+
9
+ model_type = "asr_model"
10
+ is_composition = True
11
+
12
+ # Generation defaults
13
+ GENERATION_DEFAULTS = {
14
+ "num_beams": 1,
15
+ "max_new_tokens": 128,
16
+ "min_new_tokens": 0,
17
+ "repetition_penalty": 1.0,
18
+ "length_penalty": 1.0,
19
+ "no_repeat_ngram_size": 0,
20
+ "use_cache": True,
21
+ "do_sample": False,
22
+ "temperature": None,
23
+ "top_p": None,
24
+ "top_k": None,
25
+ }
26
+
27
+ def __init__(
28
+ self,
29
+ # Model IDs
30
+ audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
31
+ text_model_id: str = "Qwen/Qwen3-0.6B",
32
+ # Model settings
33
+ attn_implementation: str = "sdpa",
34
+ model_dtype: str = "bfloat16",
35
+ system_prompt: str = "You are a helpful assistant.",
36
+ enable_thinking: bool = False,
37
+ # Encoder settings (auto-detected if None)
38
+ encoder_dim: Optional[int] = None,
39
+ llm_dim: Optional[int] = None,
40
+ encoder_conv_layers: Optional[list] = None,
41
+ audio_sample_rate: int = 16000,
42
+ # Projector settings
43
+ projector_type: str = "mlp",
44
+ projector_pool_stride: int = 4,
45
+ projector_hidden_dim: Optional[int] = None,
46
+ # Training settings (not saved to config.json for inference)
47
+ use_specaugment: bool = False,
48
+ num_time_masks: int = 2,
49
+ time_mask_length: int = 10,
50
+ num_freq_masks: int = 0,
51
+ freq_mask_length: int = 10,
52
+ freeze_projector: bool = False,
53
+ label_smoothing: float = 0.0,
54
+ # Audio Head settings (trainable AR decoder + NeuCodec)
55
+ use_audio_head: bool = False,
56
+ freeze_audio_head: bool = False,
57
+ max_audio_tokens: int = 500,
58
+ decoder_dim: int = 512,
59
+ decoder_layers: int = 6,
60
+ decoder_heads: int = 8,
61
+ neucodec_model_id: str = "neuphonic/neucodec",
62
+ **kwargs,
63
+ ):
64
+ # Merge generation defaults with kwargs (kwargs takes precedence)
65
+ for key, default in self.GENERATION_DEFAULTS.items():
66
+ if key not in kwargs:
67
+ kwargs[key] = default
68
+
69
+ # Core model settings
70
+ self.audio_model_id = audio_model_id
71
+ self.text_model_id = text_model_id
72
+ self.attn_implementation = attn_implementation
73
+ self.model_dtype = model_dtype
74
+ self.system_prompt = system_prompt
75
+ self.enable_thinking = enable_thinking
76
+
77
+ # Encoder settings
78
+ self.encoder_dim = encoder_dim
79
+ self.llm_dim = llm_dim
80
+ self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
81
+ self.audio_sample_rate = audio_sample_rate
82
+
83
+ # Projector settings
84
+ self.projector_type = projector_type
85
+ self.projector_pool_stride = projector_pool_stride
86
+ self.projector_hidden_dim = projector_hidden_dim
87
+
88
+ # Training settings
89
+ self.use_specaugment = use_specaugment
90
+ self.num_time_masks = num_time_masks
91
+ self.time_mask_length = time_mask_length
92
+ self.num_freq_masks = num_freq_masks
93
+ self.freq_mask_length = freq_mask_length
94
+ self.freeze_projector = freeze_projector
95
+ self.label_smoothing = label_smoothing
96
+
97
+ # Audio Head settings (trainable AR decoder + NeuCodec)
98
+ self.use_audio_head = use_audio_head
99
+ self.freeze_audio_head = freeze_audio_head
100
+ self.max_audio_tokens = max_audio_tokens
101
+ self.decoder_dim = decoder_dim
102
+ self.decoder_layers = decoder_layers
103
+ self.decoder_heads = decoder_heads
104
+ self.neucodec_model_id = neucodec_model_id
105
+
106
+ # Generation parameters (from kwargs after merge with defaults)
107
+ self.num_beams = kwargs.pop("num_beams")
108
+ self.max_new_tokens = kwargs.pop("max_new_tokens")
109
+ self.min_new_tokens = kwargs.pop("min_new_tokens")
110
+ self.repetition_penalty = kwargs.pop("repetition_penalty")
111
+ self.length_penalty = kwargs.pop("length_penalty")
112
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size")
113
+ self.use_cache = kwargs.pop("use_cache")
114
+ self.do_sample = kwargs.pop("do_sample")
115
+ self.temperature = kwargs.pop("temperature")
116
+ self.top_p = kwargs.pop("top_p")
117
+ self.top_k = kwargs.pop("top_k")
118
+
119
+ # Load sub-configs
120
+ self.audio_config = kwargs.pop("audio_config", None)
121
+ if self.audio_config is None:
122
+ self.audio_config = transformers.AutoConfig.from_pretrained(
123
+ audio_model_id, trust_remote_code=True
124
+ )
125
+ self.audio_config.dtype = model_dtype
126
+ elif isinstance(self.audio_config, dict) and self.audio_config.get("model_type"):
127
+ config_class = transformers.AutoConfig.for_model(
128
+ self.audio_config["model_type"]
129
+ ).__class__
130
+ self.audio_config = config_class(**self.audio_config)
131
+
132
+ self.text_config = kwargs.pop("text_config", None)
133
+ if self.text_config is None:
134
+ self.text_config = transformers.AutoConfig.from_pretrained(
135
+ text_model_id, trust_remote_code=True
136
+ )
137
+ self.text_config.dtype = model_dtype
138
+ elif isinstance(self.text_config, dict):
139
+ config_class = transformers.AutoConfig.for_model(
140
+ self.text_config["model_type"]
141
+ ).__class__
142
+ self.text_config = config_class(**self.text_config)
143
+
144
+ super().__init__(**kwargs)
145
+
146
+ # Pipeline configuration
147
+ self.encoder = self.audio_config
148
+ self.auto_map = {
149
+ "AutoConfig": "asr_config.ASRConfig",
150
+ "AutoModel": "asr_modeling.ASRModel",
151
+ "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
152
+ "AutoProcessor": "asr_processing.ASRProcessor",
153
+ }
154
+ self.custom_pipelines = {
155
+ "automatic-speech-recognition": {
156
+ "impl": "asr_pipeline.ASRPipeline",
157
+ "pt": ["AutoModelForSpeechSeq2Seq"],
158
+ "tf": [],
159
+ "type": "audio",
160
+ }
161
+ }
162
+ self.architectures = ["ASRModel"]
163
+ self.pipeline_tag = "automatic-speech-recognition"
164
+
165
+
166
+ transformers.AutoConfig.register("asr_model", ASRConfig)
asr_modeling.py ADDED
@@ -0,0 +1,817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModel,
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ PreTrainedModel,
13
+ )
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast
16
+
17
+ try:
18
+ from .asr_config import ASRConfig
19
+ from .projectors import PROJECTOR_CLASSES
20
+ except ImportError:
21
+ from asr_config import ASRConfig # type: ignore[no-redef]
22
+ from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
23
+
24
+
25
+ from torchaudio.transforms import SpecAugment
26
+
27
+
28
+ class ASRModel(PreTrainedModel, GenerationMixin):
29
+ """Audio-to-text model combining an audio encoder, projector, and language model."""
30
+
31
+ config_class = ASRConfig
32
+ base_model_prefix = "model"
33
+ main_input_name = "input_features"
34
+ _supports_flash_attn_2 = True
35
+ supports_gradient_checkpointing = True
36
+ _is_loading_from_pretrained: bool = False
37
+ _pretrained_model_path: Optional[str] = None
38
+
39
+ TRANSCRIBE_PROMPT = ""
40
+
41
+ @classmethod
42
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
43
+ """Load model from pretrained, handling device placement correctly."""
44
+ from safetensors.torch import load_file
45
+ from transformers.utils.hub import cached_file
46
+
47
+ config = kwargs.pop("config", None)
48
+ if config is None:
49
+ config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
50
+
51
+ # Set flag to avoid device_map="auto" in sub-model loaders
52
+ cls._is_loading_from_pretrained = True
53
+ cls._pretrained_model_path = pretrained_model_name_or_path
54
+
55
+ try:
56
+ model = cls(config, **kwargs)
57
+
58
+ # Load projector weights from safetensors
59
+ subfolder = kwargs.get("subfolder")
60
+ revision = kwargs.get("revision")
61
+ cache_kwargs = {}
62
+ if subfolder:
63
+ cache_kwargs["subfolder"] = subfolder
64
+ if revision:
65
+ cache_kwargs["revision"] = revision
66
+
67
+ model_file = cached_file(
68
+ pretrained_model_name_or_path,
69
+ "model.safetensors",
70
+ _raise_exceptions_for_missing_entries=False,
71
+ **cache_kwargs,
72
+ )
73
+
74
+ if model_file is not None:
75
+ state_dict = load_file(model_file)
76
+ model.load_state_dict(state_dict, strict=False)
77
+
78
+ return model
79
+ finally:
80
+ cls._is_loading_from_pretrained = False
81
+ cls._pretrained_model_path = None
82
+
83
+ def __init__(self, config: ASRConfig, **kwargs) -> None:
84
+ super().__init__(config)
85
+
86
+ self.system_prompt = config.system_prompt
87
+ target_dtype = getattr(torch, config.model_dtype)
88
+
89
+ # Audio encoder (frozen)
90
+ self.audio_tower = self._load_audio_encoder(config, target_dtype)
91
+
92
+ # Language model (frozen)
93
+ self.language_model = self._load_language_model(config, target_dtype)
94
+
95
+ # Initialize tokenizer and special tokens
96
+ self._init_tokenizer(config)
97
+
98
+ # Set up generation config with greedy decoding defaults
99
+ self.generation_config = self.language_model.generation_config
100
+ self.generation_config.max_new_tokens = config.max_new_tokens
101
+ self.generation_config.min_new_tokens = config.min_new_tokens
102
+ self.generation_config.num_beams = config.num_beams
103
+ self.generation_config.do_sample = config.do_sample
104
+ # Set sampling params from config (None means use model defaults)
105
+ self.generation_config.temperature = config.temperature
106
+ self.generation_config.top_p = config.top_p
107
+ self.generation_config.top_k = config.top_k
108
+ self.generation_config.use_cache = config.use_cache
109
+ self.generation_config.length_penalty = config.length_penalty
110
+ self.generation_config.repetition_penalty = config.repetition_penalty
111
+ self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
112
+ # Set EOS tokens, filtering out any that don't exist in the tokenizer
113
+ eos_candidates = [
114
+ self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
115
+ self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
116
+ ]
117
+ self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None]
118
+ self.generation_config.pad_token_id = self.tokenizer.pad_token_id
119
+
120
+ # Feature extractor for audio preprocessing
121
+ self.feature_extractor = self._create_feature_extractor(config)
122
+
123
+ # Audio projector (trainable unless freeze_projector is set)
124
+ self.projector = self._create_projector(config, target_dtype)
125
+
126
+ # Learned padding embedding for audio tokens (used when projector output is short)
127
+ # Using a learned embedding instead of zeros keeps values in the embedding distribution
128
+ self.audio_pad_embedding = nn.Parameter(torch.randn(1, config.llm_dim) * 0.02)
129
+
130
+ # Freeze projector if specified
131
+ if getattr(config, "freeze_projector", False):
132
+ self.projector.requires_grad_(False)
133
+
134
+ # SpecAugment for data augmentation during training
135
+ if getattr(config, "use_specaugment", False):
136
+ self.spec_augment = SpecAugment(
137
+ n_time_masks=config.num_time_masks,
138
+ time_mask_param=config.time_mask_length,
139
+ n_freq_masks=config.num_freq_masks,
140
+ freq_mask_param=config.freq_mask_length,
141
+ )
142
+ else:
143
+ self.spec_augment = None
144
+
145
+ # Audio head for S2S (trainable AR decoder + NeuCodec)
146
+ if getattr(config, "use_audio_head", False):
147
+ from .audio_head import AudioHead, AudioHeadConfig
148
+
149
+ device = next(self.language_model.parameters()).device
150
+
151
+ audio_head_config = AudioHeadConfig(
152
+ decoder_dim=config.decoder_dim,
153
+ decoder_layers=config.decoder_layers,
154
+ decoder_heads=config.decoder_heads,
155
+ text_vocab_size=len(self.tokenizer),
156
+ max_audio_tokens=config.max_audio_tokens,
157
+ neucodec_model_id=getattr(config, "neucodec_model_id", "neuphonic/neucodec"),
158
+ temperature=getattr(config, "audio_head_temperature", 1.0),
159
+ top_k=getattr(config, "audio_head_top_k", 50),
160
+ )
161
+ self.audio_head = AudioHead(audio_head_config).to(
162
+ device=device, dtype=target_dtype
163
+ )
164
+
165
+ if getattr(config, "freeze_audio_head", False):
166
+ self.audio_head.requires_grad_(False)
167
+ else:
168
+ self.audio_head = None
169
+
170
+ # Silero VAD for interruption detection (Freeze-Omni style)
171
+ # Loaded lazily on first use to avoid startup cost
172
+ self._vad_model = None
173
+ self._vad_utils = None
174
+
175
+ # For model parallelism
176
+ self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
177
+
178
+ def _tie_weights(self):
179
+ """No-op: AudioHead manages its own embeddings."""
180
+ pass
181
+
182
+ def _create_feature_extractor(self, config: ASRConfig):
183
+ """Create the appropriate feature extractor for the audio encoder."""
184
+ from transformers import AutoFeatureExtractor
185
+
186
+ feature_extractor = AutoFeatureExtractor.from_pretrained(config.audio_model_id)
187
+ # Disable padding by default - use actual audio length
188
+ feature_extractor.padding = False
189
+ return feature_extractor
190
+
191
+ @classmethod
192
+ def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
193
+ """Load and freeze the audio encoder."""
194
+ encoder_kwargs = {
195
+ "attn_implementation": config.attn_implementation,
196
+ "low_cpu_mem_usage": True,
197
+ "torch_dtype": dtype,
198
+ }
199
+
200
+ if "whisper" in config.audio_model_id.lower():
201
+ from transformers import WhisperModel
202
+
203
+ full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
204
+ encoder = full_model.encoder
205
+ del full_model
206
+ elif "glm" in config.audio_model_id.lower():
207
+ # GLM-ASR models use audio_tower as the encoder
208
+ # Requires transformers >= 5.x or installed from source
209
+ from transformers import AutoModelForSeq2SeqLM
210
+
211
+ full_model = AutoModelForSeq2SeqLM.from_pretrained(
212
+ config.audio_model_id, trust_remote_code=True, **encoder_kwargs
213
+ )
214
+ # GLM stores encoder at audio_tower (GlmAsrEncoder)
215
+ encoder = full_model.audio_tower
216
+ # Clear references to free VRAM from the LLM decoder
217
+ full_model.language_model = None
218
+ full_model.multi_modal_projector = None
219
+ del full_model
220
+ else:
221
+ encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
222
+
223
+ encoder.requires_grad_(False)
224
+ encoder.eval()
225
+ return encoder
226
+
227
+ @classmethod
228
+ def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel:
229
+ """Load and freeze the language model."""
230
+ decoder_kwargs = {
231
+ "attn_implementation": config.attn_implementation,
232
+ "trust_remote_code": True,
233
+ "low_cpu_mem_usage": True,
234
+ "dtype": dtype,
235
+ }
236
+
237
+ decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
238
+ decoder.config.use_cache = getattr(config, "use_cache", True)
239
+ decoder.requires_grad_(False)
240
+ decoder.eval()
241
+ return decoder
242
+
243
+ def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
244
+ """Create the trainable audio projector."""
245
+ # Auto-detect dimensions if not specified
246
+ if config.encoder_dim is None:
247
+ enc_cfg = self.audio_tower.config
248
+ config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr(
249
+ enc_cfg, "d_model", None
250
+ )
251
+ if config.encoder_dim is None:
252
+ raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
253
+
254
+ if config.llm_dim is None:
255
+ dec_cfg = self.language_model.config
256
+ config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr(
257
+ dec_cfg, "d_model", None
258
+ )
259
+ if config.llm_dim is None:
260
+ raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
261
+
262
+ # Select projector type based on config
263
+ projector_type = getattr(config, "projector_type", "mlp")
264
+ projector_class = PROJECTOR_CLASSES.get(projector_type)
265
+ if projector_class is None:
266
+ raise ValueError(
267
+ f"Unknown projector_type: {projector_type}. "
268
+ f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
269
+ )
270
+ projector = projector_class(config)
271
+
272
+ # Move projector to same device as language model (important when using quantization)
273
+ device = next(self.language_model.parameters()).device
274
+ return projector.to(device=device, dtype=dtype)
275
+
276
+ def _init_tokenizer(self, config: ASRConfig):
277
+ """Initialize tokenizer with audio token."""
278
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
279
+
280
+ # Set pad token
281
+ if (
282
+ self.tokenizer.pad_token is None
283
+ or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
284
+ ) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
285
+ self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
286
+
287
+ # Add audio token
288
+ existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or []
289
+ if "<audio>" not in existing_special:
290
+ self.tokenizer.add_special_tokens(
291
+ {"additional_special_tokens": existing_special + ["<audio>"]}
292
+ )
293
+ self.language_model.resize_token_embeddings(
294
+ len(self.tokenizer), mean_resizing=False, pad_to_multiple_of=64
295
+ )
296
+
297
+ self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
298
+ self.tokenizer.padding_side = "right"
299
+
300
+ # Sync token IDs to configs
301
+ for cfg in [self.config.text_config, self.language_model.config, self.generation_config]:
302
+ if cfg is not None:
303
+ cfg.pad_token_id = self.tokenizer.pad_token_id
304
+ cfg.eos_token_id = self.tokenizer.eos_token_id
305
+ cfg.bos_token_id = self.tokenizer.bos_token_id
306
+
307
+ def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
308
+ """Enable/disable gradient checkpointing for the language model."""
309
+ # The LLM still stores activations during forward for backprop to projector
310
+ # Gradient checkpointing trades compute for memory by recomputing activations
311
+ if hasattr(self.language_model, "_set_gradient_checkpointing"):
312
+ self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
313
+ elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
314
+ self.language_model.gradient_checkpointing_enable(
315
+ gradient_checkpointing_kwargs={"use_reentrant": False}
316
+ )
317
+ elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
318
+ self.language_model.gradient_checkpointing_disable()
319
+
320
+ def get_input_embeddings(self) -> nn.Module:
321
+ return self.language_model.get_input_embeddings()
322
+
323
+ def set_input_embeddings(self, value: nn.Module) -> None:
324
+ self.language_model.set_input_embeddings(value)
325
+
326
+ def get_output_embeddings(self) -> nn.Module:
327
+ return self.language_model.get_output_embeddings()
328
+
329
+ def set_output_embeddings(self, value: nn.Module) -> None:
330
+ self.language_model.set_output_embeddings(value)
331
+
332
+ def get_processor(self):
333
+ """Get the processor for this model."""
334
+ try:
335
+ from .asr_processing import ASRProcessor
336
+ except ImportError:
337
+ from asr_processing import ASRProcessor # type: ignore[no-redef]
338
+
339
+ return ASRProcessor(
340
+ feature_extractor=self.feature_extractor,
341
+ tokenizer=self.tokenizer,
342
+ projector=self.projector,
343
+ encoder_conv_layers=self.config.encoder_conv_layers,
344
+ )
345
+
346
+ # =========================================================================
347
+ # Silero VAD for Interruption Detection (Freeze-Omni style)
348
+ # =========================================================================
349
+
350
+ def load_vad(self, force_reload: bool = False) -> None:
351
+ """Load Silero VAD model for interruption detection.
352
+
353
+ Silero VAD is a lightweight (~2MB) voice activity detector that runs
354
+ in real-time. Used as the first layer of interruption detection.
355
+
356
+ Args:
357
+ force_reload: Force reload even if already loaded
358
+ """
359
+ if self._vad_model is not None and not force_reload:
360
+ return
361
+
362
+ model, utils = torch.hub.load(
363
+ repo_or_dir="snakers4/silero-vad",
364
+ model="silero_vad",
365
+ force_reload=force_reload,
366
+ trust_repo=True,
367
+ )
368
+
369
+ self._vad_model = model
370
+ self._vad_utils = utils
371
+
372
+ # Freeze VAD model
373
+ self._vad_model.eval()
374
+ for param in self._vad_model.parameters():
375
+ param.requires_grad = False
376
+
377
+ def detect_speech(
378
+ self,
379
+ audio_chunk: torch.Tensor,
380
+ sample_rate: int = 16000,
381
+ threshold: float = 0.5,
382
+ ) -> tuple[bool, float]:
383
+ """Detect speech in an audio chunk using Silero VAD.
384
+
385
+ Args:
386
+ audio_chunk: Audio waveform [samples] or [1, samples] at sample_rate
387
+ sample_rate: Audio sample rate (default 16kHz)
388
+ threshold: Speech probability threshold (default 0.5)
389
+
390
+ Returns:
391
+ Tuple of (is_speech, probability)
392
+ """
393
+ if self._vad_model is None:
394
+ self.load_vad()
395
+
396
+ # Ensure 1D tensor
397
+ if audio_chunk.dim() > 1:
398
+ audio_chunk = audio_chunk.squeeze()
399
+
400
+ # VAD expects specific sample rates (8000 or 16000)
401
+ if sample_rate not in (8000, 16000):
402
+ import torchaudio.functional as audio_functional
403
+
404
+ audio_chunk = audio_functional.resample(audio_chunk, sample_rate, 16000)
405
+ sample_rate = 16000
406
+
407
+ # Run VAD
408
+ with torch.no_grad():
409
+ speech_prob = self._vad_model(audio_chunk, sample_rate).item()
410
+
411
+ return speech_prob > threshold, speech_prob
412
+
413
+ def reset_vad_state(self) -> None:
414
+ """Reset VAD internal state between utterances."""
415
+ if self._vad_model is not None:
416
+ self._vad_model.reset_states()
417
+
418
+ def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
419
+ """Save trainable weights (projector + audio_head if present)."""
420
+ state = {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
421
+ if self.audio_head is not None:
422
+ state.update({f"audio_head.{k}": v for k, v in self.audio_head.state_dict().items()})
423
+ return state
424
+
425
+ def _compute_encoder_output_lengths(
426
+ self,
427
+ audio_attention_mask: torch.Tensor,
428
+ ) -> torch.Tensor:
429
+ """Compute per-sample encoder output lengths using conv layer formulas.
430
+
431
+ Args:
432
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
433
+
434
+ Returns:
435
+ Tensor of encoder output lengths per sample (batch,)
436
+ """
437
+ # Get mel frame lengths from attention mask
438
+ lengths = audio_attention_mask.sum(dim=-1)
439
+
440
+ # Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
441
+ for padding, kernel_size, stride in self.config.encoder_conv_layers:
442
+ lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
443
+
444
+ return lengths
445
+
446
+ def _encode_audio(
447
+ self,
448
+ audio_features: torch.Tensor,
449
+ audio_attention_mask: torch.Tensor,
450
+ expected_token_counts: torch.Tensor | None = None,
451
+ ) -> torch.Tensor:
452
+ """Encode audio and project to LLM embedding space.
453
+
454
+ Args:
455
+ audio_features: Mel spectrogram features (batch, n_mels, mel_len)
456
+ audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
457
+ expected_token_counts: Expected number of audio tokens per sample from input_ids.
458
+ If provided, output will match these counts exactly (padding/truncating as needed).
459
+
460
+ Returns:
461
+ Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
462
+ """
463
+ with torch.no_grad():
464
+ encoder_out = self.audio_tower(input_features=audio_features)
465
+ hidden_states = encoder_out.last_hidden_state
466
+
467
+ # Project to LLM space
468
+ audio_embeds = self.projector(hidden_states)
469
+
470
+ # Use expected token counts if provided (from input_ids), otherwise compute from audio
471
+ if expected_token_counts is not None:
472
+ token_counts = expected_token_counts
473
+ else:
474
+ # Compute per-sample encoder output lengths using conv formulas
475
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
476
+ token_counts = torch.tensor(
477
+ [
478
+ self.projector.get_output_length(int(length.item()))
479
+ for length in encoder_lengths
480
+ ],
481
+ device=audio_embeds.device,
482
+ )
483
+
484
+ # Extract embeddings matching expected token counts per sample
485
+ batch_size = audio_embeds.shape[0]
486
+
487
+ result_embeds = []
488
+ for i in range(batch_size):
489
+ count = int(token_counts[i].item())
490
+ sample_embeds = audio_embeds[i, :count, :] # Take first 'count' embeddings
491
+ # Pad with learned embedding if we don't have enough embeddings
492
+ if sample_embeds.shape[0] < count:
493
+ pad_count = count - sample_embeds.shape[0]
494
+ padding = self.audio_pad_embedding.expand(pad_count, -1).to(
495
+ device=audio_embeds.device, dtype=audio_embeds.dtype
496
+ )
497
+ sample_embeds = torch.cat([sample_embeds, padding], dim=0)
498
+ result_embeds.append(sample_embeds)
499
+
500
+ return torch.cat(result_embeds, dim=0)
501
+
502
+ def forward(
503
+ self,
504
+ input_ids: Optional[torch.Tensor] = None,
505
+ input_features: Optional[torch.Tensor] = None,
506
+ audio_attention_mask: Optional[torch.Tensor] = None,
507
+ attention_mask: Optional[torch.Tensor] = None,
508
+ position_ids: Optional[torch.Tensor] = None,
509
+ past_key_values: Optional[torch.Tensor] = None,
510
+ inputs_embeds: Optional[torch.Tensor] = None,
511
+ labels: Optional[torch.Tensor] = None,
512
+ use_cache: Optional[bool] = None,
513
+ cache_position: Optional[torch.Tensor] = None,
514
+ **kwargs,
515
+ ) -> CausalLMOutputWithPast:
516
+ """Forward pass for training and inference."""
517
+ # Get text embeddings if not provided
518
+ if inputs_embeds is None:
519
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
520
+
521
+ if input_features is not None and input_ids is not None:
522
+ # Apply SpecAugment during training if enabled
523
+ if self.training and self.spec_augment is not None:
524
+ input_features = self.spec_augment(input_features)
525
+
526
+ # Count expected audio tokens from input_ids (ground truth from collator)
527
+ audio_token_counts = (input_ids == self.audio_token_id).sum(dim=-1)
528
+
529
+ # Encode audio -> flattened (total_audio_tokens, hidden_dim)
530
+ audio_embeds = self._encode_audio(
531
+ input_features, audio_attention_mask, audio_token_counts
532
+ )
533
+
534
+ # Replace <audio> token placeholders with audio embeddings using masked_scatter
535
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
536
+
537
+ inputs_embeds = inputs_embeds.masked_scatter(
538
+ audio_token_mask.to(inputs_embeds.device),
539
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
540
+ )
541
+
542
+ # Remove TRL-specific keys that shouldn't go to the LLM
543
+ kwargs.pop("prompts", None)
544
+ kwargs.pop("prompt_attention_mask", None)
545
+
546
+ # Run through language model (let it compute loss if labels provided)
547
+ outputs = self.language_model(
548
+ attention_mask=attention_mask,
549
+ position_ids=position_ids,
550
+ past_key_values=past_key_values,
551
+ inputs_embeds=inputs_embeds,
552
+ labels=labels,
553
+ use_cache=use_cache,
554
+ cache_position=cache_position,
555
+ **kwargs,
556
+ )
557
+
558
+ return outputs
559
+
560
+ def prepare_inputs_for_generation(self, *args, **kwargs):
561
+ """Prepare inputs for generation, handling audio features for cached decoding."""
562
+ input_features = kwargs.pop("input_features", None)
563
+ cache_position = kwargs.get("cache_position")
564
+
565
+ model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs)
566
+
567
+ # Only pass audio features on the first generation step (cache_position[0] == 0)
568
+ if cache_position is not None and cache_position[0] == 0 and input_features is not None:
569
+ model_inputs["input_features"] = input_features
570
+
571
+ return model_inputs
572
+
573
+ def _get_num_audio_tokens(
574
+ self,
575
+ audio_attention_mask: torch.Tensor,
576
+ ) -> int:
577
+ """Calculate number of audio tokens based on actual audio length.
578
+
579
+ Uses attention mask to get real audio length, then computes:
580
+ mel_frames -> encoder_frames (via conv formulas) -> projector output tokens
581
+ """
582
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
583
+ # Use max length for batch (all samples should have same token count for generation)
584
+ encoder_output_len = int(encoder_lengths.max().item())
585
+ return int(self.projector.get_output_length(encoder_output_len))
586
+
587
+ def _build_audio_prompt(
588
+ self,
589
+ audio_attention_mask: torch.Tensor,
590
+ batch_size: int,
591
+ device: torch.device,
592
+ system_prompt: Optional[str] = None,
593
+ ) -> tuple[torch.Tensor, torch.Tensor]:
594
+ """Build input_ids and attention_mask for audio-conditioned generation.
595
+
596
+ Args:
597
+ audio_attention_mask: Mask for real vs padded mel frames
598
+ batch_size: Batch size for expanding single prompts
599
+ device: Device to place tensors on
600
+ system_prompt: Optional system prompt override
601
+
602
+ Returns:
603
+ Tuple of (input_ids, attention_mask) tensors
604
+ """
605
+ num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
606
+ audio_placeholder = "<audio>" * num_audio_tokens
607
+
608
+ system_prompt = system_prompt or self.system_prompt
609
+
610
+ messages: list[dict[str, str]] = []
611
+ if system_prompt:
612
+ messages.append({"role": "system", "content": system_prompt})
613
+ user_content = audio_placeholder
614
+ if self.TRANSCRIBE_PROMPT:
615
+ user_content += " " + self.TRANSCRIBE_PROMPT
616
+ messages.append({"role": "user", "content": user_content})
617
+
618
+ chat_result = self.tokenizer.apply_chat_template(
619
+ messages,
620
+ tokenize=True,
621
+ add_generation_prompt=True,
622
+ return_tensors="pt",
623
+ enable_thinking=getattr(self.config, "enable_thinking", False),
624
+ )
625
+ input_ids = chat_result.input_ids.to(device)
626
+
627
+ if input_ids.dim() == 1:
628
+ input_ids = input_ids.unsqueeze(0)
629
+ if input_ids.shape[0] == 1 and batch_size > 1:
630
+ input_ids = input_ids.expand(batch_size, -1)
631
+
632
+ return input_ids, torch.ones_like(input_ids)
633
+
634
+ def _inject_audio_embeddings(
635
+ self,
636
+ input_ids: torch.Tensor,
637
+ audio_embeds: torch.Tensor,
638
+ ) -> torch.Tensor:
639
+ """Replace audio token placeholders with actual audio embeddings.
640
+
641
+ Args:
642
+ input_ids: Token IDs containing <audio> placeholder tokens
643
+ audio_embeds: Encoded audio embeddings to inject
644
+
645
+ Returns:
646
+ Input embeddings with audio tokens replaced by audio embeddings
647
+ """
648
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
649
+ audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
650
+ return inputs_embeds.masked_scatter(
651
+ audio_token_mask.to(inputs_embeds.device),
652
+ audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
653
+ )
654
+
655
+ @torch.no_grad()
656
+ def generate(
657
+ self,
658
+ input_ids: Optional[torch.Tensor] = None,
659
+ input_features: Optional[torch.Tensor] = None,
660
+ audio_attention_mask: Optional[torch.Tensor] = None,
661
+ attention_mask: Optional[torch.Tensor] = None,
662
+ system_prompt: Optional[str] = None,
663
+ **generate_kwargs,
664
+ ) -> torch.Tensor:
665
+ """Generate transcription from audio input.
666
+
667
+ Can be called in two ways:
668
+ 1. With input_ids containing <audio> tokens (from processor)
669
+ 2. With just audio, and we build the prompt internally
670
+ """
671
+ if input_features is None:
672
+ raise ValueError("input_features required for generation")
673
+ if audio_attention_mask is None:
674
+ raise ValueError("audio_attention_mask required for generation")
675
+
676
+ device = input_features.device
677
+ batch_size = input_features.shape[0]
678
+
679
+ # Encode audio -> flattened embeddings
680
+ audio_embeds = self._encode_audio(input_features, audio_attention_mask)
681
+
682
+ # If input_ids not provided, build prompt with correct number of audio tokens
683
+ if input_ids is None:
684
+ input_ids, attention_mask = self._build_audio_prompt(
685
+ audio_attention_mask, batch_size, device, system_prompt
686
+ )
687
+
688
+ # Replace audio token placeholders with audio embeddings
689
+ inputs_embeds = self._inject_audio_embeddings(input_ids, audio_embeds)
690
+
691
+ # Generate using language model
692
+ # Pass both input_ids and inputs_embeds so repetition_penalty works correctly
693
+ # (it needs input_ids to track which tokens have been used)
694
+ output = self.language_model.generate(
695
+ input_ids=input_ids,
696
+ inputs_embeds=inputs_embeds,
697
+ attention_mask=attention_mask,
698
+ generation_config=self.generation_config,
699
+ **generate_kwargs,
700
+ )
701
+
702
+ # When using inputs_embeds with input_ids, generate returns full sequence
703
+ # Strip the input tokens to return only generated tokens
704
+ sequences = output if isinstance(output, torch.Tensor) else output.sequences
705
+ input_len = input_ids.shape[1]
706
+ return sequences[:, input_len:]
707
+
708
+ def _process_audio(
709
+ self,
710
+ audio,
711
+ sampling_rate: int = 16000,
712
+ ) -> dict[str, torch.Tensor]:
713
+ """Process raw audio waveform to model inputs."""
714
+ # Convert to numpy if tensor
715
+ if isinstance(audio, torch.Tensor):
716
+ audio = audio.cpu().numpy()
717
+
718
+ # Get mel features from feature extractor
719
+ inputs = self.feature_extractor(
720
+ audio,
721
+ sampling_rate=sampling_rate,
722
+ return_attention_mask=True,
723
+ return_tensors="pt",
724
+ )
725
+
726
+ device = next(self.language_model.parameters()).device
727
+ return {
728
+ "input_features": inputs["input_features"].to(device),
729
+ "attention_mask": inputs["attention_mask"].to(device),
730
+ }
731
+
732
+ def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
733
+ """Save model, tokenizer, and processor."""
734
+ import shutil
735
+ from pathlib import Path as PathlibPath
736
+
737
+ save_dir = PathlibPath(save_directory)
738
+ save_dir.mkdir(parents=True, exist_ok=True)
739
+
740
+ # Update config with actual vocab size
741
+ self.config.vocab_size = self.language_model.config.vocab_size
742
+ self.config.text_config.vocab_size = self.language_model.config.vocab_size
743
+
744
+ if hasattr(self.audio_tower.config, "num_mel_bins"):
745
+ self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins
746
+
747
+ # Save config
748
+ self.config.save_pretrained(save_dir)
749
+
750
+ # Save state dict directly to avoid HuggingFace's tied weights handling
751
+ # which conflicts with our shared AudioHead embedding
752
+ state_dict = self.state_dict()
753
+ safe_serialization = kwargs.get("safe_serialization", True)
754
+
755
+ if safe_serialization:
756
+ from safetensors.torch import save_file
757
+
758
+ save_file(state_dict, save_dir / "model.safetensors")
759
+ else:
760
+ import torch
761
+
762
+ torch.save(state_dict, save_dir / "pytorch_model.bin")
763
+
764
+ # Save tokenizer and feature extractor
765
+ self.tokenizer.save_pretrained(save_dir)
766
+ self.feature_extractor.save_pretrained(save_dir)
767
+
768
+ # Add processor auto_map to preprocessor_config.json
769
+ config_path = save_dir / "preprocessor_config.json"
770
+ if config_path.exists():
771
+ with config_path.open() as f:
772
+ processor_config = json.load(f)
773
+ else:
774
+ processor_config = {}
775
+
776
+ processor_config.update(
777
+ {
778
+ "processor_class": "ASRProcessor",
779
+ "auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
780
+ }
781
+ )
782
+
783
+ with config_path.open("w") as f:
784
+ json.dump(processor_config, f, indent=2)
785
+
786
+ # Copy source files for auto-loading
787
+ src_dir = PathlibPath(__file__).parent
788
+ for asr_file in src_dir.glob("asr_*.py"):
789
+ shutil.copy(asr_file, save_dir / asr_file.name)
790
+ # Copy projectors module
791
+ shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
792
+ # Copy alignment module
793
+ shutil.copy(src_dir / "alignment.py", save_dir / "alignment.py")
794
+ # Copy diarization module
795
+ shutil.copy(src_dir / "diarization.py", save_dir / "diarization.py")
796
+ # Copy audio head for S2S
797
+ audio_head_path = src_dir / "audio_head.py"
798
+ if audio_head_path.exists():
799
+ shutil.copy(audio_head_path, save_dir / "audio_head.py")
800
+ # Copy full duplex session for S2S
801
+ full_duplex_path = src_dir / "full_duplex.py"
802
+ if full_duplex_path.exists():
803
+ shutil.copy(full_duplex_path, save_dir / "full_duplex.py")
804
+
805
+ def push_to_hub(self, repo_id: str, **kwargs) -> str:
806
+ """Push model to HuggingFace Hub."""
807
+ self.config.pretrained_model_path = repo_id
808
+ return super().push_to_hub(repo_id, **kwargs)
809
+
810
+ def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
811
+ """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
812
+ pass
813
+
814
+
815
+ # Register with transformers Auto classes
816
+ AutoConfig.register("asr_model", ASRConfig)
817
+ AutoModel.register(ASRConfig, ASRModel)
asr_pipeline.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
2
+
3
+ import re
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import torch
9
+ import transformers
10
+
11
+ try:
12
+ from .alignment import ForcedAligner
13
+ from .asr_modeling import ASRModel
14
+ from .diarization import LocalSpeakerDiarizer
15
+ except ImportError:
16
+ from alignment import ForcedAligner # type: ignore[no-redef]
17
+ from asr_modeling import ASRModel # type: ignore[no-redef]
18
+ from diarization import LocalSpeakerDiarizer # type: ignore[no-redef]
19
+
20
+ # Re-export for backwards compatibility
21
+ __all__ = ["ForcedAligner", "LocalSpeakerDiarizer", "ASRPipeline", "strip_thinking"]
22
+
23
+
24
+ def strip_thinking(text: str) -> str:
25
+ """Remove <think>...</think> tags from model output.
26
+
27
+ Args:
28
+ text: Model output text that may contain thinking tags
29
+
30
+ Returns:
31
+ Text with thinking content removed
32
+ """
33
+ if not text:
34
+ return text
35
+ text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)
36
+ return text.strip()
37
+
38
+
39
+ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
40
+ """ASR Pipeline for audio-to-text transcription."""
41
+
42
+ model: ASRModel
43
+
44
+ def __init__(self, model: ASRModel, **kwargs):
45
+ """Initialize ASR pipeline.
46
+
47
+ Args:
48
+ model: ASRModel instance for transcription
49
+ **kwargs: Additional arguments (feature_extractor, tokenizer, device)
50
+ """
51
+ feature_extractor = kwargs.pop("feature_extractor", None)
52
+ tokenizer = kwargs.pop("tokenizer", model.tokenizer)
53
+
54
+ if feature_extractor is None:
55
+ feature_extractor = model.get_processor().feature_extractor
56
+
57
+ super().__init__(
58
+ model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
59
+ )
60
+ self._current_audio = None
61
+
62
+ def _sanitize_parameters(self, **kwargs):
63
+ """Intercept our custom parameters before parent class validates them."""
64
+ # Remove our custom parameters so parent doesn't see them
65
+ kwargs.pop("return_timestamps", None)
66
+ kwargs.pop("return_speakers", None)
67
+ kwargs.pop("num_speakers", None)
68
+ kwargs.pop("min_speakers", None)
69
+ kwargs.pop("max_speakers", None)
70
+ kwargs.pop("hf_token", None)
71
+ kwargs.pop("user_prompt", None)
72
+ kwargs.pop("system_prompt", None)
73
+ kwargs.pop("diarization_backend", None)
74
+ return super()._sanitize_parameters(**kwargs)
75
+
76
+ def __call__(
77
+ self,
78
+ inputs,
79
+ **kwargs,
80
+ ):
81
+ """Transcribe audio with optional word-level timestamps and speaker diarization.
82
+
83
+ Args:
84
+ inputs: Audio input (file path, dict with array/sampling_rate, etc.)
85
+ return_timestamps: If True, return word-level timestamps using forced alignment
86
+ return_speakers: If True, return speaker labels for each word
87
+ user_prompt: Custom transcription prompt (default: "Transcribe: ")
88
+ system_prompt: Custom system prompt override (uses model's default if not provided)
89
+ num_speakers: Exact number of speakers (if known, for diarization)
90
+ min_speakers: Minimum number of speakers (for diarization)
91
+ max_speakers: Maximum number of speakers (for diarization)
92
+ **kwargs: Additional arguments passed to the pipeline
93
+
94
+ Returns:
95
+ Dict with 'text' key, 'words' key if return_timestamps=True,
96
+ speaker labels on words if return_speakers=True
97
+ """
98
+ # Extract our params before super().__call__ (which will also call _sanitize_parameters)
99
+ return_timestamps = kwargs.pop("return_timestamps", False)
100
+ return_speakers = kwargs.pop("return_speakers", False)
101
+ user_prompt = kwargs.pop("user_prompt", None)
102
+ system_prompt = kwargs.pop("system_prompt", None)
103
+ diarization_params = {
104
+ "num_speakers": kwargs.pop("num_speakers", None),
105
+ "min_speakers": kwargs.pop("min_speakers", None),
106
+ "max_speakers": kwargs.pop("max_speakers", None),
107
+ }
108
+
109
+ if return_speakers:
110
+ return_timestamps = True
111
+
112
+ # Set custom user prompt if provided
113
+ original_prompt = None
114
+ if user_prompt:
115
+ original_prompt = self.model.TRANSCRIBE_PROMPT
116
+ self.model.TRANSCRIBE_PROMPT = user_prompt
117
+
118
+ # Set custom system prompt if provided
119
+ original_system_prompt = None
120
+ if system_prompt:
121
+ original_system_prompt = self.model.system_prompt
122
+ self.model.system_prompt = system_prompt
123
+
124
+ # Store audio for timestamp alignment and diarization
125
+ if return_timestamps or return_speakers:
126
+ self._current_audio = self._extract_audio(inputs)
127
+
128
+ # Run standard transcription
129
+ result = super().__call__(inputs, **kwargs)
130
+
131
+ # Add timestamps if requested
132
+ if return_timestamps and self._current_audio is not None:
133
+ text = result.get("text", "")
134
+ if text:
135
+ try:
136
+ words = ForcedAligner.align(
137
+ self._current_audio["array"],
138
+ text,
139
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
140
+ )
141
+ result["words"] = words
142
+ except Exception as e:
143
+ result["words"] = []
144
+ result["timestamp_error"] = str(e)
145
+ else:
146
+ result["words"] = []
147
+
148
+ # Add speaker diarization if requested
149
+ if return_speakers and self._current_audio is not None:
150
+ try:
151
+ # Run diarization
152
+ speaker_segments = LocalSpeakerDiarizer.diarize(
153
+ self._current_audio["array"],
154
+ sample_rate=self._current_audio.get("sampling_rate", 16000),
155
+ **{k: v for k, v in diarization_params.items() if v is not None},
156
+ )
157
+ result["speaker_segments"] = speaker_segments
158
+
159
+ # Assign speakers to words
160
+ if result.get("words"):
161
+ result["words"] = LocalSpeakerDiarizer.assign_speakers_to_words(
162
+ result["words"],
163
+ speaker_segments,
164
+ )
165
+ except Exception as e:
166
+ result["speaker_segments"] = []
167
+ result["diarization_error"] = str(e)
168
+
169
+ # Clean up
170
+ self._current_audio = None
171
+ if original_prompt is not None:
172
+ self.model.TRANSCRIBE_PROMPT = original_prompt
173
+ if original_system_prompt is not None:
174
+ self.model.system_prompt = original_system_prompt
175
+
176
+ return result
177
+
178
+ def _extract_audio(self, inputs) -> dict | None:
179
+ """Extract audio array from various input formats.
180
+
181
+ Supported input formats:
182
+ - str: File path to audio file
183
+ - bytes: Encoded audio (mp3, wav, etc.) - decoded via ffmpeg
184
+ - np.ndarray: Audio samples as float32 array
185
+ - dict with "array": Audio samples as numpy array
186
+ - dict with "raw": Alias for "array" (HF pipeline compat)
187
+ - dict with "raw_bytes": Raw PCM bytes (requires "dtype", optional "sampling_rate")
188
+
189
+ For raw PCM bytes, use:
190
+ {"raw_bytes": pcm_bytes, "dtype": "int16", "sampling_rate": 16000}
191
+ """
192
+ from transformers.pipelines.audio_utils import ffmpeg_read
193
+
194
+ if isinstance(inputs, dict):
195
+ if "array" in inputs:
196
+ return {
197
+ "array": inputs["array"],
198
+ "sampling_rate": inputs.get("sampling_rate", 16000),
199
+ }
200
+ if "raw" in inputs:
201
+ return {
202
+ "array": inputs["raw"],
203
+ "sampling_rate": inputs.get("sampling_rate", 16000),
204
+ }
205
+ if "raw_bytes" in inputs:
206
+ # Raw PCM bytes - convert to float32 array
207
+ dtype = inputs.get("dtype", "int16")
208
+ sample_rate = inputs.get("sampling_rate", 16000)
209
+ audio = np.frombuffer(inputs["raw_bytes"], dtype=dtype).astype(np.float32)
210
+ # Normalize based on dtype
211
+ if dtype == "int16":
212
+ audio = audio / 32768.0
213
+ elif dtype == "int32":
214
+ audio = audio / 2147483648.0
215
+ return {"array": audio, "sampling_rate": sample_rate}
216
+ elif isinstance(inputs, str):
217
+ # File path - load audio using ffmpeg (same as HF pipeline)
218
+ with Path(inputs).open("rb") as f:
219
+ audio = ffmpeg_read(f.read(), sampling_rate=16000)
220
+ return {"array": audio, "sampling_rate": 16000}
221
+ elif isinstance(inputs, bytes):
222
+ audio = ffmpeg_read(inputs, sampling_rate=16000)
223
+ return {"array": audio, "sampling_rate": 16000}
224
+ elif isinstance(inputs, np.ndarray):
225
+ return {"array": inputs, "sampling_rate": 16000}
226
+
227
+ return None
228
+
229
+ def preprocess(self, inputs, **preprocess_params):
230
+ """Preprocess audio inputs for the model.
231
+
232
+ Args:
233
+ inputs: Audio input (dict with array, file path, etc.)
234
+ **preprocess_params: Additional preprocessing parameters
235
+
236
+ Yields:
237
+ Model input dicts with input_features and attention_mask
238
+ """
239
+ # Handle dict with "array" key (from datasets)
240
+ if isinstance(inputs, dict) and "array" in inputs:
241
+ inputs = {
242
+ "raw": inputs["array"],
243
+ "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
244
+ }
245
+
246
+ for item in super().preprocess(inputs, **preprocess_params):
247
+ if "is_last" not in item:
248
+ item["is_last"] = True
249
+ yield item
250
+
251
+ def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
252
+ """Run model forward pass to generate transcription.
253
+
254
+ Args:
255
+ model_inputs: Dict with input_features and attention_mask
256
+ **generate_kwargs: Generation parameters
257
+
258
+ Returns:
259
+ Dict with generated token IDs
260
+ """
261
+ # Extract audio features and is_last flag
262
+ is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
263
+
264
+ input_features = model_inputs["input_features"].to(self.model.device)
265
+ audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
266
+
267
+ generated_ids = self.model.generate(
268
+ input_features=input_features,
269
+ audio_attention_mask=audio_attention_mask,
270
+ **generate_kwargs,
271
+ )
272
+
273
+ return {"tokens": generated_ids, "is_last": is_last}
274
+
275
+ def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
276
+ """Convert model output tokens to text.
277
+
278
+ Args:
279
+ model_outputs: Dict with 'tokens' key containing generated IDs
280
+ **kwargs: Additional postprocessing parameters
281
+
282
+ Returns:
283
+ Dict with 'text' key containing transcription
284
+ """
285
+ # Handle list of outputs (from chunking)
286
+ if isinstance(model_outputs, list):
287
+ model_outputs = model_outputs[0] if model_outputs else {}
288
+
289
+ tokens = model_outputs.get("tokens")
290
+ if tokens is None:
291
+ return super().postprocess(model_outputs, **kwargs)
292
+
293
+ if torch.is_tensor(tokens):
294
+ tokens = tokens.cpu()
295
+ if tokens.dim() > 1:
296
+ tokens = tokens[0]
297
+
298
+ # Filter out eos tokens that the tokenizer doesn't recognize as special
299
+ # (generation_config.eos_token_id may differ from tokenizer.eos_token_id)
300
+ if hasattr(self, "model") and hasattr(self.model, "generation_config"):
301
+ eos_ids = self.model.generation_config.eos_token_id
302
+ if eos_ids is not None:
303
+ eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids}
304
+ tokens = [t for t in tokens.tolist() if t not in eos_set]
305
+
306
+ text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
307
+ # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
308
+ text = strip_thinking(text)
309
+ # Truncate repetitions at end of text
310
+ text = _truncate_repetitions(text)
311
+ return {"text": text}
312
+
313
+
314
+ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
315
+ """Truncate repeated words/phrases/characters at end of text.
316
+
317
+ Detects patterns like:
318
+ - Repeated words: "the the the the" -> "the"
319
+ - Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry"
320
+ - Repeated characters: "444444" -> "4"
321
+
322
+ Args:
323
+ text: Input text to process
324
+ min_repeats: Minimum repetitions to trigger truncation (default 3)
325
+
326
+ Returns:
327
+ Text with trailing repetitions removed
328
+ """
329
+ if not text:
330
+ return text
331
+
332
+ # 1. Truncate repeated characters at end (e.g., "444444" -> "4")
333
+ char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$")
334
+ text = char_pattern.sub(r"\1", text)
335
+
336
+ # 2. Truncate repeated words at end (e.g., "the the the" -> "the")
337
+ word_pattern = re.compile(
338
+ r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
339
+ )
340
+ while word_pattern.search(text):
341
+ text = word_pattern.sub(r"\1", text)
342
+
343
+ # 3. Truncate repeated phrases (2-20 words) at end
344
+ # e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
345
+ words = text.split()
346
+ if len(words) >= min_repeats * 2:
347
+ # Try phrase lengths from 2 to 20 words
348
+ for phrase_len in range(2, min(21, len(words) // min_repeats + 1)):
349
+ # Check if the last phrase_len words repeat
350
+ phrase = " ".join(words[-phrase_len:])
351
+ # Build pattern to match repeated phrases at end
352
+ phrase_escaped = re.escape(phrase)
353
+ phrase_pattern = re.compile(
354
+ r"(^|.*?\s)("
355
+ + phrase_escaped
356
+ + r")(?:\s+"
357
+ + phrase_escaped
358
+ + r"){"
359
+ + str(min_repeats - 1)
360
+ + r",}\s*$",
361
+ re.IGNORECASE,
362
+ )
363
+ match = phrase_pattern.match(text)
364
+ if match:
365
+ # Keep prefix + one instance of the phrase
366
+ text = (match.group(1) + match.group(2)).strip()
367
+ words = text.split()
368
+ break
369
+
370
+ return text
asr_processing.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import transformers
5
+ from transformers import ProcessorMixin
6
+
7
+ try:
8
+ from .asr_config import ASRConfig
9
+ except ImportError:
10
+ from asr_config import ASRConfig # type: ignore[no-redef]
11
+
12
+
13
+ class ASRProcessor(ProcessorMixin):
14
+ """Processor for Whisper-based ASR models."""
15
+
16
+ attributes = ["feature_extractor", "tokenizer"]
17
+ feature_extractor_class = "AutoFeatureExtractor"
18
+ tokenizer_class = "AutoTokenizer"
19
+ AUDIO_TOKEN = "<audio>"
20
+ TRANSCRIBE_PROMPT = ""
21
+ # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
+ DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
+
24
+ def __init__(
25
+ self,
26
+ feature_extractor,
27
+ tokenizer,
28
+ projector=None,
29
+ encoder_conv_layers: Optional[list] = None,
30
+ ):
31
+ """Initialize the ASR processor.
32
+
33
+ Args:
34
+ feature_extractor: Audio feature extractor (WhisperFeatureExtractor)
35
+ tokenizer: Text tokenizer for the language model
36
+ projector: Audio projector module (for computing output lengths)
37
+ encoder_conv_layers: Conv layer specs [(pad, kernel, stride), ...]
38
+ """
39
+ self.feature_extractor = feature_extractor
40
+ self.tokenizer = tokenizer
41
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
42
+ self.projector = projector
43
+ self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
44
+
45
+ def _compute_encoder_output_length(self, mel_length: int) -> int:
46
+ """Compute encoder output length using conv layer formulas."""
47
+ length = mel_length
48
+ for padding, kernel_size, stride in self.encoder_conv_layers:
49
+ length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
50
+ return length
51
+
52
+ def __call__(
53
+ self,
54
+ audio: Optional[Union[list, "torch.Tensor"]] = None,
55
+ text: Optional[str] = None,
56
+ system_prompt: Optional[str] = None,
57
+ return_tensors: str = "pt",
58
+ **kwargs,
59
+ ) -> dict:
60
+ """Process audio and text inputs for inference.
61
+
62
+ Args:
63
+ audio: Raw audio waveform(s)
64
+ text: Target transcription (optional, for training - but use DataCollator instead)
65
+ system_prompt: Optional system prompt
66
+ return_tensors: Return format ("pt" for PyTorch)
67
+
68
+ Returns:
69
+ Dict with input_features, input_ids, attention_mask
70
+ """
71
+ result = {}
72
+
73
+ # Process audio
74
+ if audio is not None:
75
+ audio_inputs = self.feature_extractor(
76
+ audio,
77
+ sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
78
+ return_attention_mask=True,
79
+ return_tensors=return_tensors,
80
+ **kwargs,
81
+ )
82
+ result["input_features"] = audio_inputs["input_features"]
83
+ result["audio_attention_mask"] = audio_inputs["attention_mask"]
84
+
85
+ # Use actual audio length (from attention mask) for token count
86
+ real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
87
+ encoder_output_len = self._compute_encoder_output_length(real_mel_len)
88
+ num_audio_tokens = self.projector.get_output_length(encoder_output_len)
89
+ else:
90
+ num_audio_tokens = 0
91
+
92
+ # Build prompt with audio token placeholders (instruction-free)
93
+ if num_audio_tokens > 0:
94
+ user_content = self.AUDIO_TOKEN * num_audio_tokens
95
+ if self.TRANSCRIBE_PROMPT:
96
+ user_content += " " + self.TRANSCRIBE_PROMPT
97
+ else:
98
+ user_content = self.TRANSCRIBE_PROMPT or ""
99
+
100
+ messages = []
101
+ if system_prompt:
102
+ messages.append({"role": "system", "content": system_prompt})
103
+ messages.append({"role": "user", "content": user_content})
104
+ if text is not None:
105
+ messages.append({"role": "assistant", "content": text})
106
+
107
+ # Tokenize
108
+ tokenized = self.tokenizer.apply_chat_template(
109
+ messages,
110
+ tokenize=True,
111
+ add_generation_prompt=(text is None),
112
+ return_tensors=return_tensors,
113
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
114
+ )
115
+
116
+ # Handle both tensor and BatchEncoding returns
117
+ if isinstance(tokenized, torch.Tensor):
118
+ input_ids = tokenized
119
+ else:
120
+ # BatchEncoding or dict-like object
121
+ input_ids = tokenized.get("input_ids", tokenized.input_ids)
122
+
123
+ if input_ids.dim() == 1:
124
+ input_ids = input_ids.unsqueeze(0)
125
+
126
+ result["input_ids"] = input_ids
127
+ result["attention_mask"] = torch.ones_like(input_ids)
128
+
129
+ return result
130
+
131
+
132
+ ASRProcessor.register_for_auto_class()
133
+ transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
audio_head.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio head for speech-to-speech using a trainable AR decoder + NeuCodec.
2
+
3
+ Generates audio from text tokens via a trainable LlamaModel decoder:
4
+ Text tokens -> Embedding -> LlamaModel -> head -> NeuCodec FSQ codes -> audio
5
+
6
+ NeuCodec uses a single FSQ codebook (levels=[4]*8, vocab=65536) at 50 tokens/sec,
7
+ outputting 24kHz audio. No multi-codebook handling needed.
8
+
9
+ Training: S2SDataCollator prepares codec_input_ids/codec_labels (both 2D: [batch, seq_len]).
10
+ AudioHead predicts FSQ codes via a single head with teacher forcing.
11
+
12
+ Inference: Autoregressive generation with KV cache, feeding back predicted codes.
13
+ """
14
+
15
+ import logging
16
+ from dataclasses import dataclass
17
+ from typing import Iterator, Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn import functional as F # noqa: N812
22
+ from transformers import PretrainedConfig, PreTrainedModel
23
+ from transformers.modeling_outputs import ModelOutput
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # NeuCodec FSQ constants (levels=[4]*8, 1 quantizer -> 4^8 = 65536 codes)
28
+ NEUCODEC_VOCAB_SIZE = 65536
29
+ NEUCODEC_SAMPLE_RATE = 24000
30
+
31
+ # Special tokens (above vocab range)
32
+ BOS_TOKEN = NEUCODEC_VOCAB_SIZE
33
+ EOS_TOKEN = NEUCODEC_VOCAB_SIZE + 1
34
+ PAD_TOKEN = NEUCODEC_VOCAB_SIZE + 2
35
+ TOTAL_VOCAB = NEUCODEC_VOCAB_SIZE + 3 # 65539
36
+
37
+
38
+ class AudioHeadConfig(PretrainedConfig):
39
+ """Configuration class for the AudioHead model."""
40
+
41
+ model_type = "audio_head"
42
+
43
+ def __init__(
44
+ self,
45
+ decoder_dim: int = 512,
46
+ decoder_layers: int = 6,
47
+ decoder_heads: int = 8,
48
+ text_vocab_size: int = 32000,
49
+ max_audio_tokens: int = 500,
50
+ neucodec_model_id: str = "neuphonic/neucodec",
51
+ temperature: float = 1.0,
52
+ top_k: int = 50,
53
+ **kwargs,
54
+ ):
55
+ self.decoder_dim = decoder_dim
56
+ self.decoder_layers = decoder_layers
57
+ self.decoder_heads = decoder_heads
58
+ self.text_vocab_size = text_vocab_size
59
+ self.max_audio_tokens = max_audio_tokens
60
+ self.neucodec_model_id = neucodec_model_id
61
+ self.temperature = temperature
62
+ self.top_k = top_k
63
+ super().__init__(**kwargs)
64
+
65
+
66
+ @dataclass
67
+ class AudioHeadOutput(ModelOutput):
68
+ """Output of AudioHead forward pass.
69
+
70
+ Attributes:
71
+ loss: Cross-entropy loss when codec_labels are provided.
72
+ codes: Generated codec codes when in inference mode [batch, gen_len].
73
+ """
74
+
75
+ loss: Optional[torch.Tensor] = None
76
+ codes: Optional[torch.Tensor] = None
77
+
78
+
79
+ class AudioHead(PreTrainedModel):
80
+ """Trainable AR decoder that predicts NeuCodec FSQ codes.
81
+
82
+ NeuCodec uses a single FSQ codebook (4^8 = 65536 codes) at 50 tokens/sec.
83
+ No multi-codebook handling needed — just a flat sequence of codes.
84
+ """
85
+
86
+ config_class = AudioHeadConfig
87
+
88
+ def __init__(self, config: AudioHeadConfig):
89
+ super().__init__(config)
90
+ self.text_vocab_size = config.text_vocab_size
91
+ self.decoder_dim = config.decoder_dim
92
+ self.max_tokens = config.max_audio_tokens
93
+ self.vocab_size = NEUCODEC_VOCAB_SIZE
94
+
95
+ # Embed text tokens to decoder dim
96
+ self.text_embedding = nn.Embedding(config.text_vocab_size, config.decoder_dim)
97
+
98
+ # Codec token embedding (FSQ codes + special tokens)
99
+ self.token_embedding = nn.Embedding(TOTAL_VOCAB, config.decoder_dim)
100
+
101
+ # Small LlamaModel as decoder backbone (from config, NOT pretrained)
102
+ from transformers import LlamaConfig, LlamaModel
103
+
104
+ llama_config = LlamaConfig(
105
+ hidden_size=config.decoder_dim,
106
+ intermediate_size=config.decoder_dim * 4,
107
+ num_hidden_layers=config.decoder_layers,
108
+ num_attention_heads=config.decoder_heads,
109
+ vocab_size=TOTAL_VOCAB,
110
+ max_position_embeddings=4096,
111
+ )
112
+ self.decoder = LlamaModel(llama_config)
113
+ # We handle embeddings ourselves, remove the unused one to save memory
114
+ self.decoder.embed_tokens = None
115
+
116
+ # Sampling parameters for inference
117
+ self.temperature = config.temperature
118
+ self.top_k = config.top_k
119
+
120
+ # NeuCodec model (loaded lazily, frozen, inference only)
121
+ self.neucodec_model = None
122
+
123
+ # Initialize weights
124
+ self.post_init()
125
+
126
+ def forward(
127
+ self,
128
+ text_token_ids: torch.Tensor,
129
+ attention_mask: Optional[torch.Tensor] = None,
130
+ codec_labels: Optional[torch.Tensor] = None,
131
+ codec_input_ids: Optional[torch.Tensor] = None,
132
+ codec_attention_mask: Optional[torch.Tensor] = None,
133
+ **kwargs,
134
+ ) -> AudioHeadOutput:
135
+ """Forward pass for training or inference.
136
+
137
+ Args:
138
+ text_token_ids: Text token IDs [batch, seq_len]
139
+ attention_mask: Text attention mask [batch, seq_len] (1=real, 0=padding)
140
+ codec_labels: Target codes [batch, audio_len] (-100 for ignore)
141
+ codec_input_ids: Teacher-forced input [batch, audio_len]
142
+ codec_attention_mask: Codec attention mask [batch, audio_len]
143
+
144
+ Returns:
145
+ AudioHeadOutput with loss (training) or codes (inference).
146
+ """
147
+ # Embed text tokens (clamp to valid range)
148
+ if (text_token_ids >= self.text_vocab_size).any() or (text_token_ids < 0).any():
149
+ logger.warning(
150
+ "text_token_ids out of range [0, %d): min=%d max=%d. Clamping.",
151
+ self.text_vocab_size, text_token_ids.min().item(), text_token_ids.max().item(),
152
+ )
153
+ text_token_ids = text_token_ids.clamp(0, self.text_vocab_size - 1)
154
+ prefix = self.text_embedding(text_token_ids) # [batch, text_len, decoder_dim]
155
+ batch_size, text_len, _ = prefix.shape
156
+
157
+ if codec_labels is not None:
158
+ # Teacher forcing: codec_input_ids is [batch, audio_len]
159
+ cb_input = codec_input_ids
160
+ if (cb_input >= TOTAL_VOCAB).any() or (cb_input < 0).any():
161
+ logger.warning(
162
+ "codec_input_ids out of range [0, %d): min=%d max=%d. Clamping.",
163
+ TOTAL_VOCAB, cb_input.min().item(), cb_input.max().item(),
164
+ )
165
+ cb_input = cb_input.clamp(0, TOTAL_VOCAB - 1)
166
+ token_emb = self.token_embedding(cb_input) # [batch, audio_len, dim]
167
+
168
+ audio_len = token_emb.shape[1]
169
+
170
+ # Concatenate prefix + codec tokens
171
+ hidden = torch.cat([prefix, token_emb], dim=1) # [batch, text+audio, dim]
172
+
173
+ # Build combined attention mask
174
+ if attention_mask is not None:
175
+ prefix_mask = attention_mask
176
+ else:
177
+ prefix_mask = torch.ones(
178
+ batch_size, text_len, device=hidden.device, dtype=torch.long
179
+ )
180
+
181
+ if codec_attention_mask is not None:
182
+ audio_mask = codec_attention_mask
183
+ else:
184
+ audio_mask = torch.ones(
185
+ batch_size, audio_len, device=hidden.device, dtype=torch.long
186
+ )
187
+
188
+ combined_mask = torch.cat([prefix_mask, audio_mask], dim=1)
189
+
190
+ # Build causal mask for codec positions while prefix attends bidirectionally
191
+ total_len = text_len + audio_len
192
+ causal_mask = torch.triu(
193
+ torch.full((total_len, total_len), float("-inf"), device=hidden.device),
194
+ diagonal=1,
195
+ )
196
+ causal_mask[:text_len, :text_len] = 0.0
197
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1)
198
+
199
+ padding_mask = (1 - combined_mask).bool()
200
+ padding_mask_expanded = padding_mask.unsqueeze(1).unsqueeze(2).expand_as(causal_mask)
201
+ causal_mask = causal_mask.masked_fill(padding_mask_expanded, float("-inf"))
202
+
203
+ position_ids = (
204
+ torch.arange(total_len, device=hidden.device).unsqueeze(0).expand(batch_size, -1)
205
+ )
206
+
207
+ # Run through LlamaModel
208
+ outputs = self.decoder(
209
+ inputs_embeds=hidden,
210
+ attention_mask=causal_mask,
211
+ position_ids=position_ids,
212
+ )
213
+
214
+ # Extract audio positions only
215
+ audio_hidden = outputs.last_hidden_state[:, text_len:] # [batch, audio_len, dim]
216
+
217
+ # Predict codes and compute loss
218
+ labels = codec_labels.clone() # [batch, audio_len]
219
+ valid_mask = labels != -100
220
+ labels[valid_mask] = labels[valid_mask].clamp(0, TOTAL_VOCAB - 1)
221
+
222
+ logits = F.linear(audio_hidden, self.token_embedding.weight) # [batch, audio_len, total_vocab]
223
+ loss = F.cross_entropy(
224
+ logits.reshape(-1, TOTAL_VOCAB),
225
+ labels.reshape(-1),
226
+ ignore_index=-100,
227
+ )
228
+ return AudioHeadOutput(loss=loss)
229
+
230
+ # Inference: autoregressive generation
231
+ codes = self._generate(prefix, attention_mask)
232
+ return AudioHeadOutput(codes=codes)
233
+
234
+ def _generate(
235
+ self, prefix: torch.Tensor, prefix_mask: Optional[torch.Tensor]
236
+ ) -> torch.Tensor:
237
+ """AR generation: predict codes one timestep at a time with KV cache."""
238
+ batch_size, text_len, _ = prefix.shape
239
+ device = prefix.device
240
+
241
+ all_codes = []
242
+
243
+ # Build initial input: prefix + BOS embedding
244
+ bos_token = torch.full((batch_size, 1), BOS_TOKEN, dtype=torch.long, device=device)
245
+ bos_emb = self.token_embedding(bos_token) # [batch, 1, dim]
246
+ hidden = torch.cat([prefix, bos_emb], dim=1) # [batch, text_len+1, dim]
247
+
248
+ # Position IDs for initial forward
249
+ position_ids = torch.arange(text_len + 1, device=device).unsqueeze(0).expand(batch_size, -1)
250
+
251
+ # Initial forward pass (no KV cache yet)
252
+ outputs = self.decoder(
253
+ inputs_embeds=hidden,
254
+ position_ids=position_ids,
255
+ use_cache=True,
256
+ )
257
+ past_key_values = outputs.past_key_values
258
+ last_hidden = outputs.last_hidden_state[:, -1:] # [batch, 1, dim]
259
+
260
+ for step in range(self.max_tokens):
261
+ # Predict code token
262
+ logits = F.linear(last_hidden.squeeze(1), self.token_embedding.weight) # [batch, vocab]
263
+
264
+ # Apply temperature and top-k sampling
265
+ if self.temperature > 0 and self.top_k > 0:
266
+ logits = logits / self.temperature
267
+ # Zero out logits below top-k threshold
268
+ top_k_vals, _ = logits.topk(self.top_k, dim=-1)
269
+ logits[logits < top_k_vals[:, -1:]] = float("-inf")
270
+ probs = F.softmax(logits, dim=-1)
271
+ token = torch.multinomial(probs, num_samples=1).squeeze(-1) # [batch]
272
+ else:
273
+ token = logits.argmax(dim=-1) # [batch]
274
+
275
+ # Check for EOS
276
+ if (token == EOS_TOKEN).all():
277
+ break
278
+
279
+ all_codes.append(token)
280
+
281
+ # Feed back prediction for next step
282
+ next_emb = self.token_embedding(token.unsqueeze(1)) # [batch, 1, dim]
283
+
284
+ next_pos = torch.full(
285
+ (batch_size, 1), text_len + 1 + step + 1, dtype=torch.long, device=device
286
+ )
287
+
288
+ # Forward with KV cache
289
+ outputs = self.decoder(
290
+ inputs_embeds=next_emb,
291
+ position_ids=next_pos,
292
+ past_key_values=past_key_values,
293
+ use_cache=True,
294
+ )
295
+ past_key_values = outputs.past_key_values
296
+ last_hidden = outputs.last_hidden_state # [batch, 1, dim]
297
+
298
+ if all_codes:
299
+ # [batch, gen_len]
300
+ codes = torch.stack(all_codes, dim=1)
301
+ else:
302
+ codes = torch.empty(batch_size, 0, dtype=torch.long, device=device)
303
+
304
+ return codes
305
+
306
+ def _load_neucodec(self):
307
+ """Load frozen NeuCodec model for audio decoding."""
308
+ from neucodec import NeuCodec
309
+
310
+ self.neucodec_model = NeuCodec.from_pretrained(self.config.neucodec_model_id)
311
+ self.neucodec_model.eval()
312
+ self.neucodec_model.requires_grad_(False)
313
+ logger.info("Loaded frozen NeuCodec model for audio decoding")
314
+
315
+ def decode_to_audio(self, codes: torch.Tensor) -> list[torch.Tensor]:
316
+ """Decode NeuCodec FSQ tokens to audio waveforms.
317
+
318
+ Args:
319
+ codes: Codec tokens [batch, seq_len]
320
+
321
+ Returns:
322
+ List of audio waveform tensors (one per batch item)
323
+ """
324
+ if self.neucodec_model is None:
325
+ self._load_neucodec()
326
+ assert self.neucodec_model is not None
327
+
328
+ # NeuCodec decode_code expects [batch, 1, seq_len]
329
+ codes_3d = codes.unsqueeze(1).to(self.neucodec_model.device)
330
+
331
+ with torch.no_grad():
332
+ audio_values = self.neucodec_model.decode_code(codes_3d) # [batch, 1, samples]
333
+
334
+ return [audio_values[i, 0] for i in range(audio_values.shape[0])]
335
+
336
+ def generate_streaming(
337
+ self,
338
+ text_token_ids: torch.Tensor,
339
+ chunk_samples: int = 24000,
340
+ ) -> Iterator[torch.Tensor]:
341
+ """Generate audio and yield waveform chunks for streaming playback.
342
+
343
+ Args:
344
+ text_token_ids: Text token IDs [batch, seq_len]
345
+ chunk_samples: Audio samples per chunk (default 1s at 24kHz)
346
+
347
+ Yields:
348
+ Audio waveform chunks [samples]
349
+ """
350
+ output = self(text_token_ids)
351
+ codes = output.codes
352
+ audios = self.decode_to_audio(codes)
353
+
354
+ for audio in audios:
355
+ for start in range(0, audio.shape[-1], chunk_samples):
356
+ end = min(start + chunk_samples, audio.shape[-1])
357
+ yield audio[..., start:end]
chat_template.jinja ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {# ───── defaults ───── #}
2
+ {%- if enable_thinking is not defined -%}
3
+ {%- set enable_thinking = true -%}
4
+ {%- endif -%}
5
+
6
+ {# ───── reasoning mode ───── #}
7
+ {%- if enable_thinking -%}
8
+ {%- set reasoning_mode = "/think" -%}
9
+ {%- else -%}
10
+ {%- set reasoning_mode = "/no_think" -%}
11
+ {%- endif -%}
12
+
13
+ {# ───── header (system message) ───── #}
14
+ {{- "<|im_start|>system\n" -}}
15
+
16
+ {%- if messages[0].role == "system" -%}
17
+ {%- set system_message = messages[0].content -%}
18
+ {%- if "/no_think" in system_message -%}
19
+ {%- set reasoning_mode = "/no_think" -%}
20
+ {%- elif "/think" in system_message -%}
21
+ {%- set reasoning_mode = "/think" -%}
22
+ {%- endif -%}
23
+ {%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
24
+ {%- endif -%}
25
+
26
+ {%- if "/system_override" in system_message -%}
27
+ {{- custom_instructions.replace("/system_override", "").rstrip() -}}
28
+ {{- "<|im_end|>\n" -}}
29
+ {%- else -%}
30
+ {{- "## Metadata\n\n" -}}
31
+ {{- "Knowledge Cutoff Date: June 2025\n" -}}
32
+ {%- set today = strftime_now("%d %B %Y") -%}
33
+ {{- "Today Date: " ~ today ~ "\n" -}}
34
+ {{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
35
+
36
+ {{- "## Custom Instructions\n\n" -}}
37
+ {%- if custom_instructions -%}
38
+ {{- custom_instructions + "\n\n" -}}
39
+ {%- elif reasoning_mode == "/think" -%}
40
+ {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
41
+ {%- else -%}
42
+ {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
43
+ {%- endif -%}
44
+
45
+ {%- if xml_tools or python_tools or tools -%}
46
+ {{- "### Tools\n\n" -}}
47
+ {%- if xml_tools or tools -%}
48
+ {%- if tools -%}
49
+ {%- set xml_tools = tools -%}
50
+ {%- endif -%}
51
+ {%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
52
+ {%- for tool in xml_tools[:] -%} {# The slicing makes sure that xml_tools is a list #}
53
+ {%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | string) ~ "\n" -%}
54
+ {%- endfor -%}
55
+ {%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
56
+ {{- xml_tool_string -}}
57
+ {%- endif -%}
58
+ {%- if python_tools -%}
59
+ {%- set ns = namespace(python_tool_string="When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.\n\nYou can use the following tools in your python code like regular functions:\n<tools>\n") -%}
60
+ {%- for tool in python_tools[:] -%} {# The slicing makes sure that python_tools is a list #}
61
+ {%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
62
+ {%- endfor -%}
63
+ {%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions: so variables that you define in one step are still available thereafter." -%}
64
+ {{- python_tool_string -}}
65
+ {%- endif -%}
66
+ {{- "\n\n" -}}
67
+ {{- "<|im_end|>\n" -}}
68
+ {%- endif -%}
69
+ {%- endif -%}
70
+ {# ───── main loop ───── #}
71
+ {%- for message in messages -%}
72
+ {%- set content = message.content if message.content is string else "" -%}
73
+ {%- if message.role == "user" -%}
74
+ {{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
75
+ {%- elif message.role == "assistant" -%}
76
+ {% generation %}
77
+ {%- if reasoning_mode == "/think" -%}
78
+ {{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
79
+ {%- else -%}
80
+ {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
81
+ {%- endif -%}
82
+ {% endgeneration %}
83
+ {%- elif message.role == "tool" -%}
84
+ {{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
85
+ {%- endif -%}
86
+ {%- endfor -%}
87
+ {# ───── generation prompt ───── #}
88
+ {%- if add_generation_prompt -%}
89
+ {%- if reasoning_mode == "/think" -%}
90
+ {{ "<|im_start|>assistant\n" }}
91
+ {%- else -%}
92
+ {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
93
+ {%- endif -%}
94
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ASRModel"
4
+ ],
5
+ "attn_implementation": null,
6
+ "audio_config": {
7
+ "_name_or_path": "zai-org/GLM-ASR-Nano-2512",
8
+ "architectures": [
9
+ "GlmAsrForConditionalGeneration"
10
+ ],
11
+ "audio_config": {
12
+ "_name_or_path": "",
13
+ "architectures": null,
14
+ "attention_dropout": 0.0,
15
+ "chunk_size_feed_forward": 0,
16
+ "dtype": null,
17
+ "head_dim": 64,
18
+ "hidden_act": "gelu",
19
+ "hidden_size": 1280,
20
+ "id2label": {
21
+ "0": "LABEL_0",
22
+ "1": "LABEL_1"
23
+ },
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 5120,
26
+ "is_encoder_decoder": false,
27
+ "label2id": {
28
+ "LABEL_0": 0,
29
+ "LABEL_1": 1
30
+ },
31
+ "max_position_embeddings": 1500,
32
+ "model_type": "glmasr_encoder",
33
+ "num_attention_heads": 20,
34
+ "num_hidden_layers": 32,
35
+ "num_key_value_heads": 20,
36
+ "num_mel_bins": 128,
37
+ "output_attentions": false,
38
+ "output_hidden_states": false,
39
+ "partial_rotary_factor": 0.5,
40
+ "problem_type": null,
41
+ "return_dict": true,
42
+ "rope_parameters": {
43
+ "partial_rotary_factor": 0.5,
44
+ "rope_theta": 10000.0,
45
+ "rope_type": "default"
46
+ }
47
+ },
48
+ "audio_token_id": 59260,
49
+ "dtype": "bfloat16",
50
+ "hidden_size": 2048,
51
+ "model_type": "glmasr",
52
+ "num_mel_bins": 128,
53
+ "projector_hidden_act": "gelu",
54
+ "text_config": {
55
+ "_name_or_path": "",
56
+ "architectures": null,
57
+ "attention_bias": false,
58
+ "attention_dropout": 0.0,
59
+ "bos_token_id": 1,
60
+ "chunk_size_feed_forward": 0,
61
+ "dtype": null,
62
+ "eos_token_id": [
63
+ 59246,
64
+ 59253,
65
+ 59255
66
+ ],
67
+ "head_dim": 128,
68
+ "hidden_act": "silu",
69
+ "hidden_size": 2048,
70
+ "id2label": {
71
+ "0": "LABEL_0",
72
+ "1": "LABEL_1"
73
+ },
74
+ "initializer_range": 0.02,
75
+ "intermediate_size": 6144,
76
+ "is_encoder_decoder": false,
77
+ "label2id": {
78
+ "LABEL_0": 0,
79
+ "LABEL_1": 1
80
+ },
81
+ "max_position_embeddings": 8192,
82
+ "mlp_bias": false,
83
+ "model_type": "llama",
84
+ "num_attention_heads": 16,
85
+ "num_hidden_layers": 28,
86
+ "num_key_value_heads": 4,
87
+ "output_attentions": false,
88
+ "output_hidden_states": false,
89
+ "pad_token_id": null,
90
+ "pretraining_tp": 1,
91
+ "problem_type": null,
92
+ "return_dict": true,
93
+ "rms_norm_eps": 1e-05,
94
+ "rope_parameters": {
95
+ "rope_theta": 10000.0,
96
+ "rope_type": "default"
97
+ },
98
+ "tie_word_embeddings": false,
99
+ "use_cache": true,
100
+ "vocab_size": 59264
101
+ },
102
+ "vocab_size": 59264
103
+ },
104
+ "audio_model_id": "zai-org/GLM-ASR-Nano-2512",
105
+ "audio_sample_rate": 16000,
106
+ "auto_map": {
107
+ "AutoConfig": "asr_config.ASRConfig",
108
+ "AutoModel": "asr_modeling.ASRModel",
109
+ "AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
110
+ "AutoProcessor": "asr_processing.ASRProcessor"
111
+ },
112
+ "custom_pipelines": {
113
+ "automatic-speech-recognition": {
114
+ "impl": "asr_pipeline.ASRPipeline",
115
+ "pt": [
116
+ "AutoModelForSpeechSeq2Seq"
117
+ ],
118
+ "tf": [],
119
+ "type": "audio"
120
+ }
121
+ },
122
+ "decoder_dim": 256,
123
+ "decoder_heads": 4,
124
+ "decoder_layers": 4,
125
+ "do_sample": false,
126
+ "downsample_rate": 5,
127
+ "dtype": "bfloat16",
128
+ "enable_thinking": false,
129
+ "encoder": {
130
+ "_name_or_path": "zai-org/GLM-ASR-Nano-2512",
131
+ "architectures": [
132
+ "GlmAsrForConditionalGeneration"
133
+ ],
134
+ "audio_config": {
135
+ "_name_or_path": "",
136
+ "architectures": null,
137
+ "attention_dropout": 0.0,
138
+ "chunk_size_feed_forward": 0,
139
+ "dtype": null,
140
+ "head_dim": 64,
141
+ "hidden_act": "gelu",
142
+ "hidden_size": 1280,
143
+ "id2label": {
144
+ "0": "LABEL_0",
145
+ "1": "LABEL_1"
146
+ },
147
+ "initializer_range": 0.02,
148
+ "intermediate_size": 5120,
149
+ "is_encoder_decoder": false,
150
+ "label2id": {
151
+ "LABEL_0": 0,
152
+ "LABEL_1": 1
153
+ },
154
+ "max_position_embeddings": 1500,
155
+ "model_type": "glmasr_encoder",
156
+ "num_attention_heads": 20,
157
+ "num_hidden_layers": 32,
158
+ "num_key_value_heads": 20,
159
+ "num_mel_bins": 128,
160
+ "output_attentions": false,
161
+ "output_hidden_states": false,
162
+ "partial_rotary_factor": 0.5,
163
+ "problem_type": null,
164
+ "return_dict": true,
165
+ "rope_parameters": {
166
+ "partial_rotary_factor": 0.5,
167
+ "rope_theta": 10000.0,
168
+ "rope_type": "default"
169
+ }
170
+ },
171
+ "audio_token_id": 59260,
172
+ "dtype": "bfloat16",
173
+ "hidden_size": 2048,
174
+ "model_type": "glmasr",
175
+ "num_mel_bins": 128,
176
+ "projector_hidden_act": "gelu",
177
+ "text_config": {
178
+ "_name_or_path": "",
179
+ "architectures": null,
180
+ "attention_bias": false,
181
+ "attention_dropout": 0.0,
182
+ "bos_token_id": 1,
183
+ "chunk_size_feed_forward": 0,
184
+ "dtype": null,
185
+ "eos_token_id": [
186
+ 59246,
187
+ 59253,
188
+ 59255
189
+ ],
190
+ "head_dim": 128,
191
+ "hidden_act": "silu",
192
+ "hidden_size": 2048,
193
+ "id2label": {
194
+ "0": "LABEL_0",
195
+ "1": "LABEL_1"
196
+ },
197
+ "initializer_range": 0.02,
198
+ "intermediate_size": 6144,
199
+ "is_encoder_decoder": false,
200
+ "label2id": {
201
+ "LABEL_0": 0,
202
+ "LABEL_1": 1
203
+ },
204
+ "max_position_embeddings": 8192,
205
+ "mlp_bias": false,
206
+ "model_type": "llama",
207
+ "num_attention_heads": 16,
208
+ "num_hidden_layers": 28,
209
+ "num_key_value_heads": 4,
210
+ "output_attentions": false,
211
+ "output_hidden_states": false,
212
+ "pad_token_id": null,
213
+ "pretraining_tp": 1,
214
+ "problem_type": null,
215
+ "return_dict": true,
216
+ "rms_norm_eps": 1e-05,
217
+ "rope_parameters": {
218
+ "rope_theta": 10000.0,
219
+ "rope_type": "default"
220
+ },
221
+ "tie_word_embeddings": false,
222
+ "use_cache": true,
223
+ "vocab_size": 59264
224
+ },
225
+ "vocab_size": 59264
226
+ },
227
+ "encoder_conv_layers": [
228
+ [
229
+ 1,
230
+ 3,
231
+ 1
232
+ ],
233
+ [
234
+ 1,
235
+ 3,
236
+ 2
237
+ ]
238
+ ],
239
+ "encoder_dim": 1280,
240
+ "freeze_audio_head": false,
241
+ "freeze_projector": false,
242
+ "freq_mask_length": 27,
243
+ "label_smoothing": 0.0,
244
+ "length_penalty": 1.0,
245
+ "llm_dim": 2048,
246
+ "lora_alpha": 32,
247
+ "lora_dropout": 0.0,
248
+ "lora_rank": 8,
249
+ "lora_target_modules": [
250
+ "q_proj",
251
+ "k_proj",
252
+ "v_proj",
253
+ "o_proj",
254
+ "gate_proj",
255
+ "up_proj",
256
+ "down_proj"
257
+ ],
258
+ "max_audio_tokens": 500,
259
+ "max_new_tokens": 128,
260
+ "min_new_tokens": 0,
261
+ "model_dtype": "bfloat16",
262
+ "model_type": "asr_model",
263
+ "neucodec_model_id": "neuphonic/neucodec",
264
+ "no_repeat_ngram_size": 0,
265
+ "num_beams": 1,
266
+ "num_experts": 4,
267
+ "num_experts_per_tok": 2,
268
+ "num_freq_masks": 2,
269
+ "num_time_masks": 2,
270
+ "pipeline_tag": "automatic-speech-recognition",
271
+ "pretrained_model_path": "mazesmazes/tiny-audio-s2s-full",
272
+ "projector_dropout": 0.0,
273
+ "projector_hidden_dim": 1024,
274
+ "projector_init_std": 0.02,
275
+ "projector_num_layers": 2,
276
+ "projector_pool_stride": 4,
277
+ "projector_type": "mlp",
278
+ "qformer_hidden_size": null,
279
+ "qformer_intermediate_size": null,
280
+ "qformer_num_heads": 16,
281
+ "qformer_num_layers": 2,
282
+ "qformer_window_size": 15,
283
+ "repetition_penalty": 1.1,
284
+ "router_aux_loss_coef": 0.01,
285
+ "system_prompt": "",
286
+ "temperature": 1.0,
287
+ "text_config": {
288
+ "_name_or_path": "HuggingFaceTB/SmolLM3-3B",
289
+ "architectures": [
290
+ "SmolLM3ForCausalLM"
291
+ ],
292
+ "attention_bias": false,
293
+ "attention_dropout": 0.0,
294
+ "bos_token_id": null,
295
+ "dtype": "bfloat16",
296
+ "eos_token_id": 128012,
297
+ "hidden_act": "silu",
298
+ "hidden_size": 2048,
299
+ "initializer_range": 0.02,
300
+ "intermediate_size": 11008,
301
+ "layer_types": [
302
+ "full_attention",
303
+ "full_attention",
304
+ "full_attention",
305
+ "full_attention",
306
+ "full_attention",
307
+ "full_attention",
308
+ "full_attention",
309
+ "full_attention",
310
+ "full_attention",
311
+ "full_attention",
312
+ "full_attention",
313
+ "full_attention",
314
+ "full_attention",
315
+ "full_attention",
316
+ "full_attention",
317
+ "full_attention",
318
+ "full_attention",
319
+ "full_attention",
320
+ "full_attention",
321
+ "full_attention",
322
+ "full_attention",
323
+ "full_attention",
324
+ "full_attention",
325
+ "full_attention",
326
+ "full_attention",
327
+ "full_attention",
328
+ "full_attention",
329
+ "full_attention",
330
+ "full_attention",
331
+ "full_attention",
332
+ "full_attention",
333
+ "full_attention",
334
+ "full_attention",
335
+ "full_attention",
336
+ "full_attention",
337
+ "full_attention"
338
+ ],
339
+ "max_position_embeddings": 65536,
340
+ "max_window_layers": 28,
341
+ "mlp_bias": false,
342
+ "model_type": "smollm3",
343
+ "no_rope_layer_interval": 4,
344
+ "no_rope_layers": [
345
+ 1,
346
+ 1,
347
+ 1,
348
+ 0,
349
+ 1,
350
+ 1,
351
+ 1,
352
+ 0,
353
+ 1,
354
+ 1,
355
+ 1,
356
+ 0,
357
+ 1,
358
+ 1,
359
+ 1,
360
+ 0,
361
+ 1,
362
+ 1,
363
+ 1,
364
+ 0,
365
+ 1,
366
+ 1,
367
+ 1,
368
+ 0,
369
+ 1,
370
+ 1,
371
+ 1,
372
+ 0,
373
+ 1,
374
+ 1,
375
+ 1,
376
+ 0,
377
+ 1,
378
+ 1,
379
+ 1,
380
+ 0
381
+ ],
382
+ "num_attention_heads": 16,
383
+ "num_hidden_layers": 36,
384
+ "num_key_value_heads": 4,
385
+ "pad_token_id": 128004,
386
+ "pretraining_tp": 2,
387
+ "rms_norm_eps": 1e-06,
388
+ "rope_parameters": {
389
+ "rope_theta": 5000000.0,
390
+ "rope_type": "default"
391
+ },
392
+ "sliding_window": null,
393
+ "tie_word_embeddings": true,
394
+ "use_cache": false,
395
+ "use_sliding_window": false,
396
+ "vocab_size": 128320
397
+ },
398
+ "text_model_id": "HuggingFaceTB/SmolLM3-3B",
399
+ "text_vocab_size": 128257,
400
+ "time_mask_length": 100,
401
+ "top_k": 50,
402
+ "top_p": null,
403
+ "transformers_version": "5.0.0",
404
+ "use_audio_head": true,
405
+ "use_cache": false,
406
+ "use_lora": false,
407
+ "use_specaugment": true,
408
+ "vocab_size": 128320
409
+ }
diarization.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
2
+
3
+ Spectral clustering implementation adapted from FunASR/3D-Speaker:
4
+ https://github.com/alibaba-damo-academy/FunASR
5
+ MIT License (https://opensource.org/licenses/MIT)
6
+ """
7
+
8
+ import warnings
9
+
10
+ import numpy as np
11
+ import scipy
12
+ import sklearn.metrics.pairwise
13
+ import torch
14
+ from sklearn.cluster._kmeans import k_means
15
+ from sklearn.preprocessing import normalize
16
+
17
+
18
+ def _get_device() -> torch.device:
19
+ """Get best available device for inference."""
20
+ if torch.cuda.is_available():
21
+ return torch.device("cuda")
22
+ if torch.backends.mps.is_available():
23
+ return torch.device("mps")
24
+ return torch.device("cpu")
25
+
26
+
27
+ class SpectralCluster:
28
+ """Spectral clustering using unnormalized Laplacian of affinity matrix.
29
+
30
+ Adapted from FunASR/3D-Speaker and SpeechBrain implementations.
31
+ Uses eigenvalue gap to automatically determine number of speakers.
32
+ """
33
+
34
+ def __init__(self, min_num_spks: int = 1, max_num_spks: int = 15, pval: float = 0.06):
35
+ self.min_num_spks = min_num_spks
36
+ self.max_num_spks = max_num_spks
37
+ self.pval = pval
38
+
39
+ def __call__(self, embeddings: np.ndarray, oracle_num: int | None = None) -> np.ndarray:
40
+ """Run spectral clustering on embeddings.
41
+
42
+ Args:
43
+ embeddings: Speaker embeddings of shape [N, D]
44
+ oracle_num: Optional known number of speakers
45
+
46
+ Returns:
47
+ Cluster labels of shape [N]
48
+ """
49
+ # Similarity matrix computation
50
+ sim_mat = self.get_sim_mat(embeddings)
51
+
52
+ # Refining similarity matrix with pval
53
+ prunned_sim_mat = self.p_pruning(sim_mat)
54
+
55
+ # Symmetrization
56
+ sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
57
+
58
+ # Laplacian calculation
59
+ laplacian = self.get_laplacian(sym_prund_sim_mat)
60
+
61
+ # Get Spectral Embeddings
62
+ emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
63
+
64
+ # Perform clustering
65
+ return self.cluster_embs(emb, num_of_spk)
66
+
67
+ def get_sim_mat(self, embeddings: np.ndarray) -> np.ndarray:
68
+ """Compute cosine similarity matrix."""
69
+ return sklearn.metrics.pairwise.cosine_similarity(embeddings, embeddings)
70
+
71
+ def p_pruning(self, affinity: np.ndarray) -> np.ndarray:
72
+ """Prune low similarity values in affinity matrix (keep top pval fraction)."""
73
+ n = affinity.shape[0]
74
+ pval = max(self.pval, 6.0 / n)
75
+ k_keep = max(1, int(pval * n))
76
+
77
+ # Vectorized: find top-k indices per row and zero out the rest
78
+ top_k_idx = np.argpartition(affinity, -k_keep, axis=1)[:, -k_keep:]
79
+ mask = np.zeros_like(affinity, dtype=bool)
80
+ np.put_along_axis(mask, top_k_idx, True, axis=1)
81
+ affinity[~mask] = 0
82
+ return affinity
83
+
84
+ def get_laplacian(self, sim_mat: np.ndarray) -> np.ndarray:
85
+ """Compute unnormalized Laplacian matrix."""
86
+ from scipy.sparse.csgraph import laplacian
87
+
88
+ np.fill_diagonal(sim_mat, 0)
89
+ return laplacian(sim_mat, normed=False)
90
+
91
+ def get_spec_embs(
92
+ self, laplacian: np.ndarray, k_oracle: int | None = None
93
+ ) -> tuple[np.ndarray, int]:
94
+ """Extract spectral embeddings from Laplacian.
95
+
96
+ Uses the eigengap heuristic to estimate the number of clusters:
97
+ The number of clusters k is chosen where the gap between consecutive
98
+ eigenvalues is largest, indicating a transition from "cluster" eigenvalues
99
+ (near 0) to "noise" eigenvalues.
100
+ """
101
+ lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
102
+
103
+ num_of_spk = k_oracle if k_oracle is not None else self._estimate_num_speakers(lambdas)
104
+
105
+ emb = eig_vecs[:, :num_of_spk]
106
+ return emb, num_of_spk
107
+
108
+ def _estimate_num_speakers(self, lambdas: np.ndarray) -> int:
109
+ """Estimate number of speakers using refined eigengap heuristic.
110
+
111
+ For spectral clustering, we look for the largest gap in eigenvalues.
112
+ The eigenvalues corresponding to clusters are close to 0, and there
113
+ should be a significant jump to the remaining eigenvalues.
114
+ """
115
+ # Consider eigenvalues from index 1 to max_num_spks (skip first, it's always ~0)
116
+ # We need gaps between positions, so look at indices 1 to max_num_spks+1
117
+ max_idx = min(self.max_num_spks + 1, len(lambdas))
118
+ relevant_lambdas = lambdas[1:max_idx] # Skip first eigenvalue
119
+
120
+ if len(relevant_lambdas) < 2:
121
+ return self.min_num_spks
122
+
123
+ # Compute absolute gaps (not ratios - ratios are unstable near 0)
124
+ gaps = np.diff(relevant_lambdas)
125
+
126
+ # Find the largest gap - the index gives us (k-1) since we skipped first
127
+ # Add 1 to convert from gap index to number of speakers
128
+ # Add 1 again because we skipped the first eigenvalue
129
+ max_gap_idx = int(np.argmax(gaps))
130
+ num_of_spk = max_gap_idx + 2 # +1 for gap->count, +1 for skipped eigenvalue
131
+
132
+ # Clamp between min and max
133
+ return max(self.min_num_spks, min(num_of_spk, self.max_num_spks))
134
+
135
+ def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
136
+ """Cluster spectral embeddings using k-means."""
137
+ _, labels, _ = k_means(emb, k, n_init=10)
138
+ return labels
139
+
140
+
141
+ class SpeakerClusterer:
142
+ """Speaker clustering backend using spectral clustering with speaker merging.
143
+
144
+ Features:
145
+ - Spectral clustering with eigenvalue gap for auto speaker count detection
146
+ - P-pruning for affinity matrix refinement
147
+ - Post-clustering speaker merging by cosine similarity
148
+ """
149
+
150
+ def __init__(
151
+ self,
152
+ min_num_spks: int = 2,
153
+ max_num_spks: int = 10,
154
+ merge_thr: float = 0.90, # Moderate merging
155
+ ):
156
+ self.min_num_spks = min_num_spks
157
+ self.max_num_spks = max_num_spks
158
+ self.merge_thr = merge_thr
159
+ self._spectral_cluster: SpectralCluster | None = None
160
+
161
+ def _get_spectral_cluster(self) -> SpectralCluster:
162
+ """Lazy-load spectral clusterer."""
163
+ if self._spectral_cluster is None:
164
+ self._spectral_cluster = SpectralCluster(
165
+ min_num_spks=self.min_num_spks,
166
+ max_num_spks=self.max_num_spks,
167
+ )
168
+ return self._spectral_cluster
169
+
170
+ def __call__(self, embeddings: np.ndarray, num_speakers: int | None = None) -> np.ndarray:
171
+ """Cluster speaker embeddings and return labels.
172
+
173
+ Args:
174
+ embeddings: Speaker embeddings of shape [N, D]
175
+ num_speakers: Optional oracle number of speakers
176
+
177
+ Returns:
178
+ Cluster labels of shape [N]
179
+ """
180
+ import warnings
181
+
182
+ if len(embeddings.shape) != 2:
183
+ raise ValueError(f"Expected 2D array, got shape {embeddings.shape}")
184
+
185
+ # Handle edge cases
186
+ if embeddings.shape[0] == 0:
187
+ return np.array([], dtype=int)
188
+ if embeddings.shape[0] == 1:
189
+ return np.array([0], dtype=int)
190
+ if embeddings.shape[0] < 6:
191
+ return np.zeros(embeddings.shape[0], dtype=int)
192
+
193
+ # Normalize embeddings and replace NaN/inf
194
+ embeddings = np.nan_to_num(embeddings, nan=0.0, posinf=0.0, neginf=0.0)
195
+ embeddings = normalize(embeddings)
196
+
197
+ # Run spectral clustering (suppress numerical warnings)
198
+ spectral = self._get_spectral_cluster()
199
+
200
+ # Update min/max for oracle case
201
+ if num_speakers is not None:
202
+ spectral.min_num_spks = num_speakers
203
+ spectral.max_num_spks = num_speakers
204
+
205
+ with warnings.catch_warnings():
206
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
207
+ labels = spectral(embeddings, oracle_num=num_speakers)
208
+
209
+ # Reset min/max
210
+ if num_speakers is not None:
211
+ spectral.min_num_spks = self.min_num_spks
212
+ spectral.max_num_spks = self.max_num_spks
213
+
214
+ # Merge similar speakers if no oracle
215
+ if num_speakers is None:
216
+ labels = self._merge_by_cos(labels, embeddings, self.merge_thr)
217
+
218
+ # Re-index labels sequentially
219
+ _, labels = np.unique(labels, return_inverse=True)
220
+
221
+ return labels
222
+
223
+ def _merge_by_cos(self, labels: np.ndarray, embs: np.ndarray, cos_thr: float) -> np.ndarray:
224
+ """Merge similar speakers by cosine similarity of centroids."""
225
+ from scipy.cluster.hierarchy import fcluster, linkage
226
+ from scipy.spatial.distance import pdist
227
+
228
+ unique_labels = np.unique(labels)
229
+ if len(unique_labels) <= 1:
230
+ return labels
231
+
232
+ # Compute normalized speaker centroids
233
+ centroids = np.array([embs[labels == lbl].mean(0) for lbl in unique_labels])
234
+ centroids = normalize(centroids)
235
+
236
+ # Hierarchical clustering with cosine distance
237
+ distances = pdist(centroids, metric="cosine")
238
+ linkage_matrix = linkage(distances, method="average")
239
+ merged_labels = fcluster(linkage_matrix, t=1.0 - cos_thr, criterion="distance") - 1
240
+
241
+ # Map original labels to merged labels
242
+ label_map = dict(zip(unique_labels, merged_labels))
243
+ return np.array([label_map[lbl] for lbl in labels])
244
+
245
+
246
+ class LocalSpeakerDiarizer:
247
+ """Local speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
248
+
249
+ Pipeline:
250
+ 1. TEN-VAD detects speech segments
251
+ 2. Sliding window (1.0s, 75% overlap) for uniform embedding extraction
252
+ 3. ECAPA-TDNN extracts speaker embeddings per window
253
+ 4. Spectral clustering with eigenvalue gap for auto speaker detection
254
+ 5. Frame-level consensus voting for segment reconstruction
255
+ 6. Post-processing merges short segments to reduce flicker
256
+
257
+ Tunable Parameters (class attributes):
258
+ - WINDOW_SIZE: Embedding extraction window size in seconds
259
+ - STEP_SIZE: Sliding window step size (overlap = WINDOW_SIZE - STEP_SIZE)
260
+ - VAD_THRESHOLD: Speech detection threshold (lower = more sensitive)
261
+ - VAD_MIN_DURATION: Minimum speech segment duration
262
+ - VAD_MAX_GAP: Maximum gap to bridge between segments
263
+ - VAD_PAD_ONSET/OFFSET: Padding added to speech segments
264
+ - VOTING_RATE: Frame resolution for consensus voting
265
+ - MIN_SEGMENT_DURATION: Minimum final segment duration
266
+ - SAME_SPEAKER_GAP: Maximum gap to merge same-speaker segments
267
+ - TAIL_COVERAGE_RATIO: Minimum tail coverage to add extra window
268
+ """
269
+
270
+ _ten_vad_model = None
271
+ _ecapa_model = None
272
+ _device = None
273
+
274
+ # ==================== TUNABLE PARAMETERS ====================
275
+
276
+ # Sliding window for embedding extraction
277
+ WINDOW_SIZE = 0.75 # seconds - shorter window for finer resolution
278
+ STEP_SIZE = 0.15 # seconds (80% overlap for more votes)
279
+ TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
280
+
281
+ # VAD hysteresis parameters
282
+ VAD_THRESHOLD = 0.25 # Balanced threshold
283
+ VAD_MIN_DURATION = 0.05 # Minimum speech segment duration (seconds)
284
+ VAD_MAX_GAP = 0.50 # Bridge gaps shorter than this (seconds)
285
+ VAD_PAD_ONSET = 0.05 # Padding at segment start (seconds)
286
+ VAD_PAD_OFFSET = 0.05 # Padding at segment end (seconds)
287
+
288
+ # Frame-level voting
289
+ VOTING_RATE = 0.01 # 10ms resolution for consensus voting
290
+
291
+ # Post-processing
292
+ MIN_SEGMENT_DURATION = 0.15 # Minimum final segment duration (seconds)
293
+ SHORT_SEGMENT_GAP = 0.1 # Gap threshold for merging short segments
294
+ SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
295
+
296
+ # ===========================================================
297
+
298
+ @classmethod
299
+ def _get_ten_vad_model(cls):
300
+ """Lazy-load TEN-VAD model (singleton)."""
301
+ if cls._ten_vad_model is None:
302
+ from ten_vad import TenVad
303
+
304
+ cls._ten_vad_model = TenVad(hop_size=256, threshold=cls.VAD_THRESHOLD)
305
+ return cls._ten_vad_model
306
+
307
+ @classmethod
308
+ def _get_device(cls) -> torch.device:
309
+ """Get the best available device."""
310
+ if cls._device is None:
311
+ cls._device = _get_device()
312
+ return cls._device
313
+
314
+ @classmethod
315
+ def _get_ecapa_model(cls):
316
+ """Lazy-load ECAPA-TDNN speaker embedding model (singleton)."""
317
+ if cls._ecapa_model is None:
318
+ # Suppress torchaudio deprecation warning from SpeechBrain
319
+ with warnings.catch_warnings():
320
+ warnings.filterwarnings("ignore", message="torchaudio._backend")
321
+ from speechbrain.inference.speaker import EncoderClassifier
322
+
323
+ device = cls._get_device()
324
+ cls._ecapa_model = EncoderClassifier.from_hparams(
325
+ source="speechbrain/spkrec-ecapa-voxceleb",
326
+ run_opts={"device": str(device)},
327
+ )
328
+
329
+ return cls._ecapa_model
330
+
331
+ @classmethod
332
+ def diarize(
333
+ cls,
334
+ audio: np.ndarray | str,
335
+ sample_rate: int = 16000,
336
+ num_speakers: int | None = None,
337
+ min_speakers: int = 2,
338
+ max_speakers: int = 10,
339
+ **_kwargs,
340
+ ) -> list[dict]:
341
+ """Run speaker diarization on audio.
342
+
343
+ Args:
344
+ audio: Audio waveform as numpy array or path to audio file
345
+ sample_rate: Audio sample rate (default 16000)
346
+ num_speakers: Exact number of speakers (if known)
347
+ min_speakers: Minimum number of speakers
348
+ max_speakers: Maximum number of speakers
349
+
350
+ Returns:
351
+ List of dicts with 'speaker', 'start', 'end' keys
352
+ """
353
+ # Handle file path input
354
+ if isinstance(audio, str):
355
+ import librosa
356
+
357
+ audio, sample_rate = librosa.load(audio, sr=16000)
358
+
359
+ # Ensure correct sample rate
360
+ if sample_rate != 16000:
361
+ import librosa
362
+
363
+ audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
364
+ sample_rate = 16000
365
+
366
+ audio = audio.astype(np.float32)
367
+ total_duration = len(audio) / sample_rate
368
+
369
+ # Step 1: VAD (returns segments and raw frame-level decisions)
370
+ segments, vad_frames = cls._get_speech_segments(audio, sample_rate)
371
+ if not segments:
372
+ return []
373
+
374
+ # Step 2: Extract embeddings
375
+ embeddings, window_segments = cls._extract_embeddings(audio, segments, sample_rate)
376
+ if len(embeddings) == 0:
377
+ return []
378
+
379
+ # Step 3: Cluster
380
+ clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
381
+ labels = clusterer(embeddings, num_speakers)
382
+
383
+ # Step 4: Post-process with consensus voting (VAD-aware)
384
+ return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
385
+
386
+ @classmethod
387
+ def _get_speech_segments(
388
+ cls, audio_array: np.ndarray, sample_rate: int = 16000
389
+ ) -> tuple[list[dict], list[bool]]:
390
+ """Get speech segments using TEN-VAD.
391
+
392
+ Returns:
393
+ Tuple of (segments list, vad_frames list of per-frame speech decisions)
394
+ """
395
+ vad_model = cls._get_ten_vad_model()
396
+
397
+ # Convert to int16 as required by TEN-VAD
398
+ # Clip to prevent integer overflow
399
+ if audio_array.dtype != np.int16:
400
+ audio_int16 = (np.clip(audio_array, -1.0, 1.0) * 32767).astype(np.int16)
401
+ else:
402
+ audio_int16 = audio_array
403
+
404
+ # Process frame by frame
405
+ hop_size = 256
406
+ frame_duration = hop_size / sample_rate
407
+ speech_frames: list[bool] = []
408
+
409
+ for i in range(0, len(audio_int16) - hop_size, hop_size):
410
+ frame = audio_int16[i : i + hop_size]
411
+ _, is_speech = vad_model.process(frame)
412
+ speech_frames.append(is_speech)
413
+
414
+ # Convert frame-level decisions to segments
415
+ segments = []
416
+ in_speech = False
417
+ start_idx = 0
418
+
419
+ for i, is_speech in enumerate(speech_frames):
420
+ if is_speech and not in_speech:
421
+ start_idx = i
422
+ in_speech = True
423
+ elif not is_speech and in_speech:
424
+ start_time = start_idx * frame_duration
425
+ end_time = i * frame_duration
426
+ segments.append(
427
+ {
428
+ "start": start_time,
429
+ "end": end_time,
430
+ "start_sample": int(start_time * sample_rate),
431
+ "end_sample": int(end_time * sample_rate),
432
+ }
433
+ )
434
+ in_speech = False
435
+
436
+ # Handle trailing speech
437
+ if in_speech:
438
+ start_time = start_idx * frame_duration
439
+ end_time = len(speech_frames) * frame_duration
440
+ segments.append(
441
+ {
442
+ "start": start_time,
443
+ "end": end_time,
444
+ "start_sample": int(start_time * sample_rate),
445
+ "end_sample": int(end_time * sample_rate),
446
+ }
447
+ )
448
+
449
+ return cls._apply_vad_hysteresis(segments, sample_rate), speech_frames
450
+
451
+ @classmethod
452
+ def _apply_vad_hysteresis(cls, segments: list[dict], sample_rate: int = 16000) -> list[dict]:
453
+ """Apply hysteresis-like post-processing to VAD segments."""
454
+ if not segments:
455
+ return segments
456
+
457
+ segments = sorted(segments, key=lambda x: x["start"])
458
+
459
+ # Fill short gaps
460
+ merged = [segments[0].copy()]
461
+ for seg in segments[1:]:
462
+ gap = seg["start"] - merged[-1]["end"]
463
+ if gap <= cls.VAD_MAX_GAP:
464
+ merged[-1]["end"] = seg["end"]
465
+ merged[-1]["end_sample"] = seg["end_sample"]
466
+ else:
467
+ merged.append(seg.copy())
468
+
469
+ # Remove short segments
470
+ filtered = [seg for seg in merged if (seg["end"] - seg["start"]) >= cls.VAD_MIN_DURATION]
471
+
472
+ # Dilate segments (add padding)
473
+ for seg in filtered:
474
+ seg["start"] = max(0.0, seg["start"] - cls.VAD_PAD_ONSET)
475
+ seg["end"] = seg["end"] + cls.VAD_PAD_OFFSET
476
+ seg["start_sample"] = int(seg["start"] * sample_rate)
477
+ seg["end_sample"] = int(seg["end"] * sample_rate)
478
+
479
+ return filtered
480
+
481
+ @classmethod
482
+ def _extract_embeddings(
483
+ cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
484
+ ) -> tuple[np.ndarray, list[dict]]:
485
+ """Extract speaker embeddings using sliding windows."""
486
+ speaker_model = cls._get_ecapa_model()
487
+
488
+ window_samples = int(cls.WINDOW_SIZE * sample_rate)
489
+ step_samples = int(cls.STEP_SIZE * sample_rate)
490
+
491
+ embeddings = []
492
+ window_segments = []
493
+
494
+ with torch.no_grad():
495
+ for seg in segments:
496
+ seg_start = seg["start_sample"]
497
+ seg_end = seg["end_sample"]
498
+ seg_len = seg_end - seg_start
499
+
500
+ # Generate window positions
501
+ if seg_len <= window_samples:
502
+ starts = [seg_start]
503
+ ends = [seg_end]
504
+ else:
505
+ starts = list(range(seg_start, seg_end - window_samples + 1, step_samples))
506
+ ends = [s + window_samples for s in starts]
507
+
508
+ # Cover tail if > TAIL_COVERAGE_RATIO of window remains
509
+ if ends and ends[-1] < seg_end:
510
+ remainder = seg_end - ends[-1]
511
+ if remainder > (window_samples * cls.TAIL_COVERAGE_RATIO):
512
+ starts.append(seg_end - window_samples)
513
+ ends.append(seg_end)
514
+
515
+ for c_start, c_end in zip(starts, ends):
516
+ chunk = audio_array[c_start:c_end]
517
+
518
+ # Pad short chunks with reflection
519
+ if len(chunk) < window_samples:
520
+ pad_width = window_samples - len(chunk)
521
+ chunk = np.pad(chunk, (0, pad_width), mode="reflect")
522
+
523
+ # Extract embedding using SpeechBrain's encode_batch
524
+ chunk_tensor = torch.from_numpy(chunk).float().unsqueeze(0)
525
+ embedding = (
526
+ speaker_model.encode_batch(chunk_tensor).squeeze(0).squeeze(0).cpu().numpy()
527
+ )
528
+
529
+ # Validate embedding
530
+ if np.isfinite(embedding).all() and np.linalg.norm(embedding) > 1e-8:
531
+ embeddings.append(embedding)
532
+ window_segments.append(
533
+ {
534
+ "start": c_start / sample_rate,
535
+ "end": c_end / sample_rate,
536
+ }
537
+ )
538
+
539
+ # Normalize all embeddings at once
540
+ if embeddings:
541
+ return normalize(np.array(embeddings)), window_segments
542
+ return np.array([]), []
543
+
544
+ @classmethod
545
+ def _resample_vad(cls, vad_frames: list[bool], num_frames: int) -> np.ndarray:
546
+ """Resample VAD frame decisions to match voting grid resolution.
547
+
548
+ VAD operates at 256 samples / 16000 Hz = 16ms per frame.
549
+ Voting operates at VOTING_RATE (default 10ms) per frame.
550
+ This maps VAD decisions to the finer voting grid.
551
+ """
552
+ if not vad_frames:
553
+ return np.zeros(num_frames, dtype=bool)
554
+
555
+ vad_rate = 256 / 16000 # 16ms per VAD frame
556
+ vad_arr = np.array(vad_frames)
557
+
558
+ # Vectorized: compute VAD frame indices for each voting frame
559
+ voting_times = np.arange(num_frames) * cls.VOTING_RATE
560
+ vad_indices = np.clip((voting_times / vad_rate).astype(int), 0, len(vad_arr) - 1)
561
+ return vad_arr[vad_indices]
562
+
563
+ @classmethod
564
+ def _postprocess_segments(
565
+ cls,
566
+ window_segments: list[dict],
567
+ labels: np.ndarray,
568
+ total_duration: float,
569
+ vad_frames: list[bool],
570
+ ) -> list[dict]:
571
+ """Post-process using frame-level consensus voting with VAD-aware silence."""
572
+ if not window_segments or len(labels) == 0:
573
+ return []
574
+
575
+ # Correct labels to be contiguous
576
+ unique_labels = np.unique(labels)
577
+ label_map = {old: new for new, old in enumerate(unique_labels)}
578
+ clean_labels = np.array([label_map[lbl] for lbl in labels])
579
+ num_speakers = len(unique_labels)
580
+
581
+ if num_speakers == 0:
582
+ return []
583
+
584
+ # Create voting grid
585
+ num_frames = int(np.ceil(total_duration / cls.VOTING_RATE)) + 1
586
+ votes = np.zeros((num_frames, num_speakers), dtype=np.float32)
587
+
588
+ # Accumulate votes
589
+ for win, label in zip(window_segments, clean_labels):
590
+ start_frame = int(win["start"] / cls.VOTING_RATE)
591
+ end_frame = int(win["end"] / cls.VOTING_RATE)
592
+ end_frame = min(end_frame, num_frames)
593
+ if start_frame < end_frame:
594
+ votes[start_frame:end_frame, label] += 1.0
595
+
596
+ # Determine winner per frame
597
+ frame_speakers = np.argmax(votes, axis=1)
598
+ max_votes = np.max(votes, axis=1)
599
+
600
+ # Resample VAD to voting grid resolution for silence-aware voting
601
+ vad_resampled = cls._resample_vad(vad_frames, num_frames)
602
+
603
+ # Convert frames to segments
604
+ final_segments = []
605
+ current_speaker = -1
606
+ seg_start = 0.0
607
+
608
+ for f in range(num_frames):
609
+ speaker = int(frame_speakers[f])
610
+ score = max_votes[f]
611
+
612
+ # Force silence if VAD says no speech OR no votes
613
+ if score == 0 or not vad_resampled[f]:
614
+ speaker = -1
615
+
616
+ if speaker != current_speaker:
617
+ if current_speaker != -1:
618
+ final_segments.append(
619
+ {
620
+ "speaker": f"SPEAKER_{current_speaker}",
621
+ "start": seg_start,
622
+ "end": f * cls.VOTING_RATE,
623
+ }
624
+ )
625
+ current_speaker = speaker
626
+ seg_start = f * cls.VOTING_RATE
627
+
628
+ # Close last segment
629
+ if current_speaker != -1:
630
+ final_segments.append(
631
+ {
632
+ "speaker": f"SPEAKER_{current_speaker}",
633
+ "start": seg_start,
634
+ "end": num_frames * cls.VOTING_RATE,
635
+ }
636
+ )
637
+
638
+ return cls._merge_short_segments(final_segments)
639
+
640
+ @classmethod
641
+ def _merge_short_segments(cls, segments: list[dict]) -> list[dict]:
642
+ """Merge short segments to reduce flicker."""
643
+ if not segments:
644
+ return []
645
+
646
+ clean: list[dict] = []
647
+ for seg in segments:
648
+ dur = seg["end"] - seg["start"]
649
+ if dur < cls.MIN_SEGMENT_DURATION:
650
+ if (
651
+ clean
652
+ and clean[-1]["speaker"] == seg["speaker"]
653
+ and seg["start"] - clean[-1]["end"] < cls.SHORT_SEGMENT_GAP
654
+ ):
655
+ clean[-1]["end"] = seg["end"]
656
+ continue
657
+
658
+ if (
659
+ clean
660
+ and clean[-1]["speaker"] == seg["speaker"]
661
+ and seg["start"] - clean[-1]["end"] < cls.SAME_SPEAKER_GAP
662
+ ):
663
+ clean[-1]["end"] = seg["end"]
664
+ else:
665
+ clean.append(seg)
666
+
667
+ return clean
668
+
669
+ @classmethod
670
+ def assign_speakers_to_words(
671
+ cls,
672
+ words: list[dict],
673
+ speaker_segments: list[dict],
674
+ ) -> list[dict]:
675
+ """Assign speaker labels to words based on timestamp overlap.
676
+
677
+ Args:
678
+ words: List of word dicts with 'word', 'start', 'end' keys
679
+ speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
680
+
681
+ Returns:
682
+ Words list with 'speaker' key added to each word
683
+ """
684
+ for word in words:
685
+ word_mid = (word["start"] + word["end"]) / 2
686
+
687
+ # Find the speaker segment that contains this word's midpoint
688
+ best_speaker = None
689
+ for seg in speaker_segments:
690
+ if seg["start"] <= word_mid <= seg["end"]:
691
+ best_speaker = seg["speaker"]
692
+ break
693
+
694
+ # If no exact match, find closest segment
695
+ if best_speaker is None and speaker_segments:
696
+ min_dist = float("inf")
697
+ for seg in speaker_segments:
698
+ seg_mid = (seg["start"] + seg["end"]) / 2
699
+ dist = abs(word_mid - seg_mid)
700
+ if dist < min_dist:
701
+ min_dist = dist
702
+ best_speaker = seg["speaker"]
703
+
704
+ word["speaker"] = best_speaker
705
+
706
+ return words
full_duplex.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full-duplex audio session for speech-to-speech.
2
+
3
+ Implements Freeze-Omni style full-duplex conversation where the model
4
+ can listen and speak simultaneously, with support for user interruption.
5
+
6
+ Architecture:
7
+ - Dual queue system: PCMQueue (input) + AudioQueue (output)
8
+ - Multi-threaded: Listen thread + Generate thread run concurrently
9
+ - State machine: listen -> speak -> (interrupt) -> listen
10
+ - VAD-based turn detection using model's built-in Silero VAD
11
+
12
+ Usage (sync):
13
+ session = FullDuplexSession(model)
14
+ session.start()
15
+
16
+ while has_audio:
17
+ session.push_audio(audio_chunk)
18
+ output = session.pop_audio()
19
+ if output is not None:
20
+ speaker.play(output)
21
+
22
+ session.stop()
23
+
24
+ Usage (async/web):
25
+ session = FullDuplexSession(
26
+ model,
27
+ on_state_change=lambda s: send_status(s),
28
+ on_text=lambda t: send_text(t),
29
+ on_audio=lambda a: send_audio(a),
30
+ )
31
+ session.start()
32
+
33
+ # In your receive loop:
34
+ session.push_audio(audio_chunk)
35
+ """
36
+
37
+ import logging
38
+ import queue
39
+ import threading
40
+ import time
41
+ from dataclasses import dataclass, field
42
+ from enum import Enum
43
+ from typing import TYPE_CHECKING, Callable, Optional
44
+
45
+ import numpy as np
46
+ import torch
47
+
48
+ if TYPE_CHECKING:
49
+ from .asr_modeling import ASRModel
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+
54
+ class ConversationState(Enum):
55
+ """State machine for full-duplex conversation."""
56
+
57
+ IDLE = "idle"
58
+ LISTENING = "listening"
59
+ PROCESSING = "processing"
60
+ SPEAKING = "speaking"
61
+
62
+
63
+ @dataclass
64
+ class FullDuplexConfig:
65
+ """Configuration for full-duplex session."""
66
+
67
+ # Audio settings
68
+ sample_rate: int = 16000
69
+ chunk_size: int = 512 # Samples per chunk (32ms at 16kHz)
70
+ output_sample_rate: int = 44100 # DAC output rate
71
+
72
+ # VAD settings
73
+ vad_threshold: float = 0.5
74
+ silence_duration_ms: float = 700 # Silence to end turn
75
+ min_speech_duration_ms: float = 100 # Minimum speech to trigger
76
+
77
+ # Generation settings
78
+ audio_chunk_size: int = 4 # Tokens per audio chunk
79
+
80
+ # Timing
81
+ poll_interval: float = 0.01
82
+
83
+
84
+ class PCMQueue:
85
+ """Thread-safe queue for streaming PCM audio input."""
86
+
87
+ def __init__(self):
88
+ self.buffer = np.array([], dtype=np.float32)
89
+ self.lock = threading.Lock()
90
+
91
+ def put(self, audio: np.ndarray) -> None:
92
+ with self.lock:
93
+ self.buffer = np.concatenate([self.buffer, audio.astype(np.float32)])
94
+
95
+ def get(self, length: int) -> Optional[np.ndarray]:
96
+ with self.lock:
97
+ if len(self.buffer) < length:
98
+ return None
99
+ result = self.buffer[:length]
100
+ self.buffer = self.buffer[length:]
101
+ return result
102
+
103
+ def clear(self) -> None:
104
+ with self.lock:
105
+ self.buffer = np.array([], dtype=np.float32)
106
+
107
+ def __len__(self) -> int:
108
+ with self.lock:
109
+ return len(self.buffer)
110
+
111
+
112
+ class AudioQueue:
113
+ """Thread-safe queue for output audio chunks."""
114
+
115
+ def __init__(self):
116
+ self._queue: queue.Queue = queue.Queue()
117
+
118
+ def put(self, audio: torch.Tensor) -> None:
119
+ self._queue.put(audio)
120
+
121
+ def get(self) -> Optional[torch.Tensor]:
122
+ try:
123
+ return self._queue.get_nowait()
124
+ except queue.Empty:
125
+ return None
126
+
127
+ def clear(self) -> None:
128
+ while not self._queue.empty():
129
+ try:
130
+ self._queue.get_nowait()
131
+ except queue.Empty:
132
+ break
133
+
134
+ def is_empty(self) -> bool:
135
+ return self._queue.empty()
136
+
137
+
138
+ @dataclass
139
+ class _SessionState:
140
+ """Internal state for full-duplex session."""
141
+
142
+ state: ConversationState = ConversationState.IDLE
143
+ speech_buffer: list = field(default_factory=list)
144
+ speech_start_time: float = 0.0
145
+ last_speech_time: float = 0.0
146
+ silence_frames: int = 0
147
+ stop_generate: bool = False
148
+ is_generating: bool = False
149
+ generated_text: str = ""
150
+
151
+
152
+ class FullDuplexSession:
153
+ """Full-duplex speech-to-speech session (Freeze-Omni style).
154
+
155
+ Manages simultaneous listening and speaking with VAD-based turn detection.
156
+ Designed to be easy to integrate with both sync and async (web) applications.
157
+
158
+ Args:
159
+ model: ASRModel with audio_head configured
160
+ config: FullDuplexConfig for session parameters
161
+ on_state_change: Callback when state changes (state: ConversationState)
162
+ on_text: Callback when text is generated (text: str, interim: bool)
163
+ on_audio: Callback when audio chunk is ready (audio: torch.Tensor)
164
+ If provided, audio is sent here instead of output_queue
165
+ on_interrupted: Callback when generation is interrupted
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ model: "ASRModel",
171
+ config: Optional[FullDuplexConfig] = None,
172
+ on_state_change: Optional[Callable[[ConversationState], None]] = None,
173
+ on_text: Optional[Callable[[str, bool], None]] = None,
174
+ on_audio: Optional[Callable[[torch.Tensor], None]] = None,
175
+ on_interrupted: Optional[Callable[[], None]] = None,
176
+ ):
177
+ self.model = model
178
+ self.config = config or FullDuplexConfig()
179
+
180
+ # Callbacks
181
+ self.on_state_change = on_state_change
182
+ self.on_text = on_text
183
+ self.on_audio = on_audio
184
+ self.on_interrupted = on_interrupted
185
+
186
+ # Queues
187
+ self.input_queue = PCMQueue()
188
+ self.output_queue = AudioQueue()
189
+
190
+ # State
191
+ self._state = _SessionState()
192
+ self._running = False
193
+ self._state_lock = threading.Lock()
194
+
195
+ # Threads
196
+ self._listen_thread: Optional[threading.Thread] = None
197
+ self._generate_thread: Optional[threading.Thread] = None
198
+
199
+ # Precompute timing thresholds
200
+ ms_per_chunk = self.config.chunk_size / self.config.sample_rate * 1000
201
+ self._silence_threshold = int(self.config.silence_duration_ms / ms_per_chunk)
202
+ self._min_speech_chunks = int(self.config.min_speech_duration_ms / ms_per_chunk)
203
+
204
+ # Ensure VAD is loaded
205
+ self.model.load_vad()
206
+
207
+ @property
208
+ def state(self) -> ConversationState:
209
+ with self._state_lock:
210
+ return self._state.state
211
+
212
+ def _set_state(self, value: ConversationState) -> None:
213
+ with self._state_lock:
214
+ old_state = self._state.state
215
+ self._state.state = value
216
+ if old_state != value:
217
+ logger.debug(f"State: {old_state.value} -> {value.value}")
218
+ if self.on_state_change:
219
+ try:
220
+ self.on_state_change(value)
221
+ except Exception as e:
222
+ logger.error(f"on_state_change callback error: {e}")
223
+
224
+ @property
225
+ def is_generating(self) -> bool:
226
+ with self._state_lock:
227
+ return self._state.is_generating
228
+
229
+ @property
230
+ def generated_text(self) -> str:
231
+ with self._state_lock:
232
+ return self._state.generated_text
233
+
234
+ def start(self) -> None:
235
+ """Start the full-duplex session."""
236
+ if self._running:
237
+ return
238
+
239
+ self._running = True
240
+ self._set_state(ConversationState.LISTENING)
241
+
242
+ self._listen_thread = threading.Thread(target=self._listen_loop, daemon=True)
243
+ self._listen_thread.start()
244
+
245
+ logger.info("Full-duplex session started")
246
+
247
+ def stop(self) -> None:
248
+ """Stop the full-duplex session."""
249
+ self._running = False
250
+
251
+ with self._state_lock:
252
+ self._state.stop_generate = True
253
+
254
+ if self._listen_thread:
255
+ self._listen_thread.join(timeout=2.0)
256
+ if self._generate_thread:
257
+ self._generate_thread.join(timeout=2.0)
258
+
259
+ self.input_queue.clear()
260
+ self.output_queue.clear()
261
+ self._set_state(ConversationState.IDLE)
262
+
263
+ logger.info("Full-duplex session stopped")
264
+
265
+ def push_audio(self, audio: np.ndarray) -> None:
266
+ """Push audio samples to the input queue.
267
+
268
+ Args:
269
+ audio: Audio samples as numpy array (float32 normalized or int16)
270
+ """
271
+ if audio.dtype == np.int16:
272
+ audio = audio.astype(np.float32) / 32768.0
273
+ self.input_queue.put(audio)
274
+
275
+ def pop_audio(self) -> Optional[torch.Tensor]:
276
+ """Pop generated audio from the output queue.
277
+
278
+ Only used if on_audio callback is not set.
279
+
280
+ Returns:
281
+ Audio tensor [samples] or None
282
+ """
283
+ return self.output_queue.get()
284
+
285
+ def interrupt(self) -> None:
286
+ """Interrupt current generation and return to listening."""
287
+ with self._state_lock:
288
+ self._state.stop_generate = True
289
+
290
+ # Wait for generation to stop
291
+ timeout = 2.0
292
+ start = time.time()
293
+ while self._state.is_generating and (time.time() - start) < timeout:
294
+ time.sleep(self.config.poll_interval)
295
+
296
+ # Clear output queue
297
+ self.output_queue.clear()
298
+
299
+ # Reset state
300
+ with self._state_lock:
301
+ self._state.stop_generate = False
302
+ self._state.generated_text = ""
303
+ self._state.speech_buffer.clear()
304
+ self._state.silence_frames = 0
305
+
306
+ self._set_state(ConversationState.LISTENING)
307
+ self.model.reset_vad_state()
308
+
309
+ if self.on_interrupted:
310
+ try:
311
+ self.on_interrupted()
312
+ except Exception as e:
313
+ logger.error(f"on_interrupted callback error: {e}")
314
+
315
+ logger.debug("Generation interrupted")
316
+
317
+ def _emit_audio(self, audio: torch.Tensor) -> None:
318
+ """Send audio to callback or queue."""
319
+ if self.on_audio:
320
+ try:
321
+ self.on_audio(audio)
322
+ except Exception as e:
323
+ logger.error(f"on_audio callback error: {e}")
324
+ else:
325
+ self.output_queue.put(audio)
326
+
327
+ def _emit_text(self, text: str, interim: bool = False) -> None:
328
+ """Send text to callback."""
329
+ if self.on_text:
330
+ try:
331
+ self.on_text(text, interim)
332
+ except Exception as e:
333
+ logger.error(f"on_text callback error: {e}")
334
+
335
+ def _listen_loop(self) -> None:
336
+ """Main listening loop - processes audio and detects speech."""
337
+ is_speaking = False
338
+
339
+ while self._running:
340
+ audio = self.input_queue.get(self.config.chunk_size)
341
+ if audio is None:
342
+ time.sleep(self.config.poll_interval)
343
+ continue
344
+
345
+ # Run VAD
346
+ audio_tensor = torch.from_numpy(audio)
347
+ is_speech, prob = self.model.detect_speech(
348
+ audio_tensor,
349
+ self.config.sample_rate,
350
+ self.config.vad_threshold,
351
+ )
352
+
353
+ current_time = time.time()
354
+
355
+ # Check for interruption during generation
356
+ if self._state.is_generating and is_speech:
357
+ logger.debug(f"Interruption detected (prob={prob:.2f})")
358
+ self.interrupt()
359
+ # Start new utterance with this chunk
360
+ is_speaking = True
361
+ with self._state_lock:
362
+ self._state.speech_buffer = [audio]
363
+ self._state.speech_start_time = current_time
364
+ self._state.last_speech_time = current_time
365
+ self._state.silence_frames = 0
366
+ continue
367
+
368
+ # Normal VAD state machine
369
+ if is_speech:
370
+ if not is_speaking:
371
+ is_speaking = True
372
+ with self._state_lock:
373
+ self._state.speech_buffer = []
374
+ self._state.speech_start_time = current_time
375
+ with self._state_lock:
376
+ self._state.speech_buffer.append(audio)
377
+ self._state.last_speech_time = current_time
378
+ self._state.silence_frames = 0
379
+
380
+ elif is_speaking:
381
+ with self._state_lock:
382
+ self._state.speech_buffer.append(audio)
383
+ self._state.silence_frames += 1
384
+
385
+ if self._state.silence_frames >= self._silence_threshold:
386
+ is_speaking = False
387
+
388
+ # Check minimum speech duration
389
+ if len(self._state.speech_buffer) >= self._min_speech_chunks:
390
+ speech_audio = np.concatenate(self._state.speech_buffer)
391
+ self._state.speech_buffer = []
392
+ self._state.silence_frames = 0
393
+
394
+ # Start generation
395
+ self._generate_thread = threading.Thread(
396
+ target=self._generate_loop,
397
+ args=(speech_audio,),
398
+ daemon=True,
399
+ )
400
+ self._generate_thread.start()
401
+ else:
402
+ self._state.speech_buffer = []
403
+ self._state.silence_frames = 0
404
+
405
+ def _generate_loop(self, speech_audio: np.ndarray) -> None:
406
+ """Generation loop - produces text and audio response."""
407
+ with self._state_lock:
408
+ self._state.is_generating = True
409
+ self._state.generated_text = ""
410
+ self._state.stop_generate = False
411
+
412
+ try:
413
+ self._set_state(ConversationState.PROCESSING)
414
+
415
+ # Process input audio
416
+ device = next(self.model.language_model.parameters()).device
417
+ inputs = self.model._process_audio(speech_audio, self.config.sample_rate)
418
+ input_features = inputs["input_features"]
419
+ audio_attention_mask = inputs["attention_mask"]
420
+
421
+ # Encode
422
+ audio_embeds = self.model._encode_audio(input_features, audio_attention_mask)
423
+ input_ids, attention_mask = self.model._build_audio_prompt(
424
+ audio_attention_mask, 1, device
425
+ )
426
+ inputs_embeds = self.model._inject_audio_embeddings(input_ids, audio_embeds)
427
+
428
+ # Check for interruption
429
+ if self._state.stop_generate:
430
+ return
431
+
432
+ # Generate text
433
+ with torch.no_grad():
434
+ output = self.model.language_model.generate(
435
+ input_ids=input_ids,
436
+ inputs_embeds=inputs_embeds,
437
+ attention_mask=attention_mask,
438
+ generation_config=self.model.generation_config,
439
+ )
440
+
441
+ if self._state.stop_generate:
442
+ return
443
+
444
+ # Extract text
445
+ text_ids = output[:, input_ids.shape[1] :]
446
+ text = self.model.tokenizer.decode(text_ids[0], skip_special_tokens=True)
447
+
448
+ with self._state_lock:
449
+ self._state.generated_text = text
450
+
451
+ self._emit_text(text, interim=False)
452
+
453
+ if self._state.stop_generate:
454
+ return
455
+
456
+ # Generate audio
457
+ if self.model.audio_head is not None:
458
+ self._set_state(ConversationState.SPEAKING)
459
+
460
+ for audio_chunk in self.model.audio_head.generate_streaming(
461
+ text_token_ids=text_ids,
462
+ ):
463
+ if self._state.stop_generate:
464
+ return
465
+ self._emit_audio(audio_chunk)
466
+
467
+ self._set_state(ConversationState.LISTENING)
468
+
469
+ except Exception as e:
470
+ logger.error(f"Generation error: {e}")
471
+ self._set_state(ConversationState.LISTENING)
472
+
473
+ finally:
474
+ with self._state_lock:
475
+ self._state.is_generating = False
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4df14201e66792d6b8ccefd124852fdc5d47f013a6276f9973540d63956097ad
3
+ size 122303840
preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "dither": 0.0,
4
+ "feature_extractor_type": "WhisperFeatureExtractor",
5
+ "feature_size": 128,
6
+ "hop_length": 160,
7
+ "n_fft": 400,
8
+ "n_samples": 480000,
9
+ "nb_max_frames": 3000,
10
+ "padding": false,
11
+ "padding_side": "right",
12
+ "padding_value": 0.0,
13
+ "return_attention_mask": false,
14
+ "sampling_rate": 16000,
15
+ "processor_class": "ASRProcessor",
16
+ "auto_map": {
17
+ "AutoProcessor": "asr_processing.ASRProcessor"
18
+ }
19
+ }
projectors.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio projector module for bridging encoder and decoder embeddings.
2
+
3
+ MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
9
+
10
+
11
+ class MLPAudioProjector(nn.Module):
12
+ """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
13
+
14
+ def __init__(self, config):
15
+ """Initialize MLP projector.
16
+
17
+ Args:
18
+ config: ASRConfig with encoder_dim, llm_dim, projector_pool_stride
19
+ """
20
+ super().__init__()
21
+
22
+ encoder_dim = getattr(config, "encoder_dim", 768)
23
+ llm_dim = getattr(config, "llm_dim", 2048)
24
+ self.k = getattr(config, "projector_pool_stride", 4)
25
+
26
+ # Frame stacking: concat k adjacent frames then project
27
+ in_dim = encoder_dim * self.k
28
+ # Hidden dim defaults to llm_dim, can be overridden via config
29
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim
30
+ self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False)
31
+ self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
32
+ self.act = nn.GELU()
33
+ self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
34
+
35
+ def get_output_length(self, input_length: int) -> int:
36
+ """Calculate output sequence length given input length (matches GLM-ASR)."""
37
+ # GLM-ASR formula: (L - merge_factor) // merge_factor + 1
38
+ return (input_length - self.k) // self.k + 1
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ """Project audio features to LLM embedding space.
42
+
43
+ Args:
44
+ x: Audio encoder output of shape [batch, seq_len, encoder_dim]
45
+
46
+ Returns:
47
+ Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim]
48
+ """
49
+ batch, seq, dim = x.shape
50
+ # Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
51
+ # This drops trailing frames that don't fill a complete k-frame window
52
+ out_len = (seq - self.k) // self.k + 1
53
+ x = x[:, : out_len * self.k, :] # Truncate to exact multiple
54
+ x = x.reshape(batch, out_len, dim * self.k)
55
+
56
+ x = self.linear_1(x)
57
+ x = self.norm(x)
58
+ x = self.act(x)
59
+ return self.linear_2(x)
60
+
61
+
62
+ PROJECTOR_CLASSES = {
63
+ "mlp": MLPAudioProjector,
64
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
3
+ size 17209003
tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": null,
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|im_end|>",
6
+ "extra_special_tokens": [
7
+ "<audio>"
8
+ ],
9
+ "fast": false,
10
+ "is_local": false,
11
+ "model_input_names": [
12
+ "input_ids",
13
+ "attention_mask"
14
+ ],
15
+ "model_max_length": 131072,
16
+ "model_specific_special_tokens": {},
17
+ "pad_token": "<|finetune_right_pad_id|>",
18
+ "tokenizer_class": "TokenizersBackend"
19
+ }