IgorSlinko commited on
Commit
dd3c0f8
·
1 Parent(s): f7c61dd

Add cumulative statistics to routing charts

Browse files

- Show total tokens/cost with breakdown by model
- Annotations: Total, Base Model, Model 1, Model 2...
- Fix model_name mismatch in load_all_trajectories_calculated
- Use step_model from steps instead of reading model_name separately

Files changed (1) hide show
  1. app.py +45 -23
app.py CHANGED
@@ -265,25 +265,6 @@ def load_all_trajectories_calculated(folder: str) -> pd.DataFrame:
265
 
266
  trajectory_steps = load_all_trajectory_steps(folder)
267
 
268
- output_dir = TRAJS_DIR / folder
269
- traj_files = list(output_dir.glob("*/*.traj.json"))
270
- if not traj_files:
271
- traj_files = list(output_dir.glob("*/*.traj"))
272
- if not traj_files:
273
- traj_files = list(output_dir.glob("*.traj.json"))
274
- if not traj_files:
275
- traj_files = list(output_dir.glob("*.traj"))
276
-
277
- model_name = ""
278
- if traj_files:
279
- try:
280
- with open(traj_files[0], "r") as f:
281
- first_data = json.load(f)
282
- config = first_data.get("info", {}).get("config", {}).get("model", {})
283
- model_name = config.get("cost_calc_model_override", config.get("model_name", ""))
284
- except Exception:
285
- pass
286
-
287
  rows = []
288
  for instance_id, steps in trajectory_steps.items():
289
  if not steps:
@@ -291,7 +272,8 @@ def load_all_trajectories_calculated(folder: str) -> pd.DataFrame:
291
 
292
  try:
293
  model_totals = calculate_routing_tokens(steps)
294
- totals = model_totals.get(model_name, {})
 
295
 
296
  cache_read = totals.get("cache_read", 0)
297
  uncached_input = totals.get("uncached_input", 0)
@@ -302,7 +284,7 @@ def load_all_trajectories_calculated(folder: str) -> pd.DataFrame:
302
 
