Spaces:
Running
Running
| # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: BSD 2-Clause License | |
| """Extension to Elevenlabs services for improved ACE compatability.""" | |
| import base64 | |
| import json | |
| from typing import Any | |
| from loguru import logger | |
| from pipecat.frames.frames import ( | |
| Frame, | |
| StartInterruptionFrame, | |
| TTSAudioRawFrame, | |
| TTSStoppedFrame, | |
| ) | |
| from pipecat.processors.frame_processor import FrameDirection | |
| from pipecat.services.elevenlabs.tts import ElevenLabsTTSService, calculate_word_times | |
| from nvidia_pipecat.utils.tracing import AttachmentStrategy, traceable, traced | |
| class ElevenLabsTTSServiceWithEndOfSpeech(ElevenLabsTTSService): | |
| """ElevenLabs TTS service with end-of-speech detection. | |
| This class extends the base ElevenLabs TTS service to add functionality for detecting | |
| and handling the end of speech segments. This is useful for interactive avatar experiences | |
| where TTSStoppedFrames are required to signal the end of a speech segment to control lip movement | |
| of the avatar. | |
| Input frames: | |
| TextFrame: Text to synthesize into speech. | |
| TTSSpeakFrame: Alternative text input for speech synthesis. | |
| LLMFullResponseEndFrame: Signals LLM response completion. | |
| BotStoppedSpeakingFrame: Signals bot speech completion. | |
| Output frames: | |
| TTSStartedFrame: Signals TTS start. | |
| TTSTextFrame: Contains text being synthesized. | |
| TTSAudioRawFrame: Contains raw audio data. | |
| TTSStoppedFrame: Signals TTS completion. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| """Initialize the ElevenLabsTTSServiceWithEndOfSpeech. | |
| Shares all the parameters with the parent class ElevenLabsTTSService. | |
| Args: | |
| *args: Variable length argument list passed to parent ElevenLabsTTSService. | |
| **kwargs: Arbitrary keyword arguments passed to parent ElevenLabsTTSService. | |
| """ | |
| super().__init__(*args, **kwargs) | |
| self._partial_word: dict | None = None | |
| self._context_id_to_close: str | None = None | |
| async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: | |
| """Processes frames. | |
| Args: | |
| frame (Frame): Incoming frame to process. | |
| direction (FrameDirection): Frame flow direction. | |
| """ | |
| await super().process_frame(frame, direction) | |
| if isinstance(frame, StartInterruptionFrame): | |
| self._partial_word = None | |
| async def flush_audio(self): | |
| """Flushes remaining audio in websocket connection. | |
| Sends special marker messages to flush audio buffer and signal end of speech. | |
| """ | |
| if self._websocket and self._context_id: | |
| self._context_id_to_close = self._context_id | |
| msg = {"context_id": self._context_id, "flush": True} | |
| await self._websocket.send(json.dumps(msg)) | |
| msg = {"context_id": self._context_id, "close_context": True} | |
| await self._websocket.send(json.dumps(msg)) | |
| self._context_id = None | |
| async def run_tts(self, text: str): | |
| """Run text-to-speech synthesis. | |
| Compared to the based class method this method is instrumented for tracing. | |
| """ | |
| async for frame in super().run_tts(text): | |
| yield frame | |
| async def _receive_messages(self): | |
| async for message in self._get_websocket(): | |
| msg = json.loads(message) | |
| # Check if this message belongs to the current context | |
| # The default context may return null/None for context_id | |
| received_ctx_id = msg.get("contextId") | |
| if self._context_id is not None and received_ctx_id is not None and received_ctx_id != self._context_id: | |
| logger.trace(f"Ignoring message from different context: {received_ctx_id}") | |
| continue | |
| if msg.get("audio"): | |
| await self.stop_ttfb_metrics() | |
| self.start_word_timestamps() | |
| audio = base64.b64decode(msg["audio"]) | |
| frame = TTSAudioRawFrame(audio, self.sample_rate, 1) | |
| await self.push_frame(frame) | |
| if msg.get("alignment"): | |
| msg["alignment"] = self._shift_partial_words(msg["alignment"]) | |
| word_times = calculate_word_times(msg["alignment"], self._cumulative_time) | |
| await self.add_word_timestamps(word_times) | |
| self._cumulative_time = word_times[-1][1] | |
| if msg.get("isFinal"): | |
| logger.trace(f"Received final message for context {received_ctx_id}") | |
| # Context has finished | |
| if self._context_id == received_ctx_id or self._context_id_to_close == received_ctx_id: | |
| self._context_id = None | |
| self._context_id_to_close = None | |
| self._started = False | |
| await self.push_frame(TTSStoppedFrame()) | |
| def _shift_partial_words(self, alignment_info: dict[str, Any]) -> dict[str, Any]: | |
| """Shifts partial words from the previous alignment and retains incomplete words.""" | |
| keys = ["chars", "charStartTimesMs", "charDurationsMs"] | |
| # Add partial word from the previous part | |
| if self._partial_word: | |
| for key in keys: | |
| alignment_info[key] = self._partial_word[key] + alignment_info[key] | |
| self._partial_word = None | |
| # Check if the last word is incomplete | |
| if not alignment_info["chars"][-1].isspace(): | |
| # Find the last space character | |
| last_space_index = -1 | |
| for i in range(len(alignment_info["chars"]) - 1, -1, -1): | |
| if alignment_info["chars"][i].isspace(): | |
| last_space_index = i + 1 | |
| break | |
| if last_space_index > -1: | |
| # Split into completed and partial parts | |
| self._partial_word = {key: alignment_info[key][last_space_index:] for key in keys} | |
| for key in keys: | |
| alignment_info[key] = alignment_info[key][:last_space_index] | |
| return alignment_info | |