IgorSlinko commited on
Commit
5c06e74
Β·
1 Parent(s): 99badd3

Add routing calculation with proper caching simulation

Browse files

- Add 'Let's ROUTE!!' button with yield for staged rendering
- Add routing token/cost charts grouped by model
- Fix original cost calculation (use uncached_input, not prompt_tokens)
- Support multiple additional models with different colors
- Rename 'routing model' to 'additional model' in charts
- Each model maintains independent cache context
- When switching models, cache is preserved (not reset)
- Proper calculation: uncached_input includes obs from prev step

Files changed (1) hide show
  1. app.py +383 -0
app.py CHANGED
@@ -43,6 +43,114 @@ def parse_step_or_ratio(value: float, total_steps: int) -> int:
43
  return int(value * total_steps)
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def get_default_overhead(model_name: str) -> float:
47
  """Get default tokenizer overhead for model provider"""
48
  model_lower = model_name.lower() if model_name else ""
@@ -947,6 +1055,92 @@ def on_row_select(evt: gr.SelectData, df: pd.DataFrame):
947
  )
948
 
949
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
950
  def build_app():
951
  leaderboard_df = get_bash_only_df()
952
 
@@ -976,6 +1170,10 @@ def build_app():
976
  plot_tokens = gr.Plot(label="Token Usage by Type")
977
  plot_tokens_cost = gr.Plot(label="Cost by Token Type ($)")
978
 
 
 
 
 
979
  with gr.Row():
980
  plot_stacked = gr.Plot(label="Tokens per Trajectory")
981
  plot_cost_breakdown = gr.Plot(label="Cost per Trajectory ($)")
@@ -1117,6 +1315,11 @@ def build_app():
1117
  start_step_3 = gr.Number(label="Start (int=step; 0,0-1,0=ratio)", value=0, minimum=0, precision=2, interactive=True)
1118
  end_step_3 = gr.Number(label="End (int=step; 0,0-1,0=ratio)", value=0.5, minimum=0, precision=2, interactive=True)
1119
 
 
 
 
 
 
1120
  def on_strategy_change(strategy):
1121
  return (
1122
  gr.update(visible=strategy == "Replace on random steps"),
@@ -1242,6 +1445,186 @@ def build_app():
1242
  outputs=[routing_price_3_input, routing_price_3_cache_read, routing_price_3_cache_creation, routing_price_3_completion],
1243
  )
1244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1245
  def update_calculated_options_visibility(source):
1246
  is_calc = source == "Calculated"
1247
  return gr.update(visible=is_calc), gr.update(visible=is_calc)
 
43
  return int(value * total_steps)
44
 
45
 
