fciannella's picture
Working with service run on 7860
53ea588
raw
history blame
6.53 kB
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License
"""NVIDIA Riva Neural Machine Translation (NMT) service implementation.
This module provides integration with NVIDIA Riva's NMT service for text translation
between different languages. It supports:
- Real-time text translation
- Multiple language pairs
- Integration with LLM and sentence aggregation pipelines
"""
import re
from loguru import logger
from pipecat.frames.frames import (
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
TextFrame,
TranscriptionFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
from pipecat.transcriptions.language import Language
try:
import riva.client
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use nvidia rivaskills NMT, you need to `pip install pipecat-ai[riva]`.")
raise Exception(f"Missing module: {e}") from e
class RivaNMTService(AIService):
"""Base class for services using Riva NMT.
Handles translation of text between languages using NVIDIA's Riva Neural Machine
Translation service. Requires Riva NMT models to be deployed following:
https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide/nmt.html
"""
def __init__(
self,
source_language: Language,
target_language: Language,
model_name: str = "",
server: str = "localhost:50051",
**kwargs,
):
"""Initialize the Riva NMT service.
Args:
source_language: Source language for translation.
target_language: Target language for translation.
model_name: Name of the RIVA translation model. Empty string will
auto-select an available model.
server: Riva server address.
**kwargs: Additional arguments for AIService parent class.
Raises:
Exception: If source_language or target_language is not provided.
"""
if not source_language:
raise Exception("No source language provided for the translation..")
if not target_language:
raise Exception("No target language provided for the translation..")
super().__init__(**kwargs)
self.set_model_name(model_name)
self.source_language = source_language
self.target_language = target_language
self.llm_full_response_started = False
self.llm_full_response = ""
self.auth = riva.client.Auth(uri=server)
self.riva_nmt_client = riva.client.NeuralMachineTranslationClient(self.auth)
async def translate_text(self, text: str = "") -> tuple[str | None, str | None]:
"""Translates text using Riva NMT service.
Args:
text: The text to translate. Must not be empty.
Returns:
A tuple containing:
- str | None: Translated text if successful, None if failed
- str | None: Error message if failed, None if successful
Raises:
Exception: If no input text is provided for translation.
"""
try:
if not text:
raise Exception("No input text provided for the translation..")
logger.debug(f"Received text: {text}")
logger.debug(f"Translating the text from {self.source_language} to {self.target_language}")
response = self.riva_nmt_client.translate(
[text], self._model_name, self.source_language, self.target_language
)
logger.debug(f"Final translated text: {response.translations[0].text}")
return response.translations[0].text, None
except Exception as e:
logger.error(
f"Error while translating the text from {self.source_language} to {self.target_language}, Error: {e}"
)
return (
None,
f"Error while translating the text from {self.source_language} to {self.target_language}, Error: {e}",
)
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
"""Processes incoming frames for translation.
Handles different frame types:
- TranscriptionFrame: Translates text and pushes LLMMessagesFrame
- LLMFullResponseStartFrame: Marks start of LLM response
- LLMFullResponseEndFrame: Translates accumulated response and pushes TextFrame
- TextFrame: Accumulates text during LLM response
Args:
frame: Frame to process.
direction: Direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
await self.start_processing_metrics()
translated_text, err = await self.translate_text(frame.text)
await self.stop_processing_metrics()
if err is not None:
await self.push_error(ErrorFrame(err))
else:
messages = [{"role": "system", "content": translated_text}]
await self.push_frame(LLMMessagesFrame(messages))
elif isinstance(frame, LLMFullResponseStartFrame):
self.llm_full_response_started = True
elif isinstance(frame, LLMFullResponseEndFrame):
self.llm_full_response_started = False
# Removing period, question mark, exclamation point, colon, or semicolon
# as these match end of sentence regex in
# _process_text_frame() method of TTSService of pipecat/services/ai_services.py
# and TTS response gets truncated.
await self.start_processing_metrics()
self.llm_full_response = re.sub("[.?!:;]", "", self.llm_full_response)
translated_text, err = await self.translate_text(self.llm_full_response)
await self.stop_processing_metrics()
if err is not None:
await self.push_error(ErrorFrame(err))
else:
await self.push_frame(TextFrame(translated_text + "."))
self.llm_full_response = ""
elif self.llm_full_response_started and isinstance(frame, TextFrame):
self.llm_full_response += frame.text
else:
await self.push_frame(frame, direction)