| import os |
| import json |
| import time |
| import traceback |
| import subprocess |
| from pathlib import Path |
|
|
| import gradio as gr |
| import pandas as pd |
| import joblib |
| import papermill as pm |
| import plotly.express as px |
| from huggingface_hub import InferenceClient |
| from jupyter_client.kernelspec import KernelSpecManager |
|
|
| BASE_DIR = Path(__file__).resolve().parent |
|
|
| DATA_PATH = BASE_DIR / "bankChurn.csv" |
| MODEL_PATH = BASE_DIR / "models" / "pipeline.joblib" |
| PY_NOTEBOOK = BASE_DIR / "BankChurn_Version1.ipynb" |
| R_NOTEBOOK = BASE_DIR / "BankChurn_Version1_R.ipynb" |
| PIPELINE_CANDIDATES = [ |
| BASE_DIR / "scripts" / "pipeline.py", |
| BASE_DIR / "pipeline.py", |
| ] |
|
|
| RUNS_DIR = BASE_DIR / "runs" |
| ART_DIR = BASE_DIR / "artifacts" |
| PY_TAB_DIR = ART_DIR / "py" / "tables" |
| R_TAB_DIR = ART_DIR / "r" / "tables" |
|
|
| PAPERMILL_TIMEOUT = int(os.environ.get("PAPERMILL_TIMEOUT", "1800")) |
| HF_API_KEY = os.environ.get("HF_API_KEY", "").strip() |
| MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct").strip() |
|
|
|
|
| def ensure_dirs(): |
| for p in [RUNS_DIR, PY_TAB_DIR, R_TAB_DIR]: |
| p.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def stamp(): |
| return time.strftime("%Y%m%d-%H%M%S") |
|
|
|
|
| def available_kernels(): |
| try: |
| return KernelSpecManager().find_kernel_specs() |
| except Exception: |
| return {} |
|
|
|
|
| def python_kernel_name(): |
| kernels = available_kernels() |
| for name in ["python3", "python", "python-3"]: |
| if name in kernels: |
| return name |
| return None |
|
|
|
|
| def r_kernel_name(): |
| kernels = available_kernels() |
| for name in ["ir", "irkernel", "r"]: |
| if name in kernels: |
| return name |
| return None |
|
|
|
|
| def run_notebook(nb_path: Path, label: str, kernel_name: str | None) -> str: |
| ensure_dirs() |
| if not nb_path.exists(): |
| return f"❌ {label} not found: {nb_path.name}" |
| if not kernel_name: |
| kernels = ", ".join(sorted(available_kernels().keys())) or "none" |
| return f"❌ {label} kernel is not available.\nAvailable kernels: {kernels}" |
|
|
| try: |
| out_path = RUNS_DIR / f"run_{stamp()}_{nb_path.name}" |
| pm.execute_notebook( |
| input_path=str(nb_path), |
| output_path=str(out_path), |
| cwd=str(BASE_DIR), |
| kernel_name=kernel_name, |
| log_output=True, |
| progress_bar=False, |
| request_save_on_cell_execute=True, |
| execution_timeout=PAPERMILL_TIMEOUT, |
| ) |
| return f"✅ {label} finished successfully.\nSaved run: {out_path.name}" |
| except Exception as e: |
| return f"❌ {label} failed.\n{str(e)}\n\n{traceback.format_exc()[-3000:]}" |
|
|
|
|
| def run_python(): |
| return run_notebook(PY_NOTEBOOK, "Python notebook", python_kernel_name()) |
|
|
|
|
| def run_r(): |
| return run_notebook(R_NOTEBOOK, "R notebook", r_kernel_name()) |
|
|
|
|
| def run_pipeline(): |
| target = None |
| for candidate in PIPELINE_CANDIDATES: |
| if candidate.exists(): |
| target = candidate |
| break |
| if target is None: |
| return "❌ pipeline.py not found." |
| try: |
| proc = subprocess.run( |
| ["python", str(target)], |
| cwd=str(BASE_DIR), |
| capture_output=True, |
| text=True, |
| check=False, |
| ) |
| log = (proc.stdout or "") + ("\n" + proc.stderr if proc.stderr else "") |
| if proc.returncode == 0: |
| return f"✅ Pipeline finished successfully.\n\n{log[-5000:]}" |
| return f"❌ Pipeline failed with exit code {proc.returncode}.\n\n{log[-5000:]}" |
| except Exception as e: |
| return f"❌ Pipeline failed.\n{str(e)}" |
|
|
|
|
| def run_all(): |
| parts = [ |
| "=== Run Python ===", |
| run_python(), |
| "", |
| "=== Run R ===", |
| run_r(), |
| "", |
| "=== Run Pipeline ===", |
| run_pipeline(), |
| ] |
| return "\n".join(parts) |
|
|
|
|
| def load_model(): |
| if MODEL_PATH.exists(): |
| return joblib.load(MODEL_PATH) |
| return None |
|
|
|
|
| def encode_gender(gender): |
| if gender is None: |
| return 0 |
| g = str(gender).strip().upper() |
| if g in {"M", "MALE", "1"}: |
| return 1 |
| if g in {"F", "FEMALE", "0"}: |
| return 0 |
| return 0 |
|
|
|
|
| def predict(age, gender, balance): |
| model = load_model() |
| if model is None: |
| return "Please run the pipeline first." |
|
|
| |
| if hasattr(model, "feature_names_in_"): |
| feature_names = list(model.feature_names_in_) |
| else: |
| feature_names = ["Age", "Balance"] |
|
|
| values = {} |
| for col in feature_names: |
| c = col.lower() |
| if c in ["age"]: |
| values[col] = age |
| elif c in ["balance", "local_cur_mon_avg_bal"]: |
| values[col] = balance |
| elif c in ["gender", "gender_cd", "sex"]: |
| values[col] = encode_gender(gender) |
| else: |
| values[col] = 0 |
|
|
| X = pd.DataFrame([values], columns=feature_names) |
| pred = model.predict(X)[0] |
|
|
| if hasattr(model, "predict_proba"): |
| try: |
| prob = float(model.predict_proba(X)[0][1]) |
| return f"Churn Risk: {'Yes' if pred == 1 else 'No'} | Probability: {prob:.2%}" |
| except Exception: |
| pass |
|
|
| return "Churn Risk: Yes" if pred == 1 else "Churn Risk: No" |
|
|
|
|
| def load_data(): |
| if DATA_PATH.exists(): |
| return pd.read_csv(DATA_PATH) |
| return pd.DataFrame({ |
| "AGE": [25, 45, 33], |
| "LOCAL_CUR_MON_AVG_BAL": [1000, 5000, 2300], |
| "GENDER_CD": ["M", "F", "M"], |
| "CHURN_CUST_IND": [0, 1, 0], |
| }) |
|
|
|
|
| def get_target_col(df: pd.DataFrame): |
| for c in ["CHURN_CUST_IND", "Exited", "churn", "target"]: |
| if c in df.columns: |
| return c |
| return None |
|
|
|
|
| def get_age_col(df: pd.DataFrame): |
| for c in ["AGE", "Age", "age"]: |
| if c in df.columns: |
| return c |
| return None |
|
|
|
|
| def get_balance_col(df: pd.DataFrame): |
| for c in ["LOCAL_CUR_MON_AVG_BAL", "Balance", "balance"]: |
| if c in df.columns: |
| return c |
| return None |
|
|
|
|
| def get_segment_col(df: pd.DataFrame): |
| for c in ["Geography", "GENDER_CD", "gender", "SEGMENT"]: |
| if c in df.columns: |
| return c |
| return None |
|
|
|
|
| def _read_json(path: Path): |
| with open(path, "r", encoding="utf-8") as f: |
| obj = json.load(f) |
| if isinstance(obj, dict): |
| return pd.DataFrame([obj]) |
| return pd.DataFrame(obj) |
|
|
|
|
| def load_latest_table(table_dir: Path): |
| if not table_dir.exists(): |
| return None, None |
| files = sorted( |
| [p for p in table_dir.iterdir() if p.suffix.lower() in [".csv", ".json"]], |
| key=lambda p: p.stat().st_mtime, |
| reverse=True, |
| ) |
| if not files: |
| return None, None |
|
|
| path = files[0] |
| try: |
| if path.suffix.lower() == ".csv": |
| df = pd.read_csv(path) |
| else: |
| df = _read_json(path) |
| return path.name, df |
| except Exception as e: |
| return path.name, pd.DataFrame([{"error": str(e)}]) |
|
|
|
|
| def build_interactive_plot(df: pd.DataFrame, title: str): |
| if df is None or df.empty: |
| return px.scatter(title=f"{title}: no data") |
|
|
| numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])] |
| cat_cols = [c for c in df.columns if c not in numeric_cols] |
|
|
| if cat_cols and numeric_cols: |
| x = cat_cols[0] |
| y = numeric_cols[0] |
| chart_df = df[[x, y]].dropna().copy().head(100) |
| if chart_df.empty: |
| return px.scatter(title=f"{title}: no usable rows") |
| if chart_df[x].nunique() <= 20: |
| fig = px.bar(chart_df, x=x, y=y, title=title) |
| else: |
| fig = px.line(chart_df, x=x, y=y, title=title, markers=True) |
| fig.update_layout(height=380) |
| return fig |
|
|
| if len(numeric_cols) >= 2: |
| chart_df = df[numeric_cols[:2]].dropna().copy().head(300) |
| fig = px.scatter(chart_df, x=numeric_cols[0], y=numeric_cols[1], title=title) |
| fig.update_layout(height=380) |
| return fig |
|
|
| if len(numeric_cols) == 1: |
| fig = px.histogram(df, x=numeric_cols[0], title=title) |
| fig.update_layout(height=380) |
| return fig |
|
|
| return px.scatter(title=f"{title}: unsupported table structure") |
|
|
|
|
| def build_overview_charts(df: pd.DataFrame): |
| target_col = get_target_col(df) |
| age_col = get_age_col(df) |
| segment_col = get_segment_col(df) |
|
|
| seg_fig = px.scatter(title="Churn by Segment") |
| age_fig = px.scatter(title="Churn by Age Band") |
|
|
| if target_col and segment_col: |
| seg_df = df.groupby(segment_col, as_index=False)[target_col].mean() |
| seg_df[target_col] = (seg_df[target_col] * 100).round(2) |
| seg_fig = px.bar(seg_df, x=segment_col, y=target_col, title=f"Churn by {segment_col} (%)") |
| seg_fig.update_layout(height=380) |
|
|
| if target_col and age_col: |
| temp = df.copy() |
| temp["AgeBand"] = pd.cut(temp[age_col], bins=[18, 30, 40, 50, 60, 70, 120], include_lowest=True) |
| age_df = temp.groupby("AgeBand").agg(churn_rate=(target_col, "mean")).reset_index() |
| age_df["AgeBand"] = age_df["AgeBand"].astype(str) |
| age_df["churn_rate"] = (age_df["churn_rate"] * 100).round(2) |
| age_fig = px.line(age_df, x="AgeBand", y="churn_rate", title="Churn by Age Band (%)", markers=True) |
| age_fig.update_layout(height=380) |
|
|
| return seg_fig, age_fig |
|
|
|
|
| def build_dashboard(): |
| df = load_data() |
| target_col = get_target_col(df) |
| balance_col = get_balance_col(df) |
|
|
| summary_lines = [ |
| "### Executive Summary", |
| f"- Total Customers: **{len(df)}**", |
| ] |
| if target_col: |
| summary_lines.append(f"- Churn Rate: **{round(df[target_col].mean() * 100, 2)}%**") |
| summary_lines.append(f"- Churned Customers: **{int(df[target_col].sum())}**") |
| if balance_col: |
| summary_lines.append(f"- Average Balance: **{round(df[balance_col].mean(), 2)}**") |
|
|
| kernels = ", ".join(sorted(available_kernels().keys())) or "none" |
| summary_lines.append(f"- Available Kernels: **{kernels}**") |
| summary_md = "\n".join(summary_lines) |
|
|
| seg_fig, age_fig = build_overview_charts(df) |
|
|
| py_name, py_df = load_latest_table(PY_TAB_DIR) |
| r_name, r_df = load_latest_table(R_TAB_DIR) |
|
|
| py_status = f"### Python Analysis Output\nLatest table: **{py_name or 'none found'}**" |
| r_status = f"### R Analysis Output\nLatest table: **{r_name or 'none found'}**" |
|
|
| py_plot = build_interactive_plot(py_df, "Python Analysis Chart") |
| r_plot = build_interactive_plot(r_df, "R Analysis Chart") |
|
|
| if py_df is None: |
| py_df = pd.DataFrame([{"info": "No Python table found in artifacts/py/tables"}]) |
| if r_df is None: |
| r_df = pd.DataFrame([{"info": "No R table found in artifacts/r/tables"}]) |
|
|
| return summary_md, seg_fig, age_fig, py_status, py_plot, py_df, r_status, r_plot, r_df |
|
|
|
|
| def generate_ai_insight(question: str): |
| if not HF_API_KEY: |
| return "HF_API_KEY is not configured in Space Secrets." |
|
|
| df = load_data() |
| target_col = get_target_col(df) |
| balance_col = get_balance_col(df) |
| summary = { |
| "rows": int(len(df)), |
| "churn_rate": round(float(df[target_col].mean() * 100), 2) if target_col else None, |
| "avg_balance": round(float(df[balance_col].mean()), 2) if balance_col else None, |
| "target_column": target_col, |
| } |
| py_name, py_df = load_latest_table(PY_TAB_DIR) |
| r_name, r_df = load_latest_table(R_TAB_DIR) |
|
|
| prompt = f""" |
| You are a bank churn strategy assistant. |
| Use the dataset summary and analysis outputs to answer in concise business language. |
| |
| Question: {question} |
| |
| Dataset summary: |
| {json.dumps(summary, ensure_ascii=False)} |
| |
| Python latest table: |
| {py_name} |
| {py_df.head(8).to_csv(index=False) if py_df is not None else 'No Python analysis table available.'} |
| |
| R latest table: |
| {r_name} |
| {r_df.head(8).to_csv(index=False) if r_df is not None else 'No R analysis table available.'} |
| |
| Return: |
| 1. Key finding |
| 2. Customer retention action |
| 3. One risk or caveat |
| """ |
| try: |
| client = InferenceClient(api_key=HF_API_KEY) |
| try: |
| response = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[ |
| {"role": "system", "content": "You are a precise analytics assistant."}, |
| {"role": "user", "content": prompt}, |
| ], |
| max_tokens=350, |
| ) |
| return response.choices[0].message.content.strip() |
| except Exception: |
| return client.text_generation(prompt, model=MODEL_NAME, max_new_tokens=350) |
| except Exception as e: |
| return f"AI request failed: {str(e)}" |
|
|
|
|
| def build_ui(): |
| css_path = BASE_DIR / "style.css" |
| css = css_path.read_text(encoding="utf-8") if css_path.exists() else "" |
|
|
| with gr.Blocks(title="Bank Churn Intelligence Hub") as demo: |
| gr.HTML(f"<style>{css}</style>") |
|
|
| gr.Markdown( |
| "# 🏦 Bank Churn Intelligence Hub\n" |
| "*Run Python and R analyses, refresh the dashboard, and ask AI for retention ideas.*", |
| elem_id="escp_title", |
| ) |
|
|
| with gr.Tab("Analysis Runner"): |
| gr.Markdown("### Execute the analysis workflow") |
| with gr.Row(): |
| btn_py = gr.Button("Run Python", variant="secondary") |
| btn_r = gr.Button("Run R", variant="secondary") |
| btn_all = gr.Button("Run All", variant="primary") |
| exec_log = gr.Textbox(label="Execution Log", lines=18, max_lines=28, interactive=False) |
| btn_py.click(run_python, outputs=[exec_log]) |
| btn_r.click(run_r, outputs=[exec_log]) |
| btn_all.click(run_all, outputs=[exec_log]) |
|
|
| with gr.Tab("Interactive Dashboard"): |
| refresh_btn = gr.Button("Refresh Dashboard", variant="primary") |
| summary_md = gr.Markdown() |
|
|
| with gr.Row(): |
| seg_plot = gr.Plot(label="Churn by Segment") |
| age_plot = gr.Plot(label="Churn by Age Band") |
|
|
| with gr.Row(): |
| py_status = gr.Markdown() |
| r_status = gr.Markdown() |
|
|
| with gr.Row(): |
| py_plot = gr.Plot(label="Python Analysis Plot") |
| r_plot = gr.Plot(label="R Analysis Plot") |
|
|
| with gr.Row(): |
| py_table = gr.Dataframe(label="Python Analysis Table", interactive=True) |
| r_table = gr.Dataframe(label="R Analysis Table", interactive=True) |
|
|
| refresh_btn.click( |
| build_dashboard, |
| outputs=[summary_md, seg_plot, age_plot, py_status, py_plot, py_table, r_status, r_plot, r_table], |
| ) |
| demo.load( |
| build_dashboard, |
| outputs=[summary_md, seg_plot, age_plot, py_status, py_plot, py_table, r_status, r_plot, r_table], |
| ) |
|
|
| with gr.Tab("Prediction"): |
| with gr.Row(): |
| age = gr.Number(label="Age", value=35) |
| gender = gr.Dropdown( |
| choices=["M", "F"], |
| value="M", |
| label="Gender" |
| ) |
| balance = gr.Number(label="Balance", value=5000) |
| pred_btn = gr.Button("Predict", variant="primary") |
| pred_out = gr.Textbox(label="Prediction Result") |
| pred_btn.click(predict, inputs=[age, gender, balance], outputs=[pred_out]) |
|
|
| with gr.Tab("AI Insight"): |
| gr.Markdown("### Ask AI to interpret the Python and R analysis outputs") |
| ai_q = gr.Textbox( |
| label="Question", |
| placeholder="What does the latest Python and R analysis suggest about churn risk?" |
| ) |
| ai_btn = gr.Button("Generate AI Insight", variant="primary") |
| ai_out = gr.Textbox(label="AI Response", lines=12) |
| ai_btn.click(generate_ai_insight, inputs=[ai_q], outputs=[ai_out]) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| ensure_dirs() |
| demo = build_ui() |
| demo.queue() |
| port = int(os.environ.get("PORT", 7860)) |
| demo.launch(server_name="0.0.0.0", server_port=port) |
|
|