Arkm20 commited on
Commit
3ac721e
·
verified ·
1 Parent(s): 48b6847

Update src/charts.py

Browse files
Files changed (1) hide show
  1. src/charts.py +115 -125
src/charts.py CHANGED
@@ -6,42 +6,35 @@ Dark theme, consistent palette, interactive.
6
  import numpy as np
7
  import pandas as pd
8
  import plotly.graph_objects as go
9
- import plotly.express as px
10
  from plotly.subplots import make_subplots
11
 
12
  # ---------------------------------------------------------------------------
13
  # Design tokens
14
  # ---------------------------------------------------------------------------
15
- BG = "#0d0f14"
16
- SURFACE = "#14171f"
17
  SURFACE2 = "#1c202c"
18
- BORDER = "#2a2f3d"
19
- TEXT = "#e4e6ef"
20
  TEXT_MUTED = "#7a7f94"
21
- TEAL = "#00d4aa"
22
  TEAL_DIM = "#007a63"
23
- RED = "#ff4d6a"
24
- AMBER = "#f5a623"
25
- PURPLE = "#a78bfa"
26
- BLUE = "#60a5fa"
27
- GREEN = "#34d399"
28
 
29
- FONT = "JetBrains Mono, monospace"
30
  FONT_BODY = "DM Sans, sans-serif"
31
 
 
 
32
  BASE_LAYOUT = dict(
33
  paper_bgcolor=BG,
34
  plot_bgcolor=SURFACE,
35
  font=dict(family=FONT_BODY, color=TEXT, size=12),
36
  margin=dict(l=48, r=24, t=48, b=40),
37
- xaxis=dict(
38
- gridcolor=BORDER, zerolinecolor=BORDER,
39
- tickfont=dict(family=FONT, size=10, color=TEXT_MUTED),
40
- ),
41
- yaxis=dict(
42
- gridcolor=BORDER, zerolinecolor=BORDER,
43
- tickfont=dict(family=FONT, size=10, color=TEXT_MUTED),
44
- ),
45
  legend=dict(
46
  bgcolor=SURFACE2, bordercolor=BORDER, borderwidth=1,
47
  font=dict(family=FONT_BODY, size=11),
@@ -52,9 +45,41 @@ BASE_LAYOUT = dict(
52
  ),
53
  )
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- def _apply_base(fig):
57
- fig.update_layout(**BASE_LAYOUT)
 
 
 
 
 
 
 
 
 
 
58
  return fig
59
 
60
 
@@ -71,37 +96,32 @@ def nav_chart(nav_df: pd.DataFrame, benchmark_df: pd.DataFrame, initial_cash: fl
71
  subplot_titles=["Portfolio NAV", "Drawdown"],
72
  )
73
 
74
- dates = nav_df["Date"].astype(str)
75
- nav = nav_df["NAV"].values
76
  nav_norm = nav / initial_cash * 100
77
 
78
- # NAV line
79
  fig.add_trace(go.Scatter(
80
  x=dates, y=nav_norm,
81
  name="Sniper Portfolio",
82
  line=dict(color=TEAL, width=2),
83
  fill="tozeroy",
84
- fillcolor=f"rgba(0,212,170,0.06)",
85
  hovertemplate="<b>%{x}</b><br>NAV: %{y:.1f}<extra></extra>",
86
  ), row=1, col=1)
87
 
88
- # Benchmark
89
  if benchmark_df is not None and not benchmark_df.empty and "Benchmark NAV" in benchmark_df.columns:
90
  b_norm = benchmark_df["Benchmark NAV"].values / initial_cash * 100
91
- b_label = "Benchmark"
92
  fig.add_trace(go.Scatter(
93
  x=benchmark_df["Date"].astype(str), y=b_norm,
94
- name=b_label,
95
  line=dict(color=TEXT_MUTED, width=1.5, dash="dash"),
96
  hovertemplate="<b>%{x}</b><br>Benchmark: %{y:.1f}<extra></extra>",
97
  ), row=1, col=1)
98
 
99
- # Cash baseline
100
  fig.add_hline(y=100, line=dict(color=BORDER, width=1, dash="dot"), row=1, col=1)
101
 
102
- # Drawdown
103
  peak = pd.Series(nav).cummax()
104
- dd = (pd.Series(nav) - peak) / peak * 100
105
  colors = [RED if v < -5 else AMBER if v < -2 else TEXT_MUTED for v in dd.values]
106
 
107
  fig.add_trace(go.Bar(
@@ -115,9 +135,10 @@ def nav_chart(nav_df: pd.DataFrame, benchmark_df: pd.DataFrame, initial_cash: fl
115
  fig.update_layout(
116
  **BASE_LAYOUT,
117
  title=dict(text="Portfolio Performance", font=dict(size=16, family=FONT_BODY, color=TEXT)),
118
- xaxis2=dict(gridcolor=BORDER, tickfont=dict(family=FONT, size=10, color=TEXT_MUTED)),
119
- yaxis=dict(ticksuffix="", gridcolor=BORDER),
120
- yaxis2=dict(ticksuffix="%", gridcolor=BORDER),
 
121
  height=520,
122
  )
123
  return fig
@@ -128,26 +149,19 @@ def monthly_returns_heatmap(nav_df: pd.DataFrame, initial_cash: float) -> go.Fig
128
  nav_df["Date"] = pd.to_datetime(nav_df["Date"])
129
  nav_df = nav_df.set_index("Date").sort_index()
130
 
131
- monthly = nav_df["NAV"].resample("ME").last()
132
  monthly_ret = monthly.pct_change() * 100
133
 
134
  if monthly_ret.empty:
135
- fig = go.Figure()
136
- _apply_base(fig)
137
- fig.add_annotation(text="Insufficient data for monthly heatmap",
138
- xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
139
- font=dict(color=TEXT_MUTED))
140
- return fig
141
-
142
- years = sorted(monthly_ret.index.year.unique())
143
- months = list(range(1, 13))
144
  month_labels = ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]
145
 
146
- z = []
147
- text = []
148
  for yr in years:
149
- row_z = []
150
- row_t = []
151
  for mo in months:
152
  mask = (monthly_ret.index.year == yr) & (monthly_ret.index.month == mo)
153
  if mask.sum() > 0:
@@ -168,33 +182,24 @@ def monthly_returns_heatmap(nav_df: pd.DataFrame, initial_cash: float) -> go.Fig
168
  [0.52, "#374151"], [0.65, "#065f46"], [1.0, "#064e3b"],
169
  ],
170
  zmid=0,
171
- colorbar=dict(
172
- tickfont=dict(family=FONT, size=10, color=TEXT_MUTED),
173
- ticksuffix="%",
174
- ),
175
  hovertemplate="<b>%{y} %{x}</b><br>Return: %{text}<extra></extra>",
176
  ))
