mazesmazes commited on
Commit
01fed2c
·
verified ·
1 Parent(s): f9e43c3

Training in progress - step 25000

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. asr_config.py +1 -16
  3. asr_modeling.py +2 -180
  4. asr_pipeline.py +2 -172
.gitattributes CHANGED
@@ -1,3 +1,4 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
asr_config.py CHANGED
@@ -70,12 +70,6 @@ class ASRConfig(transformers.PretrainedConfig):
70
  lora_target_modules: Optional[list] = None,
71
  freeze_projector: bool = False,
72
  label_smoothing: float = 0.0,
73
- # Audio Head settings (Freeze-Omni style AR decoder)
74
- use_audio_head: bool = False,
75
- audio_head_hidden_dim: int = 512, # AR decoder hidden dimension
76
- codebook_size: int = 2048, # Mimi codec vocabulary size
77
- num_codebooks: int = 1, # Number of codebooks to predict (first 1-2 most important)
78
- freeze_audio_head: bool = False, # Freeze entire audio head
79
  **kwargs,
80
  ):
81
  # Merge generation defaults with kwargs (kwargs takes precedence)
@@ -140,13 +134,6 @@ class ASRConfig(transformers.PretrainedConfig):
140
  self.freeze_projector = freeze_projector
141
  self.label_smoothing = label_smoothing
142
 
143
- # Audio Head settings (Freeze-Omni style AR decoder)
144
- self.use_audio_head = use_audio_head
145
- self.audio_head_hidden_dim = audio_head_hidden_dim
146
- self.codebook_size = codebook_size
147
- self.num_codebooks = num_codebooks
148
- self.freeze_audio_head = freeze_audio_head
149
-
150
  # Generation parameters (from kwargs after merge with defaults)
151
  self.num_beams = kwargs.pop("num_beams")
152
  self.max_new_tokens = kwargs.pop("max_new_tokens")
@@ -163,9 +150,7 @@ class ASRConfig(transformers.PretrainedConfig):
163
  # Load sub-configs
164
  self.audio_config = kwargs.pop("audio_config", None)
165
  if self.audio_config is None:
166
- self.audio_config = transformers.AutoConfig.from_pretrained(
167
- audio_model_id, trust_remote_code=True
168
- )
169
  self.audio_config.dtype = model_dtype
170
  elif isinstance(self.audio_config, dict) and self.audio_config.get("model_type"):
171
  config_class = transformers.AutoConfig.for_model(
 
70
  lora_target_modules: Optional[list] = None,
71
  freeze_projector: bool = False,
72
  label_smoothing: float = 0.0,
 
 
 
 
 
 
73
  **kwargs,
74
  ):
75
  # Merge generation defaults with kwargs (kwargs takes precedence)
 
134
  self.freeze_projector = freeze_projector
135
  self.label_smoothing = label_smoothing
136
 
 
 
 
 
 
 
 
137
  # Generation parameters (from kwargs after merge with defaults)
138
  self.num_beams = kwargs.pop("num_beams")
139
  self.max_new_tokens = kwargs.pop("max_new_tokens")
 
150
  # Load sub-configs
151
  self.audio_config = kwargs.pop("audio_config", None)
152
  if self.audio_config is None:
153
+ self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
 
 
154
  self.audio_config.dtype = model_dtype
155
  elif isinstance(self.audio_config, dict) and self.audio_config.get("model_type"):
156
  config_class = transformers.AutoConfig.for_model(
asr_modeling.py CHANGED
@@ -181,19 +181,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
181
  else:
182
  self.spec_augment = None
183
 
184
- # Audio Head for S2S (trainable)
185
- if getattr(config, "use_audio_head", False):
186
- from .audio_head import AudioHead
187
-
188
- self.audio_head = AudioHead(config).to(
189
- device=next(self.language_model.parameters()).device,
190
- dtype=target_dtype,
191
- )
192
- if getattr(config, "freeze_audio_head", False):
193
- self.audio_head.requires_grad_(False)
194
- else:
195
- self.audio_head = None
196
-
197
  # For model parallelism
198
  self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
199
 
@@ -378,11 +365,8 @@ class ASRModel(PreTrainedModel, GenerationMixin):
378
  )
379
 
380
  def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