46
+ def get_routed_steps(total_steps: int, strategy: str, params: dict) -> set:
47
+ """
48
+ Determine which steps should be routed to alternative model.
49
+
50
+ Returns set of step indices (0-based) that should use the routing model.
51
+ """
52
+ import random
53
+
54
+ routed = set()
55
+
56
+ if strategy == "Replace on random steps":
57
+ pct = params.get("percentage", 50) / 100.0
58
+ num_to_route = int(total_steps * pct)
59
+ if num_to_route > 0:
60
+ routed = set(random.sample(range(total_steps), min(num_to_route, total_steps)))
61
+
62
+ elif strategy == "Replace every step k":
63
+ k = int(params.get("k", 2))
64
+ if k > 0:
65
+ routed = set(range(0, total_steps, k))
66
+
67
+ elif strategy == "Replace part of trajectory":
68
+ start = parse_step_or_ratio(params.get("start", 0), total_steps)
69
+ end = parse_step_or_ratio(params.get("end", 0.5), total_steps)
70
+ routed = set(range(start, min(end, total_steps)))
71
+
72
+ return routed
73
+
74
+
75
+ def calculate_routed_cost(
76
+ trajectory_tokens: dict,
77
+ routed_steps: set,
78
+ base_prices: dict,
79
+ routing_prices: dict,
80
+ ) -> dict:
81
+ """
82
+ Calculate cost for a trajectory with routing.
83
+
84
+ Each model maintains its own independent cache.
85
+ When switching back to a model, its cache is still available.
86
+
87
+ Args:
88
+ trajectory_tokens: dict with per-step token counts
89
+ routed_steps: set of step indices using routing model
90
+ base_prices: {input, cache_read, cache_creation, completion} for base model
91
+ routing_prices: same for routing model
92
+
93
+ Returns:
94
+ dict with base_cost, routing_cost, total_cost
95
+ """
96
+ total_steps = trajectory_tokens.get("api_calls", 0)
97
+ if total_steps == 0:
98
+ return {"base_cost": 0, "routing_cost": 0, "total_cost": 0}
99
+
100
+ prompt_tokens = trajectory_tokens.get("prompt_tokens", 0)
101
+ completion_tokens = trajectory_tokens.get("completion_tokens", 0)
102
+ cache_read = trajectory_tokens.get("cache_read_tokens", 0)
103
+ cache_creation = trajectory_tokens.get("cache_creation_tokens", 0)
104
+
105
+ avg_prompt_per_step = prompt_tokens / total_steps if total_steps > 0 else 0
106
+ avg_completion_per_step = completion_tokens / total_steps if total_steps > 0 else 0
107
+ avg_cache_read_per_step = cache_read / total_steps if total_steps > 0 else 0
108
+ avg_cache_creation_per_step = cache_creation / total_steps if total_steps > 0 else 0
109
+
110
+ base_cost = 0
111
+ routing_cost = 0
112
+
113
+ base_cache_context = 0
114
+ routing_cache_context = 0
115
+
116
+ for step in range(total_steps):
117
+ is_routed = step in routed_steps
118
+ prices = routing_prices if is_routed else base_prices
119
+
120
+ if is_routed:
121
+ cache_ctx = routing_cache_context
122
+ else:
123
+ cache_ctx = base_cache_context
124
+
125
+ uncached_input = avg_prompt_per_step - avg_cache_read_per_step
126
+ if cache_ctx == 0:
127
+ step_cache_read = 0
128
+ step_uncached = avg_prompt_per_step
129
+ else:
130
+ step_cache_read = avg_cache_read_per_step
131
+ step_uncached = uncached_input
132
+
133
+ step_cost = (
134
+ step_uncached * prices["input"] / 1e6 +
135
+ step_cache_read * prices["cache_read"] / 1e6 +
136
+ avg_cache_creation_per_step * prices["cache_creation"] / 1e6 +
137
+ avg_completion_per_step * prices["completion"] / 1e6
138
+ )
139
+
140
+ if is_routed:
141
+ routing_cost += step_cost
142
+ routing_cache_context += avg_prompt_per_step + avg_completion_per_step
143
+ else:
144
+ base_cost += step_cost
145
+ base_cache_context += avg_prompt_per_step + avg_completion_per_step
146
+
147
+ return {
148
+ "base_cost": base_cost,
149
+ "routing_cost": routing_cost,
150
+ "total_cost": base_cost + routing_cost,
151
+ }
152
+
153
+
154
  def get_default_overhead(model_name: str) -> float:
155
  """Get default tokenizer overhead for model provider"""
156
  model_lower = model_name.lower() if model_name else ""
 
1055
  )
1056
 
1057
 
