GBDB02 commited on
Commit
8341f23
·
verified ·
1 Parent(s): 4d204be

Upload visualize.py

Browse files
Files changed (1) hide show
  1. visualize.py +439 -0
visualize.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ visualize.py — Visual diagnostics and statistics for the House Price Predictor.
3
+
4
+ Adds a "📊 Analytics" tab to the Gradio UI that shows:
5
+ 1. Feature Importance — XGBoost gain-based + Lasso coefficient bar charts
6
+ 2. Prediction Distribution — histogram + KDE of predicted prices
7
+ 3. Residual Analysis — residual vs predicted scatter + Q-Q plot
8
+ 4. Training Data Stats — target distribution, correlation heatmap, numeric summary
9
+ 5. Model Comparison — CV RMSE bar chart across the three base learners
10
+ """
11
+
12
+ import os
13
+ import io
14
+ import base64
15
+ import warnings
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+ import joblib
20
+ import matplotlib
21
+ matplotlib.use("Agg") # non-interactive backend for Gradio
22
+ import matplotlib.pyplot as plt
23
+ import matplotlib.gridspec as gridspec
24
+ from matplotlib.ticker import FuncFormatter
25
+ import scipy.stats as stats
26
+
27
+ warnings.filterwarnings("ignore")
28
+
29
+ # ── shared style ──────────────────────────────────────────────────────────────
30
+ PALETTE = ["#2D6A4F", "#40916C", "#74C69D", "#B7E4C7", "#D8F3DC"]
31
+ ACCENT = "#1B4332"
32
+ WARN = "#E76F51"
33
+ BG = "#F8F9FA"
34
+ GRID_CLR = "#DEE2E6"
35
+
36
+ def _style_ax(ax, title="", xlabel="", ylabel=""):
37
+ ax.set_facecolor(BG)
38
+ ax.grid(axis="y", color=GRID_CLR, linewidth=0.7, linestyle="--", zorder=0)
39
+ ax.spines[["top", "right"]].set_visible(False)
40
+ ax.spines[["left", "bottom"]].set_color(GRID_CLR)
41
+ if title: ax.set_title(title, fontsize=12, fontweight="bold", pad=10, color=ACCENT)
42
+ if xlabel: ax.set_xlabel(xlabel, fontsize=9, color="#495057")
43
+ if ylabel: ax.set_ylabel(ylabel, fontsize=9, color="#495057")
44
+ ax.tick_params(colors="#495057", labelsize=8)
45
+
46
+ def _fig_to_image(fig):
47
+ """Convert a matplotlib figure → PIL Image (Gradio gr.Image compatible)."""
48
+ buf = io.BytesIO()
49
+ fig.savefig(buf, format="png", dpi=130, bbox_inches="tight", facecolor=fig.get_facecolor())
50
+ buf.seek(0)
51
+ from PIL import Image
52
+ img = Image.open(buf)
53
+ plt.close(fig)
54
+ return img
55
+
56
+
57
+ # ── helpers ───────────────────────────────────────────────────────────────────
58
+
59
+ def _load_artifacts():
60
+ from config import MODEL_PATH, PREPROCESSOR_PATH, META_PATH
61
+ for p in (MODEL_PATH, PREPROCESSOR_PATH, META_PATH):
62
+ if not os.path.exists(p):
63
+ raise FileNotFoundError("No trained model found. Train the model first.")
64
+ return joblib.load(MODEL_PATH), joblib.load(PREPROCESSOR_PATH), joblib.load(META_PATH)
65
+
66
+
67
+ def _feature_names(preprocessor, meta):
68
+ """Reconstruct feature names after ColumnTransformer."""
69
+ num_feats = meta["numerical_features"]
70
+ try:
71
+ cat_enc = preprocessor.named_transformers_["cat"].named_steps["onehot"]
72
+ cat_feats = cat_enc.get_feature_names_out(meta["categorical_features"]).tolist()
73
+ except Exception:
74
+ cat_feats = []
75
+ return num_feats + cat_feats
76
+
77
+
78
+ # ══════════════════════════════════════════════════════════════════════════════
79
+ # PLOT 1 — Feature Importance
80
+ # ══════════════════════════════════════════════════════════════════════════════
81
+
82
+ def plot_feature_importance():
83
+ try:
84
+ ensemble, preprocessor, meta = _load_artifacts()
85
+ feature_names = _feature_names(preprocessor, meta)
86
+ n = 20 # top-N to show
87
+
88
+ estimators = dict(ensemble.named_estimators_)
89
+
90
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6), facecolor="white")
91
+ fig.suptitle("Feature Importance", fontsize=15, fontweight="bold", color=ACCENT, y=1.01)
92
+
93
+ # ── XGBoost gain importance ──
94
+ ax = axes[0]
95
+ xgb_model = estimators.get("xgb")
96
+ if xgb_model is not None:
97
+ raw_imp = xgb_model.feature_importances_
98
+ n_feat = min(len(raw_imp), len(feature_names))
99
+ imp = pd.Series(raw_imp[:n_feat], index=feature_names[:n_feat])
100
+ top = imp.nlargest(n).sort_values()
101
+ bars = ax.barh(top.index, top.values, color=PALETTE[1], edgecolor="white", height=0.65)
102
+ for bar, val in zip(bars, top.values):
103
+ ax.text(val + top.values.max() * 0.01, bar.get_y() + bar.get_height() / 2,
104
+ f"{val:.4f}", va="center", fontsize=7, color=ACCENT)
105
+ _style_ax(ax, f"XGBoost — Top {n} Features (Gain)", "Importance", "")
106
+ else:
107
+ ax.text(0.5, 0.5, "XGBoost not available", ha="center", va="center")
108
+
109
+ # ── Lasso coefficients ──
110
+ ax = axes[1]
111
+ lasso_model = estimators.get("lasso")
112
+ if lasso_model is not None:
113
+ n_coef = min(len(lasso_model.coef_), len(feature_names))
114
+ coef = pd.Series(np.abs(lasso_model.coef_[:n_coef]), index=feature_names[:n_coef])
115
+ top = coef.nlargest(n).sort_values()
116
+ colors = [PALETTE[0] if v > 0 else WARN for v in top.values]
117
+ bars = ax.barh(top.index, top.values, color=colors, edgecolor="white", height=0.65)
118
+ for bar, val in zip(bars, top.values):
119
+ ax.text(val + top.values.max() * 0.01, bar.get_y() + bar.get_height() / 2,
120
+ f"{val:.4f}", va="center", fontsize=7, color=ACCENT)
121
+ _style_ax(ax, f"Lasso — Top {n} |Coefficients|", "|Coefficient|", "")
122
+ else:
123
+ ax.text(0.5, 0.5, "Lasso not available", ha="center", va="center")
124
+
125
+ fig.tight_layout()
126
+ return _fig_to_image(fig), "✅ Feature importance loaded."
127
+ except Exception as e:
128
+ return None, f"❌ {e}"
129
+
130
+
131
+ # ══════════════════════════════════════════════════════════════════════════════
132
+ # PLOT 2 — Prediction Distribution (requires test CSV)
133
+ # ══════════════════════════════════════════════════════════════════════════════
134
+
135
+ def plot_prediction_distribution(test_file):
136
+ try:
137
+ if test_file is None:
138
+ return None, "Please upload a test.csv file."
139
+ ensemble, preprocessor, meta = _load_artifacts()
140
+
141
+ from predict import _prepare
142
+ test_path = test_file.name if hasattr(test_file, "name") else test_file
143
+ test_df = pd.read_csv(test_path)
144
+ X_test = _prepare(test_df, meta)
145
+ preds = np.expm1(ensemble.predict(preprocessor.transform(X_test)))
146
+
147
+ fig, axes = plt.subplots(1, 2, figsize=(13, 5), facecolor="white")
148
+ fig.suptitle("Predicted Sale Price Distribution", fontsize=15, fontweight="bold", color=ACCENT)
149
+
150
+ # Histogram
151
+ ax = axes[0]
152
+ ax.hist(preds, bins=40, color=PALETTE[1], edgecolor="white", alpha=0.85)
153
+ ax.axvline(np.median(preds), color=WARN, linewidth=1.8, linestyle="--", label=f"Median: ${np.median(preds):,.0f}")
154
+ ax.axvline(np.mean(preds), color=ACCENT, linewidth=1.8, linestyle="-", label=f"Mean: ${np.mean(preds):,.0f}")
155
+ ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"${x/1e3:.0f}k"))
156
+ ax.legend(fontsize=8)
157
+ _style_ax(ax, "Histogram", "Predicted Price", "Count")
158
+
159
+ # Box + strip
160
+ ax = axes[1]
161
+ bp = ax.boxplot(preds, vert=True, patch_artist=True, widths=0.4,
162
+ boxprops=dict(facecolor=PALETTE[2], color=ACCENT),
163
+ medianprops=dict(color=WARN, linewidth=2),
164
+ whiskerprops=dict(color=ACCENT),
165
+ capprops=dict(color=ACCENT),
166
+ flierprops=dict(marker="o", color=PALETTE[0], alpha=0.3, markersize=3))
167
+ jitter = np.random.uniform(-0.15, 0.15, size=len(preds))
168
+ ax.scatter(1 + jitter, preds, alpha=0.12, s=6, color=PALETTE[0], zorder=3)
169
+ ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"${y/1e3:.0f}k"))
170
+ _style_ax(ax, "Box Plot + Jitter", "", "Predicted Price")
171
+ ax.set_xticks([])
172
+
173
+ # Stats table below
174
+ stats_txt = (f"n={len(preds):,} min=${preds.min():,.0f} "
175
+ f"Q1=${np.percentile(preds,25):,.0f} median=${np.median(preds):,.0f} "
176
+ f"Q3=${np.percentile(preds,75):,.0f} max=${preds.max():,.0f}")
177
+ fig.text(0.5, -0.02, stats_txt, ha="center", fontsize=8, color="#6C757D")
178
+ fig.tight_layout()
179
+ return _fig_to_image(fig), f"✅ Predictions generated for {len(preds):,} houses."
180
+ except Exception as e:
181
+ return None, f"❌ {e}"
182
+
183
+
184
+ # ══════════════════════════════════════════════════════════════════════════════
185
+ # PLOT 3 — Residual Analysis (requires train CSV to compute in-sample)
186
+ # ══════════════════════════════════════════════════════════════════════════════
187
+
188
+ def plot_residuals(train_file):
189
+ try:
190
+ if train_file is None:
191
+ return None, "Please upload train.csv to compute residuals."
192
+ ensemble, preprocessor, meta = _load_artifacts()
193
+
194
+ from predict import _prepare
195
+ train_path = train_file.name if hasattr(train_file, "name") else train_file
196
+ train_df = pd.read_csv(train_path)
197
+
198
+ if "SalePrice" not in train_df.columns:
199
+ return None, "train.csv must contain a SalePrice column."
200
+
201
+ y_true = train_df["SalePrice"].copy()
202
+ train_df = train_df.drop(columns=["SalePrice"], errors="ignore")
203
+ X = _prepare(train_df, meta)
204
+ y_pred = np.expm1(ensemble.predict(preprocessor.transform(X)))
205
+ residuals = y_true.values - y_pred
206
+
207
+ fig, axes = plt.subplots(1, 3, figsize=(16, 5), facecolor="white")
208
+ fig.suptitle("Residual Analysis (In-Sample)", fontsize=15, fontweight="bold", color=ACCENT)
209
+
210
+ # Residuals vs Predicted
211
+ ax = axes[0]
212
+ ax.scatter(y_pred, residuals, alpha=0.25, s=12, color=PALETTE[1])
213
+ ax.axhline(0, color=WARN, linewidth=1.5, linestyle="--")
214
+ ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"${x/1e3:.0f}k"))
215
+ ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"${y/1e3:.0f}k"))
216
+ _style_ax(ax, "Residuals vs Predicted", "Predicted Price", "Residual")
217
+
218
+ # Residual histogram
219
+ ax = axes[1]
220
+ ax.hist(residuals, bins=50, color=PALETTE[1], edgecolor="white", alpha=0.85)
221
+ ax.axvline(0, color=WARN, linewidth=1.5, linestyle="--")
222
+ ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"${x/1e3:.0f}k"))
223
+ _style_ax(ax, "Residual Distribution", "Residual", "Count")
224
+
225
+ # Q-Q plot
226
+ ax = axes[2]
227
+ (osm, osr), (slope, intercept, r) = stats.probplot(residuals, dist="norm")
228
+ ax.scatter(osm, osr, alpha=0.3, s=12, color=PALETTE[1])
229
+ line_x = np.array([osm[0], osm[-1]])
230
+ ax.plot(line_x, slope * line_x + intercept, color=WARN, linewidth=1.8)
231
+ _style_ax(ax, f"Q-Q Plot (R²={r**2:.3f})", "Theoretical Quantiles", "Sample Quantiles")
232
+
233
+ rmse = np.sqrt(np.mean(residuals**2))
234
+ mae = np.mean(np.abs(residuals))
235
+ fig.text(0.5, -0.02,
236
+ f"In-sample RMSE: ${rmse:,.0f} | MAE: ${mae:,.0f}",
237
+ ha="center", fontsize=9, color="#6C757D")
238
+ fig.tight_layout()
239
+ return _fig_to_image(fig), f"✅ Residuals computed. RMSE=${rmse:,.0f} MAE=${mae:,.0f}"
240
+ except Exception as e:
241
+ return None, f"❌ {e}"
242
+
243
+
244
+ # ══════════════════════════════════════════════════════════════════════════════
245
+ # PLOT 4 — Training Data Statistics (requires train CSV)
246
+ # ══════════════════════════════════════════════════════════════════════════════
247
+
248
+ def plot_data_stats(train_file):
249
+ try:
250
+ if train_file is None:
251
+ return None, "Please upload train.csv."
252
+ train_path = train_file.name if hasattr(train_file, "name") else train_file
253
+ df = pd.read_csv(train_path)
254
+
255
+ fig = plt.figure(figsize=(16, 10), facecolor="white")
256
+ fig.suptitle("Training Data Statistics", fontsize=15, fontweight="bold", color=ACCENT, y=1.01)
257
+ gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)
258
+
259
+ # ── SalePrice distribution ──
260
+ ax = fig.add_subplot(gs[0, 0])
261
+ ax.hist(df["SalePrice"], bins=50, color=PALETTE[1], edgecolor="white", alpha=0.85)
262
+ ax.axvline(df["SalePrice"].median(), color=WARN, linewidth=1.5, linestyle="--",
263
+ label=f"Median ${df['SalePrice'].median()/1e3:.0f}k")
264
+ ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"${x/1e3:.0f}k"))
265
+ ax.legend(fontsize=7)
266
+ _style_ax(ax, "SalePrice Distribution", "Sale Price", "Count")
267
+
268
+ # ── Log SalePrice ──
269
+ ax = fig.add_subplot(gs[0, 1])
270
+ log_price = np.log1p(df["SalePrice"])
271
+ ax.hist(log_price, bins=50, color=PALETTE[0], edgecolor="white", alpha=0.85)
272
+ _style_ax(ax, "log(SalePrice) Distribution", "log(1 + SalePrice)", "Count")
273
+
274
+ # ── Missing values (top 15) ──
275
+ ax = fig.add_subplot(gs[0, 2])
276
+ missing = (df.isnull().sum() / len(df) * 100).sort_values(ascending=False).head(15)
277
+ missing = missing[missing > 0]
278
+ if len(missing):
279
+ bars = ax.barh(missing.index[::-1], missing.values[::-1],
280
+ color=WARN, edgecolor="white", height=0.6)
281
+ for bar, val in zip(bars, missing.values[::-1]):
282
+ ax.text(val + 0.3, bar.get_y() + bar.get_height() / 2,
283
+ f"{val:.1f}%", va="center", fontsize=7, color=ACCENT)
284
+ _style_ax(ax, "Missing Values (top 15)", "Missing %", "")
285
+
286
+ # ── Overall Quality vs Price ──
287
+ ax = fig.add_subplot(gs[1, 0])
288
+ if "OverallQual" in df.columns:
289
+ groups = [df[df["OverallQual"] == q]["SalePrice"].values
290
+ for q in sorted(df["OverallQual"].unique())]
291
+ labels = sorted(df["OverallQual"].unique())
292
+ bp = ax.boxplot(groups, labels=labels, patch_artist=True,
293
+ boxprops=dict(facecolor=PALETTE[2], color=ACCENT),
294
+ medianprops=dict(color=WARN, linewidth=1.8),
295
+ whiskerprops=dict(color=ACCENT), capprops=dict(color=ACCENT),
296
+ flierprops=dict(marker=".", color=PALETTE[0], alpha=0.3, markersize=4))
297
+ ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"${y/1e3:.0f}k"))
298
+ _style_ax(ax, "Price by Overall Quality", "Quality Score", "Sale Price")
299
+
300
+ # ── Correlation with SalePrice (top 12 numerics) ──
301
+ ax = fig.add_subplot(gs[1, 1])
302
+ num_df = df.select_dtypes(include=[np.number]).drop(columns=["Id"], errors="ignore")
303
+ corr = num_df.corr()["SalePrice"].drop("SalePrice").abs().sort_values(ascending=False).head(12)
304
+ corr_signed = num_df.corr()["SalePrice"].drop("SalePrice").loc[corr.index]
305
+ colors = [PALETTE[0] if v > 0 else WARN for v in corr_signed.values]
306
+ ax.barh(corr.index[::-1], corr.values[::-1], color=colors[::-1], edgecolor="white", height=0.65)
307
+ _style_ax(ax, "Top Correlations with SalePrice", "|Pearson r|", "")
308
+
309
+ # ── Scatter GrLivArea vs SalePrice ──
310
+ ax = fig.add_subplot(gs[1, 2])
311
+ if "GrLivArea" in df.columns:
312
+ sc = ax.scatter(df["GrLivArea"], df["SalePrice"],
313
+ alpha=0.25, s=10, c=df.get("OverallQual", pd.Series(5, index=df.index)),
314
+ cmap="YlGn", edgecolors="none")
315
+ plt.colorbar(sc, ax=ax, label="Overall Quality", shrink=0.8)
316
+ ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"${y/1e3:.0f}k"))
317
+ _style_ax(ax, "GrLivArea vs SalePrice", "Above-Grade Living Area (sqft)", "Sale Price")
318
+
319
+ return _fig_to_image(fig), f"✅ Stats for {len(df):,} training samples loaded."
320
+ except Exception as e:
321
+ return None, f"❌ {e}"
322
+
323
+
324
+ # ══════════════════════════════════════════════════════════════════════════════
325
+ # PLOT 5 — Model CV Comparison (reads saved meta)
326
+ # ══════════════════════════════════════════════════════════════════════════════
327
+
328
+ def plot_model_comparison():
329
+ try:
330
+ _, _, meta = _load_artifacts()
331
+
332
+ cv_scores = meta.get("cv_scores", None)
333
+ if cv_scores is None:
334
+ return None, ("ℹ️ CV score details not stored in this model version.\n"
335
+ "Re-train to enable this chart.")
336
+
337
+ models = list(cv_scores.keys())
338
+ rmses = [cv_scores[m]["rmse"] for m in models]
339
+ stds = [cv_scores[m].get("std", 0) for m in models]
340
+
341
+ fig, ax = plt.subplots(figsize=(7, 4), facecolor="white")
342
+ x = np.arange(len(models))
343
+ bars = ax.bar(x, rmses, yerr=stds, color=PALETTE[:len(models)],
344
+ edgecolor="white", width=0.45, capsize=6,
345
+ error_kw=dict(ecolor=ACCENT, elinewidth=1.5))
346
+ for bar, val in zip(bars, rmses):
347
+ ax.text(bar.get_x() + bar.get_width() / 2, val + max(stds) * 0.05,
348
+ f"{val:.4f}", ha="center", va="bottom", fontsize=9, fontweight="bold", color=ACCENT)
349
+ ax.set_xticks(x)
350
+ ax.set_xticklabels(models, fontsize=10)
351
+ _style_ax(ax, "Cross-Validation RMSE (log scale)", "Model", "CV RMSE (log)")
352
+ fig.tight_layout()
353
+ return _fig_to_image(fig), "✅ Model comparison loaded."
354
+ except Exception as e:
355
+ return None, f"❌ {e}"
356
+
357
+
358
+ # ══════════════════════════════════════════════════════════════════════════════
359
+ # Gradio Tab builder — call this from app.py
360
+ # ══════════════════════════════════════════════════════════════════════════════
361
+
362
+ def build_analytics_tab():
363
+ """
364
+ Returns a gr.Tab block. Import and embed it inside the gr.Tabs() block in app.py.
365
+
366
+ Usage in app.py:
367
+ from visualize import build_analytics_tab
368
+ with gr.Tabs():
369
+ ...existing tabs...
370
+ build_analytics_tab()
371
+ """
372
+ import gradio as gr
373
+
374
+ with gr.Tab("📊 Analytics") as tab:
375
+ gr.Markdown(
376
+ "### Visual Diagnostics\n"
377
+ "Explore model internals, data statistics, predictions and residuals.\n"
378
+ "> **Tip:** Train the model first; some charts also need a CSV upload."
379
+ )
380
+
381
+ with gr.Tabs():
382
+
383
+ # ── Feature Importance ��─────────────────────────────────────────
384
+ with gr.Tab("Feature Importance"):
385
+ gr.Markdown("XGBoost gain-based importance **and** Lasso |coefficients|.")
386
+ btn_fi = gr.Button("Load Feature Importance", variant="primary")
387
+ img_fi = gr.Image(label="Feature Importance", type="pil")
388
+ msg_fi = gr.Markdown()
389
+ btn_fi.click(fn=plot_feature_importance, inputs=[], outputs=[img_fi, msg_fi])
390
+
391
+ # ── Prediction Distribution ─────────────────────────────────────
392
+ with gr.Tab("Prediction Distribution"):
393
+ gr.Markdown("Upload **test.csv** to visualise the distribution of predicted prices.")
394
+ f_pred = gr.File(label="Upload test.csv", file_types=[".csv"])
395
+ btn_pd = gr.Button("Generate Distribution", variant="primary")
396
+ img_pd = gr.Image(label="Prediction Distribution", type="pil")
397
+ msg_pd = gr.Markdown()
398
+ btn_pd.click(fn=plot_prediction_distribution, inputs=[f_pred], outputs=[img_pd, msg_pd])
399
+
400
+ # ── Residual Analysis ───────────────────────────────────────────
401
+ with gr.Tab("Residual Analysis"):
402
+ gr.Markdown("Upload **train.csv** to compute in-sample residuals.")
403
+ f_res = gr.File(label="Upload train.csv", file_types=[".csv"])
404
+ btn_res = gr.Button("Analyse Residuals", variant="primary")
405
+ img_res = gr.Image(label="Residual Analysis", type="pil")
406
+ msg_res = gr.Markdown()
407
+ btn_res.click(fn=plot_residuals, inputs=[f_res], outputs=[img_res, msg_res])
408
+
409
+ # ── Training Data Stats ─────────────────────────────────────────
410
+ with gr.Tab("Data Statistics"):
411
+ gr.Markdown("Upload **train.csv** to explore raw data distributions and correlations.")
412
+ f_stat = gr.File(label="Upload train.csv", file_types=[".csv"])
413
+ btn_st = gr.Button("Show Data Stats", variant="primary")
414
+ img_st = gr.Image(label="Data Statistics", type="pil")
415
+ msg_st = gr.Markdown()
416
+ btn_st.click(fn=plot_data_stats, inputs=[f_stat], outputs=[img_st, msg_st])
417
+
418
+ # ── Model Comparison ────────────────────────────────────────────
419
+ with gr.Tab("Model Comparison"):
420
+ gr.Markdown(
421
+ "CV RMSE across base learners.\n\n"
422
+ "> ⚠️ This chart requires re-training with the updated `train.py` "
423
+ "that saves `cv_scores` to the meta file (see instructions below).\n\n"
424
+ "**To enable this chart**, add the following lines in `train.py` "
425
+ "inside `joblib.dump(...)` for the meta dict:\n"
426
+ "```python\n"
427
+ '"cv_scores": {\n'
428
+ ' "Lasso": {"rmse": lasso_rmse, "std": 0},\n'
429
+ ' "Random Forest": {"rmse": rf_rmse, "std": 0},\n'
430
+ ' "XGBoost": {"rmse": xgb_rmse, "std": 0},\n'
431
+ '},\n'
432
+ "```"
433
+ )
434
+ btn_mc = gr.Button("Load Model Comparison", variant="primary")
435
+ img_mc = gr.Image(label="Model Comparison", type="pil")
436
+ msg_mc = gr.Markdown()
437
+ btn_mc.click(fn=plot_model_comparison, inputs=[], outputs=[img_mc, msg_mc])
438
+
439
+ return tab