fciannella's picture
Working with service run on 7860
53ea588
raw
history blame
16.3 kB
# Copyright(c) 2025 NVIDIA Corporation. All rights reserved.
# NVIDIA Corporation and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA Corporation is strictly prohibited.
"""Speech planner service for managing real-time speech interactions and VAD-based interruptions.
This module provides the SpeechPlanner class which handles:
- Voice Activity Detection (VAD) processing
- Speech interaction management
- Interruption handling based on VAD signals
- Coordination of speech prediction and transcription frames
"""
from collections.abc import AsyncIterator
from datetime import datetime
from typing import Any
import yaml
from langchain_core.messages.base import BaseMessageChunk
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from loguru import logger
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
Frame,
InterimTranscriptionFrame,
LLMMessagesFrame,
LLMUpdateSettingsFrame,
StartInterruptionFrame,
StopInterruptionFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection
from pydantic import BaseModel, Field
from nvidia_pipecat.services.nvidia_llm import NvidiaLLMService
class SpeechPlanner(NvidiaLLMService):
"""Speech planner that manages speech interactions and interruptions based on VAD and predictions."""
class InputParams(BaseModel):
"""Parameters for controlling NVIDIA LLM behavior."""
frequency_penalty: float | None = Field(default=None, ge=-2.0, le=2.0)
presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0)
seed: int | None = Field(default=None, ge=0)
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
top_k: int | None = Field(default=None, ge=0)
top_p: float | None = Field(default=None, ge=0.0, le=1.0)
max_tokens: int | None = Field(default=None, ge=1)
max_completion_tokens: int | None = Field(default=None, ge=1)
extra: dict[str, Any] | None = Field(default=None)
def __init__(
self,
*,
prompt_file: str,
model: str = "nvdev/google/gemma-2b-it",
api_key: str = None,
base_url: str | None = None,
context: OpenAILLMContext = None,
params: InputParams | None = None,
context_window: int = 1, # Number of previous conversation turns to consider for the current conversation.
**kwargs,
):
"""Initialize the speech planner.
Args:
prompt_file: Path to YAML file containing prompts
model: Name of the NVIDIA LLM model. Defaults to "nvdev/google/gemma-2b-it"
api_key: API key for authentication
base_url: Base URL for the API
params: Input parameters for the service
context: Context manager for conversation history
context_window: Number of previous conversation turns to consider. Defaults to 1
**kwargs: Additional keyword arguments passed to parent class.
"""
super().__init__(**kwargs)
if params is None:
params = SpeechPlanner.InputParams()
self._settings = {
"frequency_penalty": params.frequency_penalty,
"presence_penalty": params.presence_penalty,
"seed": params.seed,
"temperature": params.temperature,
"top_p": params.top_p,
"max_tokens": params.max_tokens,
"max_completion_tokens": params.max_completion_tokens,
"extra": params.extra if isinstance(params.extra, dict) else {},
}
self.set_model_name(model)
self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs)
self.context = context
self.context_window = context_window
with open(prompt_file) as file:
self.prompts = yaml.safe_load(file)
self.last_processed_frame = None
self.last_frame = None
self.last_complete_interim_frame = None
self.user_speaking = None
self.latest_bot_started_speaking_frame_timestamp = None
self.current_prediction = None
self._current_task = None
def create_client(self, api_key=None, base_url=None, **kwargs):
"""Create a client for the NVIDIA LLM service."""
return ChatNVIDIA(
base_url=base_url,
model=self.model_name,
api_key=api_key,
)
async def get_chat_completions(self, messages) -> AsyncIterator[BaseMessageChunk]:
"""Get chat completions from the LLM model.
Args:
messages: The input messages to process
Returns:
AsyncIterator[BaseMessageChunk]: Stream of response chunks from the model
"""
params = {
"model": self.model_name,
"stream": True,
"stream_options": {"include_usage": True},
"messages": messages,
"frequency_penalty": self._settings["frequency_penalty"],
"presence_penalty": self._settings["presence_penalty"],
"seed": self._settings["seed"],
"temperature": self._settings["temperature"],
"top_p": self._settings["top_p"],
"max_tokens": self._settings["max_tokens"],
"max_completion_tokens": self._settings["max_completion_tokens"],
}
params.update(self._settings["extra"])
chunks = self._client.astream(input=messages, config=params)
return chunks
async def _stream_chat_completions(self, prompt: str) -> AsyncIterator[BaseMessageChunk]:
"""Stream chat completions for a given prompt.
Args:
prompt (str): The prompt to send to the model
Returns:
AsyncIterator[BaseMessageChunk]: Stream of response chunks
"""
logger.debug(f"Generating chat: {prompt}")
chunks = await self.get_chat_completions(prompt)
return chunks
def get_chat_history(self) -> list:
"""Retrieves a subset of the conversation history for context in speech planning.
This method calculates how many recent conversation turns to include based on
the configured context_window size. It ensures we start with a user message
for proper conversation flow.
The method will:
1. Return empty list if no messages exist
2. Calculate the starting point based on context_window size
3. Include messages from either:
- The last N user-assistant pairs (where N = context_window)
- Or slightly more if needed to start with a user message
Returns:
list: A slice of conversation history messages, starting with a user message
"""
chat_history = []
messages = self.context.get_messages()
if len(messages) == 0:
return chat_history
# Calculate how many conversation turns to include
# Each turn consists of 2 messages (user + assistant)
conversation_turns_to_include = min(
self.context_window, # Max turns specified in config
(len(messages) - 2) / 2, # Available complete turns
)
start_position = max(0, int(conversation_turns_to_include * 2))
# Try to start from a user message
if -len(messages) <= -start_position < len(messages) and messages[-start_position]["role"] == "user":
chat_history = messages[-start_position:]
# If the above doesn't work, try starting one message later
elif (
-len(messages) <= (-start_position + 1) < len(messages) and messages[-start_position + 1]["role"] == "user"
):
chat_history = messages[-start_position + 1 :]
return chat_history
async def _cancel_current_task(self):
"""Cancel the current prediction task if it exists and is running."""
if self._current_task is not None:
if not (self._current_task.done() or self._current_task.cancelled()):
logger.debug("Speech Planner: Cancelling previous task")
await self.cancel_task(self._current_task)
self._current_task = None
async def _process_complete_context(self, frame: TranscriptionFrame):
"""Process a transcription frame to determine if it represents a complete utterance.
This method uses the LLM to analyze the transcription and determine if it's
a complete thought/sentence. If complete, it triggers appropriate interruption
frames and forwards the transcription.
Args:
frame (TranscriptionFrame): The transcription frame to process
"""
try:
base_prompt = self.prompts["prompts"]["completion_prompt"]
chat_history = self.get_chat_history()
transcript = frame.text
prompt = ""
if self.prompts["configurations"]["using_chat_history"]:
prompt = base_prompt.format(transcript=transcript, chat_history=chat_history)
else:
prompt = base_prompt.format(transcript=transcript)
chunk_stream = await self._stream_chat_completions(prompt)
pred = ""
async for chunk in chunk_stream:
if not chunk.content:
continue
try:
pred += chunk.content
except Exception as e:
logger.debug(
f"Failed to append chunk content: {e}, chunk: {chunk}, setting prediction to '<silent>'"
)
pred = "<silent>"
pred = pred.strip()
logger.debug(
f"""Speech Planner : Smart EOU Detection
\n\t Transcript: {transcript} \t Prompt: {prompt} \t Prediction: {pred}"""
)
def preprocess_pred(x):
"""Maps LLM speech classification labels to EOU detection states.
Args:
x (str): LLM prediction containing Label1-4 classifications
Returns:
str: "Complete" (Label1/3/4) or "Incomplete" (Label2/unrecognized)
Note: Handles "Label1" and "Label 1" formats. Defaults to "Incomplete".
"""
if (
"Label1" in x
or "Label 1" in x
or "Label3" in x
or "Label 3" in x
or "Label4" in x
or "Label 4" in x
):
return "Complete"
else:
return "Incomplete"
pred = preprocess_pred(pred)
except Exception as e:
logger.warning(f"Disabling Smart EOU detection due to error: {e}", exc_info=True)
pred = "Complete"
self.current_prediction = pred
if pred == "Complete":
# send transcript frame downwards if it is complete
if isinstance(frame, InterimTranscriptionFrame):
self.last_complete_interim_frame = frame
if len(frame.text) > 0:
logger.debug(f"Speech Planner: Pushing Complete Transcript to LLM at {datetime.now()}")
await self.push_frame(StartInterruptionFrame())
await self.push_frame(StopInterruptionFrame())
await self.push_frame(TranscriptionFrame(frame.text, frame.user_id, frame.timestamp, frame.language))
return
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames and manage speech interactions.
Args:
frame: The incoming frame to process.
direction: The direction of the frame flow.
"""
if isinstance(frame, TranscriptionFrame):
self.last_frame = frame
if self.current_prediction == "Complete": # need to reset every time new transcript comes
logger.debug("Speech Planner: Holding final frame")
self.last_frame = None
else:
await self._cancel_current_task()
# await self.push_frame(StartInterruptionFrame(), FrameDirection.DOWNSTREAM) # Triggering Interruption
logger.debug(
f"Speech Planner: Pushing final frame when prediction is {self.current_prediction} "
f"and user_speaking is {self.user_speaking}"
)
if not self.user_speaking: # Utilising acoustic VAD signal
logger.debug("Speech Planner: Sent final after VAD signal go ahead")
await self.push_frame(StartInterruptionFrame(), FrameDirection.DOWNSTREAM)
await self.push_frame(frame, FrameDirection.DOWNSTREAM)
self.last_frame = None
self.last_processed_frame = None
self.last_complete_interim_frame = None
self.current_prediction = None
elif isinstance(frame, InterimTranscriptionFrame):
self.last_frame = frame
logger.debug(f"Speech Planner: Last Complete Interim Frame {self.last_complete_interim_frame}")
if self.last_processed_frame is None or (self.last_processed_frame.text.strip() != frame.text.strip()):
self.current_prediction = None # predictions need to be reset every time new partial comes
await self._cancel_current_task()
if self.current_prediction == "Complete":
await self.push_frame(
StartInterruptionFrame(), FrameDirection.DOWNSTREAM
) # Triggering Interruption
if not self.user_speaking: # Utilising acoustic VAD signal
logger.debug("Speech Planner: Sent interim after VAD signal go ahead")
self._current_task = self.create_task(self._process_complete_context(frame))
self.last_processed_frame = frame
self.last_frame = None
elif isinstance(frame, BotStartedSpeakingFrame):
self.latest_bot_started_speaking_frame_timestamp = datetime.now()
await self.push_frame(frame, direction)
elif isinstance(frame, BotStoppedSpeakingFrame):
self.latest_bot_started_speaking_frame_timestamp = None
await self.push_frame(frame, direction)
elif isinstance(frame, UserStartedSpeakingFrame):
logger.debug("Speech Planner: Setting user speaking to True")
self.user_speaking = True
self.last_frame = None
await self.push_frame(frame, direction)
elif isinstance(frame, UserStoppedSpeakingFrame):
logger.debug("Speech Planner: Setting user speaking to False")
self.user_speaking = False
if self.last_frame is not None:
if isinstance(self.last_frame, TranscriptionFrame):
logger.debug("Speech Planner: Sent final after VAD signal go ahead")
await self.push_frame(StartInterruptionFrame(), FrameDirection.DOWNSTREAM)
await self.push_frame(self.last_frame, FrameDirection.DOWNSTREAM)
elif isinstance(self.last_frame, InterimTranscriptionFrame):
logger.debug("Speech Planner: Sent interim after VAD signal go ahead")
self._current_task = self.create_task(self._process_complete_context(self.last_frame))
self.last_processed_frame = self.last_frame
self.last_frame = None
await self.push_frame(frame, direction)
elif not isinstance(frame, OpenAILLMContextFrame | LLMMessagesFrame | LLMUpdateSettingsFrame):
await super().process_frame(frame, direction)
else:
await self.push_frame(frame, direction)