1058
+ def create_routed_token_chart(base_tokens: dict, additional_models: list):
1059
+ """
1060
+ Create grouped bar chart for tokens by type, comparing base vs additional models.
1061
+
1062
+ Args:
1063
+ base_tokens: dict with uncached_input, cache_read, cache_creation, completion
1064
+ additional_models: list of (model_name, tokens_dict) tuples
1065
+ """
1066
+ import plotly.graph_objects as go
1067
+
1068
+ categories = ["Uncached Input", "Cache Read", "Cache Creation", "Completion"]
1069
+ colors = ["#636EFA", "#EF553B", "#00CC96", "#AB63FA", "#FFA15A"]
1070
+
1071
+ fig = go.Figure()
1072
+
1073
+ base_values = [
1074
+ base_tokens.get("uncached_input", 0) / 1e6,
1075
+ base_tokens.get("cache_read", 0) / 1e6,
1076
+ base_tokens.get("cache_creation", 0) / 1e6,
1077
+ base_tokens.get("completion", 0) / 1e6,
1078
+ ]
1079
+ fig.add_trace(go.Bar(name="Base Model", x=categories, y=base_values, marker_color=colors[0]))
1080
+
1081
+ for i, (model_name, tokens) in enumerate(additional_models):
1082
+ values = [
1083
+ tokens.get("uncached_input", 0) / 1e6,
1084
+ tokens.get("cache_read", 0) / 1e6,
1085
+ tokens.get("cache_creation", 0) / 1e6,
1086
+ tokens.get("completion", 0) / 1e6,
1087
+ ]
1088
+ color = colors[(i + 1) % len(colors)]
1089
+ fig.add_trace(go.Bar(name=model_name or f"Model {i+1}", x=categories, y=values, marker_color=color))
1090
+
1091
+ fig.update_layout(
1092
+ title="Tokens by Type (per Model)",
1093
+ yaxis_title="Tokens (M)",
1094
+ barmode="group",
1095
+ margin=dict(l=40, r=40, t=60, b=40),
1096
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
1097
+ )
1098
+ return fig
1099
+
1100
+
1101
+ def create_routed_cost_chart(base_costs: dict, additional_models: list):
1102
+ """
1103
+ Create grouped bar chart for costs by type, comparing base vs additional models.
1104
+
1105
+ Args:
1106
+ base_costs: dict with uncached_input, cache_read, cache_creation, completion
1107
+ additional_models: list of (model_name, costs_dict) tuples
1108
+ """
1109
+ import plotly.graph_objects as go
1110
+
1111
+ categories = ["Uncached Input", "Cache Read", "Cache Creation", "Completion"]
1112
+ colors = ["#636EFA", "#EF553B", "#00CC96", "#AB63FA", "#FFA15A"]
1113
+
1114
+ fig = go.Figure()
1115
+
1116
+ base_values = [
1117
+ base_costs.get("uncached_input", 0),
1118
+ base_costs.get("cache_read", 0),
1119
+ base_costs.get("cache_creation", 0),
1120
+ base_costs.get("completion", 0),
1121
+ ]
1122
+ fig.add_trace(go.Bar(name="Base Model", x=categories, y=base_values, marker_color=colors[0]))
1123
+
1124
+ for i, (model_name, costs) in enumerate(additional_models):
1125
+ values = [
1126
+ costs.get("uncached_input", 0),
1127
+ costs.get("cache_read", 0),
1128
+ costs.get("cache_creation", 0),
1129
+ costs.get("completion", 0),
1130
+ ]
1131
+ color = colors[(i + 1) % len(colors)]
1132
+ fig.add_trace(go.Bar(name=model_name or f"Model {i+1}", x=categories, y=values, marker_color=color))
1133
+
1134
+ fig.update_layout(
1135
+ title="Cost by Type (per Model) ($)",
1136
+ yaxis_title="Cost ($)",
1137
+ barmode="group",
1138
+ margin=dict(l=40, r=40, t=60, b=40),
1139
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
1140
+ )
1141
+ return fig
1142
+
1143
+
1144
  def build_app():
1145
  leaderboard_df = get_bash_only_df()
1146
 
 
1170
  plot_tokens = gr.Plot(label="Token Usage by Type")
1171
  plot_tokens_cost = gr.Plot(label="Cost by Token Type ($)")
1172
 
1173
+ with gr.Row(visible=False) as routing_plots_row:
1174
+ routing_tokens_plot = gr.Plot(label="Tokens by Type (per Model)")
1175
+ routing_cost_plot = gr.Plot(label="Cost by Type (per Model)")
1176
+
1177
  with gr.Row():
1178
  plot_stacked = gr.Plot(label="Tokens per Trajectory")
1179
  plot_cost_breakdown = gr.Plot(label="Cost per Trajectory ($)")
 
1315
  start_step_3 = gr.Number(label="Start (int=step; 0,0-1,0=ratio)", value=0, minimum=0, precision=2, interactive=True)
1316
  end_step_3 = gr.Number(label="End (int=step; 0,0-1,0=ratio)", value=0.5, minimum=0, precision=2, interactive=True)
1317
 
1318
+ gr.Markdown("---")
1319
+ route_btn = gr.Button("πŸš€ Let's ROUTE!!", variant="primary", size="lg")
1320
+ routing_result = gr.Markdown(visible=False)
1321
+
1322
+
1323
  def on_strategy_change(strategy):
1324
  return (
1325
  gr.update(visible=strategy == "Replace on random steps"),
 
1445
  outputs=[routing_price_3_input, routing_price_3_cache_read, routing_price_3_cache_creation, routing_price_3_completion],
1446
  )
1447
 
