auto-swe-agent-ui / tracking /cost_tracker.py
DevilBits's picture
fix: enforce safe empty bounds for tracking data charts and match dataframe list alignments
6085b61
"""Cost and token tracking for auto-swe-agent LLM calls."""
from __future__ import annotations
import json
from dataclasses import asdict, dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Optional
@dataclass
class CostConfig:
model_name: str
input_cost_per_1k: float # USD per 1K input tokens
output_cost_per_1k: float # USD per 1K output tokens
# Approximate pricing (USD). Update as providers change rates.
MODEL_COSTS: dict[str, CostConfig] = {
"gemini/gemini-2.0-flash": CostConfig("gemini/gemini-2.0-flash", 0.000075, 0.0003),
"gemini/gemini-2.0-flash-lite": CostConfig(
"gemini/gemini-2.0-flash-lite", 0.0000375, 0.00015
),
"groq/llama-3.3-70b-versatile": CostConfig(
"groq/llama-3.3-70b-versatile", 0.00059, 0.00079
),
"groq/llama3-8b-8192": CostConfig("groq/llama3-8b-8192", 0.00005, 0.0001),
}
@dataclass
class CostRecord:
model_used: str
input_tokens: int
output_tokens: int
cost_usd: float
timestamp: str
node_type: str # e.g. "planner"
estimated: bool # True if token counts were estimated, not from API
class CostTracker:
def __init__(self, budget_usd: float = 5.0):
self.budget_usd = budget_usd
self.records: list[CostRecord] = []
self.total_cost: float = 0.0
def add_call(
self,
model_name: str,
input_tokens: int,
output_tokens: int,
node_type: str = "planner",
estimated: bool = False,
) -> float:
"""Record an LLM call. Returns the cost of this call in USD."""
config = MODEL_COSTS.get(model_name)
if config:
cost = (input_tokens / 1000) * config.input_cost_per_1k + (
output_tokens / 1000
) * config.output_cost_per_1k
else:
cost = 0.0 # unknown model — don't guess
record = CostRecord(
model_used=model_name,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost_usd=round(cost, 8),
timestamp=datetime.utcnow().isoformat(),
node_type=node_type,
estimated=estimated,
)
self.records.append(record)
self.total_cost += cost
return cost
def get_total_cost(self) -> float:
return round(self.total_cost, 6)
def get_total_tokens(self) -> int:
return sum(r.input_tokens + r.output_tokens for r in self.records)
def get_model_breakdown(self) -> dict[str, dict]:
breakdown: dict[str, dict] = {}
for r in self.records:
m = breakdown.setdefault(
r.model_used, {"calls": 0, "tokens": 0, "cost": 0.0}
)
m["calls"] += 1
m["tokens"] += r.input_tokens + r.output_tokens
m["cost"] = round(m["cost"] + r.cost_usd, 8)
return breakdown
def get_summary(self) -> dict:
breakdown = self.get_model_breakdown()
most_used = (
max(breakdown, key=lambda m: breakdown[m]["calls"]) if breakdown else None
)
return {
"total_cost_usd": self.get_total_cost(),
"total_calls": len(self.records),
"total_tokens": self.get_total_tokens(),
"model_breakdown": breakdown,
"most_used_model": most_used,
"budget_usd": self.budget_usd,
"budget_exceeded": self.check_budget_exceeded(),
}
def check_budget_exceeded(self) -> bool:
"""Returns True if total cost has exceeded the budget (ignored if budget=0)."""
return self.budget_usd > 0 and self.total_cost > self.budget_usd
def export_json(self, filepath: str | Path) -> None:
Path(filepath).write_text(
json.dumps(
{
"summary": self.get_summary(),
"records": [asdict(r) for r in self.records],
},
indent=2,
)
)
def reset(self) -> None:
self.records.clear()
self.total_cost = 0.0