Commit
Β·
c63e9d7
1
Parent(s):
1ae03ef
Support multiple routing models (up to 3)
Browse files- Add interactive=True to all part_mode Radio buttons
- Rewrite run_routing to support all 3 models
- Each model can have different strategy and parameters
- Results table shows cost breakdown per model
- Charts support any number of models with different colors
- Validate Start < End for each model separately
app.py
CHANGED
|
@@ -1310,6 +1310,7 @@ def build_app():
|
|
| 1310 |
choices=["Indexes", "Percentages"],
|
| 1311 |
value="Percentages",
|
| 1312 |
label="Mode",
|
|
|
|
| 1313 |
)
|
| 1314 |
with gr.Row():
|
| 1315 |
start_step_1 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
|
|
@@ -1346,6 +1347,7 @@ def build_app():
|
|
| 1346 |
choices=["Indexes", "Percentages"],
|
| 1347 |
value="Percentages",
|
| 1348 |
label="Mode",
|
|
|
|
| 1349 |
)
|
| 1350 |
with gr.Row():
|
| 1351 |
start_step_2 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
|
|
@@ -1382,6 +1384,7 @@ def build_app():
|
|
| 1382 |
choices=["Indexes", "Percentages"],
|
| 1383 |
value="Percentages",
|
| 1384 |
label="Mode",
|
|
|
|
| 1385 |
)
|
| 1386 |
with gr.Row():
|
| 1387 |
start_step_3 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
|
|
@@ -1522,6 +1525,10 @@ def build_app():
|
|
| 1522 |
base_input, base_cache_read, base_cache_creation, base_completion,
|
| 1523 |
routing_model_1_val, r1_input, r1_cache_read, r1_cache_creation, r1_completion,
|
| 1524 |
strategy_1_val, random_pct_1_val, step_k_1_val, part_mode_1_val, start_1_val, end_1_val,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1525 |
source, overhead, with_cache
|
| 1526 |
):
|
| 1527 |
if state_data is None:
|
|
@@ -1571,47 +1578,74 @@ def build_app():
|
|
| 1571 |
"cache_creation": base_cache_creation,
|
| 1572 |
"completion": base_completion,
|
| 1573 |
}
|
| 1574 |
-
routing_prices = {
|
| 1575 |
-
"input": r1_input,
|
| 1576 |
-
"cache_read": r1_cache_read,
|
| 1577 |
-
"cache_creation": r1_cache_creation,
|
| 1578 |
-
"completion": r1_completion,
|
| 1579 |
-
}
|
| 1580 |
|
| 1581 |
-
|
| 1582 |
-
|
| 1583 |
-
|
| 1584 |
-
|
| 1585 |
-
|
| 1586 |
-
|
| 1587 |
-
|
| 1588 |
-
|
| 1589 |
-
|
| 1590 |
-
|
| 1591 |
-
|
| 1592 |
-
|
| 1593 |
-
|
| 1594 |
-
|
| 1595 |
-
|
|
|
|
| 1596 |
return
|
| 1597 |
-
|
| 1598 |
-
|
| 1599 |
-
|
| 1600 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1601 |
|
| 1602 |
BASE_MODEL = "__base__"
|
| 1603 |
-
|
|
|
|
|
|
|
|
|
|
| 1604 |
|
| 1605 |
for instance_id, steps in trajectory_steps.items():
|
| 1606 |
if not steps:
|
| 1607 |
continue
|
| 1608 |
|
| 1609 |
total_steps = len(steps)
|
| 1610 |
-
|
|
|
|
|
|
|
|
|
|
| 1611 |
|
| 1612 |
modified_steps = []
|
| 1613 |
for i, step in enumerate(steps):
|
| 1614 |
-
model =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1615 |
modified_steps.append({
|
| 1616 |
"model": model,
|
| 1617 |
"system_user": step.get("system_user", 0),
|
|
@@ -1621,22 +1655,12 @@ def build_app():
|
|
| 1621 |
|
| 1622 |
model_totals = calculate_routing_tokens(modified_steps)
|
| 1623 |
|
| 1624 |
-
|
| 1625 |
-
|
| 1626 |
-
|
| 1627 |
-
|
| 1628 |
-
"
|
| 1629 |
-
|
| 1630 |
-
|
| 1631 |
-
total_base_tokens["cache_read"] += base_totals.get("cache_read", 0)
|
| 1632 |
-
total_base_tokens["uncached_input"] += base_totals.get("uncached_input", 0)
|
| 1633 |
-
total_base_tokens["completion"] += base_totals.get("completion", 0)
|
| 1634 |
-
total_base_tokens["cache_creation"] += base_totals.get("cache_creation", 0)
|
| 1635 |
-
|
| 1636 |
-
total_routing_tokens["cache_read"] += routing_totals.get("cache_read", 0)
|
| 1637 |
-
total_routing_tokens["uncached_input"] += routing_totals.get("uncached_input", 0)
|
| 1638 |
-
total_routing_tokens["completion"] += routing_totals.get("completion", 0)
|
| 1639 |
-
total_routing_tokens["cache_creation"] += routing_totals.get("cache_creation", 0)
|
| 1640 |
|
| 1641 |
original_steps = []
|
| 1642 |
for step in steps:
|
|
@@ -1661,11 +1685,23 @@ def build_app():
|
|
| 1661 |
tokens["completion"] * prices["completion"] / 1e6
|
| 1662 |
)
|
| 1663 |
|
| 1664 |
-
|
| 1665 |
-
|
|
|
|
| 1666 |
|
|
|
|
|
|
|
| 1667 |
total_base_cost = calc_cost(total_base_tokens, base_prices)
|
| 1668 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1669 |
|
| 1670 |
if total_original_cost_from_df is not None:
|
| 1671 |
total_original_cost = total_original_cost_from_df
|
|
@@ -1676,23 +1712,22 @@ def build_app():
|
|
| 1676 |
savings = total_original_cost - total_routed_cost
|
| 1677 |
savings_pct = (savings / total_original_cost * 100) if total_original_cost > 0 else 0
|
| 1678 |
|
| 1679 |
-
|
| 1680 |
-
## π Routing Results
|
| 1681 |
-
|
| 1682 |
-
| Metric | Value |
|
| 1683 |
-
|--------|-------|
|
| 1684 |
-
| **Original Cost (base model only)** | ${total_original_cost:.2f} |
|
| 1685 |
-
| **Routed Cost** | ${total_routed_cost:.2f} |
|
| 1686 |
-
| β³ Base model portion | ${total_base_cost:.2f} |
|
| 1687 |
-
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
-
|
| 1691 |
-
|
| 1692 |
-
|
| 1693 |
-
|
| 1694 |
-
|
| 1695 |
-
additional_cost_models = [(routing_model_1_val, routing_costs)]
|
| 1696 |
|
| 1697 |
yield (
|
| 1698 |
gr.update(visible=True, value="β³ Creating charts..."),
|
|
@@ -1718,6 +1753,10 @@ def build_app():
|
|
| 1718 |
price_input, price_cache_read, price_cache_creation, price_completion,
|
| 1719 |
routing_model_1, routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion,
|
| 1720 |
strategy_1, random_pct_1, step_k_1, part_mode_1, start_step_1, end_step_1,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1721 |
token_source, thinking_overhead, use_cache,
|
| 1722 |
],
|
| 1723 |
outputs=[routing_result, routing_plots_row, routing_tokens_plot, routing_cost_plot],
|
|
|
|
| 1310 |
choices=["Indexes", "Percentages"],
|
| 1311 |
value="Percentages",
|
| 1312 |
label="Mode",
|
| 1313 |
+
interactive=True,
|
| 1314 |
)
|
| 1315 |
with gr.Row():
|
| 1316 |
start_step_1 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
|
|
|
|
| 1347 |
choices=["Indexes", "Percentages"],
|
| 1348 |
value="Percentages",
|
| 1349 |
label="Mode",
|
| 1350 |
+
interactive=True,
|
| 1351 |
)
|
| 1352 |
with gr.Row():
|
| 1353 |
start_step_2 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
|
|
|
|
| 1384 |
choices=["Indexes", "Percentages"],
|
| 1385 |
value="Percentages",
|
| 1386 |
label="Mode",
|
| 1387 |
+
interactive=True,
|
| 1388 |
)
|
| 1389 |
with gr.Row():
|
| 1390 |
start_step_3 = gr.Number(label="Start", value=0, minimum=0, precision=0, interactive=True)
|
|
|
|
| 1525 |
base_input, base_cache_read, base_cache_creation, base_completion,
|
| 1526 |
routing_model_1_val, r1_input, r1_cache_read, r1_cache_creation, r1_completion,
|
| 1527 |
strategy_1_val, random_pct_1_val, step_k_1_val, part_mode_1_val, start_1_val, end_1_val,
|
| 1528 |
+
routing_model_2_val, r2_input, r2_cache_read, r2_cache_creation, r2_completion,
|
| 1529 |
+
strategy_2_val, random_pct_2_val, step_k_2_val, part_mode_2_val, start_2_val, end_2_val,
|
| 1530 |
+
routing_model_3_val, r3_input, r3_cache_read, r3_cache_creation, r3_completion,
|
| 1531 |
+
strategy_3_val, random_pct_3_val, step_k_3_val, part_mode_3_val, start_3_val, end_3_val,
|
| 1532 |
source, overhead, with_cache
|
| 1533 |
):
|
| 1534 |
if state_data is None:
|
|
|
|
| 1578 |
"cache_creation": base_cache_creation,
|
| 1579 |
"completion": base_completion,
|
| 1580 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1581 |
|
| 1582 |
+
def build_strategy_params(strategy, random_pct, step_k, part_mode, start_val, end_val):
|
| 1583 |
+
params = {}
|
| 1584 |
+
if strategy == "Replace on random steps":
|
| 1585 |
+
params["percentage"] = random_pct
|
| 1586 |
+
elif strategy == "Replace every step k":
|
| 1587 |
+
params["k"] = step_k
|
| 1588 |
+
elif strategy == "Replace part of trajectory":
|
| 1589 |
+
params["mode"] = part_mode
|
| 1590 |
+
params["start"] = start_val
|
| 1591 |
+
params["end"] = end_val
|
| 1592 |
+
return params
|
| 1593 |
+
|
| 1594 |
+
routing_models = []
|
| 1595 |
+
if routing_model_1_val:
|
| 1596 |
+
if strategy_1_val == "Replace part of trajectory" and start_1_val >= end_1_val:
|
| 1597 |
+
yield (gr.update(visible=True, value="β Model 1: Start must be less than End"), gr.update(visible=False), None, None)
|
| 1598 |
return
|
| 1599 |
+
routing_models.append({
|
| 1600 |
+
"name": routing_model_1_val,
|
| 1601 |
+
"prices": {"input": r1_input, "cache_read": r1_cache_read, "cache_creation": r1_cache_creation, "completion": r1_completion},
|
| 1602 |
+
"strategy": strategy_1_val,
|
| 1603 |
+
"params": build_strategy_params(strategy_1_val, random_pct_1_val, step_k_1_val, part_mode_1_val, start_1_val, end_1_val),
|
| 1604 |
+
})
|
| 1605 |
+
if routing_model_2_val:
|
| 1606 |
+
if strategy_2_val == "Replace part of trajectory" and start_2_val >= end_2_val:
|
| 1607 |
+
yield (gr.update(visible=True, value="β Model 2: Start must be less than End"), gr.update(visible=False), None, None)
|
| 1608 |
+
return
|
| 1609 |
+
routing_models.append({
|
| 1610 |
+
"name": routing_model_2_val,
|
| 1611 |
+
"prices": {"input": r2_input, "cache_read": r2_cache_read, "cache_creation": r2_cache_creation, "completion": r2_completion},
|
| 1612 |
+
"strategy": strategy_2_val,
|
| 1613 |
+
"params": build_strategy_params(strategy_2_val, random_pct_2_val, step_k_2_val, part_mode_2_val, start_2_val, end_2_val),
|
| 1614 |
+
})
|
| 1615 |
+
if routing_model_3_val:
|
| 1616 |
+
if strategy_3_val == "Replace part of trajectory" and start_3_val >= end_3_val:
|
| 1617 |
+
yield (gr.update(visible=True, value="β Model 3: Start must be less than End"), gr.update(visible=False), None, None)
|
| 1618 |
+
return
|
| 1619 |
+
routing_models.append({
|
| 1620 |
+
"name": routing_model_3_val,
|
| 1621 |
+
"prices": {"input": r3_input, "cache_read": r3_cache_read, "cache_creation": r3_cache_creation, "completion": r3_completion},
|
| 1622 |
+
"strategy": strategy_3_val,
|
| 1623 |
+
"params": build_strategy_params(strategy_3_val, random_pct_3_val, step_k_3_val, part_mode_3_val, start_3_val, end_3_val),
|
| 1624 |
+
})
|
| 1625 |
|
| 1626 |
BASE_MODEL = "__base__"
|
| 1627 |
+
model_keys = [BASE_MODEL] + [f"__routing_{i}__" for i in range(len(routing_models))]
|
| 1628 |
+
|
| 1629 |
+
all_tokens = {key: {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0} for key in model_keys}
|
| 1630 |
+
total_original_tokens = {"uncached_input": 0, "cache_read": 0, "cache_creation": 0, "completion": 0}
|
| 1631 |
|
| 1632 |
for instance_id, steps in trajectory_steps.items():
|
| 1633 |
if not steps:
|
| 1634 |
continue
|
| 1635 |
|
| 1636 |
total_steps = len(steps)
|
| 1637 |
+
|
| 1638 |
+
routed_sets = []
|
| 1639 |
+
for rm in routing_models:
|
| 1640 |
+
routed_sets.append(get_routed_steps(total_steps, rm["strategy"], rm["params"]))
|
| 1641 |
|
| 1642 |
modified_steps = []
|
| 1643 |
for i, step in enumerate(steps):
|
| 1644 |
+
model = BASE_MODEL
|
| 1645 |
+
for j, routed_set in enumerate(routed_sets):
|
| 1646 |
+
if i in routed_set:
|
| 1647 |
+
model = f"__routing_{j}__"
|
| 1648 |
+
break
|
| 1649 |
modified_steps.append({
|
| 1650 |
"model": model,
|
| 1651 |
"system_user": step.get("system_user", 0),
|
|
|
|
| 1655 |
|
| 1656 |
model_totals = calculate_routing_tokens(modified_steps)
|
| 1657 |
|
| 1658 |
+
for key in model_keys:
|
| 1659 |
+
totals = model_totals.get(key, {})
|
| 1660 |
+
all_tokens[key]["cache_read"] += totals.get("cache_read", 0)
|
| 1661 |
+
all_tokens[key]["uncached_input"] += totals.get("uncached_input", 0)
|
| 1662 |
+
all_tokens[key]["completion"] += totals.get("completion", 0)
|
| 1663 |
+
all_tokens[key]["cache_creation"] += totals.get("cache_creation", 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1664 |
|
| 1665 |
original_steps = []
|
| 1666 |
for step in steps:
|
|
|
|
| 1685 |
tokens["completion"] * prices["completion"] / 1e6
|
| 1686 |
)
|
| 1687 |
|
| 1688 |
+
def tokens_to_costs(tokens: dict, prices: dict) -> dict:
|
| 1689 |
+
price_map = {"uncached_input": "input", "cache_read": "cache_read", "cache_creation": "cache_creation", "completion": "completion"}
|
| 1690 |
+
return {k: tokens[k] * prices[price_map[k]] / 1e6 for k in tokens}
|
| 1691 |
|
| 1692 |
+
total_base_tokens = all_tokens[BASE_MODEL]
|
| 1693 |
+
base_costs = tokens_to_costs(total_base_tokens, base_prices)
|
| 1694 |
total_base_cost = calc_cost(total_base_tokens, base_prices)
|
| 1695 |
+
|
| 1696 |
+
routing_costs_list = []
|
| 1697 |
+
total_routing_cost = 0
|
| 1698 |
+
for i, rm in enumerate(routing_models):
|
| 1699 |
+
key = f"__routing_{i}__"
|
| 1700 |
+
tokens = all_tokens[key]
|
| 1701 |
+
costs = tokens_to_costs(tokens, rm["prices"])
|
| 1702 |
+
cost = calc_cost(tokens, rm["prices"])
|
| 1703 |
+
routing_costs_list.append({"name": rm["name"], "tokens": tokens, "costs": costs, "cost": cost})
|
| 1704 |
+
total_routing_cost += cost
|
| 1705 |
|
| 1706 |
if total_original_cost_from_df is not None:
|
| 1707 |
total_original_cost = total_original_cost_from_df
|
|
|
|
| 1712 |
savings = total_original_cost - total_routed_cost
|
| 1713 |
savings_pct = (savings / total_original_cost * 100) if total_original_cost > 0 else 0
|
| 1714 |
|
| 1715 |
+
result_lines = [
|
| 1716 |
+
"## π Routing Results",
|
| 1717 |
+
"",
|
| 1718 |
+
"| Metric | Value |",
|
| 1719 |
+
"|--------|-------|",
|
| 1720 |
+
f"| **Original Cost (base model only)** | ${total_original_cost:.2f} |",
|
| 1721 |
+
f"| **Routed Cost** | ${total_routed_cost:.2f} |",
|
| 1722 |
+
f"| β³ Base model portion | ${total_base_cost:.2f} |",
|
| 1723 |
+
]
|
| 1724 |
+
for rc in routing_costs_list:
|
| 1725 |
+
result_lines.append(f"| β³ {rc['name']} | ${rc['cost']:.2f} |")
|
| 1726 |
+
result_lines.append(f"| **Savings** | ${savings:.2f} ({savings_pct:+.1f}%) |")
|
| 1727 |
+
result_text = "\n".join(result_lines)
|
| 1728 |
+
|
| 1729 |
+
additional_token_models = [(rc["name"], rc["tokens"]) for rc in routing_costs_list]
|
| 1730 |
+
additional_cost_models = [(rc["name"], rc["costs"]) for rc in routing_costs_list]
|
|
|
|
| 1731 |
|
| 1732 |
yield (
|
| 1733 |
gr.update(visible=True, value="β³ Creating charts..."),
|
|
|
|
| 1753 |
price_input, price_cache_read, price_cache_creation, price_completion,
|
| 1754 |
routing_model_1, routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion,
|
| 1755 |
strategy_1, random_pct_1, step_k_1, part_mode_1, start_step_1, end_step_1,
|
| 1756 |
+
routing_model_2, routing_price_2_input, routing_price_2_cache_read, routing_price_2_cache_creation, routing_price_2_completion,
|
| 1757 |
+
strategy_2, random_pct_2, step_k_2, part_mode_2, start_step_2, end_step_2,
|
| 1758 |
+
routing_model_3, routing_price_3_input, routing_price_3_cache_read, routing_price_3_cache_creation, routing_price_3_completion,
|
| 1759 |
+
strategy_3, random_pct_3, step_k_3, part_mode_3, start_step_3, end_step_3,
|
| 1760 |
token_source, thinking_overhead, use_cache,
|
| 1761 |
],
|
| 1762 |
outputs=[routing_result, routing_plots_row, routing_tokens_plot, routing_cost_plot],
|