fciannella's picture
Working with service run on 7860
53ea588
raw
history blame
26.9 kB
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License
"""NVIDIA Riva speech services implementation.
This module provides integration with NVIDIA Riva's speech services, including:
- Text-to-Speech (TTS) with support for multiple voices and languages
- Automatic Speech Recognition (ASR) with streaming capabilities
The services can be configured to use either a local Riva Speech Server or
NVIDIA's cloud-hosted models through NVCF.
For documentation on how to configure the Riva Speech models, please refer to the
[Riva Speech Quick Start Guide](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html).
"""
import asyncio
import concurrent.futures
from collections.abc import AsyncGenerator
from pathlib import Path
import riva.client
from loguru import logger
from pipecat.audio.vad.vad_analyzer import VADState
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
StartFrame,
StartInterruptionFrame,
StopInterruptionFrame,
TranscriptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSTextFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.services.stt_service import STTService
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.time import time_now_iso8601
from riva.client.proto.riva_audio_pb2 import AudioEncoding
from nvidia_pipecat.frames.riva import RivaInterimTranscriptionFrame
from nvidia_pipecat.utils.tracing import AttachmentStrategy, traceable, traced
@traceable
class RivaTTSService(TTSService):
"""NVIDIA Riva Text-to-Speech service implementation.
Provides speech synthesis using NVIDIA's Riva TTS models with support for
multiple voices, languages, and custom dictionaries.
"""
def __init__(
self,
*,
api_key: str | None = None,
server: str = "grpc.nvcf.nvidia.com:443",
voice_id: str = "English-US.Female-1",
sample_rate: int = 16000,
function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972",
language: Language | None = Language.EN_US,
zero_shot_quality: int | None = 20,
model: str = "fastpitch-hifigan-tts",
custom_dictionary: dict | None = None,
encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
zero_shot_audio_prompt_file: Path | None = None,
audio_prompt_encoding: AudioEncoding = AudioEncoding.ENCODING_UNSPECIFIED,
use_ssl: bool = False,
text_aggregator: BaseTextAggregator | None = None,
**kwargs,
):
"""Initializes the Riva TTS service.
Args:
api_key (str | None, optional): API key for authentication. Defaults to None.
server (str, optional): Server address for Riva service. Defaults to "grpc.nvcf.nvidia.com:443".
voice_id (str, optional): Voice identifier. Defaults to "English-US.Female-1".
sample_rate (int, optional): Audio sample rate in Hz. Defaults to 16000.
function_id (str, optional): Function identifier for the service.
Defaults to "0149dedb-2be8-4195-b9a0-e57e0e14f972".
language (Language | None, optional): Language for synthesis. Defaults to Language.EN_US.
zero_shot_quality (int | None, optional): Quality level for synthesis. Defaults to 20.
model (str, optional): Model name for synthesis. Defaults to "fastpitch-hifigan-tts".
custom_dictionary (dict | None, optional): Custom pronunciation dictionary. Defaults to None.
encoding (AudioEncoding, optional): Audio encoding format. Defaults to AudioEncoding.LINEAR_PCM.
zero_shot_audio_prompt_file (str | None, optional): Path to audio prompt file. Defaults to None.
audio_prompt_encoding (AudioEncoding, optional): Encoding of audio prompt.
Defaults to AudioEncoding.LINEAR_PCM.
use_ssl (bool, optional): Whether to use SSL for connection. Defaults to False.
text_aggregator (BaseTextAggregator | None, optional): Text aggregator for sentence detection.
Defaults to None, which uses SimpleTextAggregator.
**kwargs: Additional keyword arguments passed to parent class.
Raises:
Exception: If required modules are missing or connection fails.
Usage:
If server is not set then it defaults to "grpc.nvcf.nvidia.com:443" and use NVCF hosted models.
Update function ID to use a different NVCF model. API key is required for NVCF hosted models.
For using locally deployed Riva Speech Server, set server to "localhost:50051" and
follow the quick start guide to setup the server.
"""
super().__init__(
sample_rate=sample_rate,
push_text_frames=False,
push_stop_frames=True,
text_aggregator=text_aggregator,
**kwargs,
)
self._api_key = api_key
self._function_id = function_id
self._voice_id = voice_id
self._sample_rate = sample_rate
self._language_code = language
self._zero_shot_quality = zero_shot_quality
self.set_model_name(model)
self.set_voice(voice_id)
self._custom_dictionary = custom_dictionary
self._encoding = encoding
self._zero_shot_audio_prompt_file = zero_shot_audio_prompt_file
self._audio_prompt_encoding = audio_prompt_encoding
metadata = [
["function-id", function_id],
["authorization", f"Bearer {api_key}"],
]
if server == "grpc.nvcf.nvidia.com:443":
use_ssl = True
try:
auth = riva.client.Auth(None, use_ssl, server, metadata)
self._service = riva.client.SpeechSynthesisService(auth)
# warm up the service
_ = self._service.stub.GetRivaSynthesisConfig(riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest())
except Exception as e:
logger.error(
"In order to use nvidia Riva TTSService or STTService, you will either need a locally "
"deployed Riva Speech Server with ASR and TTS models (Follow riva quick start guide at "
"https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html and "
"edit the config file to deploy which model you want to use and set the server url to "
"localhost:50051), or you can set the NVIDIA_API_KEY environment "
"variable to connect with nvcf hosted models."
)
raise Exception(f"Missing module: {e}") from e
def can_generate_metrics(self) -> bool:
"""Check if the service can generate metrics.
Returns:
bool: True as this service supports metric generation.
"""
return True
async def _push_tts_frames(self, text: str):
"""Override base class method to push text frames immediately."""
# Remove leading newlines only
text = text.lstrip("\n")
# Don't send only whitespace. This causes problems for some TTS models. But also don't
# strip all whitespace, as whitespace can influence prosody.
if not text.strip():
return
# This is just a flag that indicates if we sent something to the TTS
# service. It will be cleared if we sent text because of a TTSSpeakFrame
# or when we received an LLMFullResponseEndFrame
self._processing_text = True
await self.start_processing_metrics()
# Process all filter.
for filter in self._text_filters:
filter.reset_interruption()
text = filter.filter(text)
if text:
await self.process_generator(self.run_tts(text))
await self.stop_processing_metrics()
@traced(attachment_strategy=AttachmentStrategy.NONE, name="tts")
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Run text-to-speech synthesis."""
# Check if text contains any alphanumeric characters
if not any(c.isalnum() for c in text):
logger.debug(f"Skipping TTS for text with no alphanumeric characters: [{text}]")
return
logger.debug(f"Generating TTS: [{text.strip()}]")
responses = self._service.synthesize_online(
text.strip(),
self._voice_id,
self._language_code,
sample_rate_hz=self._sample_rate,
zero_shot_audio_prompt_file=self._zero_shot_audio_prompt_file,
audio_prompt_encoding=self._audio_prompt_encoding,
zero_shot_quality=self._zero_shot_quality,
custom_dictionary=self._custom_dictionary,
encoding=self._encoding,
)
await self.start_ttfb_metrics()
yield TTSStartedFrame()
# Push text frame immediately after TTSStartedFrame.
# TTSService base processor will push the tts text after sending generated tts audio downstream
# Need to push the text before audio frame for better TTS transcription.
yield TTSTextFrame(text)
async def get_next_response(iterator):
def _next():
try:
return next(iterator)
except StopIteration:
return None
return await asyncio.get_event_loop().run_in_executor(None, _next)
response_iterator = iter(responses)
total_audio_length = 0
while (resp := await get_next_response(response_iterator)) is not None:
try:
total_audio_length += len(resp.audio)
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=resp.audio,
sample_rate=self._sample_rate,
num_channels=1,
)
yield frame
except Exception as e:
logger.error(f"{self} Error processing TTS response: {e}")
break
await self.start_tts_usage_metrics(text)
logger.debug(f"Total generated TTS audio length: {total_audio_length / (self._sample_rate * 2)} seconds")
yield TTSStoppedFrame()
@traceable
class RivaASRService(STTService):
"""NVIDIA Riva Automatic Speech Recognition service.
Provides streaming speech recognition using Riva ASR models with support for:
- Real-time transcription
- Interim results
- Interruption handling
- Voice activity detection
- Language model customization
"""
def __init__(
self,
*,
api_key: str | None = None,
server: str = "grpc.nvcf.nvidia.com:443",
function_id: str = "1598d209-5e27-4d3c-8079-4751568b1081",
language: Language | None = Language.EN_US,
model: str = "parakeet-1.1b-en-US-asr-streaming-asr-bls-ensemble",
profanity_filter: bool = False,
automatic_punctuation: bool = False,
no_verbatim_transcripts: bool = True,
boosted_lm_words: dict | None = None,
boosted_lm_score: float = 4.0,
start_history: int = -1,
start_threshold: float = -1.0,
stop_history: int = 500,
stop_threshold: float = -1.0,
stop_history_eou: int = 240,
stop_threshold_eou: float = -1.0,
custom_configuration: str = "enable_vad_endpointing:true,neural_vad.onset:0.65,apply_partial_itn:true",
sample_rate: int = 16000,
audio_channel_count: int = 1,
max_alternatives: int = 1,
interim_results: bool = True,
generate_interruptions: bool = False, # Only set to True if transport VAD is disabled
idle_timeout: int = 30, # Timeout for idle Riva ASR request
use_ssl: bool = False,
**kwargs,
):
"""Initializes the Riva ASR service.
Args:
api_key: NVIDIA API key for cloud access.
server: Riva server address.
function_id: NVCF function identifier.
language: Language for recognition.
model: ASR model name.
profanity_filter: Enable profanity filtering.
automatic_punctuation: Enable automatic punctuation.
no_verbatim_transcripts: Disable verbatim transcripts.
boosted_lm_words: Words to boost in language model.
boosted_lm_score: Score for boosted words.
start_history: VAD start history frames.
start_threshold: VAD start threshold.
stop_history: VAD stop history frames.
stop_threshold: VAD stop threshold.
stop_history_eou: End-of-utterance history frames.
stop_threshold_eou: End-of-utterance threshold.
custom_configuration: Additional configuration string.
sample_rate: Audio sample rate in Hz.
audio_channel_count: Number of audio channels.
max_alternatives: Maximum number of alternatives.
interim_results: Enable interim results.
generate_interruptions: Enable interruption events.
idle_timeout: Timeout for idle ASR request in seconds.
use_ssl: Enable SSL connection.
**kwargs: Additional arguments for STTService.
Usage:
If server is not set then it defaults to "grpc.nvcf.nvidia.com:443" and use NVCF hosted models.
Update function ID to use a different NVCF model. API key is required for NVCF hosted models.
For using locally deployed Riva Speech Server, set server to "localhost:50051" and
follow the quick start guide to setup the server.
"""
super().__init__(**kwargs)
self._profanity_filter = profanity_filter
self._automatic_punctuation = automatic_punctuation
self._no_verbatim_transcripts = no_verbatim_transcripts
self._language_code = language
self._boosted_lm_words = boosted_lm_words
self._boosted_lm_score = boosted_lm_score
self._start_history = start_history
self._start_threshold = start_threshold
self._stop_history = stop_history
self._stop_threshold = stop_threshold
self._stop_history_eou = stop_history_eou
self._stop_threshold_eou = stop_threshold_eou
self._custom_configuration = custom_configuration
self._sample_rate: int = sample_rate
self._model = model
self._audio_channel_count = audio_channel_count
self._max_alternatives = max_alternatives
self._interim_results = interim_results
self._idle_timeout = idle_timeout
self.last_transcript_frame = None
self.set_model_name(model)
metadata = [
["function-id", function_id],
["authorization", f"Bearer {api_key}"],
]
if server == "grpc.nvcf.nvidia.com:443":
use_ssl = True
try:
auth = riva.client.Auth(None, use_ssl, server, metadata)
self._asr_service = riva.client.ASRService(auth)
except Exception as e:
logger.error(
"In order to use nvidia Riva TTSService or STTService, you will either need a locally "
"deployed Riva Speech Server with ASR and TTS models (Follow riva quick start guide at "
"https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html and "
"edit the config file to deploy which model you want to use and set the server url to "
"localhost:50051), or you can set the NVIDIA_API_KEY environment "
"variable to connect with nvcf hosted models."
)
raise Exception(f"Missing module: {e}") from e
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
encoding=riva.client.AudioEncoding.LINEAR_PCM,
language_code=self._language_code,
model=self._model,
max_alternatives=self._max_alternatives,
profanity_filter=self._profanity_filter,
enable_automatic_punctuation=self._automatic_punctuation,
verbatim_transcripts=not self._no_verbatim_transcripts,
sample_rate_hertz=self._sample_rate,
audio_channel_count=self._audio_channel_count,
),
interim_results=self._interim_results,
)
riva.client.add_word_boosting_to_config(config, self._boosted_lm_words, self._boosted_lm_score)
riva.client.add_endpoint_parameters_to_config(
config,
self._start_history,
self._start_threshold,
self._stop_history,
self._stop_history_eou,
self._stop_threshold,
self._stop_threshold_eou,
)
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
self._config = config
self._queue = asyncio.Queue()
self._generate_interruptions = generate_interruptions
if self._generate_interruptions:
self._vad_state = VADState.QUIET
# Initialize the thread task and response task
self._thread_task = None
self._response_task = None
# Initialize ASR compute latency tracking
self._audio_duration_counter = 0.0 # Tracks cumulative audio duration sent to Riva (in seconds)
def can_generate_metrics(self) -> bool:
"""Check if the service can generate metrics.
Returns:
bool: False as this service does not support metric generation.
"""
return False
async def start(self, frame: StartFrame):
"""Start the ASR service.
Args:
frame: The StartFrame that triggered the start.
"""
await super().start(frame)
self._response_task = self.create_task(self._response_task_handler())
self._response_queue = asyncio.Queue()
async def stop(self, frame: EndFrame):
"""Stop the ASR service and cleanup resources.
Args:
frame: The EndFrame that triggered the stop.
"""
await super().stop(frame)
await self._stop_tasks()
async def cancel(self, frame: CancelFrame):
"""Cancel the ASR service and cleanup resources.
Args:
frame: The CancelFrame that triggered the cancellation.
"""
await super().cancel(frame)
await self._stop_tasks()
async def _stop_tasks(self):
if self._thread_task is not None and not self._thread_task.done():
await self.cancel_task(self._thread_task)
if self._response_task is not None and not self._response_task.done():
await self.cancel_task(self._response_task)
def _response_handler(self):
try:
logger.debug("Sending new Riva ASR streaming request...")
responses = self._asr_service.streaming_response_generator(
audio_chunks=self,
streaming_config=self._config,
)
for response in responses:
if not response.results:
continue
asyncio.run_coroutine_threadsafe(self._response_queue.put(response), self.get_event_loop())
except Exception as e:
logger.error(f"Error in Riva ASR stream: {e}")
raise
logger.debug("Riva ASR streaming request terminated.")
@traced(attachment_strategy=AttachmentStrategy.NONE, name="asr")
async def _thread_task_handler(self):
try:
# Reset audio duration counter for new ASR session
self._audio_duration_counter = 0.0
self._thread_running = True
await asyncio.to_thread(self._response_handler)
except asyncio.CancelledError:
self._thread_running = False
raise
async def _handle_interruptions(self, frame: Frame):
if self.interruptions_allowed:
# Make sure we notify about interruptions quickly out-of-band.
if isinstance(frame, UserStartedSpeakingFrame):
logger.debug("User started speaking")
await self._start_interruption()
# Push an out-of-band frame (i.e. not using the ordered push
# frame task) to stop everything, specially at the output
# transport.
await self.push_frame(StartInterruptionFrame())
elif isinstance(frame, UserStoppedSpeakingFrame):
logger.debug("User stopped speaking")
await self._stop_interruption()
await self.push_frame(StopInterruptionFrame())
await self.push_frame(frame)
async def _handle_response(self, response):
"""Process ASR response and generate appropriate transcription frames.
Handles three types of transcription results:
1. Final results (is_final=True): Complete, confirmed transcriptions
2. Stable interim results (stability=1.0): High-confidence partial results
3. Partial results (stability<1.0): Lower-confidence, in-progress transcriptions
Also manages voice activity detection (VAD) state and interruption handling
when enabled. Each type of result generates appropriate transcription frames
with different stability values.
"""
partial_transcript = ""
for result in response.results:
if result and not result.alternatives:
continue
transcript = result.alternatives[0].transcript
if transcript and len(transcript) > 0:
await self.stop_ttfb_metrics()
if result.is_final:
await self.stop_processing_metrics()
if self._generate_interruptions:
self._vad_state = VADState.QUIET
await self._handle_interruptions(UserStoppedSpeakingFrame())
# Calculate ASR compute latency
if result.audio_processed:
compute_latency = self._audio_duration_counter - result.audio_processed
logger.debug(f"{self.name} ASR compute latency: {compute_latency}")
logger.debug(f"Final user transcript: [{transcript}]")
await self.push_frame(TranscriptionFrame(transcript, "", time_now_iso8601(), None))
self.last_transcript_frame = None
break
elif result.stability == 1.0:
if self._generate_interruptions and self._vad_state != VADState.SPEAKING:
self._vad_state = VADState.SPEAKING
await self._handle_interruptions(UserStartedSpeakingFrame())
if (
self.last_transcript_frame is None
or (self.last_transcript_frame.stability != 1.0)
or (self.last_transcript_frame.text.rstrip() != transcript.rstrip())
):
logger.debug(f"Interim user transcript: [{transcript}]")
frame = RivaInterimTranscriptionFrame(
transcript, "", time_now_iso8601(), None, stability=result.stability
)
await self.push_frame(frame)
self.last_transcript_frame = frame
break
else:
if self._generate_interruptions and self._vad_state != VADState.SPEAKING:
self._vad_state = VADState.SPEAKING
await self._handle_interruptions(UserStartedSpeakingFrame())
partial_transcript += transcript
if len(partial_transcript) > 0 and (
self.last_transcript_frame is None
or (self.last_transcript_frame.stability == 1.0)
or (self.last_transcript_frame.text.rstrip() != partial_transcript.rstrip())
):
logger.debug(f"Partial user transcript: [{partial_transcript}]")
frame = RivaInterimTranscriptionFrame(partial_transcript, "", time_now_iso8601(), None, stability=0.1)
await self.push_frame(frame)
self.last_transcript_frame = frame
async def _response_task_handler(self):
while True:
try:
response = await self._response_queue.get()
await self._handle_response(response)
except asyncio.CancelledError:
break
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Run speech-to-text recognition.
Args:
audio: The audio data to process.
Yields:
Frame: A sequence of frames containing the recognition results.
"""
if self._thread_task is None or self._thread_task.done():
self._thread_task = self.create_task(self._thread_task_handler())
await self._queue.put(audio)
yield None
def __next__(self) -> bytes:
"""Get the next audio chunk for processing.
Returns:
bytes: The next audio chunk.
Raises:
StopIteration: When no more audio chunks are available.
"""
if not self._thread_running:
raise StopIteration
try:
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
result = future.result(timeout=self._idle_timeout)
# Increment audio duration counter based on audio chunk size
# Assuming LINEAR_PCM encoding: bytes_per_sample = 2, channels = self._audio_channel_count
bytes_per_sample = 2 # 16-bit PCM
total_samples = len(result) // (bytes_per_sample * self._audio_channel_count)
duration_seconds = total_samples / self._sample_rate
self._audio_duration_counter += duration_seconds
except concurrent.futures.TimeoutError:
future.cancel()
logger.info(f"ASR service is idle for {self._idle_timeout} seconds, terminating active RIVA ASR request...")
self._thread_task = None
raise StopIteration from None
except Exception as e:
future.cancel()
raise e
return result
def __iter__(self):
"""Get iterator for audio chunks.
Returns:
RivaASRService: Self reference for iteration.
"""
return self