IgorSlinko commited on
Commit
0c41621
Β·
1 Parent(s): c63e9d7

Refactor routing strategy UI

Browse files

- Replace per-model strategy with global Router Strategy
- Add three strategies: Random weights, Every k-th step, Replace part of trajectory
- Random weights: each step randomly assigned based on weights (must sum to 1.0)
- Every k-th step: first model has priority on overlaps
- Replace part: non-overlapping ranges for each model
- Simplify UI: remove containers, use individual visible components
- Fix strategy parameter visibility on strategy change

Files changed (1) hide show
  1. app.py +160 -179
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import os
 
3
  import re
4
  import subprocess
5
  from pathlib import Path
@@ -29,53 +30,6 @@ _calculated_tokens_cache = {}
29
  _trajectory_steps_cache = {}
30
 
31
 
32
- def parse_start_end(start: float, end: float, total_steps: int, mode: str) -> tuple[int, int]:
33
- """
34
- Parse start and end values based on mode.
35
-
36
- Args:
37
- start: start value
38
- end: end value
39
- total_steps: total number of steps in trajectory
40
- mode: "Indexes" or "Percentages"
41
-
42
- Returns: (start_idx, end_idx) - both 0-based
43
- """
44
- if mode == "Indexes":
45
- return int(start), int(end)
46
- else:
47
- return int(start * total_steps / 100), int(end * total_steps / 100)
48
-
49
-
50
- def get_routed_steps(total_steps: int, strategy: str, params: dict) -> set:
51
- """
52
- Determine which steps should be routed to alternative model.
53
-
54
- Returns set of step indices (0-based) that should use the routing model.
55
- """
56
- import random
57
-
58
- routed = set()
59
-
60
- if strategy == "Replace on random steps":
61
- pct = params.get("percentage", 50) / 100.0
62
- num_to_route = int(total_steps * pct)
63
- if num_to_route > 0:
64
- routed = set(random.sample(range(total_steps), min(num_to_route, total_steps)))
65
-
66
- elif strategy == "Replace every step k":
67
- k = int(params.get("k", 2))
68
- if k > 0:
69
- routed = set(range(0, total_steps, k))
70
-
71
- elif strategy == "Replace part of trajectory":
72
- mode = params.get("mode", "Percentages")
73
- start, end = parse_start_end(params.get("start", 0), params.get("end", 30), total_steps, mode)
74
- routed = set(range(start, min(end, total_steps)))
75
-
76
- return routed
77
-
78
-
79
  def calculate_routing_tokens(steps: list[dict]) -> dict:
80
  """
81
  Calculate token breakdown per model with proper caching simulation.
@@ -1275,12 +1229,6 @@ def build_app():
1275
  with gr.Column(visible=False) as routing_section:
1276
  gr.Markdown("### πŸ”€ Routing Models")
1277
 
1278
- STRATEGY_CHOICES = [
1279
- "Replace on random steps",
1280
- "Replace every step k",
1281
- "Replace part of trajectory",
1282
- ]
1283
-
1284
  with gr.Column():
1285
  with gr.Group():
1286
  gr.Markdown("#### Route to Model 1")
@@ -1295,26 +1243,6 @@ def build_app():
1295
  routing_price_1_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
1296
  routing_price_1_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
1297
  routing_price_1_completion = gr.Number(label="Completion", precision=3, scale=1)
1298
- strategy_1 = gr.Dropdown(
1299
- label="Strategy",
1300
- choices=STRATEGY_CHOICES,
1301
- value="Replace on random steps",
1302
- interactive=True,
1303
- )
1304
- with gr.Row(visible=True) as random_params_1:
1305
- random_pct_1 = gr.Number(label="Percentage (%)", value=50, minimum=0, maximum=100, precision=0, interactive=True)
1306
- with gr.Row(visible=False) as every_k_params_1:
1307
- step_k_1 = gr.Number(label="k", value=2, minimum=1, precision=0, interactive=True)
1308
- with gr.Column(visible=False) as part_params_1:
1309
- part_mode_1 = gr.Radio(
1310
- choices=["Indexes", "Percentages"],
1311
- value="Percentages",
1312
- label="Mode",
1313
- interactive=True,
1314
- )
1315
- with gr.Row():
1316
- start_step_1 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
1317
- end_step_1 = gr.Number(label="End", value=30, minimum=0, precision=0, interactive=True)
1318
 
1319
  add_model_2_btn = gr.Button("+ Add another model", size="sm", visible=False)
1320
 
@@ -1332,26 +1260,6 @@ def build_app():
1332
  routing_price_2_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
1333
  routing_price_2_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
1334
  routing_price_2_completion = gr.Number(label="Completion", precision=3, scale=1)
1335
- strategy_2 = gr.Dropdown(
1336
- label="Strategy",
1337
- choices=STRATEGY_CHOICES,
1338
- value="Replace on random steps",
1339
- interactive=True,
1340
- )
1341
- with gr.Row(visible=True) as random_params_2:
1342
- random_pct_2 = gr.Number(label="Percentage (%)", value=50, minimum=0, maximum=100, precision=0, interactive=True)
1343
- with gr.Row(visible=False) as every_k_params_2:
1344
- step_k_2 = gr.Number(label="k", value=2, minimum=1, precision=0, interactive=True)
1345
- with gr.Column(visible=False) as part_params_2:
1346
- part_mode_2 = gr.Radio(
1347
- choices=["Indexes", "Percentages"],
1348
- value="Percentages",
1349
- label="Mode",
1350
- interactive=True,
1351
- )
1352
- with gr.Row():
1353
- start_step_2 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
1354
- end_step_2 = gr.Number(label="End", value=30, minimum=0, precision=0, interactive=True)
1355
 
1356
  add_model_3_btn = gr.Button("+ Add another model", size="sm", visible=False)
1357
 
@@ -1369,39 +1277,48 @@ def build_app():
1369
  routing_price_3_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
1370
  routing_price_3_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
1371
  routing_price_3_completion = gr.Number(label="Completion", precision=3, scale=1)
1372
- strategy_3 = gr.Dropdown(
1373
- label="Strategy",
1374
- choices=STRATEGY_CHOICES,
1375
- value="Replace on random steps",
1376
- interactive=True,
1377
- )
1378
- with gr.Row(visible=True) as random_params_3:
1379
- random_pct_3 = gr.Number(label="Percentage (%)", value=50, minimum=0, maximum=100, precision=0, interactive=True)
1380
- with gr.Row(visible=False) as every_k_params_3:
1381
- step_k_3 = gr.Number(label="k", value=2, minimum=1, precision=0, interactive=True)
1382
- with gr.Column(visible=False) as part_params_3:
1383
- part_mode_3 = gr.Radio(
1384
- choices=["Indexes", "Percentages"],
1385
- value="Percentages",
1386
- label="Mode",
1387
- interactive=True,
1388
- )
1389
- with gr.Row():
1390
- start_step_3 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
1391
- end_step_3 = gr.Number(label="End", value=30, minimum=0, precision=0, interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1392
 
1393
  gr.Markdown("---")
1394
  route_btn = gr.Button("πŸš€ Let's ROUTE!!", variant="primary", size="lg", interactive=False)
1395
  routing_result = gr.Markdown(visible=False)
1396
 
1397
 
1398
- def on_strategy_change(strategy):
1399
- return (
1400
- gr.update(visible=strategy == "Replace on random steps"),
1401
- gr.update(visible=strategy == "Replace every step k"),
1402
- gr.update(visible=strategy == "Replace part of trajectory"),
1403
- )
1404
-
1405
  def toggle_routing_section():
1406
  return gr.update(visible=True)
1407
 
@@ -1410,22 +1327,31 @@ def build_app():
1410
  outputs=[routing_section],
1411
  )
1412
 
1413
- strategy_1.change(
1414
- fn=on_strategy_change,
1415
- inputs=[strategy_1],
1416
- outputs=[random_params_1, every_k_params_1, part_params_1],
1417
- )
1418
-
1419
- strategy_2.change(
1420
- fn=on_strategy_change,
1421
- inputs=[strategy_2],
1422
- outputs=[random_params_2, every_k_params_2, part_params_2],
1423
- )
 
 
 
 
 
1424
 
1425
- strategy_3.change(
1426
  fn=on_strategy_change,
1427
- inputs=[strategy_3],
1428
- outputs=[random_params_3, every_k_params_3, part_params_3],
 
 
 
 
1429
  )
1430
 
1431
  def filter_models(query):
@@ -1498,9 +1424,23 @@ def build_app():
1498
  outputs=[routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion, add_model_2_btn, route_btn],
1499
  )
1500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1501
  add_model_2_btn.click(
1502
- fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
1503
- outputs=[routing_block_2, add_model_2_btn],
 
1504
  )
1505
 
1506
  routing_model_2.change(
@@ -1509,9 +1449,23 @@ def build_app():
1509
  outputs=[routing_price_2_input, routing_price_2_cache_read, routing_price_2_cache_creation, routing_price_2_completion, add_model_3_btn],
1510
  )
1511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1512
  add_model_3_btn.click(
1513
- fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
1514
- outputs=[routing_block_3, add_model_3_btn],
 
1515
  )
1516
 
1517
  routing_model_3.change(
@@ -1524,11 +1478,12 @@ def build_app():
1524
  state_data,
1525
  base_input, base_cache_read, base_cache_creation, base_completion,
1526
  routing_model_1_val, r1_input, r1_cache_read, r1_cache_creation, r1_completion,
1527
- strategy_1_val, random_pct_1_val, step_k_1_val, part_mode_1_val, start_1_val, end_1_val,
1528
  routing_model_2_val, r2_input, r2_cache_read, r2_cache_creation, r2_completion,
1529
- strategy_2_val, random_pct_2_val, step_k_2_val, part_mode_2_val, start_2_val, end_2_val,
1530
  routing_model_3_val, r3_input, r3_cache_read, r3_cache_creation, r3_completion,
1531
- strategy_3_val, random_pct_3_val, step_k_3_val, part_mode_3_val, start_3_val, end_3_val,
 
 
 
1532
  source, overhead, with_cache
1533
  ):
1534
  if state_data is None:
@@ -1556,6 +1511,7 @@ def build_app():
1556
  )
1557
  return
1558
 
 
1559
  df_calc = state_data.get("calculated")
1560
  if df_calc is not None and not df_calc.empty:
1561
  df_for_cost = apply_thinking_overhead(df_calc.copy(), overhead)
@@ -1579,50 +1535,56 @@ def build_app():
1579
  "completion": base_completion,
1580
  }
1581
 
1582
- def build_strategy_params(strategy, random_pct, step_k, part_mode, start_val, end_val):
1583
- params = {}
1584
- if strategy == "Replace on random steps":
1585
- params["percentage"] = random_pct
1586
- elif strategy == "Replace every step k":
1587
- params["k"] = step_k
1588
- elif strategy == "Replace part of trajectory":
1589
- params["mode"] = part_mode
1590
- params["start"] = start_val
1591
- params["end"] = end_val
1592
- return params
1593
-
1594
  routing_models = []
1595
  if routing_model_1_val:
1596
- if strategy_1_val == "Replace part of trajectory" and start_1_val >= end_1_val:
1597
- yield (gr.update(visible=True, value="❌ Model 1: Start must be less than End"), gr.update(visible=False), None, None)
1598
- return
1599
  routing_models.append({
1600
  "name": routing_model_1_val,
1601
  "prices": {"input": r1_input, "cache_read": r1_cache_read, "cache_creation": r1_cache_creation, "completion": r1_completion},
1602
- "strategy": strategy_1_val,
1603
- "params": build_strategy_params(strategy_1_val, random_pct_1_val, step_k_1_val, part_mode_1_val, start_1_val, end_1_val),
1604
  })
1605
  if routing_model_2_val:
1606
- if strategy_2_val == "Replace part of trajectory" and start_2_val >= end_2_val:
1607
- yield (gr.update(visible=True, value="❌ Model 2: Start must be less than End"), gr.update(visible=False), None, None)
1608
- return
1609
  routing_models.append({
1610
  "name": routing_model_2_val,
1611
  "prices": {"input": r2_input, "cache_read": r2_cache_read, "cache_creation": r2_cache_creation, "completion": r2_completion},
1612
- "strategy": strategy_2_val,
1613
- "params": build_strategy_params(strategy_2_val, random_pct_2_val, step_k_2_val, part_mode_2_val, start_2_val, end_2_val),
1614
  })
1615
  if routing_model_3_val:
1616
- if strategy_3_val == "Replace part of trajectory" and start_3_val >= end_3_val:
1617
- yield (gr.update(visible=True, value="❌ Model 3: Start must be less than End"), gr.update(visible=False), None, None)
1618
- return
1619
  routing_models.append({
1620
  "name": routing_model_3_val,
1621
  "prices": {"input": r3_input, "cache_read": r3_cache_read, "cache_creation": r3_cache_creation, "completion": r3_completion},
1622
- "strategy": strategy_3_val,
1623
- "params": build_strategy_params(strategy_3_val, random_pct_3_val, step_k_3_val, part_mode_3_val, start_3_val, end_3_val),
1624
  })
1625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1626
  BASE_MODEL = "__base__"
1627
  model_keys = [BASE_MODEL] + [f"__routing_{i}__" for i in range(len(routing_models))]
1628
 
@@ -1635,17 +1597,35 @@ def build_app():
1635
 
1636
  total_steps = len(steps)
1637
 
1638
- routed_sets = []
1639
- for rm in routing_models:
1640
- routed_sets.append(get_routed_steps(total_steps, rm["strategy"], rm["params"]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1641
 
1642
  modified_steps = []
1643
  for i, step in enumerate(steps):
1644
- model = BASE_MODEL
1645
- for j, routed_set in enumerate(routed_sets):
1646
- if i in routed_set:
1647
- model = f"__routing_{j}__"
1648
- break
1649
  modified_steps.append({
1650
  "model": model,
1651
  "system_user": step.get("system_user", 0),
@@ -1752,11 +1732,12 @@ def build_app():
1752
  trajectories_state,
1753
  price_input, price_cache_read, price_cache_creation, price_completion,
1754
  routing_model_1, routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion,
1755
- strategy_1, random_pct_1, step_k_1, part_mode_1, start_step_1, end_step_1,
1756
  routing_model_2, routing_price_2_input, routing_price_2_cache_read, routing_price_2_cache_creation, routing_price_2_completion,
1757
- strategy_2, random_pct_2, step_k_2, part_mode_2, start_step_2, end_step_2,
1758
  routing_model_3, routing_price_3_input, routing_price_3_cache_read, routing_price_3_cache_creation, routing_price_3_completion,
1759
- strategy_3, random_pct_3, step_k_3, part_mode_3, start_step_3, end_step_3,
 
 
 
1760
  token_source, thinking_overhead, use_cache,
1761
  ],
1762
  outputs=[routing_result, routing_plots_row, routing_tokens_plot, routing_cost_plot],
 
1
  import json
2
  import os
3
+ import random
4
  import re
5
  import subprocess
6
  from pathlib import Path
 
30
  _trajectory_steps_cache = {}
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def calculate_routing_tokens(steps: list[dict]) -> dict:
34
  """