303
  rows.append({
304
  "instance_id": instance_id,
305
- "model_name": model_name,
306
  "api_calls": len(steps),
307
  "instance_cost": 0,
308
  "prompt_tokens": prompt_tokens,
@@ -1099,6 +1081,7 @@ def create_routed_token_chart(base_tokens: dict, additional_models: list):
1099
 
1100
  fig = go.Figure()
1101
 
 
1102
  base_values = [
1103
  base_tokens.get("uncached_input", 0) / 1e6,
1104
  base_tokens.get("cache_read", 0) / 1e6,
@@ -1107,7 +1090,11 @@ def create_routed_token_chart(base_tokens: dict, additional_models: list):
1107
  ]
1108
  fig.add_trace(go.Bar(name="Base Model", x=categories, y=base_values, marker_color=colors[0]))
1109
 
 
 
1110
  for i, (model_name, tokens) in enumerate(additional_models):
 
 
1111
  values = [
1112
  tokens.get("uncached_input", 0) / 1e6,
1113
  tokens.get("cache_read", 0) / 1e6,
@@ -1117,13 +1104,28 @@ def create_routed_token_chart(base_tokens: dict, additional_models: list):
1117
  color = colors[(i + 1) % len(colors)]
1118
  fig.add_trace(go.Bar(name=model_name or f"Model {i+1}", x=categories, y=values, marker_color=color))
1119
 
 
 
 
 
 
1120
  fig.update_layout(
1121
  title="Tokens by Type (per Model)",
1122
  yaxis_title="Tokens (M)",
1123
  barmode="group",
1124
- margin=dict(l=40, r=40, t=60, b=40),
1125
  legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
1126
  )
 
 
 
 
 
 
 
 
 
 
1127
  return fig
1128
 
1129
 
@@ -1142,6 +1144,7 @@ def create_routed_cost_chart(base_costs: dict, additional_models: list):
1142
 
1143
  fig = go.Figure()
1144
 
 
1145
  base_values = [
1146
  base_costs.get("uncached_input", 0),
1147
  base_costs.get("cache_read", 0),
@@ -1150,7 +1153,11 @@ def create_routed_cost_chart(base_costs: dict, additional_models: list):
1150
  ]
1151
  fig.add_trace(go.Bar(name="Base Model", x=categories, y=base_values, marker_color=colors[0]))
1152
 
 
 
1153
  for i, (model_name, costs) in enumerate(additional_models):
 
 
1154
  values = [
1155
  costs.get("uncached_input", 0),
1156
  costs.get("cache_read", 0),
@@ -1160,13 +1167,28 @@ def create_routed_cost_chart(base_costs: dict, additional_models: list):
1160
  color = colors[(i + 1) % len(colors)]
1161
  fig.add_trace(go.Bar(name=model_name or f"Model {i+1}", x=categories, y=values, marker_color=color))
1162
 
 
 
 
 
 
1163
  fig.update_layout(
1164
  title="Cost by Type (per Model) ($)",
1165
  yaxis_title="Cost ($)",
1166
  barmode="group",
1167
- margin=dict(l=40, r=40, t=60, b=40),
1168
  legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
1169
  )
 
 
 
 
 
 
 
 
 
 
1170
  return fig
1171
 
1172
 
 
265
 
266
  trajectory_steps = load_all_trajectory_steps(folder)
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  rows = []
269
  for instance_id, steps in trajectory_steps.items():
270
  if not steps:
 
272
 
273
  try:
274
  model_totals = calculate_routing_tokens(steps)
275
+ step_model = steps[0].get("model", "") if steps else ""
276
+ totals = model_totals.get(step_model, {})
277
 
278
  cache_read = totals.get("cache_read", 0)
279
  uncached_input = totals.get("uncached_input", 0)
 
284
 
285
  rows.append({
286
  "instance_id": instance_id,
287
+ "model_name": step_model,
288
  "api_calls": len(steps),
289
  "instance_cost": 0,
290
  "prompt_tokens": prompt_tokens,
 
1081
 
1082
  fig = go.Figure()
1083
 
1084
+ base_total = sum(base_tokens.get(k, 0) for k in ["uncached_input", "cache_read", "cache_creation", "completion"])
1085
  base_values = [
1086
  base_tokens.get("uncached_input", 0) / 1e6,
1087
  base_tokens.get("cache_read", 0) / 1e6,
 
1090
  ]
1091
  fig.add_trace(go.Bar(name="Base Model", x=categories, y=base_values, marker_color=colors[0]))
1092
 
1093
+ model_totals = [("Base Model", base_total)]
1094
+
1095
  for i, (model_name, tokens) in enumerate(additional_models):
1096
+ model_total = sum(tokens.get(k, 0) for k in ["uncached_input", "cache_read", "cache_creation", "completion"])
1097
+ model_totals.append((model_name or f"Model {i+1}", model_total))
1098
  values = [
1099
  tokens.get("uncached_input", 0) / 1e6,
1100
  tokens.get("cache_read", 0) / 1e6,
 
1104
  color = colors[(i + 1) % len(colors)]
1105
  fig.add_trace(go.Bar(name=model_name or f"Model {i+1}", x=categories, y=values, marker_color=color))
1106
 
1107
+ grand_total = sum(t for _, t in model_totals)
1108
+ annotation_lines = [f"<b>Total: {grand_total/1e6:.2f}M</b>"]
1109
+ for name, total in model_totals:
1110
+ annotation_lines.append(f"{name}: {total/1e6:.2f}M")
1111
+
1112
  fig.update_layout(
1113
  title="Tokens by Type (per Model)",
1114
  yaxis_title="Tokens (M)",
1115
  barmode="group",
1116
+ margin=dict(l=40, r=40, t=80, b=40),
1117
  legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
1118
  )
1119
+ fig.add_annotation(
1120
+ text="<br>".join(annotation_lines),
1121
+ xref="paper", yref="paper",
1122
+ x=0.02, y=0.98, showarrow=False,
1123
+ font=dict(size=11),
1124
+ align="left",
1125
+ bgcolor="rgba(255,255,255,0.8)",
1126
+ bordercolor="gray",
1127
+ borderwidth=1,
1128
+ )
1129
  return fig
1130
 
1131
 
 
1144
 
1145
  fig = go.Figure()
1146
 
1147
+ base_total = sum(base_costs.get(k, 0) for k in ["uncached_input", "cache_read", "cache_creation", "completion"])
1148
  base_values = [
1149
  base_costs.get("uncached_input", 0),
1150
  base_costs.get("cache_read", 0),
 
1153
  ]
1154
  fig.add_trace(go.Bar(name="Base Model", x=categories, y=base_values, marker_color=colors[0]))
1155
 
1156
+ model_totals = [("Base Model", base_total)]
1157
+
1158
  for i, (model_name, costs) in enumerate(additional_models):
1159
+ model_total = sum(costs.get(k, 0) for k in ["uncached_input", "cache_read", "cache_creation", "completion"])
1160
+ model_totals.append((model_name or f"Model {i+1}", model_total))
1161
  values = [
1162
  costs.get("uncached_input", 0),
1163
  costs.get("cache_read", 0),
 
1167
  color = colors[(i + 1) % len(colors)]
1168
  fig.add_trace(go.Bar(name=model_name or f"Model {i+1}", x=categories, y=values, marker_color=color))
1169
 
1170
+ grand_total = sum(t for _, t in model_totals)
1171
+ annotation_lines = [f"<b>Total: ${grand_total:.2f}</b>"]
1172
+ for name, total in model_totals:
1173
+ annotation_lines.append(f"{name}: ${total:.2f}")
1174
+
1175
  fig.update_layout(
1176
  title="Cost by Type (per Model) ($)",
1177
  yaxis_title="Cost ($)",
1178
  barmode="group",
1179
+ margin=dict(l=40, r=40, t=80, b=40),
1180
  legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
1181
  )
1182
+ fig.add_annotation(
1183
+ text="<br>".join(annotation_lines),
1184
+ xref="paper", yref="paper",
1185
+ x=0.02, y=0.98, showarrow=False,
1186
+ font=dict(size=11),
1187
+ align="left",
1188
+ bgcolor="rgba(255,255,255,0.8)",
1189
+ bordercolor="gray",
1190
+ borderwidth=1,
1191
+ )
1192
  return fig
1193
 
1194