IgorSlinko commited on
Commit
99badd3
Β·
1 Parent(s): 9399ab7

Add routing UI: model selection, strategy params, support .traj files

Browse files

- Add routing section with up to 3 models
- Each model has: dropdown with search, price fields, strategy selection
- Strategy options: random steps, every k, part of trajectory
- Support both .traj.json and .traj file formats
- Load & Analyze auto-downloads trajectories if needed
- Add Routing button appears only after successful load
- Fix error handling for models without trajectories on S3
- Compact token prices UI

Files changed (1) hide show
  1. app.py +309 -26
app.py CHANGED
@@ -28,6 +28,21 @@ _trajectories_cache = {}
28
  _calculated_tokens_cache = {}
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def get_default_overhead(model_name: str) -> float:
32
  """Get default tokenizer overhead for model provider"""
33
  model_lower = model_name.lower() if model_name else ""
@@ -172,8 +187,12 @@ def load_all_trajectories_calculated(folder: str) -> pd.DataFrame:
172
  output_dir = TRAJS_DIR / folder
173
 
174
  traj_files = list(output_dir.glob("*/*.traj.json"))
 
 
175
  if not traj_files:
176
  traj_files = list(output_dir.glob("*.traj.json"))
 
 
177
  if not traj_files:
178
  traj_files = list(output_dir.glob("*.json"))
179
 
@@ -212,6 +231,12 @@ def load_all_trajectories_calculated(folder: str) -> pd.DataFrame:
212
  return df
213
 
214
 
 
 
 
 
 
 
215
  def get_litellm_prices() -> dict:
216
  global _litellm_prices_cache
217
  if _litellm_prices_cache is not None:
@@ -357,6 +382,8 @@ def download_trajectories_from_s3(folder: str, progress=gr.Progress()):
357
  output_dir = TRAJS_DIR / folder
358
  if output_dir.exists() and any(output_dir.iterdir()):
359
  file_count = len(list(output_dir.glob("*/*.traj.json")))
 
 
360
  if file_count == 0:
361
  file_count = len(list(output_dir.glob("*.json")))
362
  return f"βœ… Already downloaded: {output_dir}\n\n{file_count} trajectory files", gr.update(visible=True)
@@ -378,9 +405,14 @@ def download_trajectories_from_s3(folder: str, progress=gr.Progress()):
378
  return f"❌ S3 download failed:\n{result.stderr}", gr.update(visible=False)
379
 
380
  file_count = len(list(output_dir.glob("*/*.traj.json")))
 
 
381
  if file_count == 0:
382
  file_count = len(list(output_dir.glob("*.json")))
383
 
 
 
 
384
  per_instance = model.get("per_instance_details", {})
385
  resolved_count = sum(1 for v in per_instance.values() if v.get("resolved"))
386
  total_count = len(per_instance)
@@ -452,8 +484,12 @@ def load_all_trajectories(folder: str) -> pd.DataFrame:
452
  output_dir = TRAJS_DIR / folder
453
 
454
  traj_files = list(output_dir.glob("*/*.traj.json"))
 
 
455
  if not traj_files:
456
  traj_files = list(output_dir.glob("*.traj.json"))
 
 
457
  if not traj_files:
458
  traj_files = list(output_dir.glob("*.json"))
459
 
@@ -873,12 +909,11 @@ def on_row_select(evt: gr.SelectData, df: pd.DataFrame):
873
  if evt.index is None:
874
  return (
875
  "", "",
876
- gr.update(interactive=False),
877
  gr.update(visible=False),
878
- gr.update(value=0, label="πŸ’² Input"),
879
- gr.update(value=0, label="πŸ’² Cache Read"),
880
- gr.update(value=0, label="πŸ’² Cache Creation"),
881
- gr.update(value=0, label="πŸ’² Completion"),
882
  "",
883
  gr.update(value=1.0),
884
  )
@@ -888,8 +923,6 @@ def on_row_select(evt: gr.SelectData, df: pd.DataFrame):
888
  folder = row["folder"]
889
  name = row["name"]
890
 
891
- show_analyze = check_trajectories_downloaded(folder)
892
-
893
  prices_dict, model_hint = get_prices_for_folder(folder)
894
  default_overhead = get_default_overhead(model_hint)
895
 
