trading-tools / utils /callbacks /cost_tracking_callback.py
Deploy Bot
Deploy Trading Analysis Platform to HuggingFace Spaces
a1bf219
"""LangChain callback for automatic cost tracking."""
import logging
from typing import Any, Dict, List, Optional, Union
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from utils.cost_tracker import CostTracker
logger = logging.getLogger(__name__)
class CostTrackingCallback(BaseCallbackHandler):
"""
LangChain callback handler for tracking LLM API costs.
This callback automatically extracts token usage from LLM responses
and tracks costs using the CostTracker.
"""
def __init__(
self,
cost_tracker: CostTracker,
agent_name: str,
provider: Optional[str] = None,
):
"""
Initialize cost tracking callback.
Args:
cost_tracker: CostTracker instance to use
agent_name: Name of the agent making LLM calls
provider: Provider name (auto-detected if None)
"""
super().__init__()
self.cost_tracker = cost_tracker
self.agent_name = agent_name
self.provider = provider
def on_llm_end(
self,
response: LLMResult,
*,
run_id: Any,
parent_run_id: Optional[Any] = None,
**kwargs: Any,
) -> Any:
"""
Track cost when LLM call completes.
Args:
response: LLM response with token usage info
run_id: Unique identifier for this run
parent_run_id: Parent run ID if nested
**kwargs: Additional callback arguments
"""
try:
logger.info(f"CostTrackingCallback.on_llm_end called for {self.agent_name}")
# Extract token usage from response
llm_output = response.llm_output or {}
logger.info(f"llm_output keys: {list(llm_output.keys())}")
token_usage = llm_output.get("token_usage", {})
# Get token counts
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
logger.info(f"Token usage: input={input_tokens}, output={output_tokens}")
# Get model name (check both "model_name" for OpenAI/Anthropic and "model" for HuggingFace)
model = llm_output.get("model_name") or llm_output.get("model", "unknown")
logger.info(f"Model: {model}")
# Skip tracking if no tokens (e.g., cached response)
if input_tokens == 0 and output_tokens == 0:
logger.warning(f"Skipping tracking for {self.agent_name} - no tokens")
return
# Track the call
cost = self.cost_tracker.track_call(
agent_name=self.agent_name,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
provider=self.provider,
)
logger.info(
f"✓ Cost tracked: {self.agent_name} | {model} | "
f"{input_tokens} + {output_tokens} tokens | ${cost:.6f}"
)
except Exception as e:
logger.warning(f"Failed to track cost for {self.agent_name}: {e}")
def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: Any,
parent_run_id: Optional[Any] = None,
**kwargs: Any,
) -> Any:
"""
Handle LLM errors (no cost tracking needed).
Args:
error: The error that occurred
run_id: Unique identifier for this run
parent_run_id: Parent run ID if nested
**kwargs: Additional callback arguments
"""
logger.debug(f"LLM error in {self.agent_name}: {error}")
class WorkflowCostTracker:
"""
Workflow-level cost tracker that manages a CostTracker instance
and provides callbacks for agents.
"""
def __init__(self, budget_config=None):
"""
Initialize workflow cost tracker.
Args:
budget_config: Optional BudgetConfig for cost limits and alerts
"""
self.cost_tracker = CostTracker(budget_config=budget_config)
def get_callback(
self,
agent_name: str,
provider: Optional[str] = None,
) -> CostTrackingCallback:
"""
Get a cost tracking callback for an agent.
Args:
agent_name: Name of the agent
provider: Provider name (auto-detected if None)
Returns:
CostTrackingCallback instance
"""
return CostTrackingCallback(
cost_tracker=self.cost_tracker,
agent_name=agent_name,
provider=provider,
)
def get_summary(self) -> Dict[str, Any]:
"""
Get cost summary for the workflow.
Returns:
Dictionary with cost breakdown
"""
return self.cost_tracker.get_summary()
def format_summary(self) -> str:
"""
Format cost summary as human-readable string.
Returns:
Formatted cost summary
"""
return self.cost_tracker.format_summary()
def reset(self):
"""Reset cost tracking."""
self.cost_tracker.reset()