mazesmazes commited on
Commit
072895c
·
verified ·
1 Parent(s): 2581718

Training in progress - step 1000

Browse files
Files changed (2) hide show
  1. asr_config.py +91 -131
  2. asr_pipeline.py +80 -3
asr_config.py CHANGED
@@ -4,139 +4,120 @@ import transformers
4
 
5
 
6
  class ASRConfig(transformers.PretrainedConfig):
7
- """Configuration class for the ASR model.
8
-
9
- This config combines settings for:
10
- - Audio encoder (GLM-ASR/Whisper)
11
- - Text decoder (Qwen)
12
- - Projector (MLP, MOSA, MoE, QFormer)
13
- - Generation parameters
14
- - Training options (SpecAugment, LoRA)
15
- """
16
 
17
  model_type = "asr_model"
18
  is_composition = True
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def __init__(
21
  self,
 
22
  audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
23
  text_model_id: str = "Qwen/Qwen3-0.6B",
 
24
  attn_implementation: str = "sdpa",
25
  model_dtype: str = "bfloat16",
26
- num_beams: Optional[int] = None,
27
  system_prompt: str = "You are a helpful assistant.",
 
 
28
  encoder_dim: Optional[int] = None,
29
  llm_dim: Optional[int] = None,
30
- # Encoder conv layers: list of (padding, kernel_size, stride) tuples
31
- # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
32
  encoder_conv_layers: Optional[list] = None,
33
  audio_sample_rate: int = 16000,
 
 
34
  projector_pool_stride: int = 4,
35
- downsample_rate: int = 5, # Granite default
36
  projector_hidden_dim: Optional[int] = None,
37
- projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
38
- projector_num_layers: int = 2, # Number of layers in MLP projector
39
- projector_init_std: float = 0.02, # Weight initialization std
40
- projector_dropout: float = 0.0, # Dropout rate for projector layers
41
- # MoE-specific configuration
42
- num_experts: int = 4, # Number of experts in MoE projectors
43
- num_experts_per_tok: int = 2, # Top-k experts per token
44
- router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
45
- # QFormer-specific configuration (Granite defaults)
46
- qformer_window_size: int = 15, # Window size for QFormer processing
47
- qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
48
- qformer_num_layers: int = 2, # Number of QFormer transformer layers
49
- qformer_num_heads: int = 16, # Number of attention heads in QFormer
50
- qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
51
- label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
52
- inference_warmup_tokens: int = 10,
53
- # SpecAugment settings
54
  use_specaugment: bool = False,
55
  num_time_masks: int = 2,
56
  time_mask_length: int = 10,
57
  num_freq_masks: int = 0,
58
  freq_mask_length: int = 10,
59
- # LoRA configuration (for Stage 2 fine-tuning)
60
  use_lora: bool = False,
61
- lora_rank: int = 8, # SALMONN default
62
- lora_alpha: int = 32, # SALMONN default (scaling factor 4.0)
63
  lora_dropout: float = 0.0,
64
- lora_target_modules: Optional[list] = None, # Default: all linear layers
65
- freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
66
- do_sample: bool = False,
67
- enable_thinking: bool = False, # Enable Qwen3 thinking mode for omni models
68
- temperature: Optional[float] = None,
69
- top_p: Optional[float] = None,
70
- top_k: Optional[int] = None,
71
- max_new_tokens: Optional[int] = None,
72
- min_new_tokens: Optional[int] = None,
73
- repetition_penalty: Optional[float] = None,
74
- length_penalty: Optional[float] = None,
75
- no_repeat_ngram_size: Optional[int] = None,
76
- use_cache: Optional[bool] = None,
77
  **kwargs,
78
  ):
