IgorSlinko commited on
Commit
c63e9d7
Β·
1 Parent(s): 1ae03ef

Support multiple routing models (up to 3)

Browse files

- Add interactive=True to all part_mode Radio buttons
- Rewrite run_routing to support all 3 models
- Each model can have different strategy and parameters
- Results table shows cost breakdown per model
- Charts support any number of models with different colors
- Validate Start < End for each model separately

Files changed (1) hide show
  1. app.py +103 -64
app.py CHANGED
@@ -1310,6 +1310,7 @@ def build_app():
1310
  choices=["Indexes", "Percentages"],
1311
  value="Percentages",
1312
  label="Mode",
 
1313
  )
1314
  with gr.Row():
1315
  start_step_1 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
@@ -1346,6 +1347,7 @@ def build_app():
1346
  choices=["Indexes", "Percentages"],
1347
  value="Percentages",
1348
  label="Mode",
 
1349
  )
1350
  with gr.Row():
1351
  start_step_2 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
@@ -1382,6 +1384,7 @@ def build_app():
1382
  choices=["Indexes", "Percentages"],
1383
  value="Percentages",
1384
  label="Mode",
 
1385
  )
1386
  with gr.Row():
1387
  start_step_3 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
@@ -1522,6 +1525,10 @@ def build_app():
1522
  base_input, base_cache_read, base_cache_creation, base_completion,
1523
  routing_model_1_val, r1_input, r1_cache_read, r1_cache_creation, r1_completion,
1524
  strategy_1_val, random_pct_1_val, step_k_1_val, part_mode_1_val, start_1_val, end_1_val,
 
 
 
 
1525
  source, overhead, with_cache
1526
  ):
1527
  if state_data is None:
@@ -1571,47 +1578,74 @@ def build_app():
1571
  "cache_creation": base_cache_creation,
1572
  "completion": base_completion,
1573
  }
1574
- routing_prices = {
1575
- "input": r1_input,
1576
- "cache_read": r1_cache_read,
1577
- "cache_creation": r1_cache_creation,
1578
- "completion": r1_completion,
1579
- }
1580
 
1581
- strategy_params = {}
1582
- if strategy_1_val == "Replace on random steps":
1583
- strategy_params["percentage"] = random_pct_1_val
1584
- elif strategy_1_val == "Replace every step k":
1585
- strategy_params["k"] = step_k_1_val
1586
- elif strategy_1_val == "Replace part of trajectory":
1587
- strategy_params["mode"] = part_mode_1_val
1588
- strategy_params["start"] = start_1_val
1589
- strategy_params["end"] = end_1_val
1590
- if start_1_val >= end_1_val:
1591
- yield (
1592
- gr.update(visible=True, value="❌ Start must be less than End"),
1593
- gr.update(visible=False),
1594
- None, None,
1595
- )
 
1596
  return
1597
-
1598
- total_base_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1599
- total_routing_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1600
- total_original_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1601
 
1602
  BASE_MODEL = "__base__"
1603
- ROUTING_MODEL = "__routing__"
 
 
 
1604
 
1605
  for instance_id, steps in trajectory_steps.items():
1606
  if not steps:
1607
  continue
1608
 
1609
  total_steps = len(steps)
1610
- routed_step_indices = get_routed_steps(total_steps, strategy_1_val, strategy_params)
 
 
 
1611
 
1612
  modified_steps = []
1613
  for i, step in enumerate(steps):