35
  Calculate token breakdown per model with proper caching simulation.
 
1229
  with gr.Column(visible=False) as routing_section:
1230
  gr.Markdown("### πŸ”€ Routing Models")
1231
 
 
 
 
 
 
 
1232
  with gr.Column():
1233
  with gr.Group():
1234
  gr.Markdown("#### Route to Model 1")
 
1243
  routing_price_1_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
1244
  routing_price_1_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
1245
  routing_price_1_completion = gr.Number(label="Completion", precision=3, scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1246
 
1247
  add_model_2_btn = gr.Button("+ Add another model", size="sm", visible=False)
1248
 
 
1260
  routing_price_2_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
1261
  routing_price_2_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
1262
  routing_price_2_completion = gr.Number(label="Completion", precision=3, scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1263
 
1264
  add_model_3_btn = gr.Button("+ Add another model", size="sm", visible=False)
1265
 
 
1277
  routing_price_3_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
1278
  routing_price_3_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
1279
  routing_price_3_completion = gr.Number(label="Completion", precision=3, scale=1)
1280
+
1281
+ gr.Markdown("---")
1282
+ gr.Markdown("### 🎯 Router Strategy")
1283
+
1284
+ selected_strategy = gr.Radio(
1285
+ choices=["Random weights", "Every k-th step", "Replace part of trajectory"],
1286
+ value="Random weights",
1287
+ label="Strategy",
1288
+ interactive=True,
1289
+ )
1290
+
1291
+ random_hint = gr.Markdown("*Weights must sum to 1.0*", visible=True)
1292
+ weight_base = gr.Number(label="Base weight", value=0.5, minimum=0, maximum=1, precision=2, interactive=True, visible=True)
1293
+ weight_model_1 = gr.Number(label="Model 1 weight", value=0.5, minimum=0, maximum=1, precision=2, interactive=True, visible=True)
1294
+ weight_model_2 = gr.Number(label="Model 2 weight", value=0, minimum=0, maximum=1, precision=2, interactive=True, visible=False)
1295
+ weight_model_3 = gr.Number(label="Model 3 weight", value=0, minimum=0, maximum=1, precision=2, interactive=True, visible=False)
1296
+
1297
+ every_k_hint = gr.Markdown("*First model has priority on overlaps*", visible=False)
1298
+ k_model_1 = gr.Number(label="k₁ (Model 1)", value=2, minimum=1, precision=0, interactive=True, visible=False)
1299
+ k_model_2 = gr.Number(label="kβ‚‚ (Model 2)", value=3, minimum=1, precision=0, interactive=True, visible=False)
1300
+ k_model_3 = gr.Number(label="k₃ (Model 3)", value=5, minimum=1, precision=0, interactive=True, visible=False)
1301
+
1302
+ part_hint = gr.Markdown("*Ranges must not overlap*", visible=False)
1303
+ part_mode = gr.Radio(
1304
+ choices=["Indexes", "Percentages"],
1305
+ value="Percentages",
1306
+ label="Mode",
1307
+ interactive=True,
1308
+ visible=False,
1309
+ )
1310
+ start_1 = gr.Number(label="M1 Start", value=0, minimum=0, precision=0, interactive=True, visible=False)
1311
+ end_1 = gr.Number(label="M1 End", value=30, minimum=0, precision=0, interactive=True, visible=False)
1312
+ start_2 = gr.Number(label="M2 Start", value=30, minimum=0, precision=0, interactive=True, visible=False)
1313
+ end_2 = gr.Number(label="M2 End", value=60, minimum=0, precision=0, interactive=True, visible=False)
1314
+ start_3 = gr.Number(label="M3 Start", value=60, minimum=0, precision=0, interactive=True, visible=False)
1315
+ end_3 = gr.Number(label="M3 End", value=100, minimum=0, precision=0, interactive=True, visible=False)
1316
 
1317
  gr.Markdown("---")
1318
  route_btn = gr.Button("πŸš€ Let's ROUTE!!", variant="primary", size="lg", interactive=False)
1319
  routing_result = gr.Markdown(visible=False)
1320
 
1321
 
 
 
 
 
 
 
 
1322
  def toggle_routing_section():
1323
  return gr.update(visible=True)
1324
 
 
1327
  outputs=[routing_section],
1328
  )
1329
 
1330
+ def on_strategy_change(strategy):
1331
+ is_random = strategy == "Random weights"
1332
+ is_every_k = strategy == "Every k-th step"
1333
+ is_part = strategy == "Replace part of trajectory"
1334
+ print(f"DEBUG on_strategy_change: strategy={strategy}")
1335
+ return (
1336
+ gr.update(visible=is_random),
1337
+ gr.update(visible=is_random),
1338
+ gr.update(visible=is_random),
1339
+ gr.update(visible=is_every_k),
1340
+ gr.update(visible=is_every_k),
1341
+ gr.update(visible=is_part),
1342
+ gr.update(visible=is_part),
1343
+ gr.update(visible=is_part),
1344
+ gr.update(visible=is_part),
1345
+ )
1346
 
1347
+ selected_strategy.change(
1348
  fn=on_strategy_change,
1349
+ inputs=[selected_strategy],
1350
+ outputs=[
1351
+ random_hint, weight_base, weight_model_1,
1352
+ every_k_hint, k_model_1,
1353
+ part_hint, part_mode, start_1, end_1,
1354
+ ],
1355
  )
1356
 
1357
  def filter_models(query):
 
1424
  outputs=[routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion, add_model_2_btn, route_btn],
1425
  )