79
- """Initialize ASR model configuration.
80
-
81
- Args:
82
- audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
83
- text_model_id: HuggingFace model ID for text decoder (Qwen)
84
- attn_implementation: Attention implementation ("sdpa", "flash_attention_2", "eager")
85
- model_dtype: Model dtype ("bfloat16", "float16", "float32")
86
- projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
87
- use_lora: Enable LoRA adapters for Stage 2 fine-tuning
88
- use_specaugment: Enable SpecAugment data augmentation
89
- """
90
- # Set default generation parameters (greedy decoding only)
91
- generation_defaults = {
92
- "num_beams": 1,
93
- "max_new_tokens": 128,
94
- "min_new_tokens": 0,
95
- "repetition_penalty": 1.0,
96
- "length_penalty": 1.0,
97
- "no_repeat_ngram_size": 0, # Prevent repeating 3-grams like "so so so"
98
- "use_cache": True,
99
- }
100
-
101
- # Apply defaults (config.json values take precedence)
102
- kwargs = {**generation_defaults, **kwargs}
103
 
 
104
  self.audio_model_id = audio_model_id
105
  self.text_model_id = text_model_id
106
  self.attn_implementation = attn_implementation
107
  self.model_dtype = model_dtype
108
  self.system_prompt = system_prompt
 
 
 
109
  self.encoder_dim = encoder_dim
110
  self.llm_dim = llm_dim
111
- # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
112
  self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
113
  self.audio_sample_rate = audio_sample_rate
114
- self.projector_init_std = projector_init_std
 
 
115
  self.projector_pool_stride = projector_pool_stride
116
- self.downsample_rate = downsample_rate
117
  self.projector_hidden_dim = projector_hidden_dim
118
- self.projector_type = projector_type
119
  self.projector_num_layers = projector_num_layers
 
120
  self.projector_dropout = projector_dropout
121
- # MoE-specific configuration
 
122
  self.num_experts = num_experts
123
  self.num_experts_per_tok = num_experts_per_tok
124
  self.router_aux_loss_coef = router_aux_loss_coef
125
- # QFormer-specific configuration
 
126
  self.qformer_window_size = qformer_window_size
127
  self.qformer_hidden_size = qformer_hidden_size
128
  self.qformer_num_layers = qformer_num_layers
129
  self.qformer_num_heads = qformer_num_heads
130
  self.qformer_intermediate_size = qformer_intermediate_size
131
- self.label_smoothing = label_smoothing
132
- self.inference_warmup_tokens = inference_warmup_tokens
133
- # SpecAugment configuration
134
  self.use_specaugment = use_specaugment
135
  self.num_time_masks = num_time_masks
136
  self.time_mask_length = time_mask_length
137
  self.num_freq_masks = num_freq_masks
138
  self.freq_mask_length = freq_mask_length
139
- # LoRA configuration
140
  self.use_lora = use_lora
141
  self.lora_rank = lora_rank
142
  self.lora_alpha = lora_alpha
@@ -151,69 +132,48 @@ class ASRConfig(transformers.PretrainedConfig):
151
  "down_proj",
152
  ]
153
  self.freeze_projector = freeze_projector
 
154
 
155
- # Generation parameters (use explicit value if provided, else use default)
156
- self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
157
- self.max_new_tokens = (
158
- max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
159
- )
160
- self.min_new_tokens = (
161
- min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
162
- )
163
- self.repetition_penalty = (
164
- repetition_penalty
165
- if repetition_penalty is not None
166
- else generation_defaults["repetition_penalty"]
167
- )
168
- self.length_penalty = (
169
- length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
170
- )
171
- self.no_repeat_ngram_size = (
172
- no_repeat_ngram_size
173
- if no_repeat_ngram_size is not None
174
- else generation_defaults["no_repeat_ngram_size"]
175
- )
176
- self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
177
- self.do_sample = do_sample
178
- self.enable_thinking = enable_thinking
179
- self.temperature = temperature
180
- self.top_p = top_p
181
- self.top_k = top_k
182
-
183
- if "audio_config" not in kwargs:
184
  self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
185
- # Override dtype to match model_dtype
186
  self.audio_config.dtype = model_dtype
187
- else:
188
- self.audio_config = kwargs.pop("audio_config")
189
-
190
- if "text_config" not in kwargs:
 
 
 
 
191
  self.text_config = transformers.AutoConfig.from_pretrained(
192
  text_model_id, trust_remote_code=True
193
  )