177
  fig.update_layout(
178
  **BASE_LAYOUT,
179
  title=dict(text="Monthly Returns", font=dict(size=15, family=FONT_BODY, color=TEXT)),
180
  height=max(200, 60 + len(years) * 38),
181
- yaxis=dict(autorange="reversed", gridcolor=BORDER,
182
- tickfont=dict(family=FONT, size=11, color=TEXT_MUTED)),
183
- xaxis=dict(side="top", gridcolor=BORDER,
184
- tickfont=dict(family=FONT_BODY, size=11, color=TEXT_MUTED)),
185
  )
186
  return fig
187
 
188
 
189
  def exit_reasons_chart(trades_df: pd.DataFrame) -> go.Figure:
190
  if trades_df.empty or "Exit Reason" not in trades_df.columns:
191
- fig = go.Figure()
192
- _apply_base(fig)
193
- fig.add_annotation(text="No trades to display", xref="paper", yref="paper",
194
- x=0.5, y=0.5, showarrow=False, font=dict(color=TEXT_MUTED))
195
- return fig
196
 
197
- counts = trades_df["Exit Reason"].value_counts()
198
  color_map = {"Stop Loss": RED, "Take Profit": GREEN, "Time Horizon": AMBER}
199
 
200
  fig = go.Figure(go.Bar(
@@ -211,42 +216,39 @@ def exit_reasons_chart(trades_df: pd.DataFrame) -> go.Figure:
211
  title=dict(text="Exit Breakdown", font=dict(size=15, family=FONT_BODY, color=TEXT)),
212
  height=300,
213
  showlegend=False,
214
- yaxis=dict(gridcolor=BORDER),
215
- xaxis=dict(gridcolor="rgba(0,0,0,0)"),
216
  )
217
  return fig
218
 
219
 
220
  def trade_return_distribution(trades_df: pd.DataFrame) -> go.Figure:
221
  if trades_df.empty or "Return %" not in trades_df.columns:
222
- fig = go.Figure()
223
- _apply_base(fig)
224
- return fig
225
 
226
  returns = trades_df["Return %"].dropna()
227
- wins = returns[returns > 0]
228
- losses = returns[returns <= 0]
229
 
230
  fig = go.Figure()
231
  if len(losses) > 0:
232
  fig.add_trace(go.Histogram(
233
- x=losses, name="Losses", marker_color=RED,
234
- opacity=0.7, nbinsx=20,
235
  hovertemplate="Return: %{x:.1f}%<br>Count: %{y}<extra></extra>",
236
  ))
237
  if len(wins) > 0:
238
  fig.add_trace(go.Histogram(
239
- x=wins, name="Wins", marker_color=GREEN,
240
- opacity=0.7, nbinsx=20,
241
  hovertemplate="Return: %{x:.1f}%<br>Count: %{y}<extra></extra>",
242
  ))
243
  fig.add_vline(x=0, line=dict(color=BORDER, width=1.5, dash="dash"))
244
  fig.update_layout(
245
  **BASE_LAYOUT,
246
  title=dict(text="Return Distribution", font=dict(size=15, family=FONT_BODY, color=TEXT)),
247
- barmode="overlay", height=300,
248
- xaxis=dict(title="Return %", gridcolor=BORDER),
249
- yaxis=dict(title="Trades", gridcolor=BORDER),
 
250
  )
251
  return fig
252
 
@@ -256,18 +258,17 @@ def trade_return_distribution(trades_df: pd.DataFrame) -> go.Figure:
256
  # ---------------------------------------------------------------------------
257
 
258
  def radar_chart(dimension_results: list) -> go.Figure:
259
- names = [d.name.replace("_", " ").title() for d in dimension_results]
260
  scores = [d.score for d in dimension_results]
261
 
262
- # Close the polygon
263
- names_closed = names + [names[0]]
264
  scores_closed = scores + [scores[0]]
265
 
266
  fig = go.Figure()
267
  fig.add_trace(go.Scatterpolar(
268
  r=scores_closed, theta=names_closed,
269
  fill="toself",
270
- fillcolor=f"rgba(0,212,170,0.15)",
271
  line=dict(color=TEAL, width=2),
272
  marker=dict(color=TEAL, size=7),
273
  hovertemplate="<b>%{theta}</b><br>Score: %{r:.1f}<extra></extra>",
@@ -292,23 +293,21 @@ def radar_chart(dimension_results: list) -> go.Figure:
292
  height=420,
293
  margin=dict(l=60, r=60, t=60, b=60),
294
  showlegend=False,
 
295
  )
296
  return fig
297
 
298
 
299
  def reliability_diagram(reliability_bins: dict) -> go.Figure:
300
- bins = reliability_bins.get("bin_centers", [])
301
  actual = reliability_bins.get("actual_freqs", [])
302
  counts = reliability_bins.get("bin_counts", [])
303
 
304
  if not bins:
305
- fig = go.Figure()
306
- _apply_base(fig)
307
- return fig
308
 
309
  fig = go.Figure()
310
 
311
- # Perfect calibration diagonal
312
  fig.add_trace(go.Scatter(
313
  x=[0, 1], y=[0, 1], mode="lines",
314
  line=dict(color=BORDER, dash="dash", width=1.5),
@@ -316,7 +315,6 @@ def reliability_diagram(reliability_bins: dict) -> go.Figure:
316
  hoverinfo="skip",
317
  ))
318
 
319
- # Actual calibration
320
  valid = [(b, a, c) for b, a, c in zip(bins, actual, counts) if c > 0]
321
  if valid:
322
  vb, va, vc = zip(*valid)
@@ -335,59 +333,52 @@ def reliability_diagram(reliability_bins: dict) -> go.Figure:
335
  **BASE_LAYOUT,
336
  title=dict(text="Reliability Diagram (Calibration)", font=dict(size=15, family=FONT_BODY, color=TEXT)),
337
  height=340,
338
- xaxis=dict(title="Mean Predicted Probability", range=[-0.02, 1.02], gridcolor=BORDER),
339
- yaxis=dict(title="Actual Positive Fraction", range=[-0.02, 1.02], gridcolor=BORDER),
340
  )
341
  return fig
342
 
343
 
344
  def regime_heatmap(regime_scores: dict) -> go.Figure:
345
- labels_order = ["Bear / Low VIX", "Bear / High VIX", "Bull / Low VIX", "Bull / High VIX"]
346
  x_labels = ["Low VIX", "High VIX"]
347
- y_labels = ["Bear Market", "Bull Market"]
348
 
349
- z = [[None, None], [None, None]]
350
  text = [["", ""], ["", ""]]
351
 
352
  for name, data in regime_scores.items():
353
  auc = data.get("auc")
354
- n = data.get("n", 0)
355
- if "Bear" in name:
356
- row = 0
357
- else:
358
- row = 1
359
  col = 1 if "High VIX" in name else 0
360
- z[row][col] = auc
361
- text[row][col] = f"AUC: {auc:.3f}<br>n={n:,}" if auc is not None else f"n={n}<br>insufficient"
 
362
 
363
  fig = go.Figure(go.Heatmap(
364
  z=z, x=x_labels, y=y_labels,
365
  text=text, texttemplate="%{text}",
366
  colorscale=[[0, "#7f1d1d"], [0.5, SURFACE2], [1, "#064e3b"]],
367
  zmin=0.4, zmax=0.8, zmid=0.55,
368
- colorbar=dict(
369
- title="AUC",
370
- tickfont=dict(family=FONT, size=10, color=TEXT_MUTED),
371
- ),
372
  hovertemplate="<b>%{y} / %{x}</b><br>%{text}<extra></extra>",
373
  ))
374
  fig.update_layout(
375
  **BASE_LAYOUT,
376
- title=dict(text="Regime Robustness (AUC by Market Condition)", font=dict(size=15, family=FONT_BODY, color=TEXT)),
 
377
  height=280,
378
- yaxis=dict(autorange="reversed", gridcolor=BORDER),
379
- xaxis=dict(gridcolor=BORDER),
380
  )
381
  return fig
382
 
383
 
384
  def feature_psi_chart(psi_df: pd.DataFrame, top_n: int = 25) -> go.Figure:
385
  if psi_df.empty:
386
- fig = go.Figure()
387
- _apply_base(fig)
388
- return fig
389
 
390
- top = psi_df.head(top_n)
391
  colors = []
392
  for s in top["Status"]:
393
  if "🔴" in s:
@@ -404,31 +395,30 @@ def feature_psi_chart(psi_df: pd.DataFrame, top_n: int = 25) -> go.Figure:
404
  marker_color=colors,
405
  hovertemplate="<b>%{y}</b><br>PSI: %{x:.4f}<extra></extra>",
406
  ))