@@ -898,14 +931,13 @@ def on_row_select(evt: gr.SelectData, df: pd.DataFrame):
898
  if price_info["found"]:
899
  return gr.update(value=value, label=f"βœ… {name}")
900
  elif value > 0:
901
- return gr.update(value=value, label=f"❌ {name} (estimated)")
902
  else:
903
  return gr.update(value=0, label=f"❌ {name}")
904
 
905
  return (
906
  folder, name,
907
- gr.update(interactive=True),
908
- gr.update(visible=show_analyze),
909
  price_update(prices_dict["input"], "Input"),
910
  price_update(prices_dict["cache_read"], "Cache Read"),
911
  price_update(prices_dict["cache_creation"], "Cache Creation"),
@@ -953,18 +985,17 @@ def build_app():
953
  gr.Markdown("### Selected Model")
954
  selected_name = gr.Textbox(label="Model Name", interactive=False)
955
 
956
- download_btn = gr.Button("πŸ“₯ Download Trajectories", interactive=False)
957
- download_status = gr.Textbox(label="Status", interactive=False, lines=3)
958
-
959
  analyze_btn = gr.Button("πŸ“Š Load & Analyze", visible=False, variant="primary")
 
960
 
961
  gr.Markdown("---")
962
  gr.Markdown("### πŸ’° Token Prices ($/1M) Β· *[litellm](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)*")
963
  detected_model = gr.Textbox(label="Detected Model", interactive=False)
964
- price_input = gr.Number(label="πŸ’² Input", value=0, precision=2)
965
- price_cache_read = gr.Number(label="πŸ’² Cache Read", value=0, precision=2)
966
- price_cache_creation = gr.Number(label="πŸ’² Cache Creation", value=0, precision=2)
967
- price_completion = gr.Number(label="πŸ’² Completion", value=0, precision=2)
 
968
 
969
  gr.Markdown("---")
970
  gr.Markdown("### πŸ“Š Token Count Source")
@@ -986,6 +1017,231 @@ def build_app():
986
  visible=False,
987
  )
988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
989
  def update_calculated_options_visibility(source):
990
  is_calc = source == "Calculated"
991
  return gr.update(visible=is_calc), gr.update(visible=is_calc)
@@ -999,30 +1255,47 @@ def build_app():
999
  leaderboard_table.select(
1000
  fn=on_row_select,
1001
  inputs=[leaderboard_table],
1002
- outputs=[selected_folder, selected_name, download_btn, analyze_btn, price_input, price_cache_read, price_cache_creation, price_completion, detected_model, thinking_overhead],
1003
- )
1004
-
1005
- download_btn.click(
1006
- fn=download_trajectories_from_s3,
1007
- inputs=[selected_folder],
1008
- outputs=[download_status, analyze_btn],
1009
  )
1010
 
1011
- def load_and_analyze(folder, input_price, cache_read_price, cache_creation_price, completion_price, source, overhead, with_cache):
1012
  empty_result = (
 
1013
  gr.update(visible=False),
1014
  None, None, None, None, None, None,
1015
  None,
 
1016
  )
1017
 
1018
  if not folder:
1019
  yield empty_result
1020
  return
1021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1022
  yield (
 
1023
  gr.update(visible=True),
1024
  None, None, None, None, None, None,
1025
  None,
 
1026
  )
1027
 
1028
  df_meta = load_all_trajectories(folder)
@@ -1040,7 +1313,13 @@ def build_app():
1040
  df = apply_no_cache(df)
1041
 
1042
  if df.empty:
1043
- yield empty_result
 
 
 
 
 
 
1044
  return
1045
 