194
- # Override dtype to match model_dtype
195
  self.text_config.dtype = model_dtype
196
- else:
197
- self.text_config = kwargs.pop("text_config")
198
-
199
- if isinstance(self.text_config, dict):
200
- # Reconstruct config from dict using the model_type stored in the dict
201
- model_type = self.text_config["model_type"]
202
- config_class = transformers.AutoConfig.for_model(model_type).__class__
203
  self.text_config = config_class(**self.text_config)
204
 
205
- if isinstance(self.audio_config, dict):
206
- model_type = self.audio_config.get("model_type")
207
- if model_type:
208
- config_class = transformers.AutoConfig.for_model(model_type).__class__
209
- self.audio_config = config_class(**self.audio_config)
210
-
211
  super().__init__(**kwargs)
212
 
213
- # Point encoder to audio_config so pipeline uses correct feature extractor
214
- # The pipeline looks for config.encoder._name_or_path for feature extractor
215
  self.encoder = self.audio_config
216
-
217
  self.auto_map = {
218
  "AutoConfig": "asr_config.ASRConfig",
219
  "AutoModel": "asr_modeling.ASRModel",
 
4
 
5
 
6
  class ASRConfig(transformers.PretrainedConfig):
7
+ """Configuration class for the ASR model."""
 
 
 
 
 
 
 
 
8
 
9
  model_type = "asr_model"
10
  is_composition = True
11
 
12
+ # Generation defaults
13
+ GENERATION_DEFAULTS = {
14
+ "num_beams": 1,
15
+ "max_new_tokens": 128,
16
+ "min_new_tokens": 0,
17
+ "repetition_penalty": 1.0,
18
+ "length_penalty": 1.0,
19
+ "no_repeat_ngram_size": 0,
20
+ "use_cache": True,
21
+ "do_sample": False,
22
+ "temperature": None,
23
+ "top_p": None,
24
+ "top_k": None,
25
+ }
26
+
27
  def __init__(
28
  self,
29
+ # Model IDs
30
  audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
31
  text_model_id: str = "Qwen/Qwen3-0.6B",
32
+ # Model settings
33
  attn_implementation: str = "sdpa",
34
  model_dtype: str = "bfloat16",
 
35
  system_prompt: str = "You are a helpful assistant.",
36
+ enable_thinking: bool = False,
37
+ # Encoder settings (auto-detected if None)
38
  encoder_dim: Optional[int] = None,
39
  llm_dim: Optional[int] = None,
 
 
40
  encoder_conv_layers: Optional[list] = None,
41
  audio_sample_rate: int = 16000,
42
+ # Projector settings
43
+ projector_type: str = "mlp",
44
  projector_pool_stride: int = 4,
 
45
  projector_hidden_dim: Optional[int] = None,
46
+ projector_num_layers: int = 2,
47
+ projector_init_std: float = 0.02,
48
+ projector_dropout: float = 0.0,
49
+ # MoE projector settings
50
+ num_experts: int = 4,
51
+ num_experts_per_tok: int = 2,
52
+ router_aux_loss_coef: float = 0.01,
53
+ # QFormer projector settings
54
+ qformer_window_size: int = 15,
55
+ qformer_hidden_size: Optional[int] = None,
56
+ qformer_num_layers: int = 2,
57
+ qformer_num_heads: int = 16,
58
+ qformer_intermediate_size: Optional[int] = None,
59
+ downsample_rate: int = 5,
60
+ # Training settings (not saved to config.json for inference)
 
 
61
  use_specaugment: bool = False,
62
  num_time_masks: int = 2,
63
  time_mask_length: int = 10,
64
  num_freq_masks: int = 0,
65
  freq_mask_length: int = 10,
 
66
  use_lora: bool = False,
67
+ lora_rank: int = 8,
68
+ lora_alpha: int = 32,
69
  lora_dropout: float = 0.0,
70
+ lora_target_modules: Optional[list] = None,
71
+ freeze_projector: bool = False,
72
+ label_smoothing: float = 0.0,
 
 
 
 
 
 
 
 
 
 
73
  **kwargs,
74
  ):
75
+ # Merge generation defaults with kwargs (kwargs takes precedence)
76
+ for key, default in self.GENERATION_DEFAULTS.items():
77
+ if key not in kwargs:
78
+ kwargs[key] = default
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ # Core model settings
81
  self.audio_model_id = audio_model_id
82
  self.text_model_id = text_model_id
83
  self.attn_implementation = attn_implementation
84
  self.model_dtype = model_dtype
85
  self.system_prompt = system_prompt
86
+ self.enable_thinking = enable_thinking
87
+
88
+ # Encoder settings
89
  self.encoder_dim = encoder_dim
90
  self.llm_dim = llm_dim
 
91
  self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
92
  self.audio_sample_rate = audio_sample_rate
93
+
94
+ # Projector settings
95
+ self.projector_type = projector_type
96
  self.projector_pool_stride = projector_pool_stride
 
97
  self.projector_hidden_dim = projector_hidden_dim
 
98
  self.projector_num_layers = projector_num_layers
99
+ self.projector_init_std = projector_init_std
100
  self.projector_dropout = projector_dropout
101
+
102
+ # MoE settings
103
  self.num_experts = num_experts
104
  self.num_experts_per_tok = num_experts_per_tok
105
  self.router_aux_loss_coef = router_aux_loss_coef
106
+
107
+ # QFormer settings
108
  self.qformer_window_size = qformer_window_size
109
  self.qformer_hidden_size = qformer_hidden_size
110
  self.qformer_num_layers = qformer_num_layers
111
  self.qformer_num_heads = qformer_num_heads
112
  self.qformer_intermediate_size = qformer_intermediate_size
113
+ self.downsample_rate = downsample_rate
114
+
115
+ # Training settings
116
  self.use_specaugment = use_specaugment
117
  self.num_time_masks = num_time_masks
118
  self.time_mask_length = time_mask_length
119
  self.num_freq_masks = num_freq_masks
120
  self.freq_mask_length = freq_mask_length
 
121
  self.use_lora = use_lora
122
  self.lora_rank = lora_rank
123
  self.lora_alpha = lora_alpha
 
132
  "down_proj",
133
  ]
