XRachel commited on
Commit
b8dd6a5
·
verified ·
1 Parent(s): 51167cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -31
app.py CHANGED
@@ -1,67 +1,220 @@
1
-
2
  import os
 
3
  import subprocess
 
 
 
4
  import gradio as gr
5
  import pandas as pd
 
6
  import joblib
7
 
8
- MODEL_PATH="models/pipeline.joblib"
 
 
9
 
 
 
 
 
10
  def load_model():
11
  if os.path.exists(MODEL_PATH):
12
  return joblib.load(MODEL_PATH)
13
  return None
14
 
15
- model=load_model()
16
 
17
- def predict(age,balance):
18
- global model
19
- model=load_model()
20
  if model is None:
21
- return "Run pipeline first"
22
- df=pd.DataFrame([[age,balance]],columns=["Age","Balance"])
23
- p=model.predict(df)[0]
24
- return f"Prediction: {p}"
 
25
 
26
  def run_pipeline():
27
- proc=subprocess.Popen(
28
- ["python","scripts/pipeline.py"],
29
  stdout=subprocess.PIPE,
30
  stderr=subprocess.STDOUT,
31
  text=True
32
  )
33
- log=""
34
  for line in proc.stdout:
35
- log+=line
36
  yield log
37
 
38
- def build_ui():
39
- css=open("style.css").read()
40
 
41
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  gr.HTML(f"<style>{css}</style>")
44
 
45
- gr.Markdown("# 🏦 Bank Churn Predictor")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  with gr.Tab("Pipeline"):
48
- gr.Markdown("Train model and view execution log")
49
- btn=gr.Button("Run Pipeline")
50
- log=gr.Textbox(lines=20,label="Execution Log")
51
- btn.click(run_pipeline,outputs=log)
52
 
53
  with gr.Tab("Prediction"):
54
- age=gr.Number(label="Age")
55
- balance=gr.Number(label="Balance")
56
- btn=gr.Button("Predict")
57
- out=gr.Textbox()
 
58
 
