IgorSlinko commited on
Commit
ad3271e
·
1 Parent(s): 10ece01

Auto-recalculate cost charts when prices change

Browse files

- Add create_cost_by_type_chart() function for reuse
- Store trajectories DataFrame in gr.State
- Add .change() handlers on all 4 price fields
- Recalculate 'Total Cost by Token Type' and 'Cost Breakdown per Instance' on price change

Files changed (1) hide show
  1. app.py +69 -33
app.py CHANGED
@@ -250,6 +250,54 @@ def load_all_trajectories(folder: str) -> pd.DataFrame:
250
  return df
251
 
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  def create_basic_histograms(df: pd.DataFrame, input_price: float, cache_read_price: float, cache_creation_price: float, completion_price: float):
254
  if df.empty:
255
  return None, None, None, None, None
@@ -330,39 +378,8 @@ def create_basic_histograms(df: pd.DataFrame, input_price: float, cache_read_pri
330
  font=dict(size=12),
331
  )
332
 
333
- # Cost by token type
334
- cost_uncached_input = total_uncached_input * input_price / 1e6
335
- cost_cache_read = total_cache_read * cache_read_price / 1e6
336
- cost_cache_creation = total_cache_creation * cache_creation_price / 1e6
337
- cost_completion = total_completion * completion_price / 1e6
338
-
339
- cost_data = pd.DataFrame({
340
- "Token Type": ["Uncached Input", "Cache Read", "Cache Creation", "Completion"],
341
- "Cost ($)": [cost_uncached_input, cost_cache_read, cost_cache_creation, cost_completion],
342
- })
343
-
344
- fig_tokens_cost = px.bar(
345
- cost_data,
346
- x="Token Type",
347
- y="Cost ($)",
348
- title="Total Cost by Token Type ($)",
349
- color="Token Type",
350
- color_discrete_sequence=["#EF553B", "#19D3F3", "#FFA15A", "#AB63FA"],
351
- )
352
- fig_tokens_cost.update_layout(
353
- xaxis_title="Token Type",
354
- yaxis_title="Cost ($)",
355
- showlegend=False,
356
- margin=dict(l=40, r=20, t=40, b=40),
357
- )
358
-
359
- total_cost = cost_uncached_input + cost_cache_read + cost_cache_creation + cost_completion
360
- fig_tokens_cost.add_annotation(
361
- text=f"Total: ${total_cost:.2f}",
362
- xref="paper", yref="paper",
363
- x=0.95, y=0.95, showarrow=False,
364
- font=dict(size=12),
365
- )
366
 
367
  df_sorted = df.sort_values("cache_read_tokens", ascending=False).reset_index(drop=True)
368
  df_sorted["instance_idx"] = range(len(df_sorted))
@@ -627,6 +644,7 @@ def build_app():
627
  empty_result = (
628
  gr.update(visible=False),
629
  None, None, None, None, None, None,
 
630
  )
631
 
632
  if not folder:
@@ -636,6 +654,7 @@ def build_app():
636
  yield (
637
  gr.update(visible=True),
638
  None, None, None, None, None, None,
 
639
  )
640
 
641
  df = load_all_trajectories(folder)
@@ -651,6 +670,7 @@ def build_app():
651
  yield (
652
  gr.update(visible=True),
653
  fig_steps, fig_cost, fig_tokens, fig_tokens_cost, fig_stacked, fig_cost_breakdown,
 
654
  )
655
 
656
  analyze_btn.click(
@@ -659,9 +679,25 @@ def build_app():
659
  outputs=[
660
  analysis_section,
661
  plot_steps, plot_cost, plot_tokens, plot_tokens_cost, plot_stacked, plot_cost_breakdown,
 
662
  ],
663
  )
664
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
  return app
666
 
667
 
 
250
  return df
251
 
252
 
