Spaces:
Running
Running
File size: 6,527 Bytes
53ea588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License
"""NVIDIA Riva Neural Machine Translation (NMT) service implementation.
This module provides integration with NVIDIA Riva's NMT service for text translation
between different languages. It supports:
- Real-time text translation
- Multiple language pairs
- Integration with LLM and sentence aggregation pipelines
"""
import re
from loguru import logger
from pipecat.frames.frames import (
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
TextFrame,
TranscriptionFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
from pipecat.transcriptions.language import Language
try:
import riva.client
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use nvidia rivaskills NMT, you need to `pip install pipecat-ai[riva]`.")
raise Exception(f"Missing module: {e}") from e
class RivaNMTService(AIService):
"""Base class for services using Riva NMT.
Handles translation of text between languages using NVIDIA's Riva Neural Machine
Translation service. Requires Riva NMT models to be deployed following:
https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide/nmt.html
"""
def __init__(
self,
source_language: Language,
target_language: Language,
model_name: str = "",
server: str = "localhost:50051",
**kwargs,
):
"""Initialize the Riva NMT service.
Args:
source_language: Source language for translation.
target_language: Target language for translation.
model_name: Name of the RIVA translation model. Empty string will
auto-select an available model.
server: Riva server address.
**kwargs: Additional arguments for AIService parent class.
Raises:
Exception: If source_language or target_language is not provided.
"""
if not source_language:
raise Exception("No source language provided for the translation..")
if not target_language:
raise Exception("No target language provided for the translation..")
super().__init__(**kwargs)
self.set_model_name(model_name)
self.source_language = source_language
self.target_language = target_language
self.llm_full_response_started = False
self.llm_full_response = ""
self.auth = riva.client.Auth(uri=server)
self.riva_nmt_client = riva.client.NeuralMachineTranslationClient(self.auth)
async def translate_text(self, text: str = "") -> tuple[str | None, str | None]:
"""Translates text using Riva NMT service.
Args:
text: The text to translate. Must not be empty.
Returns:
A tuple containing:
- str | None: Translated text if successful, None if failed
- str | None: Error message if failed, None if successful
Raises:
Exception: If no input text is provided for translation.
"""
try:
if not text:
raise Exception("No input text provided for the translation..")
logger.debug(f"Received text: {text}")
logger.debug(f"Translating the text from {self.source_language} to {self.target_language}")
response = self.riva_nmt_client.translate(
[text], self._model_name, self.source_language, self.target_language
)
logger.debug(f"Final translated text: {response.translations[0].text}")
return response.translations[0].text, None
except Exception as e:
logger.error(
f"Error while translating the text from {self.source_language} to {self.target_language}, Error: {e}"
)
return (
None,
f"Error while translating the text from {self.source_language} to {self.target_language}, Error: {e}",
)
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
"""Processes incoming frames for translation.
Handles different frame types:
- TranscriptionFrame: Translates text and pushes LLMMessagesFrame
- LLMFullResponseStartFrame: Marks start of LLM response
- LLMFullResponseEndFrame: Translates accumulated response and pushes TextFrame
- TextFrame: Accumulates text during LLM response
Args:
frame: Frame to process.
direction: Direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
await self.start_processing_metrics()
translated_text, err = await self.translate_text(frame.text)
await self.stop_processing_metrics()
if err is not None:
await self.push_error(ErrorFrame(err))
else:
messages = [{"role": "system", "content": translated_text}]
await self.push_frame(LLMMessagesFrame(messages))
elif isinstance(frame, LLMFullResponseStartFrame):
self.llm_full_response_started = True
elif isinstance(frame, LLMFullResponseEndFrame):
self.llm_full_response_started = False
# Removing period, question mark, exclamation point, colon, or semicolon
# as these match end of sentence regex in
# _process_text_frame() method of TTSService of pipecat/services/ai_services.py
# and TTS response gets truncated.
await self.start_processing_metrics()
self.llm_full_response = re.sub("[.?!:;]", "", self.llm_full_response)
translated_text, err = await self.translate_text(self.llm_full_response)
await self.stop_processing_metrics()
if err is not None:
await self.push_error(ErrorFrame(err))
else:
await self.push_frame(TextFrame(translated_text + "."))
self.llm_full_response = ""
elif self.llm_full_response_started and isinstance(frame, TextFrame):
self.llm_full_response += frame.text
else:
await self.push_frame(frame, direction)
|