"""ASR pipeline for audio-to-text transcription with optional timestamps and diarization.""" import re from pathlib import Path from typing import Any import numpy as np import torch import transformers from transformers.pipelines.audio_utils import ffmpeg_read try: from .alignment import ForcedAligner from .asr_modeling import ASRModel from .diarization import SpeakerDiarizer except ImportError: from alignment import ForcedAligner # type: ignore[no-redef] from asr_modeling import ASRModel # type: ignore[no-redef] from diarization import SpeakerDiarizer # type: ignore[no-redef] # Re-export for backwards compatibility __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"] _THINK_TAG_RE = re.compile(r".*?\s*", flags=re.DOTALL) _DEFAULT_MIN_REPEATS = 3 _TRAILING_CHAR_RE = re.compile(rf"(.)\1{{{_DEFAULT_MIN_REPEATS - 1},}}$") _TRAILING_WORD_RE = re.compile( rf"\b(\w+)(?:\s+\1){{{_DEFAULT_MIN_REPEATS - 1},}}\s*$", re.IGNORECASE ) class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline): """ASR Pipeline for audio-to-text transcription.""" model: ASRModel def __init__(self, model: ASRModel, **kwargs): """Initialize ASR pipeline. Args: model: ASRModel instance for transcription **kwargs: Additional arguments (feature_extractor, tokenizer, device) """ feature_extractor = kwargs.pop("feature_extractor", None) tokenizer = kwargs.pop("tokenizer", model.tokenizer) if feature_extractor is None: feature_extractor = model.get_processor().feature_extractor super().__init__( model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs ) self._current_audio = None def _sanitize_parameters(self, **kwargs): """Intercept our custom parameters before parent class validates them.""" # Remove our custom parameters so parent doesn't see them kwargs.pop("return_timestamps", None) kwargs.pop("return_speakers", None) kwargs.pop("num_speakers", None) kwargs.pop("min_speakers", None) kwargs.pop("max_speakers", None) kwargs.pop("hf_token", None) kwargs.pop("user_prompt", None) kwargs.pop("diarization_backend", None) return super()._sanitize_parameters(**kwargs) def __call__( self, inputs, **kwargs, ): """Transcribe audio with optional word-level timestamps and speaker diarization. Args: inputs: Audio input (file path, dict with array/sampling_rate, etc.) return_timestamps: If True, return word-level timestamps using forced alignment return_speakers: If True, return speaker labels for each word user_prompt: Custom transcription prompt (default: "Transcribe: ") num_speakers: Exact number of speakers (if known, for diarization) min_speakers: Minimum number of speakers (for diarization) max_speakers: Maximum number of speakers (for diarization) **kwargs: Additional arguments passed to the pipeline Returns: Dict with 'text' key, 'words' key if return_timestamps=True, and speaker labels on words if return_speakers=True """ # Extract our params before super().__call__ (which will also call _sanitize_parameters) return_timestamps = kwargs.pop("return_timestamps", False) return_speakers = kwargs.pop("return_speakers", False) user_prompt = kwargs.pop("user_prompt", None) diarization_params = { "num_speakers": kwargs.pop("num_speakers", None), "min_speakers": kwargs.pop("min_speakers", None), "max_speakers": kwargs.pop("max_speakers", None), } if return_speakers: return_timestamps = True # Set custom user prompt if provided original_prompt = None if user_prompt: original_prompt = self.model.TRANSCRIBE_PROMPT self.model.TRANSCRIBE_PROMPT = user_prompt # Store audio for timestamp alignment and diarization if return_timestamps or return_speakers: self._current_audio = self._extract_audio(inputs) # Run standard transcription result = super().__call__(inputs, **kwargs) # Add timestamps if requested if return_timestamps and self._current_audio is not None: text = result.get("text", "") if text: try: words = ForcedAligner.align( self._current_audio["array"], text, sample_rate=self._current_audio.get("sampling_rate", 16000), ) result["words"] = words except Exception as e: result["words"] = [] result["timestamp_error"] = str(e) else: result["words"] = [] # Add speaker diarization if requested if return_speakers and self._current_audio is not None: try: # Run diarization speaker_segments = SpeakerDiarizer.diarize( self._current_audio["array"], sample_rate=self._current_audio.get("sampling_rate", 16000), **{k: v for k, v in diarization_params.items() if v is not None}, ) result["speaker_segments"] = speaker_segments # Assign speakers to words if result.get("words"): result["words"] = SpeakerDiarizer.assign_speakers_to_words( result["words"], speaker_segments, ) except Exception as e: result["speaker_segments"] = [] result["diarization_error"] = str(e) # Clean up self._current_audio = None if original_prompt is not None: self.model.TRANSCRIBE_PROMPT = original_prompt return result def _extract_audio(self, inputs) -> dict | None: """Extract audio array from various input formats using HF utilities.""" if isinstance(inputs, dict): if "array" in inputs: return { "array": inputs["array"], "sampling_rate": inputs.get("sampling_rate", 16000), } if "raw" in inputs: return { "array": inputs["raw"], "sampling_rate": inputs.get("sampling_rate", 16000), } elif isinstance(inputs, str): # File path - load audio using ffmpeg (same as HF pipeline) with Path(inputs).open("rb") as f: audio = ffmpeg_read(f.read(), sampling_rate=16000) return {"array": audio, "sampling_rate": 16000} elif isinstance(inputs, bytes): audio = ffmpeg_read(inputs, sampling_rate=16000) return {"array": audio, "sampling_rate": 16000} elif isinstance(inputs, np.ndarray): return {"array": inputs, "sampling_rate": 16000} return None def preprocess(self, inputs, **preprocess_params): """Preprocess audio inputs for the model. Args: inputs: Audio input (dict with array, file path, etc.) **preprocess_params: Additional preprocessing parameters Yields: Model input dicts with input_features and attention_mask """ # Handle dict with "array" key (from datasets) if isinstance(inputs, dict) and "array" in inputs: inputs = { "raw": inputs["array"], "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate), } for item in super().preprocess(inputs, **preprocess_params): if "is_last" not in item: item["is_last"] = True yield item def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]: """Run model forward pass to generate transcription. Args: model_inputs: Dict with input_features and attention_mask **generate_kwargs: Generation parameters. Pass ``output_scores=True`` (and ``return_dict_in_generate=True``, which is then implied) to also return per-step top-1 and top-2 log-probabilities — used by the eval harness's confidence metric. Backward-compatible: when unset, returns just token IDs as before. Returns: Dict with generated token IDs, and optionally per-step ``top1_logprob`` / ``top2_logprob`` tensors when scores were requested. """ # Extract audio features and is_last flag is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True input_features = model_inputs["input_features"].to(self.model.device) audio_attention_mask = model_inputs["attention_mask"].to(self.model.device) # Opt-in: when output_scores is requested, force return_dict_in_generate # so we get a GenerateOutput rather than a bare token tensor. want_scores = bool(generate_kwargs.get("output_scores", False)) if want_scores: generate_kwargs.setdefault("return_dict_in_generate", True) generate_output = self.model.generate( input_features=input_features, audio_attention_mask=audio_attention_mask, **generate_kwargs, ) # Default (no scores requested): generate returns a tensor of token IDs. if torch.is_tensor(generate_output): return {"tokens": generate_output, "is_last": is_last} # Scores requested: GenerateOutput dict-like with .sequences and .scores. # `scores` is a tuple of per-step logits tensors (batch, vocab); convert # each to log-probs and take top-2 to produce two short tensors over the # generation horizon — kept small (no full vocab) so this is cheap to # carry through postprocess. sequences = generate_output.sequences scores = generate_output.scores top1_logprobs: list[float] = [] top2_logprobs: list[float] = [] if scores: for step_logits in scores: step_logprobs = torch.log_softmax(step_logits[0].float(), dim=-1) top2 = torch.topk(step_logprobs, k=2) top1_logprobs.append(top2.values[0].item()) top2_logprobs.append(top2.values[1].item()) return { "tokens": sequences, "top1_logprob": top1_logprobs, "top2_logprob": top2_logprobs, "is_last": is_last, } def postprocess(self, model_outputs, **kwargs) -> dict[str, str]: """Convert model output tokens to text. Args: model_outputs: Dict with 'tokens' key containing generated IDs **kwargs: Additional postprocessing parameters Returns: Dict with 'text' key containing transcription """ # Handle list of outputs (from chunking) if isinstance(model_outputs, list): model_outputs = model_outputs[0] if model_outputs else {} tokens = model_outputs.get("tokens") if tokens is None: return super().postprocess(model_outputs, **kwargs) if torch.is_tensor(tokens): tokens = tokens.cpu() if tokens.dim() > 1: tokens = tokens[0] # Filter out eos tokens that the tokenizer doesn't recognize as special # (generation_config.eos_token_id may differ from tokenizer.eos_token_id) if hasattr(self, "model") and hasattr(self.model, "generation_config"): eos_ids = self.model.generation_config.eos_token_id if eos_ids is not None: eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids} tokens = [t for t in tokens.tolist() if t not in eos_set] text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip() # Strip ... tags (Qwen3 doesn't respect /no_think prompt) if "" in text: text = _THINK_TAG_RE.sub("", text).strip() text = _truncate_repetitions(text) out: dict[str, Any] = {"text": text} # Pass through per-step logprobs when _forward captured them (i.e. caller # passed output_scores=True). Lets eval harnesses compute confidence # stats without re-running the model. if "top1_logprob" in model_outputs: out["top1_logprob"] = model_outputs["top1_logprob"] if "top2_logprob" in model_outputs: out["top2_logprob"] = model_outputs["top2_logprob"] return out def _truncate_repetitions(text: str, min_repeats: int = 3) -> str: """Truncate repeated words/phrases/characters at end of text. Detects patterns like: - Repeated words: "the the the the" -> "the" - Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry" - Repeated characters: "444444" -> "4" Args: text: Input text to process min_repeats: Minimum repetitions to trigger truncation (default 3) Returns: Text with trailing repetitions removed """ if not text: return text if min_repeats == _DEFAULT_MIN_REPEATS: char_pattern = _TRAILING_CHAR_RE word_pattern = _TRAILING_WORD_RE else: char_pattern = re.compile(rf"(.)\1{{{min_repeats - 1},}}$") word_pattern = re.compile(rf"\b(\w+)(?:\s+\1){{{min_repeats - 1},}}\s*$", re.IGNORECASE) text = char_pattern.sub(r"\1", text) while word_pattern.search(text): text = word_pattern.sub(r"\1", text) # 3. Truncate repeated phrases (2-20 words) at end # e.g., "i am sorry i am sorry i am sorry" -> "i am sorry" words = text.split() if len(words) < min_repeats * 2: return text # Cheap pre-check: trailing window must contain duplicates for any phrase repeat # to be possible. set(window) == window means all unique → no repetition. window = words[-min_repeats * 2 :] if len(set(window)) == len(window): return text for phrase_len in range(2, min(21, len(words) // min_repeats + 1)): phrase_escaped = re.escape(" ".join(words[-phrase_len:])) phrase_pattern = re.compile( rf"(^|.*?\s)({phrase_escaped})(?:\s+{phrase_escaped}){{{min_repeats - 1},}}\s*$", re.IGNORECASE, ) match = phrase_pattern.match(text) if match: text = (match.group(1) + match.group(2)).strip() break return text