open-voice-agent / src /agent /_langchain_usage_patch.py
dvalle08's picture
Integrate Langchain usage patch and update UI metrics: Apply Langchain usage patch in agent.py, adjust pipeline stage layout in index.html, and enhance metric descriptions in main.js for improved clarity and functionality.
e9aa04f
"""Patch LiveKit LangChain bridge to propagate token usage into LLMMetrics.
LiveKit computes LLMMetrics token fields from ``ChatChunk.usage``. The upstream
``livekit.plugins.langchain.langgraph`` bridge emits chunks with content but, in
some providers, does not map LangChain ``usage_metadata`` into ``usage``.
When that happens, Langfuse receives token metrics as zeros.
"""
from __future__ import annotations
from typing import Any, Mapping, Optional
from langchain_core.messages import BaseMessageChunk
from livekit.agents import llm, utils
from livekit.plugins.langchain import langgraph as lk_langgraph
from src.core.logger import logger
_PATCH_FLAG = "_open_voice_agent_usage_patch_applied"
_ORIGINAL_TO_CHAT_CHUNK = "_open_voice_agent_original_to_chat_chunk"
_MISSING_USAGE_LOGGED = "_open_voice_agent_missing_usage_logged"
def apply_langchain_usage_patch() -> bool:
"""Apply an idempotent patch to preserve token usage from LangChain chunks."""
if getattr(lk_langgraph, _PATCH_FLAG, False):
return False
original = getattr(lk_langgraph, "_to_chat_chunk", None)
if not callable(original):
logger.warning("LangChain usage patch skipped: _to_chat_chunk not found")
return False
def _patched_to_chat_chunk(msg: str | Any) -> llm.ChatChunk | None:
message_id = utils.shortuuid("LC_")
content: str | None = None
usage = _completion_usage_from_message_chunk(msg)
if isinstance(msg, str):
content = msg
elif isinstance(msg, BaseMessageChunk):
text_value = getattr(msg, "text", None)
if text_value is not None:
content = str(text_value)
chunk_id = getattr(msg, "id", None)
if chunk_id:
message_id = chunk_id # type: ignore[assignment]
# Preserve usage-only chunks (final event often carries usage with no text).
if not content and usage is None:
return None
delta = (
llm.ChoiceDelta(
role="assistant",
content=content,
)
if content
else None
)
return llm.ChatChunk(
id=message_id,
delta=delta,
usage=usage,
)
setattr(lk_langgraph, _ORIGINAL_TO_CHAT_CHUNK, original)
setattr(lk_langgraph, "_to_chat_chunk", _patched_to_chat_chunk)
setattr(lk_langgraph, _PATCH_FLAG, True)
logger.info("Applied LangChain usage bridge patch for LLM token metrics")
return True
def _completion_usage_from_message_chunk(
message_chunk: Any,
) -> Optional[llm.CompletionUsage]:
if not isinstance(message_chunk, BaseMessageChunk):
return None
usage_metadata = _as_mapping(getattr(message_chunk, "usage_metadata", None))
if usage_metadata:
usage = _completion_usage_from_mapping(usage_metadata)
if usage:
return usage
# Fallback for providers that place usage in response_metadata.
response_metadata = _as_mapping(getattr(message_chunk, "response_metadata", None))
if not response_metadata:
_log_missing_usage_once(message_chunk)
return None
response_usage = (
_as_mapping(response_metadata.get("usage_metadata"))
or _as_mapping(response_metadata.get("usage"))
or _as_mapping(response_metadata.get("token_usage"))
)
usage = _completion_usage_from_mapping(response_usage) if response_usage else None
if usage is None:
_log_missing_usage_once(message_chunk)
return usage
def _completion_usage_from_mapping(
usage: Mapping[str, Any] | None,
) -> Optional[llm.CompletionUsage]:
if not usage:
return None
prompt_tokens = _as_int(
usage.get("input_tokens"),
usage.get("prompt_tokens"),
)
completion_tokens = _as_int(
usage.get("output_tokens"),
usage.get("completion_tokens"),
)
total_tokens = _as_int(usage.get("total_tokens"))
# Derive missing pieces when enough data is present.
if total_tokens is None and prompt_tokens is not None and completion_tokens is not None:
total_tokens = prompt_tokens + completion_tokens
if prompt_tokens is None and total_tokens is not None and completion_tokens is not None:
prompt_tokens = max(total_tokens - completion_tokens, 0)
if completion_tokens is None and total_tokens is not None and prompt_tokens is not None:
completion_tokens = max(total_tokens - prompt_tokens, 0)
if prompt_tokens is None or completion_tokens is None or total_tokens is None:
return None
prompt_cached_tokens = _extract_prompt_cached_tokens(usage)
return llm.CompletionUsage(
completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens,
prompt_cached_tokens=prompt_cached_tokens,
total_tokens=total_tokens,
)
def _extract_prompt_cached_tokens(usage: Mapping[str, Any]) -> int:
cached = _as_int(
usage.get("prompt_cached_tokens"),
usage.get("cached_tokens"),
usage.get("cache_read_tokens"),
)
if cached is not None:
return max(cached, 0)
input_details = _as_mapping(usage.get("input_token_details"))
if input_details:
cached = _as_int(
input_details.get("cache_read"),
input_details.get("cache_read_tokens"),
input_details.get("cached_tokens"),
)
if cached is not None:
return max(cached, 0)
prompt_details = _as_mapping(usage.get("prompt_tokens_details"))
if prompt_details:
cached = _as_int(
prompt_details.get("cached_tokens"),
prompt_details.get("cache_read"),
prompt_details.get("cache_read_tokens"),
)
if cached is not None:
return max(cached, 0)
return 0
def _as_mapping(value: Any) -> Optional[Mapping[str, Any]]:
if isinstance(value, Mapping):
return value
if hasattr(value, "model_dump"):
dumped = value.model_dump(exclude_none=True)
if isinstance(dumped, Mapping):
return dumped
if hasattr(value, "dict"):
dumped = value.dict()
if isinstance(dumped, Mapping):
return dumped
return None
def _as_int(*values: Any) -> Optional[int]:
for value in values:
if isinstance(value, bool):
continue
if isinstance(value, int):
return value
if isinstance(value, float) and value.is_integer():
return int(value)
if isinstance(value, str):
stripped = value.strip()
if stripped.isdigit():
return int(stripped)
return None
def _log_missing_usage_once(message_chunk: BaseMessageChunk) -> None:
if getattr(lk_langgraph, _MISSING_USAGE_LOGGED, False):
return
setattr(lk_langgraph, _MISSING_USAGE_LOGGED, True)
logger.info(
"LLM chunk arrived without token usage metadata; LLM token metrics may remain zero. chunk_type=%s",
type(message_chunk).__name__,
)