1614
- model = ROUTING_MODEL if i in routed_step_indices else BASE_MODEL
 
 
 
 
1615
  modified_steps.append({
1616
  "model": model,
1617
  "system_user": step.get("system_user", 0),
@@ -1621,22 +1655,12 @@ def build_app():
1621
 
1622
  model_totals = calculate_routing_tokens(modified_steps)
1623
 
1624
- base_totals = model_totals.get(BASE_MODEL, {
1625
- "cache_read": 0, "uncached_input": 0, "completion": 0, "cache_creation": 0
1626
- })
1627
- routing_totals = model_totals.get(ROUTING_MODEL, {
1628
- "cache_read": 0, "uncached_input": 0, "completion": 0, "cache_creation": 0
1629
- })
1630
-
1631
- total_base_tokens["cache_read"] += base_totals.get("cache_read", 0)
1632
- total_base_tokens["uncached_input"] += base_totals.get("uncached_input", 0)
1633
- total_base_tokens["completion"] += base_totals.get("completion", 0)
1634
- total_base_tokens["cache_creation"] += base_totals.get("cache_creation", 0)
1635
-
1636
- total_routing_tokens["cache_read"] += routing_totals.get("cache_read", 0)
1637
- total_routing_tokens["uncached_input"] += routing_totals.get("uncached_input", 0)
1638
- total_routing_tokens["completion"] += routing_totals.get("completion", 0)
1639
- total_routing_tokens["cache_creation"] += routing_totals.get("cache_creation", 0)
1640
 
1641
  original_steps = []
1642
  for step in steps:
@@ -1661,11 +1685,23 @@ def build_app():
1661
  tokens["completion"] * prices["completion"] / 1e6
1662
  )
1663
 
1664
- base_costs = {k: total_base_tokens[k] * base_prices[{"uncached_input": "input", "cache_read": "cache_read", "cache_creation": "cache_creation", "completion": "completion"}[k]] / 1e6 for k in total_base_tokens}
1665
- routing_costs = {k: total_routing_tokens[k] * routing_prices[{"uncached_input": "input", "cache_read": "cache_read", "cache_creation": "cache_creation", "completion": "completion"}[k]] / 1e6 for k in total_routing_tokens}
 
1666
 
 
 
1667
  total_base_cost = calc_cost(total_base_tokens, base_prices)
1668
- total_routing_cost = calc_cost(total_routing_tokens, routing_prices)
 
 
 
 
 
 
 
 
 
1669
 
1670
  if total_original_cost_from_df is not None:
1671
  total_original_cost = total_original_cost_from_df
@@ -1676,23 +1712,22 @@ def build_app():
1676
  savings = total_original_cost - total_routed_cost
1677
  savings_pct = (savings / total_original_cost * 100) if total_original_cost > 0 else 0
1678
 
1679
- result_text = f"""
1680
- ## πŸš€ Routing Results
1681
-
1682
- | Metric | Value |
1683
- |--------|-------|
1684
- | **Original Cost (base model only)** | ${total_original_cost:.2f} |
1685
- | **Routed Cost** | ${total_routed_cost:.2f} |
1686
- | ↳ Base model portion | ${total_base_cost:.2f} |
1687
- | ↳ Routing model portion | ${total_routing_cost:.2f} |
1688
- | **Savings** | ${savings:.2f} ({savings_pct:+.1f}%) |
1689
-
1690
- *Strategy: {strategy_1_val}*
1691
- *Routing model: {routing_model_1_val}*
1692
- """
1693
-
1694
- additional_token_models = [(routing_model_1_val, total_routing_tokens)]
1695
- additional_cost_models = [(routing_model_1_val, routing_costs)]
1696
 
1697
  yield (
1698
  gr.update(visible=True, value="⏳ Creating charts..."),
@@ -1718,6 +1753,10 @@ def build_app():
1718
  price_input, price_cache_read, price_cache_creation, price_completion,
1719
  routing_model_1, routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion,
1720
  strategy_1, random_pct_1, step_k_1, part_mode_1, start_step_1, end_step_1,
 
 
 
 
1721
  token_source, thinking_overhead, use_cache,
1722
  ],
1723
  outputs=[routing_result, routing_plots_row, routing_tokens_plot, routing_cost_plot],
 
1310
  choices=["Indexes", "Percentages"],
1311
  value="Percentages",
1312
  label="Mode",
1313
+ interactive=True,
1314
  )
1315
  with gr.Row():
1316
  start_step_1 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
 
1347
  choices=["Indexes", "Percentages"],
1348
  value="Percentages",
1349
  label="Mode",
1350
+ interactive=True,
1351
  )
1352
  with gr.Row():
1353
  start_step_2 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
 
1384
  choices=["Indexes", "Percentages"],
1385
  value="Percentages",
1386
  label="Mode",
1387
+ interactive=True,
1388
  )
1389
  with gr.Row():
1390
  start_step_3 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
 
1525
  base_input, base_cache_read, base_cache_creation, base_completion,
1526
  routing_model_1_val, r1_input, r1_cache_read, r1_cache_creation, r1_completion,
1527
  strategy_1_val, random_pct_1_val, step_k_1_val, part_mode_1_val, start_1_val, end_1_val,
