Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| import traceback | |
| import subprocess | |
| from pathlib import Path | |
| import gradio as gr | |
| import pandas as pd | |
| import joblib | |
| import papermill as pm | |
| import plotly.express as px | |
| from huggingface_hub import InferenceClient | |
| from jupyter_client.kernelspec import KernelSpecManager | |
| BASE_DIR = Path(__file__).resolve().parent | |
| DATA_PATH = BASE_DIR / "bankChurn.csv" | |
| MODEL_PATH = BASE_DIR / "models" / "pipeline.joblib" | |
| PY_NOTEBOOK = BASE_DIR / "BankChurn_Version1.ipynb" | |
| R_NOTEBOOK = BASE_DIR / "BankChurn_Version1_R.ipynb" | |
| PIPELINE_CANDIDATES = [ | |
| BASE_DIR / "scripts" / "pipeline.py", | |
| BASE_DIR / "pipeline.py", | |
| ] | |
| RUNS_DIR = BASE_DIR / "runs" | |
| ART_DIR = BASE_DIR / "artifacts" | |
| PY_TAB_DIR = ART_DIR / "py" / "tables" | |
| R_TAB_DIR = ART_DIR / "r" / "tables" | |
| PAPERMILL_TIMEOUT = int(os.environ.get("PAPERMILL_TIMEOUT", "1800")) | |
| HF_API_KEY = os.environ.get("HF_API_KEY", "").strip() | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct").strip() | |
| def ensure_dirs(): | |
| for p in [RUNS_DIR, PY_TAB_DIR, R_TAB_DIR]: | |
| p.mkdir(parents=True, exist_ok=True) | |
| def stamp(): | |
| return time.strftime("%Y%m%d-%H%M%S") | |
| def available_kernels(): | |
| try: | |
| return KernelSpecManager().find_kernel_specs() | |
| except Exception: | |
| return {} | |
| def python_kernel_name(): | |
| kernels = available_kernels() | |
| for name in ["python3", "python", "python-3"]: | |
| if name in kernels: | |
| return name | |
| return None | |
| def r_kernel_name(): | |
| kernels = available_kernels() | |
| for name in ["ir", "irkernel", "r"]: | |
| if name in kernels: | |
| return name | |
| return None | |
| def run_notebook(nb_path: Path, label: str, kernel_name: str | None) -> str: | |
| ensure_dirs() | |
| if not nb_path.exists(): | |
| return f"❌ {label} not found: {nb_path.name}" | |
| if not kernel_name: | |
| kernels = ", ".join(sorted(available_kernels().keys())) or "none" | |
| return f"❌ {label} kernel is not available.\nAvailable kernels: {kernels}" | |
| try: | |
| out_path = RUNS_DIR / f"run_{stamp()}_{nb_path.name}" | |
| pm.execute_notebook( | |
| input_path=str(nb_path), | |
| output_path=str(out_path), | |
| cwd=str(BASE_DIR), | |
| kernel_name=kernel_name, | |
| log_output=True, | |
| progress_bar=False, | |
| request_save_on_cell_execute=True, | |
| execution_timeout=PAPERMILL_TIMEOUT, | |
| ) | |
| return f"✅ {label} finished successfully.\nSaved run: {out_path.name}" | |
| except Exception as e: | |
| return f"❌ {label} failed.\n{str(e)}\n\n{traceback.format_exc()[-3000:]}" | |
| def run_python(): | |
| return run_notebook(PY_NOTEBOOK, "Python notebook", python_kernel_name()) | |
| def run_r(): | |
| return run_notebook(R_NOTEBOOK, "R notebook", r_kernel_name()) | |
| def run_pipeline(): | |
| target = None | |
| for candidate in PIPELINE_CANDIDATES: | |
| if candidate.exists(): | |
| target = candidate | |
| break | |
| if target is None: | |
| return "❌ pipeline.py not found." | |
| try: | |
| proc = subprocess.run( | |
| ["python", str(target)], | |
| cwd=str(BASE_DIR), | |
| capture_output=True, | |
| text=True, | |
| check=False, | |
| ) | |
| log = (proc.stdout or "") + ("\n" + proc.stderr if proc.stderr else "") | |
| if proc.returncode == 0: | |
| return f"✅ Pipeline finished successfully.\n\n{log[-5000:]}" | |
| return f"❌ Pipeline failed with exit code {proc.returncode}.\n\n{log[-5000:]}" | |
| except Exception as e: | |
| return f"❌ Pipeline failed.\n{str(e)}" | |
| def run_all(): | |
| parts = [ | |
| "=== Run Python ===", | |
| run_python(), | |
| "", | |
| "=== Run R ===", | |
| run_r(), | |
| "", | |
| "=== Run Pipeline ===", | |
| run_pipeline(), | |
| ] | |
| return "\n".join(parts) | |
| def load_model(): | |
| if MODEL_PATH.exists(): | |
| return joblib.load(MODEL_PATH) | |
| return None | |
| def encode_gender(gender): | |
| if gender is None: | |
| return 0 | |
| g = str(gender).strip().upper() | |
| if g in {"M", "MALE", "1"}: | |
| return 1 | |
| if g in {"F", "FEMALE", "0"}: | |
| return 0 | |
| return 0 | |
| def predict(age, gender, balance): | |
| model = load_model() | |
| if model is None: | |
| return "Please run the pipeline first." | |
| # Try to align with whatever features the saved model expects | |
| if hasattr(model, "feature_names_in_"): | |
| feature_names = list(model.feature_names_in_) | |
| else: | |
| feature_names = ["Age", "Balance"] | |
| values = {} | |
| for col in feature_names: | |
| c = col.lower() | |
| if c in ["age"]: | |
| values[col] = age | |
| elif c in ["balance", "local_cur_mon_avg_bal"]: | |
| values[col] = balance | |
| elif c in ["gender", "gender_cd", "sex"]: | |
| values[col] = encode_gender(gender) | |
| else: | |
| values[col] = 0 | |
| X = pd.DataFrame([values], columns=feature_names) | |
| pred = model.predict(X)[0] | |
| if hasattr(model, "predict_proba"): | |
| try: | |
| prob = float(model.predict_proba(X)[0][1]) | |
| return f"Churn Risk: {'Yes' if pred == 1 else 'No'} | Probability: {prob:.2%}" | |
| except Exception: | |
| pass | |
| return "Churn Risk: Yes" if pred == 1 else "Churn Risk: No" | |
| def load_data(): | |
| if DATA_PATH.exists(): | |
| return pd.read_csv(DATA_PATH) | |
| return pd.DataFrame({ | |
| "AGE": [25, 45, 33], | |
| "LOCAL_CUR_MON_AVG_BAL": [1000, 5000, 2300], | |
| "GENDER_CD": ["M", "F", "M"], | |
| "CHURN_CUST_IND": [0, 1, 0], | |
| }) | |
| def get_target_col(df: pd.DataFrame): | |
| for c in ["CHURN_CUST_IND", "Exited", "churn", "target"]: | |
| if c in df.columns: | |
| return c | |
| return None | |
| def get_age_col(df: pd.DataFrame): | |
| for c in ["AGE", "Age", "age"]: | |
| if c in df.columns: | |
| return c | |
| return None | |
| def get_balance_col(df: pd.DataFrame): | |
| for c in ["LOCAL_CUR_MON_AVG_BAL", "Balance", "balance"]: | |
| if c in df.columns: | |
| return c | |
| return None | |
| def get_segment_col(df: pd.DataFrame): | |
| for c in ["Geography", "GENDER_CD", "gender", "SEGMENT"]: | |
| if c in df.columns: | |
| return c | |
| return None | |
| def _read_json(path: Path): | |
| with open(path, "r", encoding="utf-8") as f: | |
| obj = json.load(f) | |
| if isinstance(obj, dict): | |
| return pd.DataFrame([obj]) | |
| return pd.DataFrame(obj) | |
| def load_latest_table(table_dir: Path): | |
| if not table_dir.exists(): | |
| return None, None | |
| files = sorted( | |
| [p for p in table_dir.iterdir() if p.suffix.lower() in [".csv", ".json"]], | |
| key=lambda p: p.stat().st_mtime, | |
| reverse=True, | |
| ) | |
| if not files: | |
| return None, None | |
| path = files[0] | |
| try: | |
| if path.suffix.lower() == ".csv": | |
| df = pd.read_csv(path) | |
| else: | |
| df = _read_json(path) | |
| return path.name, df | |
| except Exception as e: | |
| return path.name, pd.DataFrame([{"error": str(e)}]) | |
| def build_interactive_plot(df: pd.DataFrame, title: str): | |
| if df is None or df.empty: | |
| return px.scatter(title=f"{title}: no data") | |
| numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])] | |
| cat_cols = [c for c in df.columns if c not in numeric_cols] | |
| if cat_cols and numeric_cols: | |
| x = cat_cols[0] | |
| y = numeric_cols[0] | |
| chart_df = df[[x, y]].dropna().copy().head(100) | |
| if chart_df.empty: | |
| return px.scatter(title=f"{title}: no usable rows") | |
| if chart_df[x].nunique() <= 20: | |
| fig = px.bar(chart_df, x=x, y=y, title=title) | |
| else: | |
| fig = px.line(chart_df, x=x, y=y, title=title, markers=True) | |
| fig.update_layout(height=380) | |
| return fig | |
| if len(numeric_cols) >= 2: | |
| chart_df = df[numeric_cols[:2]].dropna().copy().head(300) | |
| fig = px.scatter(chart_df, x=numeric_cols[0], y=numeric_cols[1], title=title) | |
| fig.update_layout(height=380) | |
| return fig | |
| if len(numeric_cols) == 1: | |
| fig = px.histogram(df, x=numeric_cols[0], title=title) | |
| fig.update_layout(height=380) | |
| return fig | |
| return px.scatter(title=f"{title}: unsupported table structure") | |
| def build_overview_charts(df: pd.DataFrame): | |
| target_col = get_target_col(df) | |
| age_col = get_age_col(df) | |
| segment_col = get_segment_col(df) | |
| seg_fig = px.scatter(title="Churn by Segment") | |
| age_fig = px.scatter(title="Churn by Age Band") | |
| if target_col and segment_col: | |
| seg_df = df.groupby(segment_col, as_index=False)[target_col].mean() | |
| seg_df[target_col] = (seg_df[target_col] * 100).round(2) | |
| seg_fig = px.bar(seg_df, x=segment_col, y=target_col, title=f"Churn by {segment_col} (%)") | |
| seg_fig.update_layout(height=380) | |
| if target_col and age_col: | |
| temp = df.copy() | |
| temp["AgeBand"] = pd.cut(temp[age_col], bins=[18, 30, 40, 50, 60, 70, 120], include_lowest=True) | |
| age_df = temp.groupby("AgeBand").agg(churn_rate=(target_col, "mean")).reset_index() | |
| age_df["AgeBand"] = age_df["AgeBand"].astype(str) | |
| age_df["churn_rate"] = (age_df["churn_rate"] * 100).round(2) | |
| age_fig = px.line(age_df, x="AgeBand", y="churn_rate", title="Churn by Age Band (%)", markers=True) | |
| age_fig.update_layout(height=380) | |
| return seg_fig, age_fig | |
| def build_dashboard(): | |
| df = load_data() | |
| target_col = get_target_col(df) | |
| balance_col = get_balance_col(df) | |
| summary_lines = [ | |
| "### Executive Summary", | |
| f"- Total Customers: **{len(df)}**", | |
| ] | |
| if target_col: | |
| summary_lines.append(f"- Churn Rate: **{round(df[target_col].mean() * 100, 2)}%**") | |
| summary_lines.append(f"- Churned Customers: **{int(df[target_col].sum())}**") | |
| if balance_col: | |
| summary_lines.append(f"- Average Balance: **{round(df[balance_col].mean(), 2)}**") | |
| kernels = ", ".join(sorted(available_kernels().keys())) or "none" | |
| summary_lines.append(f"- Available Kernels: **{kernels}**") | |
| summary_md = "\n".join(summary_lines) | |
| seg_fig, age_fig = build_overview_charts(df) | |
| py_name, py_df = load_latest_table(PY_TAB_DIR) | |
| r_name, r_df = load_latest_table(R_TAB_DIR) | |
| py_status = f"### Python Analysis Output\nLatest table: **{py_name or 'none found'}**" | |
| r_status = f"### R Analysis Output\nLatest table: **{r_name or 'none found'}**" | |
| py_plot = build_interactive_plot(py_df, "Python Analysis Chart") | |
| r_plot = build_interactive_plot(r_df, "R Analysis Chart") | |
| if py_df is None: | |
| py_df = pd.DataFrame([{"info": "No Python table found in artifacts/py/tables"}]) | |
| if r_df is None: | |
| r_df = pd.DataFrame([{"info": "No R table found in artifacts/r/tables"}]) | |
| return summary_md, seg_fig, age_fig, py_status, py_plot, py_df, r_status, r_plot, r_df | |
| def generate_ai_insight(question: str): | |
| if not HF_API_KEY: | |
| return "HF_API_KEY is not configured in Space Secrets." | |
| df = load_data() | |
| target_col = get_target_col(df) | |
| balance_col = get_balance_col(df) | |
| summary = { | |
| "rows": int(len(df)), | |
| "churn_rate": round(float(df[target_col].mean() * 100), 2) if target_col else None, | |
| "avg_balance": round(float(df[balance_col].mean()), 2) if balance_col else None, | |
| "target_column": target_col, | |
| } | |
| py_name, py_df = load_latest_table(PY_TAB_DIR) | |
| r_name, r_df = load_latest_table(R_TAB_DIR) | |
| prompt = f""" | |
| You are a bank churn strategy assistant. | |
| Use the dataset summary and analysis outputs to answer in concise business language. | |
| Question: {question} | |
| Dataset summary: | |
| {json.dumps(summary, ensure_ascii=False)} | |
| Python latest table: | |
| {py_name} | |
| {py_df.head(8).to_csv(index=False) if py_df is not None else 'No Python analysis table available.'} | |
| R latest table: | |
| {r_name} | |
| {r_df.head(8).to_csv(index=False) if r_df is not None else 'No R analysis table available.'} | |
| Return: | |
| 1. Key finding | |
| 2. Customer retention action | |
| 3. One risk or caveat | |
| """ | |
| try: | |
| client = InferenceClient(api_key=HF_API_KEY) | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": "You are a precise analytics assistant."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| max_tokens=350, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception: | |
| return client.text_generation(prompt, model=MODEL_NAME, max_new_tokens=350) | |
| except Exception as e: | |
| return f"AI request failed: {str(e)}" | |
| def build_ui(): | |
| css_path = BASE_DIR / "style.css" | |
| css = css_path.read_text(encoding="utf-8") if css_path.exists() else "" | |
| with gr.Blocks(title="Bank Churn Intelligence Hub") as demo: | |
| gr.HTML(f"<style>{css}</style>") | |
| gr.Markdown( | |
| "# 🏦 Bank Churn Intelligence Hub\n" | |
| "*Run Python and R analyses, refresh the dashboard, and ask AI for retention ideas.*", | |
| elem_id="escp_title", | |
| ) | |
| with gr.Tab("Analysis Runner"): | |
| gr.Markdown("### Execute the analysis workflow") | |
| with gr.Row(): | |
| btn_py = gr.Button("Run Python", variant="secondary") | |
| btn_r = gr.Button("Run R", variant="secondary") | |
| btn_all = gr.Button("Run All", variant="primary") | |
| exec_log = gr.Textbox(label="Execution Log", lines=18, max_lines=28, interactive=False) | |
| btn_py.click(run_python, outputs=[exec_log]) | |
| btn_r.click(run_r, outputs=[exec_log]) | |
| btn_all.click(run_all, outputs=[exec_log]) | |
| with gr.Tab("Interactive Dashboard"): | |
| refresh_btn = gr.Button("Refresh Dashboard", variant="primary") | |
| summary_md = gr.Markdown() | |
| with gr.Row(): | |
| seg_plot = gr.Plot(label="Churn by Segment") | |
| age_plot = gr.Plot(label="Churn by Age Band") | |
| with gr.Row(): | |
| py_status = gr.Markdown() | |
| r_status = gr.Markdown() | |
| with gr.Row(): | |
| py_plot = gr.Plot(label="Python Analysis Plot") | |
| r_plot = gr.Plot(label="R Analysis Plot") | |
| with gr.Row(): | |
| py_table = gr.Dataframe(label="Python Analysis Table", interactive=True) | |
| r_table = gr.Dataframe(label="R Analysis Table", interactive=True) | |
| refresh_btn.click( | |
| build_dashboard, | |
| outputs=[summary_md, seg_plot, age_plot, py_status, py_plot, py_table, r_status, r_plot, r_table], | |
| ) | |
| demo.load( | |
| build_dashboard, | |
| outputs=[summary_md, seg_plot, age_plot, py_status, py_plot, py_table, r_status, r_plot, r_table], | |
| ) | |
| with gr.Tab("Prediction"): | |
| with gr.Row(): | |
| age = gr.Number(label="Age", value=35) | |
| gender = gr.Dropdown( | |
| choices=["M", "F"], | |
| value="M", | |
| label="Gender" | |
| ) | |
| balance = gr.Number(label="Balance", value=5000) | |
| pred_btn = gr.Button("Predict", variant="primary") | |
| pred_out = gr.Textbox(label="Prediction Result") | |
| pred_btn.click(predict, inputs=[age, gender, balance], outputs=[pred_out]) | |
| with gr.Tab("AI Insight"): | |
| gr.Markdown("### Ask AI to interpret the Python and R analysis outputs") | |
| ai_q = gr.Textbox( | |
| label="Question", | |
| placeholder="What does the latest Python and R analysis suggest about churn risk?" | |
| ) | |
| ai_btn = gr.Button("Generate AI Insight", variant="primary") | |
| ai_out = gr.Textbox(label="AI Response", lines=12) | |
| ai_btn.click(generate_ai_insight, inputs=[ai_q], outputs=[ai_out]) | |
| return demo | |
| if __name__ == "__main__": | |
| ensure_dirs() | |
| demo = build_ui() | |
| demo.queue() | |
| port = int(os.environ.get("PORT", 7860)) | |
| demo.launch(server_name="0.0.0.0", server_port=port) | |