1426
 
1427
+ def show_model_2(strategy):
1428
+ is_random = strategy == "Random weights"
1429
+ is_every_k = strategy == "Every k-th step"
1430
+ is_part = strategy == "Replace part of trajectory"
1431
+ return (
1432
+ gr.update(visible=True),
1433
+ gr.update(visible=False),
1434
+ gr.update(visible=is_random),
1435
+ gr.update(visible=is_every_k),
1436
+ gr.update(visible=is_part),
1437
+ gr.update(visible=is_part),
1438
+ )
1439
+
1440
  add_model_2_btn.click(
1441
+ fn=show_model_2,
1442
+ inputs=[selected_strategy],
1443
+ outputs=[routing_block_2, add_model_2_btn, weight_model_2, k_model_2, start_2, end_2],
1444
  )
1445
 
1446
  routing_model_2.change(
 
1449
  outputs=[routing_price_2_input, routing_price_2_cache_read, routing_price_2_cache_creation, routing_price_2_completion, add_model_3_btn],
1450
  )
1451
 
1452
+ def show_model_3(strategy):
1453
+ is_random = strategy == "Random weights"
1454
+ is_every_k = strategy == "Every k-th step"
1455
+ is_part = strategy == "Replace part of trajectory"
1456
+ return (
1457
+ gr.update(visible=True),
1458
+ gr.update(visible=False),
1459
+ gr.update(visible=is_random),
1460
+ gr.update(visible=is_every_k),
1461
+ gr.update(visible=is_part),
1462
+ gr.update(visible=is_part),
1463
+ )
1464
+
1465
  add_model_3_btn.click(
1466
+ fn=show_model_3,
1467
+ inputs=[selected_strategy],
1468
+ outputs=[routing_block_3, add_model_3_btn, weight_model_3, k_model_3, start_3, end_3],
1469
  )
