File size: 2,485 Bytes
ab8db29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Utilities for extracting and normalizing token usage from LLM responses."""
from typing import Dict, Any, Optional


def extract_usage_metadata(response: Any) -> Dict[str, Any]:
    """
    Extract usage metadata from a LangChain LLM response.
    
    Handles different formats:
    - Dict: Direct dictionary
    - UsageMetadata object: Object with attributes
    - None/empty: Returns empty dict
    
    Args:
        response: LLM response object (AIMessage, etc.)
        
    Returns:
        Dictionary with usage information (may be empty)
    """
    usage = getattr(response, "usage_metadata", None)
    
    if usage is None:
        return {}
    
    # If it's already a dict, return it
    if isinstance(usage, dict):
        return usage
    
    # If it's an object with attributes, convert to dict
    if hasattr(usage, "__dict__"):
        return usage.__dict__
    
    # Try to access common attributes directly
    result = {}
    for attr in ["input_tokens", "prompt_tokens", "output_tokens", "completion_tokens", "total_tokens"]:
        if hasattr(usage, attr):
            value = getattr(usage, attr)
            if value is not None:
                result[attr] = value
    
    return result


def normalize_usage(usage: Dict[str, Any]) -> Dict[str, int]:
    """
    Normalize usage dictionary to standard format.
    
    Converts various formats (prompt_tokens/completion_tokens vs input_tokens/output_tokens)
    to a consistent format.
    
    Args:
        usage: Raw usage dictionary from LLM
        
    Returns:
        Normalized dict with input_tokens, output_tokens, total_tokens
    """
    if not usage:
        return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
    
    # Extract input tokens (try multiple keys)
    input_tokens = (
        usage.get("input_tokens") or
        usage.get("prompt_tokens") or
        usage.get("n_prompt_tokens") or
        0
    )
    
    # Extract output tokens (try multiple keys)
    output_tokens = (
        usage.get("output_tokens") or
        usage.get("completion_tokens") or
        usage.get("n_completion_tokens") or
        0
    )
    
    # Extract total tokens
    total_tokens = (
        usage.get("total_tokens") or
        usage.get("n_total_tokens") or
        (int(input_tokens) + int(output_tokens))
    )
    
    return {
        "input_tokens": int(input_tokens),
        "output_tokens": int(output_tokens),
        "total_tokens": int(total_tokens),
    }