59
- btn.click(predict,[age,balance],out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  return demo
62
 
63
- if __name__=="__main__":
64
- demo=build_ui()
 
65
  demo.queue()
66
- port=int(os.environ.get("PORT",7860))
67
- demo.launch(server_name="0.0.0.0",server_port=port)
 
 
1
  import os
2
+ import json
3
  import subprocess
4
+ import urllib.request
5
+ import urllib.error
6
+
7
  import gradio as gr
8
  import pandas as pd
9
+ import numpy as np
10
  import joblib
11
 
12
+ MODEL_PATH = "models/pipeline.joblib"
13
+ PY_NOTEBOOK = "BankChurn_Version1.ipynb"
14
+ R_NOTEBOOK = "BankChurn_Version1_R.ipynb"
15
 
16
+
17
+ # =========================
18
+ # Model
19
+ # =========================
20
  def load_model():
21
  if os.path.exists(MODEL_PATH):
22
  return joblib.load(MODEL_PATH)
23
  return None
24
 
 
25
 
26
+ def predict(age, balance):
27
+ model = load_model()
 
28
  if model is None:
29
+ return "Please run the pipeline first."
30
+ df = pd.DataFrame([[age, balance]], columns=["Age", "Balance"])
31
+ pred = model.predict(df)[0]
32
+ return "Churn Risk: Yes" if pred == 1 else "Churn Risk: No"
33
+
34
 
35
  def run_pipeline():
36
+ proc = subprocess.Popen(
37
+ ["python", "scripts/pipeline.py"],
38
  stdout=subprocess.PIPE,
39
  stderr=subprocess.STDOUT,
40
  text=True
41
  )
42
+ log = ""
43
  for line in proc.stdout:
44
+ log += line
45
  yield log
46
 
 
 
47
 
48
+ # =========================
49
+ # Demo data for dashboard
50
+ # =========================
51
+ def make_demo_data():
52
+ np.random.seed(42)
53
+ n = 120
54
+ df = pd.DataFrame({
55
+ "CustomerID": range(1, n + 1),
56
+ "Age": np.random.randint(18, 70, n),
57
+ "Balance": np.random.randint(500, 10000, n),
58
+ "Tenure": np.random.randint(1, 10, n),
59
+ "Products": np.random.randint(1, 5, n),
60
+ "Geography": np.random.choice(["France", "Germany", "Spain"], n),
61
+ "Exited": np.random.choice([0, 1], n, p=[0.78, 0.22])
62
+ })
63
+ return df
64
+
65
+
66
+ demo_df = make_demo_data()
67
+
68
+ geo_df = demo_df.groupby("Geography", as_index=False)["Exited"].mean()
69
+ geo_df["Exited"] = (geo_df["Exited"] * 100).round(2)
70
+
71
+ age_df = demo_df.groupby(pd.cut(demo_df["Age"], bins=[18, 30, 40, 50, 60, 70])).agg(
72
+ churn_rate=("Exited", "mean")
73
+ ).reset_index()
74
+ age_df["AgeBand"] = age_df["Age"].astype(str)
75
+ age_df["churn_rate"] = (age_df["churn_rate"] * 100).round(2)
76
+
77
+ summary_md = f"""
78
+ ### Dashboard Summary
79
+ - Total Customers: **{len(demo_df)}**
80
+ - Churned Customers: **{int(demo_df['Exited'].sum())}**
81
+ - Churn Rate: **{round(demo_df['Exited'].mean() * 100, 2)}%**
82
+ - Avg Balance: **${round(demo_df['Balance'].mean(), 2)}**
83
+ """
84
+
85
+
86
+ # =========================
87
+ # Notebook preview
88
+ # =========================
89
+ def load_notebook_preview(path, max_cells=6):
90
+ if not os.path.exists(path):
91
+ return "File not found."
92
+ with open(path, "r", encoding="utf-8") as f:
93
+ nb = json.load(f)
94
+
95
+ parts = []
96
+ for i, cell in enumerate(nb.get("cells", [])[:max_cells]):
97
+ source = "".join(cell.get("source", []))
98
+ parts.append(f"# Cell {i+1} ({cell.get('cell_type','code')})\n{source}\n")
99
+ return "\n\n".join(parts)
100
+
101
+
102
+ py_preview = load_notebook_preview(PY_NOTEBOOK)
103
+ r_preview = load_notebook_preview(R_NOTEBOOK)
104
 
105
+
106
+ # =========================
107
+ # HF AI integration
108
+ # =========================
109
+ def hf_ai_insight(question):
110
+ api_key = os.getenv("HF_API_KEY")
111
+ if not api_key:
112
+ return "HF_API_KEY not found. Add it in Space Secrets first."
113
+
114
+ prompt = f"""
115
+ You are a banking analytics assistant.
116
+ Context:
117
+ - This app is about bank churn prediction.
118
+ - Give short, practical business insights.
119
+ User question: {question}
120
+ """
121
+
122
+ payload = json.dumps({"inputs": prompt}).encode("utf-8")
123
+ req = urllib.request.Request(
124
+ "https://api-inference.huggingface.co/models/google/flan-t5-base",
125
+ data=payload,
126
+ headers={
127
+ "Authorization": f"Bearer {api_key}",
128
+ "Content-Type": "application/json"
129
+ }
130
+ )
131
+
132
+ try:
133
+ with urllib.request.urlopen(req, timeout=60) as resp:
134
+ result = json.loads(resp.read().decode("utf-8"))
135
+ if isinstance(result, list) and len(result) > 0 and "generated_text" in result[0]:
136
+ return result[0]["generated_text"]
137
+ return str(result)
138
+ except urllib.error.HTTPError as e:
139
+ return f"HF API error: {e.read().decode('utf-8')}"
140
+ except Exception as e:
141
+ return f"Request failed: {str(e)}"
142
+
143
+
144
+ # =========================
145
+ # UI
146
+ # =========================
147
+ def build_ui():
148
+ css = open("style.css", "r", encoding="utf-8").read()
149
+
150
+ with gr.Blocks(title="Bank Churn Dashboard") as demo:
151
  gr.HTML(f"<style>{css}</style>")
152
 
153
+ gr.Markdown("""
154
+ # 🏦 Bank Churn Dashboard
155
+ Interactive churn analysis, model pipeline, prediction, and AI insight.
156
+ """)
157
+
158
+ with gr.Tab("Dashboard"):
159
+ gr.Markdown(summary_md)
160
+
161
+ with gr.Row():
162
+ churn_geo = gr.BarPlot(
163
+ value=geo_df,
164
+ x="Geography",
165
+ y="Exited",
166
+ title="Interactive Churn Rate by Geography (%)"
167
+ )
168
+ churn_age = gr.LinePlot(
169
+ value=age_df,
170
+ x="AgeBand",
171
+ y="churn_rate",
172
+ title="Interactive Churn Rate by Age Band (%)"
173
+ )
174
+
175
+ gr.Dataframe(
176
+ value=demo_df,
177
+ interactive=True,
178
+ label="Interactive Customer Table"
179
+ )
180
 
181
  with gr.Tab("Pipeline"):
182
+ gr.Markdown("Train the model and inspect the execution log.")
183
+ btn_run = gr.Button("Run Pipeline")
184
+ log = gr.Textbox(lines=18, label="Execution Log")
185
+ btn_run.click(run_pipeline, outputs=log)
186
 
187
  with gr.Tab("Prediction"):
188
+ age = gr.Number(label="Age", value=35)
189
+ balance = gr.Number(label="Balance", value=5000)
190
+ btn_pred = gr.Button("Predict Churn")
191
+ pred_out = gr.Textbox(label="Prediction Result")
192
+ btn_pred.click(predict, inputs=[age, balance], outputs=pred_out)
193
 
194
+ with gr.Tab("Analysis Files"):
195
+ gr.Markdown("### Python analysis notebook")
196
+ gr.File(value=PY_NOTEBOOK, label="Download Python Notebook")
197
+ gr.Code(value=py_preview, language="python", label="Python Notebook Preview")
198
+
199
+ gr.Markdown("### R analysis notebook")
200
+ gr.File(value=R_NOTEBOOK, label="Download R Notebook")
201
+ gr.Code(value=r_preview, language="r", label="R Notebook Preview")
202
+
203
+ with gr.Tab("AI Insight"):
204
+ gr.Markdown("Ask AI for a churn insight. Requires `HF_API_KEY` in Space Secrets.")
205
+ ai_q = gr.Textbox(
206
+ label="Ask something",
207
+ placeholder="Example: What customer segment should the bank focus on retaining first?"
208
+ )
209
+ ai_btn = gr.Button("Generate AI Insight")
210
+ ai_out = gr.Textbox(lines=8, label="AI Response")
211
+ ai_btn.click(hf_ai_insight, inputs=ai_q, outputs=ai_out)
212
 
213
  return demo
214
 
215
+
216
+ if __name__ == "__main__":
217
+ demo = build_ui()
218
  demo.queue()
219
+ port = int(os.environ.get("PORT", 7860))
220
+ demo.launch(server_name="0.0.0.0", server_port=port)