File size: 6,263 Bytes
53ea588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License

"""Extension to Elevenlabs services for improved ACE compatability."""

import base64
import json
from typing import Any

from loguru import logger
from pipecat.frames.frames import (
    Frame,
    StartInterruptionFrame,
    TTSAudioRawFrame,
    TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService, calculate_word_times

from nvidia_pipecat.utils.tracing import AttachmentStrategy, traceable, traced


@traceable
class ElevenLabsTTSServiceWithEndOfSpeech(ElevenLabsTTSService):
    """ElevenLabs TTS service with end-of-speech detection.

    This class extends the base ElevenLabs TTS service to add functionality for detecting
    and handling the end of speech segments. This is useful for interactive avatar experiences
    where TTSStoppedFrames are required to signal the end of a speech segment to control lip movement
    of the avatar.

    Input frames:
        TextFrame: Text to synthesize into speech.
        TTSSpeakFrame: Alternative text input for speech synthesis.
        LLMFullResponseEndFrame: Signals LLM response completion.
        BotStoppedSpeakingFrame: Signals bot speech completion.

    Output frames:
        TTSStartedFrame: Signals TTS start.
        TTSTextFrame: Contains text being synthesized.
        TTSAudioRawFrame: Contains raw audio data.
        TTSStoppedFrame: Signals TTS completion.
    """

    def __init__(self, *args, **kwargs):
        """Initialize the ElevenLabsTTSServiceWithEndOfSpeech.

        Shares all the parameters with the parent class ElevenLabsTTSService.

        Args:
            *args: Variable length argument list passed to parent ElevenLabsTTSService.
            **kwargs: Arbitrary keyword arguments passed to parent ElevenLabsTTSService.
        """
        super().__init__(*args, **kwargs)
        self._partial_word: dict | None = None
        self._context_id_to_close: str | None = None

    async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
        """Processes frames.

        Args:
            frame (Frame): Incoming frame to process.
            direction (FrameDirection): Frame flow direction.
        """
        await super().process_frame(frame, direction)

        if isinstance(frame, StartInterruptionFrame):
            self._partial_word = None

    async def flush_audio(self):
        """Flushes remaining audio in websocket connection.

        Sends special marker messages to flush audio buffer and signal end of speech.
        """
        if self._websocket and self._context_id:
            self._context_id_to_close = self._context_id
            msg = {"context_id": self._context_id, "flush": True}
            await self._websocket.send(json.dumps(msg))
            msg = {"context_id": self._context_id, "close_context": True}
            await self._websocket.send(json.dumps(msg))
            self._context_id = None

    @traced(attachment_strategy=AttachmentStrategy.NONE, name="tts")
    async def run_tts(self, text: str):
        """Run text-to-speech synthesis.

        Compared to the based class method this method is instrumented for tracing.
        """
        async for frame in super().run_tts(text):
            yield frame

    async def _receive_messages(self):
        async for message in self._get_websocket():
            msg = json.loads(message)
            # Check if this message belongs to the current context
            # The default context may return null/None for context_id
            received_ctx_id = msg.get("contextId")
            if self._context_id is not None and received_ctx_id is not None and received_ctx_id != self._context_id:
                logger.trace(f"Ignoring message from different context: {received_ctx_id}")
                continue

            if msg.get("audio"):
                await self.stop_ttfb_metrics()
                self.start_word_timestamps()

                audio = base64.b64decode(msg["audio"])
                frame = TTSAudioRawFrame(audio, self.sample_rate, 1)
                await self.push_frame(frame)
            if msg.get("alignment"):
                msg["alignment"] = self._shift_partial_words(msg["alignment"])
                word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
                await self.add_word_timestamps(word_times)
                self._cumulative_time = word_times[-1][1]
            if msg.get("isFinal"):
                logger.trace(f"Received final message for context {received_ctx_id}")
                # Context has finished
                if self._context_id == received_ctx_id or self._context_id_to_close == received_ctx_id:
                    self._context_id = None
                    self._context_id_to_close = None
                    self._started = False
                    await self.push_frame(TTSStoppedFrame())

    def _shift_partial_words(self, alignment_info: dict[str, Any]) -> dict[str, Any]:
        """Shifts partial words from the previous alignment and retains incomplete words."""
        keys = ["chars", "charStartTimesMs", "charDurationsMs"]
        # Add partial word from the previous part
        if self._partial_word:
            for key in keys:
                alignment_info[key] = self._partial_word[key] + alignment_info[key]
            self._partial_word = None

        # Check if the last word is incomplete
        if not alignment_info["chars"][-1].isspace():
            # Find the last space character
            last_space_index = -1
            for i in range(len(alignment_info["chars"]) - 1, -1, -1):
                if alignment_info["chars"][i].isspace():
                    last_space_index = i + 1
                    break

            if last_space_index > -1:
                # Split into completed and partial parts
                self._partial_word = {key: alignment_info[key][last_space_index:] for key in keys}
                for key in keys:
                    alignment_info[key] = alignment_info[key][:last_space_index]

        return alignment_info