| from __future__ import annotations |
|
|
| import json |
| import os |
| import subprocess |
| from pathlib import Path |
| from typing import Generator |
|
|
| import gradio as gr |
| import joblib |
| import matplotlib.pyplot as plt |
| import pandas as pd |
| import shap |
|
|
| APP_DIR = Path(__file__).parent.resolve() |
| STYLE_FILE = APP_DIR / "style.css" |
| ASSETS_DIR = APP_DIR / "assets" |
| DATA_DIR = APP_DIR / "data" |
| MODELS_DIR = APP_DIR / "models" |
| OUT_DIR = APP_DIR / "outputs" |
| FIG_DIR = OUT_DIR / "figures" |
| TAB_DIR = OUT_DIR / "tables" |
|
|
| MODEL_FILE = MODELS_DIR / "pipeline.joblib" |
| META_FILE = MODELS_DIR / "model_meta.json" |
| BG_FILE = MODELS_DIR / "background_sample.csv" |
| TEMPLATE_CSV = DATA_DIR / "batch_template.csv" |
|
|
| DEFAULTS = { |
| "AGE": 42, |
| "OPEN_ACC_DUR": 120, |
| "GENDER_CD": "1", |
| "HASNT_HOME_ADDRESS_INF": "N", |
| "HASNT_MOBILE_TEL_NUM_INF": "N", |
| "LOCAL_CUR_MON_AVG_BAL": 25000.0, |
| "LOCAL_FIX_MON_AVG_BAL": 18000.0, |
| "LOCAL_SAV_CUR_ALL_BAL": 28000.0, |
| "POS_CONSUME_TX_AMT": 5000.0, |
| "ATM_ALL_TX_NUM": 6, |
| "COUNTER_ALL_TX_NUM": 2, |
| } |
| FEATURES = list(DEFAULTS.keys()) |
|
|
| PIPE = None |
| META = None |
|
|
|
|
| def ensure_template_csv() -> None: |
| if not TEMPLATE_CSV.exists(): |
| pd.DataFrame([DEFAULTS]).to_csv(TEMPLATE_CSV, index=False) |
|
|
|
|
| def load_assets() -> tuple[object | None, dict | None]: |
| pipe = joblib.load(MODEL_FILE) if MODEL_FILE.exists() else None |
| meta = json.loads(META_FILE.read_text(encoding="utf-8")) if META_FILE.exists() else None |
| return pipe, meta |
|
|
|
|
| def refresh_model_state() -> str: |
| global PIPE, META |
| PIPE, META = load_assets() |
| if PIPE is None: |
| return "⚠️ 当前为演示状态:请先在 Pipeline 标签页点击 **Run Pipeline** 生成模型。" |
| return "✅ 模型已加载,可以进行单条预测、批量预测和 SHAP 解释。" |
|
|
|
|
| def gauge_html(prob: float) -> str: |
| pct = max(0.0, min(100.0, prob * 100.0)) |
| color = "#16a34a" if prob < 0.35 else ("#f59e0b" if prob < 0.65 else "#dc2626") |
| return f""" |
| <div style='background:rgba(255,255,255,0.88);padding:16px;border-radius:18px'> |
| <div style='font-size:18px;font-weight:700;margin-bottom:8px'>Churn Probability Gauge</div> |
| <div style='width:100%;height:20px;background:#e5e7eb;border-radius:999px;overflow:hidden'> |
| <div style='width:{pct:.1f}%;height:20px;background:{color};border-radius:999px'></div> |
| </div> |
| <div style='margin-top:10px;font-size:28px;font-weight:800;color:{color}'>{pct:.1f}%</div> |
| </div> |
| """ |
|
|
|
|
| def input_df(age, open_acc_dur, gender_cd, hasnt_home_address_inf, hasnt_mobile_tel_num_inf, |
| local_cur_mon_avg_bal, local_fix_mon_avg_bal, local_sav_cur_all_bal, |
| pos_consume_tx_amt, atm_all_tx_num, counter_all_tx_num) -> pd.DataFrame: |
| return pd.DataFrame([{ |
| "AGE": int(age), |
| "OPEN_ACC_DUR": int(open_acc_dur), |
| "GENDER_CD": str(gender_cd), |
| "HASNT_HOME_ADDRESS_INF": str(hasnt_home_address_inf), |
| "HASNT_MOBILE_TEL_NUM_INF": str(hasnt_mobile_tel_num_inf), |
| "LOCAL_CUR_MON_AVG_BAL": float(local_cur_mon_avg_bal), |
| "LOCAL_FIX_MON_AVG_BAL": float(local_fix_mon_avg_bal), |
| "LOCAL_SAV_CUR_ALL_BAL": float(local_sav_cur_all_bal), |
| "POS_CONSUME_TX_AMT": float(pos_consume_tx_amt), |
| "ATM_ALL_TX_NUM": int(atm_all_tx_num), |
| "COUNTER_ALL_TX_NUM": int(counter_all_tx_num), |
| }]) |
|
|
|
|
| def predict_single(age, open_acc_dur, gender_cd, hasnt_home_address_inf, hasnt_mobile_tel_num_inf, |
| local_cur_mon_avg_bal, local_fix_mon_avg_bal, local_sav_cur_all_bal, |
| pos_consume_tx_amt, atm_all_tx_num, counter_all_tx_num): |
| if PIPE is None: |
| return {"error": "Run Pipeline first."}, "请先运行 Pipeline。", gauge_html(0.0), None |
| df = input_df(age, open_acc_dur, gender_cd, hasnt_home_address_inf, hasnt_mobile_tel_num_inf, |
| local_cur_mon_avg_bal, local_fix_mon_avg_bal, local_sav_cur_all_bal, |
| pos_consume_tx_amt, atm_all_tx_num, counter_all_tx_num) |
| prob = float(PIPE.predict_proba(df)[0, 1]) |
| pred = int(prob >= 0.5) |
| risk = "低风险" if prob < 0.35 else ("中风险" if prob < 0.65 else "高风险") |
| payload = { |
| "churn_probability": round(prob, 6), |
| "predicted_label": pred, |
| "risk_level": risk, |
| } |
| summary = f"**预测结果**:{'流失' if pred == 1 else '留存'} \n\n**概率**:{prob:.2%} \n**风险等级**:{risk}" |
| return payload, summary, gauge_html(prob), None |
|
|
|
|
| def predict_batch(file_obj): |
| if PIPE is None: |
| return None, None, "请先运行 Pipeline。" |
| if file_obj is None: |
| return None, None, "请先上传 CSV。" |
| df = pd.read_csv(file_obj.name) |
| missing = [c for c in FEATURES if c not in df.columns] |
| if missing: |
| return None, None, f"CSV 缺少列:{missing}" |
| x = df[FEATURES].copy() |
| proba = PIPE.predict_proba(x)[:, 1] |
| pred = (proba >= 0.5).astype(int) |
| out = df.copy() |
| out["churn_proba"] = proba |
| out["churn_pred"] = pred |
| out_path = OUT_DIR / "batch_predictions.csv" |
| out.to_csv(out_path, index=False) |
| return out.head(50), str(out_path), "批量预测完成。" |
|
|
|
|
| def make_feature_importance_plot(): |
| fp = TAB_DIR / "feature_importance.csv" |
| if not fp.exists(): |
| return None |
| fi = pd.read_csv(fp) |
| plt.figure(figsize=(8, 4.5)) |
| plt.barh(fi["feature"][::-1], fi["importance"][::-1]) |
| plt.title("Feature Importance") |
| plt.xlabel("Importance") |
| plt.tight_layout() |
| fig_path = FIG_DIR / "feature_importance_runtime.png" |
| plt.savefig(fig_path, dpi=160) |
| plt.close() |
| return str(fig_path) |
|
|
|
|
| def explain_single(age, open_acc_dur, gender_cd, hasnt_home_address_inf, hasnt_mobile_tel_num_inf, |
| local_cur_mon_avg_bal, local_fix_mon_avg_bal, local_sav_cur_all_bal, |
| pos_consume_tx_amt, atm_all_tx_num, counter_all_tx_num): |
| if PIPE is None or not BG_FILE.exists(): |
| return None, "请先运行 Pipeline。" |
|
|
| row = input_df(age, open_acc_dur, gender_cd, hasnt_home_address_inf, hasnt_mobile_tel_num_inf, |
| local_cur_mon_avg_bal, local_fix_mon_avg_bal, local_sav_cur_all_bal, |
| pos_consume_tx_amt, atm_all_tx_num, counter_all_tx_num) |
| background = pd.read_csv(BG_FILE) |
| background = background[FEATURES].head(40) |
|
|
| def f(x): |
| x_df = pd.DataFrame(x, columns=FEATURES) |
| for c in ["GENDER_CD", "HASNT_HOME_ADDRESS_INF", "HASNT_MOBILE_TEL_NUM_INF"]: |
| x_df[c] = x_df[c].astype(str) |
| for c in [col for col in FEATURES if col not in ["GENDER_CD", "HASNT_HOME_ADDRESS_INF", "HASNT_MOBILE_TEL_NUM_INF"]]: |
| x_df[c] = pd.to_numeric(x_df[c], errors="coerce") |
| return PIPE.predict_proba(x_df)[:, 1] |
|
|
| explainer = shap.Explainer(f, background, feature_names=FEATURES) |
| sv = explainer(row) |
|
|
| plt.figure(figsize=(9, 4.8)) |
| shap.plots.waterfall(sv[0], max_display=10, show=False) |
| plt.tight_layout() |
| out_path = FIG_DIR / "shap_waterfall.png" |
| plt.savefig(out_path, dpi=160, bbox_inches="tight") |
| plt.close() |
| prob = float(PIPE.predict_proba(row)[0, 1]) |
| txt = f"SHAP 解释已生成。该客户流失概率约为 **{prob:.2%}**。" |
| return str(out_path), txt |
|
|
|
|
| def run_pipeline_stream() -> Generator[tuple[str, str, str], None, None]: |
| log_lines = [] |
| cmd = ["python", "-u", str(APP_DIR / "scripts" / "pipeline.py")] |
| proc = subprocess.Popen(cmd, cwd=str(APP_DIR), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) |
| assert proc.stdout is not None |
| yield "", "⏳ Pipeline 正在运行...", refresh_model_state() |
| for line in proc.stdout: |
| log_lines.append(line.rstrip("\n")) |
| if len(log_lines) > 400: |
| log_lines = log_lines[-400:] |
| yield "\n".join(log_lines), "⏳ Pipeline 正在运行...", refresh_model_state() |
| rc = proc.wait() |
| status = "✅ Pipeline 运行完成。" if rc == 0 else f"❌ Pipeline 失败,退出码 {rc}。" |
| model_status = refresh_model_state() |
| yield "\n".join(log_lines), status, model_status |
|
|
|
|
| def build_ui(): |
| ensure_template_csv() |
| gr.set_static_paths(paths=[str(ASSETS_DIR)]) |
| css = STYLE_FILE.read_text(encoding="utf-8") if STYLE_FILE.exists() else "" |
| model_status = refresh_model_state() |
| with gr.Blocks() as demo: |
| gr.HTML(f"<style>{css}</style>") |
| with gr.Column(elem_id="main_panel"): |
| gr.Markdown("# 🏦 Bank Churn Pro Demo\n全屏背景 + Pipeline 日志 + 特征重要性 + 概率仪表盘 + CSV 批量预测 + SHAP 解释") |
| model_state_md = gr.Markdown(model_status) |
| pipeline_status_md = gr.Markdown("尚未运行 Pipeline。") |
|
|
| with gr.Tabs(): |
| with gr.Tab("Pipeline"): |
| gr.Markdown("点击按钮执行 3 步流水线:数据准备 → 模型训练与特征重要性 → 验证与 SHAP 背景缓存") |
| run_btn = gr.Button("▶ Run Pipeline", variant="primary") |
| log_box = gr.Textbox(label="Pipeline Step 1/2/3 日志", lines=22, interactive=False) |
| fi_image = gr.Image(label="Feature Importance 图", type="filepath") |
| run_btn.click(fn=run_pipeline_stream, inputs=[], outputs=[log_box, pipeline_status_md, model_state_md]).then(fn=make_feature_importance_plot, inputs=[], outputs=fi_image) |
|
|
| with gr.Tab("Single Prediction"): |
| with gr.Row(): |
| with gr.Column(): |
| age = gr.Slider(18, 100, value=DEFAULTS["AGE"], step=1, label="AGE") |
| open_acc_dur = gr.Slider(0, 400, value=DEFAULTS["OPEN_ACC_DUR"], step=1, label="OPEN_ACC_DUR") |
| gender_cd = gr.Dropdown(choices=["0", "1"], value=DEFAULTS["GENDER_CD"], label="GENDER_CD") |
| hasnt_home = gr.Dropdown(choices=["N", "Y"], value=DEFAULTS["HASNT_HOME_ADDRESS_INF"], label="HASNT_HOME_ADDRESS_INF") |
| hasnt_mobile = gr.Dropdown(choices=["N", "Y"], value=DEFAULTS["HASNT_MOBILE_TEL_NUM_INF"], label="HASNT_MOBILE_TEL_NUM_INF") |
| local_cur = gr.Number(value=DEFAULTS["LOCAL_CUR_MON_AVG_BAL"], label="LOCAL_CUR_MON_AVG_BAL") |
| local_fix = gr.Number(value=DEFAULTS["LOCAL_FIX_MON_AVG_BAL"], label="LOCAL_FIX_MON_AVG_BAL") |
| local_sav = gr.Number(value=DEFAULTS["LOCAL_SAV_CUR_ALL_BAL"], label="LOCAL_SAV_CUR_ALL_BAL") |
| pos_amt = gr.Number(value=DEFAULTS["POS_CONSUME_TX_AMT"], label="POS_CONSUME_TX_AMT") |
| atm_num = gr.Slider(0, 100, value=DEFAULTS["ATM_ALL_TX_NUM"], step=1, label="ATM_ALL_TX_NUM") |
| counter_num = gr.Slider(0, 100, value=DEFAULTS["COUNTER_ALL_TX_NUM"], step=1, label="COUNTER_ALL_TX_NUM") |
| pred_btn = gr.Button("Predict", variant="primary") |
| with gr.Column(): |
| pred_json = gr.JSON(label="Prediction JSON") |
| pred_md = gr.Markdown() |
| gauge = gr.HTML(label="Gauge") |
| pred_btn.click( |
| fn=predict_single, |
| inputs=[age, open_acc_dur, gender_cd, hasnt_home, hasnt_mobile, local_cur, local_fix, local_sav, pos_amt, atm_num, counter_num], |
| outputs=[pred_json, pred_md, gauge, fi_image], |
| ) |
|
|
| with gr.Tab("CSV Batch"): |
| gr.Markdown("上传包含以下列的 CSV:" + ", ".join(FEATURES)) |
| with gr.Row(): |
| batch_file = gr.File(label="Upload CSV", file_types=[".csv"]) |
| template_file = gr.File(value=str(TEMPLATE_CSV), label="Template CSV") |
| batch_btn = gr.Button("Run Batch Prediction") |
| batch_df = gr.Dataframe(label="Preview (Top 50)") |
| batch_out_file = gr.File(label="Download Result CSV") |
| batch_msg = gr.Markdown() |
| batch_btn.click(fn=predict_batch, inputs=[batch_file], outputs=[batch_df, batch_out_file, batch_msg]) |
|
|
| with gr.Tab("Explainability"): |
| gr.Markdown("使用当前表单中的同一组输入生成 SHAP waterfall 图。") |
| explain_btn = gr.Button("Generate SHAP Explainability") |
| shap_image = gr.Image(label="SHAP Explainability", type="filepath") |
| shap_md = gr.Markdown() |
| explain_btn.click( |
| fn=explain_single, |
| inputs=[age, open_acc_dur, gender_cd, hasnt_home, hasnt_mobile, local_cur, local_fix, local_sav, pos_amt, atm_num, counter_num], |
| outputs=[shap_image, shap_md], |
| ) |
|
|
| gr.Markdown("<div class='footer-note'>提示:首次进入请先运行 Pipeline,再使用预测、批量预测和解释功能。</div>") |
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| demo = build_ui() |
| demo.queue() |
| port = int(os.environ.get("PORT", "7860")) |
| demo.launch(server_name="0.0.0.0", server_port=port) |
|
|