134
  self.freeze_projector = freeze_projector
135
+ self.label_smoothing = label_smoothing
136
 
137
+ # Generation parameters (from kwargs after merge with defaults)
138
+ self.num_beams = kwargs.pop("num_beams")
139
+ self.max_new_tokens = kwargs.pop("max_new_tokens")
140
+ self.min_new_tokens = kwargs.pop("min_new_tokens")
141
+ self.repetition_penalty = kwargs.pop("repetition_penalty")
142
+ self.length_penalty = kwargs.pop("length_penalty")
143
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size")
144
+ self.use_cache = kwargs.pop("use_cache")
145
+ self.do_sample = kwargs.pop("do_sample")
146
+ self.temperature = kwargs.pop("temperature")
147
+ self.top_p = kwargs.pop("top_p")
148
+ self.top_k = kwargs.pop("top_k")
149
+
150
+ # Load sub-configs
151
+ self.audio_config = kwargs.pop("audio_config", None)
152
+ if self.audio_config is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
 
154
  self.audio_config.dtype = model_dtype
155
+ elif isinstance(self.audio_config, dict) and self.audio_config.get("model_type"):
156
+ config_class = transformers.AutoConfig.for_model(
157
+ self.audio_config["model_type"]
158
+ ).__class__
159
+ self.audio_config = config_class(**self.audio_config)
160
+
161
+ self.text_config = kwargs.pop("text_config", None)
162
+ if self.text_config is None:
163
  self.text_config = transformers.AutoConfig.from_pretrained(
164
  text_model_id, trust_remote_code=True
165
  )
 
166
  self.text_config.dtype = model_dtype
