Cyril Dupland
Agent A2 with tokens
ab8db29
"""Utilities for extracting and normalizing token usage from LLM responses."""
from typing import Dict, Any, Optional
def extract_usage_metadata(response: Any) -> Dict[str, Any]:
"""
Extract usage metadata from a LangChain LLM response.
Handles different formats:
- Dict: Direct dictionary
- UsageMetadata object: Object with attributes
- None/empty: Returns empty dict
Args:
response: LLM response object (AIMessage, etc.)
Returns:
Dictionary with usage information (may be empty)
"""
usage = getattr(response, "usage_metadata", None)
if usage is None:
return {}
# If it's already a dict, return it
if isinstance(usage, dict):
return usage
# If it's an object with attributes, convert to dict
if hasattr(usage, "__dict__"):
return usage.__dict__
# Try to access common attributes directly
result = {}
for attr in ["input_tokens", "prompt_tokens", "output_tokens", "completion_tokens", "total_tokens"]:
if hasattr(usage, attr):
value = getattr(usage, attr)
if value is not None:
result[attr] = value
return result
def normalize_usage(usage: Dict[str, Any]) -> Dict[str, int]:
"""
Normalize usage dictionary to standard format.
Converts various formats (prompt_tokens/completion_tokens vs input_tokens/output_tokens)
to a consistent format.
Args:
usage: Raw usage dictionary from LLM
Returns:
Normalized dict with input_tokens, output_tokens, total_tokens
"""
if not usage:
return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
# Extract input tokens (try multiple keys)
input_tokens = (
usage.get("input_tokens") or
usage.get("prompt_tokens") or
usage.get("n_prompt_tokens") or
0
)
# Extract output tokens (try multiple keys)
output_tokens = (
usage.get("output_tokens") or
usage.get("completion_tokens") or
usage.get("n_completion_tokens") or
0
)
# Extract total tokens
total_tokens = (
usage.get("total_tokens") or
usage.get("n_total_tokens") or
(int(input_tokens) + int(output_tokens))
)
return {
"input_tokens": int(input_tokens),
"output_tokens": int(output_tokens),
"total_tokens": int(total_tokens),
}