Spaces:
Sleeping
Sleeping
File size: 5,294 Bytes
a1bf219 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
"""LangChain callback for automatic cost tracking."""
import logging
from typing import Any, Dict, List, Optional, Union
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from utils.cost_tracker import CostTracker
logger = logging.getLogger(__name__)
class CostTrackingCallback(BaseCallbackHandler):
"""
LangChain callback handler for tracking LLM API costs.
This callback automatically extracts token usage from LLM responses
and tracks costs using the CostTracker.
"""
def __init__(
self,
cost_tracker: CostTracker,
agent_name: str,
provider: Optional[str] = None,
):
"""
Initialize cost tracking callback.
Args:
cost_tracker: CostTracker instance to use
agent_name: Name of the agent making LLM calls
provider: Provider name (auto-detected if None)
"""
super().__init__()
self.cost_tracker = cost_tracker
self.agent_name = agent_name
self.provider = provider
def on_llm_end(
self,
response: LLMResult,
*,
run_id: Any,
parent_run_id: Optional[Any] = None,
**kwargs: Any,
) -> Any:
"""
Track cost when LLM call completes.
Args:
response: LLM response with token usage info
run_id: Unique identifier for this run
parent_run_id: Parent run ID if nested
**kwargs: Additional callback arguments
"""
try:
logger.info(f"CostTrackingCallback.on_llm_end called for {self.agent_name}")
# Extract token usage from response
llm_output = response.llm_output or {}
logger.info(f"llm_output keys: {list(llm_output.keys())}")
token_usage = llm_output.get("token_usage", {})
# Get token counts
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
logger.info(f"Token usage: input={input_tokens}, output={output_tokens}")
# Get model name (check both "model_name" for OpenAI/Anthropic and "model" for HuggingFace)
model = llm_output.get("model_name") or llm_output.get("model", "unknown")
logger.info(f"Model: {model}")
# Skip tracking if no tokens (e.g., cached response)
if input_tokens == 0 and output_tokens == 0:
logger.warning(f"Skipping tracking for {self.agent_name} - no tokens")
return
# Track the call
cost = self.cost_tracker.track_call(
agent_name=self.agent_name,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
provider=self.provider,
)
logger.info(
f"✓ Cost tracked: {self.agent_name} | {model} | "
f"{input_tokens} + {output_tokens} tokens | ${cost:.6f}"
)
except Exception as e:
logger.warning(f"Failed to track cost for {self.agent_name}: {e}")
def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: Any,
parent_run_id: Optional[Any] = None,
**kwargs: Any,
) -> Any:
"""
Handle LLM errors (no cost tracking needed).
Args:
error: The error that occurred
run_id: Unique identifier for this run
parent_run_id: Parent run ID if nested
**kwargs: Additional callback arguments
"""
logger.debug(f"LLM error in {self.agent_name}: {error}")
class WorkflowCostTracker:
"""
Workflow-level cost tracker that manages a CostTracker instance
and provides callbacks for agents.
"""
def __init__(self, budget_config=None):
"""
Initialize workflow cost tracker.
Args:
budget_config: Optional BudgetConfig for cost limits and alerts
"""
self.cost_tracker = CostTracker(budget_config=budget_config)
def get_callback(
self,
agent_name: str,
provider: Optional[str] = None,
) -> CostTrackingCallback:
"""
Get a cost tracking callback for an agent.
Args:
agent_name: Name of the agent
provider: Provider name (auto-detected if None)
Returns:
CostTrackingCallback instance
"""
return CostTrackingCallback(
cost_tracker=self.cost_tracker,
agent_name=agent_name,
provider=provider,
)
def get_summary(self) -> Dict[str, Any]:
"""
Get cost summary for the workflow.
Returns:
Dictionary with cost breakdown
"""
return self.cost_tracker.get_summary()
def format_summary(self) -> str:
"""
Format cost summary as human-readable string.
Returns:
Formatted cost summary
"""
return self.cost_tracker.format_summary()
def reset(self):
"""Reset cost tracking."""
self.cost_tracker.reset()
|