381
- """Save trainable weights (projector + audio_head if present)."""
382
- state = {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
383
- if self.audio_head is not None:
384
- state.update({f"audio_head.{k}": v for k, v in self.audio_head.state_dict().items()})
385
- return state
386
 
387
  def _compute_encoder_output_lengths(
388
  self,
@@ -476,8 +460,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
476
  labels: Optional[torch.Tensor] = None,
477
  use_cache: Optional[bool] = None,
478
  cache_position: Optional[torch.Tensor] = None,
479
- codec_targets: Optional[torch.Tensor] = None,
480
- codec_lengths: Optional[torch.Tensor] = None,
481
  **kwargs,
482
  ) -> CausalLMOutputWithPast:
483
  """Forward pass for training and inference."""
@@ -505,10 +487,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
505
  audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
506
  )
507
 
508
- # Request hidden states if training audio head with codec targets
509
- if self.audio_head is not None and codec_targets is not None:
510
- kwargs["output_hidden_states"] = True
511
-
512
  # Run through language model (let it compute loss if labels provided)
513
  outputs = self.language_model(
514
  attention_mask=attention_mask,
@@ -527,29 +505,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
527
  if aux_loss is not None and aux_loss.numel() > 0:
528
  outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device)
529
 
530
- # Compute audio head loss if training S2S with codec targets
531
- if self.audio_head is not None and codec_targets is not None:
532
- hidden_states = outputs.hidden_states[-1] # Last layer hidden states
533
- # No detach needed: LLM is frozen (requires_grad=False), so gradients
534
- # naturally stop there. Hidden states keep their grad_fn for proper backprop.
535
- audio_head_loss = self.audio_head(
536
- hidden_states,
537
- codec_targets=codec_targets,
538
- codec_lengths=codec_lengths,
539
- )
540
- # Add audio_head_loss directly to outputs.loss
541
- # (CausalLMOutputWithPast doesn't preserve custom attributes through Accelerator)
542
- if outputs.loss is not None:
543
- outputs.loss = outputs.loss + audio_head_loss
544
- else:
545
- # S2S-only training: audio head loss is the only loss
546
- outputs.loss = audio_head_loss
547
- else:
548
- print(
549
- f"DEBUG: audio_head branch NOT taken: audio_head={self.audio_head is not None}, codec_targets={codec_targets is not None}"
550
- )
551
-
552
- print(f"DEBUG: returning outputs.loss={outputs.loss}")
553
  return outputs
554
 
555
  def prepare_inputs_for_generation(self, *args, **kwargs):
@@ -833,139 +788,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
833
  response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
834
  return response.strip()
835
 
