IgorSlinko commited on
Commit
f7c61dd
·
1 Parent(s): 81a982c

Unify token calculation using calculate_routing_tokens

Browse files

- Rewrite load_all_trajectories_calculated() to use calculate_routing_tokens
- Remove obsolete calculate_tokens_from_trajectory() function
- Remove obsolete calculate_routed_cost() function
- Single source of truth for token calculation logic
- prompt_tokens = cache_read + uncached_input (mathematically equivalent)

Files changed (1) hide show
  1. app.py +31 -173
app.py CHANGED
@@ -73,85 +73,6 @@ def get_routed_steps(total_steps: int, strategy: str, params: dict) -> set:
73
  return routed
74
 
75
 
76
- def calculate_routed_cost(
77
- trajectory_tokens: dict,
78
- routed_steps: set,
79
- base_prices: dict,
80
- routing_prices: dict,
81
- ) -> dict:
82
- """
83
- Calculate cost for a trajectory with routing.
84
-
85
- Each model maintains its own independent cache.
86
- When switching back to a model, its cache is still available.
87
-
88
- Args:
89
- trajectory_tokens: dict with per-step token counts
90
- routed_steps: set of step indices using routing model
91
- base_prices: {input, cache_read, cache_creation, completion} for base model
92
- routing_prices: same for routing model
93
-
94
- Returns:
95
- dict with base_cost, routing_cost, total_cost
96
- """
97
- total_steps = trajectory_tokens.get("api_calls", 0)
98
- if total_steps == 0:
99
- return {"base_cost": 0, "routing_cost": 0, "total_cost": 0}
100
-
101
- prompt_tokens = trajectory_tokens.get("prompt_tokens", 0)
102
- completion_tokens = trajectory_tokens.get("completion_tokens", 0)
103
- cache_read = trajectory_tokens.get("cache_read_tokens", 0)
104
- cache_creation = trajectory_tokens.get("cache_creation_tokens", 0)
105
-
106
- avg_prompt_per_step = prompt_tokens / total_steps if total_steps > 0 else 0
107
- avg_completion_per_step = completion_tokens / total_steps if total_steps > 0 else 0
108
- avg_cache_read_per_step = cache_read / total_steps if total_steps > 0 else 0
109
- avg_cache_creation_per_step = cache_creation / total_steps if total_steps > 0 else 0
110
-
111
- base_cost = 0
112
- routing_cost = 0
113
-
114
- base_cache_context = 0
115
- routing_cache_context = 0
116
-
117
- for step in range(total_steps):
118
- is_routed = step in routed_steps
119
- prices = routing_prices if is_routed else base_prices
120
-
121
- if is_routed:
122
- cache_ctx = routing_cache_context
123
- else:
124
- cache_ctx = base_cache_context
125
-
126
- uncached_input = avg_prompt_per_step - avg_cache_read_per_step
127
- if cache_ctx == 0:
128
- step_cache_read = 0
129
- step_uncached = avg_prompt_per_step
130
- else:
131
- step_cache_read = avg_cache_read_per_step
132
- step_uncached = uncached_input
133
-
134
- step_cost = (
135
- step_uncached * prices["input"] / 1e6 +
136
- step_cache_read * prices["cache_read"] / 1e6 +
137
- avg_cache_creation_per_step * prices["cache_creation"] / 1e6 +
138
- avg_completion_per_step * prices["completion"] / 1e6
139
- )
140
-
141
- if is_routed:
142
- routing_cost += step_cost
143
- routing_cache_context += avg_prompt_per_step + avg_completion_per_step
144
- else:
145
- base_cost += step_cost
146
- base_cache_context += avg_prompt_per_step + avg_completion_per_step
147
-
148
- return {
149
- "base_cost": base_cost,
150
- "routing_cost": routing_cost,
151
- "total_cost": base_cost + routing_cost,
152
- }
153
-
154
-
155
  def calculate_routing_tokens(steps: list[dict]) -> dict:
156
  """
157
  Calculate token breakdown per model with proper caching simulation.
@@ -309,78 +230,6 @@ def get_tokenizer(model_name: str):
309
  return lambda text: len(enc.encode(text)), tokenizer_name
310
 
311
 
312
- def calculate_tokens_from_trajectory(traj_path: Path, model_name: str) -> dict:
313
- """
314
- Calculate tokens from trajectory messages simulating API behavior.
315
-
316
- API counts prompt_tokens cumulatively for each call (full context each time).
317
- With caching: cache_read = previous context, cache_creation = new content.
318
-
319
- Returns dict with:
320
- - prompt_tokens: total input tokens (cumulative across all API calls)
321
- - completion_tokens: total output tokens
322
- - cache_read_tokens: tokens read from cache
323
- - cache_creation_tokens: tokens written to cache
324
- - api_calls: number of assistant responses
325
- """
326
- with open(traj_path, "r", encoding="utf-8") as f:
327
- data = json.load(f)
328
-
329
- messages = data.get("messages", [])
330
- if not messages:
331
- return {"prompt_tokens": 0, "completion_tokens": 0, "cache_read_tokens": 0, "cache_creation_tokens": 0, "api_calls": 0}
332
-
333
- count_tokens, _ = get_tokenizer(model_name)
334
-
335
- message_tokens = []
336
- for msg in messages:
337
- content = msg.get("content", "")
338
- if isinstance(content, list):
339
- content = json.dumps(content)
340
- tokens = count_tokens(str(content))
341
- message_tokens.append({
342
- "role": msg.get("role", "user"),
343
- "tokens": tokens
344
- })
345
-
346
- # Simulate API behavior: each call sends full context
347
- # LLM APIs cache full context including assistant responses
348
- prompt_tokens = 0 # Cumulative prompt tokens across all API calls
349
- completion_tokens = 0
350
- cache_read_tokens = 0
351
- cache_creation_tokens = 0
352
- api_calls = 0
353
-
354
- context_so_far = 0 # Total tokens in context (including assistant responses)
355
- cached_context = 0 # Tokens that are cached from previous API calls
356
-
357
- for i, mt in enumerate(message_tokens):
358
- if mt["role"] == "assistant":
359
- completion_tokens += mt["tokens"]
360
- api_calls += 1
361
- context_so_far += mt["tokens"]
362
- else:
363
- context_so_far += mt["tokens"]
364
-
365
- next_is_assistant = (i + 1 < len(message_tokens) and message_tokens[i + 1]["role"] == "assistant")
366
-
367
- if next_is_assistant:
368
- prompt_tokens += context_so_far
369
- cache_read_tokens += cached_context
370
-
371
- assistant_tokens = message_tokens[i + 1]["tokens"]
372
- cache_creation_tokens += (context_so_far - cached_context) + assistant_tokens
373
- cached_context = context_so_far + assistant_tokens
374
-
375
- return {
376
- "prompt_tokens": prompt_tokens,
377
- "completion_tokens": completion_tokens,
378
- "cache_read_tokens": cache_read_tokens,
379
- "cache_creation_tokens": cache_creation_tokens,
380
- "api_calls": api_calls,
381
- }
382
-
383
-
384
  def apply_thinking_overhead(df: pd.DataFrame, overhead: float) -> pd.DataFrame:
385
  """Apply tokenizer overhead multiplier to all token counts"""
386
  if df.empty or overhead == 1.0:
@@ -407,15 +256,16 @@ def apply_no_cache(df: pd.DataFrame) -> pd.DataFrame:
407
 
408
 
409
  def load_all_trajectories_calculated(folder: str) -> pd.DataFrame:
410
- """Load trajectories with self-calculated token counts"""
411
  global _calculated_tokens_cache
412
-
413
  cache_key = f"calculated_{folder}"
414
  if cache_key in _calculated_tokens_cache:
415
  return _calculated_tokens_cache[cache_key]
416
-
 
 
417
  output_dir = TRAJS_DIR / folder
418
-
419
  traj_files = list(output_dir.glob("*/*.traj.json"))
420
  if not traj_files:
421
  traj_files = list(output_dir.glob("*/*.traj"))
@@ -423,10 +273,7 @@ def load_all_trajectories_calculated(folder: str) -> pd.DataFrame:
423
  traj_files = list(output_dir.glob("*.traj.json"))
424
  if not traj_files:
425
  traj_files = list(output_dir.glob("*.traj"))
426
- if not traj_files:
427
- traj_files = list(output_dir.glob("*.json"))
428
-
429
- # Get model name from first trajectory
430
  model_name = ""
431
  if traj_files:
432
  try:
@@ -436,26 +283,37 @@ def load_all_trajectories_calculated(folder: str) -> pd.DataFrame:
436
  model_name = config.get("cost_calc_model_override", config.get("model_name", ""))
437
  except Exception:
438
  pass
439
-
440
  rows = []
441
- for traj_path in traj_files:
 
 
 
442
  try:
443
- tokens = calculate_tokens_from_trajectory(traj_path, model_name)
444
-
 
 
 
 
 
 
 
 
445
  rows.append({
446
- "instance_id": traj_path.stem.replace(".traj", ""),
447
  "model_name": model_name,
448
- "api_calls": tokens["api_calls"],
449
- "instance_cost": 0, # Will be calculated from prices
450
- "prompt_tokens": tokens["prompt_tokens"],
451
- "completion_tokens": tokens["completion_tokens"],
452
- "total_tokens": tokens["prompt_tokens"] + tokens["completion_tokens"],
453
- "cache_read_tokens": tokens["cache_read_tokens"],
454
- "cache_creation_tokens": tokens["cache_creation_tokens"],
455
  })
456
  except Exception as e:
457
- print(f"Error calculating tokens for {traj_path}: {e}")
458
-
459
  df = pd.DataFrame(rows)
460
  _calculated_tokens_cache[cache_key] = df
461
  return df
 
73
  return routed
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def calculate_routing_tokens(steps: list[dict]) -> dict:
77
  """