407
- fig.add_vline(x=0.2, line=dict(color=RED, width=1.5, dash="dot"),
408
- annotation=dict(text="High drift", font=dict(color=RED, size=10), y=1.02))
409
- fig.add_vline(x=0.1, line=dict(color=AMBER, width=1, dash="dot"),
410
- annotation=dict(text="Watch", font=dict(color=AMBER, size=10), y=0.95))
411
  fig.update_layout(
412
  **BASE_LAYOUT,
413
- title=dict(text=f"Feature PSI — Top {top_n} by Drift", font=dict(size=15, family=FONT_BODY, color=TEXT)),
 
414
  height=max(300, 20 * len(top) + 100),
415
- xaxis=dict(title="Population Stability Index", gridcolor=BORDER),
416
- yaxis=dict(gridcolor="rgba(0,0,0,0)", autorange="reversed",
417
- tickfont=dict(family=FONT, size=10, color=TEXT_MUTED)),
418
  showlegend=False,
 
 
 
 
419
  )
420
  return fig
421
 
422
 
423
  def multi_model_comparison(results: list, model_labels: list) -> go.Figure:
424
- """Bar chart comparing multiple model scores across dimensions."""
425
  if not results:
426
- fig = go.Figure()
427
- _apply_base(fig)
428
- return fig
429
 