167
+ elif isinstance(self.text_config, dict):
168
+ config_class = transformers.AutoConfig.for_model(
169
+ self.text_config["model_type"]
170
+ ).__class__
 
 
 
171
  self.text_config = config_class(**self.text_config)
172
 
 
 
 
 
 
 
173
  super().__init__(**kwargs)
174
 
175
+ # Pipeline configuration
 
176
  self.encoder = self.audio_config
 
177
  self.auto_map = {
178
  "AutoConfig": "asr_config.ASRConfig",
179
  "AutoModel": "asr_modeling.ASRModel",
asr_pipeline.py CHANGED
@@ -18,7 +18,26 @@ except ImportError:
18
  from diarization import SpeakerDiarizer # type: ignore[no-redef]
19
 
20
  # Re-export for backwards compatibility
21
- __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
@@ -43,6 +62,44 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
43
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
44
  )
45
  self._current_audio = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def _sanitize_parameters(self, **kwargs):
48
  """Intercept our custom parameters before parent class validates them."""
@@ -55,6 +112,9 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
55
  kwargs.pop("hf_token", None)
56
  kwargs.pop("user_prompt", None)
57
  kwargs.pop("diarization_backend", None)
 
 
 
58
 
59
  return super()._sanitize_parameters(**kwargs)
60
 
@@ -69,6 +129,8 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
69
  inputs: Audio input (file path, dict with array/sampling_rate, etc.)
70
  return_timestamps: If True, return word-level timestamps using forced alignment
71
  return_speakers: If True, return speaker labels for each word
 
 
72
  user_prompt: Custom transcription prompt (default: "Transcribe: ")
73
  num_speakers: Exact number of speakers (if known, for diarization)
74
  min_speakers: Minimum number of speakers (for diarization)
@@ -77,11 +139,14 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
77
 
78
  Returns:
79
  Dict with 'text' key, 'words' key if return_timestamps=True,
80
- and speaker labels on words if return_speakers=True
 
