| """ASR pipeline for audio-to-text transcription with optional timestamps and diarization.""" |
|
|
| import re |
| from pathlib import Path |
| from typing import Any, Iterator, Union |
|
|
| import numpy as np |
| import torch |
| import transformers |
|
|
| try: |
| from .alignment import ForcedAligner |
| from .asr_modeling import ASRModel |
| from .diarization import SpeakerDiarizer |
| except ImportError: |
| from alignment import ForcedAligner |
| from asr_modeling import ASRModel |
| from diarization import SpeakerDiarizer |
|
|
| |
| __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline", "strip_thinking"] |
|
|
| |
| DEFAULT_TTS_VOICE = "af_heart" |
| TTS_SAMPLE_RATE = 24000 |
|
|
|
|
| def strip_thinking(text: str) -> str: |
| """Remove <think>...</think> tags from model output. |
| |
| Args: |
| text: Model output text that may contain thinking tags |
| |
| Returns: |
| Text with thinking content removed |
| """ |
| if not text: |
| return text |
| text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL) |
| return text.strip() |
|
|
|
|
| class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline): |
| """ASR Pipeline for audio-to-text transcription.""" |
|
|
| model: ASRModel |
|
|
| def __init__(self, model: ASRModel, **kwargs): |
| """Initialize ASR pipeline. |
| |
| Args: |
| model: ASRModel instance for transcription |
| **kwargs: Additional arguments (feature_extractor, tokenizer, device) |
| """ |
| feature_extractor = kwargs.pop("feature_extractor", None) |
| tokenizer = kwargs.pop("tokenizer", model.tokenizer) |
|
|
| if feature_extractor is None: |
| feature_extractor = model.get_processor().feature_extractor |
|
|
| super().__init__( |
| model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs |
| ) |
| self._current_audio = None |
| self._tts_pipeline = None |
|
|
| @property |
| def tts_pipeline(self): |
| """Lazy-load Kokoro TTS pipeline on first use.""" |
| if self._tts_pipeline is None: |
| try: |
| from kokoro import KPipeline |
|
|
| self._tts_pipeline = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M") |
| except ImportError as e: |
| raise ImportError( |
| "Kokoro TTS is required for audio output. " |
| "Install with: pip install kokoro>=0.9.2\n" |
| "Also requires espeak-ng: apt-get install espeak-ng" |
| ) from e |
| return self._tts_pipeline |
|
|
| def text_to_speech(self, text: str, voice: str = DEFAULT_TTS_VOICE) -> dict[str, Any]: |
| """Convert text to speech using Kokoro TTS. |
| |
| Args: |
| text: Text to synthesize |
| voice: Kokoro voice ID (default: "af_heart") |
| |
| Returns: |
| Dict with 'audio' (numpy array) and 'sample_rate' keys |
| """ |
| if not text or not text.strip(): |
| return {"audio": np.array([], dtype=np.float32), "sample_rate": TTS_SAMPLE_RATE} |
|
|
| |
| audio_chunks = [] |
| for _, _, audio in self.tts_pipeline(text, voice=voice): |
| audio_chunks.append(audio) |
|
|
| audio = np.concatenate(audio_chunks) if audio_chunks else np.array([], dtype=np.float32) |
| return {"audio": audio, "sample_rate": TTS_SAMPLE_RATE} |
|
|
| def transcribe_streaming( |
| self, |
| inputs: Union[str, bytes, np.ndarray, dict], |
| system_prompt: str | None = None, |
| ) -> Iterator[str]: |
| """Transcribe audio with streaming token output for low-latency applications. |
| |
| Yields partial transcript strings as tokens are generated, reducing |
| time-to-first-word compared to waiting for full transcription. |
| |
| Args: |
| inputs: Audio input in any supported format: |
| - str: File path to audio file |
| - bytes: Raw audio bytes |
| - np.ndarray: Audio samples as numpy array |
| - dict: {"array": np.ndarray, "sampling_rate": int} |
| system_prompt: Optional system prompt override (uses model's default if not provided) |
| |
| Yields: |
| Partial transcript text as each token is generated |
| |
| Example: |
| >>> for partial in pipeline.transcribe_streaming("audio.wav"): |
| ... print(partial, end="", flush=True) |
| """ |
| |
| audio_data = self._extract_audio(inputs) |
| if audio_data is None: |
| return |
|
|
| audio_array = audio_data["array"] |
| sample_rate = audio_data.get("sampling_rate", 16000) |
|
|
| |
| model_inputs = self.feature_extractor( |
| audio_array, |
| sampling_rate=sample_rate, |
| return_tensors="pt", |
| return_attention_mask=True, |
| ) |
|
|
| |
| device = self.model.device |
| model_dtype = next(self.model.parameters()).dtype |
| input_features = model_inputs.input_features.to(device, dtype=model_dtype) |
| attention_mask = model_inputs.attention_mask.to(device) |
|
|
| |
| yield from self.model.generate_streaming( |
| input_features=input_features, |
| audio_attention_mask=attention_mask, |
| system_prompt=system_prompt, |
| ) |
|
|
| def transcribe_streaming_with_audio( |
| self, |
| inputs: Union[str, bytes, np.ndarray, dict], |
| voice: str = DEFAULT_TTS_VOICE, |
| system_prompt: str | None = None, |
| ) -> Iterator[dict[str, Any]]: |
| """Transcribe audio with streaming text AND audio output. |
| |
| Yields partial text as tokens are generated, and audio chunks |
| as complete sentences are detected. This enables low-latency |
| voice agents that can start speaking before transcription completes. |
| |
| Args: |
| inputs: Audio input (same formats as transcribe_streaming) |
| voice: Kokoro TTS voice ID |
| system_prompt: Optional system prompt override (uses model's default if not provided) |
| |
| Yields: |
| Dicts with either: |
| - {"type": "text", "text": str, "interim": bool} for text updates |
| - {"type": "audio", "audio": np.ndarray, "sample_rate": int} for audio chunks |
| |
| Example: |
| >>> for chunk in pipeline.transcribe_streaming_with_audio(audio): |
| ... if chunk["type"] == "text": |
| ... print(chunk["text"], end="", flush=True) |
| ... elif chunk["type"] == "audio": |
| ... play_audio(chunk["audio"], chunk["sample_rate"]) |
| """ |
| import re |
|
|
| sentence_buffer = "" |
| full_text = "" |
|
|
| |
| sentence_end_pattern = re.compile(r"[.!?](?:\s|$)") |
|
|
| for token_text in self.transcribe_streaming(inputs, system_prompt=system_prompt): |
| full_text += token_text |
| sentence_buffer += token_text |
|
|
| |
| yield {"type": "text", "text": full_text, "interim": True} |
|
|
| |
| match = sentence_end_pattern.search(sentence_buffer) |
| if match: |
| |
| end_pos = match.end() |
| complete_text = sentence_buffer[:end_pos].strip() |
| sentence_buffer = sentence_buffer[end_pos:] |
|
|
| |
| if complete_text: |
| try: |
| tts_result = self.text_to_speech(complete_text, voice=voice) |
| if tts_result["audio"] is not None and len(tts_result["audio"]) > 0: |
| yield { |
| "type": "audio", |
| "audio": tts_result["audio"], |
| "sample_rate": tts_result["sample_rate"], |
| } |
| except Exception: |
| pass |
|
|
| |
| yield {"type": "text", "text": full_text, "interim": False} |
|
|
| |
| remaining = sentence_buffer.strip() |
| if remaining: |
| try: |
| tts_result = self.text_to_speech(remaining, voice=voice) |
| if tts_result["audio"] is not None and len(tts_result["audio"]) > 0: |
| yield { |
| "type": "audio", |
| "audio": tts_result["audio"], |
| "sample_rate": tts_result["sample_rate"], |
| } |
| except Exception: |
| pass |
|
|
| def _sanitize_parameters(self, **kwargs): |
| """Intercept our custom parameters before parent class validates them.""" |
| |
| kwargs.pop("return_timestamps", None) |
| kwargs.pop("return_speakers", None) |
| kwargs.pop("num_speakers", None) |
| kwargs.pop("min_speakers", None) |
| kwargs.pop("max_speakers", None) |
| kwargs.pop("hf_token", None) |
| kwargs.pop("user_prompt", None) |
| kwargs.pop("system_prompt", None) |
| kwargs.pop("diarization_backend", None) |
| |
| kwargs.pop("return_audio", None) |
| kwargs.pop("tts_voice", None) |
|
|
| return super()._sanitize_parameters(**kwargs) |
|
|
| def __call__( |
| self, |
| inputs, |
| **kwargs, |
| ): |
| """Transcribe audio with optional word-level timestamps and speaker diarization. |
| |
| Args: |
| inputs: Audio input (file path, dict with array/sampling_rate, etc.) |
| return_timestamps: If True, return word-level timestamps using forced alignment |
| return_speakers: If True, return speaker labels for each word |
| return_audio: If True, synthesize transcription as speech using Kokoro TTS |
| tts_voice: Kokoro voice ID for TTS output (default: "af_heart") |
| user_prompt: Custom transcription prompt (default: "Transcribe: ") |
| system_prompt: Custom system prompt override (uses model's default if not provided) |
| num_speakers: Exact number of speakers (if known, for diarization) |
| min_speakers: Minimum number of speakers (for diarization) |
| max_speakers: Maximum number of speakers (for diarization) |
| **kwargs: Additional arguments passed to the pipeline |
| |
| Returns: |
| Dict with 'text' key, 'words' key if return_timestamps=True, |
| speaker labels on words if return_speakers=True, |
| and 'audio'/'sample_rate' keys if return_audio=True |
| """ |
| |
| return_timestamps = kwargs.pop("return_timestamps", False) |
| return_speakers = kwargs.pop("return_speakers", False) |
| return_audio = kwargs.pop("return_audio", False) |
| tts_voice = kwargs.pop("tts_voice", DEFAULT_TTS_VOICE) |
| user_prompt = kwargs.pop("user_prompt", None) |
| system_prompt = kwargs.pop("system_prompt", None) |
| diarization_params = { |
| "num_speakers": kwargs.pop("num_speakers", None), |
| "min_speakers": kwargs.pop("min_speakers", None), |
| "max_speakers": kwargs.pop("max_speakers", None), |
| } |
|
|
| if return_speakers: |
| return_timestamps = True |
|
|
| |
| original_prompt = None |
| if user_prompt: |
| original_prompt = self.model.TRANSCRIBE_PROMPT |
| self.model.TRANSCRIBE_PROMPT = user_prompt |
|
|
| |
| original_system_prompt = None |
| if system_prompt: |
| original_system_prompt = self.model.system_prompt |
| self.model.system_prompt = system_prompt |
|
|
| |
| if return_timestamps or return_speakers: |
| self._current_audio = self._extract_audio(inputs) |
|
|
| |
| result = super().__call__(inputs, **kwargs) |
|
|
| |
| if return_timestamps and self._current_audio is not None: |
| text = result.get("text", "") |
| if text: |
| try: |
| words = ForcedAligner.align( |
| self._current_audio["array"], |
| text, |
| sample_rate=self._current_audio.get("sampling_rate", 16000), |
| ) |
| result["words"] = words |
| except Exception as e: |
| result["words"] = [] |
| result["timestamp_error"] = str(e) |
| else: |
| result["words"] = [] |
|
|
| |
| if return_speakers and self._current_audio is not None: |
| try: |
| |
| speaker_segments = SpeakerDiarizer.diarize( |
| self._current_audio["array"], |
| sample_rate=self._current_audio.get("sampling_rate", 16000), |
| **{k: v for k, v in diarization_params.items() if v is not None}, |
| ) |
| result["speaker_segments"] = speaker_segments |
|
|
| |
| if result.get("words"): |
| result["words"] = SpeakerDiarizer.assign_speakers_to_words( |
| result["words"], |
| speaker_segments, |
| ) |
| except Exception as e: |
| result["speaker_segments"] = [] |
| result["diarization_error"] = str(e) |
|
|
| |
| if return_audio: |
| text = result.get("text", "") |
| try: |
| tts_result = self.text_to_speech(text, voice=tts_voice) |
| result["audio"] = tts_result["audio"] |
| result["sample_rate"] = tts_result["sample_rate"] |
| except Exception as e: |
| result["audio"] = np.array([], dtype=np.float32) |
| result["sample_rate"] = TTS_SAMPLE_RATE |
| result["tts_error"] = str(e) |
|
|
| |
| self._current_audio = None |
| if original_prompt is not None: |
| self.model.TRANSCRIBE_PROMPT = original_prompt |
| if original_system_prompt is not None: |
| self.model.system_prompt = original_system_prompt |
|
|
| return result |
|
|
| def _extract_audio(self, inputs) -> dict | None: |
| """Extract audio array from various input formats. |
| |
| Supported input formats: |
| - str: File path to audio file |
| - bytes: Encoded audio (mp3, wav, etc.) - decoded via ffmpeg |
| - np.ndarray: Audio samples as float32 array |
| - dict with "array": Audio samples as numpy array |
| - dict with "raw": Alias for "array" (HF pipeline compat) |
| - dict with "raw_bytes": Raw PCM bytes (requires "dtype", optional "sampling_rate") |
| |
| For raw PCM bytes (e.g., from pipecat), use: |
| {"raw_bytes": pcm_bytes, "dtype": "int16", "sampling_rate": 16000} |
| """ |
| from transformers.pipelines.audio_utils import ffmpeg_read |
|
|
| if isinstance(inputs, dict): |
| if "array" in inputs: |
| return { |
| "array": inputs["array"], |
| "sampling_rate": inputs.get("sampling_rate", 16000), |
| } |
| if "raw" in inputs: |
| return { |
| "array": inputs["raw"], |
| "sampling_rate": inputs.get("sampling_rate", 16000), |
| } |
| if "raw_bytes" in inputs: |
| |
| dtype = inputs.get("dtype", "int16") |
| sample_rate = inputs.get("sampling_rate", 16000) |
| audio = np.frombuffer(inputs["raw_bytes"], dtype=dtype).astype(np.float32) |
| |
| if dtype == "int16": |
| audio = audio / 32768.0 |
| elif dtype == "int32": |
| audio = audio / 2147483648.0 |
| return {"array": audio, "sampling_rate": sample_rate} |
| elif isinstance(inputs, str): |
| |
| with Path(inputs).open("rb") as f: |
| audio = ffmpeg_read(f.read(), sampling_rate=16000) |
| return {"array": audio, "sampling_rate": 16000} |
| elif isinstance(inputs, bytes): |
| audio = ffmpeg_read(inputs, sampling_rate=16000) |
| return {"array": audio, "sampling_rate": 16000} |
| elif isinstance(inputs, np.ndarray): |
| return {"array": inputs, "sampling_rate": 16000} |
|
|
| return None |
|
|
| def preprocess(self, inputs, **preprocess_params): |
| """Preprocess audio inputs for the model. |
| |
| Args: |
| inputs: Audio input (dict with array, file path, etc.) |
| **preprocess_params: Additional preprocessing parameters |
| |
| Yields: |
| Model input dicts with input_features and attention_mask |
| """ |
| |
| if isinstance(inputs, dict) and "array" in inputs: |
| inputs = { |
| "raw": inputs["array"], |
| "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate), |
| } |
|
|
| for item in super().preprocess(inputs, **preprocess_params): |
| if "is_last" not in item: |
| item["is_last"] = True |
| yield item |
|
|
| def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]: |
| """Run model forward pass to generate transcription. |
| |
| Args: |
| model_inputs: Dict with input_features and attention_mask |
| **generate_kwargs: Generation parameters |
| |
| Returns: |
| Dict with generated token IDs |
| """ |
| |
| is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True |
|
|
| input_features = model_inputs["input_features"].to(self.model.device) |
| audio_attention_mask = model_inputs["attention_mask"].to(self.model.device) |
|
|
| generated_ids = self.model.generate( |
| input_features=input_features, |
| audio_attention_mask=audio_attention_mask, |
| **generate_kwargs, |
| ) |
|
|
| return {"tokens": generated_ids, "is_last": is_last} |
|
|
| def postprocess(self, model_outputs, **kwargs) -> dict[str, str]: |
| """Convert model output tokens to text. |
| |
| Args: |
| model_outputs: Dict with 'tokens' key containing generated IDs |
| **kwargs: Additional postprocessing parameters |
| |
| Returns: |
| Dict with 'text' key containing transcription |
| """ |
| |
| if isinstance(model_outputs, list): |
| model_outputs = model_outputs[0] if model_outputs else {} |
|
|
| tokens = model_outputs.get("tokens") |
| if tokens is None: |
| return super().postprocess(model_outputs, **kwargs) |
|
|
| if torch.is_tensor(tokens): |
| tokens = tokens.cpu() |
| if tokens.dim() > 1: |
| tokens = tokens[0] |
|
|
| |
| |
| if hasattr(self, "model") and hasattr(self.model, "generation_config"): |
| eos_ids = self.model.generation_config.eos_token_id |
| if eos_ids is not None: |
| eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids} |
| tokens = [t for t in tokens.tolist() if t not in eos_set] |
|
|
| text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip() |
| |
| text = strip_thinking(text) |
| |
| text = _truncate_repetitions(text) |
| return {"text": text} |
|
|
|
|
| def _truncate_repetitions(text: str, min_repeats: int = 3) -> str: |
| """Truncate repeated words/phrases/characters at end of text. |
| |
| Detects patterns like: |
| - Repeated words: "the the the the" -> "the" |
| - Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry" |
| - Repeated characters: "444444" -> "4" |
| |
| Args: |
| text: Input text to process |
| min_repeats: Minimum repetitions to trigger truncation (default 3) |
| |
| Returns: |
| Text with trailing repetitions removed |
| """ |
| if not text: |
| return text |
|
|
| |
| char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$") |
| text = char_pattern.sub(r"\1", text) |
|
|
| |
| word_pattern = re.compile( |
| r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE |
| ) |
| while word_pattern.search(text): |
| text = word_pattern.sub(r"\1", text) |
|
|
| |
| |
| words = text.split() |
| if len(words) >= min_repeats * 2: |
| |
| for phrase_len in range(2, min(21, len(words) // min_repeats + 1)): |
| |
| phrase = " ".join(words[-phrase_len:]) |
| |
| phrase_escaped = re.escape(phrase) |
| phrase_pattern = re.compile( |
| r"(^|.*?\s)(" |
| + phrase_escaped |
| + r")(?:\s+" |
| + phrase_escaped |
| + r"){" |
| + str(min_repeats - 1) |
| + r",}\s*$", |
| re.IGNORECASE, |
| ) |
| match = phrase_pattern.match(text) |
| if match: |
| |
| text = (match.group(1) + match.group(2)).strip() |
| words = text.split() |
| break |
|
|
| return text |
|
|