Spaces:
Running
Running
File size: 6,263 Bytes
53ea588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License
"""Extension to Elevenlabs services for improved ACE compatability."""
import base64
import json
from typing import Any
from loguru import logger
from pipecat.frames.frames import (
Frame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService, calculate_word_times
from nvidia_pipecat.utils.tracing import AttachmentStrategy, traceable, traced
@traceable
class ElevenLabsTTSServiceWithEndOfSpeech(ElevenLabsTTSService):
"""ElevenLabs TTS service with end-of-speech detection.
This class extends the base ElevenLabs TTS service to add functionality for detecting
and handling the end of speech segments. This is useful for interactive avatar experiences
where TTSStoppedFrames are required to signal the end of a speech segment to control lip movement
of the avatar.
Input frames:
TextFrame: Text to synthesize into speech.
TTSSpeakFrame: Alternative text input for speech synthesis.
LLMFullResponseEndFrame: Signals LLM response completion.
BotStoppedSpeakingFrame: Signals bot speech completion.
Output frames:
TTSStartedFrame: Signals TTS start.
TTSTextFrame: Contains text being synthesized.
TTSAudioRawFrame: Contains raw audio data.
TTSStoppedFrame: Signals TTS completion.
"""
def __init__(self, *args, **kwargs):
"""Initialize the ElevenLabsTTSServiceWithEndOfSpeech.
Shares all the parameters with the parent class ElevenLabsTTSService.
Args:
*args: Variable length argument list passed to parent ElevenLabsTTSService.
**kwargs: Arbitrary keyword arguments passed to parent ElevenLabsTTSService.
"""
super().__init__(*args, **kwargs)
self._partial_word: dict | None = None
self._context_id_to_close: str | None = None
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
"""Processes frames.
Args:
frame (Frame): Incoming frame to process.
direction (FrameDirection): Frame flow direction.
"""
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
self._partial_word = None
async def flush_audio(self):
"""Flushes remaining audio in websocket connection.
Sends special marker messages to flush audio buffer and signal end of speech.
"""
if self._websocket and self._context_id:
self._context_id_to_close = self._context_id
msg = {"context_id": self._context_id, "flush": True}
await self._websocket.send(json.dumps(msg))
msg = {"context_id": self._context_id, "close_context": True}
await self._websocket.send(json.dumps(msg))
self._context_id = None
@traced(attachment_strategy=AttachmentStrategy.NONE, name="tts")
async def run_tts(self, text: str):
"""Run text-to-speech synthesis.
Compared to the based class method this method is instrumented for tracing.
"""
async for frame in super().run_tts(text):
yield frame
async def _receive_messages(self):
async for message in self._get_websocket():
msg = json.loads(message)
# Check if this message belongs to the current context
# The default context may return null/None for context_id
received_ctx_id = msg.get("contextId")
if self._context_id is not None and received_ctx_id is not None and received_ctx_id != self._context_id:
logger.trace(f"Ignoring message from different context: {received_ctx_id}")
continue
if msg.get("audio"):
await self.stop_ttfb_metrics()
self.start_word_timestamps()
audio = base64.b64decode(msg["audio"])
frame = TTSAudioRawFrame(audio, self.sample_rate, 1)
await self.push_frame(frame)
if msg.get("alignment"):
msg["alignment"] = self._shift_partial_words(msg["alignment"])
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]
if msg.get("isFinal"):
logger.trace(f"Received final message for context {received_ctx_id}")
# Context has finished
if self._context_id == received_ctx_id or self._context_id_to_close == received_ctx_id:
self._context_id = None
self._context_id_to_close = None
self._started = False
await self.push_frame(TTSStoppedFrame())
def _shift_partial_words(self, alignment_info: dict[str, Any]) -> dict[str, Any]:
"""Shifts partial words from the previous alignment and retains incomplete words."""
keys = ["chars", "charStartTimesMs", "charDurationsMs"]
# Add partial word from the previous part
if self._partial_word:
for key in keys:
alignment_info[key] = self._partial_word[key] + alignment_info[key]
self._partial_word = None
# Check if the last word is incomplete
if not alignment_info["chars"][-1].isspace():
# Find the last space character
last_space_index = -1
for i in range(len(alignment_info["chars"]) - 1, -1, -1):
if alignment_info["chars"][i].isspace():
last_space_index = i + 1
break
if last_space_index > -1:
# Split into completed and partial parts
self._partial_word = {key: alignment_info[key][last_space_index:] for key in keys}
for key in keys:
alignment_info[key] = alignment_info[key][:last_space_index]
return alignment_info
|