Commit
Β·
99badd3
1
Parent(s):
9399ab7
Add routing UI: model selection, strategy params, support .traj files
Browse files- Add routing section with up to 3 models
- Each model has: dropdown with search, price fields, strategy selection
- Strategy options: random steps, every k, part of trajectory
- Support both .traj.json and .traj file formats
- Load & Analyze auto-downloads trajectories if needed
- Add Routing button appears only after successful load
- Fix error handling for models without trajectories on S3
- Compact token prices UI
app.py
CHANGED
|
@@ -28,6 +28,21 @@ _trajectories_cache = {}
|
|
| 28 |
_calculated_tokens_cache = {}
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
def get_default_overhead(model_name: str) -> float:
|
| 32 |
"""Get default tokenizer overhead for model provider"""
|
| 33 |
model_lower = model_name.lower() if model_name else ""
|
|
@@ -172,8 +187,12 @@ def load_all_trajectories_calculated(folder: str) -> pd.DataFrame:
|
|
| 172 |
output_dir = TRAJS_DIR / folder
|
| 173 |
|
| 174 |
traj_files = list(output_dir.glob("*/*.traj.json"))
|
|
|
|
|
|
|
| 175 |
if not traj_files:
|
| 176 |
traj_files = list(output_dir.glob("*.traj.json"))
|
|
|
|
|
|
|
| 177 |
if not traj_files:
|
| 178 |
traj_files = list(output_dir.glob("*.json"))
|
| 179 |
|
|
@@ -212,6 +231,12 @@ def load_all_trajectories_calculated(folder: str) -> pd.DataFrame:
|
|
| 212 |
return df
|
| 213 |
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
def get_litellm_prices() -> dict:
|
| 216 |
global _litellm_prices_cache
|
| 217 |
if _litellm_prices_cache is not None:
|
|
@@ -357,6 +382,8 @@ def download_trajectories_from_s3(folder: str, progress=gr.Progress()):
|
|
| 357 |
output_dir = TRAJS_DIR / folder
|
| 358 |
if output_dir.exists() and any(output_dir.iterdir()):
|
| 359 |
file_count = len(list(output_dir.glob("*/*.traj.json")))
|
|
|
|
|
|
|
| 360 |
if file_count == 0:
|
| 361 |
file_count = len(list(output_dir.glob("*.json")))
|
| 362 |
return f"β
Already downloaded: {output_dir}\n\n{file_count} trajectory files", gr.update(visible=True)
|
|
@@ -378,9 +405,14 @@ def download_trajectories_from_s3(folder: str, progress=gr.Progress()):
|
|
| 378 |
return f"β S3 download failed:\n{result.stderr}", gr.update(visible=False)
|
| 379 |
|
| 380 |
file_count = len(list(output_dir.glob("*/*.traj.json")))
|
|
|
|
|
|
|
| 381 |
if file_count == 0:
|
| 382 |
file_count = len(list(output_dir.glob("*.json")))
|
| 383 |
|
|
|
|
|
|
|
|
|
|
| 384 |
per_instance = model.get("per_instance_details", {})
|
| 385 |
resolved_count = sum(1 for v in per_instance.values() if v.get("resolved"))
|
| 386 |
total_count = len(per_instance)
|
|
@@ -452,8 +484,12 @@ def load_all_trajectories(folder: str) -> pd.DataFrame:
|
|
| 452 |
output_dir = TRAJS_DIR / folder
|
| 453 |
|
| 454 |
traj_files = list(output_dir.glob("*/*.traj.json"))
|
|
|
|
|
|
|
| 455 |
if not traj_files:
|
| 456 |
traj_files = list(output_dir.glob("*.traj.json"))
|
|
|
|
|
|
|
| 457 |
if not traj_files:
|
| 458 |
traj_files = list(output_dir.glob("*.json"))
|
| 459 |
|
|
@@ -873,12 +909,11 @@ def on_row_select(evt: gr.SelectData, df: pd.DataFrame):
|
|
| 873 |
if evt.index is None:
|
| 874 |
return (
|
| 875 |
"", "",
|
| 876 |
-
gr.update(interactive=False),
|
| 877 |
gr.update(visible=False),
|
| 878 |
-
gr.update(value=0, label="
|
| 879 |
-
gr.update(value=0, label="
|
| 880 |
-
gr.update(value=0, label="
|
| 881 |
-
gr.update(value=0, label="
|
| 882 |
"",
|
| 883 |
gr.update(value=1.0),
|
| 884 |
)
|
|
@@ -888,8 +923,6 @@ def on_row_select(evt: gr.SelectData, df: pd.DataFrame):
|
|
| 888 |
folder = row["folder"]
|
| 889 |
name = row["name"]
|
| 890 |
|
| 891 |
-
show_analyze = check_trajectories_downloaded(folder)
|
| 892 |
-
|
| 893 |
prices_dict, model_hint = get_prices_for_folder(folder)
|
| 894 |
default_overhead = get_default_overhead(model_hint)
|
| 895 |
|
|
@@ -898,14 +931,13 @@ def on_row_select(evt: gr.SelectData, df: pd.DataFrame):
|
|
| 898 |
if price_info["found"]:
|
| 899 |
return gr.update(value=value, label=f"β
{name}")
|
| 900 |
elif value > 0:
|
| 901 |
-
return gr.update(value=value, label=f"β {name} (
|
| 902 |
else:
|
| 903 |
return gr.update(value=0, label=f"β {name}")
|
| 904 |
|
| 905 |
return (
|
| 906 |
folder, name,
|
| 907 |
-
gr.update(
|
| 908 |
-
gr.update(visible=show_analyze),
|
| 909 |
price_update(prices_dict["input"], "Input"),
|
| 910 |
price_update(prices_dict["cache_read"], "Cache Read"),
|
| 911 |
price_update(prices_dict["cache_creation"], "Cache Creation"),
|
|
@@ -953,18 +985,17 @@ def build_app():
|
|
| 953 |
gr.Markdown("### Selected Model")
|
| 954 |
selected_name = gr.Textbox(label="Model Name", interactive=False)
|
| 955 |
|
| 956 |
-
download_btn = gr.Button("π₯ Download Trajectories", interactive=False)
|
| 957 |
-
download_status = gr.Textbox(label="Status", interactive=False, lines=3)
|
| 958 |
-
|
| 959 |
analyze_btn = gr.Button("π Load & Analyze", visible=False, variant="primary")
|
|
|
|
| 960 |
|
| 961 |
gr.Markdown("---")
|
| 962 |
gr.Markdown("### π° Token Prices ($/1M) Β· *[litellm](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)*")
|
| 963 |
detected_model = gr.Textbox(label="Detected Model", interactive=False)
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
|
|
|
|
| 968 |
|
| 969 |
gr.Markdown("---")
|
| 970 |
gr.Markdown("### π Token Count Source")
|
|
@@ -986,6 +1017,231 @@ def build_app():
|
|
| 986 |
visible=False,
|
| 987 |
)
|
| 988 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 989 |
def update_calculated_options_visibility(source):
|
| 990 |
is_calc = source == "Calculated"
|
| 991 |
return gr.update(visible=is_calc), gr.update(visible=is_calc)
|
|
@@ -999,30 +1255,47 @@ def build_app():
|
|
| 999 |
leaderboard_table.select(
|
| 1000 |
fn=on_row_select,
|
| 1001 |
inputs=[leaderboard_table],
|
| 1002 |
-
outputs=[selected_folder, selected_name,
|
| 1003 |
-
)
|
| 1004 |
-
|
| 1005 |
-
download_btn.click(
|
| 1006 |
-
fn=download_trajectories_from_s3,
|
| 1007 |
-
inputs=[selected_folder],
|
| 1008 |
-
outputs=[download_status, analyze_btn],
|
| 1009 |
)
|
| 1010 |
|
| 1011 |
-
def load_and_analyze(folder, input_price, cache_read_price, cache_creation_price, completion_price, source, overhead, with_cache):
|
| 1012 |
empty_result = (
|
|
|
|
| 1013 |
gr.update(visible=False),
|
| 1014 |
None, None, None, None, None, None,
|
| 1015 |
None,
|
|
|
|
| 1016 |
)
|
| 1017 |
|
| 1018 |
if not folder:
|
| 1019 |
yield empty_result
|
| 1020 |
return
|
| 1021 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1022 |
yield (
|
|
|
|
| 1023 |
gr.update(visible=True),
|
| 1024 |
None, None, None, None, None, None,
|
| 1025 |
None,
|
|
|
|
| 1026 |
)
|
| 1027 |
|
| 1028 |
df_meta = load_all_trajectories(folder)
|
|
@@ -1040,7 +1313,13 @@ def build_app():
|
|
| 1040 |
df = apply_no_cache(df)
|
| 1041 |
|
| 1042 |
if df.empty:
|
| 1043 |
-
yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1044 |
return
|
| 1045 |
|
| 1046 |
fig_steps, fig_cost, fig_tokens, fig_tokens_cost, fig_stacked = create_basic_histograms(
|
|
@@ -1049,18 +1328,22 @@ def build_app():
|
|
| 1049 |
fig_cost_breakdown = create_cost_breakdown(df, input_price, cache_read_price, cache_creation_price, completion_price)
|
| 1050 |
|
| 1051 |
yield (
|
|
|
|
| 1052 |
gr.update(visible=True),
|
| 1053 |
fig_steps, fig_cost, fig_tokens, fig_tokens_cost, fig_stacked, fig_cost_breakdown,
|
| 1054 |
state_data,
|
|
|
|
| 1055 |
)
|
| 1056 |
|
| 1057 |
analyze_btn.click(
|
| 1058 |
fn=load_and_analyze,
|
| 1059 |
inputs=[selected_folder, price_input, price_cache_read, price_cache_creation, price_completion, token_source, thinking_overhead, use_cache],
|
| 1060 |
outputs=[
|
|
|
|
| 1061 |
analysis_section,
|
| 1062 |
plot_steps, plot_cost, plot_tokens, plot_tokens_cost, plot_stacked, plot_cost_breakdown,
|
| 1063 |
trajectories_state,
|
|
|
|
| 1064 |
],
|
| 1065 |
)
|
| 1066 |
|
|
|
|
| 28 |
_calculated_tokens_cache = {}
|
| 29 |
|
| 30 |
|
| 31 |
+
def parse_step_or_ratio(value: float, total_steps: int) -> int:
|
| 32 |
+
"""
|
| 33 |
+
Parse a value as either step number or ratio.
|
| 34 |
+
|
| 35 |
+
If value is integer (e.g., 3.0, 5.0) -> treat as step number
|
| 36 |
+
If value is float with decimal (e.g., 0.5, 0.25) -> treat as ratio of total_steps
|
| 37 |
+
|
| 38 |
+
Returns: step index (0-based)
|
| 39 |
+
"""
|
| 40 |
+
if value == int(value) and value >= 1:
|
| 41 |
+
return int(value)
|
| 42 |
+
else:
|
| 43 |
+
return int(value * total_steps)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
def get_default_overhead(model_name: str) -> float:
|
| 47 |
"""Get default tokenizer overhead for model provider"""
|
| 48 |
model_lower = model_name.lower() if model_name else ""
|
|
|
|
| 187 |
output_dir = TRAJS_DIR / folder
|
| 188 |
|
| 189 |
traj_files = list(output_dir.glob("*/*.traj.json"))
|
| 190 |
+
if not traj_files:
|
| 191 |
+
traj_files = list(output_dir.glob("*/*.traj"))
|
| 192 |
if not traj_files:
|
| 193 |
traj_files = list(output_dir.glob("*.traj.json"))
|
| 194 |
+
if not traj_files:
|
| 195 |
+
traj_files = list(output_dir.glob("*.traj"))
|
| 196 |
if not traj_files:
|
| 197 |
traj_files = list(output_dir.glob("*.json"))
|
| 198 |
|
|
|
|
| 231 |
return df
|
| 232 |
|
| 233 |
|
| 234 |
+
def get_litellm_model_list() -> list[str]:
|
| 235 |
+
"""Get list of model names from litellm prices"""
|
| 236 |
+
prices = get_litellm_prices()
|
| 237 |
+
return sorted(prices.keys())
|
| 238 |
+
|
| 239 |
+
|
| 240 |
def get_litellm_prices() -> dict:
|
| 241 |
global _litellm_prices_cache
|
| 242 |
if _litellm_prices_cache is not None:
|
|
|
|
| 382 |
output_dir = TRAJS_DIR / folder
|
| 383 |
if output_dir.exists() and any(output_dir.iterdir()):
|
| 384 |
file_count = len(list(output_dir.glob("*/*.traj.json")))
|
| 385 |
+
if file_count == 0:
|
| 386 |
+
file_count = len(list(output_dir.glob("*/*.traj")))
|
| 387 |
if file_count == 0:
|
| 388 |
file_count = len(list(output_dir.glob("*.json")))
|
| 389 |
return f"β
Already downloaded: {output_dir}\n\n{file_count} trajectory files", gr.update(visible=True)
|
|
|
|
| 405 |
return f"β S3 download failed:\n{result.stderr}", gr.update(visible=False)
|
| 406 |
|
| 407 |
file_count = len(list(output_dir.glob("*/*.traj.json")))
|
| 408 |
+
if file_count == 0:
|
| 409 |
+
file_count = len(list(output_dir.glob("*/*.traj")))
|
| 410 |
if file_count == 0:
|
| 411 |
file_count = len(list(output_dir.glob("*.json")))
|
| 412 |
|
| 413 |
+
if file_count == 0:
|
| 414 |
+
return f"β No trajectory files found on S3 for {folder}", gr.update(visible=False)
|
| 415 |
+
|
| 416 |
per_instance = model.get("per_instance_details", {})
|
| 417 |
resolved_count = sum(1 for v in per_instance.values() if v.get("resolved"))
|
| 418 |
total_count = len(per_instance)
|
|
|
|
| 484 |
output_dir = TRAJS_DIR / folder
|
| 485 |
|
| 486 |
traj_files = list(output_dir.glob("*/*.traj.json"))
|
| 487 |
+
if not traj_files:
|
| 488 |
+
traj_files = list(output_dir.glob("*/*.traj"))
|
| 489 |
if not traj_files:
|
| 490 |
traj_files = list(output_dir.glob("*.traj.json"))
|
| 491 |
+
if not traj_files:
|
| 492 |
+
traj_files = list(output_dir.glob("*.traj"))
|
| 493 |
if not traj_files:
|
| 494 |
traj_files = list(output_dir.glob("*.json"))
|
| 495 |
|
|
|
|
| 909 |
if evt.index is None:
|
| 910 |
return (
|
| 911 |
"", "",
|
|
|
|
| 912 |
gr.update(visible=False),
|
| 913 |
+
gr.update(value=0, label="Input"),
|
| 914 |
+
gr.update(value=0, label="Cache Read"),
|
| 915 |
+
gr.update(value=0, label="Cache Creation"),
|
| 916 |
+
gr.update(value=0, label="Completion"),
|
| 917 |
"",
|
| 918 |
gr.update(value=1.0),
|
| 919 |
)
|
|
|
|
| 923 |
folder = row["folder"]
|
| 924 |
name = row["name"]
|
| 925 |
|
|
|
|
|
|
|
| 926 |
prices_dict, model_hint = get_prices_for_folder(folder)
|
| 927 |
default_overhead = get_default_overhead(model_hint)
|
| 928 |
|
|
|
|
| 931 |
if price_info["found"]:
|
| 932 |
return gr.update(value=value, label=f"β
{name}")
|
| 933 |
elif value > 0:
|
| 934 |
+
return gr.update(value=value, label=f"β {name} (est.)")
|
| 935 |
else:
|
| 936 |
return gr.update(value=0, label=f"β {name}")
|
| 937 |
|
| 938 |
return (
|
| 939 |
folder, name,
|
| 940 |
+
gr.update(visible=True),
|
|
|
|
| 941 |
price_update(prices_dict["input"], "Input"),
|
| 942 |
price_update(prices_dict["cache_read"], "Cache Read"),
|
| 943 |
price_update(prices_dict["cache_creation"], "Cache Creation"),
|
|
|
|
| 985 |
gr.Markdown("### Selected Model")
|
| 986 |
selected_name = gr.Textbox(label="Model Name", interactive=False)
|
| 987 |
|
|
|
|
|
|
|
|
|
|
| 988 |
analyze_btn = gr.Button("π Load & Analyze", visible=False, variant="primary")
|
| 989 |
+
download_status = gr.Textbox(label="Status", interactive=False, lines=3)
|
| 990 |
|
| 991 |
gr.Markdown("---")
|
| 992 |
gr.Markdown("### π° Token Prices ($/1M) Β· *[litellm](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)*")
|
| 993 |
detected_model = gr.Textbox(label="Detected Model", interactive=False)
|
| 994 |
+
with gr.Row():
|
| 995 |
+
price_input = gr.Number(label="Input", value=0, precision=2, scale=1)
|
| 996 |
+
price_cache_read = gr.Number(label="Cache Read", value=0, precision=2, scale=1)
|
| 997 |
+
price_cache_creation = gr.Number(label="Cache Creation", value=0, precision=2, scale=1)
|
| 998 |
+
price_completion = gr.Number(label="Completion", value=0, precision=2, scale=1)
|
| 999 |
|
| 1000 |
gr.Markdown("---")
|
| 1001 |
gr.Markdown("### π Token Count Source")
|
|
|
|
| 1017 |
visible=False,
|
| 1018 |
)
|
| 1019 |
|
| 1020 |
+
gr.Markdown("---")
|
| 1021 |
+
add_routing_btn = gr.Button("β Add Routing", variant="primary", visible=False)
|
| 1022 |
+
|
| 1023 |
+
with gr.Column(visible=False) as routing_section:
|
| 1024 |
+
gr.Markdown("### π Routing Models")
|
| 1025 |
+
|
| 1026 |
+
STRATEGY_CHOICES = [
|
| 1027 |
+
"Replace on random steps",
|
| 1028 |
+
"Replace every step k",
|
| 1029 |
+
"Replace part of trajectory",
|
| 1030 |
+
]
|
| 1031 |
+
|
| 1032 |
+
with gr.Column():
|
| 1033 |
+
with gr.Group():
|
| 1034 |
+
gr.Markdown("#### Route to Model 1")
|
| 1035 |
+
routing_model_1 = gr.Dropdown(
|
| 1036 |
+
label="Model (type 3+ chars to search)",
|
| 1037 |
+
choices=[],
|
| 1038 |
+
allow_custom_value=True,
|
| 1039 |
+
interactive=True,
|
| 1040 |
+
)
|
| 1041 |
+
with gr.Row():
|
| 1042 |
+
routing_price_1_input = gr.Number(label="Input", precision=3, scale=1)
|
| 1043 |
+
routing_price_1_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
|
| 1044 |
+
routing_price_1_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
|
| 1045 |
+
routing_price_1_completion = gr.Number(label="Completion", precision=3, scale=1)
|
| 1046 |
+
strategy_1 = gr.Dropdown(
|
| 1047 |
+
label="Strategy",
|
| 1048 |
+
choices=STRATEGY_CHOICES,
|
| 1049 |
+
value="Replace on random steps",
|
| 1050 |
+
interactive=True,
|
| 1051 |
+
)
|
| 1052 |
+
with gr.Row(visible=True) as random_params_1:
|
| 1053 |
+
random_pct_1 = gr.Number(label="Percentage (%)", value=50, minimum=0, maximum=100, precision=0, interactive=True)
|
| 1054 |
+
with gr.Row(visible=False) as every_k_params_1:
|
| 1055 |
+
step_k_1 = gr.Number(label="k", value=2, minimum=1, precision=0, interactive=True)
|
| 1056 |
+
with gr.Row(visible=False) as part_params_1:
|
| 1057 |
+
start_step_1 = gr.Number(label="Start (int=step; 0,0-1,0=ratio)", value=0, minimum=0, precision=2, interactive=True)
|
| 1058 |
+
end_step_1 = gr.Number(label="End (int=step; 0,0-1,0=ratio)", value=0.5, minimum=0, precision=2, interactive=True)
|
| 1059 |
+
|
| 1060 |
+
add_model_2_btn = gr.Button("+ Add another model", size="sm", visible=False)
|
| 1061 |
+
|
| 1062 |
+
with gr.Column(visible=False) as routing_block_2:
|
| 1063 |
+
with gr.Group():
|
| 1064 |
+
gr.Markdown("#### Route to Model 2")
|
| 1065 |
+
routing_model_2 = gr.Dropdown(
|
| 1066 |
+
label="Model (type 3+ chars to search)",
|
| 1067 |
+
choices=[],
|
| 1068 |
+
allow_custom_value=True,
|
| 1069 |
+
interactive=True,
|
| 1070 |
+
)
|
| 1071 |
+
with gr.Row():
|
| 1072 |
+
routing_price_2_input = gr.Number(label="Input", precision=3, scale=1)
|
| 1073 |
+
routing_price_2_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
|
| 1074 |
+
routing_price_2_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
|
| 1075 |
+
routing_price_2_completion = gr.Number(label="Completion", precision=3, scale=1)
|
| 1076 |
+
strategy_2 = gr.Dropdown(
|
| 1077 |
+
label="Strategy",
|
| 1078 |
+
choices=STRATEGY_CHOICES,
|
| 1079 |
+
value="Replace on random steps",
|
| 1080 |
+
interactive=True,
|
| 1081 |
+
)
|
| 1082 |
+
with gr.Row(visible=True) as random_params_2:
|
| 1083 |
+
random_pct_2 = gr.Number(label="Percentage (%)", value=50, minimum=0, maximum=100, precision=0, interactive=True)
|
| 1084 |
+
with gr.Row(visible=False) as every_k_params_2:
|
| 1085 |
+
step_k_2 = gr.Number(label="k", value=2, minimum=1, precision=0, interactive=True)
|
| 1086 |
+
with gr.Row(visible=False) as part_params_2:
|
| 1087 |
+
start_step_2 = gr.Number(label="Start (int=step; 0,0-1,0=ratio)", value=0, minimum=0, precision=2, interactive=True)
|
| 1088 |
+
end_step_2 = gr.Number(label="End (int=step; 0,0-1,0=ratio)", value=0.5, minimum=0, precision=2, interactive=True)
|
| 1089 |
+
|
| 1090 |
+
add_model_3_btn = gr.Button("+ Add another model", size="sm", visible=False)
|
| 1091 |
+
|
| 1092 |
+
with gr.Column(visible=False) as routing_block_3:
|
| 1093 |
+
with gr.Group():
|
| 1094 |
+
gr.Markdown("#### Route to Model 3")
|
| 1095 |
+
routing_model_3 = gr.Dropdown(
|
| 1096 |
+
label="Model (type 3+ chars to search)",
|
| 1097 |
+
choices=[],
|
| 1098 |
+
allow_custom_value=True,
|
| 1099 |
+
interactive=True,
|
| 1100 |
+
)
|
| 1101 |
+
with gr.Row():
|
| 1102 |
+
routing_price_3_input = gr.Number(label="Input", precision=3, scale=1)
|
| 1103 |
+
routing_price_3_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
|
| 1104 |
+
routing_price_3_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
|
| 1105 |
+
routing_price_3_completion = gr.Number(label="Completion", precision=3, scale=1)
|
| 1106 |
+
strategy_3 = gr.Dropdown(
|
| 1107 |
+
label="Strategy",
|
| 1108 |
+
choices=STRATEGY_CHOICES,
|
| 1109 |
+
value="Replace on random steps",
|
| 1110 |
+
interactive=True,
|
| 1111 |
+
)
|
| 1112 |
+
with gr.Row(visible=True) as random_params_3:
|
| 1113 |
+
random_pct_3 = gr.Number(label="Percentage (%)", value=50, minimum=0, maximum=100, precision=0, interactive=True)
|
| 1114 |
+
with gr.Row(visible=False) as every_k_params_3:
|
| 1115 |
+
step_k_3 = gr.Number(label="k", value=2, minimum=1, precision=0, interactive=True)
|
| 1116 |
+
with gr.Row(visible=False) as part_params_3:
|
| 1117 |
+
start_step_3 = gr.Number(label="Start (int=step; 0,0-1,0=ratio)", value=0, minimum=0, precision=2, interactive=True)
|
| 1118 |
+
end_step_3 = gr.Number(label="End (int=step; 0,0-1,0=ratio)", value=0.5, minimum=0, precision=2, interactive=True)
|
| 1119 |
+
|
| 1120 |
+
def on_strategy_change(strategy):
|
| 1121 |
+
return (
|
| 1122 |
+
gr.update(visible=strategy == "Replace on random steps"),
|
| 1123 |
+
gr.update(visible=strategy == "Replace every step k"),
|
| 1124 |
+
gr.update(visible=strategy == "Replace part of trajectory"),
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
def toggle_routing_section():
|
| 1128 |
+
return gr.update(visible=True)
|
| 1129 |
+
|
| 1130 |
+
add_routing_btn.click(
|
| 1131 |
+
fn=toggle_routing_section,
|
| 1132 |
+
outputs=[routing_section],
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
strategy_1.change(
|
| 1136 |
+
fn=on_strategy_change,
|
| 1137 |
+
inputs=[strategy_1],
|
| 1138 |
+
outputs=[random_params_1, every_k_params_1, part_params_1],
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
strategy_2.change(
|
| 1142 |
+
fn=on_strategy_change,
|
| 1143 |
+
inputs=[strategy_2],
|
| 1144 |
+
outputs=[random_params_2, every_k_params_2, part_params_2],
|
| 1145 |
+
)
|
| 1146 |
+
|
| 1147 |
+
strategy_3.change(
|
| 1148 |
+
fn=on_strategy_change,
|
| 1149 |
+
inputs=[strategy_3],
|
| 1150 |
+
outputs=[random_params_3, every_k_params_3, part_params_3],
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
def filter_models(query):
|
| 1154 |
+
"""Filter models based on search query (starts at 3 chars)"""
|
| 1155 |
+
if not query or len(query) < 3:
|
| 1156 |
+
return gr.update(choices=[])
|
| 1157 |
+
all_models = get_litellm_model_list()
|
| 1158 |
+
query_lower = query.lower()
|
| 1159 |
+
filtered = [m for m in all_models if query_lower in m.lower()][:50]
|
| 1160 |
+
return gr.update(choices=filtered)
|
| 1161 |
+
|
| 1162 |
+
routing_model_1.input(fn=filter_models, inputs=[routing_model_1], outputs=[routing_model_1])
|
| 1163 |
+
routing_model_2.input(fn=filter_models, inputs=[routing_model_2], outputs=[routing_model_2])
|
| 1164 |
+
routing_model_3.input(fn=filter_models, inputs=[routing_model_3], outputs=[routing_model_3])
|
| 1165 |
+
|
| 1166 |
+
def get_routing_prices_with_labels(model_name):
|
| 1167 |
+
"""Get all 4 prices for a routing model with found/estimated labels"""
|
| 1168 |
+
if not model_name:
|
| 1169 |
+
return (
|
| 1170 |
+
gr.update(value=0, label="Input"),
|
| 1171 |
+
gr.update(value=0, label="Cache Read"),
|
| 1172 |
+
gr.update(value=0, label="Cache Creation"),
|
| 1173 |
+
gr.update(value=0, label="Completion"),
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
prices = get_litellm_prices()
|
| 1177 |
+
model_prices = prices.get(model_name, {})
|
| 1178 |
+
|
| 1179 |
+
input_price = model_prices.get("input_cost_per_token", 0) * 1e6
|
| 1180 |
+
cache_read = model_prices.get("cache_read_input_token_cost", 0) * 1e6
|
| 1181 |
+
cache_creation = model_prices.get("cache_creation_input_token_cost", 0) * 1e6
|
| 1182 |
+
completion = model_prices.get("output_cost_per_token", 0) * 1e6
|
| 1183 |
+
|
| 1184 |
+
input_found = input_price > 0
|
| 1185 |
+
cache_read_found = cache_read > 0
|
| 1186 |
+
cache_creation_found = cache_creation > 0
|
| 1187 |
+
completion_found = completion > 0
|
| 1188 |
+
|
| 1189 |
+
if not cache_read_found and input_price > 0:
|
| 1190 |
+
cache_read = input_price * 0.1
|
| 1191 |
+
if not cache_creation_found and input_price > 0:
|
| 1192 |
+
cache_creation = input_price * 1.25
|
| 1193 |
+
|
| 1194 |
+
def label(name, found):
|
| 1195 |
+
return f"β
{name}" if found else f"β {name}"
|
| 1196 |
+
|
| 1197 |
+
return (
|
| 1198 |
+
gr.update(value=input_price, label=label("Input", input_found)),
|
| 1199 |
+
gr.update(value=cache_read, label=label("Cache Read", cache_read_found)),
|
| 1200 |
+
gr.update(value=cache_creation, label=label("Cache Creation", cache_creation_found)),
|
| 1201 |
+
gr.update(value=completion, label=label("Completion", completion_found)),
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
def on_routing_model_1_select(model_name):
|
| 1205 |
+
prices = get_routing_prices_with_labels(model_name)
|
| 1206 |
+
show_btn = bool(model_name)
|
| 1207 |
+
return *prices, gr.update(visible=show_btn)
|
| 1208 |
+
|
| 1209 |
+
def on_routing_model_2_select(model_name):
|
| 1210 |
+
prices = get_routing_prices_with_labels(model_name)
|
| 1211 |
+
show_btn = bool(model_name)
|
| 1212 |
+
return *prices, gr.update(visible=show_btn)
|
| 1213 |
+
|
| 1214 |
+
def on_routing_model_3_select(model_name):
|
| 1215 |
+
return get_routing_prices_with_labels(model_name)
|
| 1216 |
+
|
| 1217 |
+
routing_model_1.change(
|
| 1218 |
+
fn=on_routing_model_1_select,
|
| 1219 |
+
inputs=[routing_model_1],
|
| 1220 |
+
outputs=[routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion, add_model_2_btn],
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
add_model_2_btn.click(
|
| 1224 |
+
fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
|
| 1225 |
+
outputs=[routing_block_2, add_model_2_btn],
|
| 1226 |
+
)
|
| 1227 |
+
|
| 1228 |
+
routing_model_2.change(
|
| 1229 |
+
fn=on_routing_model_2_select,
|
| 1230 |
+
inputs=[routing_model_2],
|
| 1231 |
+
outputs=[routing_price_2_input, routing_price_2_cache_read, routing_price_2_cache_creation, routing_price_2_completion, add_model_3_btn],
|
| 1232 |
+
)
|
| 1233 |
+
|
| 1234 |
+
add_model_3_btn.click(
|
| 1235 |
+
fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
|
| 1236 |
+
outputs=[routing_block_3, add_model_3_btn],
|
| 1237 |
+
)
|
| 1238 |
+
|
| 1239 |
+
routing_model_3.change(
|
| 1240 |
+
fn=on_routing_model_3_select,
|
| 1241 |
+
inputs=[routing_model_3],
|
| 1242 |
+
outputs=[routing_price_3_input, routing_price_3_cache_read, routing_price_3_cache_creation, routing_price_3_completion],
|
| 1243 |
+
)
|
| 1244 |
+
|
| 1245 |
def update_calculated_options_visibility(source):
|
| 1246 |
is_calc = source == "Calculated"
|
| 1247 |
return gr.update(visible=is_calc), gr.update(visible=is_calc)
|
|
|
|
| 1255 |
leaderboard_table.select(
|
| 1256 |
fn=on_row_select,
|
| 1257 |
inputs=[leaderboard_table],
|
| 1258 |
+
outputs=[selected_folder, selected_name, analyze_btn, price_input, price_cache_read, price_cache_creation, price_completion, detected_model, thinking_overhead],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1259 |
)
|
| 1260 |
|
| 1261 |
+
def load_and_analyze(folder, input_price, cache_read_price, cache_creation_price, completion_price, source, overhead, with_cache, progress=gr.Progress()):
|
| 1262 |
empty_result = (
|
| 1263 |
+
"",
|
| 1264 |
gr.update(visible=False),
|
| 1265 |
None, None, None, None, None, None,
|
| 1266 |
None,
|
| 1267 |
+
gr.update(visible=False),
|
| 1268 |
)
|
| 1269 |
|
| 1270 |
if not folder:
|
| 1271 |
yield empty_result
|
| 1272 |
return
|
| 1273 |
|
| 1274 |
+
if not check_trajectories_downloaded(folder):
|
| 1275 |
+
yield (
|
| 1276 |
+
"β³ Downloading trajectories...",
|
| 1277 |
+
gr.update(visible=False),
|
| 1278 |
+
None, None, None, None, None, None,
|
| 1279 |
+
None,
|
| 1280 |
+
gr.update(visible=False),
|
| 1281 |
+
)
|
| 1282 |
+
status, _ = download_trajectories_from_s3(folder)
|
| 1283 |
+
if "β" in status:
|
| 1284 |
+
yield (
|
| 1285 |
+
status,
|
| 1286 |
+
gr.update(visible=False),
|
| 1287 |
+
None, None, None, None, None, None,
|
| 1288 |
+
None,
|
| 1289 |
+
gr.update(visible=False),
|
| 1290 |
+
)
|
| 1291 |
+
return
|
| 1292 |
+
|
| 1293 |
yield (
|
| 1294 |
+
"β³ Loading trajectories...",
|
| 1295 |
gr.update(visible=True),
|
| 1296 |
None, None, None, None, None, None,
|
| 1297 |
None,
|
| 1298 |
+
gr.update(visible=False),
|
| 1299 |
)
|
| 1300 |
|
| 1301 |
df_meta = load_all_trajectories(folder)
|
|
|
|
| 1313 |
df = apply_no_cache(df)
|
| 1314 |
|
| 1315 |
if df.empty:
|
| 1316 |
+
yield (
|
| 1317 |
+
"β No trajectories found",
|
| 1318 |
+
gr.update(visible=False),
|
| 1319 |
+
None, None, None, None, None, None,
|
| 1320 |
+
None,
|
| 1321 |
+
gr.update(visible=False),
|
| 1322 |
+
)
|
| 1323 |
return
|
| 1324 |
|
| 1325 |
fig_steps, fig_cost, fig_tokens, fig_tokens_cost, fig_stacked = create_basic_histograms(
|
|
|
|
| 1328 |
fig_cost_breakdown = create_cost_breakdown(df, input_price, cache_read_price, cache_creation_price, completion_price)
|
| 1329 |
|
| 1330 |
yield (
|
| 1331 |
+
f"β
Loaded {len(df)} trajectories",
|
| 1332 |
gr.update(visible=True),
|
| 1333 |
fig_steps, fig_cost, fig_tokens, fig_tokens_cost, fig_stacked, fig_cost_breakdown,
|
| 1334 |
state_data,
|
| 1335 |
+
gr.update(visible=True),
|
| 1336 |
)
|
| 1337 |
|
| 1338 |
analyze_btn.click(
|
| 1339 |
fn=load_and_analyze,
|
| 1340 |
inputs=[selected_folder, price_input, price_cache_read, price_cache_creation, price_completion, token_source, thinking_overhead, use_cache],
|
| 1341 |
outputs=[
|
| 1342 |
+
download_status,
|
| 1343 |
analysis_section,
|
| 1344 |
plot_steps, plot_cost, plot_tokens, plot_tokens_cost, plot_stacked, plot_cost_breakdown,
|
| 1345 |
trajectories_state,
|
| 1346 |
+
add_routing_btn,
|
| 1347 |
],
|
| 1348 |
)
|
| 1349 |
|