XRachel commited on
Commit
13fba44
·
verified ·
1 Parent(s): a696af2

Upload 3 files

Browse files
Files changed (3) hide show
  1. BankChurn_Version1.ipynb +1 -1
  2. BankChurn_Version1_R.ipynb +1 -1
  3. app.py +68 -51
BankChurn_Version1.ipynb CHANGED
@@ -1 +1 @@
1
- {"cells": [{"cell_type": "code", "metadata": {}, "source": ["import pandas as pd\n", "from pathlib import Path\n", "df = pd.read_csv('bankChurn.csv')\n", "summary = df.groupby('Geography')['Exited'].mean().reset_index()\n", "summary['Exited'] = summary['Exited']*100\n", "out = Path('artifacts/py/tables')\n", "out.mkdir(parents=True, exist_ok=True)\n", "summary.to_csv(out/'churn_by_geo.csv', index=False)\n", "summary\n"], "outputs": [], "execution_count": null}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
 
1
+ {"cells": [{"cell_type": "code", "metadata": {}, "source": ["import pandas as pd\n", "from pathlib import Path\n", "\n", "df = pd.read_csv('bankChurn.csv')\n", "target_col = 'CHURN_CUST_IND' if 'CHURN_CUST_IND' in df.columns else 'Exited'\n", "segment_col = 'GENDER_CD' if 'GENDER_CD' in df.columns else df.select_dtypes(include='object').columns[0]\n", "age_col = 'AGE' if 'AGE' in df.columns else 'Age'\n", "\n", "out = Path('artifacts/py/tables')\n", "out.mkdir(parents=True, exist_ok=True)\n", "\n", "seg = df.groupby(segment_col)[target_col].mean().reset_index()\n", "seg[target_col] = seg[target_col] * 100\n", "seg.to_csv(out / 'python_churn_by_segment.csv', index=False)\n", "\n", "age = df[[age_col, target_col]].dropna().copy()\n", "age['AgeBand'] = pd.cut(age[age_col], bins=[18,30,40,50,60,70,120], include_lowest=True)\n", "age = age.groupby('AgeBand')[target_col].mean().reset_index()\n", "age[target_col] = age[target_col] * 100\n", "age.to_csv(out / 'python_churn_by_age.csv', index=False)\n", "\n", "seg\n"], "outputs": [], "execution_count": null}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}}, "nbformat": 4, "nbformat_minor": 5}
BankChurn_Version1_R.ipynb CHANGED
@@ -1 +1 @@
1
- {"cells": [{"cell_type": "code", "metadata": {"language": "R"}, "source": ["library(readr)\n", "library(dplyr)\n", "dir.create('artifacts/r/tables', recursive=TRUE, showWarnings=FALSE)\n", "bankChurn <- read_csv('bankChurn.csv')\n", "summary_geo <- bankChurn |> group_by(Geography) |> summarise(churn_rate = mean(Exited))\n", "write_csv(summary_geo, 'artifacts/r/tables/r_churn_geo.csv')\n", "summary_geo\n"], "outputs": [], "execution_count": null}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
 
1
+ {"cells": [{"cell_type": "code", "metadata": {"language": "R"}, "source": ["library(readr)\n", "library(dplyr)\n", "library(tidyr)\n", "library(forcats)\n", "\n", "dir.create('artifacts/r/tables', recursive=TRUE, showWarnings=FALSE)\n", "bankChurn <- read_csv('bankChurn.csv')\n", "\n", "target_col <- if ('CHURN_CUST_IND' %in% names(bankChurn)) 'CHURN_CUST_IND' else 'Exited'\n", "segment_col <- if ('GENDER_CD' %in% names(bankChurn)) 'GENDER_CD' else names(bankChurn)[sapply(bankChurn, is.character)][1]\n", "age_col <- if ('AGE' %in% names(bankChurn)) 'AGE' else 'Age'\n", "\n", "summary_seg <- bankChurn |> group_by(.data[[segment_col]]) |> summarise(churn_rate = mean(.data[[target_col]], na.rm = TRUE), .groups='drop')\n", "write_csv(summary_seg, 'artifacts/r/tables/r_churn_by_segment.csv')\n", "\n", "summary_age <- bankChurn |> mutate(AgeBand = cut(.data[[age_col]], breaks=c(18,30,40,50,60,70,120), include.lowest=TRUE)) |> group_by(AgeBand) |> summarise(churn_rate = mean(.data[[target_col]], na.rm = TRUE), .groups='drop')\n", "write_csv(summary_age, 'artifacts/r/tables/r_churn_by_age.csv')\n", "\n", "summary_seg\n"], "outputs": [], "execution_count": null}], "metadata": {"kernelspec": {"display_name": "ir", "language": "R", "name": "ir"}}, "nbformat": 4, "nbformat_minor": 5}
app.py CHANGED
@@ -26,9 +26,7 @@ PIPELINE_CANDIDATES = [
26
 
27
  RUNS_DIR = BASE_DIR / "runs"
28
  ART_DIR = BASE_DIR / "artifacts"
29
- PY_FIG_DIR = ART_DIR / "py" / "figures"
30
  PY_TAB_DIR = ART_DIR / "py" / "tables"
31
- R_FIG_DIR = ART_DIR / "r" / "figures"
32
  R_TAB_DIR = ART_DIR / "r" / "tables"
33
 
34
  PAPERMILL_TIMEOUT = int(os.environ.get("PAPERMILL_TIMEOUT", "1800"))
@@ -37,7 +35,7 @@ MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct").strip()
37
 
38
 
39
  def ensure_dirs():
40
- for p in [RUNS_DIR, PY_FIG_DIR, PY_TAB_DIR, R_FIG_DIR, R_TAB_DIR]:
41
  p.mkdir(parents=True, exist_ok=True)
42
 
43
 
@@ -158,13 +156,41 @@ def load_data():
158
  if DATA_PATH.exists():
159
  return pd.read_csv(DATA_PATH)
160
  return pd.DataFrame({
161
- "Age": [25, 45, 33],
162
- "Balance": [1000, 5000, 2300],
163
- "Geography": ["France", "Germany", "Spain"],
164
- "Exited": [0, 1, 0],
165
  })
166
 
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  def _read_json(path: Path):
169
  with open(path, "r", encoding="utf-8") as f:
170
  obj = json.load(f)
@@ -217,12 +243,7 @@ def build_interactive_plot(df: pd.DataFrame, title: str):
217
 
218
  if len(numeric_cols) >= 2:
219
  chart_df = df[numeric_cols[:2]].dropna().copy().head(300)
220
- fig = px.scatter(
221
- chart_df,
222
- x=numeric_cols[0],
223
- y=numeric_cols[1],
224
- title=title
225
- )
226
  fig.update_layout(height=380)
227
  return fig
228
 
@@ -235,49 +256,51 @@ def build_interactive_plot(df: pd.DataFrame, title: str):
235
 
236
 
237
  def build_overview_charts(df: pd.DataFrame):
238
- geo_fig = px.scatter(title="Churn by Geography (%)")
239
- age_fig = px.scatter(title="Churn by Age Band (%)")
 
 
 
 
240
 
241
- if {"Geography", "Exited"}.issubset(df.columns):
242
- geo_df = df.groupby("Geography", as_index=False)["Exited"].mean()
243
- geo_df["Exited"] = (geo_df["Exited"] * 100).round(2)
244
- geo_fig = px.bar(geo_df, x="Geography", y="Exited", title="Churn by Geography (%)")
245
- geo_fig.update_layout(height=380)
246
 
247
- if {"Age", "Exited"}.issubset(df.columns):
248
  temp = df.copy()
249
- temp["AgeBand"] = pd.cut(
250
- temp["Age"],
251
- bins=[18, 30, 40, 50, 60, 70],
252
- include_lowest=True
253
- )
254
- age_df = temp.groupby("AgeBand").agg(churn_rate=("Exited", "mean")).reset_index()
255
  age_df["AgeBand"] = age_df["AgeBand"].astype(str)
256
  age_df["churn_rate"] = (age_df["churn_rate"] * 100).round(2)
257
  age_fig = px.line(age_df, x="AgeBand", y="churn_rate", title="Churn by Age Band (%)", markers=True)
258
  age_fig.update_layout(height=380)
259
 
260
- return geo_fig, age_fig
261
 
262
 
263
  def build_dashboard():
264
  df = load_data()
 
 
265
 
266
  summary_lines = [
267
  "### Executive Summary",
268
  f"- Total Customers: **{len(df)}**",
269
  ]
270
- if "Exited" in df.columns:
271
- summary_lines.append(f"- Churn Rate: **{round(df['Exited'].mean() * 100, 2)}%**")
272
- summary_lines.append(f"- Churned Customers: **{int(df['Exited'].sum())}**")
273
- if "Balance" in df.columns:
274
- summary_lines.append(f"- Average Balance: **{round(df['Balance'].mean(), 2)}**")
275
 
276
  kernels = ", ".join(sorted(available_kernels().keys())) or "none"
277
  summary_lines.append(f"- Available Kernels: **{kernels}**")
278
  summary_md = "\n".join(summary_lines)
279
 
280
- geo_fig, age_fig = build_overview_charts(df)
281
 
282
  py_name, py_df = load_latest_table(PY_TAB_DIR)
283
  r_name, r_df = load_latest_table(R_TAB_DIR)
@@ -293,7 +316,7 @@ def build_dashboard():
293
  if r_df is None:
294
  r_df = pd.DataFrame([{"info": "No R table found in artifacts/r/tables"}])
295
 
296
- return summary_md, geo_fig, age_fig, py_status, py_plot, py_df, r_status, r_plot, r_df
297
 
298
 
299
  def generate_ai_insight(question: str):
@@ -301,10 +324,13 @@ def generate_ai_insight(question: str):
301
  return "HF_API_KEY is not configured in Space Secrets."
302
 
303
  df = load_data()
 
 
304
  summary = {
305
  "rows": int(len(df)),
306
- "churn_rate": round(float(df["Exited"].mean() * 100), 2) if "Exited" in df.columns else None,
307
- "avg_balance": round(float(df["Balance"].mean()), 2) if "Balance" in df.columns else None,
 
308
  }
309
  py_name, py_df = load_latest_table(PY_TAB_DIR)
310
  r_name, r_df = load_latest_table(R_TAB_DIR)
@@ -344,11 +370,7 @@ Return:
344
  )
345
  return response.choices[0].message.content.strip()
346
  except Exception:
347
- return client.text_generation(
348
- prompt,
349
- model=MODEL_NAME,
350
- max_new_tokens=350,
351
- )
352
  except Exception as e:
353
  return f"AI request failed: {str(e)}"
354
 
@@ -372,12 +394,7 @@ def build_ui():
372
  btn_py = gr.Button("Run Python", variant="secondary")
373
  btn_r = gr.Button("Run R", variant="secondary")
374
  btn_all = gr.Button("Run All", variant="primary")
375
- exec_log = gr.Textbox(
376
- label="Execution Log",
377
- lines=18,
378
- max_lines=28,
379
- interactive=False,
380
- )
381
  btn_py.click(run_python, outputs=[exec_log])
382
  btn_r.click(run_r, outputs=[exec_log])
383
  btn_all.click(run_all, outputs=[exec_log])
@@ -387,7 +404,7 @@ def build_ui():
387
  summary_md = gr.Markdown()
388
 
389
  with gr.Row():
390
- geo_plot = gr.Plot(label="Churn by Geography")
391
  age_plot = gr.Plot(label="Churn by Age Band")
392
 
393
  with gr.Row():
@@ -404,11 +421,11 @@ def build_ui():
404
 
405
  refresh_btn.click(
406
  build_dashboard,
407
- outputs=[summary_md, geo_plot, age_plot, py_status, py_plot, py_table, r_status, r_plot, r_table],
408
  )
409
  demo.load(
410
  build_dashboard,
411
- outputs=[summary_md, geo_plot, age_plot, py_status, py_plot, py_table, r_status, r_plot, r_table],
412
  )
413
 
414
  with gr.Tab("Prediction"):
 
26
 
27
  RUNS_DIR = BASE_DIR / "runs"
28
  ART_DIR = BASE_DIR / "artifacts"
 
29
  PY_TAB_DIR = ART_DIR / "py" / "tables"
 
30
  R_TAB_DIR = ART_DIR / "r" / "tables"
31
 
32
  PAPERMILL_TIMEOUT = int(os.environ.get("PAPERMILL_TIMEOUT", "1800"))
 
35
 
36
 
37
  def ensure_dirs():
38
+ for p in [RUNS_DIR, PY_TAB_DIR, R_TAB_DIR]:
39
  p.mkdir(parents=True, exist_ok=True)
40
 
41
 
 
156
  if DATA_PATH.exists():
157
  return pd.read_csv(DATA_PATH)
158
  return pd.DataFrame({
159
+ "AGE": [25, 45, 33],
160
+ "LOCAL_CUR_MON_AVG_BAL": [1000, 5000, 2300],
161
+ "GENDER_CD": ["M", "F", "M"],
162
+ "CHURN_CUST_IND": [0, 1, 0],
163
  })
164
 
165
 
166
+ def get_target_col(df: pd.DataFrame):
167
+ for c in ["CHURN_CUST_IND", "Exited", "churn", "target"]:
168
+ if c in df.columns:
169
+ return c
170
+ return None
171
+
172
+
173
+ def get_age_col(df: pd.DataFrame):
174
+ for c in ["AGE", "Age", "age"]:
175
+ if c in df.columns:
176
+ return c
177
+ return None
178
+
179
+
180
+ def get_balance_col(df: pd.DataFrame):
181
+ for c in ["LOCAL_CUR_MON_AVG_BAL", "Balance", "balance"]:
182
+ if c in df.columns:
183
+ return c
184
+ return None
185
+
186
+
187
+ def get_segment_col(df: pd.DataFrame):
188
+ for c in ["Geography", "GENDER_CD", "gender", "SEGMENT"]:
189
+ if c in df.columns:
190
+ return c
191
+ return None
192
+
193
+
194
  def _read_json(path: Path):
195
  with open(path, "r", encoding="utf-8") as f:
196
  obj = json.load(f)
 
243
 
244
  if len(numeric_cols) >= 2:
245
  chart_df = df[numeric_cols[:2]].dropna().copy().head(300)
246
+ fig = px.scatter(chart_df, x=numeric_cols[0], y=numeric_cols[1], title=title)
 
 
 
 
 
247
  fig.update_layout(height=380)
248
  return fig
249
 
 
256
 
257
 
258
  def build_overview_charts(df: pd.DataFrame):
259
+ target_col = get_target_col(df)
260
+ age_col = get_age_col(df)
261
+ segment_col = get_segment_col(df)
262
+
263
+ seg_fig = px.scatter(title="Churn by Segment")
264
+ age_fig = px.scatter(title="Churn by Age Band")
265
 
266
+ if target_col and segment_col:
267
+ seg_df = df.groupby(segment_col, as_index=False)[target_col].mean()
268
+ seg_df[target_col] = (seg_df[target_col] * 100).round(2)
269
+ seg_fig = px.bar(seg_df, x=segment_col, y=target_col, title=f"Churn by {segment_col} (%)")
270
+ seg_fig.update_layout(height=380)
271
 
272
+ if target_col and age_col:
273
  temp = df.copy()
274
+ temp["AgeBand"] = pd.cut(temp[age_col], bins=[18, 30, 40, 50, 60, 70, 120], include_lowest=True)
275
+ age_df = temp.groupby("AgeBand").agg(churn_rate=(target_col, "mean")).reset_index()
 
 
 
 
276
  age_df["AgeBand"] = age_df["AgeBand"].astype(str)
277
  age_df["churn_rate"] = (age_df["churn_rate"] * 100).round(2)
278
  age_fig = px.line(age_df, x="AgeBand", y="churn_rate", title="Churn by Age Band (%)", markers=True)
279
  age_fig.update_layout(height=380)
280
 
281
+ return seg_fig, age_fig
282
 
283
 
284
  def build_dashboard():
285
  df = load_data()
286
+ target_col = get_target_col(df)
287
+ balance_col = get_balance_col(df)
288
 
289
  summary_lines = [
290
  "### Executive Summary",
291
  f"- Total Customers: **{len(df)}**",
292
  ]
293
+ if target_col:
294
+ summary_lines.append(f"- Churn Rate: **{round(df[target_col].mean() * 100, 2)}%**")
295
+ summary_lines.append(f"- Churned Customers: **{int(df[target_col].sum())}**")
296
+ if balance_col:
297
+ summary_lines.append(f"- Average Balance: **{round(df[balance_col].mean(), 2)}**")
298
 
299
  kernels = ", ".join(sorted(available_kernels().keys())) or "none"
300
  summary_lines.append(f"- Available Kernels: **{kernels}**")
301
  summary_md = "\n".join(summary_lines)
302
 
303
+ seg_fig, age_fig = build_overview_charts(df)
304
 
305
  py_name, py_df = load_latest_table(PY_TAB_DIR)
306
  r_name, r_df = load_latest_table(R_TAB_DIR)
 
316
  if r_df is None:
317
  r_df = pd.DataFrame([{"info": "No R table found in artifacts/r/tables"}])
318
 
319
+ return summary_md, seg_fig, age_fig, py_status, py_plot, py_df, r_status, r_plot, r_df
320
 
321
 
322
  def generate_ai_insight(question: str):
 
324
  return "HF_API_KEY is not configured in Space Secrets."
325
 
326
  df = load_data()
327
+ target_col = get_target_col(df)
328
+ balance_col = get_balance_col(df)
329
  summary = {
330
  "rows": int(len(df)),
331
+ "churn_rate": round(float(df[target_col].mean() * 100), 2) if target_col else None,
332
+ "avg_balance": round(float(df[balance_col].mean()), 2) if balance_col else None,
333
+ "target_column": target_col,
334
  }
335
  py_name, py_df = load_latest_table(PY_TAB_DIR)
336
  r_name, r_df = load_latest_table(R_TAB_DIR)
 
370
  )
371
  return response.choices[0].message.content.strip()
372
  except Exception:
373
+ return client.text_generation(prompt, model=MODEL_NAME, max_new_tokens=350)
 
 
 
 
374
  except Exception as e:
375
  return f"AI request failed: {str(e)}"
376
 
 
394
  btn_py = gr.Button("Run Python", variant="secondary")
395
  btn_r = gr.Button("Run R", variant="secondary")
396
  btn_all = gr.Button("Run All", variant="primary")
397
+ exec_log = gr.Textbox(label="Execution Log", lines=18, max_lines=28, interactive=False)
 
 
 
 
 
398
  btn_py.click(run_python, outputs=[exec_log])
399
  btn_r.click(run_r, outputs=[exec_log])
400
  btn_all.click(run_all, outputs=[exec_log])
 
404
  summary_md = gr.Markdown()
405
 
406
  with gr.Row():
407
+ seg_plot = gr.Plot(label="Churn by Segment")
408
  age_plot = gr.Plot(label="Churn by Age Band")
409
 
410
  with gr.Row():
 
421
 
422
  refresh_btn.click(
423
  build_dashboard,
424
+ outputs=[summary_md, seg_plot, age_plot, py_status, py_plot, py_table, r_status, r_plot, r_table],
425
  )
426
  demo.load(
427
  build_dashboard,
428
+ outputs=[summary_md, seg_plot, age_plot, py_status, py_plot, py_table, r_status, r_plot, r_table],
429
  )
430
 
431
  with gr.Tab("Prediction"):