253
+ def create_cost_by_type_chart(df: pd.DataFrame, input_price: float, cache_read_price: float, cache_creation_price: float, completion_price: float):
254
+ """Create Total Cost by Token Type chart (can be called separately for price updates)"""
255
+ if df.empty:
256
+ return None
257
+
258
+ total_completion = df["completion_tokens"].sum()
259
+ total_cache_read = df["cache_read_tokens"].sum()
260
+ total_cache_creation = df["cache_creation_tokens"].sum()
261
+ df_temp = df.copy()
262
+ df_temp["uncached_input"] = (df_temp["prompt_tokens"] - df_temp["cache_read_tokens"] - df_temp["cache_creation_tokens"]).clip(lower=0)
263
+ total_uncached_input = df_temp["uncached_input"].sum()
264
+
265
+ cost_uncached_input = total_uncached_input * input_price / 1e6
266
+ cost_cache_read = total_cache_read * cache_read_price / 1e6
267
+ cost_cache_creation = total_cache_creation * cache_creation_price / 1e6
268
+ cost_completion = total_completion * completion_price / 1e6
269
+
270
+ cost_data = pd.DataFrame({
271
+ "Token Type": ["Uncached Input", "Cache Read", "Cache Creation", "Completion"],
272
+ "Cost ($)": [cost_uncached_input, cost_cache_read, cost_cache_creation, cost_completion],
273
+ })
274
+
275
+ fig = px.bar(
276
+ cost_data,
277
+ x="Token Type",
278
+ y="Cost ($)",
279
+ title="Total Cost by Token Type ($)",
280
+ color="Token Type",
281
+ color_discrete_sequence=["#EF553B", "#19D3F3", "#FFA15A", "#AB63FA"],
282
+ )
283
+ fig.update_layout(
284
+ xaxis_title="Token Type",
285
+ yaxis_title="Cost ($)",
286
+ showlegend=False,
287
+ margin=dict(l=40, r=20, t=40, b=40),
288
+ )
289
+
290
+ total_cost = cost_uncached_input + cost_cache_read + cost_cache_creation + cost_completion
291
+ fig.add_annotation(
292
+ text=f"Total: ${total_cost:.2f}",
293
+ xref="paper", yref="paper",
294
+ x=0.95, y=0.95, showarrow=False,
295
+ font=dict(size=12),
296
+ )
297
+
298
+ return fig
299
+
300
+
301
  def create_basic_histograms(df: pd.DataFrame, input_price: float, cache_read_price: float, cache_creation_price: float, completion_price: float):
302
  if df.empty:
303
  return None, None, None, None, None
 
378
  font=dict(size=12),
379
  )
380
 
381
+ # Cost by token type (use separate function)
382
+ fig_tokens_cost = create_cost_by_type_chart(df, input_price, cache_read_price, cache_creation_price, completion_price)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
  df_sorted = df.sort_values("cache_read_tokens", ascending=False).reset_index(drop=True)
385
  df_sorted["instance_idx"] = range(len(df_sorted))
 
644
  empty_result = (
645
  gr.update(visible=False),
646
  None, None, None, None, None, None,
647
+ None, # trajectories_state
648
  )
649
 
650
  if not folder:
 
654
  yield (
655
  gr.update(visible=True),
656
  None, None, None, None, None, None,
657
+ None,
658
  )
659
 
660
  df = load_all_trajectories(folder)
 
670
  yield (
671
  gr.update(visible=True),
672
  fig_steps, fig_cost, fig_tokens, fig_tokens_cost, fig_stacked, fig_cost_breakdown,
673
+ df, # save to state
674
  )
675
 
676
  analyze_btn.click(
 
679
  outputs=[
680
  analysis_section,
681
  plot_steps, plot_cost, plot_tokens, plot_tokens_cost, plot_stacked, plot_cost_breakdown,
682
+ trajectories_state,
683
  ],
684
  )
685
 
686
+ def recalculate_costs(df, input_price, cache_read_price, cache_creation_price, completion_price):
687
+ if df is None or (isinstance(df, pd.DataFrame) and df.empty):
688
+ return None, None
689
+ fig_tokens_cost = create_cost_by_type_chart(df, input_price, cache_read_price, cache_creation_price, completion_price)
690
+ fig_cost_breakdown = create_cost_breakdown(df, input_price, cache_read_price, cache_creation_price, completion_price)
691
+ return fig_tokens_cost, fig_cost_breakdown
692
+
693
+ price_inputs = [trajectories_state, price_input, price_cache_read, price_cache_creation, price_completion]
694
+ price_outputs = [plot_tokens_cost, plot_cost_breakdown]
695
+
696
+ price_input.change(fn=recalculate_costs, inputs=price_inputs, outputs=price_outputs)
697
+ price_cache_read.change(fn=recalculate_costs, inputs=price_inputs, outputs=price_outputs)
698
+ price_cache_creation.change(fn=recalculate_costs, inputs=price_inputs, outputs=price_outputs)
699
+ price_completion.change(fn=recalculate_costs, inputs=price_inputs, outputs=price_outputs)
700
+
701
  return app
702
 
703