81
  """
82
  # Extract our params before super().__call__ (which will also call _sanitize_parameters)
83
  return_timestamps = kwargs.pop("return_timestamps", False)
84
  return_speakers = kwargs.pop("return_speakers", False)
 
 
85
  user_prompt = kwargs.pop("user_prompt", None)
86
  diarization_params = {
87
  "num_speakers": kwargs.pop("num_speakers", None),
@@ -143,6 +208,18 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
143
  result["speaker_segments"] = []
144
  result["diarization_error"] = str(e)
145
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  # Clean up
147
  self._current_audio = None
148
  if original_prompt is not None:
@@ -257,7 +334,7 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
257
 
258
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
259
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
260
- text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
261
  # Truncate repetitions at end of text
262
  text = _truncate_repetitions(text)
263
  return {"text": text}
 
18
  from diarization import SpeakerDiarizer # type: ignore[no-redef]
19
 
20
  # Re-export for backwards compatibility
21
+ __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline", "strip_thinking"]
22
+
23
+ # Default TTS voice for Kokoro
24
+ DEFAULT_TTS_VOICE = "af_heart"
25
+ TTS_SAMPLE_RATE = 24000
26
+
27
+
28
+ def strip_thinking(text: str) -> str:
29
+ """Remove <think>...</think> tags from model output.
30
+
31
+ Args:
32
+ text: Model output text that may contain thinking tags
33
+
34
+ Returns:
35
+ Text with thinking content removed
36
+ """
37
+ if not text:
38
+ return text
39
+ text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)
40
+ return text.strip()
41
 
42
 
43
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
 
62
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
63
  )
64
  self._current_audio = None
65
+ self._tts_pipeline = None
66
+
67
+ @property
68
+ def tts_pipeline(self):
69
+ """Lazy-load Kokoro TTS pipeline on first use."""
70
+ if self._tts_pipeline is None:
71
+ try:
72
+ from kokoro import KPipeline
73
+
74
+ self._tts_pipeline = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M")
75
+ except ImportError as e:
76
+ raise ImportError(
77
+ "Kokoro TTS is required for audio output. "
78
+ "Install with: pip install kokoro>=0.9.2\n"
79
+ "Also requires espeak-ng: apt-get install espeak-ng"
80
+ ) from e
81
+ return self._tts_pipeline
82
+
83
+ def text_to_speech(self, text: str, voice: str = DEFAULT_TTS_VOICE) -> dict[str, Any]:
84
+ """Convert text to speech using Kokoro TTS.
85
+
86
+ Args:
87
+ text: Text to synthesize
88
+ voice: Kokoro voice ID (default: "af_heart")
89
+
90
+ Returns:
91
+ Dict with 'audio' (numpy array) and 'sample_rate' keys
92
+ """
93
+ if not text or not text.strip():
94
+ return {"audio": np.array([], dtype=np.float32), "sample_rate": TTS_SAMPLE_RATE}
95
+
96
+ # Generate audio chunks and concatenate
97
+ audio_chunks = []
98
+ for _, _, audio in self.tts_pipeline(text, voice=voice):
99
+ audio_chunks.append(audio)
100
+
101
+ audio = np.concatenate(audio_chunks) if audio_chunks else np.array([], dtype=np.float32)
102
+ return {"audio": audio, "sample_rate": TTS_SAMPLE_RATE}
103
 
104
  def _sanitize_parameters(self, **kwargs):
105
  """Intercept our custom parameters before parent class validates them."""
 
112
  kwargs.pop("hf_token", None)
113
  kwargs.pop("user_prompt", None)
114
  kwargs.pop("diarization_backend", None)
115
+ # TTS parameters
116
+ kwargs.pop("return_audio", None)
117
+ kwargs.pop("tts_voice", None)
118
 
119
  return super()._sanitize_parameters(**kwargs)
120
 
 
129
  inputs: Audio input (file path, dict with array/sampling_rate, etc.)
130
  return_timestamps: If True, return word-level timestamps using forced alignment
131
  return_speakers: If True, return speaker labels for each word
132
+ return_audio: If True, synthesize transcription as speech using Kokoro TTS
133
+ tts_voice: Kokoro voice ID for TTS output (default: "af_heart")
134
  user_prompt: Custom transcription prompt (default: "Transcribe: ")
135
  num_speakers: Exact number of speakers (if known, for diarization)
136
  min_speakers: Minimum number of speakers (for diarization)
 
139
 
140
  Returns:
141
  Dict with 'text' key, 'words' key if return_timestamps=True,
142
+ speaker labels on words if return_speakers=True,
143
+ and 'audio'/'sample_rate' keys if return_audio=True
144
  """
145
  # Extract our params before super().__call__ (which will also call _sanitize_parameters)
146
  return_timestamps = kwargs.pop("return_timestamps", False)
147
  return_speakers = kwargs.pop("return_speakers", False)
148
+ return_audio = kwargs.pop("return_audio", False)
149
+ tts_voice = kwargs.pop("tts_voice", DEFAULT_TTS_VOICE)
150
  user_prompt = kwargs.pop("user_prompt", None)
151
  diarization_params = {
152
  "num_speakers": kwargs.pop("num_speakers", None),
 
208
  result["speaker_segments"] = []
209
  result["diarization_error"] = str(e)
210
 
211
+ # Synthesize transcription as speech if requested
212
+ if return_audio:
213
+ text = result.get("text", "")
214
+ try:
215
+ tts_result = self.text_to_speech(text, voice=tts_voice)
216
+ result["audio"] = tts_result["audio"]
217
+ result["sample_rate"] = tts_result["sample_rate"]
218
+ except Exception as e:
219
+ result["audio"] = np.array([], dtype=np.float32)
220
+ result["sample_rate"] = TTS_SAMPLE_RATE
221
+ result["tts_error"] = str(e)
222
+
223
  # Clean up
224
  self._current_audio = None
225
  if original_prompt is not None:
 
334
 
335
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
336
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
337
+ text = strip_thinking(text)
338
  # Truncate repetitions at end of text
339
  text = _truncate_repetitions(text)
340
  return {"text": text}