Commit
Β·
5c06e74
1
Parent(s):
99badd3
Add routing calculation with proper caching simulation
Browse files- Add 'Let's ROUTE!!' button with yield for staged rendering
- Add routing token/cost charts grouped by model
- Fix original cost calculation (use uncached_input, not prompt_tokens)
- Support multiple additional models with different colors
- Rename 'routing model' to 'additional model' in charts
- Each model maintains independent cache context
- When switching models, cache is preserved (not reset)
- Proper calculation: uncached_input includes obs from prev step
app.py
CHANGED
|
@@ -43,6 +43,114 @@ def parse_step_or_ratio(value: float, total_steps: int) -> int:
|
|
| 43 |
return int(value * total_steps)
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def get_default_overhead(model_name: str) -> float:
|
| 47 |
"""Get default tokenizer overhead for model provider"""
|
| 48 |
model_lower = model_name.lower() if model_name else ""
|
|
@@ -947,6 +1055,92 @@ def on_row_select(evt: gr.SelectData, df: pd.DataFrame):
|
|
| 947 |
)
|
| 948 |
|
| 949 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 950 |
def build_app():
|
| 951 |
leaderboard_df = get_bash_only_df()
|
| 952 |
|
|
@@ -976,6 +1170,10 @@ def build_app():
|
|
| 976 |
plot_tokens = gr.Plot(label="Token Usage by Type")
|
| 977 |
plot_tokens_cost = gr.Plot(label="Cost by Token Type ($)")
|
| 978 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 979 |
with gr.Row():
|
| 980 |
plot_stacked = gr.Plot(label="Tokens per Trajectory")
|
| 981 |
plot_cost_breakdown = gr.Plot(label="Cost per Trajectory ($)")
|
|
@@ -1117,6 +1315,11 @@ def build_app():
|
|
| 1117 |
start_step_3 = gr.Number(label="Start (int=step; 0,0-1,0=ratio)", value=0, minimum=0, precision=2, interactive=True)
|
| 1118 |
end_step_3 = gr.Number(label="End (int=step; 0,0-1,0=ratio)", value=0.5, minimum=0, precision=2, interactive=True)
|
| 1119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1120 |
def on_strategy_change(strategy):
|
| 1121 |
return (
|
| 1122 |
gr.update(visible=strategy == "Replace on random steps"),
|
|
@@ -1242,6 +1445,186 @@ def build_app():
|
|
| 1242 |
outputs=[routing_price_3_input, routing_price_3_cache_read, routing_price_3_cache_creation, routing_price_3_completion],
|
| 1243 |
)
|
| 1244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1245 |
def update_calculated_options_visibility(source):
|
| 1246 |
is_calc = source == "Calculated"
|
| 1247 |
return gr.update(visible=is_calc), gr.update(visible=is_calc)
|
|
|
|
| 43 |
return int(value * total_steps)
|
| 44 |
|
| 45 |
|
| 46 |
+
def get_routed_steps(total_steps: int, strategy: str, params: dict) -> set:
|
| 47 |
+
"""
|
| 48 |
+
Determine which steps should be routed to alternative model.
|
| 49 |
+
|
| 50 |
+
Returns set of step indices (0-based) that should use the routing model.
|
| 51 |
+
"""
|
| 52 |
+
import random
|
| 53 |
+
|
| 54 |
+
routed = set()
|
| 55 |
+
|
| 56 |
+
if strategy == "Replace on random steps":
|
| 57 |
+
pct = params.get("percentage", 50) / 100.0
|
| 58 |
+
num_to_route = int(total_steps * pct)
|
| 59 |
+
if num_to_route > 0:
|
| 60 |
+
routed = set(random.sample(range(total_steps), min(num_to_route, total_steps)))
|
| 61 |
+
|
| 62 |
+
elif strategy == "Replace every step k":
|
| 63 |
+
k = int(params.get("k", 2))
|
| 64 |
+
if k > 0:
|
| 65 |
+
routed = set(range(0, total_steps, k))
|
| 66 |
+
|
| 67 |
+
elif strategy == "Replace part of trajectory":
|
| 68 |
+
start = parse_step_or_ratio(params.get("start", 0), total_steps)
|
| 69 |
+
end = parse_step_or_ratio(params.get("end", 0.5), total_steps)
|
| 70 |
+
routed = set(range(start, min(end, total_steps)))
|
| 71 |
+
|
| 72 |
+
return routed
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def calculate_routed_cost(
|
| 76 |
+
trajectory_tokens: dict,
|
| 77 |
+
routed_steps: set,
|
| 78 |
+
base_prices: dict,
|
| 79 |
+
routing_prices: dict,
|
| 80 |
+
) -> dict:
|
| 81 |
+
"""
|
| 82 |
+
Calculate cost for a trajectory with routing.
|
| 83 |
+
|
| 84 |
+
Each model maintains its own independent cache.
|
| 85 |
+
When switching back to a model, its cache is still available.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
trajectory_tokens: dict with per-step token counts
|
| 89 |
+
routed_steps: set of step indices using routing model
|
| 90 |
+
base_prices: {input, cache_read, cache_creation, completion} for base model
|
| 91 |
+
routing_prices: same for routing model
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
dict with base_cost, routing_cost, total_cost
|
| 95 |
+
"""
|
| 96 |
+
total_steps = trajectory_tokens.get("api_calls", 0)
|
| 97 |
+
if total_steps == 0:
|
| 98 |
+
return {"base_cost": 0, "routing_cost": 0, "total_cost": 0}
|
| 99 |
+
|
| 100 |
+
prompt_tokens = trajectory_tokens.get("prompt_tokens", 0)
|
| 101 |
+
completion_tokens = trajectory_tokens.get("completion_tokens", 0)
|
| 102 |
+
cache_read = trajectory_tokens.get("cache_read_tokens", 0)
|
| 103 |
+
cache_creation = trajectory_tokens.get("cache_creation_tokens", 0)
|
| 104 |
+
|
| 105 |
+
avg_prompt_per_step = prompt_tokens / total_steps if total_steps > 0 else 0
|
| 106 |
+
avg_completion_per_step = completion_tokens / total_steps if total_steps > 0 else 0
|
| 107 |
+
avg_cache_read_per_step = cache_read / total_steps if total_steps > 0 else 0
|
| 108 |
+
avg_cache_creation_per_step = cache_creation / total_steps if total_steps > 0 else 0
|
| 109 |
+
|
| 110 |
+
base_cost = 0
|
| 111 |
+
routing_cost = 0
|
| 112 |
+
|
| 113 |
+
base_cache_context = 0
|
| 114 |
+
routing_cache_context = 0
|
| 115 |
+
|
| 116 |
+
for step in range(total_steps):
|
| 117 |
+
is_routed = step in routed_steps
|
| 118 |
+
prices = routing_prices if is_routed else base_prices
|
| 119 |
+
|
| 120 |
+
if is_routed:
|
| 121 |
+
cache_ctx = routing_cache_context
|
| 122 |
+
else:
|
| 123 |
+
cache_ctx = base_cache_context
|
| 124 |
+
|
| 125 |
+
uncached_input = avg_prompt_per_step - avg_cache_read_per_step
|
| 126 |
+
if cache_ctx == 0:
|
| 127 |
+
step_cache_read = 0
|
| 128 |
+
step_uncached = avg_prompt_per_step
|
| 129 |
+
else:
|
| 130 |
+
step_cache_read = avg_cache_read_per_step
|
| 131 |
+
step_uncached = uncached_input
|
| 132 |
+
|
| 133 |
+
step_cost = (
|
| 134 |
+
step_uncached * prices["input"] / 1e6 +
|
| 135 |
+
step_cache_read * prices["cache_read"] / 1e6 +
|
| 136 |
+
avg_cache_creation_per_step * prices["cache_creation"] / 1e6 +
|
| 137 |
+
avg_completion_per_step * prices["completion"] / 1e6
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if is_routed:
|
| 141 |
+
routing_cost += step_cost
|
| 142 |
+
routing_cache_context += avg_prompt_per_step + avg_completion_per_step
|
| 143 |
+
else:
|
| 144 |
+
base_cost += step_cost
|
| 145 |
+
base_cache_context += avg_prompt_per_step + avg_completion_per_step
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
"base_cost": base_cost,
|
| 149 |
+
"routing_cost": routing_cost,
|
| 150 |
+
"total_cost": base_cost + routing_cost,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
def get_default_overhead(model_name: str) -> float:
|
| 155 |
"""Get default tokenizer overhead for model provider"""
|
| 156 |
model_lower = model_name.lower() if model_name else ""
|
|
|
|
| 1055 |
)
|
| 1056 |
|
| 1057 |
|
| 1058 |
+
def create_routed_token_chart(base_tokens: dict, additional_models: list):
|
| 1059 |
+
"""
|
| 1060 |
+
Create grouped bar chart for tokens by type, comparing base vs additional models.
|
| 1061 |
+
|
| 1062 |
+
Args:
|
| 1063 |
+
base_tokens: dict with uncached_input, cache_read, cache_creation, completion
|
| 1064 |
+
additional_models: list of (model_name, tokens_dict) tuples
|
| 1065 |
+
"""
|
| 1066 |
+
import plotly.graph_objects as go
|
| 1067 |
+
|
| 1068 |
+
categories = ["Uncached Input", "Cache Read", "Cache Creation", "Completion"]
|
| 1069 |
+
colors = ["#636EFA", "#EF553B", "#00CC96", "#AB63FA", "#FFA15A"]
|
| 1070 |
+
|
| 1071 |
+
fig = go.Figure()
|
| 1072 |
+
|
| 1073 |
+
base_values = [
|
| 1074 |
+
base_tokens.get("uncached_input", 0) / 1e6,
|
| 1075 |
+
base_tokens.get("cache_read", 0) / 1e6,
|
| 1076 |
+
base_tokens.get("cache_creation", 0) / 1e6,
|
| 1077 |
+
base_tokens.get("completion", 0) / 1e6,
|
| 1078 |
+
]
|
| 1079 |
+
fig.add_trace(go.Bar(name="Base Model", x=categories, y=base_values, marker_color=colors[0]))
|
| 1080 |
+
|
| 1081 |
+
for i, (model_name, tokens) in enumerate(additional_models):
|
| 1082 |
+
values = [
|
| 1083 |
+
tokens.get("uncached_input", 0) / 1e6,
|
| 1084 |
+
tokens.get("cache_read", 0) / 1e6,
|
| 1085 |
+
tokens.get("cache_creation", 0) / 1e6,
|
| 1086 |
+
tokens.get("completion", 0) / 1e6,
|
| 1087 |
+
]
|
| 1088 |
+
color = colors[(i + 1) % len(colors)]
|
| 1089 |
+
fig.add_trace(go.Bar(name=model_name or f"Model {i+1}", x=categories, y=values, marker_color=color))
|
| 1090 |
+
|
| 1091 |
+
fig.update_layout(
|
| 1092 |
+
title="Tokens by Type (per Model)",
|
| 1093 |
+
yaxis_title="Tokens (M)",
|
| 1094 |
+
barmode="group",
|
| 1095 |
+
margin=dict(l=40, r=40, t=60, b=40),
|
| 1096 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
|
| 1097 |
+
)
|
| 1098 |
+
return fig
|
| 1099 |
+
|
| 1100 |
+
|
| 1101 |
+
def create_routed_cost_chart(base_costs: dict, additional_models: list):
|
| 1102 |
+
"""
|
| 1103 |
+
Create grouped bar chart for costs by type, comparing base vs additional models.
|
| 1104 |
+
|
| 1105 |
+
Args:
|
| 1106 |
+
base_costs: dict with uncached_input, cache_read, cache_creation, completion
|
| 1107 |
+
additional_models: list of (model_name, costs_dict) tuples
|
| 1108 |
+
"""
|
| 1109 |
+
import plotly.graph_objects as go
|
| 1110 |
+
|
| 1111 |
+
categories = ["Uncached Input", "Cache Read", "Cache Creation", "Completion"]
|
| 1112 |
+
colors = ["#636EFA", "#EF553B", "#00CC96", "#AB63FA", "#FFA15A"]
|
| 1113 |
+
|
| 1114 |
+
fig = go.Figure()
|
| 1115 |
+
|
| 1116 |
+
base_values = [
|
| 1117 |
+
base_costs.get("uncached_input", 0),
|
| 1118 |
+
base_costs.get("cache_read", 0),
|
| 1119 |
+
base_costs.get("cache_creation", 0),
|
| 1120 |
+
base_costs.get("completion", 0),
|
| 1121 |
+
]
|
| 1122 |
+
fig.add_trace(go.Bar(name="Base Model", x=categories, y=base_values, marker_color=colors[0]))
|
| 1123 |
+
|
| 1124 |
+
for i, (model_name, costs) in enumerate(additional_models):
|
| 1125 |
+
values = [
|
| 1126 |
+
costs.get("uncached_input", 0),
|
| 1127 |
+
costs.get("cache_read", 0),
|
| 1128 |
+
costs.get("cache_creation", 0),
|
| 1129 |
+
costs.get("completion", 0),
|
| 1130 |
+
]
|
| 1131 |
+
color = colors[(i + 1) % len(colors)]
|
| 1132 |
+
fig.add_trace(go.Bar(name=model_name or f"Model {i+1}", x=categories, y=values, marker_color=color))
|
| 1133 |
+
|
| 1134 |
+
fig.update_layout(
|
| 1135 |
+
title="Cost by Type (per Model) ($)",
|
| 1136 |
+
yaxis_title="Cost ($)",
|
| 1137 |
+
barmode="group",
|
| 1138 |
+
margin=dict(l=40, r=40, t=60, b=40),
|
| 1139 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
|
| 1140 |
+
)
|
| 1141 |
+
return fig
|
| 1142 |
+
|
| 1143 |
+
|
| 1144 |
def build_app():
|
| 1145 |
leaderboard_df = get_bash_only_df()
|
| 1146 |
|
|
|
|
| 1170 |
plot_tokens = gr.Plot(label="Token Usage by Type")
|
| 1171 |
plot_tokens_cost = gr.Plot(label="Cost by Token Type ($)")
|
| 1172 |
|
| 1173 |
+
with gr.Row(visible=False) as routing_plots_row:
|
| 1174 |
+
routing_tokens_plot = gr.Plot(label="Tokens by Type (per Model)")
|
| 1175 |
+
routing_cost_plot = gr.Plot(label="Cost by Type (per Model)")
|
| 1176 |
+
|
| 1177 |
with gr.Row():
|
| 1178 |
plot_stacked = gr.Plot(label="Tokens per Trajectory")
|
| 1179 |
plot_cost_breakdown = gr.Plot(label="Cost per Trajectory ($)")
|
|
|
|
| 1315 |
start_step_3 = gr.Number(label="Start (int=step; 0,0-1,0=ratio)", value=0, minimum=0, precision=2, interactive=True)
|
| 1316 |
end_step_3 = gr.Number(label="End (int=step; 0,0-1,0=ratio)", value=0.5, minimum=0, precision=2, interactive=True)
|
| 1317 |
|
| 1318 |
+
gr.Markdown("---")
|
| 1319 |
+
route_btn = gr.Button("π Let's ROUTE!!", variant="primary", size="lg")
|
| 1320 |
+
routing_result = gr.Markdown(visible=False)
|
| 1321 |
+
|
| 1322 |
+
|
| 1323 |
def on_strategy_change(strategy):
|
| 1324 |
return (
|
| 1325 |
gr.update(visible=strategy == "Replace on random steps"),
|
|
|
|
| 1445 |
outputs=[routing_price_3_input, routing_price_3_cache_read, routing_price_3_cache_creation, routing_price_3_completion],
|
| 1446 |
)
|
| 1447 |
|
| 1448 |
+
def run_routing(
|
| 1449 |
+
state_data,
|
| 1450 |
+
base_input, base_cache_read, base_cache_creation, base_completion,
|
| 1451 |
+
routing_model_1_val, r1_input, r1_cache_read, r1_cache_creation, r1_completion,
|
| 1452 |
+
strategy_1_val, random_pct_1_val, step_k_1_val, start_1_val, end_1_val,
|
| 1453 |
+
source, overhead, with_cache
|
| 1454 |
+
):
|
| 1455 |
+
if state_data is None:
|
| 1456 |
+
yield (
|
| 1457 |
+
gr.update(visible=True, value="β No trajectories loaded. Click 'Load & Analyze' first."),
|
| 1458 |
+
gr.update(visible=False),
|
| 1459 |
+
None, None,
|
| 1460 |
+
)
|
| 1461 |
+
return
|
| 1462 |
+
|
| 1463 |
+
if not routing_model_1_val:
|
| 1464 |
+
yield (
|
| 1465 |
+
gr.update(visible=True, value="β Please select at least one routing model."),
|
| 1466 |
+
gr.update(visible=False),
|
| 1467 |
+
None, None,
|
| 1468 |
+
)
|
| 1469 |
+
return
|
| 1470 |
+
|
| 1471 |
+
df_key = "meta" if source == "Metadata" else "calculated"
|
| 1472 |
+
df = state_data.get(df_key)
|
| 1473 |
+
if df is None or df.empty:
|
| 1474 |
+
yield (
|
| 1475 |
+
gr.update(visible=True, value="β No trajectory data available."),
|
| 1476 |
+
gr.update(visible=False),
|
| 1477 |
+
None, None,
|
| 1478 |
+
)
|
| 1479 |
+
return
|
| 1480 |
+
|
| 1481 |
+
if source == "Calculated":
|
| 1482 |
+
df = apply_thinking_overhead(df.copy(), overhead)
|
| 1483 |
+
if not with_cache:
|
| 1484 |
+
df = apply_no_cache(df)
|
| 1485 |
+
|
| 1486 |
+
base_prices = {
|
| 1487 |
+
"input": base_input,
|
| 1488 |
+
"cache_read": base_cache_read,
|
| 1489 |
+
"cache_creation": base_cache_creation,
|
| 1490 |
+
"completion": base_completion,
|
| 1491 |
+
}
|
| 1492 |
+
routing_prices = {
|
| 1493 |
+
"input": r1_input,
|
| 1494 |
+
"cache_read": r1_cache_read,
|
| 1495 |
+
"cache_creation": r1_cache_creation,
|
| 1496 |
+
"completion": r1_completion,
|
| 1497 |
+
}
|
| 1498 |
+
|
| 1499 |
+
strategy_params = {}
|
| 1500 |
+
if strategy_1_val == "Replace on random steps":
|
| 1501 |
+
strategy_params["percentage"] = random_pct_1_val
|
| 1502 |
+
elif strategy_1_val == "Replace every step k":
|
| 1503 |
+
strategy_params["k"] = step_k_1_val
|
| 1504 |
+
elif strategy_1_val == "Replace part of trajectory":
|
| 1505 |
+
strategy_params["start"] = start_1_val
|
| 1506 |
+
strategy_params["end"] = end_1_val
|
| 1507 |
+
|
| 1508 |
+
total_base_cost = 0
|
| 1509 |
+
total_routing_cost = 0
|
| 1510 |
+
total_original_cost = 0
|
| 1511 |
+
|
| 1512 |
+
base_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
|
| 1513 |
+
routing_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
|
| 1514 |
+
base_costs = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
|
| 1515 |
+
routing_costs = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
|
| 1516 |
+
|
| 1517 |
+
for _, row in df.iterrows():
|
| 1518 |
+
total_steps = int(row.get("api_calls", 0))
|
| 1519 |
+
if total_steps == 0:
|
| 1520 |
+
continue
|
| 1521 |
+
|
| 1522 |
+
routed_steps = get_routed_steps(total_steps, strategy_1_val, strategy_params)
|
| 1523 |
+
num_base_steps = total_steps - len(routed_steps)
|
| 1524 |
+
num_routing_steps = len(routed_steps)
|
| 1525 |
+
|
| 1526 |
+
prompt_tokens = row.get("prompt_tokens", 0)
|
| 1527 |
+
completion_tokens = row.get("completion_tokens", 0)
|
| 1528 |
+
cache_read_tokens = row.get("cache_read_tokens", 0)
|
| 1529 |
+
cache_creation_tokens = row.get("cache_creation_tokens", 0)
|
| 1530 |
+
uncached_input_tokens = prompt_tokens - cache_read_tokens - cache_creation_tokens
|
| 1531 |
+
if uncached_input_tokens < 0:
|
| 1532 |
+
uncached_input_tokens = 0
|
| 1533 |
+
|
| 1534 |
+
base_ratio = num_base_steps / total_steps if total_steps > 0 else 0
|
| 1535 |
+
routing_ratio = num_routing_steps / total_steps if total_steps > 0 else 0
|
| 1536 |
+
|
| 1537 |
+
base_tokens["uncached_input"] += uncached_input_tokens * base_ratio
|
| 1538 |
+
base_tokens["cache_read"] += cache_read_tokens * base_ratio
|
| 1539 |
+
base_tokens["cache_creation"] += cache_creation_tokens * base_ratio
|
| 1540 |
+
base_tokens["completion"] += completion_tokens * base_ratio
|
| 1541 |
+
|
| 1542 |
+
routing_tokens["uncached_input"] += uncached_input_tokens * routing_ratio
|
| 1543 |
+
routing_tokens["cache_read"] += cache_read_tokens * routing_ratio
|
| 1544 |
+
routing_tokens["cache_creation"] += cache_creation_tokens * routing_ratio
|
| 1545 |
+
routing_tokens["completion"] += completion_tokens * routing_ratio
|
| 1546 |
+
|
| 1547 |
+
base_costs["uncached_input"] += uncached_input_tokens * base_ratio * base_prices["input"] / 1e6
|
| 1548 |
+
base_costs["cache_read"] += cache_read_tokens * base_ratio * base_prices["cache_read"] / 1e6
|
| 1549 |
+
base_costs["cache_creation"] += cache_creation_tokens * base_ratio * base_prices["cache_creation"] / 1e6
|
| 1550 |
+
base_costs["completion"] += completion_tokens * base_ratio * base_prices["completion"] / 1e6
|
| 1551 |
+
|
| 1552 |
+
routing_costs["uncached_input"] += uncached_input_tokens * routing_ratio * routing_prices["input"] / 1e6
|
| 1553 |
+
routing_costs["cache_read"] += cache_read_tokens * routing_ratio * routing_prices["cache_read"] / 1e6
|
| 1554 |
+
routing_costs["cache_creation"] += cache_creation_tokens * routing_ratio * routing_prices["cache_creation"] / 1e6
|
| 1555 |
+
routing_costs["completion"] += completion_tokens * routing_ratio * routing_prices["completion"] / 1e6
|
| 1556 |
+
|
| 1557 |
+
traj_tokens = {
|
| 1558 |
+
"api_calls": total_steps,
|
| 1559 |
+
"prompt_tokens": prompt_tokens,
|
| 1560 |
+
"completion_tokens": completion_tokens,
|
| 1561 |
+
"cache_read_tokens": cache_read_tokens,
|
| 1562 |
+
"cache_creation_tokens": cache_creation_tokens,
|
| 1563 |
+
}
|
| 1564 |
+
|
| 1565 |
+
result = calculate_routed_cost(traj_tokens, routed_steps, base_prices, routing_prices)
|
| 1566 |
+
total_base_cost += result["base_cost"]
|
| 1567 |
+
total_routing_cost += result["routing_cost"]
|
| 1568 |
+
|
| 1569 |
+
original_cost = (
|
| 1570 |
+
uncached_input_tokens * base_prices["input"] / 1e6 +
|
| 1571 |
+
cache_read_tokens * base_prices["cache_read"] / 1e6 +
|
| 1572 |
+
cache_creation_tokens * base_prices["cache_creation"] / 1e6 +
|
| 1573 |
+
completion_tokens * base_prices["completion"] / 1e6
|
| 1574 |
+
)
|
| 1575 |
+
total_original_cost += original_cost
|
| 1576 |
+
|
| 1577 |
+
total_routed_cost = total_base_cost + total_routing_cost
|
| 1578 |
+
savings = total_original_cost - total_routed_cost
|
| 1579 |
+
savings_pct = (savings / total_original_cost * 100) if total_original_cost > 0 else 0
|
| 1580 |
+
|
| 1581 |
+
result_text = f"""
|
| 1582 |
+
## π Routing Results
|
| 1583 |
+
|
| 1584 |
+
| Metric | Value |
|
| 1585 |
+
|--------|-------|
|
| 1586 |
+
| **Original Cost (base model only)** | ${total_original_cost:.2f} |
|
| 1587 |
+
| **Routed Cost** | ${total_routed_cost:.2f} |
|
| 1588 |
+
| β³ Base model portion | ${total_base_cost:.2f} |
|
| 1589 |
+
| β³ Routing model portion | ${total_routing_cost:.2f} |
|
| 1590 |
+
| **Savings** | ${savings:.2f} ({savings_pct:+.1f}%) |
|
| 1591 |
+
|
| 1592 |
+
*Strategy: {strategy_1_val}*
|
| 1593 |
+
*Routing model: {routing_model_1_val}*
|
| 1594 |
+
"""
|
| 1595 |
+
|
| 1596 |
+
additional_token_models = [(routing_model_1_val, routing_tokens)]
|
| 1597 |
+
additional_cost_models = [(routing_model_1_val, routing_costs)]
|
| 1598 |
+
|
| 1599 |
+
yield (
|
| 1600 |
+
gr.update(visible=True, value="β³ Creating charts..."),
|
| 1601 |
+
gr.update(visible=True),
|
| 1602 |
+
None,
|
| 1603 |
+
None,
|
| 1604 |
+
)
|
| 1605 |
+
|
| 1606 |
+
tokens_chart = create_routed_token_chart(base_tokens, additional_token_models)
|
| 1607 |
+
cost_chart = create_routed_cost_chart(base_costs, additional_cost_models)
|
| 1608 |
+
|
| 1609 |
+
yield (
|
| 1610 |
+
gr.update(visible=True, value=result_text),
|
| 1611 |
+
gr.update(visible=True),
|
| 1612 |
+
tokens_chart,
|
| 1613 |
+
cost_chart,
|
| 1614 |
+
)
|
| 1615 |
+
|
| 1616 |
+
route_btn.click(
|
| 1617 |
+
fn=run_routing,
|
| 1618 |
+
inputs=[
|
| 1619 |
+
trajectories_state,
|
| 1620 |
+
price_input, price_cache_read, price_cache_creation, price_completion,
|
| 1621 |
+
routing_model_1, routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion,
|
| 1622 |
+
strategy_1, random_pct_1, step_k_1, start_step_1, end_step_1,
|
| 1623 |
+
token_source, thinking_overhead, use_cache,
|
| 1624 |
+
],
|
| 1625 |
+
outputs=[routing_result, routing_plots_row, routing_tokens_plot, routing_cost_plot],
|
| 1626 |
+
)
|
| 1627 |
+
|
| 1628 |
def update_calculated_options_visibility(source):
|
| 1629 |
is_calc = source == "Calculated"
|
| 1630 |
return gr.update(visible=is_calc), gr.update(visible=is_calc)
|