78
  Calculate token breakdown per model with proper caching simulation.
 
230
  return lambda text: len(enc.encode(text)), tokenizer_name
231
 
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  def apply_thinking_overhead(df: pd.DataFrame, overhead: float) -> pd.DataFrame:
234
  """Apply tokenizer overhead multiplier to all token counts"""
235
  if df.empty or overhead == 1.0:
 
256
 
257
 
258
  def load_all_trajectories_calculated(folder: str) -> pd.DataFrame:
259
+ """Load trajectories with self-calculated token counts using calculate_routing_tokens"""
260
  global _calculated_tokens_cache
261
+
262
  cache_key = f"calculated_{folder}"
263
  if cache_key in _calculated_tokens_cache:
264
  return _calculated_tokens_cache[cache_key]
265
+
266
+ trajectory_steps = load_all_trajectory_steps(folder)
267
+
268
  output_dir = TRAJS_DIR / folder
 
269
  traj_files = list(output_dir.glob("*/*.traj.json"))
270
  if not traj_files:
271
  traj_files = list(output_dir.glob("*/*.traj"))
 
273
  traj_files = list(output_dir.glob("*.traj.json"))
274
  if not traj_files:
275
  traj_files = list(output_dir.glob("*.traj"))
276
+
 
 
 
277
  model_name = ""
278
  if traj_files:
279
  try:
 
283
  model_name = config.get("cost_calc_model_override", config.get("model_name", ""))
284
  except Exception:
285
  pass
286
+
287
  rows = []
288
+ for instance_id, steps in trajectory_steps.items():
289
+ if not steps:
290
+ continue
291
+
292
  try:
293
+ model_totals = calculate_routing_tokens(steps)
294
+ totals = model_totals.get(model_name, {})
295
+
296
+ cache_read = totals.get("cache_read", 0)
297
+ uncached_input = totals.get("uncached_input", 0)
298
+ completion = totals.get("completion", 0)
299
+ cache_creation = totals.get("cache_creation", 0)
300
+
301
+ prompt_tokens = cache_read + uncached_input
302
+
303
  rows.append({
304
+ "instance_id": instance_id,
305
  "model_name": model_name,
306
+ "api_calls": len(steps),
307
+ "instance_cost": 0,
308
+ "prompt_tokens": prompt_tokens,
309
+ "completion_tokens": completion,
310
+ "total_tokens": prompt_tokens + completion,
311
+ "cache_read_tokens": cache_read,
312
+ "cache_creation_tokens": cache_creation,
313
  })
314
  except Exception as e:
315
+ print(f"Error calculating tokens for {instance_id}: {e}")
316
+
317
  df = pd.DataFrame(rows)
318
  _calculated_tokens_cache[cache_key] = df
319
  return df