File size: 6,527 Bytes
53ea588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License

"""NVIDIA Riva Neural Machine Translation (NMT) service implementation.

This module provides integration with NVIDIA Riva's NMT service for text translation
between different languages. It supports:
- Real-time text translation
- Multiple language pairs
- Integration with LLM and sentence aggregation pipelines
"""

import re

from loguru import logger
from pipecat.frames.frames import (
    ErrorFrame,
    Frame,
    LLMFullResponseEndFrame,
    LLMFullResponseStartFrame,
    LLMMessagesFrame,
    TextFrame,
    TranscriptionFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
from pipecat.transcriptions.language import Language

try:
    import riva.client
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use nvidia rivaskills NMT, you need to `pip install pipecat-ai[riva]`.")
    raise Exception(f"Missing module: {e}") from e


class RivaNMTService(AIService):
    """Base class for services using Riva NMT.

    Handles translation of text between languages using NVIDIA's Riva Neural Machine
    Translation service. Requires Riva NMT models to be deployed following:
    https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide/nmt.html
    """

    def __init__(
        self,
        source_language: Language,
        target_language: Language,
        model_name: str = "",
        server: str = "localhost:50051",
        **kwargs,
    ):
        """Initialize the Riva NMT service.

        Args:
            source_language: Source language for translation.
            target_language: Target language for translation.
            model_name: Name of the RIVA translation model. Empty string will
                auto-select an available model.
            server: Riva server address.
            **kwargs: Additional arguments for AIService parent class.

        Raises:
            Exception: If source_language or target_language is not provided.
        """
        if not source_language:
            raise Exception("No source language provided for the translation..")
        if not target_language:
            raise Exception("No target language provided for the translation..")
        super().__init__(**kwargs)
        self.set_model_name(model_name)
        self.source_language = source_language
        self.target_language = target_language
        self.llm_full_response_started = False
        self.llm_full_response = ""
        self.auth = riva.client.Auth(uri=server)
        self.riva_nmt_client = riva.client.NeuralMachineTranslationClient(self.auth)

    async def translate_text(self, text: str = "") -> tuple[str | None, str | None]:
        """Translates text using Riva NMT service.

        Args:
            text: The text to translate. Must not be empty.

        Returns:
            A tuple containing:
                - str | None: Translated text if successful, None if failed
                - str | None: Error message if failed, None if successful

        Raises:
            Exception: If no input text is provided for translation.
        """
        try:
            if not text:
                raise Exception("No input text provided for the translation..")

            logger.debug(f"Received text: {text}")
            logger.debug(f"Translating the text from {self.source_language} to {self.target_language}")
            response = self.riva_nmt_client.translate(
                [text], self._model_name, self.source_language, self.target_language
            )
            logger.debug(f"Final translated text: {response.translations[0].text}")
            return response.translations[0].text, None
        except Exception as e:
            logger.error(
                f"Error while translating the text from {self.source_language} to {self.target_language}, Error: {e}"
            )
            return (
                None,
                f"Error while translating the text from {self.source_language} to {self.target_language}, Error: {e}",
            )

    async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
        """Processes incoming frames for translation.

        Handles different frame types:
            - TranscriptionFrame: Translates text and pushes LLMMessagesFrame
            - LLMFullResponseStartFrame: Marks start of LLM response
            - LLMFullResponseEndFrame: Translates accumulated response and pushes TextFrame
            - TextFrame: Accumulates text during LLM response

        Args:
            frame: Frame to process.
            direction: Direction of frame processing.
        """
        await super().process_frame(frame, direction)

        if isinstance(frame, TranscriptionFrame):
            await self.start_processing_metrics()
            translated_text, err = await self.translate_text(frame.text)
            await self.stop_processing_metrics()
            if err is not None:
                await self.push_error(ErrorFrame(err))
            else:
                messages = [{"role": "system", "content": translated_text}]
                await self.push_frame(LLMMessagesFrame(messages))
        elif isinstance(frame, LLMFullResponseStartFrame):
            self.llm_full_response_started = True
        elif isinstance(frame, LLMFullResponseEndFrame):
            self.llm_full_response_started = False
            # Removing period, question mark, exclamation point, colon, or semicolon
            # as these match end of sentence regex in
            # _process_text_frame() method of TTSService of pipecat/services/ai_services.py
            # and TTS response gets truncated.
            await self.start_processing_metrics()
            self.llm_full_response = re.sub("[.?!:;]", "", self.llm_full_response)
            translated_text, err = await self.translate_text(self.llm_full_response)
            await self.stop_processing_metrics()
            if err is not None:
                await self.push_error(ErrorFrame(err))
            else:
                await self.push_frame(TextFrame(translated_text + "."))
            self.llm_full_response = ""
        elif self.llm_full_response_started and isinstance(frame, TextFrame):
            self.llm_full_response += frame.text
        else:
            await self.push_frame(frame, direction)