IgorSlinko commited on
Commit
a75cc98
·
1 Parent(s): 51fc1ab

Fix routing bugs and unify chart formatting

Browse files

- Fix double-click bug: add intermediate yield with loading state
- Fix Original Cost mismatch: use df-based calculation (same as Total Cost by Token)
- Let's ROUTE!! button already disabled until routing model selected
- Unify all token charts to use millions (M) with consistent formatting
- Update hover templates to show values in M format
- Consistent legend positioning across all charts

Files changed (1) hide show
  1. app.py +44 -15
app.py CHANGED
@@ -803,27 +803,27 @@ def create_basic_histograms(df: pd.DataFrame, input_price: float, cache_read_pri
803
 
804
  token_data = pd.DataFrame({
805
  "Token Type": ["Uncached Input", "Cache Read", "Cache Creation", "Completion"],
806
- "Total Tokens": [total_uncached_input, total_cache_read, total_cache_creation, total_completion],
807
  })
808
 
809
  fig_tokens = px.bar(
810
  token_data,
811
  x="Token Type",
812
- y="Total Tokens",
813
  title="Total Tokens by Type",
814
  color="Token Type",
815
  color_discrete_sequence=["#EF553B", "#19D3F3", "#FFA15A", "#AB63FA"],
816
  )
817
  fig_tokens.update_layout(
818
  xaxis_title="Token Type",
819
- yaxis_title="Total Tokens",
820
  showlegend=False,
821
  margin=dict(l=40, r=20, t=40, b=40),
822
  )
823
 
824
- total_all = token_data["Total Tokens"].sum()
825
  fig_tokens.add_annotation(
826
- text=f"Total: {total_all:,.0f}",
827
  xref="paper", yref="paper",
828
  x=0.95, y=0.95, showarrow=False,
829
  font=dict(size=12),
@@ -844,40 +844,40 @@ def create_basic_histograms(df: pd.DataFrame, input_price: float, cache_read_pri
844
  fig_stacked.add_trace(go.Bar(
845
  name="Uncached Input",
846
  x=df_sorted["trajectory_idx"],
847
- y=df_sorted["uncached_input_tokens"],
848
  marker_color="#EF553B",
849
- hovertemplate="Trajectory: %{x}<br>Uncached Input: %{y:,.0f}<extra></extra>",
850
  ))
851
 
852
  fig_stacked.add_trace(go.Bar(
853
  name="Cache Read",
854
  x=df_sorted["trajectory_idx"],
855
- y=df_sorted["cache_read_tokens"],
856
  marker_color="#19D3F3",
857
- hovertemplate="Trajectory: %{x}<br>Cache Read: %{y:,.0f}<extra></extra>",
858
  ))
859
 
860
  fig_stacked.add_trace(go.Bar(
861
  name="Cache Creation",
862
  x=df_sorted["trajectory_idx"],
863
- y=df_sorted["cache_creation_tokens"],
864
  marker_color="#FFA15A",
865
- hovertemplate="Trajectory: %{x}<br>Cache Creation: %{y:,.0f}<extra></extra>",
866
  ))
867
 
868
  fig_stacked.add_trace(go.Bar(
869
  name="Completion",
870
  x=df_sorted["trajectory_idx"],
871
- y=df_sorted["completion_tokens"],
872
  marker_color="#AB63FA",
873
- hovertemplate="Trajectory: %{x}<br>Completion: %{y:,.0f}<extra></extra>",
874
  ))
875
 
876
  fig_stacked.update_layout(
877
  barmode="stack",
878
  title="Tokens per Trajectory (stacked)",
879
  xaxis_title="Trajectory (sorted by total tokens)",
880
- yaxis_title="Tokens",
881
  legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
882
  margin=dict(l=50, r=20, t=60, b=40),
883
  )
@@ -1528,6 +1528,24 @@ def build_app():
1528
  )
1529
  return
1530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1531
  base_prices = {
1532
  "input": base_input,
1533
  "cache_read": base_cache_read,
@@ -1621,7 +1639,11 @@ def build_app():
1621
 
1622
  total_base_cost = calc_cost(total_base_tokens, base_prices)
1623
  total_routing_cost = calc_cost(total_routing_tokens, routing_prices)
1624
- total_original_cost = calc_cost(total_original_tokens, base_prices)
 
 
 
 
1625
 
1626
  total_routed_cost = total_base_cost + total_routing_cost
1627
  savings = total_original_cost - total_routed_cost
@@ -1645,6 +1667,13 @@ def build_app():
1645
  additional_token_models = [(routing_model_1_val, total_routing_tokens)]
1646
  additional_cost_models = [(routing_model_1_val, routing_costs)]
1647
 
 
 
 
 
 
 
 
1648
  tokens_chart = create_routed_token_chart(total_base_tokens, additional_token_models)
1649
  cost_chart = create_routed_cost_chart(base_costs, additional_cost_models)
1650
 
 
803
 
804
  token_data = pd.DataFrame({
805
  "Token Type": ["Uncached Input", "Cache Read", "Cache Creation", "Completion"],
806
+ "Tokens (M)": [total_uncached_input / 1e6, total_cache_read / 1e6, total_cache_creation / 1e6, total_completion / 1e6],
807
  })
808
 
809
  fig_tokens = px.bar(
810
  token_data,
811
  x="Token Type",
812
+ y="Tokens (M)",
813
  title="Total Tokens by Type",
814
  color="Token Type",
815
  color_discrete_sequence=["#EF553B", "#19D3F3", "#FFA15A", "#AB63FA"],
816
  )
817
  fig_tokens.update_layout(
818
  xaxis_title="Token Type",
819
+ yaxis_title="Tokens (M)",
820
  showlegend=False,
821
  margin=dict(l=40, r=20, t=40, b=40),
822
  )
823
 
824
+ total_all = total_uncached_input + total_cache_read + total_cache_creation + total_completion
825
  fig_tokens.add_annotation(
826
+ text=f"Total: {total_all/1e6:.2f}M",
827
  xref="paper", yref="paper",
828
  x=0.95, y=0.95, showarrow=False,
829
  font=dict(size=12),
 
844
  fig_stacked.add_trace(go.Bar(
845
  name="Uncached Input",
846
  x=df_sorted["trajectory_idx"],
847
+ y=df_sorted["uncached_input_tokens"] / 1e6,
848
  marker_color="#EF553B",
849
+ hovertemplate="Trajectory: %{x}<br>Uncached Input: %{y:.3f}M<extra></extra>",
850
  ))
851
 
852
  fig_stacked.add_trace(go.Bar(
853
  name="Cache Read",
854
  x=df_sorted["trajectory_idx"],
855
+ y=df_sorted["cache_read_tokens"] / 1e6,
856
  marker_color="#19D3F3",
857
+ hovertemplate="Trajectory: %{x}<br>Cache Read: %{y:.3f}M<extra></extra>",
858
  ))
859
 
860
  fig_stacked.add_trace(go.Bar(
861
  name="Cache Creation",
862
  x=df_sorted["trajectory_idx"],
863
+ y=df_sorted["cache_creation_tokens"] / 1e6,
864
  marker_color="#FFA15A",
865
+ hovertemplate="Trajectory: %{x}<br>Cache Creation: %{y:.3f}M<extra></extra>",
866
  ))
867
 
868
  fig_stacked.add_trace(go.Bar(
869
  name="Completion",
870
  x=df_sorted["trajectory_idx"],
871
+ y=df_sorted["completion_tokens"] / 1e6,
872
  marker_color="#AB63FA",
873
+ hovertemplate="Trajectory: %{x}<br>Completion: %{y:.3f}M<extra></extra>",
874
  ))
875
 
876
  fig_stacked.update_layout(
877
  barmode="stack",
878
  title="Tokens per Trajectory (stacked)",
879
  xaxis_title="Trajectory (sorted by total tokens)",
880
+ yaxis_title="Tokens (M)",
881
  legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
882
  margin=dict(l=50, r=20, t=60, b=40),
883
  )
 
1528
  )
1529
  return
1530
 
1531
+ df_key = "meta" if source == "Metadata" else "calculated"
1532
+ df = state_data.get(df_key)
1533
+ if df is not None and not df.empty:
1534
+ if source == "Calculated":
1535
+ df = apply_thinking_overhead(df.copy(), overhead)
1536
+ if not with_cache:
1537
+ df = apply_no_cache(df)
1538
+ df_temp = df.copy()
1539
+ df_temp["uncached_input"] = (df_temp["prompt_tokens"] - df_temp["cache_read_tokens"] - df_temp["cache_creation_tokens"]).clip(lower=0)
1540
+ total_original_cost_from_df = (
1541
+ df_temp["uncached_input"].sum() * base_input / 1e6 +
1542
+ df["cache_read_tokens"].sum() * base_cache_read / 1e6 +
1543
+ df["cache_creation_tokens"].sum() * base_cache_creation / 1e6 +
1544
+ df["completion_tokens"].sum() * base_completion / 1e6
1545
+ )
1546
+ else:
1547
+ total_original_cost_from_df = None
1548
+
1549
  base_prices = {
1550
  "input": base_input,
1551
  "cache_read": base_cache_read,
 
1639
 
1640
  total_base_cost = calc_cost(total_base_tokens, base_prices)
1641
  total_routing_cost = calc_cost(total_routing_tokens, routing_prices)
1642
+
1643
+ if total_original_cost_from_df is not None:
1644
+ total_original_cost = total_original_cost_from_df
1645
+ else:
1646
+ total_original_cost = calc_cost(total_original_tokens, base_prices)
1647
 
1648
  total_routed_cost = total_base_cost + total_routing_cost
1649
  savings = total_original_cost - total_routed_cost
 
1667
  additional_token_models = [(routing_model_1_val, total_routing_tokens)]
1668
  additional_cost_models = [(routing_model_1_val, routing_costs)]
1669
 
1670
+ yield (
1671
+ gr.update(visible=True, value="⏳ Creating charts..."),
1672
+ gr.update(visible=True),
1673
+ None,
1674
+ None,
1675
+ )
1676
+
1677
  tokens_chart = create_routed_token_chart(total_base_tokens, additional_token_models)
1678
  cost_chart = create_routed_cost_chart(base_costs, additional_cost_models)
1679