Spaces:
Sleeping
Sleeping
| """LangChain callback for automatic cost tracking.""" | |
| import logging | |
| from typing import Any, Dict, List, Optional, Union | |
| from langchain_core.callbacks import BaseCallbackHandler | |
| from langchain_core.messages import BaseMessage | |
| from langchain_core.outputs import LLMResult | |
| from utils.cost_tracker import CostTracker | |
| logger = logging.getLogger(__name__) | |
| class CostTrackingCallback(BaseCallbackHandler): | |
| """ | |
| LangChain callback handler for tracking LLM API costs. | |
| This callback automatically extracts token usage from LLM responses | |
| and tracks costs using the CostTracker. | |
| """ | |
| def __init__( | |
| self, | |
| cost_tracker: CostTracker, | |
| agent_name: str, | |
| provider: Optional[str] = None, | |
| ): | |
| """ | |
| Initialize cost tracking callback. | |
| Args: | |
| cost_tracker: CostTracker instance to use | |
| agent_name: Name of the agent making LLM calls | |
| provider: Provider name (auto-detected if None) | |
| """ | |
| super().__init__() | |
| self.cost_tracker = cost_tracker | |
| self.agent_name = agent_name | |
| self.provider = provider | |
| def on_llm_end( | |
| self, | |
| response: LLMResult, | |
| *, | |
| run_id: Any, | |
| parent_run_id: Optional[Any] = None, | |
| **kwargs: Any, | |
| ) -> Any: | |
| """ | |
| Track cost when LLM call completes. | |
| Args: | |
| response: LLM response with token usage info | |
| run_id: Unique identifier for this run | |
| parent_run_id: Parent run ID if nested | |
| **kwargs: Additional callback arguments | |
| """ | |
| try: | |
| logger.info(f"CostTrackingCallback.on_llm_end called for {self.agent_name}") | |
| # Extract token usage from response | |
| llm_output = response.llm_output or {} | |
| logger.info(f"llm_output keys: {list(llm_output.keys())}") | |
| token_usage = llm_output.get("token_usage", {}) | |
| # Get token counts | |
| input_tokens = token_usage.get("prompt_tokens", 0) | |
| output_tokens = token_usage.get("completion_tokens", 0) | |
| logger.info(f"Token usage: input={input_tokens}, output={output_tokens}") | |
| # Get model name (check both "model_name" for OpenAI/Anthropic and "model" for HuggingFace) | |
| model = llm_output.get("model_name") or llm_output.get("model", "unknown") | |
| logger.info(f"Model: {model}") | |
| # Skip tracking if no tokens (e.g., cached response) | |
| if input_tokens == 0 and output_tokens == 0: | |
| logger.warning(f"Skipping tracking for {self.agent_name} - no tokens") | |
| return | |
| # Track the call | |
| cost = self.cost_tracker.track_call( | |
| agent_name=self.agent_name, | |
| model=model, | |
| input_tokens=input_tokens, | |
| output_tokens=output_tokens, | |
| provider=self.provider, | |
| ) | |
| logger.info( | |
| f"✓ Cost tracked: {self.agent_name} | {model} | " | |
| f"{input_tokens} + {output_tokens} tokens | ${cost:.6f}" | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Failed to track cost for {self.agent_name}: {e}") | |
| def on_llm_error( | |
| self, | |
| error: Union[Exception, KeyboardInterrupt], | |
| *, | |
| run_id: Any, | |
| parent_run_id: Optional[Any] = None, | |
| **kwargs: Any, | |
| ) -> Any: | |
| """ | |
| Handle LLM errors (no cost tracking needed). | |
| Args: | |
| error: The error that occurred | |
| run_id: Unique identifier for this run | |
| parent_run_id: Parent run ID if nested | |
| **kwargs: Additional callback arguments | |
| """ | |
| logger.debug(f"LLM error in {self.agent_name}: {error}") | |
| class WorkflowCostTracker: | |
| """ | |
| Workflow-level cost tracker that manages a CostTracker instance | |
| and provides callbacks for agents. | |
| """ | |
| def __init__(self, budget_config=None): | |
| """ | |
| Initialize workflow cost tracker. | |
| Args: | |
| budget_config: Optional BudgetConfig for cost limits and alerts | |
| """ | |
| self.cost_tracker = CostTracker(budget_config=budget_config) | |
| def get_callback( | |
| self, | |
| agent_name: str, | |
| provider: Optional[str] = None, | |
| ) -> CostTrackingCallback: | |
| """ | |
| Get a cost tracking callback for an agent. | |
| Args: | |
| agent_name: Name of the agent | |
| provider: Provider name (auto-detected if None) | |
| Returns: | |
| CostTrackingCallback instance | |
| """ | |
| return CostTrackingCallback( | |
| cost_tracker=self.cost_tracker, | |
| agent_name=agent_name, | |
| provider=provider, | |
| ) | |
| def get_summary(self) -> Dict[str, Any]: | |
| """ | |
| Get cost summary for the workflow. | |
| Returns: | |
| Dictionary with cost breakdown | |
| """ | |
| return self.cost_tracker.get_summary() | |
| def format_summary(self) -> str: | |
| """ | |
| Format cost summary as human-readable string. | |
| Returns: | |
| Formatted cost summary | |
| """ | |
| return self.cost_tracker.format_summary() | |
| def reset(self): | |
| """Reset cost tracking.""" | |
| self.cost_tracker.reset() | |