430
  dim_names = [d.name.replace("_", " ").title() for d in results[0].dimensions]
431
- palette = [TEAL, PURPLE, AMBER, BLUE, GREEN, RED]
432
 
433
  fig = go.Figure()
434
  for i, (result, label) in enumerate(zip(results, model_labels)):
@@ -447,7 +437,7 @@ def multi_model_comparison(results: list, model_labels: list) -> go.Figure:
447
  title=dict(text="Model Comparison", font=dict(size=16, family=FONT_BODY, color=TEXT)),
448
  barmode="group",
449
  height=380,
450
- yaxis=dict(range=[0, 105], title="Score (0–100)", gridcolor=BORDER),
451
- xaxis=dict(gridcolor="rgba(0,0,0,0)"),
452
  )
453
  return fig
 
6
  import numpy as np
7
  import pandas as pd
8
  import plotly.graph_objects as go
 
9
  from plotly.subplots import make_subplots
10
 
11
  # ---------------------------------------------------------------------------
12
  # Design tokens
13
  # ---------------------------------------------------------------------------
14
+ BG = "#0d0f14"
15
+ SURFACE = "#14171f"
16
  SURFACE2 = "#1c202c"
17
+ BORDER = "#2a2f3d"
18
+ TEXT = "#e4e6ef"
19
  TEXT_MUTED = "#7a7f94"
