"""LangChain callback for automatic cost tracking.""" import logging from typing import Any, Dict, List, Optional, Union from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import BaseMessage from langchain_core.outputs import LLMResult from utils.cost_tracker import CostTracker logger = logging.getLogger(__name__) class CostTrackingCallback(BaseCallbackHandler): """ LangChain callback handler for tracking LLM API costs. This callback automatically extracts token usage from LLM responses and tracks costs using the CostTracker. """ def __init__( self, cost_tracker: CostTracker, agent_name: str, provider: Optional[str] = None, ): """ Initialize cost tracking callback. Args: cost_tracker: CostTracker instance to use agent_name: Name of the agent making LLM calls provider: Provider name (auto-detected if None) """ super().__init__() self.cost_tracker = cost_tracker self.agent_name = agent_name self.provider = provider def on_llm_end( self, response: LLMResult, *, run_id: Any, parent_run_id: Optional[Any] = None, **kwargs: Any, ) -> Any: """ Track cost when LLM call completes. Args: response: LLM response with token usage info run_id: Unique identifier for this run parent_run_id: Parent run ID if nested **kwargs: Additional callback arguments """ try: logger.info(f"CostTrackingCallback.on_llm_end called for {self.agent_name}") # Extract token usage from response llm_output = response.llm_output or {} logger.info(f"llm_output keys: {list(llm_output.keys())}") token_usage = llm_output.get("token_usage", {}) # Get token counts input_tokens = token_usage.get("prompt_tokens", 0) output_tokens = token_usage.get("completion_tokens", 0) logger.info(f"Token usage: input={input_tokens}, output={output_tokens}") # Get model name (check both "model_name" for OpenAI/Anthropic and "model" for HuggingFace) model = llm_output.get("model_name") or llm_output.get("model", "unknown") logger.info(f"Model: {model}") # Skip tracking if no tokens (e.g., cached response) if input_tokens == 0 and output_tokens == 0: logger.warning(f"Skipping tracking for {self.agent_name} - no tokens") return # Track the call cost = self.cost_tracker.track_call( agent_name=self.agent_name, model=model, input_tokens=input_tokens, output_tokens=output_tokens, provider=self.provider, ) logger.info( f"✓ Cost tracked: {self.agent_name} | {model} | " f"{input_tokens} + {output_tokens} tokens | ${cost:.6f}" ) except Exception as e: logger.warning(f"Failed to track cost for {self.agent_name}: {e}") def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], *, run_id: Any, parent_run_id: Optional[Any] = None, **kwargs: Any, ) -> Any: """ Handle LLM errors (no cost tracking needed). Args: error: The error that occurred run_id: Unique identifier for this run parent_run_id: Parent run ID if nested **kwargs: Additional callback arguments """ logger.debug(f"LLM error in {self.agent_name}: {error}") class WorkflowCostTracker: """ Workflow-level cost tracker that manages a CostTracker instance and provides callbacks for agents. """ def __init__(self, budget_config=None): """ Initialize workflow cost tracker. Args: budget_config: Optional BudgetConfig for cost limits and alerts """ self.cost_tracker = CostTracker(budget_config=budget_config) def get_callback( self, agent_name: str, provider: Optional[str] = None, ) -> CostTrackingCallback: """ Get a cost tracking callback for an agent. Args: agent_name: Name of the agent provider: Provider name (auto-detected if None) Returns: CostTrackingCallback instance """ return CostTrackingCallback( cost_tracker=self.cost_tracker, agent_name=agent_name, provider=provider, ) def get_summary(self) -> Dict[str, Any]: """ Get cost summary for the workflow. Returns: Dictionary with cost breakdown """ return self.cost_tracker.get_summary() def format_summary(self) -> str: """ Format cost summary as human-readable string. Returns: Formatted cost summary """ return self.cost_tracker.format_summary() def reset(self): """Reset cost tracking.""" self.cost_tracker.reset()