IgorSlinko commited on
Commit
81a982c
·
1 Parent(s): 5c06e74

Integrate calculate_routing_tokens for accurate multi-model caching

Browse files

- Add calculate_routing_tokens() function for precise token tracking
- Add parse_trajectory_to_steps() to extract step data from trajectories
- Add load_all_trajectory_steps() with caching for routing calculations
- Rewrite run_routing() to use step-by-step token calculation
- Each model maintains independent cache context
- Proper handling of system/user, completion, and observation tokens
- Accurate uncached_input and cache_creation per step

Files changed (1) hide show
  1. app.py +245 -75
app.py CHANGED
@@ -26,6 +26,7 @@ LITELLM_PRICES_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/mod
26
  _litellm_prices_cache = None
27
  _trajectories_cache = {}
28
  _calculated_tokens_cache = {}
 
29
 
30
 
31
  def parse_step_or_ratio(value: float, total_steps: int) -> int:
@@ -151,6 +152,127 @@ def calculate_routed_cost(
151
  }
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def get_default_overhead(model_name: str) -> float:
155
  """Get default tokenizer overhead for model provider"""
156
  model_lower = model_name.lower() if model_name else ""
@@ -339,6 +461,55 @@ def load_all_trajectories_calculated(folder: str) -> pd.DataFrame:
339
  return df
340
 
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  def get_litellm_model_list() -> list[str]:
343
  """Get list of model names from litellm prices"""
344
  prices = get_litellm_prices()
@@ -1468,21 +1639,15 @@ def build_app():
1468
  )
1469
  return
1470
 
1471
- df_key = "meta" if source == "Metadata" else "calculated"
1472
- df = state_data.get(df_key)
1473
- if df is None or df.empty:
1474
  yield (
1475
- gr.update(visible=True, value="❌ No trajectory data available."),
1476
  gr.update(visible=False),
1477
  None, None,
1478
  )
1479
  return
1480
 