1046
  fig_steps, fig_cost, fig_tokens, fig_tokens_cost, fig_stacked = create_basic_histograms(
@@ -1049,18 +1328,22 @@ def build_app():
1049
  fig_cost_breakdown = create_cost_breakdown(df, input_price, cache_read_price, cache_creation_price, completion_price)
1050
 
1051
  yield (
 
1052
  gr.update(visible=True),
1053
  fig_steps, fig_cost, fig_tokens, fig_tokens_cost, fig_stacked, fig_cost_breakdown,
1054
  state_data,
 
1055
  )
1056
 
1057
  analyze_btn.click(
1058
  fn=load_and_analyze,
1059
  inputs=[selected_folder, price_input, price_cache_read, price_cache_creation, price_completion, token_source, thinking_overhead, use_cache],
1060
  outputs=[
 
1061
  analysis_section,
1062
  plot_steps, plot_cost, plot_tokens, plot_tokens_cost, plot_stacked, plot_cost_breakdown,
1063
  trajectories_state,
 
1064
  ],
1065
  )
1066
 
 
28
  _calculated_tokens_cache = {}
29
 
30
 
31
+ def parse_step_or_ratio(value: float, total_steps: int) -> int:
32
+ """
33
+ Parse a value as either step number or ratio.
34
+
35
+ If value is integer (e.g., 3.0, 5.0) -> treat as step number
36
+ If value is float with decimal (e.g., 0.5, 0.25) -> treat as ratio of total_steps
37
+
38
+ Returns: step index (0-based)
39
+ """
40
+ if value == int(value) and value >= 1:
41
+ return int(value)
42
+ else:
43
+ return int(value * total_steps)
44
+
45
+
46
  def get_default_overhead(model_name: str) -> float:
47
  """Get default tokenizer overhead for model provider"""
48
  model_lower = model_name.lower() if model_name else ""
 
187
  output_dir = TRAJS_DIR / folder
188
 
189
  traj_files = list(output_dir.glob("*/*.traj.json"))
190
+ if not traj_files:
191
+ traj_files = list(output_dir.glob("*/*.traj"))
192
  if not traj_files:
193
  traj_files = list(output_dir.glob("*.traj.json"))
194
+ if not traj_files:
195
+ traj_files = list(output_dir.glob("*.traj"))
196
  if not traj_files:
197
  traj_files = list(output_dir.glob("*.json"))
198
 
 
231
  return df
232
 
233
 
234
+ def get_litellm_model_list() -> list[str]:
235
+ """Get list of model names from litellm prices"""
236
+ prices = get_litellm_prices()
237
+ return sorted(prices.keys())
238
+
239
+
240
  def get_litellm_prices() -> dict:
241
  global _litellm_prices_cache
242
  if _litellm_prices_cache is not None:
 
382
  output_dir = TRAJS_DIR / folder
383
  if output_dir.exists() and any(output_dir.iterdir()):
384
  file_count = len(list(output_dir.glob("*/*.traj.json")))
385
+ if file_count == 0:
386
+ file_count = len(list(output_dir.glob("*/*.traj")))
387
  if file_count == 0:
388
  file_count = len(list(output_dir.glob("*.json")))
389
  return f"βœ… Already downloaded: {output_dir}\n\n{file_count} trajectory files", gr.update(visible=True)
 
405
  return f"❌ S3 download failed:\n{result.stderr}", gr.update(visible=False)
406
 
407
  file_count = len(list(output_dir.glob("*/*.traj.json")))
408
+ if file_count == 0:
409
+ file_count = len(list(output_dir.glob("*/*.traj")))
410
  if file_count == 0:
411
  file_count = len(list(output_dir.glob("*.json")))
412
 
413
+ if file_count == 0:
414
+ return f"❌ No trajectory files found on S3 for {folder}", gr.update(visible=False)
415
+
416
  per_instance = model.get("per_instance_details", {})
417
  resolved_count = sum(1 for v in per_instance.values() if v.get("resolved"))
418
  total_count = len(per_instance)
 
484
  output_dir = TRAJS_DIR / folder
485
 
486
  traj_files = list(output_dir.glob("*/*.traj.json"))
487
+ if not traj_files:
488
+ traj_files = list(output_dir.glob("*/*.traj"))
489
  if not traj_files:
490
  traj_files = list(output_dir.glob("*.traj.json"))
491
+ if not traj_files:
492
+ traj_files = list(output_dir.glob("*.traj"))
493
  if not traj_files:
494
  traj_files = list(output_dir.glob("*.json"))
495
 
 
909
  if evt.index is None:
910
  return (
911
  "", "",
 
912
  gr.update(visible=False),
913
+ gr.update(value=0, label="Input"),
914
+ gr.update(value=0, label="Cache Read"),
915
+ gr.update(value=0, label="Cache Creation"),
916
+ gr.update(value=0, label="Completion"),
917
  "",
918
  gr.update(value=1.0),
919
  )
 
923
  folder = row["folder"]
924
  name = row["name"]
925
 
 
 
926
  prices_dict, model_hint = get_prices_for_folder(folder)
927
  default_overhead = get_default_overhead(model_hint)
928
 
 
931
  if price_info["found"]:
932
  return gr.update(value=value, label=f"βœ… {name}")
933
  elif value > 0:
934
+ return gr.update(value=value, label=f"❌ {name} (est.)")
935
  else:
936
  return gr.update(value=0, label=f"❌ {name}")
937
 
938
  return (
939
  folder, name,
940
+ gr.update(visible=True),
 
941
  price_update(prices_dict["input"], "Input"),
942
  price_update(prices_dict["cache_read"], "Cache Read"),
943
  price_update(prices_dict["cache_creation"], "Cache Creation"),
 
985
  gr.Markdown("### Selected Model")
986
  selected_name = gr.Textbox(label="Model Name", interactive=False)
987
 
 
 
 
988
  analyze_btn = gr.Button("πŸ“Š Load & Analyze", visible=False, variant="primary")
989
+ download_status = gr.Textbox(label="Status", interactive=False, lines=3)
990
 
991
  gr.Markdown("---")
992
  gr.Markdown("### πŸ’° Token Prices ($/1M) Β· *[litellm](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)*")
993
  detected_model = gr.Textbox(label="Detected Model", interactive=False)
994
+ with gr.Row():
995
+ price_input = gr.Number(label="Input", value=0, precision=2, scale=1)
996
+ price_cache_read = gr.Number(label="Cache Read", value=0, precision=2, scale=1)
997
+ price_cache_creation = gr.Number(label="Cache Creation", value=0, precision=2, scale=1)
998
+ price_completion = gr.Number(label="Completion", value=0, precision=2, scale=1)
999
 
1000
  gr.Markdown("---")
1001
  gr.Markdown("### πŸ“Š Token Count Source")
 
1017
  visible=False,
1018
  )
1019
 
1020
+ gr.Markdown("---")
1021
+ add_routing_btn = gr.Button("βž• Add Routing", variant="primary", visible=False)
1022
+
1023
+ with gr.Column(visible=False) as routing_section:
1024
+ gr.Markdown("### πŸ”€ Routing Models")
1025
+
1026
+ STRATEGY_CHOICES = [
1027
+ "Replace on random steps",
1028
+ "Replace every step k",
1029
+ "Replace part of trajectory",
1030
+ ]
1031
+
1032
+ with gr.Column():
1033
+ with gr.Group():
1034
+ gr.Markdown("#### Route to Model 1")
1035
+ routing_model_1 = gr.Dropdown(
1036
+ label="Model (type 3+ chars to search)",
1037
+ choices=[],
1038
+ allow_custom_value=True,
1039
+ interactive=True,
1040
+ )
1041
+ with gr.Row():
1042
+ routing_price_1_input = gr.Number(label="Input", precision=3, scale=1)
1043
+ routing_price_1_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
1044
+ routing_price_1_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
1045
+ routing_price_1_completion = gr.Number(label="Completion", precision=3, scale=1)
1046
+ strategy_1 = gr.Dropdown(
1047
+ label="Strategy",
1048
+ choices=STRATEGY_CHOICES,
1049
+ value="Replace on random steps",
1050
+ interactive=True,
1051
+ )
1052
+ with gr.Row(visible=True) as random_params_1:
1053
+ random_pct_1 = gr.Number(label="Percentage (%)", value=50, minimum=0, maximum=100, precision=0, interactive=True)
1054
+ with gr.Row(visible=False) as every_k_params_1:
1055
+ step_k_1 = gr.Number(label="k", value=2, minimum=1, precision=0, interactive=True)
1056
+ with gr.Row(visible=False) as part_params_1:
1057
+ start_step_1 = gr.Number(label="Start (int=step; 0,0-1,0=ratio)", value=0, minimum=0, precision=2, interactive=True)
1058
+ end_step_1 = gr.Number(label="End (int=step; 0,0-1,0=ratio)", value=0.5, minimum=0, precision=2, interactive=True)
1059
+
1060
+ add_model_2_btn = gr.Button("+ Add another model", size="sm", visible=False)
1061
+
1062
+ with gr.Column(visible=False) as routing_block_2:
1063
+ with gr.Group():
1064
+ gr.Markdown("#### Route to Model 2")
1065
+ routing_model_2 = gr.Dropdown(
1066
+ label="Model (type 3+ chars to search)",
1067
+ choices=[],
1068
+ allow_custom_value=True,
1069
+ interactive=True,
1070
+ )
1071
+ with gr.Row():
1072
+ routing_price_2_input = gr.Number(label="Input", precision=3, scale=1)
1073
+ routing_price_2_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
1074
+ routing_price_2_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
1075
+ routing_price_2_completion = gr.Number(label="Completion", precision=3, scale=1)
1076
+ strategy_2 = gr.Dropdown(
1077
+ label="Strategy",
1078
+ choices=STRATEGY_CHOICES,
1079
+ value="Replace on random steps",
1080
+ interactive=True,
1081
+ )
1082
+ with gr.Row(visible=True) as random_params_2:
1083
+ random_pct_2 = gr.Number(label="Percentage (%)", value=50, minimum=0, maximum=100, precision=0, interactive=True)
1084
+ with gr.Row(visible=False) as every_k_params_2:
1085
+ step_k_2 = gr.Number(label="k", value=2, minimum=1, precision=0, interactive=True)
1086
+ with gr.Row(visible=False) as part_params_2:
1087
+ start_step_2 = gr.Number(label="Start (int=step; 0,0-1,0=ratio)", value=0, minimum=0, precision=2, interactive=True)
1088
+ end_step_2 = gr.Number(label="End (int=step; 0,0-1,0=ratio)", value=0.5, minimum=0, precision=2, interactive=True)
1089
+
1090
+ add_model_3_btn = gr.Button("+ Add another model", size="sm", visible=False)
1091
+
1092
+ with gr.Column(visible=False) as routing_block_3:
1093
+ with gr.Group():
1094
+ gr.Markdown("#### Route to Model 3")
1095
+ routing_model_3 = gr.Dropdown(
1096
+ label="Model (type 3+ chars to search)",
1097
+ choices=[],
1098
+ allow_custom_value=True,
1099
+ interactive=True,
1100
+ )
1101
+ with gr.Row():
1102
+ routing_price_3_input = gr.Number(label="Input", precision=3, scale=1)
1103
+ routing_price_3_cache_read = gr.Number(label="Cache Read", precision=3, scale=1)
1104
+ routing_price_3_cache_creation = gr.Number(label="Cache Creation", precision=3, scale=1)
1105
+ routing_price_3_completion = gr.Number(label="Completion", precision=3, scale=1)
1106
+ strategy_3 = gr.Dropdown(
1107
+ label="Strategy",
1108
+ choices=STRATEGY_CHOICES,
1109
+ value="Replace on random steps",
1110
+ interactive=True,
1111
+ )
1112
+ with gr.Row(visible=True) as random_params_3:
1113
+ random_pct_3 = gr.Number(label="Percentage (%)", value=50, minimum=0, maximum=100, precision=0, interactive=True)
1114
+ with gr.Row(visible=False) as every_k_params_3:
1115
+ step_k_3 = gr.Number(label="k", value=2, minimum=1, precision=0, interactive=True)
1116
+ with gr.Row(visible=False) as part_params_3:
1117
+ start_step_3 = gr.Number(label="Start (int=step; 0,0-1,0=ratio)", value=0, minimum=0, precision=2, interactive=True)
1118
+ end_step_3 = gr.Number(label="End (int=step; 0,0-1,0=ratio)", value=0.5, minimum=0, precision=2, interactive=True)
1119
+
1120
+ def on_strategy_change(strategy):
1121
+ return (
1122
+ gr.update(visible=strategy == "Replace on random steps"),
1123
+ gr.update(visible=strategy == "Replace every step k"),
1124
+ gr.update(visible=strategy == "Replace part of trajectory"),
1125
+ )
1126
+
1127
+ def toggle_routing_section():
1128
+ return gr.update(visible=True)
1129
+
1130
+ add_routing_btn.click(
1131
+ fn=toggle_routing_section,
1132
+ outputs=[routing_section],
1133
+ )
1134
+
1135
+ strategy_1.change(
1136
+ fn=on_strategy_change,
1137
+ inputs=[strategy_1],
1138
+ outputs=[random_params_1, every_k_params_1, part_params_1],
1139
+ )
1140
+
1141
+ strategy_2.change(
1142
+ fn=on_strategy_change,
1143
+ inputs=[strategy_2],
1144
+ outputs=[random_params_2, every_k_params_2, part_params_2],
1145
+ )
1146
+
1147
+ strategy_3.change(
1148
+ fn=on_strategy_change,
1149
+ inputs=[strategy_3],
1150
+ outputs=[random_params_3, every_k_params_3, part_params_3],
1151
+ )
1152
+
1153
+ def filter_models(query):
1154
+ """Filter models based on search query (starts at 3 chars)"""
1155
+ if not query or len(query) < 3:
1156
+ return gr.update(choices=[])
1157
+ all_models = get_litellm_model_list()
1158
+ query_lower = query.lower()
1159
+ filtered = [m for m in all_models if query_lower in m.lower()][:50]
1160
+ return gr.update(choices=filtered)
1161
+
1162
+ routing_model_1.input(fn=filter_models, inputs=[routing_model_1], outputs=[routing_model_1])
1163
+ routing_model_2.input(fn=filter_models, inputs=[routing_model_2], outputs=[routing_model_2])
1164
+ routing_model_3.input(fn=filter_models, inputs=[routing_model_3], outputs=[routing_model_3])
1165
+
1166
+ def get_routing_prices_with_labels(model_name):
1167
+ """Get all 4 prices for a routing model with found/estimated labels"""
1168
+ if not model_name:
1169
+ return (
1170
+ gr.update(value=0, label="Input"),
1171
+ gr.update(value=0, label="Cache Read"),
1172
+ gr.update(value=0, label="Cache Creation"),
1173
+ gr.update(value=0, label="Completion"),
1174
+ )
1175
+
1176
+ prices = get_litellm_prices()
1177
+ model_prices = prices.get(model_name, {})
1178
+
1179
+ input_price = model_prices.get("input_cost_per_token", 0) * 1e6
1180
+ cache_read = model_prices.get("cache_read_input_token_cost", 0) * 1e6
1181
+ cache_creation = model_prices.get("cache_creation_input_token_cost", 0) * 1e6
1182
+ completion = model_prices.get("output_cost_per_token", 0) * 1e6
1183
+
1184
+ input_found = input_price > 0
1185
+ cache_read_found = cache_read > 0
1186
+ cache_creation_found = cache_creation > 0
1187
+ completion_found = completion > 0
1188
+
1189
+ if not cache_read_found and input_price > 0:
1190
+ cache_read = input_price * 0.1
1191
+ if not cache_creation_found and input_price > 0:
1192
+ cache_creation = input_price * 1.25
1193
+
1194
+ def label(name, found):
1195
+ return f"βœ… {name}" if found else f"❌ {name}"
1196
+
1197
+ return (
1198
+ gr.update(value=input_price, label=label("Input", input_found)),
1199
+ gr.update(value=cache_read, label=label("Cache Read", cache_read_found)),
1200
+ gr.update(value=cache_creation, label=label("Cache Creation", cache_creation_found)),
1201
+ gr.update(value=completion, label=label("Completion", completion_found)),
1202
+ )
1203
+
1204
+ def on_routing_model_1_select(model_name):
1205
+ prices = get_routing_prices_with_labels(model_name)
1206
+ show_btn = bool(model_name)
1207
+ return *prices, gr.update(visible=show_btn)
1208
+
1209
+ def on_routing_model_2_select(model_name):
1210
+ prices = get_routing_prices_with_labels(model_name)
1211
+ show_btn = bool(model_name)
1212
+ return *prices, gr.update(visible=show_btn)
1213
+
1214
+ def on_routing_model_3_select(model_name):
1215
+ return get_routing_prices_with_labels(model_name)
1216
+
1217
+ routing_model_1.change(
1218
+ fn=on_routing_model_1_select,
1219
+ inputs=[routing_model_1],
1220
+ outputs=[routing_price_1_input, routing_price_1_cache_read, routing_price_1_cache_creation, routing_price_1_completion, add_model_2_btn],
1221
+ )
1222
+
1223
+ add_model_2_btn.click(
1224
+ fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
1225
+ outputs=[routing_block_2, add_model_2_btn],
1226
+ )
1227
+
1228
+ routing_model_2.change(
1229
+ fn=on_routing_model_2_select,
1230
+ inputs=[routing_model_2],
1231
+ outputs=[routing_price_2_input, routing_price_2_cache_read, routing_price_2_cache_creation, routing_price_2_completion, add_model_3_btn],
1232
+ )
1233
+
1234
+ add_model_3_btn.click(
1235
+ fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
1236
+ outputs=[routing_block_3, add_model_3_btn],
1237
+ )
1238
+
1239
+ routing_model_3.change(
1240
+ fn=on_routing_model_3_select,
1241
+ inputs=[routing_model_3],
1242
+ outputs=[routing_price_3_input, routing_price_3_cache_read, routing_price_3_cache_creation, routing_price_3_completion],
1243
+ )
1244
+
1245
  def update_calculated_options_visibility(source):
1246
  is_calc = source == "Calculated"
1247
  return gr.update(visible=is_calc), gr.update(visible=is_calc)
 
1255
  leaderboard_table.select(
1256
  fn=on_row_select,
1257
  inputs=[leaderboard_table],
1258
+ outputs=[selected_folder, selected_name, analyze_btn, price_input, price_cache_read, price_cache_creation, price_completion, detected_model, thinking_overhead],
 
 
 
 
 
 
1259
  )
1260
 
1261
+ def load_and_analyze(folder, input_price, cache_read_price, cache_creation_price, completion_price, source, overhead, with_cache, progress=gr.Progress()):
1262
  empty_result = (
1263
+ "",
1264
  gr.update(visible=False),
1265
  None, None, None, None, None, None,
1266
  None,
1267
+ gr.update(visible=False),
1268
  )
1269
 
1270
  if not folder:
1271
  yield empty_result
1272
  return
1273
 
1274
+ if not check_trajectories_downloaded(folder):
1275
+ yield (
1276
+ "⏳ Downloading trajectories...",
1277
+ gr.update(visible=False),
1278
+ None, None, None, None, None, None,
1279
+ None,
1280
+ gr.update(visible=False),
1281
+ )
1282
+ status, _ = download_trajectories_from_s3(folder)
1283
+ if "❌" in status:
1284
+ yield (
1285
+ status,
1286
+ gr.update(visible=False),
1287
+ None, None, None, None, None, None,
1288
+ None,
1289
+ gr.update(visible=False),
1290
+ )
1291
+ return
1292
+
1293
  yield (
1294
+ "⏳ Loading trajectories...",
1295
  gr.update(visible=True),
1296
  None, None, None, None, None, None,
1297
  None,
1298
+ gr.update(visible=False),
1299
  )
1300
 
1301
  df_meta = load_all_trajectories(folder)
 
1313
  df = apply_no_cache(df)
1314
 
1315
  if df.empty:
1316
+ yield (
1317
+ "❌ No trajectories found",
1318
+ gr.update(visible=False),
1319
+ None, None, None, None, None, None,
1320
+ None,
1321
+ gr.update(visible=False),
1322
+ )
1323
  return
1324
 
1325
  fig_steps, fig_cost, fig_tokens, fig_tokens_cost, fig_stacked = create_basic_histograms(
 
1328
  fig_cost_breakdown = create_cost_breakdown(df, input_price, cache_read_price, cache_creation_price, completion_price)
1329
 
1330
  yield (
1331
+ f"βœ… Loaded {len(df)} trajectories",
1332
  gr.update(visible=True),
1333
  fig_steps, fig_cost, fig_tokens, fig_tokens_cost, fig_stacked, fig_cost_breakdown,
1334
  state_data,
1335
+ gr.update(visible=True),
1336
  )
1337
 
1338
  analyze_btn.click(
1339
  fn=load_and_analyze,
1340
  inputs=[selected_folder, price_input, price_cache_read, price_cache_creation, price_completion, token_source, thinking_overhead, use_cache],
1341
  outputs=[
1342
+ download_status,
1343
  analysis_section,
1344
  plot_steps, plot_cost, plot_tokens, plot_tokens_cost, plot_stacked, plot_cost_breakdown,
1345
  trajectories_state,
1346
+ add_routing_btn,
1347
  ],
1348
  )
1349