1528
+ routing_model_2_val, r2_input, r2_cache_read, r2_cache_creation, r2_completion,
1529
+ strategy_2_val, random_pct_2_val, step_k_2_val, part_mode_2_val, start_2_val, end_2_val,
1530
+ routing_model_3_val, r3_input, r3_cache_read, r3_cache_creation, r3_completion,
1531
+ strategy_3_val, random_pct_3_val, step_k_3_val, part_mode_3_val, start_3_val, end_3_val,
1532
  source, overhead, with_cache
1533
  ):
1534
  if state_data is None:
 
1578
  "cache_creation": base_cache_creation,
1579
  "completion": base_completion,
1580
  }
 
 
 
 
 
 
1581
 
1582
+ def build_strategy_params(strategy, random_pct, step_k, part_mode, start_val, end_val):
1583
+ params = {}
1584
+ if strategy == "Replace on random steps":
1585
+ params["percentage"] = random_pct
1586
+ elif strategy == "Replace every step k":
1587
+ params["k"] = step_k
1588
+ elif strategy == "Replace part of trajectory":
1589
+ params["mode"] = part_mode
1590
+ params["start"] = start_val
1591
+ params["end"] = end_val
1592
+ return params
1593
+
1594
+ routing_models = []
1595
+ if routing_model_1_val:
1596
+ if strategy_1_val == "Replace part of trajectory" and start_1_val >= end_1_val:
1597
+ yield (gr.update(visible=True, value="❌ Model 1: Start must be less than End"), gr.update(visible=False), None, None)
1598
  return
1599
+ routing_models.append({
1600
+ "name": routing_model_1_val,
1601
+ "prices": {"input": r1_input, "cache_read": r1_cache_read, "cache_creation": r1_cache_creation, "completion": r1_completion},
1602
+ "strategy": strategy_1_val,
1603
+ "params": build_strategy_params(strategy_1_val, random_pct_1_val, step_k_1_val, part_mode_1_val, start_1_val, end_1_val),
1604
+ })
1605
+ if routing_model_2_val:
1606
+ if strategy_2_val == "Replace part of trajectory" and start_2_val >= end_2_val:
1607
+ yield (gr.update(visible=True, value="❌ Model 2: Start must be less than End"), gr.update(visible=False), None, None)
1608
+ return
1609
+ routing_models.append({
1610
+ "name": routing_model_2_val,
1611
+ "prices": {"input": r2_input, "cache_read": r2_cache_read, "cache_creation": r2_cache_creation, "completion": r2_completion},
1612
+ "strategy": strategy_2_val,
1613
+ "params": build_strategy_params(strategy_2_val, random_pct_2_val, step_k_2_val, part_mode_2_val, start_2_val, end_2_val),
1614
+ })
1615
+ if routing_model_3_val:
1616
+ if strategy_3_val == "Replace part of trajectory" and start_3_val >= end_3_val:
1617
+ yield (gr.update(visible=True, value="❌ Model 3: Start must be less than End"), gr.update(visible=False), None, None)
1618
+ return
1619
+ routing_models.append({
1620
+ "name": routing_model_3_val,
1621
+ "prices": {"input": r3_input, "cache_read": r3_cache_read, "cache_creation": r3_cache_creation, "completion": r3_completion},
1622
+ "strategy": strategy_3_val,
1623
+ "params": build_strategy_params(strategy_3_val, random_pct_3_val, step_k_3_val, part_mode_3_val, start_3_val, end_3_val),
1624
+ })
1625
 
1626
  BASE_MODEL = "__base__"
1627
+ model_keys = [BASE_MODEL] + [f"__routing_{i}__" for i in range(len(routing_models))]
1628
+
1629
+ all_tokens = {key: {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0} for key in model_keys}
1630
+ total_original_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
1631
 
1632
  for instance_id, steps in trajectory_steps.items():
1633
  if not steps:
1634
  continue
1635
 
1636
  total_steps = len(steps)
1637
+
1638
+ routed_sets = []
1639
+ for rm in routing_models:
1640
+ routed_sets.append(get_routed_steps(total_steps, rm["strategy"], rm["params"]))
1641
 
1642
  modified_steps = []
1643
  for i, step in enumerate(steps):
