"""ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
import re
from pathlib import Path
from typing import Any
import numpy as np
import torch
import transformers
from transformers.pipelines.audio_utils import ffmpeg_read
try:
from .alignment import ForcedAligner
from .asr_modeling import ASRModel
from .diarization import SpeakerDiarizer
except ImportError:
from alignment import ForcedAligner # type: ignore[no-redef]
from asr_modeling import ASRModel # type: ignore[no-redef]
from diarization import SpeakerDiarizer # type: ignore[no-redef]
# Re-export for backwards compatibility
__all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"]
_THINK_TAG_RE = re.compile(r".*?\s*", flags=re.DOTALL)
_DEFAULT_MIN_REPEATS = 3
_TRAILING_CHAR_RE = re.compile(rf"(.)\1{{{_DEFAULT_MIN_REPEATS - 1},}}$")
_TRAILING_WORD_RE = re.compile(
rf"\b(\w+)(?:\s+\1){{{_DEFAULT_MIN_REPEATS - 1},}}\s*$", re.IGNORECASE
)
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
"""ASR Pipeline for audio-to-text transcription."""
model: ASRModel
def __init__(self, model: ASRModel, **kwargs):
"""Initialize ASR pipeline.
Args:
model: ASRModel instance for transcription
**kwargs: Additional arguments (feature_extractor, tokenizer, device)
"""
feature_extractor = kwargs.pop("feature_extractor", None)
tokenizer = kwargs.pop("tokenizer", model.tokenizer)
if feature_extractor is None:
feature_extractor = model.get_processor().feature_extractor
super().__init__(
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
)
self._current_audio = None
def _sanitize_parameters(self, **kwargs):
"""Intercept our custom parameters before parent class validates them."""
# Remove our custom parameters so parent doesn't see them
kwargs.pop("return_timestamps", None)
kwargs.pop("return_speakers", None)
kwargs.pop("num_speakers", None)
kwargs.pop("min_speakers", None)
kwargs.pop("max_speakers", None)
kwargs.pop("hf_token", None)
kwargs.pop("user_prompt", None)
kwargs.pop("diarization_backend", None)
return super()._sanitize_parameters(**kwargs)
def __call__(
self,
inputs,
**kwargs,
):
"""Transcribe audio with optional word-level timestamps and speaker diarization.
Args:
inputs: Audio input (file path, dict with array/sampling_rate, etc.)
return_timestamps: If True, return word-level timestamps using forced alignment
return_speakers: If True, return speaker labels for each word
user_prompt: Custom transcription prompt (default: "Transcribe: ")
num_speakers: Exact number of speakers (if known, for diarization)
min_speakers: Minimum number of speakers (for diarization)
max_speakers: Maximum number of speakers (for diarization)
**kwargs: Additional arguments passed to the pipeline
Returns:
Dict with 'text' key, 'words' key if return_timestamps=True,
and speaker labels on words if return_speakers=True
"""
# Extract our params before super().__call__ (which will also call _sanitize_parameters)
return_timestamps = kwargs.pop("return_timestamps", False)
return_speakers = kwargs.pop("return_speakers", False)
user_prompt = kwargs.pop("user_prompt", None)
diarization_params = {
"num_speakers": kwargs.pop("num_speakers", None),
"min_speakers": kwargs.pop("min_speakers", None),
"max_speakers": kwargs.pop("max_speakers", None),
}
if return_speakers:
return_timestamps = True
# Set custom user prompt if provided
original_prompt = None
if user_prompt:
original_prompt = self.model.TRANSCRIBE_PROMPT
self.model.TRANSCRIBE_PROMPT = user_prompt
# Store audio for timestamp alignment and diarization
if return_timestamps or return_speakers:
self._current_audio = self._extract_audio(inputs)
# Run standard transcription
result = super().__call__(inputs, **kwargs)
# Add timestamps if requested
if return_timestamps and self._current_audio is not None:
text = result.get("text", "")
if text:
try:
words = ForcedAligner.align(
self._current_audio["array"],
text,
sample_rate=self._current_audio.get("sampling_rate", 16000),
)
result["words"] = words
except Exception as e:
result["words"] = []
result["timestamp_error"] = str(e)
else:
result["words"] = []
# Add speaker diarization if requested
if return_speakers and self._current_audio is not None:
try:
# Run diarization
speaker_segments = SpeakerDiarizer.diarize(
self._current_audio["array"],
sample_rate=self._current_audio.get("sampling_rate", 16000),
**{k: v for k, v in diarization_params.items() if v is not None},
)
result["speaker_segments"] = speaker_segments
# Assign speakers to words
if result.get("words"):
result["words"] = SpeakerDiarizer.assign_speakers_to_words(
result["words"],
speaker_segments,
)
except Exception as e:
result["speaker_segments"] = []
result["diarization_error"] = str(e)
# Clean up
self._current_audio = None
if original_prompt is not None:
self.model.TRANSCRIBE_PROMPT = original_prompt
return result
def _extract_audio(self, inputs) -> dict | None:
"""Extract audio array from various input formats using HF utilities."""
if isinstance(inputs, dict):
if "array" in inputs:
return {
"array": inputs["array"],
"sampling_rate": inputs.get("sampling_rate", 16000),
}
if "raw" in inputs:
return {
"array": inputs["raw"],
"sampling_rate": inputs.get("sampling_rate", 16000),
}
elif isinstance(inputs, str):
# File path - load audio using ffmpeg (same as HF pipeline)
with Path(inputs).open("rb") as f:
audio = ffmpeg_read(f.read(), sampling_rate=16000)
return {"array": audio, "sampling_rate": 16000}
elif isinstance(inputs, bytes):
audio = ffmpeg_read(inputs, sampling_rate=16000)
return {"array": audio, "sampling_rate": 16000}
elif isinstance(inputs, np.ndarray):
return {"array": inputs, "sampling_rate": 16000}
return None
def preprocess(self, inputs, **preprocess_params):
"""Preprocess audio inputs for the model.
Args:
inputs: Audio input (dict with array, file path, etc.)
**preprocess_params: Additional preprocessing parameters
Yields:
Model input dicts with input_features and attention_mask
"""
# Handle dict with "array" key (from datasets)
if isinstance(inputs, dict) and "array" in inputs:
inputs = {
"raw": inputs["array"],
"sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
}
for item in super().preprocess(inputs, **preprocess_params):
if "is_last" not in item:
item["is_last"] = True
yield item
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
"""Run model forward pass to generate transcription.
Args:
model_inputs: Dict with input_features and attention_mask
**generate_kwargs: Generation parameters. Pass ``output_scores=True``
(and ``return_dict_in_generate=True``, which is then implied) to
also return per-step top-1 and top-2 log-probabilities — used by
the eval harness's confidence metric. Backward-compatible: when
unset, returns just token IDs as before.
Returns:
Dict with generated token IDs, and optionally per-step
``top1_logprob`` / ``top2_logprob`` tensors when scores were
requested.
"""
# Extract audio features and is_last flag
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
input_features = model_inputs["input_features"].to(self.model.device)
audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
# Opt-in: when output_scores is requested, force return_dict_in_generate
# so we get a GenerateOutput rather than a bare token tensor.
want_scores = bool(generate_kwargs.get("output_scores", False))
if want_scores:
generate_kwargs.setdefault("return_dict_in_generate", True)
generate_output = self.model.generate(
input_features=input_features,
audio_attention_mask=audio_attention_mask,
**generate_kwargs,
)
# Default (no scores requested): generate returns a tensor of token IDs.
if torch.is_tensor(generate_output):
return {"tokens": generate_output, "is_last": is_last}
# Scores requested: GenerateOutput dict-like with .sequences and .scores.
# `scores` is a tuple of per-step logits tensors (batch, vocab); convert
# each to log-probs and take top-2 to produce two short tensors over the
# generation horizon — kept small (no full vocab) so this is cheap to
# carry through postprocess.
sequences = generate_output.sequences
scores = generate_output.scores
top1_logprobs: list[float] = []
top2_logprobs: list[float] = []
if scores:
for step_logits in scores:
step_logprobs = torch.log_softmax(step_logits[0].float(), dim=-1)
top2 = torch.topk(step_logprobs, k=2)
top1_logprobs.append(top2.values[0].item())
top2_logprobs.append(top2.values[1].item())
return {
"tokens": sequences,
"top1_logprob": top1_logprobs,
"top2_logprob": top2_logprobs,
"is_last": is_last,
}
def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
"""Convert model output tokens to text.
Args:
model_outputs: Dict with 'tokens' key containing generated IDs
**kwargs: Additional postprocessing parameters
Returns:
Dict with 'text' key containing transcription
"""
# Handle list of outputs (from chunking)
if isinstance(model_outputs, list):
model_outputs = model_outputs[0] if model_outputs else {}
tokens = model_outputs.get("tokens")
if tokens is None:
return super().postprocess(model_outputs, **kwargs)
if torch.is_tensor(tokens):
tokens = tokens.cpu()
if tokens.dim() > 1:
tokens = tokens[0]
# Filter out eos tokens that the tokenizer doesn't recognize as special
# (generation_config.eos_token_id may differ from tokenizer.eos_token_id)
if hasattr(self, "model") and hasattr(self.model, "generation_config"):
eos_ids = self.model.generation_config.eos_token_id
if eos_ids is not None:
eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids}
tokens = [t for t in tokens.tolist() if t not in eos_set]
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
# Strip ... tags (Qwen3 doesn't respect /no_think prompt)
if "" in text:
text = _THINK_TAG_RE.sub("", text).strip()
text = _truncate_repetitions(text)
out: dict[str, Any] = {"text": text}
# Pass through per-step logprobs when _forward captured them (i.e. caller
# passed output_scores=True). Lets eval harnesses compute confidence
# stats without re-running the model.
if "top1_logprob" in model_outputs:
out["top1_logprob"] = model_outputs["top1_logprob"]
if "top2_logprob" in model_outputs:
out["top2_logprob"] = model_outputs["top2_logprob"]
return out
def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
"""Truncate repeated words/phrases/characters at end of text.
Detects patterns like:
- Repeated words: "the the the the" -> "the"
- Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry"
- Repeated characters: "444444" -> "4"
Args:
text: Input text to process
min_repeats: Minimum repetitions to trigger truncation (default 3)
Returns:
Text with trailing repetitions removed
"""
if not text:
return text
if min_repeats == _DEFAULT_MIN_REPEATS:
char_pattern = _TRAILING_CHAR_RE
word_pattern = _TRAILING_WORD_RE
else:
char_pattern = re.compile(rf"(.)\1{{{min_repeats - 1},}}$")
word_pattern = re.compile(rf"\b(\w+)(?:\s+\1){{{min_repeats - 1},}}\s*$", re.IGNORECASE)
text = char_pattern.sub(r"\1", text)
while word_pattern.search(text):
text = word_pattern.sub(r"\1", text)
# 3. Truncate repeated phrases (2-20 words) at end
# e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
words = text.split()
if len(words) < min_repeats * 2:
return text
# Cheap pre-check: trailing window must contain duplicates for any phrase repeat
# to be possible. set(window) == window means all unique → no repetition.
window = words[-min_repeats * 2 :]
if len(set(window)) == len(window):
return text
for phrase_len in range(2, min(21, len(words) // min_repeats + 1)):
phrase_escaped = re.escape(" ".join(words[-phrase_len:]))
phrase_pattern = re.compile(
rf"(^|.*?\s)({phrase_escaped})(?:\s+{phrase_escaped}){{{min_repeats - 1},}}\s*$",
re.IGNORECASE,
)
match = phrase_pattern.match(text)
if match:
text = (match.group(1) + match.group(2)).strip()
break
return text