20
+ TEAL = "#00d4aa"
21
  TEAL_DIM = "#007a63"
22
+ RED = "#ff4d6a"
23
+ AMBER = "#f5a623"
24
+ PURPLE = "#a78bfa"
25
+ BLUE = "#60a5fa"
26
+ GREEN = "#34d399"
27
 
28
+ FONT = "JetBrains Mono, monospace"
29
  FONT_BODY = "DM Sans, sans-serif"
30
 
31
+ # Base layout WITHOUT xaxis/yaxis — those are always passed per-chart
32
+ # to avoid "multiple values for keyword argument" when callers also pass them.
33
  BASE_LAYOUT = dict(
34
  paper_bgcolor=BG,
35
  plot_bgcolor=SURFACE,
36
  font=dict(family=FONT_BODY, color=TEXT, size=12),
37
  margin=dict(l=48, r=24, t=48, b=40),
 
 
 
 
 
 
 
 
38
  legend=dict(
39
  bgcolor=SURFACE2, bordercolor=BORDER, borderwidth=1,
40
  font=dict(family=FONT_BODY, size=11),
 
45
  ),
46
  )
47
 
48
+ # Reusable axis style helpers
49
+ _XAXIS = dict(gridcolor=BORDER, zerolinecolor=BORDER,
50
+ tickfont=dict(family=FONT, size=10, color=TEXT_MUTED))
51
+ _YAXIS = dict(gridcolor=BORDER, zerolinecolor=BORDER,
52
+ tickfont=dict(family=FONT, size=10, color=TEXT_MUTED))
53
+
54
+
55
+ def _ax(**kwargs):
56
+ """Merged x-axis style dict."""
57
+ return {**_XAXIS, **kwargs}
58
+
59
+
60
+ def _ay(**kwargs):
61
+ """Merged y-axis style dict."""
62
+ return {**_YAXIS, **kwargs}
63
+
64
+
65
+ def _apply_base(fig, **extra):
66
+ """Apply BASE_LAYOUT + any per-chart overrides."""
67
+ fig.update_layout(**BASE_LAYOUT, **extra)
68
+ return fig
69
+
70
 
71
+ def _empty_fig(msg="No data"):
72
+ fig = go.Figure()
73
+ fig.update_layout(
74
+ **BASE_LAYOUT,
75
+ xaxis=_ax(), yaxis=_ay(),
76
+ annotations=[dict(
77
+ text=msg, xref="paper", yref="paper",
78
+ x=0.5, y=0.5, showarrow=False,
79
+ font=dict(size=13, color=TEXT_MUTED),
80
+ )],
81
+ height=300,
82
+ )
83
  return fig
84
 
85
 
 
96
  subplot_titles=["Portfolio NAV", "Drawdown"],
97
  )
98
 