836
- @torch.no_grad()
837
- def generate_with_audio(
838
- self,
839
- input_features: torch.Tensor,
840
- audio_attention_mask: torch.Tensor,
841
- **generate_kwargs,
842
- ) -> dict[str, torch.Tensor]:
843
- """Generate text and NeuCodec tokens for Speech-to-Speech.
844
-
845
- Args:
846
- input_features: Mel spectrogram features (batch, n_mels, mel_len)
847
- audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len)
848
- **generate_kwargs: Additional generation arguments
849
-
850
- Returns:
851
- Dict with:
852
- - text_ids: Generated text token IDs (batch, seq_len)
853
- - text: Decoded text strings (list of str)
854
- - codec_tokens: Predicted NeuCodec tokens (batch, audio_len)
855
- """
856
- if self.audio_head is None:
857
- raise ValueError("Audio head not configured. Set use_audio_head=True in config.")
858
-
859
- device = input_features.device
860
- batch_size = input_features.shape[0]
861
-
862
- # Encode audio -> flattened embeddings
863
- audio_embeds = self._encode_audio(input_features, audio_attention_mask)
864
-
865
- # Build prompt with correct number of audio tokens
866
- num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
867
- audio_placeholder = "<audio>" * num_audio_tokens
868
-
869
- messages: list[dict[str, str]] = []
870
- if self.system_prompt:
871
- messages.append({"role": "system", "content": self.system_prompt})
872
- user_content = audio_placeholder
873
- if self.TRANSCRIBE_PROMPT:
874
- user_content += " " + self.TRANSCRIBE_PROMPT
875
- messages.append({"role": "user", "content": user_content})
876
-
877
- chat_result = self.tokenizer.apply_chat_template(
878
- messages,
879
- tokenize=True,
880
- add_generation_prompt=True,
881
- return_tensors="pt",
882
- enable_thinking=getattr(self.config, "enable_thinking", False),
883
- )
884
- input_ids = chat_result.input_ids.to(device)
885
-
886
- if input_ids.dim() == 1:
887
- input_ids = input_ids.unsqueeze(0)
888
- if input_ids.shape[0] == 1 and batch_size > 1:
889
- input_ids = input_ids.expand(batch_size, -1)
890
-
891
- attention_mask = torch.ones_like(input_ids)
892
-
893
- # Get text embeddings and replace audio tokens with audio embeddings
894
- inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
895
- audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
896
- inputs_embeds = inputs_embeds.masked_scatter(
897
- audio_token_mask.to(inputs_embeds.device),
898
- audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
899
- )
900
-
901
- # Generate with hidden states
902
- output = self.language_model.generate(
903
- input_ids=input_ids,
904
- inputs_embeds=inputs_embeds,
905
- attention_mask=attention_mask,
906
- generation_config=self.generation_config,
907
- output_hidden_states=True,
908
- return_dict_in_generate=True,
909
- **generate_kwargs,
910
- )
911
-
912
- # Extract generated text
913
- text_ids = output.sequences[:, input_ids.shape[1] :]
914
- text = self.tokenizer.batch_decode(text_ids, skip_special_tokens=True)
915
-
916
- # Extract hidden states from generation steps and concatenate
917
- # output.hidden_states is tuple of (step,) where each step is tuple of (layer,)
918
- # Each layer tensor is (batch, 1, hidden_dim) for generated tokens
919
- last_layer_states = []
920
- for step_hidden in output.hidden_states:
921
- # step_hidden is tuple of (num_layers,) tensors
922
- # Get last layer: shape (batch, 1, hidden_dim)
923
- last_layer_states.append(step_hidden[-1])
924
-
925
- # Concatenate across generation steps: (batch, gen_seq_len, hidden_dim)
926
- hidden_states = torch.cat(last_layer_states, dim=1)
927
-
928
- # Predict codec tokens (uses inference heuristic for duration)
929
- # WavTokenizer: single codebook, shape (batch, audio_len)
930
- codec_tokens = self.audio_head(hidden_states)
931
-
932
- return {
933
- "text_ids": text_ids,
934
- "text": text,
935
- "codec_tokens": codec_tokens,
936
- }
937
-
938
- def decode_audio(
939
- self,
940
- codec_tokens: torch.Tensor,
941
- codec_model_id: str = "neuphonic/neucodec",
942
- ) -> torch.Tensor:
943
- """Decode NeuCodec tokens to waveform.
944
-
945
- Args:
946
- codec_tokens: Codec token indices (batch, audio_len)
947
- codec_model_id: HuggingFace model ID for NeuCodec
948
-
949
- Returns:
950
- Waveform tensor (batch, 1, samples) at 24kHz
951
- """
952
- try:
953
- from neucodec import NeuCodec
954
- except ImportError as e:
955
- raise ImportError(
956
- "NeuCodec required for audio decoding. Install with: pip install neucodec"
957
- ) from e
958
-
959
- model = NeuCodec.from_pretrained(codec_model_id)
960
- model = model.to(codec_tokens.device)
961
- model.eval()
962
-
963
- # NeuCodec decode expects (batch, 1, seq_len)
964
- codes = codec_tokens.unsqueeze(1)
965
-
966
- with torch.no_grad():
967
- return model.decode_code(codes)
968
-
969
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
970
  """Save model, tokenizer, and processor."""
971
  import shutil
 
181
  else:
182
  self.spec_augment = None
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  # For model parallelism
185
  self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
186
 
 
365
  )
366
 
367
  def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
368
+ """Only save trainable projector weights."""
369
+ return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
 
 
 
370
 
371
  def _compute_encoder_output_lengths(
372
  self,
 
460
  labels: Optional[torch.Tensor] = None,
461
  use_cache: Optional[bool] = None,
462
  cache_position: Optional[torch.Tensor] = None,
 
 
463
  **kwargs,
464
  ) -> CausalLMOutputWithPast:
465
  """Forward pass for training and inference."""
 
487
  audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
488
  )
489
 
 
 
 
 
490
  # Run through language model (let it compute loss if labels provided)
491
  outputs = self.language_model(
492
  attention_mask=attention_mask,
 
505
  if aux_loss is not None and aux_loss.numel() > 0:
506
  outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device)
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  return outputs
509
 