1470
 
1471
  routing_model_3.change(
 
1478
  state_data,
1479
  base_input, base_cache_read, base_cache_creation, base_completion,
1480
  routing_model_1_val, r1_input, r1_cache_read, r1_cache_creation, r1_completion,
 
1481
  routing_model_2_val, r2_input, r2_cache_read, r2_cache_creation, r2_completion,
 
1482
  routing_model_3_val, r3_input, r3_cache_read, r3_cache_creation, r3_completion,
1483
+ strategy_val,
1484
+ weight_base_val, weight_1_val, weight_2_val, weight_3_val,
1485
+ k_1_val, k_2_val, k_3_val,
1486
+ part_mode_val, start_1_val, end_1_val, start_2_val, end_2_val, start_3_val, end_3_val,
1487
  source, overhead, with_cache
1488
  ):
1489
  if state_data is None:
 
1511
  )
1512
  return
1513
 
1514
+
1515
  df_calc = state_data.get("calculated")
1516
  if df_calc is not None and not df_calc.empty:
1517
  df_for_cost = apply_thinking_overhead(df_calc.copy(), overhead)
 
1535
  "completion": base_completion,
1536
  }
1537
 
 
 
 
 
 
 
 
 
 
 
 
 
1538
  routing_models = []
1539
  if routing_model_1_val:
 
 
 