99
+ dates = nav_df["Date"].astype(str)
100
+ nav = nav_df["NAV"].values
101
  nav_norm = nav / initial_cash * 100
102
 
 
103
  fig.add_trace(go.Scatter(
104
  x=dates, y=nav_norm,
105
  name="Sniper Portfolio",
106
  line=dict(color=TEAL, width=2),
107
  fill="tozeroy",
108
+ fillcolor="rgba(0,212,170,0.06)",
109
  hovertemplate="<b>%{x}</b><br>NAV: %{y:.1f}<extra></extra>",
110
  ), row=1, col=1)
111
 
 
112
  if benchmark_df is not None and not benchmark_df.empty and "Benchmark NAV" in benchmark_df.columns:
113
  b_norm = benchmark_df["Benchmark NAV"].values / initial_cash * 100
 
114
  fig.add_trace(go.Scatter(
115
  x=benchmark_df["Date"].astype(str), y=b_norm,
116
+ name="Benchmark",
117
  line=dict(color=TEXT_MUTED, width=1.5, dash="dash"),
118
  hovertemplate="<b>%{x}</b><br>Benchmark: %{y:.1f}<extra></extra>",
119
  ), row=1, col=1)
120
 
 
121
  fig.add_hline(y=100, line=dict(color=BORDER, width=1, dash="dot"), row=1, col=1)
122
 
 
123
  peak = pd.Series(nav).cummax()
124
+ dd = (pd.Series(nav) - peak) / peak * 100
125
  colors = [RED if v < -5 else AMBER if v < -2 else TEXT_MUTED for v in dd.values]
126
 
127
  fig.add_trace(go.Bar(
 
135
  fig.update_layout(
136
  **BASE_LAYOUT,
137
  title=dict(text="Portfolio Performance", font=dict(size=16, family=FONT_BODY, color=TEXT)),
138
+ xaxis=_ax(),
139
+ yaxis=_ay(),
140
+ xaxis2=_ax(),
141
+ yaxis2=_ay(ticksuffix="%"),
142
  height=520,
143
  )
144
  return fig
 
149
  nav_df["Date"] = pd.to_datetime(nav_df["Date"])
150
  nav_df = nav_df.set_index("Date").sort_index()
151
 
152
+ monthly = nav_df["NAV"].resample("ME").last()
153
  monthly_ret = monthly.pct_change() * 100
154
 
155
  if monthly_ret.empty:
156
+ return _empty_fig("Insufficient data for monthly heatmap")
157
+
158
+ years = sorted(monthly_ret.index.year.unique())
159
+ months = list(range(1, 13))
 
 
 
 
 
160
  month_labels = ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]
161
 
162
+ z, text = [], []
 
163
  for yr in years:
164
+ row_z, row_t = [], []
 
165
  for mo in months:
166
  mask = (monthly_ret.index.year == yr) & (monthly_ret.index.month == mo)
167
  if mask.sum() > 0:
 
182
  [0.52, "#374151"], [0.65, "#065f46"], [1.0, "#064e3b"],
183
  ],
184
  zmid=0,
185
+ colorbar=dict(tickfont=dict(family=FONT, size=10, color=TEXT_MUTED), ticksuffix="%"),
 
 
 
186
  hovertemplate="<b>%{y} %{x}</b><br>Return: %{text}<extra></extra>",
187
  ))
188
  fig.update_layout(
189
  **BASE_LAYOUT,
190
  title=dict(text="Monthly Returns", font=dict(size=15, family=FONT_BODY, color=TEXT)),
191
  height=max(200, 60 + len(years) * 38),
192
+ xaxis=_ax(side="top"),
193
+ yaxis=_ay(autorange="reversed"),
 
 
194
  )
195
  return fig
196
 
197
 
198
  def exit_reasons_chart(trades_df: pd.DataFrame) -> go.Figure:
199
  if trades_df.empty or "Exit Reason" not in trades_df.columns:
200
+ return _empty_fig("No trades to display")
 
 
 
 
201
 
202
+ counts = trades_df["Exit Reason"].value_counts()
203
  color_map = {"Stop Loss": RED, "Take Profit": GREEN, "Time Horizon": AMBER}
204
 
