File size: 5,294 Bytes
a1bf219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""LangChain callback for automatic cost tracking."""

import logging
from typing import Any, Dict, List, Optional, Union

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult

from utils.cost_tracker import CostTracker

logger = logging.getLogger(__name__)


class CostTrackingCallback(BaseCallbackHandler):
    """
    LangChain callback handler for tracking LLM API costs.

    This callback automatically extracts token usage from LLM responses
    and tracks costs using the CostTracker.
    """

    def __init__(
        self,
        cost_tracker: CostTracker,
        agent_name: str,
        provider: Optional[str] = None,
    ):
        """
        Initialize cost tracking callback.

        Args:
            cost_tracker: CostTracker instance to use
            agent_name: Name of the agent making LLM calls
            provider: Provider name (auto-detected if None)
        """
        super().__init__()
        self.cost_tracker = cost_tracker
        self.agent_name = agent_name
        self.provider = provider

    def on_llm_end(
        self,
        response: LLMResult,
        *,
        run_id: Any,
        parent_run_id: Optional[Any] = None,
        **kwargs: Any,
    ) -> Any:
        """
        Track cost when LLM call completes.

        Args:
            response: LLM response with token usage info
            run_id: Unique identifier for this run
            parent_run_id: Parent run ID if nested
            **kwargs: Additional callback arguments
        """
        try:
            logger.info(f"CostTrackingCallback.on_llm_end called for {self.agent_name}")

            # Extract token usage from response
            llm_output = response.llm_output or {}
            logger.info(f"llm_output keys: {list(llm_output.keys())}")

            token_usage = llm_output.get("token_usage", {})

            # Get token counts
            input_tokens = token_usage.get("prompt_tokens", 0)
            output_tokens = token_usage.get("completion_tokens", 0)

            logger.info(f"Token usage: input={input_tokens}, output={output_tokens}")

            # Get model name (check both "model_name" for OpenAI/Anthropic and "model" for HuggingFace)
            model = llm_output.get("model_name") or llm_output.get("model", "unknown")
            logger.info(f"Model: {model}")

            # Skip tracking if no tokens (e.g., cached response)
            if input_tokens == 0 and output_tokens == 0:
                logger.warning(f"Skipping tracking for {self.agent_name} - no tokens")
                return

            # Track the call
            cost = self.cost_tracker.track_call(
                agent_name=self.agent_name,
                model=model,
                input_tokens=input_tokens,
                output_tokens=output_tokens,
                provider=self.provider,
            )

            logger.info(
                f"✓ Cost tracked: {self.agent_name} | {model} | "
                f"{input_tokens} + {output_tokens} tokens | ${cost:.6f}"
            )

        except Exception as e:
            logger.warning(f"Failed to track cost for {self.agent_name}: {e}")

    def on_llm_error(
        self,
        error: Union[Exception, KeyboardInterrupt],
        *,
        run_id: Any,
        parent_run_id: Optional[Any] = None,
        **kwargs: Any,
    ) -> Any:
        """
        Handle LLM errors (no cost tracking needed).

        Args:
            error: The error that occurred
            run_id: Unique identifier for this run
            parent_run_id: Parent run ID if nested
            **kwargs: Additional callback arguments
        """
        logger.debug(f"LLM error in {self.agent_name}: {error}")


class WorkflowCostTracker:
    """
    Workflow-level cost tracker that manages a CostTracker instance
    and provides callbacks for agents.
    """

    def __init__(self, budget_config=None):
        """
        Initialize workflow cost tracker.

        Args:
            budget_config: Optional BudgetConfig for cost limits and alerts
        """
        self.cost_tracker = CostTracker(budget_config=budget_config)

    def get_callback(
        self,
        agent_name: str,
        provider: Optional[str] = None,
    ) -> CostTrackingCallback:
        """
        Get a cost tracking callback for an agent.

        Args:
            agent_name: Name of the agent
            provider: Provider name (auto-detected if None)

        Returns:
            CostTrackingCallback instance
        """
        return CostTrackingCallback(
            cost_tracker=self.cost_tracker,
            agent_name=agent_name,
            provider=provider,
        )

    def get_summary(self) -> Dict[str, Any]:
        """
        Get cost summary for the workflow.

        Returns:
            Dictionary with cost breakdown
        """
        return self.cost_tracker.get_summary()

    def format_summary(self) -> str:
        """
        Format cost summary as human-readable string.

        Returns:
            Formatted cost summary
        """
        return self.cost_tracker.format_summary()

    def reset(self):
        """Reset cost tracking."""
        self.cost_tracker.reset()