1481
- if source == "Calculated":
1482
- df = apply_thinking_overhead(df.copy(), overhead)
1483
- if not with_cache:
1484
- df = apply_no_cache(df)
1485
-
1486
  base_prices = {
1487
  "input": base_input,
1488
  "cache_read": base_cache_read,
@@ -1505,74 +1670,78 @@ def build_app():
1505
  strategy_params["start"] = start_1_val
1506
  strategy_params["end"] = end_1_val
1507
 
1508
- total_base_cost = 0
1509
- total_routing_cost = 0
1510
- total_original_cost = 0
1511
 
1512
- base_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1513
- routing_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1514
- base_costs = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1515
- routing_costs = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1516
 
1517
- for _, row in df.iterrows():
1518
- total_steps = int(row.get("api_calls", 0))
1519
- if total_steps == 0:
1520
  continue
1521
 
1522
- routed_steps = get_routed_steps(total_steps, strategy_1_val, strategy_params)
1523
- num_base_steps = total_steps - len(routed_steps)
1524
- num_routing_steps = len(routed_steps)
1525
-
1526
- prompt_tokens = row.get("prompt_tokens", 0)
1527
- completion_tokens = row.get("completion_tokens", 0)
1528
- cache_read_tokens = row.get("cache_read_tokens", 0)
1529
- cache_creation_tokens = row.get("cache_creation_tokens", 0)
1530
- uncached_input_tokens = prompt_tokens - cache_read_tokens - cache_creation_tokens
1531
- if uncached_input_tokens < 0:
1532
- uncached_input_tokens = 0
1533
-
1534
- base_ratio = num_base_steps / total_steps if total_steps > 0 else 0
1535
- routing_ratio = num_routing_steps / total_steps if total_steps > 0 else 0
1536
-
1537
- base_tokens["uncached_input"] += uncached_input_tokens * base_ratio
1538
- base_tokens["cache_read"] += cache_read_tokens * base_ratio
1539
- base_tokens["cache_creation"] += cache_creation_tokens * base_ratio
1540
- base_tokens["completion"] += completion_tokens * base_ratio
1541
-
1542
- routing_tokens["uncached_input"] += uncached_input_tokens * routing_ratio
1543
- routing_tokens["cache_read"] += cache_read_tokens * routing_ratio
1544
- routing_tokens["cache_creation"] += cache_creation_tokens * routing_ratio
1545
- routing_tokens["completion"] += completion_tokens * routing_ratio
1546
-
1547
- base_costs["uncached_input"] += uncached_input_tokens * base_ratio * base_prices["input"] / 1e6
1548
- base_costs["cache_read"] += cache_read_tokens * base_ratio * base_prices["cache_read"] / 1e6
1549
- base_costs["cache_creation"] += cache_creation_tokens * base_ratio * base_prices["cache_creation"] / 1e6
1550
- base_costs["completion"] += completion_tokens * base_ratio * base_prices["completion"] / 1e6
1551
-
1552
- routing_costs["uncached_input"] += uncached_input_tokens * routing_ratio * routing_prices["input"] / 1e6
1553
- routing_costs["cache_read"] += cache_read_tokens * routing_ratio * routing_prices["cache_read"] / 1e6
1554
- routing_costs["cache_creation"] += cache_creation_tokens * routing_ratio * routing_prices["cache_creation"] / 1e6
1555
- routing_costs["completion"] += completion_tokens * routing_ratio * routing_prices["completion"] / 1e6
1556
-
1557
- traj_tokens = {
1558
- "api_calls": total_steps,
1559
- "prompt_tokens": prompt_tokens,
1560
- "completion_tokens": completion_tokens,
1561
- "cache_read_tokens": cache_read_tokens,
1562
- "cache_creation_tokens": cache_creation_tokens,
1563
- }
1564
-
1565
- result = calculate_routed_cost(traj_tokens, routed_steps, base_prices, routing_prices)
1566
- total_base_cost += result["base_cost"]
1567
- total_routing_cost += result["routing_cost"]
1568
-
1569
- original_cost = (
1570
- uncached_input_tokens * base_prices["input"] / 1e6 +
1571
- cache_read_tokens * base_prices["cache_read"] / 1e6 +
1572
- cache_creation_tokens * base_prices["cache_creation"] / 1e6 +
1573
- completion_tokens * base_prices["completion"] / 1e6
 
1574
  )
1575
- total_original_cost += original_cost
 
 
 
 
 
 
1576
 
1577
  total_routed_cost = total_base_cost + total_routing_cost
1578
  savings = total_original_cost - total_routed_cost
@@ -1593,7 +1762,7 @@ def build_app():
1593
  *Routing model: {routing_model_1_val}*
1594
  """
1595
 
1596
- additional_token_models = [(routing_model_1_val, routing_tokens)]
1597
  additional_cost_models = [(routing_model_1_val, routing_costs)]
1598
 
1599
  yield (
@@ -1603,7 +1772,7 @@ def build_app():
1603
  None,
1604
  )
1605
 
1606
- tokens_chart = create_routed_token_chart(base_tokens, additional_token_models)
1607
  cost_chart = create_routed_cost_chart(base_costs, additional_cost_models)
1608
 
1609
  yield (
@@ -1685,8 +1854,9 @@ def build_app():
1685
  df_calc = load_all_trajectories_calculated(folder)
1686
  df_calc["api_calls"] = df_meta["api_calls"].values
1687
  df_calc["instance_cost"] = df_meta["instance_cost"].values
 
1688
 
1689
- state_data = {"meta": df_meta, "calculated": df_calc}
1690
 
1691
  if source == "Metadata":
1692
  df = df_meta
 
26
  _litellm_prices_cache = None
27
  _trajectories_cache = {}
28
  _calculated_tokens_cache = {}
29
+ _trajectory_steps_cache = {}
30
 
31
 
32
  def parse_step_or_ratio(value: float, total_steps: int) -> int:
 
152
  }
153
 
154
 
155
+ def calculate_routing_tokens(steps: list[dict]) -> dict:
156
+ """
157
+ Calculate token breakdown per model with proper caching simulation.
158
+
159
+ Args:
160
+ steps: list of dicts with keys:
161
+ - model: str (model name)
162
+ - system_user: int (tokens for system/user message, usually only step 0)
163
+ - completion: int (generated tokens)
164
+ - observation: int or None (env response tokens, None for last step)
165
+
166
+ Returns:
167
+ dict with per-model totals:
168
+ {model_name: {cache_read, uncached_input, completion, observation, cache_creation}}
169
+ """
170
+ model_caches = {}
171
+ model_totals = {}
172
+
173
+ total_context = 0
174
+ prev_observation = 0
175
+
176
+ for i, step in enumerate(steps):
177
+ model = step["model"]
178
+ system_user = step.get("system_user", 0)
179
+ completion = step.get("completion", 0)
180
+ observation = step.get("observation") or 0
181
+
182
+ if model not in model_caches:
183
+ model_caches[model] = 0
184
+ if model not in model_totals:
185
+ model_totals[model] = {
186
+ "cache_read": 0,
187
+ "uncached_input": 0,
188
+ "completion": 0,
189
+ "observation": 0,
190
+ "cache_creation": 0,
191
+ }
192
+
193
+ cache_read = model_caches[model]
194
+
195
+ if i == 0:
196
+ uncached_input = system_user
197
+ else:
198
+ full_context_needed = total_context + prev_observation
199
+ uncached_input = full_context_needed - cache_read
200
+
201
+ cache_creation = uncached_input + completion
202
+
203
+ model_caches[model] = cache_read + cache_creation
204
+
205
+ model_totals[model]["cache_read"] += cache_read
206
+ model_totals[model]["uncached_input"] += uncached_input
207
+ model_totals[model]["completion"] += completion
208
+ model_totals[model]["observation"] += observation
209
+ model_totals[model]["cache_creation"] += cache_creation
210
+
211
+ total_context = cache_read + uncached_input + completion
212
+ prev_observation = observation
213
+
214
+ return model_totals
215
+
216
+
217
+ def parse_trajectory_to_steps(traj_path: Path, model_name: str) -> list[dict]:
218
+ """
219
+ Parse trajectory file into step format for calculate_routing_tokens.
220
+
221
+ Returns list of steps with:
222
+ - model: base model name
223
+ - system_user: tokens for system + user message (step 0 only)
224
+ - completion: assistant response tokens
225
+ - observation: env response tokens (None for last step)
226
+ """
227
+ with open(traj_path, "r", encoding="utf-8") as f:
228
+ data = json.load(f)
229
+
230
+ messages = data.get("messages", [])
231
+ if not messages:
232
+ return []
233
+
234
+ count_tokens, _ = get_tokenizer(model_name)
235
+
236
+ steps = []
237
+ system_user_tokens = 0
238
+ current_completion = 0
239
+ pending_observation = None
240
+
241
+ i = 0
242
+ while i < len(messages):
243
+ msg = messages[i]
244
+ role = msg.get("role", "user")
245
+ content = msg.get("content", "")
246
+ if isinstance(content, list):
247
+ content = json.dumps(content)
248
+ tokens = count_tokens(str(content))
249
+
250
+ if role == "system":
251
+ system_user_tokens += tokens
252
+ i += 1
253
+ elif role == "user":
254
+ if not steps:
255
+ system_user_tokens += tokens
256
+ i += 1
257
+ else:
258
+ if steps:
259
+ steps[-1]["observation"] = tokens
260
+ pending_observation = tokens
261
+ i += 1
262
+ elif role == "assistant":
263
+ step = {
264
+ "model": model_name,
265
+ "system_user": system_user_tokens if not steps else 0,
266
+ "completion": tokens,
267
+ "observation": None,
268
+ }
269
+ steps.append(step)
270
+ system_user_tokens = 0
271
+ i += 1
272
+
273
+ return steps
274
+
275
+
276
  def get_default_overhead(model_name: str) -> float:
277
  """Get default tokenizer overhead for model provider"""
278
  model_lower = model_name.lower() if model_name else ""
 
461
  return df
462
 
463
 
464
+ def load_all_trajectory_steps(folder: str) -> dict[str, list[dict]]:
465
+ """
466
+ Load all trajectories as step sequences for routing calculations.
467
+
468
+ Returns:
469
+ dict mapping instance_id -> list of steps for calculate_routing_tokens
470
+ """
471
+ global _trajectory_steps_cache
472
+
473
+ cache_key = f"steps_{folder}"
474
+ if cache_key in _trajectory_steps_cache:
475
+ return _trajectory_steps_cache[cache_key]
476
+
477
+ output_dir = TRAJS_DIR / folder
478
+
479
+ traj_files = list(output_dir.glob("*/*.traj.json"))
480
+ if not traj_files:
481
+ traj_files = list(output_dir.glob("*/*.traj"))
482
+ if not traj_files:
483
+ traj_files = list(output_dir.glob("*.traj.json"))
484
+ if not traj_files:
485
+ traj_files = list(output_dir.glob("*.traj"))
486
+ if not traj_files:
487
+ traj_files = list(output_dir.glob("*.json"))
488
+
489
+ model_name = ""
490
+ if traj_files:
491
+ try:
492
+ with open(traj_files[0], "r") as f:
493
+ first_data = json.load(f)
494
+ config = first_data.get("info", {}).get("config", {}).get("model", {})
495
+ model_name = config.get("cost_calc_model_override", config.get("model_name", ""))
496
+ except Exception:
497
+ pass
498
+
499
+ result = {}
500
+ for traj_path in traj_files:
501
+ try:
502
+ instance_id = traj_path.stem.replace(".traj", "")
503
+ steps = parse_trajectory_to_steps(traj_path, model_name)
504
+ if steps:
505
+ result[instance_id] = steps
506
+ except Exception as e:
507
+ print(f"Error parsing steps for {traj_path}: {e}")
508
+
509
+ _trajectory_steps_cache[cache_key] = result
510
+ return result
511
+
512
+
513
  def get_litellm_model_list() -> list[str]:
514
  """Get list of model names from litellm prices"""
515
  prices = get_litellm_prices()
 
1639
  )
1640
  return
1641
 
1642
+ trajectory_steps = state_data.get("steps", {})
1643
+ if not trajectory_steps:
 
1644
  yield (
1645
+ gr.update(visible=True, value="❌ No trajectory steps data available."),
1646
  gr.update(visible=False),
1647
  None, None,
1648
  )
1649
  return
1650
 
 
 
 
 
 
1651
  base_prices = {
1652
  "input": base_input,
1653
  "cache_read": base_cache_read,
 
1670
  strategy_params["start"] = start_1_val
1671
  strategy_params["end"] = end_1_val
1672
 
1673
+ total_base_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1674
+ total_routing_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1675
+ total_original_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1676
 
1677
+ BASE_MODEL = "__base__"
1678
+ ROUTING_MODEL = "__routing__"
 
 
1679
 
1680
+ for instance_id, steps in trajectory_steps.items():
1681
+ if not steps:
 
1682
  continue
1683
 
1684
+ total_steps = len(steps)
1685
+ routed_step_indices = get_routed_steps(total_steps, strategy_1_val, strategy_params)
1686
+
1687
+ modified_steps = []
1688
+ for i, step in enumerate(steps):
1689
+ model = ROUTING_MODEL if i in routed_step_indices else BASE_MODEL
1690
+ modified_steps.append({
1691
+ "model": model,
1692
+ "system_user": step.get("system_user", 0),
1693
+ "completion": int(step.get("completion", 0) * (overhead if source == "Calculated" else 1)),
1694
+ "observation": step.get("observation"),
1695
+ })
1696
+
1697
+ model_totals = calculate_routing_tokens(modified_steps)
1698
+
1699
+ base_totals = model_totals.get(BASE_MODEL, {
1700
+ "cache_read": 0, "uncached_input": 0, "completion": 0, "cache_creation": 0
1701
+ })
1702
+ routing_totals = model_totals.get(ROUTING_MODEL, {
1703
+ "cache_read": 0, "uncached_input": 0, "completion": 0, "cache_creation": 0
1704
+ })
1705
+
1706
+ total_base_tokens["cache_read"] += base_totals.get("cache_read", 0)
1707
+ total_base_tokens["uncached_input"] += base_totals.get("uncached_input", 0)
1708
+ total_base_tokens["completion"] += base_totals.get("completion", 0)
1709
+ total_base_tokens["cache_creation"] += base_totals.get("cache_creation", 0)
1710
+
1711
+ total_routing_tokens["cache_read"] += routing_totals.get("cache_read", 0)
1712
+ total_routing_tokens["uncached_input"] += routing_totals.get("uncached_input", 0)
1713
+ total_routing_tokens["completion"] += routing_totals.get("completion", 0)
1714
+ total_routing_tokens["cache_creation"] += routing_totals.get("cache_creation", 0)
1715
+
1716
+ original_steps = []
1717
+ for step in steps:
1718
+ original_steps.append({
1719
+ "model": BASE_MODEL,
1720
+ "system_user": step.get("system_user", 0),
1721
+ "completion": int(step.get("completion", 0) * (overhead if source == "Calculated" else 1)),
1722
+ "observation": step.get("observation"),
1723
+ })
1724
+ original_totals = calculate_routing_tokens(original_steps)
1725
+ orig = original_totals.get(BASE_MODEL, {})
1726
+ total_original_tokens["cache_read"] += orig.get("cache_read", 0)
1727
+ total_original_tokens["uncached_input"] += orig.get("uncached_input", 0)
1728
+ total_original_tokens["completion"] += orig.get("completion", 0)
1729
+ total_original_tokens["cache_creation"] += orig.get("cache_creation", 0)
1730
+
1731
+ def calc_cost(tokens: dict, prices: dict) -> float:
1732
+ return (
1733
+ tokens["uncached_input"] * prices["input"] / 1e6 +
1734
+ tokens["cache_read"] * prices["cache_read"] / 1e6 +
1735
+ tokens["cache_creation"] * prices["cache_creation"] / 1e6 +
1736
+ tokens["completion"] * prices["completion"] / 1e6
1737
  )
1738
+
1739
+ base_costs = {k: total_base_tokens[k] * base_prices[{"uncached_input": "input", "cache_read": "cache_read", "cache_creation": "cache_creation", "completion": "completion"}[k]] / 1e6 for k in total_base_tokens}
1740
+ routing_costs = {k: total_routing_tokens[k] * routing_prices[{"uncached_input": "input", "cache_read": "cache_read", "cache_creation": "cache_creation", "completion": "completion"}[k]] / 1e6 for k in total_routing_tokens}
1741
+
1742
+ total_base_cost = calc_cost(total_base_tokens, base_prices)
1743
+ total_routing_cost = calc_cost(total_routing_tokens, routing_prices)
1744
+ total_original_cost = calc_cost(total_original_tokens, base_prices)
1745
 
1746
  total_routed_cost = total_base_cost + total_routing_cost
1747
  savings = total_original_cost - total_routed_cost
 
1762
  *Routing model: {routing_model_1_val}*
1763
  """
1764
 
1765
+ additional_token_models = [(routing_model_1_val, total_routing_tokens)]
1766
  additional_cost_models = [(routing_model_1_val, routing_costs)]
1767
 
1768
  yield (
 
1772
  None,
1773
  )
1774
 
1775
+ tokens_chart = create_routed_token_chart(total_base_tokens, additional_token_models)
1776
  cost_chart = create_routed_cost_chart(base_costs, additional_cost_models)
1777
 
1778
  yield (
 
1854
  df_calc = load_all_trajectories_calculated(folder)
1855
  df_calc["api_calls"] = df_meta["api_calls"].values
1856
  df_calc["instance_cost"] = df_meta["instance_cost"].values
1857
+ trajectory_steps = load_all_trajectory_steps(folder)
1858
 
1859
+ state_data = {"meta": df_meta, "calculated": df_calc, "folder": folder, "steps": trajectory_steps}
1860
 
1861
  if source == "Metadata":
1862
  df = df_meta