205
  fig = go.Figure(go.Bar(
 
216
  title=dict(text="Exit Breakdown", font=dict(size=15, family=FONT_BODY, color=TEXT)),
217
  height=300,
218
  showlegend=False,
219
+ xaxis=_ax(gridcolor="rgba(0,0,0,0)"),
220
+ yaxis=_ay(),
221
  )
222
  return fig
223
 
224
 
225
  def trade_return_distribution(trades_df: pd.DataFrame) -> go.Figure:
226
  if trades_df.empty or "Return %" not in trades_df.columns:
227
+ return _empty_fig("No trades to display")
 
 
228
 
229
  returns = trades_df["Return %"].dropna()
230
+ wins = returns[returns > 0]
231
+ losses = returns[returns <= 0]
232
 
233
  fig = go.Figure()
234
  if len(losses) > 0:
235
  fig.add_trace(go.Histogram(
236
+ x=losses, name="Losses", marker_color=RED, opacity=0.7, nbinsx=20,
 
237
  hovertemplate="Return: %{x:.1f}%<br>Count: %{y}<extra></extra>",
238
  ))
239
  if len(wins) > 0:
240
  fig.add_trace(go.Histogram(
241
+ x=wins, name="Wins", marker_color=GREEN, opacity=0.7, nbinsx=20,
 
242
  hovertemplate="Return: %{x:.1f}%<br>Count: %{y}<extra></extra>",
243
  ))
244
  fig.add_vline(x=0, line=dict(color=BORDER, width=1.5, dash="dash"))
245
  fig.update_layout(
246
  **BASE_LAYOUT,
247
  title=dict(text="Return Distribution", font=dict(size=15, family=FONT_BODY, color=TEXT)),
248
+ barmode="overlay",
249
+ height=300,
250
+ xaxis=_ax(title="Return %"),
251
+ yaxis=_ay(title="Trades"),
252
  )
253
  return fig
254
 
 
258
  # ---------------------------------------------------------------------------
259
 
260
  def radar_chart(dimension_results: list) -> go.Figure:
261
+ names = [d.name.replace("_", " ").title() for d in dimension_results]
262
  scores = [d.score for d in dimension_results]
263
 
264
+ names_closed = names + [names[0]]
 
265
  scores_closed = scores + [scores[0]]
266
 
267
  fig = go.Figure()
268
  fig.add_trace(go.Scatterpolar(
269
  r=scores_closed, theta=names_closed,
270
  fill="toself",
271
+ fillcolor="rgba(0,212,170,0.15)",
272
  line=dict(color=TEAL, width=2),
273
  marker=dict(color=TEAL, size=7),
274
  hovertemplate="<b>%{theta}</b><br>Score: %{r:.1f}<extra></extra>",
 
293
  height=420,
294
  margin=dict(l=60, r=60, t=60, b=60),
295
  showlegend=False,
296
+ hoverlabel=dict(bgcolor=SURFACE2, bordercolor=BORDER, font=dict(family=FONT_BODY, size=12)),
297
  )
298
  return fig
299
 
300
 
301
  def reliability_diagram(reliability_bins: dict) -> go.Figure:
302
+ bins = reliability_bins.get("bin_centers", [])
303
  actual = reliability_bins.get("actual_freqs", [])
304
  counts = reliability_bins.get("bin_counts", [])
305
 
306
  if not bins:
307
+ return _empty_fig("No calibration data")
 
 
308
 
309
  fig = go.Figure()
310
 
 
311
  fig.add_trace(go.Scatter(
312
  x=[0, 1], y=[0, 1], mode="lines",
313
  line=dict(color=BORDER, dash="dash", width=1.5),
 
315
  hoverinfo="skip",
316
  ))
317
 
 
318
  valid = [(b, a, c) for b, a, c in zip(bins, actual, counts) if c > 0]
319
  if valid:
320
  vb, va, vc = zip(*valid)
 
333
  **BASE_LAYOUT,
334
  title=dict(text="Reliability Diagram (Calibration)", font=dict(size=15, family=FONT_BODY, color=TEXT)),
335
  height=340,
336
+ xaxis=_ax(title="Mean Predicted Probability", range=[-0.02, 1.02]),
337
+ yaxis=_ay(title="Actual Positive Fraction", range=[-0.02, 1.02]),
338
  )
339
  return fig
340
 
341
 
342
  def regime_heatmap(regime_scores: dict) -> go.Figure:
 
343
  x_labels = ["Low VIX", "High VIX"]
344
+ y_labels = ["Bear Market", "Bull Market"]
345
 
346
+ z = [[None, None], [None, None]]
347
  text = [["", ""], ["", ""]]
348
 
349
  for name, data in regime_scores.items():
350
  auc = data.get("auc")
351
+ n = data.get("n", 0)
352
+ row = 1 if "Bull" in name else 0
 
 
 
353
  col = 1 if "High VIX" in name else 0
354
+ z[row][col] = auc
355
+ text[row][col] = (f"AUC: {auc:.3f}<br>n={n:,}" if auc is not None
356
+ else f"n={n}<br>insufficient")
357
 
358
  fig = go.Figure(go.Heatmap(
359
  z=z, x=x_labels, y=y_labels,
360
  text=text, texttemplate="%{text}",
361
  colorscale=[[0, "#7f1d1d"], [0.5, SURFACE2], [1, "#064e3b"]],
362
  zmin=0.4, zmax=0.8, zmid=0.55,
363
+ colorbar=dict(title="AUC", tickfont=dict(family=FONT, size=10, color=TEXT_MUTED)),
 
 
 
364
  hovertemplate="<b>%{y} / %{x}</b><br>%{text}<extra></extra>",
365
  ))
366
  fig.update_layout(
367
  **BASE_LAYOUT,
368
+ title=dict(text="Regime Robustness (AUC by Market Condition)",
369
+ font=dict(size=15, family=FONT_BODY, color=TEXT)),
370
  height=280,
371
+ xaxis=_ax(),
372
+ yaxis=_ay(autorange="reversed"),
373
  )
374
  return fig
375
 
376
 
377
  def feature_psi_chart(psi_df: pd.DataFrame, top_n: int = 25) -> go.Figure:
378
  if psi_df.empty:
379
+ return _empty_fig("No PSI data available")
 
 
380
 
381
+ top = psi_df.head(top_n)
382
  colors = []
383
  for s in top["Status"]:
384
  if "🔴" in s:
 
395
  marker_color=colors,
396
  hovertemplate="<b>%{y}</b><br>PSI: %{x:.4f}<extra></extra>",
397
  ))
