Commit
Β·
0c41621
1
Parent(s):
c63e9d7
Refactor routing strategy UI
Browse files- Replace per-model strategy with global Router Strategy
- Add three strategies: Random weights, Every k-th step, Replace part of trajectory
- Random weights: each step randomly assigned based on weights (must sum to 1.0)
- Every k-th step: first model has priority on overlaps
- Replace part: non-overlapping ranges for each model
- Simplify UI: remove containers, use individual visible components
- Fix strategy parameter visibility on strategy change
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
|
|
|
| 3 |
import re
|
| 4 |
import subprocess
|
| 5 |
from pathlib import Path
|
|
@@ -29,53 +30,6 @@ _calculated_tokens_cache = {}
|
|
| 29 |
_trajectory_steps_cache = {}
|
| 30 |
|
| 31 |
|
| 32 |
-
def parse_start_end(start: float, end: float, total_steps: int, mode: str) -> tuple[int, int]:
|
| 33 |
-
"""
|
| 34 |
-
Parse start and end values based on mode.
|
| 35 |
-
|
| 36 |
-
Args:
|
| 37 |
-
start: start value
|
| 38 |
-
end: end value
|
| 39 |
-
total_steps: total number of steps in trajectory
|
| 40 |
-
mode: "Indexes" or "Percentages"
|
| 41 |
-
|
| 42 |
-
Returns: (start_idx, end_idx) - both 0-based
|
| 43 |
-
"""
|
| 44 |
-
if mode == "Indexes":
|
| 45 |
-
return int(start), int(end)
|
| 46 |
-
else:
|
| 47 |
-
return int(start * total_steps / 100), int(end * total_steps / 100)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def get_routed_steps(total_steps: int, strategy: str, params: dict) -> set:
|
| 51 |
-
"""
|
| 52 |
-
Determine which steps should be routed to alternative model.
|
| 53 |
-
|
| 54 |
-
Returns set of step indices (0-based) that should use the routing model.
|
| 55 |
-
"""
|
| 56 |
-
import random
|
| 57 |
-
|
| 58 |
-
routed = set()
|
| 59 |
-
|
| 60 |
-
if strategy == "Replace on random steps":
|
| 61 |
-
pct = params.get("percentage", 50) / 100.0
|
| 62 |
-
num_to_route = int(total_steps * pct)
|
| 63 |
-
if num_to_route > 0:
|
| 64 |
-
routed = set(random.sample(range(total_steps), min(num_to_route, total_steps)))
|
| 65 |
-
|
| 66 |
-
elif strategy == "Replace every step k":
|
| 67 |
-
k = int(params.get("k", 2))
|
| 68 |
-
if k > 0:
|
| 69 |
-
routed = set(range(0, total_steps, k))
|
| 70 |
-
|
| 71 |
-
elif strategy == "Replace part of trajectory":
|
| 72 |
-
mode = params.get("mode", "Percentages")
|
| 73 |
-
start, end = parse_start_end(params.get("start", 0), params.get("end", 30), total_steps, mode)
|
| 74 |
-
routed = set(range(start, min(end, total_steps)))
|
| 75 |
-
|
| 76 |
-
return routed
|
| 77 |
-
|
| 78 |
-
|
| 79 |
def calculate_routing_tokens(steps: list[dict]) -> dict:
|
| 80 |
"""
|
| 81 |
Calculate token breakdown per model with proper caching simulation.
|
|
@@ -1275,12 +1229,6 @@ def build_app():
|
|
| 1275 |
with gr.Column(visible=False) as routing_section:
|
| 1276 |
gr.Markdown("### π Routing Models")
|
| 1277 |
|
| 1278 |
-
STRATEGY_CHOICES = [
|
| 1279 |
-
"Replace on random steps",
|
| 1280 |
-
"Replace every step k",
|
| 1281 |
-
"Replace part of trajectory",
|
| 1282 |
-
]
|
| 1283 |
-
|
| 1284 |
with gr.Column():
|
| 1285 |
with gr.Group():
|
| 1286 |
gr.Markdown("#### Route to Model 1")
|
|
@@ -1295,26 +1243,6 @@ def build_app():
|
|
| 1295 |
routing_price_1_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
|
| 1296 |
routing_price_1_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
|
| 1297 |
routing_price_1_completion = gr.Number(label="Completion", precision=3, scale=1)
|
| 1298 |
-
strategy_1 = gr.Dropdown(
|
| 1299 |
-
label="Strategy",
|
| 1300 |
-
choices=STRATEGY_CHOICES,
|
| 1301 |
-
value="Replace on random steps",
|
| 1302 |
-
interactive=True,
|
| 1303 |
-
)
|
| 1304 |
-
with gr.Row(visible=True) as random_params_1:
|
| 1305 |
-
random_pct_1 = gr.Number(label="Percentage (%)", value=50, minimum=0, maximum=100, precision=0, interactive=True)
|
| 1306 |
-
with gr.Row(visible=False) as every_k_params_1:
|
| 1307 |
-
step_k_1 = gr.Number(label="k", value=2, minimum=1, precision=0, interactive=True)
|
| 1308 |
-
with gr.Column(visible=False) as part_params_1:
|
| 1309 |
-
part_mode_1 = gr.Radio(
|
| 1310 |
-
choices=["Indexes", "Percentages"],
|
| 1311 |
-
value="Percentages",
|
| 1312 |
-
label="Mode",
|
| 1313 |
-
interactive=True,
|
| 1314 |
-
)
|
| 1315 |
-
with gr.Row():
|
| 1316 |
-
start_step_1 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
|
| 1317 |
-
end_step_1 = gr.Number(label="End", value=30, minimum=0, precision=0, interactive=True)
|
| 1318 |
|
| 1319 |
add_model_2_btn = gr.Button("+ Add another model", size="sm", visible=False)
|
| 1320 |
|
|
@@ -1332,26 +1260,6 @@ def build_app():
|
|
| 1332 |
routing_price_2_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
|
| 1333 |
routing_price_2_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
|
| 1334 |
routing_price_2_completion = gr.Number(label="Completion", precision=3, scale=1)
|
| 1335 |
-
strategy_2 = gr.Dropdown(
|
| 1336 |
-
label="Strategy",
|
| 1337 |
-
choices=STRATEGY_CHOICES,
|
| 1338 |
-
value="Replace on random steps",
|
| 1339 |
-
interactive=True,
|
| 1340 |
-
)
|
| 1341 |
-
with gr.Row(visible=True) as random_params_2:
|
| 1342 |
-
random_pct_2 = gr.Number(label="Percentage (%)", value=50, minimum=0, maximum=100, precision=0, interactive=True)
|
| 1343 |
-
with gr.Row(visible=False) as every_k_params_2:
|
| 1344 |
-
step_k_2 = gr.Number(label="k", value=2, minimum=1, precision=0, interactive=True)
|
| 1345 |
-
with gr.Column(visible=False) as part_params_2:
|
| 1346 |
-
part_mode_2 = gr.Radio(
|
| 1347 |
-
choices=["Indexes", "Percentages"],
|
| 1348 |
-
value="Percentages",
|
| 1349 |
-
label="Mode",
|
| 1350 |
-
interactive=True,
|
| 1351 |
-
)
|
| 1352 |
-
with gr.Row():
|
| 1353 |
-
start_step_2 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
|
| 1354 |
-
end_step_2 = gr.Number(label="End", value=30, minimum=0, precision=0, interactive=True)
|
| 1355 |
|
| 1356 |
add_model_3_btn = gr.Button("+ Add another model", size="sm", visible=False)
|
| 1357 |
|
|
@@ -1369,39 +1277,48 @@ def build_app():
|
|
| 1369 |
routing_price_3_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
|
| 1370 |
routing_price_3_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
|
| 1371 |
routing_price_3_completion = gr.Number(label="Completion", precision=3, scale=1)
|
| 1372 |
-
|
| 1373 |
-
|
| 1374 |
-
|
| 1375 |
-
|
| 1376 |
-
|
| 1377 |
-
|
| 1378 |
-
|
| 1379 |
-
|
| 1380 |
-
|
| 1381 |
-
|
| 1382 |
-
|
| 1383 |
-
|
| 1384 |
-
|
| 1385 |
-
|
| 1386 |
-
|
| 1387 |
-
|
| 1388 |
-
|
| 1389 |
-
|
| 1390 |
-
|
| 1391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1392 |
|
| 1393 |
gr.Markdown("---")
|
| 1394 |
route_btn = gr.Button("π Let's ROUTE!!", variant="primary", size="lg", interactive=False)
|
| 1395 |
routing_result = gr.Markdown(visible=False)
|
| 1396 |
|
| 1397 |
|
| 1398 |
-
def on_strategy_change(strategy):
|
| 1399 |
-
return (
|
| 1400 |
-
gr.update(visible=strategy == "Replace on random steps"),
|
| 1401 |
-
gr.update(visible=strategy == "Replace every step k"),
|
| 1402 |
-
gr.update(visible=strategy == "Replace part of trajectory"),
|
| 1403 |
-
)
|
| 1404 |
-
|
| 1405 |
def toggle_routing_section():
|
| 1406 |
return gr.update(visible=True)
|
| 1407 |
|
|
@@ -1410,22 +1327,31 @@ def build_app():
|
|
| 1410 |
outputs=[routing_section],
|
| 1411 |
)
|
| 1412 |
|
| 1413 |
-
|
| 1414 |
-
|
| 1415 |
-
|
| 1416 |
-
|
| 1417 |
-
|
| 1418 |
-
|
| 1419 |
-
|
| 1420 |
-
|
| 1421 |
-
|
| 1422 |
-
|
| 1423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1424 |
|
| 1425 |
-
|
| 1426 |
fn=on_strategy_change,
|
| 1427 |
-
inputs=[
|
| 1428 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1429 |
)
|
| 1430 |
|
| 1431 |
def filter_models(query):
|
|
@@ -1498,9 +1424,23 @@ def build_app():
|
|
| 1498 |
outputs=[routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion, add_model_2_btn, route_btn],
|
| 1499 |
)
|
| 1500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1501 |
add_model_2_btn.click(
|
| 1502 |
-
fn=
|
| 1503 |
-
|
|
|
|
| 1504 |
)
|
| 1505 |
|
| 1506 |
routing_model_2.change(
|
|
@@ -1509,9 +1449,23 @@ def build_app():
|
|
| 1509 |
outputs=[routing_price_2_input, routing_price_2_cache_read, routing_price_2_cache_creation, routing_price_2_completion, add_model_3_btn],
|
| 1510 |
)
|
| 1511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1512 |
add_model_3_btn.click(
|
| 1513 |
-
fn=
|
| 1514 |
-
|
|
|
|
| 1515 |
)
|
| 1516 |
|
| 1517 |
routing_model_3.change(
|
|
@@ -1524,11 +1478,12 @@ def build_app():
|
|
| 1524 |
state_data,
|
| 1525 |
base_input, base_cache_read, base_cache_creation, base_completion,
|
| 1526 |
routing_model_1_val, r1_input, r1_cache_read, r1_cache_creation, r1_completion,
|
| 1527 |
-
strategy_1_val, random_pct_1_val, step_k_1_val, part_mode_1_val, start_1_val, end_1_val,
|
| 1528 |
routing_model_2_val, r2_input, r2_cache_read, r2_cache_creation, r2_completion,
|
| 1529 |
-
strategy_2_val, random_pct_2_val, step_k_2_val, part_mode_2_val, start_2_val, end_2_val,
|
| 1530 |
routing_model_3_val, r3_input, r3_cache_read, r3_cache_creation, r3_completion,
|
| 1531 |
-
|
|
|
|
|
|
|
|
|
|
| 1532 |
source, overhead, with_cache
|
| 1533 |
):
|
| 1534 |
if state_data is None:
|
|
@@ -1556,6 +1511,7 @@ def build_app():
|
|
| 1556 |
)
|
| 1557 |
return
|
| 1558 |
|
|
|
|
| 1559 |
df_calc = state_data.get("calculated")
|
| 1560 |
if df_calc is not None and not df_calc.empty:
|
| 1561 |
df_for_cost = apply_thinking_overhead(df_calc.copy(), overhead)
|
|
@@ -1579,50 +1535,56 @@ def build_app():
|
|
| 1579 |
"completion": base_completion,
|
| 1580 |
}
|
| 1581 |
|
| 1582 |
-
def build_strategy_params(strategy, random_pct, step_k, part_mode, start_val, end_val):
|
| 1583 |
-
params = {}
|
| 1584 |
-
if strategy == "Replace on random steps":
|
| 1585 |
-
params["percentage"] = random_pct
|
| 1586 |
-
elif strategy == "Replace every step k":
|
| 1587 |
-
params["k"] = step_k
|
| 1588 |
-
elif strategy == "Replace part of trajectory":
|
| 1589 |
-
params["mode"] = part_mode
|
| 1590 |
-
params["start"] = start_val
|
| 1591 |
-
params["end"] = end_val
|
| 1592 |
-
return params
|
| 1593 |
-
|
| 1594 |
routing_models = []
|
| 1595 |
if routing_model_1_val:
|
| 1596 |
-
if strategy_1_val == "Replace part of trajectory" and start_1_val >= end_1_val:
|
| 1597 |
-
yield (gr.update(visible=True, value="β Model 1: Start must be less than End"), gr.update(visible=False), None, None)
|
| 1598 |
-
return
|
| 1599 |
routing_models.append({
|
| 1600 |
"name": routing_model_1_val,
|
| 1601 |
"prices": {"input": r1_input, "cache_read": r1_cache_read, "cache_creation": r1_cache_creation, "completion": r1_completion},
|
| 1602 |
-
"strategy": strategy_1_val,
|
| 1603 |
-
"params": build_strategy_params(strategy_1_val, random_pct_1_val, step_k_1_val, part_mode_1_val, start_1_val, end_1_val),
|
| 1604 |
})
|
| 1605 |
if routing_model_2_val:
|
| 1606 |
-
if strategy_2_val == "Replace part of trajectory" and start_2_val >= end_2_val:
|
| 1607 |
-
yield (gr.update(visible=True, value="β Model 2: Start must be less than End"), gr.update(visible=False), None, None)
|
| 1608 |
-
return
|
| 1609 |
routing_models.append({
|
| 1610 |
"name": routing_model_2_val,
|
| 1611 |
"prices": {"input": r2_input, "cache_read": r2_cache_read, "cache_creation": r2_cache_creation, "completion": r2_completion},
|
| 1612 |
-
"strategy": strategy_2_val,
|
| 1613 |
-
"params": build_strategy_params(strategy_2_val, random_pct_2_val, step_k_2_val, part_mode_2_val, start_2_val, end_2_val),
|
| 1614 |
})
|
| 1615 |
if routing_model_3_val:
|
| 1616 |
-
if strategy_3_val == "Replace part of trajectory" and start_3_val >= end_3_val:
|
| 1617 |
-
yield (gr.update(visible=True, value="β Model 3: Start must be less than End"), gr.update(visible=False), None, None)
|
| 1618 |
-
return
|
| 1619 |
routing_models.append({
|
| 1620 |
"name": routing_model_3_val,
|
| 1621 |
"prices": {"input": r3_input, "cache_read": r3_cache_read, "cache_creation": r3_cache_creation, "completion": r3_completion},
|
| 1622 |
-
"strategy": strategy_3_val,
|
| 1623 |
-
"params": build_strategy_params(strategy_3_val, random_pct_3_val, step_k_3_val, part_mode_3_val, start_3_val, end_3_val),
|
| 1624 |
})
|
| 1625 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1626 |
BASE_MODEL = "__base__"
|
| 1627 |
model_keys = [BASE_MODEL] + [f"__routing_{i}__" for i in range(len(routing_models))]
|
| 1628 |
|
|
@@ -1635,17 +1597,35 @@ def build_app():
|
|
| 1635 |
|
| 1636 |
total_steps = len(steps)
|
| 1637 |
|
| 1638 |
-
|
| 1639 |
-
|
| 1640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1641 |
|
| 1642 |
modified_steps = []
|
| 1643 |
for i, step in enumerate(steps):
|
| 1644 |
-
model = BASE_MODEL
|
| 1645 |
-
for j, routed_set in enumerate(routed_sets):
|
| 1646 |
-
if i in routed_set:
|
| 1647 |
-
model = f"__routing_{j}__"
|
| 1648 |
-
break
|
| 1649 |
modified_steps.append({
|
| 1650 |
"model": model,
|
| 1651 |
"system_user": step.get("system_user", 0),
|
|
@@ -1752,11 +1732,12 @@ def build_app():
|
|
| 1752 |
trajectories_state,
|
| 1753 |
price_input, price_cache_read, price_cache_creation, price_completion,
|
| 1754 |
routing_model_1, routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion,
|
| 1755 |
-
strategy_1, random_pct_1, step_k_1, part_mode_1, start_step_1, end_step_1,
|
| 1756 |
routing_model_2, routing_price_2_input, routing_price_2_cache_read, routing_price_2_cache_creation, routing_price_2_completion,
|
| 1757 |
-
strategy_2, random_pct_2, step_k_2, part_mode_2, start_step_2, end_step_2,
|
| 1758 |
routing_model_3, routing_price_3_input, routing_price_3_cache_read, routing_price_3_cache_creation, routing_price_3_completion,
|
| 1759 |
-
|
|
|
|
|
|
|
|
|
|
| 1760 |
token_source, thinking_overhead, use_cache,
|
| 1761 |
],
|
| 1762 |
outputs=[routing_result, routing_plots_row, routing_tokens_plot, routing_cost_plot],
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
+
import random
|
| 4 |
import re
|
| 5 |
import subprocess
|
| 6 |
from pathlib import Path
|
|
|
|
| 30 |
_trajectory_steps_cache = {}
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def calculate_routing_tokens(steps: list[dict]) -> dict:
|
| 34 |
"""
|
| 35 |
Calculate token breakdown per model with proper caching simulation.
|
|
|
|
| 1229 |
with gr.Column(visible=False) as routing_section:
|
| 1230 |
gr.Markdown("### π Routing Models")
|
| 1231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1232 |
with gr.Column():
|
| 1233 |
with gr.Group():
|
| 1234 |
gr.Markdown("#### Route to Model 1")
|
|
|
|
| 1243 |
routing_price_1_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
|
| 1244 |
routing_price_1_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
|
| 1245 |
routing_price_1_completion = gr.Number(label="Completion", precision=3, scale=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1246 |
|
| 1247 |
add_model_2_btn = gr.Button("+ Add another model", size="sm", visible=False)
|
| 1248 |
|
|
|
|
| 1260 |
routing_price_2_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
|
| 1261 |
routing_price_2_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
|
| 1262 |
routing_price_2_completion = gr.Number(label="Completion", precision=3, scale=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1263 |
|
| 1264 |
add_model_3_btn = gr.Button("+ Add another model", size="sm", visible=False)
|
| 1265 |
|
|
|
|
| 1277 |
routing_price_3_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
|
| 1278 |
routing_price_3_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
|
| 1279 |
routing_price_3_completion = gr.Number(label="Completion", precision=3, scale=1)
|
| 1280 |
+
|
| 1281 |
+
gr.Markdown("---")
|
| 1282 |
+
gr.Markdown("### π― Router Strategy")
|
| 1283 |
+
|
| 1284 |
+
selected_strategy = gr.Radio(
|
| 1285 |
+
choices=["Random weights", "Every k-th step", "Replace part of trajectory"],
|
| 1286 |
+
value="Random weights",
|
| 1287 |
+
label="Strategy",
|
| 1288 |
+
interactive=True,
|
| 1289 |
+
)
|
| 1290 |
+
|
| 1291 |
+
random_hint = gr.Markdown("*Weights must sum to 1.0*", visible=True)
|
| 1292 |
+
weight_base = gr.Number(label="Base weight", value=0.5, minimum=0, maximum=1, precision=2, interactive=True, visible=True)
|
| 1293 |
+
weight_model_1 = gr.Number(label="Model 1 weight", value=0.5, minimum=0, maximum=1, precision=2, interactive=True, visible=True)
|
| 1294 |
+
weight_model_2 = gr.Number(label="Model 2 weight", value=0, minimum=0, maximum=1, precision=2, interactive=True, visible=False)
|
| 1295 |
+
weight_model_3 = gr.Number(label="Model 3 weight", value=0, minimum=0, maximum=1, precision=2, interactive=True, visible=False)
|
| 1296 |
+
|
| 1297 |
+
every_k_hint = gr.Markdown("*First model has priority on overlaps*", visible=False)
|
| 1298 |
+
k_model_1 = gr.Number(label="kβ (Model 1)", value=2, minimum=1, precision=0, interactive=True, visible=False)
|
| 1299 |
+
k_model_2 = gr.Number(label="kβ (Model 2)", value=3, minimum=1, precision=0, interactive=True, visible=False)
|
| 1300 |
+
k_model_3 = gr.Number(label="kβ (Model 3)", value=5, minimum=1, precision=0, interactive=True, visible=False)
|
| 1301 |
+
|
| 1302 |
+
part_hint = gr.Markdown("*Ranges must not overlap*", visible=False)
|
| 1303 |
+
part_mode = gr.Radio(
|
| 1304 |
+
choices=["Indexes", "Percentages"],
|
| 1305 |
+
value="Percentages",
|
| 1306 |
+
label="Mode",
|
| 1307 |
+
interactive=True,
|
| 1308 |
+
visible=False,
|
| 1309 |
+
)
|
| 1310 |
+
start_1 = gr.Number(label="M1 Start", value=0, minimum=0, precision=0, interactive=True, visible=False)
|
| 1311 |
+
end_1 = gr.Number(label="M1 End", value=30, minimum=0, precision=0, interactive=True, visible=False)
|
| 1312 |
+
start_2 = gr.Number(label="M2 Start", value=30, minimum=0, precision=0, interactive=True, visible=False)
|
| 1313 |
+
end_2 = gr.Number(label="M2 End", value=60, minimum=0, precision=0, interactive=True, visible=False)
|
| 1314 |
+
start_3 = gr.Number(label="M3 Start", value=60, minimum=0, precision=0, interactive=True, visible=False)
|
| 1315 |
+
end_3 = gr.Number(label="M3 End", value=100, minimum=0, precision=0, interactive=True, visible=False)
|
| 1316 |
|
| 1317 |
gr.Markdown("---")
|
| 1318 |
route_btn = gr.Button("π Let's ROUTE!!", variant="primary", size="lg", interactive=False)
|
| 1319 |
routing_result = gr.Markdown(visible=False)
|
| 1320 |
|
| 1321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1322 |
def toggle_routing_section():
|
| 1323 |
return gr.update(visible=True)
|
| 1324 |
|
|
|
|
| 1327 |
outputs=[routing_section],
|
| 1328 |
)
|
| 1329 |
|
| 1330 |
+
def on_strategy_change(strategy):
|
| 1331 |
+
is_random = strategy == "Random weights"
|
| 1332 |
+
is_every_k = strategy == "Every k-th step"
|
| 1333 |
+
is_part = strategy == "Replace part of trajectory"
|
| 1334 |
+
print(f"DEBUG on_strategy_change: strategy={strategy}")
|
| 1335 |
+
return (
|
| 1336 |
+
gr.update(visible=is_random),
|
| 1337 |
+
gr.update(visible=is_random),
|
| 1338 |
+
gr.update(visible=is_random),
|
| 1339 |
+
gr.update(visible=is_every_k),
|
| 1340 |
+
gr.update(visible=is_every_k),
|
| 1341 |
+
gr.update(visible=is_part),
|
| 1342 |
+
gr.update(visible=is_part),
|
| 1343 |
+
gr.update(visible=is_part),
|
| 1344 |
+
gr.update(visible=is_part),
|
| 1345 |
+
)
|
| 1346 |
|
| 1347 |
+
selected_strategy.change(
|
| 1348 |
fn=on_strategy_change,
|
| 1349 |
+
inputs=[selected_strategy],
|
| 1350 |
+
outputs=[
|
| 1351 |
+
random_hint, weight_base, weight_model_1,
|
| 1352 |
+
every_k_hint, k_model_1,
|
| 1353 |
+
part_hint, part_mode, start_1, end_1,
|
| 1354 |
+
],
|
| 1355 |
)
|
| 1356 |
|
| 1357 |
def filter_models(query):
|
|
|
|
| 1424 |
outputs=[routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion, add_model_2_btn, route_btn],
|
| 1425 |
)
|
| 1426 |
|
| 1427 |
+
def show_model_2(strategy):
|
| 1428 |
+
is_random = strategy == "Random weights"
|
| 1429 |
+
is_every_k = strategy == "Every k-th step"
|
| 1430 |
+
is_part = strategy == "Replace part of trajectory"
|
| 1431 |
+
return (
|
| 1432 |
+
gr.update(visible=True),
|
| 1433 |
+
gr.update(visible=False),
|
| 1434 |
+
gr.update(visible=is_random),
|
| 1435 |
+
gr.update(visible=is_every_k),
|
| 1436 |
+
gr.update(visible=is_part),
|
| 1437 |
+
gr.update(visible=is_part),
|
| 1438 |
+
)
|
| 1439 |
+
|
| 1440 |
add_model_2_btn.click(
|
| 1441 |
+
fn=show_model_2,
|
| 1442 |
+
inputs=[selected_strategy],
|
| 1443 |
+
outputs=[routing_block_2, add_model_2_btn, weight_model_2, k_model_2, start_2, end_2],
|
| 1444 |
)
|
| 1445 |
|
| 1446 |
routing_model_2.change(
|
|
|
|
| 1449 |
outputs=[routing_price_2_input, routing_price_2_cache_read, routing_price_2_cache_creation, routing_price_2_completion, add_model_3_btn],
|
| 1450 |
)
|
| 1451 |
|
| 1452 |
+
def show_model_3(strategy):
|
| 1453 |
+
is_random = strategy == "Random weights"
|
| 1454 |
+
is_every_k = strategy == "Every k-th step"
|
| 1455 |
+
is_part = strategy == "Replace part of trajectory"
|
| 1456 |
+
return (
|
| 1457 |
+
gr.update(visible=True),
|
| 1458 |
+
gr.update(visible=False),
|
| 1459 |
+
gr.update(visible=is_random),
|
| 1460 |
+
gr.update(visible=is_every_k),
|
| 1461 |
+
gr.update(visible=is_part),
|
| 1462 |
+
gr.update(visible=is_part),
|
| 1463 |
+
)
|
| 1464 |
+
|
| 1465 |
add_model_3_btn.click(
|
| 1466 |
+
fn=show_model_3,
|
| 1467 |
+
inputs=[selected_strategy],
|
| 1468 |
+
outputs=[routing_block_3, add_model_3_btn, weight_model_3, k_model_3, start_3, end_3],
|
| 1469 |
)
|
| 1470 |
|
| 1471 |
routing_model_3.change(
|
|
|
|
| 1478 |
state_data,
|
| 1479 |
base_input, base_cache_read, base_cache_creation, base_completion,
|
| 1480 |
routing_model_1_val, r1_input, r1_cache_read, r1_cache_creation, r1_completion,
|
|
|
|
| 1481 |
routing_model_2_val, r2_input, r2_cache_read, r2_cache_creation, r2_completion,
|
|
|
|
| 1482 |
routing_model_3_val, r3_input, r3_cache_read, r3_cache_creation, r3_completion,
|
| 1483 |
+
strategy_val,
|
| 1484 |
+
weight_base_val, weight_1_val, weight_2_val, weight_3_val,
|
| 1485 |
+
k_1_val, k_2_val, k_3_val,
|
| 1486 |
+
part_mode_val, start_1_val, end_1_val, start_2_val, end_2_val, start_3_val, end_3_val,
|
| 1487 |
source, overhead, with_cache
|
| 1488 |
):
|
| 1489 |
if state_data is None:
|
|
|
|
| 1511 |
)
|
| 1512 |
return
|
| 1513 |
|
| 1514 |
+
|
| 1515 |
df_calc = state_data.get("calculated")
|
| 1516 |
if df_calc is not None and not df_calc.empty:
|
| 1517 |
df_for_cost = apply_thinking_overhead(df_calc.copy(), overhead)
|
|
|
|
| 1535 |
"completion": base_completion,
|
| 1536 |
}
|
| 1537 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1538 |
routing_models = []
|
| 1539 |
if routing_model_1_val:
|
|
|
|
|
|
|
|
|
|
| 1540 |
routing_models.append({
|
| 1541 |
"name": routing_model_1_val,
|
| 1542 |
"prices": {"input": r1_input, "cache_read": r1_cache_read, "cache_creation": r1_cache_creation, "completion": r1_completion},
|
|
|
|
|
|
|
| 1543 |
})
|
| 1544 |
if routing_model_2_val:
|
|
|
|
|
|
|
|
|
|
| 1545 |
routing_models.append({
|
| 1546 |
"name": routing_model_2_val,
|
| 1547 |
"prices": {"input": r2_input, "cache_read": r2_cache_read, "cache_creation": r2_cache_creation, "completion": r2_completion},
|
|
|
|
|
|
|
| 1548 |
})
|
| 1549 |
if routing_model_3_val:
|
|
|
|
|
|
|
|
|
|
| 1550 |
routing_models.append({
|
| 1551 |
"name": routing_model_3_val,
|
| 1552 |
"prices": {"input": r3_input, "cache_read": r3_cache_read, "cache_creation": r3_cache_creation, "completion": r3_completion},
|
|
|
|
|
|
|
| 1553 |
})
|
| 1554 |
|
| 1555 |
+
if strategy_val == "Replace part of trajectory":
|
| 1556 |
+
ranges = [(start_1_val, end_1_val)]
|
| 1557 |
+
if len(routing_models) > 1:
|
| 1558 |
+
ranges.append((start_2_val, end_2_val))
|
| 1559 |
+
if len(routing_models) > 2:
|
| 1560 |
+
ranges.append((start_3_val, end_3_val))
|
| 1561 |
+
for i, (s, e) in enumerate(ranges):
|
| 1562 |
+
if s >= e:
|
| 1563 |
+
yield (gr.update(visible=True, value=f"β Model {i+1}: Start must be less than End"), gr.update(visible=False), None, None)
|
| 1564 |
+
return
|
| 1565 |
+
for i in range(len(ranges)):
|
| 1566 |
+
for j in range(i+1, len(ranges)):
|
| 1567 |
+
s1, e1 = ranges[i]
|
| 1568 |
+
s2, e2 = ranges[j]
|
| 1569 |
+
if not (e1 <= s2 or e2 <= s1):
|
| 1570 |
+
yield (gr.update(visible=True, value=f"β Model {i+1} and Model {j+1} ranges overlap"), gr.update(visible=False), None, None)
|
| 1571 |
+
return
|
| 1572 |
+
|
| 1573 |
+
weights = None
|
| 1574 |
+
if strategy_val == "Random weights":
|
| 1575 |
+
weights = [weight_base_val, weight_1_val]
|
| 1576 |
+
if len(routing_models) > 1:
|
| 1577 |
+
weights.append(weight_2_val)
|
| 1578 |
+
if len(routing_models) > 2:
|
| 1579 |
+
weights.append(weight_3_val)
|
| 1580 |
+
total_weight = sum(weights)
|
| 1581 |
+
if abs(total_weight - 1.0) > 0.01:
|
| 1582 |
+
yield (gr.update(visible=True, value=f"β Weights must sum to 1.0 (current: {total_weight:.2f})"), gr.update(visible=False), None, None)
|
| 1583 |
+
return
|
| 1584 |
+
|
| 1585 |
+
k_values = [k_1_val, k_2_val, k_3_val][:len(routing_models)]
|
| 1586 |
+
part_ranges = [(start_1_val, end_1_val), (start_2_val, end_2_val), (start_3_val, end_3_val)][:len(routing_models)]
|
| 1587 |
+
|
| 1588 |
BASE_MODEL = "__base__"
|
| 1589 |
model_keys = [BASE_MODEL] + [f"__routing_{i}__" for i in range(len(routing_models))]
|
| 1590 |
|
|
|
|
| 1597 |
|
| 1598 |
total_steps = len(steps)
|
| 1599 |
|
| 1600 |
+
step_to_model = {}
|
| 1601 |
+
|
| 1602 |
+
if strategy_val == "Random weights":
|
| 1603 |
+
model_choices = [BASE_MODEL] + [f"__routing_{j}__" for j in range(len(routing_models))]
|
| 1604 |
+
for i in range(total_steps):
|
| 1605 |
+
step_to_model[i] = random.choices(model_choices, weights=weights)[0]
|
| 1606 |
+
|
| 1607 |
+
elif strategy_val == "Every k-th step":
|
| 1608 |
+
for j, k_val in enumerate(k_values):
|
| 1609 |
+
if k_val and k_val > 0:
|
| 1610 |
+
for i in range(total_steps):
|
| 1611 |
+
if (i + 1) % int(k_val) == 0:
|
| 1612 |
+
if i not in step_to_model:
|
| 1613 |
+
step_to_model[i] = f"__routing_{j}__"
|
| 1614 |
+
|
| 1615 |
+
elif strategy_val == "Replace part of trajectory":
|
| 1616 |
+
for j, (start_val, end_val) in enumerate(part_ranges):
|
| 1617 |
+
if part_mode_val == "Percentages":
|
| 1618 |
+
start_idx = int(total_steps * start_val / 100)
|
| 1619 |
+
end_idx = int(total_steps * end_val / 100)
|
| 1620 |
+
else:
|
| 1621 |
+
start_idx = int(start_val)
|
| 1622 |
+
end_idx = min(int(end_val), total_steps)
|
| 1623 |
+
for i in range(start_idx, end_idx):
|
| 1624 |
+
step_to_model[i] = f"__routing_{j}__"
|
| 1625 |
|
| 1626 |
modified_steps = []
|
| 1627 |
for i, step in enumerate(steps):
|
| 1628 |
+
model = step_to_model.get(i, BASE_MODEL)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1629 |
modified_steps.append({
|
| 1630 |
"model": model,
|
| 1631 |
"system_user": step.get("system_user", 0),
|
|
|
|
| 1732 |
trajectories_state,
|
| 1733 |
price_input, price_cache_read, price_cache_creation, price_completion,
|
| 1734 |
routing_model_1, routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion,
|
|
|
|
| 1735 |
routing_model_2, routing_price_2_input, routing_price_2_cache_read, routing_price_2_cache_creation, routing_price_2_completion,
|
|
|
|
| 1736 |
routing_model_3, routing_price_3_input, routing_price_3_cache_read, routing_price_3_cache_creation, routing_price_3_completion,
|
| 1737 |
+
selected_strategy,
|
| 1738 |
+
weight_base, weight_model_1, weight_model_2, weight_model_3,
|
| 1739 |
+
k_model_1, k_model_2, k_model_3,
|
| 1740 |
+
part_mode, start_1, end_1, start_2, end_2, start_3, end_3,
|
| 1741 |
token_source, thinking_overhead, use_cache,
|
| 1742 |
],
|
| 1743 |
outputs=[routing_result, routing_plots_row, routing_tokens_plot, routing_cost_plot],
|