510
  def prepare_inputs_for_generation(self, *args, **kwargs):
 
788
  response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
789
  return response.strip()
790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
791
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
792
  """Save model, tokenizer, and processor."""
793
  import shutil
asr_pipeline.py CHANGED
@@ -2,7 +2,7 @@
2
 
3
  import re
4
  from pathlib import Path
5
- from typing import Any, Iterator, Union
6
 
7
  import numpy as np
8
  import torch
@@ -101,142 +101,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
101
  audio = np.concatenate(audio_chunks) if audio_chunks else np.array([], dtype=np.float32)
102
  return {"audio": audio, "sample_rate": TTS_SAMPLE_RATE}
103
 
104
- def transcribe_streaming(
105
- self,
106
- inputs: Union[str, bytes, np.ndarray, dict],
107
- system_prompt: str | None = None,
108
- ) -> Iterator[str]:
109
- """Transcribe audio with streaming token output for low-latency applications.
110
-
111
- Yields partial transcript strings as tokens are generated, reducing
112
- time-to-first-word compared to waiting for full transcription.
113
-
114
- Args:
115
- inputs: Audio input in any supported format:
116
- - str: File path to audio file
117
- - bytes: Raw audio bytes
118
- - np.ndarray: Audio samples as numpy array
119
- - dict: {"array": np.ndarray, "sampling_rate": int}
120
- system_prompt: Optional system prompt override (uses model's default if not provided)
121
-
122
- Yields:
123
- Partial transcript text as each token is generated
124
-
125
- Example:
126
- >>> for partial in pipeline.transcribe_streaming("audio.wav"):
127
- ... print(partial, end="", flush=True)
128
- """
129
- # Extract audio array from various input formats
130
- audio_data = self._extract_audio(inputs)
131
- if audio_data is None:
132
- return
133
-
134
- audio_array = audio_data["array"]
135
- sample_rate = audio_data.get("sampling_rate", 16000)
136
-
137
- # Preprocess audio through feature extractor
138
- model_inputs = self.feature_extractor(
139
- audio_array,
140
- sampling_rate=sample_rate,
141
- return_tensors="pt",
142
- return_attention_mask=True,
143
- )
144
-
145
- # Get model dtype and device, cast inputs to match
146
- device = self.model.device
147
- model_dtype = next(self.model.parameters()).dtype
148
- input_features = model_inputs.input_features.to(device, dtype=model_dtype)
149
- attention_mask = model_inputs.attention_mask.to(device)
150
-
151
- # Stream tokens from model
152
- yield from self.model.generate_streaming(
153
- input_features=input_features,
154
- audio_attention_mask=attention_mask,
155
- system_prompt=system_prompt,
156
- )
157
-
158
- def transcribe_streaming_with_audio(
159
- self,
160
- inputs: Union[str, bytes, np.ndarray, dict],
161
- voice: str = DEFAULT_TTS_VOICE,
162
- system_prompt: str | None = None,
163
- ) -> Iterator[dict[str, Any]]:
164
- """Transcribe audio with streaming text AND audio output.
165
-
166
- Yields partial text as tokens are generated, and audio chunks
167
- as complete sentences are detected. This enables low-latency
168
- voice agents that can start speaking before transcription completes.
169
-
170
- Args:
171
- inputs: Audio input (same formats as transcribe_streaming)
172
- voice: Kokoro TTS voice ID
173
- system_prompt: Optional system prompt override (uses model's default if not provided)
174
-
175
- Yields:
176
- Dicts with either:
177
- - {"type": "text", "text": str, "interim": bool} for text updates
178
- - {"type": "audio", "audio": np.ndarray, "sample_rate": int} for audio chunks
179
-
180
- Example:
181
- >>> for chunk in pipeline.transcribe_streaming_with_audio(audio):
182
- ... if chunk["type"] == "text":
183
- ... print(chunk["text"], end="", flush=True)
184
- ... elif chunk["type"] == "audio":
185
- ... play_audio(chunk["audio"], chunk["sample_rate"])
186
- """
187
- import re
188
-
189
- sentence_buffer = ""
190
- full_text = ""
191
-
192
- # Sentence-ending patterns (handles ., !, ?, and common abbreviations)
193
- sentence_end_pattern = re.compile(r"[.!?](?:\s|$)")
194
-
195
- for token_text in self.transcribe_streaming(inputs, system_prompt=system_prompt):
196
- full_text += token_text
197
- sentence_buffer += token_text
198
-
199
- # Yield text update
200
- yield {"type": "text", "text": full_text, "interim": True}
201
-
202
- # Check for complete sentence
203
- match = sentence_end_pattern.search(sentence_buffer)
204
- if match:
205
- # Extract complete sentence(s)
206
- end_pos = match.end()
207
- complete_text = sentence_buffer[:end_pos].strip()
208
- sentence_buffer = sentence_buffer[end_pos:]
209
-
210
- # Generate audio for the complete sentence
211
- if complete_text:
212
- try:
213
- tts_result = self.text_to_speech(complete_text, voice=voice)
214
- if tts_result["audio"] is not None and len(tts_result["audio"]) > 0:
215
- yield {
216
- "type": "audio",
217
- "audio": tts_result["audio"],
218
- "sample_rate": tts_result["sample_rate"],
219
- }
220
- except Exception:
221
- pass # Skip audio on TTS errors
222
-
223
- # Final text update (not interim)
224
- yield {"type": "text", "text": full_text, "interim": False}
225
-
226
- # Generate audio for any remaining text
227
- remaining = sentence_buffer.strip()
228
- if remaining:
229
- try:
230
- tts_result = self.text_to_speech(remaining, voice=voice)
231
- if tts_result["audio"] is not None and len(tts_result["audio"]) > 0:
232
- yield {
233
- "type": "audio",
234
- "audio": tts_result["audio"],
235
- "sample_rate": tts_result["sample_rate"],
236
- }
237
- except Exception:
238
- pass
239
-
240
  def _sanitize_parameters(self, **kwargs):