1540
  routing_models.append({
1541
  "name": routing_model_1_val,
1542
  "prices": {"input": r1_input, "cache_read": r1_cache_read, "cache_creation": r1_cache_creation, "completion": r1_completion},
 
 
1543
  })
1544
  if routing_model_2_val:
 
 
 
1545
  routing_models.append({
1546
  "name": routing_model_2_val,
1547
  "prices": {"input": r2_input, "cache_read": r2_cache_read, "cache_creation": r2_cache_creation, "completion": r2_completion},
 
 
1548
  })
1549
  if routing_model_3_val:
 
 
 
1550
  routing_models.append({
1551
  "name": routing_model_3_val,
1552
  "prices": {"input": r3_input, "cache_read": r3_cache_read, "cache_creation": r3_cache_creation, "completion": r3_completion},
 
 
1553
  })
1554
 
1555
+ if strategy_val == "Replace part of trajectory":
1556
+ ranges = [(start_1_val, end_1_val)]
1557
+ if len(routing_models) > 1:
1558
+ ranges.append((start_2_val, end_2_val))
1559
+ if len(routing_models) > 2:
1560
+ ranges.append((start_3_val, end_3_val))
1561
+ for i, (s, e) in enumerate(ranges):
1562
+ if s >= e:
1563
+ yield (gr.update(visible=True, value=f"❌ Model {i+1}: Start must be less than End"), gr.update(visible=False), None, None)
1564
+ return
1565
+ for i in range(len(ranges)):
1566
+ for j in range(i+1, len(ranges)):
1567
+ s1, e1 = ranges[i]
1568
+ s2, e2 = ranges[j]
1569
+ if not (e1 <= s2 or e2 <= s1):
1570
+ yield (gr.update(visible=True, value=f"❌ Model {i+1} and Model {j+1} ranges overlap"), gr.update(visible=False), None, None)
1571
+ return
1572
+
1573
+ weights = None
1574
+ if strategy_val == "Random weights":
1575
+ weights = [weight_base_val, weight_1_val]
1576
+ if len(routing_models) > 1:
1577
+ weights.append(weight_2_val)
1578
+ if len(routing_models) > 2:
1579
+ weights.append(weight_3_val)
1580
+ total_weight = sum(weights)
1581
+ if abs(total_weight - 1.0) > 0.01:
1582
+ yield (gr.update(visible=True, value=f"❌ Weights must sum to 1.0 (current: {total_weight:.2f})"), gr.update(visible=False), None, None)
1583
+ return
1584
+
1585
+ k_values = [k_1_val, k_2_val, k_3_val][:len(routing_models)]
1586
+ part_ranges = [(start_1_val, end_1_val), (start_2_val, end_2_val), (start_3_val, end_3_val)][:len(routing_models)]
1587
+
1588
  BASE_MODEL = "__base__"