398
+ fig.add_vline(x=0.2, line=dict(color=RED, width=1.5, dash="dot"),
399
+ annotation=dict(text="High drift", font=dict(color=RED, size=10), y=1.02))
400
+ fig.add_vline(x=0.1, line=dict(color=AMBER, width=1, dash="dot"),
401
+ annotation=dict(text="Watch", font=dict(color=AMBER, size=10), y=0.95))
402
  fig.update_layout(
403
  **BASE_LAYOUT,
404
+ title=dict(text=f"Feature PSI — Top {top_n} by Drift",
405
+ font=dict(size=15, family=FONT_BODY, color=TEXT)),
406
  height=max(300, 20 * len(top) + 100),
 
 
 
407
  showlegend=False,
408
+ xaxis=_ax(title="Population Stability Index"),
409
+ yaxis=_ay(autorange="reversed",
410
+ tickfont=dict(family=FONT, size=10, color=TEXT_MUTED),
411
+ gridcolor="rgba(0,0,0,0)"),
412
  )
413
  return fig
414
 
415
 
416
  def multi_model_comparison(results: list, model_labels: list) -> go.Figure:
 
417
  if not results:
418
+ return _empty_fig("No results to compare")
 
 
419
 
420
  dim_names = [d.name.replace("_", " ").title() for d in results[0].dimensions]
421
+ palette = [TEAL, PURPLE, AMBER, BLUE, GREEN, RED]
422
 
423
  fig = go.Figure()
424
  for i, (result, label) in enumerate(zip(results, model_labels)):
 
437
  title=dict(text="Model Comparison", font=dict(size=16, family=FONT_BODY, color=TEXT)),
438
  barmode="group",
439
  height=380,
440
+ xaxis=_ax(gridcolor="rgba(0,0,0,0)"),
441
+ yaxis=_ay(range=[0, 105], title="Score (0–100)"),
442
  )
443
  return fig