241
  """Intercept our custom parameters before parent class validates them."""
242
  # Remove our custom parameters so parent doesn't see them
@@ -247,7 +111,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
247
  kwargs.pop("max_speakers", None)
248
  kwargs.pop("hf_token", None)
249
  kwargs.pop("user_prompt", None)
250
- kwargs.pop("system_prompt", None)
251
  kwargs.pop("diarization_backend", None)
252
  # TTS parameters
253
  kwargs.pop("return_audio", None)
@@ -269,7 +132,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
269
  return_audio: If True, synthesize transcription as speech using Kokoro TTS
270
  tts_voice: Kokoro voice ID for TTS output (default: "af_heart")
271
  user_prompt: Custom transcription prompt (default: "Transcribe: ")
272
- system_prompt: Custom system prompt override (uses model's default if not provided)
273
  num_speakers: Exact number of speakers (if known, for diarization)
274
  min_speakers: Minimum number of speakers (for diarization)
275
  max_speakers: Maximum number of speakers (for diarization)
@@ -286,7 +148,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
286
  return_audio = kwargs.pop("return_audio", False)
287
  tts_voice = kwargs.pop("tts_voice", DEFAULT_TTS_VOICE)
288
  user_prompt = kwargs.pop("user_prompt", None)
289
- system_prompt = kwargs.pop("system_prompt", None)
290
  diarization_params = {
291
  "num_speakers": kwargs.pop("num_speakers", None),
292
  "min_speakers": kwargs.pop("min_speakers", None),
@@ -302,12 +163,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
302
  original_prompt = self.model.TRANSCRIBE_PROMPT
303
  self.model.TRANSCRIBE_PROMPT = user_prompt
304
 
305
- # Set custom system prompt if provided
306
- original_system_prompt = None
307
- if system_prompt:
308
- original_system_prompt = self.model.system_prompt
309
- self.model.system_prompt = system_prompt
310
-
311
  # Store audio for timestamp alignment and diarization
312
  if return_timestamps or return_speakers:
313
  self._current_audio = self._extract_audio(inputs)
@@ -369,25 +224,11 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
369
  self._current_audio = None
370
  if original_prompt is not None:
371
  self.model.TRANSCRIBE_PROMPT = original_prompt
372
- if original_system_prompt is not None:
373
- self.model.system_prompt = original_system_prompt
374
 
375
  return result
376
 
377
  def _extract_audio(self, inputs) -> dict | None:
378
- """Extract audio array from various input formats.
379
-
380
- Supported input formats:
381
- - str: File path to audio file
382
- - bytes: Encoded audio (mp3, wav, etc.) - decoded via ffmpeg
383
- - np.ndarray: Audio samples as float32 array
384
- - dict with "array": Audio samples as numpy array
385
- - dict with "raw": Alias for "array" (HF pipeline compat)
386
- - dict with "raw_bytes": Raw PCM bytes (requires "dtype", optional "sampling_rate")
387
-
388
- For raw PCM bytes (e.g., from pipecat), use:
389
- {"raw_bytes": pcm_bytes, "dtype": "int16", "sampling_rate": 16000}
390
- """
391
  from transformers.pipelines.audio_utils import ffmpeg_read
392
 
393
  if isinstance(inputs, dict):
@@ -401,17 +242,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
401
  "array": inputs["raw"],
402
  "sampling_rate": inputs.get("sampling_rate", 16000),
403
  }
404
- if "raw_bytes" in inputs:
405
- # Raw PCM bytes - convert to float32 array
406
- dtype = inputs.get("dtype", "int16")
407
- sample_rate = inputs.get("sampling_rate", 16000)
408
- audio = np.frombuffer(inputs["raw_bytes"], dtype=dtype).astype(np.float32)
409
- # Normalize based on dtype
410
- if dtype == "int16":
411
- audio = audio / 32768.0
412
- elif dtype == "int32":
413
- audio = audio / 2147483648.0
414
- return {"array": audio, "sampling_rate": sample_rate}
415
  elif isinstance(inputs, str):
416
  # File path - load audio using ffmpeg (same as HF pipeline)
417
  with Path(inputs).open("rb") as f:
 
2
 
3
  import re
4
  from pathlib import Path
5
+ from typing import Any
6
 
7
  import numpy as np
8
  import torch
 
101
  audio = np.concatenate(audio_chunks) if audio_chunks else np.array([], dtype=np.float32)
102
  return {"audio": audio, "sample_rate": TTS_SAMPLE_RATE}
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def _sanitize_parameters(self, **kwargs):
105
  """Intercept our custom parameters before parent class validates them."""
106
  # Remove our custom parameters so parent doesn't see them
 
111
  kwargs.pop("max_speakers", None)
112
  kwargs.pop("hf_token", None)
113
  kwargs.pop("user_prompt", None)
 
114
  kwargs.pop("diarization_backend", None)
115
  # TTS parameters
116
  kwargs.pop("return_audio", None)
 
132
  return_audio: If True, synthesize transcription as speech using Kokoro TTS
133
  tts_voice: Kokoro voice ID for TTS output (default: "af_heart")
134
  user_prompt: Custom transcription prompt (default: "Transcribe: ")
 
135
  num_speakers: Exact number of speakers (if known, for diarization)
136
  min_speakers: Minimum number of speakers (for diarization)
137
  max_speakers: Maximum number of speakers (for diarization)
 
148
  return_audio = kwargs.pop("return_audio", False)
149
  tts_voice = kwargs.pop("tts_voice", DEFAULT_TTS_VOICE)
150
  user_prompt = kwargs.pop("user_prompt", None)
 
151
  diarization_params = {
152
  "num_speakers": kwargs.pop("num_speakers", None),
153
  "min_speakers": kwargs.pop("min_speakers", None),
 
163
  original_prompt = self.model.TRANSCRIBE_PROMPT
164
  self.model.TRANSCRIBE_PROMPT = user_prompt
165
 
 
 
 
 
 
 
166
  # Store audio for timestamp alignment and diarization
167
  if return_timestamps or return_speakers:
168
  self._current_audio = self._extract_audio(inputs)
 
224
  self._current_audio = None
225
  if original_prompt is not None:
226
  self.model.TRANSCRIBE_PROMPT = original_prompt
 
 
227
 
228
  return result
229
 
230
  def _extract_audio(self, inputs) -> dict | None:
231
+ """Extract audio array from various input formats using HF utilities."""
 
 
 
 
 
 
 
 
 
 
 
 
232
  from transformers.pipelines.audio_utils import ffmpeg_read
233
 
234
  if isinstance(inputs, dict):
 
242
  "array": inputs["raw"],
243
  "sampling_rate": inputs.get("sampling_rate", 16000),
244
  }
 
 
 
 
 
 
 
 
 
 
 
245
  elif isinstance(inputs, str):
246
  # File path - load audio using ffmpeg (same as HF pipeline)
247
  with Path(inputs).open("rb") as f: