Spaces:
Sleeping
Sleeping
| """Cost and token tracking for auto-swe-agent LLM calls.""" | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import asdict, dataclass, field | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Optional | |
| class CostConfig: | |
| model_name: str | |
| input_cost_per_1k: float # USD per 1K input tokens | |
| output_cost_per_1k: float # USD per 1K output tokens | |
| # Approximate pricing (USD). Update as providers change rates. | |
| MODEL_COSTS: dict[str, CostConfig] = { | |
| "gemini/gemini-2.0-flash": CostConfig("gemini/gemini-2.0-flash", 0.000075, 0.0003), | |
| "gemini/gemini-2.0-flash-lite": CostConfig( | |
| "gemini/gemini-2.0-flash-lite", 0.0000375, 0.00015 | |
| ), | |
| "groq/llama-3.3-70b-versatile": CostConfig( | |
| "groq/llama-3.3-70b-versatile", 0.00059, 0.00079 | |
| ), | |
| "groq/llama3-8b-8192": CostConfig("groq/llama3-8b-8192", 0.00005, 0.0001), | |
| } | |
| class CostRecord: | |
| model_used: str | |
| input_tokens: int | |
| output_tokens: int | |
| cost_usd: float | |
| timestamp: str | |
| node_type: str # e.g. "planner" | |
| estimated: bool # True if token counts were estimated, not from API | |
| class CostTracker: | |
| def __init__(self, budget_usd: float = 5.0): | |
| self.budget_usd = budget_usd | |
| self.records: list[CostRecord] = [] | |
| self.total_cost: float = 0.0 | |
| def add_call( | |
| self, | |
| model_name: str, | |
| input_tokens: int, | |
| output_tokens: int, | |
| node_type: str = "planner", | |
| estimated: bool = False, | |
| ) -> float: | |
| """Record an LLM call. Returns the cost of this call in USD.""" | |
| config = MODEL_COSTS.get(model_name) | |
| if config: | |
| cost = (input_tokens / 1000) * config.input_cost_per_1k + ( | |
| output_tokens / 1000 | |
| ) * config.output_cost_per_1k | |
| else: | |
| cost = 0.0 # unknown model — don't guess | |
| record = CostRecord( | |
| model_used=model_name, | |
| input_tokens=input_tokens, | |
| output_tokens=output_tokens, | |
| cost_usd=round(cost, 8), | |
| timestamp=datetime.utcnow().isoformat(), | |
| node_type=node_type, | |
| estimated=estimated, | |
| ) | |
| self.records.append(record) | |
| self.total_cost += cost | |
| return cost | |
| def get_total_cost(self) -> float: | |
| return round(self.total_cost, 6) | |
| def get_total_tokens(self) -> int: | |
| return sum(r.input_tokens + r.output_tokens for r in self.records) | |
| def get_model_breakdown(self) -> dict[str, dict]: | |
| breakdown: dict[str, dict] = {} | |
| for r in self.records: | |
| m = breakdown.setdefault( | |
| r.model_used, {"calls": 0, "tokens": 0, "cost": 0.0} | |
| ) | |
| m["calls"] += 1 | |
| m["tokens"] += r.input_tokens + r.output_tokens | |
| m["cost"] = round(m["cost"] + r.cost_usd, 8) | |
| return breakdown | |
| def get_summary(self) -> dict: | |
| breakdown = self.get_model_breakdown() | |
| most_used = ( | |
| max(breakdown, key=lambda m: breakdown[m]["calls"]) if breakdown else None | |
| ) | |
| return { | |
| "total_cost_usd": self.get_total_cost(), | |
| "total_calls": len(self.records), | |
| "total_tokens": self.get_total_tokens(), | |
| "model_breakdown": breakdown, | |
| "most_used_model": most_used, | |
| "budget_usd": self.budget_usd, | |
| "budget_exceeded": self.check_budget_exceeded(), | |
| } | |
| def check_budget_exceeded(self) -> bool: | |
| """Returns True if total cost has exceeded the budget (ignored if budget=0).""" | |
| return self.budget_usd > 0 and self.total_cost > self.budget_usd | |
| def export_json(self, filepath: str | Path) -> None: | |
| Path(filepath).write_text( | |
| json.dumps( | |
| { | |
| "summary": self.get_summary(), | |
| "records": [asdict(r) for r in self.records], | |
| }, | |
| indent=2, | |
| ) | |
| ) | |
| def reset(self) -> None: | |
| self.records.clear() | |
| self.total_cost = 0.0 | |