1644
+ model = BASE_MODEL
1645
+ for j, routed_set in enumerate(routed_sets):
1646
+ if i in routed_set:
1647
+ model = f"__routing_{j}__"
1648
+ break
1649
  modified_steps.append({
1650
  "model": model,
1651
  "system_user": step.get("system_user", 0),
 
1655
 
1656
  model_totals = calculate_routing_tokens(modified_steps)
1657
 
1658
+ for key in model_keys:
1659
+ totals = model_totals.get(key, {})
1660
+ all_tokens[key]["cache_read"] += totals.get("cache_read", 0)
1661
+ all_tokens[key]["uncached_input"] += totals.get("uncached_input", 0)
1662
+ all_tokens[key]["completion"] += totals.get("completion", 0)
1663
+ all_tokens[key]["cache_creation"] += totals.get("cache_creation", 0)
 
 
 
 
 
 
 
 
 
 
1664
 
1665
  original_steps = []
1666
  for step in steps:
 
1685
  tokens["completion"] * prices["completion"] / 1e6
1686
  )
1687
 
1688
+ def tokens_to_costs(tokens: dict, prices: dict) -> dict:
1689
+ price_map = {"uncached_input": "input", "cache_read": "cache_read", "cache_creation": "cache_creation", "completion": "completion"}
1690
+ return {k: tokens[k] * prices[price_map[k]] / 1e6 for k in tokens}
1691
 
1692
+ total_base_tokens = all_tokens[BASE_MODEL]
1693
+ base_costs = tokens_to_costs(total_base_tokens, base_prices)
1694
  total_base_cost = calc_cost(total_base_tokens, base_prices)
1695
+
1696
+ routing_costs_list = []
1697
+ total_routing_cost = 0
1698
+ for i, rm in enumerate(routing_models):
1699
+ key = f"__routing_{i}__"
1700
+ tokens = all_tokens[key]
1701
+ costs = tokens_to_costs(tokens, rm["prices"])
1702
+ cost = calc_cost(tokens, rm["prices"])
1703
+ routing_costs_list.append({"name": rm["name"], "tokens": tokens, "costs": costs, "cost": cost})
1704
+ total_routing_cost += cost
1705
 
1706
  if total_original_cost_from_df is not None:
1707
  total_original_cost = total_original_cost_from_df
 
1712
  savings = total_original_cost - total_routed_cost
1713
  savings_pct = (savings / total_original_cost * 100) if total_original_cost > 0 else 0
1714
 
1715
+ result_lines = [
1716
+ "## πŸš€ Routing Results",
1717
+ "",
1718
+ "| Metric | Value |",
1719
+ "|--------|-------|",
1720
+ f"| **Original Cost (base model only)** | ${total_original_cost:.2f} |",
1721
+ f"| **Routed Cost** | ${total_routed_cost:.2f} |",
1722
+ f"| ↳ Base model portion | ${total_base_cost:.2f} |",
1723
+ ]
1724
+ for rc in routing_costs_list:
1725
+ result_lines.append(f"| ↳ {rc['name']} | ${rc['cost']:.2f} |")
1726
+ result_lines.append(f"| **Savings** | ${savings:.2f} ({savings_pct:+.1f}%) |")
1727
+ result_text = "\n".join(result_lines)
1728
+
1729
+ additional_token_models = [(rc["name"], rc["tokens"]) for rc in routing_costs_list]
1730
+ additional_cost_models = [(rc["name"], rc["costs"]) for rc in routing_costs_list]
 
1731
 
1732
  yield (
1733
  gr.update(visible=True, value="⏳ Creating charts..."),
 
1753
  price_input, price_cache_read, price_cache_creation, price_completion,
1754
  routing_model_1, routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion,
1755
  strategy_1, random_pct_1, step_k_1, part_mode_1, start_step_1, end_step_1,
1756
+ routing_model_2, routing_price_2_input, routing_price_2_cache_read, routing_price_2_cache_creation, routing_price_2_completion,
1757
+ strategy_2, random_pct_2, step_k_2, part_mode_2, start_step_2, end_step_2,
1758
+ routing_model_3, routing_price_3_input, routing_price_3_cache_read, routing_price_3_cache_creation, routing_price_3_completion,
1759
+ strategy_3, random_pct_3, step_k_3, part_mode_3, start_step_3, end_step_3,
1760
  token_source, thinking_overhead, use_cache,
1761
  ],
1762
  outputs=[routing_result, routing_plots_row, routing_tokens_plot, routing_cost_plot],