tiny-audio-omni / s2s_pipeline.py
mazesmazes's picture
Update custom model files, README, and requirements
627eea2 verified
"""Speech-to-Speech pipeline for audio-in, audio-out generation.
This pipeline combines ASR (speech-to-text) with TTS (text-to-speech) to create
a unified speech-to-speech interface that can be used with HuggingFace's pipeline API.
Usage:
from transformers import pipeline
# Load as speech-to-speech pipeline
pipe = pipeline("speech-to-speech", model="mazesmazes/tiny-audio-omni", trust_remote_code=True)
# Process audio (outputs 48kHz by default for browser compatibility)
result = pipe("audio.wav")
# Returns: {"text": "transcription", "audio": np.array, "sampling_rate": 48000}
# With custom TTS voice
result = pipe("audio.wav", tts_voice="af_bella")
# Output at native TTS rate (24kHz) without resampling
result = pipe("audio.wav", output_sample_rate=24000)
# Get only audio output (for streaming/playback)
audio, sr = result["audio"], result["sampling_rate"]
# Streaming with built-in VAD (Voice Activity Detection)
for result in pipe.stream(audio_chunk_generator()):
print(result["text"])
play_audio(result["audio"], result["sampling_rate"])
"""
from collections.abc import Generator, Iterator
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import numpy as np
import scipy.signal
import torch
from transformers import Pipeline
from transformers.pipelines.audio_utils import ffmpeg_read
try:
from .asr_modeling import ASRModel
from .asr_pipeline import _truncate_repetitions, strip_thinking
except ImportError:
from asr_modeling import ASRModel # type: ignore[no-redef]
from asr_pipeline import _truncate_repetitions, strip_thinking # type: ignore[no-redef]
__all__ = ["SpeechToSpeechPipeline", "VADConfig"]
# Default TTS settings
DEFAULT_TTS_VOICE = "af_heart"
TTS_SAMPLE_RATE = 24000 # Native Kokoro TTS sample rate
DEFAULT_OUTPUT_SAMPLE_RATE = 48000 # Browser-friendly sample rate
# Default VAD settings
DEFAULT_VAD_THRESHOLD = 0.5
DEFAULT_SILENCE_DURATION_MS = 700
DEFAULT_INPUT_SAMPLE_RATE = 16000
@dataclass
class VADConfig:
"""Configuration for Voice Activity Detection.
Args:
threshold: VAD probability threshold (0.0-1.0). Higher = stricter.
silence_duration_ms: Milliseconds of silence before end-of-speech.
sample_rate: Expected input audio sample rate.
"""
threshold: float = DEFAULT_VAD_THRESHOLD
silence_duration_ms: int = DEFAULT_SILENCE_DURATION_MS
sample_rate: int = DEFAULT_INPUT_SAMPLE_RATE
@dataclass
class _VADState:
"""Internal state for VAD streaming."""
is_speaking: bool = False
silence_frames: int = 0
audio_buffer: list[np.ndarray] = field(default_factory=list)
def reset(self):
"""Reset state after processing an utterance."""
self.is_speaking = False
self.silence_frames = 0
self.audio_buffer = []
class SpeechToSpeechPipeline(Pipeline):
"""HuggingFace pipeline for speech-to-speech generation.
This pipeline takes audio input, transcribes it using an ASR model,
and synthesizes the response as speech using Kokoro TTS.
Args:
model: ASRModel instance for transcription
tts_voice: Default Kokoro TTS voice ID (default: "af_heart")
output_sample_rate: Output audio sample rate (default: 48000 for browser compatibility)
**kwargs: Additional arguments passed to Pipeline base class
Example:
>>> from transformers import pipeline
>>> pipe = pipeline("speech-to-speech", model="mazesmazes/tiny-audio-omni", trust_remote_code=True)
>>> result = pipe("audio.wav")
>>> result["text"] # Transcription/response text
>>> result["audio"] # Audio as numpy array (48kHz)
>>> result["sampling_rate"] # 48000
"""
model: ASRModel
def __init__(
self,
model: ASRModel,
tts_voice: str = DEFAULT_TTS_VOICE,
output_sample_rate: int = DEFAULT_OUTPUT_SAMPLE_RATE,
vad_config: VADConfig | None = None,
**kwargs,
):
"""Initialize Speech-to-Speech pipeline."""
feature_extractor = kwargs.pop("feature_extractor", None)
tokenizer = kwargs.pop("tokenizer", model.tokenizer)
if feature_extractor is None:
feature_extractor = model.get_processor().feature_extractor
super().__init__(
model=model,
feature_extractor=feature_extractor,
tokenizer=tokenizer,
**kwargs,
)
self.tts_voice = tts_voice
self.output_sample_rate = output_sample_rate
self.vad_config = vad_config or VADConfig()
self._tts_pipeline = None
self._vad_model = None
self._vad_utils = None
@property
def tts_pipeline(self):
"""Lazy-load Kokoro TTS pipeline on first use."""
if self._tts_pipeline is None:
try:
from kokoro import KPipeline
self._tts_pipeline = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M")
except ImportError as e:
raise ImportError(
"Kokoro TTS is required for speech-to-speech. "
"Install with: pip install kokoro>=0.9.2\n"
"Also requires espeak-ng: apt-get install espeak-ng"
) from e
return self._tts_pipeline
@property
def vad_model(self):
"""Lazy-load Silero VAD model on first use."""
if self._vad_model is None:
self._vad_model, self._vad_utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=False,
)
return self._vad_model
@property
def vad_utils(self):
"""Get VAD utilities (loads model if needed)."""
if self._vad_utils is None:
# Access vad_model to trigger loading
_ = self.vad_model
return self._vad_utils
def stream(
self,
audio_chunks: Iterator[np.ndarray],
tts_voice: str | None = None,
output_sample_rate: int | None = None,
vad_config: VADConfig | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Process streaming audio with VAD and yield responses.
Takes an iterator of audio chunks, detects speech using Silero VAD,
and yields responses when speech ends (after silence threshold).
Args:
audio_chunks: Iterator yielding audio chunks as numpy arrays (float32, 16kHz).
Each chunk should be ~100-500ms of audio.
tts_voice: Kokoro voice ID for TTS output (default: self.tts_voice)
output_sample_rate: Output sample rate (default: self.output_sample_rate)
vad_config: VAD configuration (default: self.vad_config)
Yields:
Dict with 'text', 'audio', and 'sampling_rate' for each detected utterance.
Example:
>>> def audio_generator():
... while True:
... chunk = get_audio_chunk() # Get ~100ms of audio
... if chunk is None:
... break
... yield chunk
>>> for result in pipe.stream(audio_generator()):
... print(result["text"])
... play_audio(result["audio"], result["sampling_rate"])
"""
config = vad_config or self.vad_config
voice = tts_voice or self.tts_voice
target_sr = output_sample_rate or self.output_sample_rate
state = _VADState()
vad_utils = self.vad_utils
if vad_utils is None:
raise RuntimeError("Failed to load Silero VAD model")
get_speech_timestamps = vad_utils[0]
# Calculate silence threshold in frames
# Assuming ~100ms chunks at 16kHz = 1600 samples per chunk
# silence_duration_ms / chunk_duration_ms = number of silent chunks
chunk_duration_ms = 100 # Approximate, will be calculated per chunk
silence_threshold = max(1, config.silence_duration_ms // chunk_duration_ms)
for chunk in audio_chunks:
# Ensure chunk is float32
if chunk.dtype != np.float32:
chunk = chunk.astype(np.float32)
# Normalize if needed (int16 range to float32)
if chunk.max() > 1.0 or chunk.min() < -1.0:
chunk = chunk / 32768.0
# Update chunk duration estimate for silence threshold
chunk_duration_ms = len(chunk) / config.sample_rate * 1000
silence_threshold = max(1, int(config.silence_duration_ms / chunk_duration_ms))
# Run VAD
speech_timestamps = get_speech_timestamps(
torch.from_numpy(chunk),
self.vad_model,
sampling_rate=config.sample_rate,
threshold=config.threshold,
)
has_speech = len(speech_timestamps) > 0
if has_speech:
if not state.is_speaking:
state.is_speaking = True
state.audio_buffer = []
state.audio_buffer.append(chunk)
state.silence_frames = 0
elif state.is_speaking:
state.audio_buffer.append(chunk)
state.silence_frames += 1
if state.silence_frames >= silence_threshold:
# End of speech detected - process the utterance
if state.audio_buffer:
full_audio = np.concatenate(state.audio_buffer)
result = self(
{"array": full_audio, "sampling_rate": config.sample_rate},
tts_voice=voice,
output_sample_rate=target_sr,
)
yield result
state.reset()
def _sanitize_parameters(
self,
tts_voice: str | None = None,
output_sample_rate: int | None = None,
return_text_only: bool = False,
user_prompt: str | None = None,
**kwargs,
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
"""Sanitize and route parameters to preprocessing, forward, and postprocessing."""
preprocess_kwargs: dict[str, Any] = {}
forward_kwargs: dict[str, Any] = {}
postprocess_kwargs: dict[str, Any] = {}
if tts_voice is not None:
postprocess_kwargs["tts_voice"] = tts_voice
if output_sample_rate is not None:
postprocess_kwargs["output_sample_rate"] = output_sample_rate
if return_text_only:
postprocess_kwargs["return_text_only"] = return_text_only
if user_prompt is not None:
forward_kwargs["user_prompt"] = user_prompt
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(self, inputs, **kwargs) -> dict[str, Any]:
"""Preprocess audio inputs for the model.
Handles various input formats:
- File path (str)
- Dict with 'array' and 'sampling_rate'
- Dict with 'raw' audio bytes
- Raw numpy array
- Bytes
Returns:
Dict with input_features and attention_mask for the model
"""
# Extract audio array from various formats
audio_array = self._extract_audio(inputs)
if audio_array is None:
raise ValueError(f"Could not extract audio from input type: {type(inputs)}")
# Use feature extractor to get mel features
processed = self.feature_extractor(
audio_array,
sampling_rate=self.feature_extractor.sampling_rate,
return_tensors="pt",
return_attention_mask=True,
)
return {
"input_features": processed.input_features,
"attention_mask": processed.attention_mask,
}
def _forward(self, model_inputs: dict, user_prompt: str | None = None) -> dict[str, Any]:
"""Run ASR model to generate text from audio.
Args:
model_inputs: Dict with input_features and attention_mask
user_prompt: Optional custom prompt for the model
Returns:
Dict with generated token IDs
"""
input_features = model_inputs["input_features"].to(self.model.device)
attention_mask = model_inputs["attention_mask"].to(self.model.device)
# Set custom prompt if provided
original_prompt = None
if user_prompt:
original_prompt = self.model.TRANSCRIBE_PROMPT
self.model.TRANSCRIBE_PROMPT = user_prompt
try:
generated_ids = self.model.generate(
input_features=input_features,
audio_attention_mask=attention_mask,
)
finally:
if original_prompt is not None:
self.model.TRANSCRIBE_PROMPT = original_prompt
return {"tokens": generated_ids}
def postprocess(
self,
model_outputs: dict,
tts_voice: str | None = None,
output_sample_rate: int | None = None,
return_text_only: bool = False,
) -> dict[str, Any]:
"""Convert model output to text and synthesize speech.
Args:
model_outputs: Dict with 'tokens' containing generated IDs
tts_voice: Kokoro voice ID (default: self.tts_voice)
output_sample_rate: Output sample rate (default: self.output_sample_rate)
return_text_only: If True, skip TTS and return only text
Returns:
Dict with 'text', 'audio' (numpy array), and 'sampling_rate'
"""
target_sr = output_sample_rate or self.output_sample_rate
tokens = model_outputs.get("tokens")
if tokens is None:
return {
"text": "",
"audio": np.array([], dtype=np.float32),
"sampling_rate": target_sr,
}
# Convert tokens to text
if torch.is_tensor(tokens):
tokens = tokens.cpu()
if tokens.dim() > 1:
tokens = tokens[0]
# Filter EOS tokens
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
eos_ids = self.model.generation_config.eos_token_id
if eos_ids is not None:
eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids}
tokens = [t for t in tokens.tolist() if t not in eos_set]
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
text = strip_thinking(text)
text = _truncate_repetitions(text)
result = {"text": text}
# Synthesize speech unless text-only requested
if not return_text_only:
voice = tts_voice or self.tts_voice
audio = self._synthesize_speech(text, voice)
# Resample if target sample rate differs from native TTS rate
audio = self._resample_audio(audio, TTS_SAMPLE_RATE, target_sr)
result["audio"] = audio
result["sampling_rate"] = target_sr
return result
def _synthesize_speech(self, text: str, voice: str) -> np.ndarray:
"""Synthesize speech from text using Kokoro TTS.
Args:
text: Text to synthesize
voice: Kokoro voice ID
Returns:
Audio as numpy array (float32, 24kHz native TTS rate)
"""
if not text or not text.strip():
return np.array([], dtype=np.float32)
try:
audio_chunks = []
for _, _, audio in self.tts_pipeline(text, voice=voice):
audio_chunks.append(audio)
if audio_chunks:
return np.concatenate(audio_chunks)
except Exception:
pass
return np.array([], dtype=np.float32)
def _resample_audio(self, audio: np.ndarray, from_sr: int, to_sr: int) -> np.ndarray:
"""Resample audio to target sample rate.
Args:
audio: Input audio array
from_sr: Source sample rate
to_sr: Target sample rate
Returns:
Resampled audio array
"""
if len(audio) == 0 or from_sr == to_sr:
return audio
num_samples = int(len(audio) * to_sr / from_sr)
return scipy.signal.resample(audio, num_samples).astype(np.float32)
def text_to_speech(
self,
text: str,
voice: str | None = None,
output_sample_rate: int | None = None,
) -> dict[str, Any]:
"""Convert text to speech using Kokoro TTS.
This is a convenience method for generating audio from text without
going through the full speech-to-speech pipeline.
Args:
text: Text to synthesize
voice: Kokoro voice ID (default: self.tts_voice)
output_sample_rate: Output sample rate (default: self.output_sample_rate)
Returns:
Dict with 'audio' (numpy array) and 'sampling_rate' keys
"""
voice = voice or self.tts_voice
target_sr = output_sample_rate or self.output_sample_rate
audio = self._synthesize_speech(text, voice)
audio = self._resample_audio(audio, TTS_SAMPLE_RATE, target_sr)
return {"audio": audio, "sampling_rate": target_sr}
def _extract_audio(self, inputs) -> np.ndarray | None:
"""Extract audio array from various input formats.
Args:
inputs: Audio input in various formats
Returns:
Audio as numpy array (float32) or None if extraction fails
"""
if isinstance(inputs, dict):
if "array" in inputs:
audio = inputs["array"]
if isinstance(audio, np.ndarray):
return audio.astype(np.float32) if audio.dtype != np.float32 else audio
return np.array(audio, dtype=np.float32)
if "raw" in inputs:
audio = inputs["raw"]
if isinstance(audio, np.ndarray):
return audio.astype(np.float32) if audio.dtype != np.float32 else audio
return np.array(audio, dtype=np.float32)
elif isinstance(inputs, str):
# File path
with Path(inputs).open("rb") as f:
return ffmpeg_read(f.read(), sampling_rate=16000)
elif isinstance(inputs, bytes):
return ffmpeg_read(inputs, sampling_rate=16000)
elif isinstance(inputs, np.ndarray):
return inputs.astype(np.float32) if inputs.dtype != np.float32 else inputs
return None
def __call__(self, inputs, **kwargs) -> dict[str, Any]:
"""Process audio input and return speech output.
Args:
inputs: Audio input (file path, dict with array, numpy array, or bytes)
tts_voice: Kokoro voice ID for TTS output (default: "af_heart")
return_text_only: If True, skip TTS and return only transcription
user_prompt: Custom prompt for the model
Returns:
Dict with:
- 'text': Transcription/response text
- 'audio': Synthesized speech as numpy array (float32)
- 'sampling_rate': Audio sample rate (24000)
"""
return super().__call__(inputs, **kwargs)