1589
  model_keys = [BASE_MODEL] + [f"__routing_{i}__" for i in range(len(routing_models))]
1590
 
 
1597
 
1598
  total_steps = len(steps)
1599
 
1600
+ step_to_model = {}
1601
+
1602
+ if strategy_val == "Random weights":
1603
+ model_choices = [BASE_MODEL] + [f"__routing_{j}__" for j in range(len(routing_models))]
1604
+ for i in range(total_steps):
1605
+ step_to_model[i] = random.choices(model_choices, weights=weights)[0]
1606
+
1607
+ elif strategy_val == "Every k-th step":
1608
+ for j, k_val in enumerate(k_values):
1609
+ if k_val and k_val > 0:
1610
+ for i in range(total_steps):
1611
+ if (i + 1) % int(k_val) == 0:
1612
+ if i not in step_to_model:
1613
+ step_to_model[i] = f"__routing_{j}__"
1614
+
1615
+ elif strategy_val == "Replace part of trajectory":
1616
+ for j, (start_val, end_val) in enumerate(part_ranges):
1617
+ if part_mode_val == "Percentages":
1618
+ start_idx = int(total_steps * start_val / 100)
1619
+ end_idx = int(total_steps * end_val / 100)
1620
+ else:
1621
+ start_idx = int(start_val)
1622
+ end_idx = min(int(end_val), total_steps)
1623
+ for i in range(start_idx, end_idx):
1624
+ step_to_model[i] = f"__routing_{j}__"
1625
 
1626
  modified_steps = []
1627
  for i, step in enumerate(steps):
1628
+ model = step_to_model.get(i, BASE_MODEL)
 
 
 
 
1629
  modified_steps.append({
1630
  "model": model,
1631
  "system_user": step.get("system_user", 0),
 
1732
  trajectories_state,
1733
  price_input, price_cache_read, price_cache_creation, price_completion,
1734
  routing_model_1, routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion,
 
1735
  routing_model_2, routing_price_2_input, routing_price_2_cache_read, routing_price_2_cache_creation, routing_price_2_completion,
 
1736
  routing_model_3, routing_price_3_input, routing_price_3_cache_read, routing_price_3_cache_creation, routing_price_3_completion,
1737
+ selected_strategy,
1738
+ weight_base, weight_model_1, weight_model_2, weight_model_3,
1739
+ k_model_1, k_model_2, k_model_3,
1740
+ part_mode, start_1, end_1, start_2, end_2, start_3, end_3,
1741
  token_source, thinking_overhead, use_cache,
1742
  ],
1743
  outputs=[routing_result, routing_plots_row, routing_tokens_plot, routing_cost_plot],