voice-agent-examples / src /nvidia_pipecat /services /animation_graph_service.py
fciannella's picture
Working with service run on 7860
53ea588
raw
history blame
58.4 kB
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License
"""Animation graph service for managing avatar animations and interactions with the animation graph service.
You can find more information about the animation graph service here:
https://docs.nvidia.com/ace/animation-graph-microservice/latest/index.html
"""
import asyncio
import json
import os
import time
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from itertools import chain
from pathlib import Path
from typing import Any
import grpc
import torch
from grpc.aio import StreamUnaryCall
from loguru import logger
from nvidia_ace.audio_pb2 import AudioHeader
from nvidia_ace.status_pb2 import Status
from nvidia_animation_graph.animgraph_pb2_grpc import AnimationDataServiceStub
from nvidia_animation_graph.messages_pb2 import AnimationDataStream, AnimationDataStreamHeader, AnimationIds
from pipecat.frames.frames import (
BotSpeakingFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
StartInterruptionFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util # type: ignore
from nvidia_pipecat.frames.action import (
ActionFrame,
FinishedFacialGestureBotActionFrame,
FinishedGestureBotActionFrame,
FinishedMotionEffectCameraActionFrame,
FinishedPositionBotActionFrame,
FinishedPostureBotActionFrame,
FinishedShotCameraActionFrame,
StartedFacialGestureBotActionFrame,
StartedGestureBotActionFrame,
StartedMotionEffectCameraActionFrame,
StartedPositionBotActionFrame,
StartedPostureBotActionFrame,
StartedShotCameraActionFrame,
StartFacialGestureBotActionFrame,
StartGestureBotActionFrame,
StartMotionEffectCameraActionFrame,
StartPositionBotActionFrame,
StartPostureBotActionFrame,
StartShotCameraActionFrame,
StopFacialGestureBotActionFrame,
StopGestureBotActionFrame,
StopMotionEffectCameraActionFrame,
StopPositionBotActionFrame,
StopPostureBotActionFrame,
StopShotCameraActionFrame,
)
from nvidia_pipecat.frames.animation import (
AnimationDataStreamRawFrame,
AnimationDataStreamStartedFrame,
AnimationDataStreamStoppedFrame,
)
from nvidia_pipecat.services.action_handler import (
ActionHandler,
InternalStateMachineAbortFrame,
InternalStateMachineTriggerFrame,
)
from nvidia_pipecat.services.base_action_service import (
BaseActionService,
ModalityManager,
OverrideModalityManager,
ReplaceModalityManager,
)
from nvidia_pipecat.utils.http_client import CallMethod, HttpClient
from nvidia_pipecat.utils.message_broker import MessageBrokerConfig, message_broker_factory
# Setting the number of threads is required due to an issue when running torch models in a multiprocessing context
# Since this module is imported before any multiprocess has started this ensures that
# the sentence transformer models work
# More information here: https://github.com/pytorch/pytorch/issues/36191
torch.set_num_threads(1)
# Sentence Transformers
cache_path = Path(os.getenv("ANIMATION_GRAPH_SERVICE_CACHE", "./models/"))
model_name = "all-MiniLM-L6-v2"
# We need to use the full local path here, otherwise SentenceTransformers will still query upstream models
model_cache_path = cache_path.resolve() / model_name
model_loaded_from_cache = False
if model_cache_path.is_dir():
try:
model = SentenceTransformer(str(model_cache_path), device="cpu")
model_loaded_from_cache = True
except Exception:
# If loading from cache fails for any reason we try to redownload the model
pass
if not model_loaded_from_cache:
model = SentenceTransformer(model_name, device="cpu")
try:
model.save(str(model_cache_path))
except Exception:
logger.warning("Could not cache sentence transformer model. This can impact startup times.")
def _compute_embedding(text: str) -> Any:
return model.encode(text, convert_to_tensor=True, show_progress_bar=False)
def _similarity(doc1: Any, doc2: Any) -> Any:
return util.cos_sim(doc1, doc2)
async def _delay(coroutine, seconds) -> None:
await asyncio.sleep(seconds)
await coroutine
class ClipParameters(BaseModel):
"""Parameters for a clip."""
clip_id: str
description: str
meaning: str
duration: float
class AnimationConfiguration(BaseModel):
"""Configuration for an animation."""
default_clip_id: str
clips: list[ClipParameters]
class AnimationType(BaseModel):
"""Type of animation."""
duration_relevant_animation_name: str
animations: dict[str, AnimationConfiguration]
class AnimationGraphConfiguration(BaseModel):
"""Configuration for the animation graph service."""
animation_types: dict[str, AnimationType]
@dataclass
class Animation:
"""Represents a single animation clip with metadata.
Attributes:
id (str): Unique identifier for the animation.
description (str): Natural language description of the animation.
meaning (str): Semantic meaning or purpose of the animation.
duration (float): Length of the animation in seconds.
description_embedding (Any): Computed embedding of the description.
meaning_embedding (Any): Computed embedding of the meaning.
"""
id: str
description: str
meaning: str
duration: float
description_embedding: Any
meaning_embedding: Any
def __str__(self) -> str:
"""Returns string representation of the animation.
Returns:
str: Animation ID in string format.
"""
return f"Animation('{self.id}')"
@dataclass
class AnimationMatch:
"""Represents a match between a query and an animation.
Attributes:
animation (Animation): The matched animation.
description_score (float): Similarity score for description match.
meaning_score (float): Similarity score for meaning match.
"""
animation: Animation
description_score: float
meaning_score: float
class AnimationDatabase:
"""Database for managing and querying animation clips.
This class provides functionality to store, search, and retrieve animation clips
based on natural language descriptions (NLD) or animation IDs. It uses semantic
similarity to match queries against animation descriptions and meanings.
Attributes:
animations: List of all available animations.
id_to_animation: Dictionary mapping animation IDs to Animation objects.
"""
def __init__(self, available_animations: list[ClipParameters]) -> None:
"""Initializes the animation database.
Args:
available_animations: List of animation clip parameters to load into the database.
"""
self.id_to_animation: dict[str, Animation] = {}
self.animations = self._load_animations(available_animations)
def _compute_similarities(self, nld: str) -> list[AnimationMatch]:
query_doc = _compute_embedding(nld)
result = []
for anim in self.animations:
scores = []
for doc in [anim.description_embedding, anim.meaning_embedding]:
scores.append(_similarity(query_doc, doc))
result.append(AnimationMatch(anim, scores[0], scores[1]))
return result
def _load_animations(self, available_animations: list[ClipParameters]) -> list[Animation]:
# Pre-compute all embeddings in batch for better performance
descriptions = [anim.description for anim in available_animations]
meanings = [anim.meaning for anim in available_animations]
# Compute embeddings in batch
description_embeddings = model.encode(descriptions, convert_to_tensor=True, show_progress_bar=False)
meaning_embeddings = model.encode(meanings, convert_to_tensor=True, show_progress_bar=False)
result = []
for i, anim in enumerate(available_animations):
a = Animation(
anim.clip_id,
anim.description,
anim.meaning,
anim.duration,
description_embeddings[i],
meaning_embeddings[i],
)
result.append(a)
self.id_to_animation[a.id] = a
return result
def query(self, nld: str, n: int = 3) -> list[AnimationMatch]:
"""Query the database for animations matching a natural language description.
Args:
nld: Natural language description to match against.
n: Number of top matches to return. Defaults to 3.
Returns:
list[AnimationMatch]: Top n animation matches by similarity score.
"""
matches = self._compute_similarities(nld)
sorted_matches = sorted(matches, key=lambda m: max(m.description_score, m.meaning_score))
return sorted_matches[-n:]
def query_one(self, nld: str) -> AnimationMatch:
"""Query the database for the best matching animation.
Args:
nld: Natural language description to match against.
Returns:
AnimationMatch: Best matching animation by similarity score.
"""
return self.query(nld, n=1)[0]
def query_id(self, id: str) -> Animation | None:
"""Query the database for an animation by its ID.
Args:
id: The ID of the animation to find.
Returns:
Animation | None: The matching animation if found, None otherwise.
Case-insensitive matching is performed.
"""
animation = self.id_to_animation.get(id, None)
if not animation:
for anim in self.animations:
if anim.id.casefold().strip() == id.casefold().strip():
return anim
return animation
class AnimationGraphClient(HttpClient):
"""Client for interacting with the Animation Graph service.
This class provides methods to control various aspects of avatar animation through
HTTP requests to the Animation Graph service. It handles state variables for
postures, gestures, facial expressions, camera shots, and other animation controls.
Attributes:
url: Base URL of the Animation Graph service.
stream_uid: Unique identifier for the animation stream.
"""
def __init__(self, url: str, stream_uid: str) -> None:
"""Initialize the Animation Graph client.
Args:
url: Base URL of the Animation Graph service.
stream_uid: Unique identifier for the animation stream.
"""
super().__init__()
self.url = url
self.stream_uid = stream_uid
async def register_stream(self) -> bool:
"""Register a new animation stream with the service.
Returns:
bool: True if registration was successful.
"""
return True
async def stop_request_playback(self, request_id: str) -> bool:
"""Stop playback of a specific animation request.
Args:
request_id: ID of the animation request to stop.
Returns:
bool: True if the playback was successfully stopped.
"""
return await self.delete(
url=f"{self.url}/streams/{self.stream_uid}/requests/{request_id}", headers={"x-stream-id": self.stream_uid}
)
async def set_state_variable(
self, variable_name: str, variable_value: str, graph: str = "avatar", **kwargs
) -> bool:
"""Set a state variable in the animation graph.
Args:
variable_name: Name of the state variable to set.
variable_value: Value to set for the state variable.
graph: Name of the animation graph. Defaults to "avatar".
**kwargs: Additional arguments to pass to the request.
Returns:
bool: True if the state variable was successfully set.
"""
return await self.send_request(
url=f"{self.url}/streams/{self.stream_uid}/animation_graphs/{graph}/variables/{variable_name}/{variable_value}",
params={},
payload={},
headers={"Content-Type": "application/json", "x-stream-id": self.stream_uid},
call_method=CallMethod.PUT,
)
async def set_posture_state(self, posture: str) -> bool:
"""Set the avatar's posture state.
Args:
posture: Name of the posture to set.
Returns:
bool: True if the posture state was successfully set.
"""
return await self.set_state_variable("posture_state", posture)
async def set_gesture_state(self, gesture: str) -> bool:
"""Set the avatar's gesture state.
Args:
gesture: Name of the gesture to set.
Returns:
bool: True if the gesture state was successfully set.
"""
return await self.set_state_variable("gesture_state", gesture)
async def set_facial_gesture_state(self, facial_gesture: str) -> bool:
"""Set the avatar's facial gesture state.
Args:
facial_gesture: Name of the facial gesture to set.
Returns:
bool: True if the facial gesture state was successfully set.
"""
return await self.set_state_variable("facial_gesture_state", facial_gesture)
async def set_position_state(self, position: str) -> bool:
"""Set the avatar's position state.
Args:
position: Name of the position to set.
Returns:
bool: True if the position state was successfully set.
"""
return await self.set_state_variable("position_state", position)
async def set_shot_state(self, shot: str, **kwargs) -> bool:
"""Set the camera shot state.
Args:
shot: Name of the camera shot to set.
**kwargs: Additional arguments to pass to the request.
Returns:
bool: True if the shot state was successfully set.
"""
return await self.set_state_variable("shot_state", shot)
async def set_shot_transition_speed(
self,
start_transition: str | None = None,
stop_transition: str | None = None,
**kwargs,
) -> bool:
"""Set the speed of shot transitions.
Args:
start_transition: Name of the start transition effect.
stop_transition: Name of the stop transition effect.
**kwargs: Additional arguments to pass to the request.
"""
if not start_transition and not stop_transition:
return False
value = start_transition or stop_transition or ""
return await self.set_state_variable("shot_transition_speed", value)
async def set_camera_motion_effect_state(self, effect: str, **kwargs) -> bool:
"""Set the camera motion effect state.
Args:
effect: Name of the camera motion effect to set.
**kwargs: Additional arguments to pass to the request.
Returns:
bool: True if the camera motion effect state was successfully set.
"""
return await self.set_state_variable("camera_motion_effect_state", effect)
class EmbodiedBotActionHandler(ActionHandler):
"""Base class for all action handlers bot animations.
This includes actions like, gestures, postures, facial expressions.
"""
_frame_type_lookup: dict[str, type[ActionFrame]] = {
"start": ActionFrame,
"started": ActionFrame,
"finished": ActionFrame,
"stop": ActionFrame,
}
action_name = "UnknownEmbodiedAction"
nld_parameter_names: set[str] = set()
def __init__(
self,
parent_processor: FrameProcessor,
get_client: Callable[[], AnimationGraphClient],
animation_databases: dict[str, AnimationDatabase],
animation_type_config: AnimationType,
) -> None:
"""Initialize the embodied bot action handler.
Args:
parent_processor: The parent frame processor.
get_client: Function to get the animation graph client.
animation_databases: Dictionary of animation databases.
animation_type_config: Configuration for this animation type.
"""
super().__init__(parent_processor)
self.get_client = get_client
self.animation_databases = animation_databases
self.machine.add_transition("timeout", source="running", dest="stopping", before=["animation_done"]) # type: ignore
self.task: asyncio.Task | None = None
self.resolved_parameters: dict[str, str] = {}
self.selected_animations: dict[str, Animation] = {}
self.previous_action_state: dict[str, Any] | None = None
self.animation_type_config = animation_type_config
self.nld_parameter_names = {self.animation_type_config.duration_relevant_animation_name}
def _select_animation(self, data_base: str, animation_nld: str) -> Animation:
"""If the NLD equals an animation ID return that animation.
Return a similarity search based match otherwise.
"""
assert data_base in self.animation_databases
animation = self.animation_databases[data_base].query_id(animation_nld)
if animation:
return animation
else:
return self.animation_databases[data_base].query_one(animation_nld).animation
def _resolve_nld_parameters(self) -> None:
self.resolved_parameters = {}
for parameter_name, parameter_nld in self.nld_parameters.items():
self.selected_animations[parameter_name] = self._select_animation(parameter_name, parameter_nld)
parameter_resolved = self.selected_animations[parameter_name].id
self.resolved_parameters[parameter_name] = parameter_resolved
logger.info(f"{self.action_name} NLD parameter resolution: {self.resolved_parameters}")
async def start_animation(self) -> bool:
"""Start the animation.
Returns:
bool: True if the animation was successfully started.
"""
raise NotImplementedError()
async def resume_animation(self) -> bool:
"""Resume a previously paused animation.
Returns:
bool: True if the animation was successfully resumed.
"""
return await self.start_animation()
@property
def nld_parameters(self) -> dict[str, str]:
"""Get the natural language description parameters.
Returns:
dict[str, str]: Dictionary of parameter names and values.
"""
return {name: self.action_state.get(name, "") for name in self.nld_parameter_names}
async def on_enter_starting(self, frame: ActionFrame | InternalStateMachineTriggerFrame) -> None:
"""Handle entering the starting state.
Args:
frame: The frame that triggered the state transition.
"""
self._resolve_nld_parameters()
if not await self.start_animation():
logger.warning("Request to Animation Graph Endpoint failed")
await self.parent_processor.queue_for_internal_processing(
InternalStateMachineAbortFrame(
action_name=self.action_name,
action_id=self.action_id,
reason="Request to Animation Graph Endpoint failed",
)
)
else:
await self.parent_processor.queue_for_internal_processing(
self._frame_type_lookup["started"](action_id=self.action_id)
)
async def on_enter_stopping(self, frame: ActionFrame | InternalStateMachineTriggerFrame) -> None:
"""Handle entering the stopping state.
Args:
frame: The frame that triggered the state transition.
"""
if self.task:
await self.parent_processor.cancel_task(self.task)
self.task = None
await self.parent_processor.queue_for_internal_processing(
self._frame_type_lookup["finished"](
action_id=self.action_id,
is_success=self.action_is_success,
was_stopped=self.was_stopped,
failure_reason=self.action_failure_reason,
)
)
async def on_enter_paused(self, frame: ActionFrame | InternalStateMachineTriggerFrame) -> None:
"""Handle entering the paused state.
Args:
frame: The frame that triggered the state transition.
"""
if self.task:
await self.parent_processor.cancel_task(self.task)
self.task = None
async def on_enter_resuming(self, frame: ActionFrame | InternalStateMachineTriggerFrame) -> None:
"""Handle entering the resuming state.
Args:
frame: The frame that triggered the state transition.
"""
if isinstance(frame, InternalStateMachineTriggerFrame) and "previous_action_state" in frame.data:
self.previous_action_state = frame.data["previous_action_state"]
if not await self.resume_animation():
await self.parent_processor.queue_for_internal_processing(
InternalStateMachineAbortFrame(
action_name=self.action_name,
action_id=self.action_id,
reason="Request to Animation Graph endpoint failed",
)
)
else:
logger.info(f"Action {self.action_name}({self.action_id}) resumed.")
await self.parent_processor.queue_for_internal_processing(
InternalStateMachineTriggerFrame("resumed", action_name=self.action_name, action_id=self.action_id)
)
async def animation_done(self, frame: ActionFrame | InternalStateMachineTriggerFrame) -> None:
"""Handle animation completion.
Args:
frame: The frame that triggered the completion.
"""
self.action_is_success = True
async def clear_modality(self) -> None:
"""Clear the current modality state."""
await self.get_client().set_gesture_state("none")
class FiniteAnimationActionHandler(EmbodiedBotActionHandler):
"""Base class for all action handlers dealing with fixed duration animation clips."""
_frame_type_lookup: dict[str, type[ActionFrame]] = {
"start": ActionFrame,
"started": ActionFrame,
"finished": ActionFrame,
"stop": ActionFrame,
}
action_name = "UnknownFiniteAnimationAction"
# # The Animation Graph MS currently does not support sending out the end-of-clip events.
# # therefore we take the animation duration of the following animation
# duration_relevant_animation_name = "gesture"
# nld_parameter_names = {duration_relevant_animation_name}
async def on_enter_running(self, frame: ActionFrame | InternalStateMachineTriggerFrame) -> None:
"""Handle entering the running state.
Args:
frame: The frame that triggered the state transition.
"""
assert self.animation_type_config.duration_relevant_animation_name in self.selected_animations
async def animation_finished() -> None:
params = ",".join([f"{name}={value}" for name, value in self.resolved_parameters.items()])
logger.info(f"{self.action_name}({params}) finished. ")
await self.parent_processor.queue_for_internal_processing(
InternalStateMachineTriggerFrame("timeout", action_name=self.action_name, action_id=self.action_id)
)
seconds = timedelta(
seconds=self.selected_animations[self.animation_type_config.duration_relevant_animation_name].duration
).total_seconds()
if seconds > 0:
self.task = self.parent_processor.create_task(_delay(animation_finished(), seconds))
class GestureBotActionHandler(FiniteAnimationActionHandler):
"""Handler for avatar gesture animations.
Manages gesture-based animations, including starting, stopping, and transitioning
between different gesture states.
"""
_frame_type_lookup: dict[str, type[ActionFrame]] = {
"start": StartGestureBotActionFrame,
"started": StartedGestureBotActionFrame,
"finished": FinishedGestureBotActionFrame,
"stop": StopGestureBotActionFrame,
}
action_name = "GestureBotAction"
async def start_animation(self) -> bool:
"""Start a gesture animation.
Returns:
bool: True if the gesture animation was successfully started.
"""
return await self.get_client().set_gesture_state(**self.resolved_parameters)
async def clear_modality(self) -> None:
"""Clear the current gesture state."""
await self.get_client().set_gesture_state("none")
class FacialGestureBotActionHandler(GestureBotActionHandler):
"""Handler for avatar facial gesture animations.
Manages facial expression animations, including starting, stopping, and transitioning
between different facial gesture states.
"""
_frame_type_lookup: dict[str, type[ActionFrame]] = {
"start": StartFacialGestureBotActionFrame,
"started": StartedFacialGestureBotActionFrame,
"finished": FinishedFacialGestureBotActionFrame,
"stop": StopFacialGestureBotActionFrame,
}
action_name = "FacialGestureBotAction"
async def start_animation(self) -> bool:
"""Start a facial gesture animation.
Returns:
bool: True if the facial gesture animation was successfully started.
"""
return await self.get_client().set_facial_gesture_state(**self.resolved_parameters)
async def clear_modality(self) -> None:
"""Clear the current facial gesture state."""
await self.get_client().set_facial_gesture_state("none")
class MotionEffectCameraActionHandler(FiniteAnimationActionHandler):
"""Handler for camera motion effect actions.
Manages camera motion effects, including starting, stopping, and transitioning
between different camera motion states.
"""
_frame_type_lookup: dict[str, type[ActionFrame]] = {
"start": StartMotionEffectCameraActionFrame,
"started": StartedMotionEffectCameraActionFrame,
"finished": FinishedMotionEffectCameraActionFrame,
"stop": StopMotionEffectCameraActionFrame,
}
action_name = "MotionEffectCameraAction"
async def start_animation(self) -> bool:
"""Start a camera motion effect.
Returns:
bool: True if the camera motion effect was successfully started.
"""
return await self.get_client().set_camera_motion_effect_state(**self.resolved_parameters)
async def clear_modality(self) -> None:
"""Clear the current camera motion effect state."""
await self.get_client().set_camera_motion_effect_state("none")
class PostureBotActionHandler(EmbodiedBotActionHandler):
"""Handler for avatar posture animations.
Manages posture-based animations, including starting, stopping, and transitioning
between different posture states.
"""
_frame_type_lookup: dict[str, type[ActionFrame]] = {
"start": StartPostureBotActionFrame,
"started": StartedPostureBotActionFrame,
"finished": FinishedPostureBotActionFrame,
"stop": StopPostureBotActionFrame,
}
action_name = "PostureBotAction"
async def start_animation(self) -> bool:
"""Start a posture animation.
Returns:
bool: True if the posture animation was successfully started.
"""
return await self.get_client().set_posture_state(**self.resolved_parameters)
async def clear_modality(self) -> None:
"""Clear the current posture state."""
await self.get_client().set_posture_state(self.animation_type_config.animations["posture"].default_clip_id)
class PositionBotActionHandler(EmbodiedBotActionHandler):
"""Handler for avatar position animations.
Manages position-based animations, including starting, stopping, and transitioning
between different position states.
"""
_frame_type_lookup: dict[str, type[ActionFrame]] = {
"start": StartPositionBotActionFrame,
"started": StartedPositionBotActionFrame,
"finished": FinishedPositionBotActionFrame,
"stop": StopPositionBotActionFrame,
}
action_name = "PositionBotAction"
async def start_animation(self) -> bool:
"""Start a position animation.
Returns:
bool: True if the position animation was successfully started.
"""
return await self.get_client().set_position_state(**self.resolved_parameters)
async def clear_modality(self) -> None:
"""Clear the current position state."""
await self.get_client().set_position_state(self.animation_type_config.animations["position"].default_clip_id)
class ShotCameraActionHandler(EmbodiedBotActionHandler):
"""Handler for camera shot animations.
Manages camera shot transitions and states, including starting, stopping, and
transitioning between different camera angles and positions.
"""
_frame_type_lookup: dict[str, type[ActionFrame]] = {
"start": StartShotCameraActionFrame,
"started": StartedShotCameraActionFrame,
"finished": FinishedShotCameraActionFrame,
"stop": StopShotCameraActionFrame,
}
action_name = "ShotCameraAction"
async def _update_state_variables(self, **kwargs) -> bool:
if await self.get_client().set_shot_transition_speed(**kwargs):
return await self.get_client().set_shot_state(**kwargs)
return False
async def start_animation(self) -> bool:
"""Start a camera shot animation.
Returns:
bool: True if the camera shot animation was successfully started.
"""
return await self._update_state_variables(**self.resolved_parameters)
async def resume_animation(self) -> bool:
"""Resume a previously paused camera shot animation.
Returns:
bool: True if the camera shot animation was successfully resumed.
"""
if self.previous_action_state:
transition = self._select_animation("stop_transition", self.previous_action_state["stop_transition"]).id
else:
transition = self.resolved_parameters["start_transition"]
parameters = {
"shot": self.resolved_parameters["shot"],
"start_transition": transition,
}
return await self._update_state_variables(**parameters)
async def clear_modality(self) -> None:
"""Clear the current camera shot state."""
await self.get_client().set_shot_state(self.animation_type_config.animations["shots"].default_clip_id)
def get_action_handler_factory(
st: "AnimationGraphService.ActionConfig", service: BaseActionService, name: str
) -> Callable[[Frame], ActionHandler]:
"""Create a factory function for action handlers.
Args:
st: Action configuration from the AnimationGraphService.
service: The base action service instance.
name: Name of the action type.
Returns:
Callable[[Frame], ActionHandler]: Factory function that creates action handlers.
"""
def factory(frame: Frame) -> ActionHandler:
"""Create an action handler instance.
Args:
frame: Frame containing action parameters.
Returns:
ActionHandler: New action handler instance.
"""
return st.action_handler_cls(
service,
service.get_client,
service.animation_databases,
service.config.animation_types[name],
)
return factory
class AnimationGraphService(BaseActionService):
"""Manage avatar animations and interactions with the animation graph service.
This service coordinates different types of animations including gestures, postures, facial expressions,
and camera shots and intereacts with the animation graph service. It maintains animation databases, handles action
state transitions, and manages animation data streaming (coming from the audio2face service).
The service supports multiple animation types through dedicated action handlers:
- Gesture animations (GestureBotActionHandler)
- Posture animations (PostureBotActionHandler)
- Facial gesture animations (FacialGestureBotActionHandler)
- Camera shot animations (ShotCameraActionHandler)
Each animation type has its own modality manager to handle how animations interact and override each other.
The service processes animation frames, manages animation state, and coordinates with the animation graph client
for streaming animation data.
Note on startup performance:
- To speed up the startup of the pipeline, use the class method `pregenerate_animation_databases()`
to pre-generate the animation databases before starting the pipeline.
Service for managing animation graphs and processing animation-related frames.
Input Frames:
- StartPostureBotActionFrame: Initiates a looping posture animation
- StopPostureBotActionFrame: Stops a running posture animation
- StartGestureBotActionFrame: Initiates a finite gesture animation
- StopGestureBotActionFrame (optional): Stops a running gesture animation
- StartFacialGestureBotActionFrame: Initiates a facial gesture animation
- StopFacialGestureBotActionFrame (optional): Stops a running facial gesture animation
- StartPositionBotActionFrame: Initiates a position animation
- StopPositionBotActionFrame: Stops a running position animation
- StartMotionEffectCameraActionFrame: Initiates a camera motion effect animation
- StopMotionEffectCameraActionFrame (optional): Stops a running camera motion effect animation
- StartShotCameraActionFrame: Initiates a camera shot animation
- StopShotCameraActionFrame (optional): Stops a running camera shot animation
- StartInterruptionFrame (optional): Interrupts the avatar and the current lip animation.
- AnimationDataStreamStartedFrame (consumed): Signals start of animation data stream.
At the moment, this is only used for lip animation data coming from the audio2face service.
- AnimationDataStreamRawFrame (consumed): Contains animation data
- AnimationDataStreamStoppedFrame (consumed): Signals end of animation data stream
Output Frames:
- StartedPostureBotActionFrame (sent up & down): Confirms posture animation has started
- FinishedPostureBotActionFrame (sent up & down): Signals completion of posture animation
- StartedGestureBotActionFrame (sent up & down): Confirms gesture animation has started
- FinishedGestureBotActionFrame (sent up & down): Signals completion of gesture animation
- StartedFacialGestureBotActionFrame (sent up & down): Confirms facial gesture animation has started
- FinishedFacialGestureBotActionFrame (sent up & down): Signals completion of facial gesture animation
- StartedPositionBotActionFrame (sent up & down): Confirms position animation has started
- FinishedPositionBotActionFrame (sent up & down): Signals completion of position animation
- StartedMotionEffectCameraActionFrame (sent up & down): Confirms camera motion effect animation has started
- FinishedMotionEffectCameraActionFrame (sent up & down): Signals completion of camera motion effect animation
- StartedShotCameraActionFrame (sent up & down): Confirms camera shot animation has started
- FinishedShotCameraActionFrame (sent up & down): Signals completion of camera shot animation
- BotStartedSpeakingFrame (sent up & down): Signals that the bot has started speaking
- BotStoppedSpeakingFrame (sent up & down): Signals that the bot has stopped speaking
"""
@dataclass
class ActionConfig:
"""Configuration for an action type.
Attributes:
modality_manager_cls: Class to use for modality management.
action_name: Name of the action type.
action_handler_cls: Class to use for handling actions.
frame_types_to_process: Types of frames this action type processes.
failure_action_factory: Function to create failure action frames.
"""
modality_manager_cls: type[ModalityManager]
action_name: str
action_handler_cls: type[EmbodiedBotActionHandler]
frame_types_to_process: tuple[type[Frame], ...]
failure_action_factory: Callable[[str, Frame], ActionFrame]
supported_animation_types: dict[str, ActionConfig] = {
"gesture": ActionConfig(
ReplaceModalityManager,
"GestureBotAction",
GestureBotActionHandler,
(
StartGestureBotActionFrame,
StartedGestureBotActionFrame,
StopGestureBotActionFrame,
FinishedGestureBotActionFrame,
),
lambda reason, frame: FinishedGestureBotActionFrame(
action_id=frame.action_id, is_success=False, failure_reason=reason
),
),
"posture": ActionConfig(
ReplaceModalityManager,
"PostureBotAction",
PostureBotActionHandler,
(
StartPostureBotActionFrame,
StartedPostureBotActionFrame,
StopPostureBotActionFrame,
FinishedPostureBotActionFrame,
),
lambda reason, frame: FinishedPostureBotActionFrame(
action_id=frame.action_id, is_success=False, failure_reason=reason
),
),
"facial_gesture": ActionConfig(
ReplaceModalityManager,
"FacialGestureBotAction",
FacialGestureBotActionHandler,
(
StartFacialGestureBotActionFrame,
StartedFacialGestureBotActionFrame,
StopFacialGestureBotActionFrame,
FinishedFacialGestureBotActionFrame,
),
lambda reason, frame: FinishedFacialGestureBotActionFrame(
action_id=frame.action_id, is_success=False, failure_reason=reason
),
),
"camera_motion_effect": ActionConfig(
ReplaceModalityManager,
"MotionEffectCameraAction",
MotionEffectCameraActionHandler,
(
StartMotionEffectCameraActionFrame,
StartedMotionEffectCameraActionFrame,
StopMotionEffectCameraActionFrame,
FinishedMotionEffectCameraActionFrame,
),
lambda reason, frame: FinishedMotionEffectCameraActionFrame(
action_id=frame.action_id, is_success=False, failure_reason=reason
),
),
"position": ActionConfig(
OverrideModalityManager,
"PositionBotAction",
PositionBotActionHandler,
(
StartPositionBotActionFrame,
StartedPositionBotActionFrame,
StopPositionBotActionFrame,
FinishedPositionBotActionFrame,
),
lambda reason, frame: FinishedPositionBotActionFrame(
action_id=frame.action_id, is_success=False, failure_reason=reason
),
),
"camera_shot": ActionConfig(
OverrideModalityManager,
"ShotCameraAction",
ShotCameraActionHandler,
(
StartShotCameraActionFrame,
StartedShotCameraActionFrame,
StopShotCameraActionFrame,
FinishedShotCameraActionFrame,
),
lambda reason, frame: FinishedShotCameraActionFrame(
action_id=frame.action_id, is_success=False, failure_reason=reason
),
),
}
animation_databases: dict[str, AnimationDatabase] = {}
frame_types_to_process: tuple[type[Frame], ...] = tuple(
chain.from_iterable(t.frame_types_to_process for t in supported_animation_types.values())
)
def __init__(
self,
*,
animation_graph_rest_url: str,
animation_graph_grpc_target: str,
message_broker_config: MessageBrokerConfig,
config: AnimationGraphConfiguration,
check_data_starvation: bool = True,
):
"""Initialize the animation graph service.
Args:
animation_graph_rest_url: The REST URL for the animation graph service.
animation_graph_grpc_target: The gRPC target for the animation graph service.
message_broker_config: The message broker configuration.
config: The animation graph configuration.
check_data_starvation: Whether to check for data starvation. If enabled this will print a warning as well
as close the stream to AnimGraph if the data is not received in time. This is useful to prevent leaving
the avatar in a bad state when e.g. the audio stream stopped too early (e.g. a TTS connectivity issue).
"""
self.animation_graph_client: AnimationGraphClient | None = None
self.animation_graph_rest_url = animation_graph_rest_url
self.animation_graph_grpc_target = animation_graph_grpc_target
self.config = config
self.message_broker_config = message_broker_config
self.avatar_done_talking = asyncio.Event()
self.avatar_done_talking.set()
self.current_speaking_action_id: str | None = None
self._bot_speaking: bool = False
# Data stream monitoring variables
self._check_data_starvation = check_data_starvation
self._data_stream_in_progress = False
self._data_stream_warning_sent = False
self.animation_data_queue: asyncio.Queue[
AnimationDataStreamStartedFrame | AnimationDataStreamRawFrame | AnimationDataStreamStoppedFrame
] = asyncio.Queue()
self.stream_animation_data_task: asyncio.Task | None = None
self.process_animation_events_task: asyncio.Task | None = None
self._bot_speaking_handler_task: asyncio.Task | None = None
self.channel = grpc.aio.insecure_channel(self.animation_graph_grpc_target)
self.stub = AnimationDataServiceStub(self.channel)
managers = []
for name in self.config.animation_types:
if name in self.supported_animation_types:
supported_type = self.supported_animation_types[name]
managers.append(
supported_type.modality_manager_cls(
frame_types_to_process=supported_type.frame_types_to_process,
action_name=supported_type.action_name,
service=self,
action_handler_factory=get_action_handler_factory(supported_type, self, name),
failure_action_factory=supported_type.failure_action_factory,
)
)
super().__init__(managers)
def get_client(self) -> AnimationGraphClient:
"""Get the animation graph client.
Returns:
AnimationGraphClient: The current animation graph client instance.
Raises:
AssertionError: If the client is not initialized.
"""
assert self.animation_graph_client
return self.animation_graph_client
async def start(self, frame: StartFrame) -> None:
"""Called during pipeline start."""
await super().start(frame)
self.stream_animation_data_task = self.create_task(self._stream_animation_data())
self.process_animation_events_task = self.create_task(self._process_animation_events())
self._bot_speaking_handler_task = self.create_task(self._bot_speaking_handler())
# Load animation databases from cache or create new ones
for type_name, _supported_type in self.supported_animation_types.items():
if type_name in self.config.animation_types:
for animation_name, animation_config in self.config.animation_types[type_name].animations.items():
if animation_name not in AnimationGraphService.animation_databases:
db = AnimationDatabase(animation_config.clips)
AnimationGraphService.animation_databases[animation_name] = db
await self._create_animation_graph_client()
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
"""Process incoming frames.
Handles animation data stream frames and manages animation state transitions.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
if isinstance(
frame, AnimationDataStreamStartedFrame | AnimationDataStreamRawFrame | AnimationDataStreamStoppedFrame
):
# Put into animation data queue and don't push AnimationDataStream frames any further
self.animation_data_queue.put_nowait(frame)
else:
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
await self._interrupt_avatar()
async def _interrupt_avatar(self) -> None:
logger.debug("_interrupt_avatar called")
if self.current_speaking_action_id:
logger.debug(f"stopping animation clip with action_id={self.current_speaking_action_id}")
if not await self.get_client().stop_request_playback(self.current_speaking_action_id):
# This can often happen if the playback finished before the stop request was received
# so we don't want to it as an error, but just log it as a debug message
info_message = f"Stopping playback for {self.current_speaking_action_id} failed (usually harmless)"
logger.debug(info_message)
# await self.push_frame(ErrorFrame(error_message))
else:
logger.debug("received StartInterruptionFrame when no speaking animation clip is playing")
await self._bot_stopped_speaking()
logger.debug("waiting for streaming task to finish")
await self._stop_stream_animation_data_task()
logger.debug("creating new data streaming task with an empty queue")
self.animation_data_queue = asyncio.Queue()
self.stream_animation_data_task = self.create_task(self._stream_animation_data())
async def _process_animation_events(self):
"""Subscribe to animation graph events redis pub/sub channel to receive events about animation clip playback."""
try:
broker = message_broker_factory(config=self.message_broker_config, channels=[])
await broker.wait_for_connection()
except Exception as e:
logger.exception(f"Could not create message broker {e}")
return
while True:
try:
message = await broker.pubsub_receive_message(channels=["animation_graph_events"], timeout=None)
# Check if got a message and that it is for the current stream before loading it (which is expensive)
if message and self.stream_id in message:
logger.debug(f"received event from animation graph ms: {message}")
event = json.loads(message)
if (
event["event_type"] == "request_playback_ended"
or event["event_type"] == "request_playback_interrupted"
):
logger.debug(
f"Bot stopped speaking based on event received from animation graph: {event['event_type']}"
)
await self._bot_stopped_speaking()
except Exception as e:
logger.error(e)
async def _bot_started_speaking(self):
self.avatar_done_talking.clear()
if not self._bot_speaking:
logger.debug("Bot started speaking")
await self.push_frame(BotStartedSpeakingFrame())
await self.push_frame(BotStartedSpeakingFrame(), FrameDirection.UPSTREAM)
self._bot_speaking = True
async def _bot_stopped_speaking(self):
self.avatar_done_talking.set()
if self._bot_speaking:
logger.debug("Bot stopped speaking")
await self.push_frame(BotStoppedSpeakingFrame())
await self.push_frame(BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM)
self._bot_speaking = False
async def _bot_speaking_handler(self):
"""This task sends out BotSpeakingFrames every 200ms to fulfill the protocol of the output transports.
We have no way of knowing when exactly audio chunks are being sent by the renderer to the client
so we simply send a BotSpeakingFrame every 200ms.
"""
TIMEOUT = 0.2
while not self._cancelling:
await asyncio.sleep(TIMEOUT)
if self._bot_speaking:
await self.push_frame(BotSpeakingFrame())
await self.push_frame(BotSpeakingFrame(), FrameDirection.UPSTREAM)
async def _stream_animation_data(self) -> None:
stream: StreamUnaryCall[AnimationDataStream, Status] | None = None
DATA_STREAM_TIMEOUT = 0.3
while not self._cancelling:
try:
if stream:
frame = await asyncio.wait_for(self.animation_data_queue.get(), timeout=DATA_STREAM_TIMEOUT)
else:
frame = await self.animation_data_queue.get()
new_message = None
if isinstance(frame, AnimationDataStreamStartedFrame):
await self.avatar_done_talking.wait()
stream = self.stub.PushAnimationDataStream(metadata=(("x-stream-id", self.stream_id),))
audio_header: AudioHeader = frame.audio_header
self.current_speaking_action_id = frame.action_id
logger.debug(f"Sending AnimationData header with request_id={self.current_speaking_action_id}")
new_message = AnimationDataStream(
animation_data_stream_header=AnimationDataStreamHeader(
animation_ids=AnimationIds(
stream_id=self.stream_id,
request_id=self.current_speaking_action_id,
target_object_id="toto",
),
source_service_id=frame.animation_source_id,
audio_header=audio_header,
skel_animation_header=frame.animation_header,
start_time_code_since_epoch=time.time(),
)
)
await self._bot_started_speaking()
elif isinstance(frame, AnimationDataStreamRawFrame):
new_message = AnimationDataStream(
animation_data=frame.animation_data,
)
elif isinstance(frame, AnimationDataStreamStoppedFrame):
logger.debug(
f"AnimationData stream stopped for request_id={self.current_speaking_action_id}. "
"Does not affect playback."
)
if stream and not stream.done():
await stream.done_writing()
stream = None
# Send new message only if we have an active stream
if new_message and stream and not stream.done():
await stream.write(new_message)
if await self._check_datastream_starvation(frame):
await self._close_stream(stream)
stream = None
except TimeoutError:
if await self._check_datastream_starvation(None):
await self._close_stream(stream)
stream = None
except Exception as e:
logger.error(f"Exception: {e}")
async def _create_animation_graph_client(self) -> None:
if self.stream_id:
await self._close_animation_graph_client()
self.animation_graph_client = AnimationGraphClient(
self.animation_graph_rest_url,
self.stream_id,
)
await self.animation_graph_client.register_stream()
async def _close_animation_graph_client(self) -> None:
if self.animation_graph_client:
await self.animation_graph_client.close()
async def _stop_stream_animation_data_task(self) -> None:
if self.stream_animation_data_task:
await self.cancel_task(self.stream_animation_data_task, timeout=0.3)
self.stream_animation_data_task = None
self._data_stream_in_progress = False
async def _stop_running_tasks(self) -> None:
await self._stop_stream_animation_data_task()
if self.process_animation_events_task:
await self.cancel_task(self.process_animation_events_task, timeout=0.3)
self.process_animation_events_task = None
if self._bot_speaking_handler_task:
await self.cancel_task(self._bot_speaking_handler_task, timeout=0.3)
self._bot_speaking_handler_task = None
async def cleanup(self) -> None:
"""Clean up the service."""
await super().cleanup()
await self._close_animation_graph_client()
await self._stop_running_tasks()
async def stop(self, frame: EndFrame) -> None:
"""Called during pipeline end."""
await self._close_animation_graph_client()
await self._stop_running_tasks()
@classmethod
def pregenerate_animation_databases(cls, config: AnimationGraphConfiguration) -> None:
"""Pre-generate animation databases from configuration and cache them.
You can do this before you start the pipeline you don't have to wait
for the databases to be created during startup.
Args:
config: Animation graph configuration containing animation types and clips
"""
# Generate databases for each animation type
for type_name in config.animation_types:
for animation_name, animation_config in config.animation_types[type_name].animations.items():
db = AnimationDatabase(animation_config.clips)
cls.animation_databases[animation_name] = db
async def _close_stream(self, stream: StreamUnaryCall[AnimationDataStream, Status]) -> None:
logger.debug(
f"Closing stream for request_id={self.current_speaking_action_id}"
" before StoppedFrame was received due to stream starvation."
)
self._data_stream_in_progress = False
if stream and not stream.done():
await stream.done_writing()
try:
await self.get_client().stop_request_playback(self.current_speaking_action_id)
except Exception as e:
logger.info(f"Could not stop request playback on timeout: {e}")
async def _check_datastream_starvation(
self,
frame: AnimationDataStreamStartedFrame | AnimationDataStreamRawFrame | AnimationDataStreamStoppedFrame | None,
) -> bool:
if not self._check_data_starvation:
return False
WARNING_TIMEOUT_S = 0.1 # Print warning if data is delayed by more than this
STARVATION_CLOSE_TIMEOUT_S = 1.0 # Indicate starvation if data is delayed by more than this
if isinstance(frame, AnimationDataStreamStartedFrame):
self._data_until_playback_starves = time.monotonic()
self._audio_samples_per_second = frame.audio_header.samples_per_second
self._audio_bits_per_sample = frame.audio_header.bits_per_sample
self._audio_channel_count = frame.audio_header.channel_count
self._data_stream_in_progress = True
self._data_stream_warning_sent = False
elif isinstance(frame, AnimationDataStreamStoppedFrame):
self._data_stream_in_progress = False
elif self._data_stream_in_progress and isinstance(frame, AnimationDataStreamRawFrame):
audio_buffer_size = len(frame.animation_data.audio.audio_buffer)
audio_buffer_size_s = audio_buffer_size / (
self._audio_samples_per_second * self._audio_bits_per_sample / 8 * self._audio_channel_count
)
self._data_until_playback_starves += audio_buffer_size_s
if self._data_stream_in_progress:
now = time.monotonic()
if now > self._data_until_playback_starves + STARVATION_CLOSE_TIMEOUT_S:
logger.warning(
f"Data stream starvation detected: data behind by {now - self._data_until_playback_starves}s"
)
await self.push_frame(
ErrorFrame("AnimGraph: Data stream starvation detected. AnimGraph connection will be reset.")
)
return True
elif now > self._data_until_playback_starves + WARNING_TIMEOUT_S:
logger.info(f"Data stream data behind by {now - self._data_until_playback_starves}s")
if not self._data_stream_warning_sent:
await self.push_frame(
ErrorFrame(f"AnimGraph: Received data stream is behind by more than {WARNING_TIMEOUT_S}s")
)
self._data_stream_warning_sent = True
return False
else:
return False