Spaces:
Sleeping
Sleeping
File size: 4,084 Bytes
1499363 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | """Per-workflow API usage tracker.
Listens to every Message produced by the backend and accumulates token counts
and estimated cost per model role. Register with add_message_listener() and
retrieve results via get_rows() when the workflow finishes.
"""
from __future__ import annotations
import threading
from scider.core.types import Message
class UsageTracker:
def __init__(self) -> None:
self._lock = threading.Lock()
self._stats: dict[str, dict] = {} # role -> {model, input_tokens, output_tokens, calls}
def on_message(self, msg: Message) -> None:
input_t = msg.prompt_tokens or 0
output_t = msg.completion_tokens or 0
if not msg.llm_sender or not (input_t or output_t):
return
role = msg.llm_sender
# Some senders (e.g. the Claude Agent SDK coding subagent) report an
# exact cost directly — that price correctly accounts for prompt-cache
# read/write rates, so prefer it over a flat litellm token estimate.
cost_override = getattr(msg, "cost_usd", None)
with self._lock:
if role not in self._stats:
model_id = _model_id_for_role(role)
self._stats[role] = {
"model": model_id,
"input_tokens": 0,
"output_tokens": 0,
"calls": 0,
# tokens from messages WITHOUT an exact cost — estimated below
"est_input_tokens": 0,
"est_output_tokens": 0,
"cost_override_sum": 0.0,
"has_override": False,
}
s = self._stats[role]
s["input_tokens"] += input_t
s["output_tokens"] += output_t
s["calls"] += 1
if cost_override is not None:
s["cost_override_sum"] += cost_override
s["has_override"] = True
else:
s["est_input_tokens"] += input_t
s["est_output_tokens"] += output_t
def get_rows(self) -> list[dict]:
"""Return per-role stats sorted by total token usage (descending)."""
with self._lock:
rows = []
for role, s in self._stats.items():
est = _estimate_cost(s["model"], s["est_input_tokens"], s["est_output_tokens"])
if s["has_override"]:
# Exact provider cost + estimate for any non-exact messages.
cost = s["cost_override_sum"] + (est or 0.0)
else:
cost = est
rows.append(
{
"role": role,
"model": s["model"],
"input_tokens": s["input_tokens"],
"output_tokens": s["output_tokens"],
"calls": s["calls"],
"cost_usd": cost,
}
)
return sorted(rows, key=lambda r: -(r["input_tokens"] + r["output_tokens"]))
@property
def total_tokens(self) -> int:
with self._lock:
return sum(s["input_tokens"] + s["output_tokens"] for s in self._stats.values())
@property
def total_cost(self) -> float | None:
rows = self.get_rows()
costs = [r["cost_usd"] for r in rows if r["cost_usd"] is not None]
return sum(costs) if costs else None
def _model_id_for_role(role: str) -> str:
try:
from scider.core.llms import ModelRegistry
return ModelRegistry.instance().models.get(role, {}).get("model", "unknown")
except Exception:
return "unknown"
def _estimate_cost(model_id: str, input_tokens: int, output_tokens: int) -> float | None:
try:
import litellm
prompt_cost, completion_cost = litellm.cost_per_token(
model=model_id,
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
)
return prompt_cost + completion_cost
except Exception:
return None
|