Instructions to use mazesmazes/tiny-audio-next-encoder with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use mazesmazes/tiny-audio-next-encoder with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="mazesmazes/tiny-audio-next-encoder", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("mazesmazes/tiny-audio-next-encoder", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ASR pipeline for audio-to-text transcription with optional timestamps and diarization.""" | |
| import re | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| import transformers | |
| from transformers.pipelines.audio_utils import ffmpeg_read | |
| try: | |
| from .alignment import ForcedAligner | |
| from .asr_modeling import ASRModel | |
| from .diarization import SpeakerDiarizer | |
| except ImportError: | |
| from alignment import ForcedAligner # type: ignore[no-redef] | |
| from asr_modeling import ASRModel # type: ignore[no-redef] | |
| from diarization import SpeakerDiarizer # type: ignore[no-redef] | |
| # Re-export for backwards compatibility | |
| __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"] | |
| _THINK_TAG_RE = re.compile(r"<think>.*?</think>\s*", flags=re.DOTALL) | |
| _DEFAULT_MIN_REPEATS = 3 | |
| _TRAILING_CHAR_RE = re.compile(rf"(.)\1{{{_DEFAULT_MIN_REPEATS - 1},}}$") | |
| _TRAILING_WORD_RE = re.compile( | |
| rf"\b(\w+)(?:\s+\1){{{_DEFAULT_MIN_REPEATS - 1},}}\s*$", re.IGNORECASE | |
| ) | |
| class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline): | |
| """ASR Pipeline for audio-to-text transcription.""" | |
| model: ASRModel | |
| def __init__(self, model: ASRModel, **kwargs): | |
| """Initialize ASR pipeline. | |
| Args: | |
| model: ASRModel instance for transcription | |
| **kwargs: Additional arguments (feature_extractor, tokenizer, device) | |
| """ | |
| feature_extractor = kwargs.pop("feature_extractor", None) | |
| tokenizer = kwargs.pop("tokenizer", model.tokenizer) | |
| if feature_extractor is None: | |
| feature_extractor = model.get_processor().feature_extractor | |
| super().__init__( | |
| model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs | |
| ) | |
| self._current_audio = None | |
| def _sanitize_parameters(self, **kwargs): | |
| """Intercept our custom parameters before parent class validates them.""" | |
| # Remove our custom parameters so parent doesn't see them | |
| kwargs.pop("return_timestamps", None) | |
| kwargs.pop("return_speakers", None) | |
| kwargs.pop("num_speakers", None) | |
| kwargs.pop("min_speakers", None) | |
| kwargs.pop("max_speakers", None) | |
| kwargs.pop("hf_token", None) | |
| kwargs.pop("user_prompt", None) | |
| kwargs.pop("diarization_backend", None) | |
| return super()._sanitize_parameters(**kwargs) | |
| def __call__( | |
| self, | |
| inputs, | |
| **kwargs, | |
| ): | |
| """Transcribe audio with optional word-level timestamps and speaker diarization. | |
| Args: | |
| inputs: Audio input (file path, dict with array/sampling_rate, etc.) | |
| return_timestamps: If True, return word-level timestamps using forced alignment | |
| return_speakers: If True, return speaker labels for each word | |
| user_prompt: Custom transcription prompt (default: "Transcribe: ") | |
| num_speakers: Exact number of speakers (if known, for diarization) | |
| min_speakers: Minimum number of speakers (for diarization) | |
| max_speakers: Maximum number of speakers (for diarization) | |
| **kwargs: Additional arguments passed to the pipeline | |
| Returns: | |
| Dict with 'text' key, 'words' key if return_timestamps=True, | |
| and speaker labels on words if return_speakers=True | |
| """ | |
| # Extract our params before super().__call__ (which will also call _sanitize_parameters) | |
| return_timestamps = kwargs.pop("return_timestamps", False) | |
| return_speakers = kwargs.pop("return_speakers", False) | |
| user_prompt = kwargs.pop("user_prompt", None) | |
| diarization_params = { | |
| "num_speakers": kwargs.pop("num_speakers", None), | |
| "min_speakers": kwargs.pop("min_speakers", None), | |
| "max_speakers": kwargs.pop("max_speakers", None), | |
| } | |
| if return_speakers: | |
| return_timestamps = True | |
| # Set custom user prompt if provided | |
| original_prompt = None | |
| if user_prompt: | |
| original_prompt = self.model.TRANSCRIBE_PROMPT | |
| self.model.TRANSCRIBE_PROMPT = user_prompt | |
| # Store audio for timestamp alignment and diarization | |
| if return_timestamps or return_speakers: | |
| self._current_audio = self._extract_audio(inputs) | |
| # Run standard transcription | |
| result = super().__call__(inputs, **kwargs) | |
| # Add timestamps if requested | |
| if return_timestamps and self._current_audio is not None: | |
| text = result.get("text", "") | |
| if text: | |
| try: | |
| words = ForcedAligner.align( | |
| self._current_audio["array"], | |
| text, | |
| sample_rate=self._current_audio.get("sampling_rate", 16000), | |
| ) | |
| result["words"] = words | |
| except Exception as e: | |
| result["words"] = [] | |
| result["timestamp_error"] = str(e) | |
| else: | |
| result["words"] = [] | |
| # Add speaker diarization if requested | |
| if return_speakers and self._current_audio is not None: | |
| try: | |
| # Run diarization | |
| speaker_segments = SpeakerDiarizer.diarize( | |
| self._current_audio["array"], | |
| sample_rate=self._current_audio.get("sampling_rate", 16000), | |
| **{k: v for k, v in diarization_params.items() if v is not None}, | |
| ) | |
| result["speaker_segments"] = speaker_segments | |
| # Assign speakers to words | |
| if result.get("words"): | |
| result["words"] = SpeakerDiarizer.assign_speakers_to_words( | |
| result["words"], | |
| speaker_segments, | |
| ) | |
| except Exception as e: | |
| result["speaker_segments"] = [] | |
| result["diarization_error"] = str(e) | |
| # Clean up | |
| self._current_audio = None | |
| if original_prompt is not None: | |
| self.model.TRANSCRIBE_PROMPT = original_prompt | |
| return result | |
| def _extract_audio(self, inputs) -> dict | None: | |
| """Extract audio array from various input formats using HF utilities.""" | |
| if isinstance(inputs, dict): | |
| if "array" in inputs: | |
| return { | |
| "array": inputs["array"], | |
| "sampling_rate": inputs.get("sampling_rate", 16000), | |
| } | |
| if "raw" in inputs: | |
| return { | |
| "array": inputs["raw"], | |
| "sampling_rate": inputs.get("sampling_rate", 16000), | |
| } | |
| elif isinstance(inputs, str): | |
| # File path - load audio using ffmpeg (same as HF pipeline) | |
| with Path(inputs).open("rb") as f: | |
| audio = ffmpeg_read(f.read(), sampling_rate=16000) | |
| return {"array": audio, "sampling_rate": 16000} | |
| elif isinstance(inputs, bytes): | |
| audio = ffmpeg_read(inputs, sampling_rate=16000) | |
| return {"array": audio, "sampling_rate": 16000} | |
| elif isinstance(inputs, np.ndarray): | |
| return {"array": inputs, "sampling_rate": 16000} | |
| return None | |
| def preprocess(self, inputs, **preprocess_params): | |
| """Preprocess audio inputs for the model. | |
| Args: | |
| inputs: Audio input (dict with array, file path, etc.) | |
| **preprocess_params: Additional preprocessing parameters | |
| Yields: | |
| Model input dicts with input_features and attention_mask | |
| """ | |
| # Handle dict with "array" key (from datasets) | |
| if isinstance(inputs, dict) and "array" in inputs: | |
| inputs = { | |
| "raw": inputs["array"], | |
| "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate), | |
| } | |
| for item in super().preprocess(inputs, **preprocess_params): | |
| if "is_last" not in item: | |
| item["is_last"] = True | |
| yield item | |
| def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]: | |
| """Run model forward pass to generate transcription. | |
| Args: | |
| model_inputs: Dict with input_features and attention_mask | |
| **generate_kwargs: Generation parameters. Pass ``output_scores=True`` | |
| (and ``return_dict_in_generate=True``, which is then implied) to | |
| also return per-step top-1 and top-2 log-probabilities — used by | |
| the eval harness's confidence metric. Backward-compatible: when | |
| unset, returns just token IDs as before. | |
| Returns: | |
| Dict with generated token IDs, and optionally per-step | |
| ``top1_logprob`` / ``top2_logprob`` tensors when scores were | |
| requested. | |
| """ | |
| # Extract audio features and is_last flag | |
| is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True | |
| input_features = model_inputs["input_features"].to(self.model.device) | |
| audio_attention_mask = model_inputs["attention_mask"].to(self.model.device) | |
| # Opt-in: when output_scores is requested, force return_dict_in_generate | |
| # so we get a GenerateOutput rather than a bare token tensor. | |
| want_scores = bool(generate_kwargs.get("output_scores", False)) | |
| if want_scores: | |
| generate_kwargs.setdefault("return_dict_in_generate", True) | |
| generate_output = self.model.generate( | |
| input_features=input_features, | |
| audio_attention_mask=audio_attention_mask, | |
| **generate_kwargs, | |
| ) | |
| # Default (no scores requested): generate returns a tensor of token IDs. | |
| if torch.is_tensor(generate_output): | |
| return {"tokens": generate_output, "is_last": is_last} | |
| # Scores requested: GenerateOutput dict-like with .sequences and .scores. | |
| # `scores` is a tuple of per-step logits tensors (batch, vocab); convert | |
| # each to log-probs and take top-2 to produce two short tensors over the | |
| # generation horizon — kept small (no full vocab) so this is cheap to | |
| # carry through postprocess. | |
| sequences = generate_output.sequences | |
| scores = generate_output.scores | |
| top1_logprobs: list[float] = [] | |
| top2_logprobs: list[float] = [] | |
| if scores: | |
| for step_logits in scores: | |
| step_logprobs = torch.log_softmax(step_logits[0].float(), dim=-1) | |
| top2 = torch.topk(step_logprobs, k=2) | |
| top1_logprobs.append(top2.values[0].item()) | |
| top2_logprobs.append(top2.values[1].item()) | |
| return { | |
| "tokens": sequences, | |
| "top1_logprob": top1_logprobs, | |
| "top2_logprob": top2_logprobs, | |
| "is_last": is_last, | |
| } | |
| def postprocess(self, model_outputs, **kwargs) -> dict[str, str]: | |
| """Convert model output tokens to text. | |
| Args: | |
| model_outputs: Dict with 'tokens' key containing generated IDs | |
| **kwargs: Additional postprocessing parameters | |
| Returns: | |
| Dict with 'text' key containing transcription | |
| """ | |
| # Handle list of outputs (from chunking) | |
| if isinstance(model_outputs, list): | |
| model_outputs = model_outputs[0] if model_outputs else {} | |
| tokens = model_outputs.get("tokens") | |
| if tokens is None: | |
| return super().postprocess(model_outputs, **kwargs) | |
| if torch.is_tensor(tokens): | |
| tokens = tokens.cpu() | |
| if tokens.dim() > 1: | |
| tokens = tokens[0] | |
| # Filter out eos tokens that the tokenizer doesn't recognize as special | |
| # (generation_config.eos_token_id may differ from tokenizer.eos_token_id) | |
| if hasattr(self, "model") and hasattr(self.model, "generation_config"): | |
| eos_ids = self.model.generation_config.eos_token_id | |
| if eos_ids is not None: | |
| eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids} | |
| tokens = [t for t in tokens.tolist() if t not in eos_set] | |
| text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip() | |
| # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt) | |
| if "<think>" in text: | |
| text = _THINK_TAG_RE.sub("", text).strip() | |
| text = _truncate_repetitions(text) | |
| out: dict[str, Any] = {"text": text} | |
| # Pass through per-step logprobs when _forward captured them (i.e. caller | |
| # passed output_scores=True). Lets eval harnesses compute confidence | |
| # stats without re-running the model. | |
| if "top1_logprob" in model_outputs: | |
| out["top1_logprob"] = model_outputs["top1_logprob"] | |
| if "top2_logprob" in model_outputs: | |
| out["top2_logprob"] = model_outputs["top2_logprob"] | |
| return out | |
| def _truncate_repetitions(text: str, min_repeats: int = 3) -> str: | |
| """Truncate repeated words/phrases/characters at end of text. | |
| Detects patterns like: | |
| - Repeated words: "the the the the" -> "the" | |
| - Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry" | |
| - Repeated characters: "444444" -> "4" | |
| Args: | |
| text: Input text to process | |
| min_repeats: Minimum repetitions to trigger truncation (default 3) | |
| Returns: | |
| Text with trailing repetitions removed | |
| """ | |
| if not text: | |
| return text | |
| if min_repeats == _DEFAULT_MIN_REPEATS: | |
| char_pattern = _TRAILING_CHAR_RE | |
| word_pattern = _TRAILING_WORD_RE | |
| else: | |
| char_pattern = re.compile(rf"(.)\1{{{min_repeats - 1},}}$") | |
| word_pattern = re.compile(rf"\b(\w+)(?:\s+\1){{{min_repeats - 1},}}\s*$", re.IGNORECASE) | |
| text = char_pattern.sub(r"\1", text) | |
| while word_pattern.search(text): | |
| text = word_pattern.sub(r"\1", text) | |
| # 3. Truncate repeated phrases (2-20 words) at end | |
| # e.g., "i am sorry i am sorry i am sorry" -> "i am sorry" | |
| words = text.split() | |
| if len(words) < min_repeats * 2: | |
| return text | |
| # Cheap pre-check: trailing window must contain duplicates for any phrase repeat | |
| # to be possible. set(window) == window means all unique → no repetition. | |
| window = words[-min_repeats * 2 :] | |
| if len(set(window)) == len(window): | |
| return text | |
| for phrase_len in range(2, min(21, len(words) // min_repeats + 1)): | |
| phrase_escaped = re.escape(" ".join(words[-phrase_len:])) | |
| phrase_pattern = re.compile( | |
| rf"(^|.*?\s)({phrase_escaped})(?:\s+{phrase_escaped}){{{min_repeats - 1},}}\s*$", | |
| re.IGNORECASE, | |
| ) | |
| match = phrase_pattern.match(text) | |
| if match: | |
| text = (match.group(1) + match.group(2)).strip() | |
| break | |
| return text | |