File size: 2,952 Bytes
ae5413a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9e710f
 
 
 
 
 
 
 
 
 
 
 
 
ae5413a
 
 
 
 
 
 
 
 
 
 
 
 
 
3572ba0
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Token tracking middleware for monitoring API usage."""

from collections.abc import Awaitable, Callable

import structlog
from agent_framework._middleware import ChatContext, ChatMiddleware

logger = structlog.get_logger()


class TokenTrackingMiddleware(ChatMiddleware):
    """Tracks token usage across chat requests.

    This middleware logs token usage after each chat completion
    and maintains running totals for the session.

    Usage metrics are logged via structlog for observability.
    """

    def __init__(self) -> None:
        self.total_input_tokens = 0
        self.total_output_tokens = 0
        self.request_count = 0

    async def process(
        self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
    ) -> None:
        """Process request and track token usage."""
        await next(context)

        # Extract usage from response if available
        if context.result is None:
            return

        usage = None

        # Try to get usage from response
        if hasattr(context.result, "usage"):
            usage = context.result.usage
        elif hasattr(context.result, "messages") and context.result.messages:
            # Check first message for usage metadata
            msg = context.result.messages[0]
            if hasattr(msg, "metadata") and msg.metadata:
                usage = msg.metadata.get("usage")

        if usage:
            # Handle both dict-like and object attribute access
            if hasattr(usage, "get"):
                # Dict-like access
                input_tokens = usage.get("input_tokens", 0) or usage.get("prompt_tokens", 0)
                output_tokens = usage.get("output_tokens", 0) or usage.get("completion_tokens", 0)
            else:
                # Object attribute access (Pydantic models, etc.)
                input_tokens = getattr(usage, "input_tokens", 0) or getattr(
                    usage, "prompt_tokens", 0
                )
                output_tokens = getattr(usage, "output_tokens", 0) or getattr(
                    usage, "completion_tokens", 0
                )

            self.total_input_tokens += input_tokens
            self.total_output_tokens += output_tokens
            self.request_count += 1

            logger.info(
                "Token usage",
                request_input=input_tokens,
                request_output=output_tokens,
                total_input=self.total_input_tokens,
                total_output=self.total_output_tokens,
                total_requests=self.request_count,
            )

    def get_stats(self) -> dict[str, int]:
        """Get cumulative token usage statistics.

        Returns:
            Dictionary with total_input, total_output, and request_count.
        """
        return {
            "total_input": self.total_input_tokens,
            "total_output": self.total_output_tokens,
            "request_count": self.request_count,
        }