| import os |
| import math |
| from pathlib import Path |
|
|
| import gradio as gr |
| import pandas as pd |
|
|
| REQUIRED_COLUMNS = [ |
| "CreditScore", |
| "Age", |
| "Balance", |
| "NumOfProducts", |
| "IsActiveMember", |
| "EstimatedSalary", |
| ] |
|
|
| def clamp(x, low=0.0, high=1.0): |
| return max(low, min(high, x)) |
|
|
| def churn_probability( |
| credit_score, |
| age, |
| balance, |
| num_products, |
| is_active_member, |
| estimated_salary, |
| ): |
| score = -1.2 |
| score += max(age - 40, 0) * 0.04 |
| score += max(650 - credit_score, 0) * 0.003 |
| score += min(balance / 100000.0, 2.0) * 0.45 |
| score += 0.35 if int(num_products) <= 1 else -0.15 |
| score += -0.55 if int(is_active_member) == 1 else 0.35 |
| score += min(estimated_salary / 200000.0, 1.0) * 0.08 |
| prob = 1 / (1 + math.exp(-score)) |
| return clamp(prob) |
|
|
| def risk_label(prob): |
| if prob < 0.30: |
| return "Low" |
| if prob < 0.60: |
| return "Medium" |
| return "High" |
|
|
| def predict_single( |
| credit_score, |
| age, |
| balance, |
| num_products, |
| is_active_member, |
| estimated_salary, |
| ): |
| p = churn_probability( |
| float(credit_score), |
| float(age), |
| float(balance), |
| float(num_products), |
| int(is_active_member), |
| float(estimated_salary), |
| ) |
| label = 1 if p >= 0.5 else 0 |
| summary = ( |
| f"### Prediction Result\n\n" |
| f"- Churn probability: **{p:.1%}**\n" |
| f"- Predicted class: **{label}**\n" |
| f"- Risk level: **{risk_label(p)}**" |
| ) |
| table = pd.DataFrame( |
| { |
| "metric": ["churn_probability", "predicted_class", "risk_level"], |
| "value": [round(p, 4), label, risk_label(p)], |
| } |
| ) |
| return summary, table |
|
|
| def predict_batch(file): |
| if file is None: |
| return None, "Please upload a CSV file." |
| try: |
| df = pd.read_csv(file.name) |
| except Exception as e: |
| return None, f"Could not read CSV: {e}" |
|
|
| missing = [c for c in REQUIRED_COLUMNS if c not in df.columns] |
| if missing: |
| return None, f"Missing required columns: {missing}" |
|
|
| probs = [] |
| preds = [] |
| for _, row in df.iterrows(): |
| p = churn_probability( |
| row["CreditScore"], |
| row["Age"], |
| row["Balance"], |
| row["NumOfProducts"], |
| row["IsActiveMember"], |
| row["EstimatedSalary"], |
| ) |
| probs.append(round(p, 4)) |
| preds.append(1 if p >= 0.5 else 0) |
|
|
| out = df.copy() |
| out["churn_probability"] = probs |
| out["churn_prediction"] = preds |
|
|
| output_path = Path("/tmp/bank_churn_predictions.csv") |
| out.to_csv(output_path, index=False) |
| return str(output_path), f"Done. Processed {len(out)} rows." |
|
|
| def sample_csv(): |
| df = pd.DataFrame( |
| [ |
| [600, 45, 50000, 1, 0, 70000], |
| [720, 31, 12000, 2, 1, 85000], |
| ], |
| columns=REQUIRED_COLUMNS, |
| ) |
| path = Path("/tmp/sample_bank_churn_input.csv") |
| df.to_csv(path, index=False) |
| return str(path) |
|
|
| def build_ui(): |
| with gr.Blocks() as demo: |
| gr.Markdown("# 🏦 Bank Churn Simple App") |
| gr.Markdown( |
| "This is a lightweight version built to reduce Hugging Face startup issues." |
| ) |
|
|
| with gr.Tab("Single Prediction"): |
| credit_score = gr.Slider(300, 900, value=650, step=1, label="CreditScore") |
| age = gr.Slider(18, 100, value=40, step=1, label="Age") |
| balance = gr.Number(value=50000, label="Balance") |
| num_products = gr.Slider(1, 4, value=2, step=1, label="NumOfProducts") |
| is_active_member = gr.Dropdown( |
| choices=[0, 1], value=1, label="IsActiveMember" |
| ) |
| estimated_salary = gr.Number(value=80000, label="EstimatedSalary") |
|
|
| predict_btn = gr.Button("Predict") |
| summary_out = gr.Markdown() |
| table_out = gr.Dataframe() |
|
|
| predict_btn.click( |
| fn=predict_single, |
| inputs=[ |
| credit_score, |
| age, |
| balance, |
| num_products, |
| is_active_member, |
| estimated_salary, |
| ], |
| outputs=[summary_out, table_out], |
| ) |
|
|
| with gr.Tab("CSV Batch Prediction"): |
| gr.Markdown("Required columns: " + ", ".join(REQUIRED_COLUMNS)) |
| input_file = gr.File(label="Upload CSV", file_types=[".csv"]) |
| batch_btn = gr.Button("Run Batch Prediction") |
| output_file = gr.File(label="Download Results") |
| batch_msg = gr.Markdown() |
| sample_btn = gr.Button("Download Sample CSV") |
|
|
| batch_btn.click( |
| fn=predict_batch, |
| inputs=[input_file], |
| outputs=[output_file, batch_msg], |
| ) |
| sample_btn.click(fn=sample_csv, outputs=[output_file]) |
|
|
| return demo |
|
|
| if __name__ == "__main__": |
| demo = build_ui() |
| port = int(os.environ.get("PORT", "7860")) |
| demo.launch(server_name="0.0.0.0", server_port=port) |
|
|