1448
+ def run_routing(
1449
+ state_data,
1450
+ base_input, base_cache_read, base_cache_creation, base_completion,
1451
+ routing_model_1_val, r1_input, r1_cache_read, r1_cache_creation, r1_completion,
1452
+ strategy_1_val, random_pct_1_val, step_k_1_val, start_1_val, end_1_val,
1453
+ source, overhead, with_cache
1454
+ ):
1455
+ if state_data is None:
1456
+ yield (
1457
+ gr.update(visible=True, value="❌ No trajectories loaded. Click 'Load & Analyze' first."),
1458
+ gr.update(visible=False),
1459
+ None, None,
1460
+ )
1461
+ return
1462
+
1463
+ if not routing_model_1_val:
1464
+ yield (
1465
+ gr.update(visible=True, value="❌ Please select at least one routing model."),
1466
+ gr.update(visible=False),
1467
+ None, None,
1468
+ )
1469
+ return
1470
+
1471
+ df_key = "meta" if source == "Metadata" else "calculated"
1472
+ df = state_data.get(df_key)
1473
+ if df is None or df.empty:
1474
+ yield (
1475
+ gr.update(visible=True, value="❌ No trajectory data available."),
1476
+ gr.update(visible=False),
1477
+ None, None,
1478
+ )
1479
+ return
1480
+
1481
+ if source == "Calculated":
1482
+ df = apply_thinking_overhead(df.copy(), overhead)
1483
+ if not with_cache:
1484
+ df = apply_no_cache(df)
1485
+
1486
+ base_prices = {
1487
+ "input": base_input,
1488
+ "cache_read": base_cache_read,
1489
+ "cache_creation": base_cache_creation,
1490
+ "completion": base_completion,
1491
+ }
1492
+ routing_prices = {
1493
+ "input": r1_input,
1494
+ "cache_read": r1_cache_read,
1495
+ "cache_creation": r1_cache_creation,
1496
+ "completion": r1_completion,
1497
+ }
1498
+
1499
+ strategy_params = {}
1500
+ if strategy_1_val == "Replace on random steps":
1501
+ strategy_params["percentage"] = random_pct_1_val
1502
+ elif strategy_1_val == "Replace every step k":
1503
+ strategy_params["k"] = step_k_1_val
1504
+ elif strategy_1_val == "Replace part of trajectory":
1505
+ strategy_params["start"] = start_1_val
1506
+ strategy_params["end"] = end_1_val
1507
+
1508
+ total_base_cost = 0
1509
+ total_routing_cost = 0
1510
+ total_original_cost = 0
1511
+
1512
+ base_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1513
+ routing_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1514
+ base_costs = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1515
+ routing_costs = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1516
+
1517
+ for _, row in df.iterrows():
1518
+ total_steps = int(row.get("api_calls", 0))
1519
+ if total_steps == 0:
1520
+ continue
1521
+
1522
+ routed_steps = get_routed_steps(total_steps, strategy_1_val, strategy_params)
1523
+ num_base_steps = total_steps - len(routed_steps)
1524
+ num_routing_steps = len(routed_steps)
1525
+
1526
+ prompt_tokens = row.get("prompt_tokens", 0)
1527
+ completion_tokens = row.get("completion_tokens", 0)
1528
+ cache_read_tokens = row.get("cache_read_tokens", 0)
1529
+ cache_creation_tokens = row.get("cache_creation_tokens", 0)
1530
+ uncached_input_tokens = prompt_tokens - cache_read_tokens - cache_creation_tokens
1531
+ if uncached_input_tokens < 0:
1532
+ uncached_input_tokens = 0
1533
+
1534
+ base_ratio = num_base_steps / total_steps if total_steps > 0 else 0
1535
+ routing_ratio = num_routing_steps / total_steps if total_steps > 0 else 0
1536
+
1537
+ base_tokens["uncached_input"] += uncached_input_tokens * base_ratio
1538
+ base_tokens["cache_read"] += cache_read_tokens * base_ratio
1539
+ base_tokens["cache_creation"] += cache_creation_tokens * base_ratio
1540
+ base_tokens["completion"] += completion_tokens * base_ratio
1541
+
1542
+ routing_tokens["uncached_input"] += uncached_input_tokens * routing_ratio
1543
+ routing_tokens["cache_read"] += cache_read_tokens * routing_ratio
1544
+ routing_tokens["cache_creation"] += cache_creation_tokens * routing_ratio
1545
+ routing_tokens["completion"] += completion_tokens * routing_ratio
1546
+
1547
+ base_costs["uncached_input"] += uncached_input_tokens * base_ratio * base_prices["input"] / 1e6
1548
+ base_costs["cache_read"] += cache_read_tokens * base_ratio * base_prices["cache_read"] / 1e6
1549
+ base_costs["cache_creation"] += cache_creation_tokens * base_ratio * base_prices["cache_creation"] / 1e6
1550
+ base_costs["completion"] += completion_tokens * base_ratio * base_prices["completion"] / 1e6
1551
+
1552
+ routing_costs["uncached_input"] += uncached_input_tokens * routing_ratio * routing_prices["input"] / 1e6
1553
+ routing_costs["cache_read"] += cache_read_tokens * routing_ratio * routing_prices["cache_read"] / 1e6
1554
+ routing_costs["cache_creation"] += cache_creation_tokens * routing_ratio * routing_prices["cache_creation"] / 1e6
1555
+ routing_costs["completion"] += completion_tokens * routing_ratio * routing_prices["completion"] / 1e6
1556
+
1557
+ traj_tokens = {
1558
+ "api_calls": total_steps,
1559
+ "prompt_tokens": prompt_tokens,
1560
+ "completion_tokens": completion_tokens,
1561
+ "cache_read_tokens": cache_read_tokens,
1562
+ "cache_creation_tokens": cache_creation_tokens,
1563
+ }
1564
+
1565
+ result = calculate_routed_cost(traj_tokens, routed_steps, base_prices, routing_prices)
1566
+ total_base_cost += result["base_cost"]
1567
+ total_routing_cost += result["routing_cost"]
1568
+
1569
+ original_cost = (
1570
+ uncached_input_tokens * base_prices["input"] / 1e6 +
1571
+ cache_read_tokens * base_prices["cache_read"] / 1e6 +
1572
+ cache_creation_tokens * base_prices["cache_creation"] / 1e6 +
1573
+ completion_tokens * base_prices["completion"] / 1e6
1574
+ )
1575
+ total_original_cost += original_cost
1576
+
1577
+ total_routed_cost = total_base_cost + total_routing_cost
1578
+ savings = total_original_cost - total_routed_cost
1579
+ savings_pct = (savings / total_original_cost * 100) if total_original_cost > 0 else 0
1580
+
1581
+ result_text = f"""
1582
+ ## πŸš€ Routing Results
1583
+
1584
+ | Metric | Value |
1585
+ |--------|-------|
1586
+ | **Original Cost (base model only)** | ${total_original_cost:.2f} |
1587
+ | **Routed Cost** | ${total_routed_cost:.2f} |
1588
+ | ↳ Base model portion | ${total_base_cost:.2f} |
1589
+ | ↳ Routing model portion | ${total_routing_cost:.2f} |
1590
+ | **Savings** | ${savings:.2f} ({savings_pct:+.1f}%) |
1591
+
1592
+ *Strategy: {strategy_1_val}*
1593
+ *Routing model: {routing_model_1_val}*
1594
+ """
1595
+
1596
+ additional_token_models = [(routing_model_1_val, routing_tokens)]
1597
+ additional_cost_models = [(routing_model_1_val, routing_costs)]
1598
+
1599
+ yield (
1600
+ gr.update(visible=True, value="⏳ Creating charts..."),
1601
+ gr.update(visible=True),
1602
+ None,
1603
+ None,
1604
+ )
1605
+
1606
+ tokens_chart = create_routed_token_chart(base_tokens, additional_token_models)
1607
+ cost_chart = create_routed_cost_chart(base_costs, additional_cost_models)
1608
+
1609
+ yield (
1610
+ gr.update(visible=True, value=result_text),
1611
+ gr.update(visible=True),
1612
+ tokens_chart,
1613
+ cost_chart,
1614
+ )
1615
+
1616
+ route_btn.click(
1617
+ fn=run_routing,
1618
+ inputs=[
1619
+ trajectories_state,
1620
+ price_input, price_cache_read, price_cache_creation, price_completion,
1621
+ routing_model_1, routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion,
1622
+ strategy_1, random_pct_1, step_k_1, start_step_1, end_step_1,
1623
+ token_source, thinking_overhead, use_cache,
1624
+ ],
1625
+ outputs=[routing_result, routing_plots_row, routing_tokens_plot, routing_cost_plot],
1626
+ )
1627
+
1628
  def update_calculated_options_visibility(source):
1629
  is_calc = source == "Calculated"
1630
  return gr.update(visible=is_calc), gr.update(visible=is_calc)