Phileassss commited on
Commit
486816b
·
verified ·
1 Parent(s): 2ef3a03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +482 -281
app.py CHANGED
@@ -7,9 +7,10 @@ from pathlib import Path
7
  from typing import Dict, Any, List, Tuple
8
 
9
  import pandas as pd
 
 
10
  import gradio as gr
11
  import papermill as pm
12
- import plotly.graph_objects as go
13
 
14
  # Optional LLM (HuggingFace Inference API)
15
  try:
@@ -102,7 +103,6 @@ def run_notebook(nb_name: str) -> str:
102
  )
103
  return f"Executed {nb_name}"
104
 
105
-
106
  def run_datacreation() -> str:
107
  try:
108
  log = run_notebook(NB1)
@@ -111,7 +111,6 @@ def run_datacreation() -> str:
111
  except Exception as e:
112
  return f"FAILED {e}\n\n{traceback.format_exc()[-2000:]}"
113
 
114
-
115
  def run_pythonanalysis() -> str:
116
  try:
117
  log = run_notebook(NB2)
@@ -126,93 +125,130 @@ def run_pythonanalysis() -> str:
126
  except Exception as e:
127
  return f"FAILED {e}\n\n{traceback.format_exc()[-2000:]}"
128
 
129
-
130
  def run_full_pipeline() -> str:
131
  logs = []
132
  logs.append("=" * 50)
133
- logs.append("STEP 1/2: Data Creation (web scraping + synthetic data)")
134
  logs.append("=" * 50)
135
  logs.append(run_datacreation())
136
  logs.append("")
137
  logs.append("=" * 50)
138
- logs.append("STEP 2/2: Python Analysis (sentiment, ARIMA, dashboard)")
139
  logs.append("=" * 50)
140
  logs.append(run_pythonanalysis())
141
  return "\n".join(logs)
142
 
143
-
144
  # =========================================================
145
- # GALLERY LOADERS
146
  # =========================================================
147
 
148
- def _load_all_figures() -> List[Tuple[str, str]]:
149
- """Return list of (filepath, caption) for Gallery."""
150
- items = []
151
- for p in sorted(PY_FIG_DIR.glob("*.png")):
152
- items.append((str(p), p.stem.replace('_', ' ').title()))
153
- return items
154
-
155
-
156
- def _load_table_safe(path: Path) -> pd.DataFrame:
157
- try:
158
- if path.suffix == ".json":
159
- obj = _read_json(path)
160
- if isinstance(obj, dict):
161
- return pd.DataFrame([obj])
162
- return pd.DataFrame(obj)
163
- return _read_csv(path)
164
- except Exception as e:
165
- return pd.DataFrame([{"error": str(e)}])
166
-
167
-
168
- def refresh_gallery():
169
- """Called when user clicks Refresh on Gallery tab."""
170
- figures = _load_all_figures()
171
- idx = artifacts_index()
172
-
173
- table_choices = list(idx["python"]["tables"])
174
-
175
- default_df = pd.DataFrame()
176
- if table_choices:
177
- default_df = _load_table_safe(PY_TAB_DIR / table_choices[0])
178
-
179
- return (
180
- figures if figures else [],
181
- gr.update(choices=table_choices, value=table_choices[0] if table_choices else None),
182
- default_df,
 
 
 
 
 
 
183
  )
184
 
 
185
 
186
- def on_table_select(choice: str):
187
- if not choice:
188
- return pd.DataFrame([{"hint": "Select a table above."}])
189
- path = PY_TAB_DIR / choice
190
- if not path.exists():
191
- return pd.DataFrame([{"error": f"File not found: {choice}"}])
192
- return _load_table_safe(path)
193
-
194
-
195
- # =========================================================
196
- # KPI LOADER
197
- # =========================================================
198
-
199
- def load_kpis() -> Dict[str, Any]:
200
- for candidate in [PY_TAB_DIR / "kpis.json", PY_FIG_DIR / "kpis.json"]:
201
- if candidate.exists():
202
- try:
203
- return _read_json(candidate)
204
- except Exception:
205
- pass
206
- return {}
 
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  # =========================================================
210
- # AI DASHBOARD -- LLM picks what to display
211
  # =========================================================
212
 
213
- DASHBOARD_SYSTEM = """You are an AI dashboard assistant for a book-sales analytics app.
214
- The user asks questions or requests about their data. You have access to pre-computed
215
- artifacts from a Python analysis pipeline.
216
 
217
  AVAILABLE ARTIFACTS (only reference ones that exist):
218
  {artifacts_json}
@@ -222,26 +258,21 @@ KPI SUMMARY: {kpis_json}
222
  YOUR JOB:
223
  1. Answer the user's question conversationally using the KPIs and your knowledge of the artifacts.
224
  2. At the END of your response, output a JSON block (fenced with ```json ... ```) that tells
225
- the dashboard which artifact to display. The JSON must have this shape:
226
  {{"show": "figure"|"table"|"none", "scope": "python", "filename": "..."}}
227
 
228
- - Use "show": "figure" to display a chart image.
229
- - Use "show": "table" to display a CSV/JSON table.
230
- - Use "show": "none" if no artifact is relevant.
231
-
232
  RULES:
233
- - If the user asks about sales trends or forecasting by title, show sales_trends or arima figures.
234
- - If the user asks about sentiment, show sentiment figure or sentiment_counts table.
235
- - If the user asks about forecast accuracy or ARIMA, show arima figures.
236
- - If the user asks about top sellers, show top_titles_by_units_sold.csv.
237
- - If the user asks a general data question, pick the most relevant artifact.
238
- - Keep your answer concise (2-4 sentences), then the JSON block.
239
  """
240
 
241
  JSON_BLOCK_RE = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL)
242
  FALLBACK_JSON_RE = re.compile(r"\{[^{}]*\"show\"[^{}]*\}", re.DOTALL)
243
 
244
-
245
  def _parse_display_directive(text: str) -> Dict[str, str]:
246
  m = JSON_BLOCK_RE.search(text)
247
  if m:
@@ -257,14 +288,10 @@ def _parse_display_directive(text: str) -> Dict[str, str]:
257
  pass
258
  return {"show": "none"}
259
 
260
-
261
  def _clean_response(text: str) -> str:
262
- """Strip the JSON directive block from the displayed response."""
263
  return JSON_BLOCK_RE.sub("", text).strip()
264
 
265
-
266
  def _n8n_call(msg: str) -> Tuple[str, Dict]:
267
- """Call the student's n8n webhook and return (reply, directive)."""
268
  import requests as req
269
  try:
270
  resp = req.post(N8N_WEBHOOK_URL, json={"question": msg}, timeout=20)
@@ -277,16 +304,13 @@ def _n8n_call(msg: str) -> Tuple[str, Dict]:
277
  except Exception as e:
278
  return f"n8n error: {e}. Falling back to keyword matching.", None
279
 
280
-
281
  def ai_chat(user_msg: str, history: list):
282
- """Chat function for the AI Dashboard tab."""
283
  if not user_msg or not user_msg.strip():
284
  return history, "", None, None
285
 
286
  idx = artifacts_index()
287
  kpis = load_kpis()
288
 
289
- # Priority: n8n webhook > HF LLM > keyword fallback
290
  if N8N_WEBHOOK_URL:
291
  reply, directive = _n8n_call(user_msg)
292
  if directive is None:
@@ -297,20 +321,16 @@ def ai_chat(user_msg: str, history: list):
297
  else:
298
  system = DASHBOARD_SYSTEM.format(
299
  artifacts_json=json.dumps(idx, indent=2),
300
- kpis_json=json.dumps(kpis, indent=2) if kpis else "(no KPIs yet, run the pipeline first)",
301
  )
302
  msgs = [{"role": "system", "content": system}]
303
  for entry in (history or [])[-6:]:
304
  msgs.append(entry)
305
  msgs.append({"role": "user", "content": user_msg})
306
-
307
  try:
308
  r = llm_client.chat_completion(
309
- model=MODEL_NAME,
310
- messages=msgs,
311
- temperature=0.3,
312
- max_tokens=600,
313
- stream=False,
314
  )
315
  raw = (
316
  r["choices"][0]["message"]["content"]
@@ -324,30 +344,28 @@ def ai_chat(user_msg: str, history: list):
324
  reply_fb, directive = _keyword_fallback(user_msg, idx, kpis)
325
  reply += "\n\n" + reply_fb
326
 
327
- # Resolve artifacts — build interactive Plotly charts when possible
328
  chart_out = None
329
  tab_out = None
330
  show = directive.get("show", "none")
331
  fname = directive.get("filename", "")
332
  chart_name = directive.get("chart", "")
333
 
334
- # Interactive chart builders keyed by name
335
  chart_builders = {
336
- "sales": build_sales_chart,
337
- "sentiment": build_sentiment_chart,
338
- "top_sellers": build_top_sellers_chart,
 
339
  }
340
 
341
  if chart_name and chart_name in chart_builders:
342
  chart_out = chart_builders[chart_name]()
343
  elif show == "figure" and fname:
344
- # Fallback: try to match filename to a chart builder
345
- if "sales_trend" in fname:
346
- chart_out = build_sales_chart()
347
- elif "sentiment" in fname:
 
348
  chart_out = build_sentiment_chart()
349
- elif "arima" in fname or "forecast" in fname:
350
- chart_out = build_sales_chart() # closest interactive equivalent
351
  else:
352
  chart_out = _empty_chart(f"No interactive chart for {fname}")
353
 
@@ -362,92 +380,98 @@ def ai_chat(user_msg: str, history: list):
362
  {"role": "user", "content": user_msg},
363
  {"role": "assistant", "content": reply},
364
  ]
365
-
366
  return new_history, "", chart_out, tab_out
367
 
368
-
369
  def _keyword_fallback(msg: str, idx: Dict, kpis: Dict) -> Tuple[str, Dict]:
370
- """Simple keyword matcher when LLM is unavailable."""
371
  msg_lower = msg.lower()
372
 
373
- if not idx["python"]["figures"] and not idx["python"]["tables"]:
374
- return (
375
- "No artifacts found yet. Please run the pipeline first (Tab 1), "
376
- "then come back here to explore the results.",
377
- {"show": "none"},
378
- )
379
-
380
  kpi_text = ""
381
  if kpis:
382
- total = kpis.get("total_units_sold", 0)
383
- kpi_text = (
384
- f"Quick summary: **{kpis.get('n_titles', '?')}** book titles across "
385
- f"**{kpis.get('n_months', '?')}** months, with **{total:,.0f}** total units sold."
386
- )
387
 
388
- if any(w in msg_lower for w in ["trend", "sales trend", "monthly sale"]):
389
  return (
390
- f"Here are the sales trends. {kpi_text}",
391
- {"show": "figure", "chart": "sales"},
392
  )
393
-
394
- if any(w in msg_lower for w in ["sentiment", "review", "positive", "negative"]):
395
  return (
396
- f"Here is the sentiment distribution across sampled book titles. {kpi_text}",
397
- {"show": "figure", "chart": "sentiment"},
398
  )
399
-
400
- if any(w in msg_lower for w in ["arima", "forecast", "predict"]):
401
  return (
402
- f"Here are the sales trends and forecasts. {kpi_text}",
403
- {"show": "figure", "chart": "sales"},
404
  )
405
-
406
- if any(w in msg_lower for w in ["top", "best sell", "popular", "rank"]):
407
  return (
408
- f"Here are the top-selling titles by units sold. {kpi_text}",
409
- {"show": "table", "scope": "python", "filename": "top_titles_by_units_sold.csv"},
410
  )
411
-
412
- if any(w in msg_lower for w in ["price", "pricing", "decision"]):
413
  return (
414
- f"Here are the pricing decisions. {kpi_text}",
415
- {"show": "table", "scope": "python", "filename": "pricing_decisions.csv"},
416
  )
417
-
418
  if any(w in msg_lower for w in ["dashboard", "overview", "summary", "kpi"]):
419
  return (
420
- f"Dashboard overview: {kpi_text}\n\nAsk me about sales trends, sentiment, forecasts, "
421
- "pricing, or top sellers to see specific visualizations.",
422
- {"show": "table", "scope": "python", "filename": "df_dashboard.csv"},
423
  )
424
-
425
- # Default
426
  return (
427
- f"I can show you various analyses. {kpi_text}\n\n"
428
- "Try asking about: **sales trends**, **sentiment**, **ARIMA forecasts**, "
429
- "**pricing decisions**, **top sellers**, or **dashboard overview**.",
430
  {"show": "none"},
431
  )
432
 
433
-
434
  # =========================================================
435
- # KPI CARDS (BubbleBusters style)
436
  # =========================================================
437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  def render_kpi_cards() -> str:
439
  kpis = load_kpis()
440
  if not kpis:
441
  return (
442
  '<div style="background:rgba(255,255,255,.65);backdrop-filter:blur(16px);'
443
  'border-radius:20px;padding:28px;text-align:center;'
444
- 'border:1.5px solid rgba(255,255,255,.7);'
445
- 'box-shadow:0 8px 32px rgba(124,92,191,.08);">'
446
  '<div style="font-size:36px;margin-bottom:10px;">📊</div>'
447
- '<div style="color:#a48de8;font-size:14px;'
448
- 'font-weight:800;margin-bottom:6px;">No data yet</div>'
449
- '<div style="color:#9d8fc4;font-size:12px;">'
450
- 'Run the pipeline to populate these cards.</div>'
451
  '</div>'
452
  )
453
 
@@ -456,19 +480,18 @@ def render_kpi_cards() -> str:
456
  <div style="background:rgba(255,255,255,.72);backdrop-filter:blur(16px);
457
  border-radius:20px;padding:18px 14px 16px;text-align:center;
458
  border:1.5px solid rgba(255,255,255,.8);
459
- box-shadow:0 4px 16px rgba(124,92,191,.08);
460
  border-top:3px solid {colour};">
461
- <div style="font-size:26px;margin-bottom:7px;line-height:1;">{icon}</div>
462
  <div style="color:#9d8fc4;font-size:9.5px;text-transform:uppercase;
463
  letter-spacing:1.8px;margin-bottom:7px;font-weight:800;">{label}</div>
464
  <div style="color:#2d1f4e;font-size:16px;font-weight:800;">{value}</div>
465
  </div>"""
466
 
467
  kpi_config = [
468
- ("n_titles", "📚", "Book Titles", "#a48de8"),
469
- ("n_months", "📅", "Time Periods", "#7aa6f8"),
470
- ("total_units_sold", "📦", "Units Sold", "#6ee7c7"),
471
- ("total_revenue", "💰", "Revenue", "#3dcba8"),
472
  ]
473
 
474
  html = (
@@ -480,27 +503,20 @@ def render_kpi_cards() -> str:
480
  if val is None:
481
  continue
482
  if isinstance(val, (int, float)) and val > 100:
483
- val = f"{val:,.0f}"
484
  html += card(icon, label, str(val), colour)
485
- # Extra KPIs not in config
486
- known = {k for k, *_ in kpi_config}
487
- for key, val in kpis.items():
488
- if key not in known:
489
- label = key.replace("_", " ").title()
490
- if isinstance(val, (int, float)) and val > 100:
491
- val = f"{val:,.0f}"
492
- html += card("📈", label, str(val), "#8fa8f8")
493
  html += "</div>"
494
  return html
495
 
496
-
497
  # =========================================================
498
- # INTERACTIVE PLOTLY CHARTS (BubbleBusters style)
499
  # =========================================================
500
 
501
  CHART_PALETTE = ["#7c5cbf", "#2ec4a0", "#e8537a", "#e8a230", "#5e8fef",
502
  "#c45ea8", "#3dbacc", "#a0522d", "#6aaa3a", "#d46060"]
503
 
 
 
504
  def _styled_layout(**kwargs) -> dict:
505
  defaults = dict(
506
  template="plotly_white",
@@ -518,94 +534,226 @@ def _styled_layout(**kwargs) -> dict:
518
  defaults.update(kwargs)
519
  return defaults
520
 
521
-
522
  def _empty_chart(title: str) -> go.Figure:
523
  fig = go.Figure()
524
  fig.update_layout(
525
  title=title, height=420, template="plotly_white",
526
  paper_bgcolor="rgba(255,255,255,0.95)",
527
- annotations=[dict(text="Run the pipeline to generate data",
 
528
  x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False,
529
  font=dict(size=14, color="rgba(124,92,191,0.5)"))],
530
  )
531
  return fig
532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
 
534
- def build_sales_chart() -> go.Figure:
535
- path = PY_TAB_DIR / "df_dashboard.csv"
536
- if not path.exists():
537
- return _empty_chart("Sales Trends — run the pipeline first")
538
- df = pd.read_csv(path)
539
- date_col = next((c for c in df.columns if "month" in c.lower() or "date" in c.lower()), None)
540
- val_cols = [c for c in df.columns if c != date_col and df[c].dtype in ("float64", "int64")]
541
- if not date_col or not val_cols:
542
- return _empty_chart("Could not auto-detect columns in df_dashboard.csv")
543
- df[date_col] = pd.to_datetime(df[date_col], errors="coerce")
544
  fig = go.Figure()
545
- for i, col in enumerate(val_cols):
546
- fig.add_trace(go.Scatter(
547
- x=df[date_col], y=df[col], name=col.replace("_", " ").title(),
548
- mode="lines+markers", line=dict(color=CHART_PALETTE[i % len(CHART_PALETTE)], width=2),
549
- marker=dict(size=4),
550
- hovertemplate=f"<b>{col.replace('_',' ').title()}</b><br>%{{x|%b %Y}}: %{{y:,.0f}}<extra></extra>",
551
- ))
552
- fig.update_layout(**_styled_layout(height=450, hovermode="x unified",
553
- title=dict(text="Monthly Overview")))
554
- fig.update_xaxes(gridcolor="rgba(124,92,191,0.15)", showgrid=True)
555
- fig.update_yaxes(gridcolor="rgba(124,92,191,0.15)", showgrid=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
556
  return fig
557
 
 
 
 
 
558
 
559
- def build_sentiment_chart() -> go.Figure:
560
- path = PY_TAB_DIR / "sentiment_counts_sampled.csv"
561
- if not path.exists():
562
- return _empty_chart("Sentiment Distribution — run the pipeline first")
563
- df = pd.read_csv(path)
564
- title_col = df.columns[0]
565
- sent_cols = [c for c in ["negative", "neutral", "positive"] if c in df.columns]
566
- if not sent_cols:
567
- return _empty_chart("No sentiment columns found in CSV")
568
- colors = {"negative": "#e8537a", "neutral": "#5e8fef", "positive": "#2ec4a0"}
569
  fig = go.Figure()
570
- for col in sent_cols:
571
- fig.add_trace(go.Bar(
572
- name=col.title(), y=df[title_col], x=df[col],
573
- orientation="h", marker_color=colors.get(col, "#888"),
574
- hovertemplate=f"<b>{col.title()}</b>: %{{x}}<extra></extra>",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  ))
 
576
  fig.update_layout(**_styled_layout(
577
- height=max(400, len(df) * 28), barmode="stack",
578
- title=dict(text="Sentiment Distribution by Book"),
 
 
579
  ))
580
- fig.update_xaxes(title="Number of Reviews")
581
- fig.update_yaxes(autorange="reversed")
582
  return fig
583
 
 
 
 
 
584
 
585
- def build_top_sellers_chart() -> go.Figure:
586
- path = PY_TAB_DIR / "top_titles_by_units_sold.csv"
587
- if not path.exists():
588
- return _empty_chart("Top Sellers — run the pipeline first")
589
- df = pd.read_csv(path).head(15)
590
- title_col = next((c for c in df.columns if "title" in c.lower()), df.columns[0])
591
- val_col = next((c for c in df.columns if "unit" in c.lower() or "sold" in c.lower()), df.columns[-1])
592
- fig = go.Figure(go.Bar(
593
- y=df[title_col], x=df[val_col], orientation="h",
594
- marker=dict(color=df[val_col], colorscale=[[0, "#c5b4f0"], [1, "#7c5cbf"]]),
595
- hovertemplate="<b>%{y}</b><br>Units: %{x:,.0f}<extra></extra>",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  ))
597
  fig.update_layout(**_styled_layout(
598
- height=max(400, len(df) * 30),
599
- title=dict(text="Top Selling Titles"), showlegend=False,
 
600
  ))
601
- fig.update_yaxes(autorange="reversed")
602
- fig.update_xaxes(title="Total Units Sold")
603
  return fig
604
 
 
 
 
 
 
 
 
 
 
 
605
 
606
  def refresh_dashboard():
607
- return render_kpi_cards(), build_sales_chart(), build_sentiment_chart(), build_top_sellers_chart()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
 
 
 
 
 
 
 
 
609
 
610
  # =========================================================
611
  # UI
@@ -617,26 +765,90 @@ def load_css() -> str:
617
  css_path = BASE_DIR / "style.css"
618
  return css_path.read_text(encoding="utf-8") if css_path.exists() else ""
619
 
620
-
621
- with gr.Blocks(title="AIBDM 2026 Workshop App") as demo:
622
 
623
  gr.Markdown(
624
- "# SE21 App Template\n"
625
- "*This is an app template for SE21 students*",
626
  elem_id="escp_title",
627
  )
628
 
629
  # ===========================================================
630
- # TAB 1 -- Pipeline Runner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
  # ===========================================================
632
  with gr.Tab("Pipeline Runner"):
633
- gr.Markdown()
634
 
635
  with gr.Row():
636
  with gr.Column(scale=1):
637
- btn_nb1 = gr.Button("Step 1: Data Creation", variant="secondary")
638
  with gr.Column(scale=1):
639
- btn_nb2 = gr.Button("Step 2: Python Analysis", variant="secondary")
640
 
641
  with gr.Row():
642
  btn_all = gr.Button("Run Full Pipeline (Both Steps)", variant="primary")
@@ -653,25 +865,22 @@ with gr.Blocks(title="AIBDM 2026 Workshop App") as demo:
653
  btn_all.click(run_full_pipeline, outputs=[run_log])
654
 
655
  # ===========================================================
656
- # TAB 2 -- Dashboard (KPIs + Interactive Charts + Gallery)
657
  # ===========================================================
658
  with gr.Tab("Dashboard"):
659
  kpi_html = gr.HTML(value=render_kpi_cards)
660
 
661
  refresh_btn = gr.Button("Refresh Dashboard", variant="primary")
662
 
663
- gr.Markdown("#### Interactive Charts")
664
- chart_sales = gr.Plot(label="Monthly Overview")
665
- chart_sentiment = gr.Plot(label="Sentiment Distribution")
666
- chart_top = gr.Plot(label="Top Sellers")
667
 
668
- gr.Markdown("#### Static Figures (from notebooks)")
669
- gallery = gr.Gallery(
670
- label="Generated Figures",
671
- columns=2,
672
- height=480,
673
- object_fit="contain",
674
- )
675
 
676
  gr.Markdown("#### Data Tables")
677
  table_dropdown = gr.Dropdown(
@@ -679,19 +888,21 @@ with gr.Blocks(title="AIBDM 2026 Workshop App") as demo:
679
  choices=[],
680
  interactive=True,
681
  )
682
- table_display = gr.Dataframe(
683
- label="Table Preview",
684
- interactive=False,
685
- )
686
 
687
  def _on_refresh():
688
- kpi, c1, c2, c3 = refresh_dashboard()
689
  figs, dd, df = refresh_gallery()
690
- return kpi, c1, c2, c3, figs, dd, df
 
 
 
 
 
691
 
692
  refresh_btn.click(
693
  _on_refresh,
694
- outputs=[kpi_html, chart_sales, chart_sentiment, chart_top,
695
  gallery, table_dropdown, table_display],
696
  )
697
  table_dropdown.change(
@@ -701,52 +912,42 @@ with gr.Blocks(title="AIBDM 2026 Workshop App") as demo:
701
  )
702
 
703
  # ===========================================================
704
- # TAB 3 -- AI Dashboard
705
  # ===========================================================
706
  with gr.Tab('"AI" Dashboard'):
707
  _ai_status = (
708
  "Connected to your **n8n workflow**." if N8N_WEBHOOK_URL
709
  else "**LLM active.**" if LLM_ENABLED
710
- else "Using **keyword matching**. Upgrade options: "
711
- "set `N8N_WEBHOOK_URL` to connect your n8n workflow, "
712
- "or set `HF_API_KEY` for direct LLM access."
713
  )
714
  gr.Markdown(
715
  "### Ask questions, get interactive visualisations\n\n"
716
- f"Type a question and the system will pick the right interactive chart or table. {_ai_status}"
717
  )
718
 
719
  with gr.Row(equal_height=True):
720
  with gr.Column(scale=1):
721
- chatbot = gr.Chatbot(
722
- label="Conversation",
723
- height=380,
724
- )
725
  user_input = gr.Textbox(
726
  label="Ask about your data",
727
- placeholder="e.g. Show me sales trends / What are the top sellers? / Sentiment analysis",
728
  lines=1,
729
  )
730
  gr.Examples(
731
  examples=[
732
- "Show me the sales trends",
733
- "What does the sentiment look like?",
734
- "Which titles sell the most?",
735
- "Show the ARIMA forecasts",
736
- "What are the pricing decisions?",
737
  "Give me a dashboard overview",
 
738
  ],
739
  inputs=user_input,
740
  )
741
 
742
  with gr.Column(scale=1):
743
- ai_figure = gr.Plot(
744
- label="Interactive Chart",
745
- )
746
- ai_table = gr.Dataframe(
747
- label="Data Table",
748
- interactive=False,
749
- )
750
 
751
  user_input.submit(
752
  ai_chat,
@@ -755,4 +956,4 @@ with gr.Blocks(title="AIBDM 2026 Workshop App") as demo:
755
  )
756
 
757
 
758
- demo.launch(css=load_css(), allowed_paths=[str(BASE_DIR)])
 
7
  from typing import Dict, Any, List, Tuple
8
 
9
  import pandas as pd
10
+ import numpy as np
11
+ import plotly.graph_objects as go
12
  import gradio as gr
13
  import papermill as pm
 
14
 
15
  # Optional LLM (HuggingFace Inference API)
16
  try:
 
103
  )
104
  return f"Executed {nb_name}"
105
 
 
106
  def run_datacreation() -> str:
107
  try:
108
  log = run_notebook(NB1)
 
111
  except Exception as e:
112
  return f"FAILED {e}\n\n{traceback.format_exc()[-2000:]}"
113
 
 
114
  def run_pythonanalysis() -> str:
115
  try:
116
  log = run_notebook(NB2)
 
125
  except Exception as e:
126
  return f"FAILED {e}\n\n{traceback.format_exc()[-2000:]}"
127
 
 
128
  def run_full_pipeline() -> str:
129
  logs = []
130
  logs.append("=" * 50)
131
+ logs.append("STEP 1/2: Data Creation & Synthetic Enrichment")
132
  logs.append("=" * 50)
133
  logs.append(run_datacreation())
134
  logs.append("")
135
  logs.append("=" * 50)
136
+ logs.append("STEP 2/2: Salary Analysis & Predictions")
137
  logs.append("=" * 50)
138
  logs.append(run_pythonanalysis())
139
  return "\n".join(logs)
140
 
 
141
  # =========================================================
142
+ # SALARY PREDICTION (inline — no Flask needed)
143
  # =========================================================
144
 
145
+ def get_experience_group(exp: float) -> str:
146
+ if exp <= 5: return "0-5 years"
147
+ elif exp <= 10: return "6-10 years"
148
+ elif exp <= 15: return "11-15 years"
149
+ elif exp <= 20: return "16-20 years"
150
+ else: return "20+"
151
+
152
+ def get_career_tier(age: float, exp: float, edu: str) -> str:
153
+ edu_score = {"Bachelor's": 1, "Master's": 2, "PhD": 3}.get(edu, 1)
154
+ if exp >= 15 and edu_score >= 2: return "senior"
155
+ elif exp >= 7 or edu_score == 3: return "mid"
156
+ else: return "junior"
157
+
158
+ def predict_salary_formula(age: float, exp: float, edu: str, job_title: str, gender: str) -> Tuple[float, str, str, str]:
159
+ """
160
+ Rule-based salary estimator aligned with the Random Forest model logic.
161
+ Returns (predicted_salary, career_tier, experience_group, explanation)
162
+ """
163
+ tier = get_career_tier(age, exp, edu)
164
+ exp_group = get_experience_group(exp)
165
+
166
+ edu_bonus = {"Bachelor's": 0, "Master's": 8000, "PhD": 18000}.get(edu, 0)
167
+ tier_bonus = {"junior": 0, "mid": 15000, "senior": 35000}.get(tier, 0)
168
+ base = 25000 + (age * 1100) + (exp * 4200) + edu_bonus + tier_bonus
169
+
170
+ # Job title keyword bonus
171
+ job_lower = job_title.lower()
172
+ if any(k in job_lower for k in ["director", "vp", "chief", "head"]):
173
+ base += 30000
174
+ elif any(k in job_lower for k in ["manager", "lead", "senior"]):
175
+ base += 15000
176
+ elif any(k in job_lower for k in ["junior", "intern", "assistant"]):
177
+ base -= 8000
178
+
179
+ salary = round(max(20000, base), 2)
180
+
181
+ explanation = (
182
+ f"Based on your profile: **{tier.capitalize()}** career tier | "
183
+ f"**{exp_group}** experience | **{edu}** education\n\n"
184
+ f"Key drivers: Age ({age}y) + Experience ({exp}y) + Education bonus (${edu_bonus:,}) "
185
+ f"+ Seniority bonus (${tier_bonus:,})"
186
  )
187
 
188
+ return salary, tier, exp_group, explanation
189
 
190
+ def predict_salary_n8n(age: float, exp: float, edu: str, job_title: str, gender: str) -> Tuple[str, str]:
191
+ """Call n8n webhook if available, otherwise use formula."""
192
+ import requests as req
193
+ if N8N_WEBHOOK_URL:
194
+ try:
195
+ resp = req.post(
196
+ N8N_WEBHOOK_URL,
197
+ json={"age": age, "experience": exp, "education": edu},
198
+ timeout=15
199
+ )
200
+ data = resp.json()
201
+ salary = data.get("predicted_salary", None)
202
+ tier = data.get("career_tier", get_career_tier(age, exp, edu))
203
+ exp_group = data.get("experience_group", get_experience_group(exp))
204
+ if salary:
205
+ explanation = (
206
+ f"Prediction from **Random Forest model** via n8n automation\n\n"
207
+ f"Career tier: **{tier}** | Experience group: **{exp_group}**"
208
+ )
209
+ return salary, tier, exp_group, explanation
210
+ except Exception:
211
+ pass
212
+ return predict_salary_formula(age, exp, edu, job_title, gender)
213
 
214
+ def run_prediction(age, exp, edu, job_title, gender):
215
+ """Main prediction function called by Gradio."""
216
+ if not job_title:
217
+ job_title = "Employee"
218
+ try:
219
+ salary, tier, exp_group, explanation = predict_salary_n8n(
220
+ float(age), float(exp), edu, job_title, gender
221
+ )
222
+ salary_str = f"${salary:,.2f}"
223
+ tier_color = {"junior": "#45FFCA", "mid": "#D09CFA", "senior": "#FF9B9B"}.get(tier, "#888")
224
+ result_html = f"""
225
+ <div style="background: white; border-radius: 12px; padding: 24px; border: 2px solid {tier_color};">
226
+ <div style="font-size: 36px; font-weight: 800; color: #2d1f4e; text-align: center;">
227
+ {salary_str}
228
+ </div>
229
+ <div style="text-align: center; margin-top: 8px;">
230
+ <span style="background: {tier_color}; color: #2d1f4e; padding: 4px 16px;
231
+ border-radius: 20px; font-weight: 700; font-size: 14px;">
232
+ {tier.upper()} TIER
233
+ </span>
234
+ <span style="margin-left: 8px; background: #f3f4f6; color: #374151;
235
+ padding: 4px 16px; border-radius: 20px; font-size: 13px;">
236
+ {exp_group}
237
+ </span>
238
+ </div>
239
+ </div>
240
+ """
241
+ return result_html, explanation
242
+ except Exception as e:
243
+ return f"<div style='color:red;'>Error: {e}</div>", ""
244
 
245
  # =========================================================
246
+ # AI DASHBOARD SYSTEM PROMPT
247
  # =========================================================
248
 
249
+ DASHBOARD_SYSTEM = """You are an AI dashboard assistant for a salary prediction analytics app.
250
+ The user asks questions about employee salary data, career tiers, and predictions.
251
+ You have access to pre-computed artifacts from a Python analysis pipeline.
252
 
253
  AVAILABLE ARTIFACTS (only reference ones that exist):
254
  {artifacts_json}
 
258
  YOUR JOB:
259
  1. Answer the user's question conversationally using the KPIs and your knowledge of the artifacts.
260
  2. At the END of your response, output a JSON block (fenced with ```json ... ```) that tells
261
+ the dashboard which artifact to display:
262
  {{"show": "figure"|"table"|"none", "scope": "python", "filename": "..."}}
263
 
 
 
 
 
264
  RULES:
265
+ - If asked about salary by gender/education/experience, show the relevant chart or table.
266
+ - If asked about career tiers, show tier distribution.
267
+ - If asked about sentiment/feedback, show vader analysis.
268
+ - If asked about salary growth or progression, show progression data.
269
+ - If asked about correlations, show the correlation heatmap.
270
+ - Keep answers concise (2-4 sentences), then the JSON block.
271
  """
272
 
273
  JSON_BLOCK_RE = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL)
274
  FALLBACK_JSON_RE = re.compile(r"\{[^{}]*\"show\"[^{}]*\}", re.DOTALL)
275
 
 
276
  def _parse_display_directive(text: str) -> Dict[str, str]:
277
  m = JSON_BLOCK_RE.search(text)
278
  if m:
 
288
  pass
289
  return {"show": "none"}
290
 
 
291
  def _clean_response(text: str) -> str:
 
292
  return JSON_BLOCK_RE.sub("", text).strip()
293
 
 
294
  def _n8n_call(msg: str) -> Tuple[str, Dict]:
 
295
  import requests as req
296
  try:
297
  resp = req.post(N8N_WEBHOOK_URL, json={"question": msg}, timeout=20)
 
304
  except Exception as e:
305
  return f"n8n error: {e}. Falling back to keyword matching.", None
306
 
 
307
  def ai_chat(user_msg: str, history: list):
 
308
  if not user_msg or not user_msg.strip():
309
  return history, "", None, None
310
 
311
  idx = artifacts_index()
312
  kpis = load_kpis()
313
 
 
314
  if N8N_WEBHOOK_URL:
315
  reply, directive = _n8n_call(user_msg)
316
  if directive is None:
 
321
  else:
322
  system = DASHBOARD_SYSTEM.format(
323
  artifacts_json=json.dumps(idx, indent=2),
324
+ kpis_json=json.dumps(kpis, indent=2) if kpis else "(no KPIs yet)",
325
  )
326
  msgs = [{"role": "system", "content": system}]
327
  for entry in (history or [])[-6:]:
328
  msgs.append(entry)
329
  msgs.append({"role": "user", "content": user_msg})
 
330
  try:
331
  r = llm_client.chat_completion(
332
+ model=MODEL_NAME, messages=msgs,
333
+ temperature=0.3, max_tokens=600, stream=False,
 
 
 
334
  )
335
  raw = (
336
  r["choices"][0]["message"]["content"]
 
344
  reply_fb, directive = _keyword_fallback(user_msg, idx, kpis)
345
  reply += "\n\n" + reply_fb
346
 
 
347
  chart_out = None
348
  tab_out = None
349
  show = directive.get("show", "none")
350
  fname = directive.get("filename", "")
351
  chart_name = directive.get("chart", "")
352
 
 
353
  chart_builders = {
354
+ "salary_by_tier": build_salary_by_tier_chart,
355
+ "salary_progression": build_salary_progression_chart,
356
+ "sentiment": build_sentiment_chart,
357
+ "career_distribution": build_career_distribution_chart,
358
  }
359
 
360
  if chart_name and chart_name in chart_builders:
361
  chart_out = chart_builders[chart_name]()
362
  elif show == "figure" and fname:
363
+ if "tier" in fname or "career" in fname:
364
+ chart_out = build_salary_by_tier_chart()
365
+ elif "progression" in fname or "growth" in fname:
366
+ chart_out = build_salary_progression_chart()
367
+ elif "sentiment" in fname or "vader" in fname:
368
  chart_out = build_sentiment_chart()
 
 
369
  else:
370
  chart_out = _empty_chart(f"No interactive chart for {fname}")
371
 
 
380
  {"role": "user", "content": user_msg},
381
  {"role": "assistant", "content": reply},
382
  ]
 
383
  return new_history, "", chart_out, tab_out
384
 
 
385
  def _keyword_fallback(msg: str, idx: Dict, kpis: Dict) -> Tuple[str, Dict]:
 
386
  msg_lower = msg.lower()
387
 
 
 
 
 
 
 
 
388
  kpi_text = ""
389
  if kpis:
390
+ n_emp = kpis.get("n_employees", "?")
391
+ avg_sal = kpis.get("avg_salary", "?")
392
+ kpi_text = f"Quick summary: **{n_emp}** employees with average salary **${avg_sal:,.0f}**." if isinstance(avg_sal, (int, float)) else ""
 
 
393
 
394
+ if any(w in msg_lower for w in ["tier", "senior", "junior", "mid", "career"]):
395
  return (
396
+ f"Here is the salary distribution by career tier. {kpi_text}",
397
+ {"show": "figure", "chart": "salary_by_tier"},
398
  )
399
+ if any(w in msg_lower for w in ["progression", "growth", "over time", "year", "lstm", "arima"]):
 
400
  return (
401
+ f"Here is the salary progression over time. {kpi_text}",
402
+ {"show": "figure", "chart": "salary_progression"},
403
  )
404
+ if any(w in msg_lower for w in ["sentiment", "vader", "feedback", "positive", "negative"]):
 
405
  return (
406
+ f"Here is the employee feedback sentiment analysis. {kpi_text}",
407
+ {"show": "figure", "chart": "sentiment"},
408
  )
409
+ if any(w in msg_lower for w in ["distribution", "gender", "education", "experience"]):
 
410
  return (
411
+ f"Here is the career tier distribution. {kpi_text}",
412
+ {"show": "figure", "chart": "career_distribution"},
413
  )
414
+ if any(w in msg_lower for w in ["table", "data", "employee", "list"]):
 
415
  return (
416
+ f"Here is the employee analysis data. {kpi_text}",
417
+ {"show": "table", "scope": "python", "filename": "employee_analysis_ready.csv"},
418
  )
 
419
  if any(w in msg_lower for w in ["dashboard", "overview", "summary", "kpi"]):
420
  return (
421
+ f"Dashboard overview: {kpi_text}\n\nAsk me about career tiers, salary progression, "
422
+ "sentiment analysis, or employee data.",
423
+ {"show": "figure", "chart": "career_distribution"},
424
  )
 
 
425
  return (
426
+ f"I can show you various salary analyses. {kpi_text}\n\n"
427
+ "Try asking about: **career tiers**, **salary progression**, **feedback sentiment**, "
428
+ "**salary by gender/education**, or **employee overview**.",
429
  {"show": "none"},
430
  )
431
 
 
432
  # =========================================================
433
+ # KPI CARDS
434
  # =========================================================
435
 
436
+ def load_kpis() -> Dict[str, Any]:
437
+ # Try loading from files first
438
+ for candidate in [PY_TAB_DIR / "kpis.json", PY_FIG_DIR / "kpis.json"]:
439
+ if candidate.exists():
440
+ try:
441
+ return _read_json(candidate)
442
+ except Exception:
443
+ pass
444
+
445
+ # Build KPIs from CSVs directly
446
+ kpis = {}
447
+ for csv_name in ["employee_analysis_ready.csv", BASE_DIR / "employee_analysis_ready.csv"]:
448
+ path = Path(csv_name) if isinstance(csv_name, str) else csv_name
449
+ if not path.exists():
450
+ path = BASE_DIR / "employee_analysis_ready.csv"
451
+ if path.exists():
452
+ try:
453
+ df = pd.read_csv(path)
454
+ kpis["n_employees"] = int(len(df))
455
+ if "Salary" in df.columns:
456
+ kpis["avg_salary"] = round(float(df["Salary"].mean()), 2)
457
+ kpis["max_salary"] = round(float(df["Salary"].max()), 2)
458
+ if "salary_growth" in df.columns:
459
+ kpis["avg_salary_growth"] = round(float(df["salary_growth"].mean()), 2)
460
+ except Exception:
461
+ pass
462
+ break
463
+ return kpis
464
+
465
  def render_kpi_cards() -> str:
466
  kpis = load_kpis()
467
  if not kpis:
468
  return (
469
  '<div style="background:rgba(255,255,255,.65);backdrop-filter:blur(16px);'
470
  'border-radius:20px;padding:28px;text-align:center;'
471
+ 'border:1.5px solid rgba(255,255,255,.7);">'
 
472
  '<div style="font-size:36px;margin-bottom:10px;">📊</div>'
473
+ '<div style="color:#a48de8;font-size:14px;font-weight:800;margin-bottom:6px;">No data yet</div>'
474
+ '<div style="color:#9d8fc4;font-size:12px;">Run the pipeline to populate these cards.</div>'
 
 
475
  '</div>'
476
  )
477
 
 
480
  <div style="background:rgba(255,255,255,.72);backdrop-filter:blur(16px);
481
  border-radius:20px;padding:18px 14px 16px;text-align:center;
482
  border:1.5px solid rgba(255,255,255,.8);
 
483
  border-top:3px solid {colour};">
484
+ <div style="font-size:26px;margin-bottom:7px;">{icon}</div>
485
  <div style="color:#9d8fc4;font-size:9.5px;text-transform:uppercase;
486
  letter-spacing:1.8px;margin-bottom:7px;font-weight:800;">{label}</div>
487
  <div style="color:#2d1f4e;font-size:16px;font-weight:800;">{value}</div>
488
  </div>"""
489
 
490
  kpi_config = [
491
+ ("n_employees", "👥", "Employees", "#a48de8"),
492
+ ("avg_salary", "💰", "Avg Salary", "#2ec4a0"),
493
+ ("max_salary", "🏆", "Max Salary", "#e8537a"),
494
+ ("avg_salary_growth","📈", "Avg Salary Growth","#5e8fef"),
495
  ]
496
 
497
  html = (
 
503
  if val is None:
504
  continue
505
  if isinstance(val, (int, float)) and val > 100:
506
+ val = f"${val:,.0f}" if "salary" in key.lower() else f"{val:,.0f}"
507
  html += card(icon, label, str(val), colour)
 
 
 
 
 
 
 
 
508
  html += "</div>"
509
  return html
510
 
 
511
  # =========================================================
512
+ # INTERACTIVE PLOTLY CHARTS
513
  # =========================================================
514
 
515
  CHART_PALETTE = ["#7c5cbf", "#2ec4a0", "#e8537a", "#e8a230", "#5e8fef",
516
  "#c45ea8", "#3dbacc", "#a0522d", "#6aaa3a", "#d46060"]
517
 
518
+ TIER_COLORS = {"junior": "#45FFCA", "mid": "#D09CFA", "senior": "#FF9B9B"}
519
+
520
  def _styled_layout(**kwargs) -> dict:
521
  defaults = dict(
522
  template="plotly_white",
 
534
  defaults.update(kwargs)
535
  return defaults
536
 
 
537
  def _empty_chart(title: str) -> go.Figure:
538
  fig = go.Figure()
539
  fig.update_layout(
540
  title=title, height=420, template="plotly_white",
541
  paper_bgcolor="rgba(255,255,255,0.95)",
542
+ annotations=[dict(
543
+ text="Upload your CSV files or run the pipeline first",
544
  x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False,
545
  font=dict(size=14, color="rgba(124,92,191,0.5)"))],
546
  )
547
  return fig
548
 
549
+ def _find_csv(candidates: List[str]) -> pd.DataFrame | None:
550
+ """Try multiple paths to find a CSV file."""
551
+ for name in candidates:
552
+ for prefix in [BASE_DIR, PY_TAB_DIR, Path(".")]:
553
+ path = prefix / name
554
+ if path.exists():
555
+ try:
556
+ return pd.read_csv(path)
557
+ except Exception:
558
+ pass
559
+ return None
560
+
561
+ def build_salary_by_tier_chart() -> go.Figure:
562
+ df = _find_csv(["employee_analysis_ready.csv"])
563
+ if df is None or "career_tier" not in df.columns or "Salary" not in df.columns:
564
+ return _empty_chart("Salary by Career Tier — upload employee_analysis_ready.csv")
565
+
566
+ tier_stats = df.groupby("career_tier")["Salary"].agg(["mean", "median", "std"]).reset_index()
567
+ tier_order = ["junior", "mid", "senior"]
568
+ tier_stats["career_tier"] = pd.Categorical(tier_stats["career_tier"], categories=tier_order, ordered=True)
569
+ tier_stats = tier_stats.sort_values("career_tier")
570
+
571
+ colors = [TIER_COLORS.get(t, "#888") for t in tier_stats["career_tier"]]
572
 
 
 
 
 
 
 
 
 
 
 
573
  fig = go.Figure()
574
+ fig.add_trace(go.Bar(
575
+ x=tier_stats["career_tier"],
576
+ y=tier_stats["mean"],
577
+ name="Avg Salary",
578
+ marker_color=colors,
579
+ error_y=dict(type="data", array=tier_stats["std"], visible=True),
580
+ hovertemplate="<b>%{x}</b><br>Avg: $%{y:,.0f}<extra></extra>",
581
+ text=[f"${v:,.0f}" for v in tier_stats["mean"]],
582
+ textposition="outside",
583
+ ))
584
+ fig.add_trace(go.Scatter(
585
+ x=tier_stats["career_tier"],
586
+ y=tier_stats["median"],
587
+ name="Median Salary",
588
+ mode="markers",
589
+ marker=dict(color="#2d1f4e", size=10, symbol="diamond"),
590
+ hovertemplate="<b>%{x}</b><br>Median: $%{y:,.0f}<extra></extra>",
591
+ ))
592
+ fig.update_layout(**_styled_layout(
593
+ height=450,
594
+ title=dict(text="Average Salary by Career Tier"),
595
+ yaxis_title="Salary ($)",
596
+ xaxis_title="Career Tier",
597
+ ))
598
  return fig
599
 
600
+ def build_salary_progression_chart() -> go.Figure:
601
+ df = _find_csv(["synthetic_salary_progression.csv"])
602
+ if df is None or "year" not in df.columns or "salary_that_year" not in df.columns:
603
+ return _empty_chart("Salary Progression — upload synthetic_salary_progression.csv")
604
 
 
 
 
 
 
 
 
 
 
 
605
  fig = go.Figure()
606
+ if "career_tier" in df.columns:
607
+ for i, tier in enumerate(["junior", "mid", "senior"]):
608
+ tier_df = df[df["career_tier"] == tier]
609
+ if tier_df.empty:
610
+ continue
611
+ avg_by_year = tier_df.groupby("year")["salary_that_year"].mean().reset_index()
612
+ fig.add_trace(go.Scatter(
613
+ x=avg_by_year["year"],
614
+ y=avg_by_year["salary_that_year"],
615
+ name=tier.capitalize(),
616
+ mode="lines+markers",
617
+ line=dict(color=TIER_COLORS.get(tier, CHART_PALETTE[i]), width=3),
618
+ marker=dict(size=6),
619
+ hovertemplate=f"<b>{tier.capitalize()}</b><br>Year: %{{x}}<br>Avg: $%{{y:,.0f}}<extra></extra>",
620
+ ))
621
+ else:
622
+ avg_by_year = df.groupby("year")["salary_that_year"].mean().reset_index()
623
+ fig.add_trace(go.Scatter(
624
+ x=avg_by_year["year"], y=avg_by_year["salary_that_year"],
625
+ name="Avg Salary", mode="lines+markers",
626
+ line=dict(color="#7c5cbf", width=3),
627
  ))
628
+
629
  fig.update_layout(**_styled_layout(
630
+ height=450,
631
+ title=dict(text="Salary Progression 2020–2024 by Career Tier"),
632
+ yaxis_title="Average Salary ($)",
633
+ xaxis_title="Year",
634
  ))
 
 
635
  return fig
636
 
637
+ def build_sentiment_chart() -> go.Figure:
638
+ df = _find_csv(["synthetic_employee_feedback.csv"])
639
+ if df is None or "feedback_comment" not in df.columns:
640
+ return _empty_chart("Sentiment Analysis — upload synthetic_employee_feedback.csv")
641
 
642
+ try:
643
+ from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
644
+ analyzer = SentimentIntensityAnalyzer()
645
+ df["vader_score"] = df["feedback_comment"].apply(
646
+ lambda x: analyzer.polarity_scores(str(x))["compound"]
647
+ )
648
+ df["sentiment"] = df["vader_score"].apply(
649
+ lambda s: "positive" if s >= 0.05 else "negative" if s <= -0.05 else "neutral"
650
+ )
651
+ except ImportError:
652
+ if "vader_score" not in df.columns:
653
+ return _empty_chart("Install vaderSentiment to see this chart")
654
+
655
+ if "career_tier" in df.columns:
656
+ sent_by_tier = df.groupby(["career_tier", "sentiment"]).size().unstack(fill_value=0).reset_index()
657
+ tier_order = ["junior", "mid", "senior"]
658
+ sent_by_tier["career_tier"] = pd.Categorical(sent_by_tier["career_tier"], categories=tier_order, ordered=True)
659
+ sent_by_tier = sent_by_tier.sort_values("career_tier")
660
+
661
+ colors_map = {"positive": "#45FFCA", "neutral": "#D3D1C7", "negative": "#FF9B9B"}
662
+ fig = go.Figure()
663
+ for sent in ["negative", "neutral", "positive"]:
664
+ if sent in sent_by_tier.columns:
665
+ fig.add_trace(go.Bar(
666
+ name=sent.capitalize(),
667
+ x=sent_by_tier["career_tier"],
668
+ y=sent_by_tier[sent],
669
+ marker_color=colors_map.get(sent, "#888"),
670
+ hovertemplate=f"<b>{sent.capitalize()}</b>: %{{y}}<extra></extra>",
671
+ ))
672
+ fig.update_layout(**_styled_layout(
673
+ height=450, barmode="stack",
674
+ title=dict(text="Feedback Sentiment Distribution by Career Tier"),
675
+ yaxis_title="Number of Comments",
676
+ xaxis_title="Career Tier",
677
+ ))
678
+ else:
679
+ sent_counts = df["sentiment"].value_counts()
680
+ fig = go.Figure(go.Bar(
681
+ x=sent_counts.index, y=sent_counts.values,
682
+ marker_color=["#45FFCA", "#D3D1C7", "#FF9B9B"][:len(sent_counts)],
683
+ ))
684
+ fig.update_layout(**_styled_layout(height=400, title=dict(text="Overall Feedback Sentiment")))
685
+
686
+ return fig
687
+
688
+ def build_career_distribution_chart() -> go.Figure:
689
+ df = _find_csv(["employee_analysis_ready.csv"])
690
+ if df is None or "career_tier" not in df.columns:
691
+ return _empty_chart("Career Distribution — upload employee_analysis_ready.csv")
692
+
693
+ fig = go.Figure()
694
+
695
+ # Pie chart of career tier distribution
696
+ tier_counts = df["career_tier"].value_counts()
697
+ colors = [TIER_COLORS.get(t, "#888") for t in tier_counts.index]
698
+
699
+ fig.add_trace(go.Pie(
700
+ labels=[t.capitalize() for t in tier_counts.index],
701
+ values=tier_counts.values,
702
+ marker=dict(colors=colors),
703
+ hole=0.4,
704
+ hovertemplate="<b>%{label}</b><br>Count: %{value}<br>Share: %{percent}<extra></extra>",
705
  ))
706
  fig.update_layout(**_styled_layout(
707
+ height=420,
708
+ title=dict(text="Career Tier Distribution"),
709
+ showlegend=True,
710
  ))
 
 
711
  return fig
712
 
713
+ def _load_table_safe(path: Path) -> pd.DataFrame:
714
+ try:
715
+ if path.suffix == ".json":
716
+ obj = _read_json(path)
717
+ if isinstance(obj, dict):
718
+ return pd.DataFrame([obj])
719
+ return pd.DataFrame(obj)
720
+ return _read_csv(path)
721
+ except Exception as e:
722
+ return pd.DataFrame([{"error": str(e)}])
723
 
724
  def refresh_dashboard():
725
+ return (
726
+ render_kpi_cards(),
727
+ build_salary_by_tier_chart(),
728
+ build_salary_progression_chart(),
729
+ build_sentiment_chart(),
730
+ build_career_distribution_chart(),
731
+ )
732
+
733
+ def refresh_gallery():
734
+ figures = []
735
+ for p in sorted(PY_FIG_DIR.glob("*.png")):
736
+ figures.append((str(p), p.stem.replace("_", " ").title()))
737
+
738
+ idx = artifacts_index()
739
+ table_choices = list(idx["python"]["tables"])
740
+ default_df = pd.DataFrame()
741
+ if table_choices:
742
+ default_df = _load_table_safe(PY_TAB_DIR / table_choices[0])
743
+
744
+ return (
745
+ figures if figures else [],
746
+ gr.update(choices=table_choices, value=table_choices[0] if table_choices else None),
747
+ default_df,
748
+ )
749
 
750
+ def on_table_select(choice: str):
751
+ if not choice:
752
+ return pd.DataFrame([{"hint": "Select a table above."}])
753
+ path = PY_TAB_DIR / choice
754
+ if not path.exists():
755
+ return pd.DataFrame([{"error": f"File not found: {choice}"}])
756
+ return _load_table_safe(path)
757
 
758
  # =========================================================
759
  # UI
 
765
  css_path = BASE_DIR / "style.css"
766
  return css_path.read_text(encoding="utf-8") if css_path.exists() else ""
767
 
768
+ with gr.Blocks(title="F2 Salary Predictor") as demo:
 
769
 
770
  gr.Markdown(
771
+ "# F2 Salary Predictor\n"
772
+ "*AI-powered salary prediction and employee analytics ESCP Big Data Project*",
773
  elem_id="escp_title",
774
  )
775
 
776
  # ===========================================================
777
+ # TAB 1 -- Salary Predictor
778
+ # ===========================================================
779
+ with gr.Tab("Salary Predictor"):
780
+ gr.Markdown("### Predict your salary based on your profile")
781
+ gr.Markdown(
782
+ "Enter your details below. The model uses age, experience, education, "
783
+ "and job title to estimate your expected salary using our trained Random Forest model."
784
+ )
785
+
786
+ with gr.Row():
787
+ with gr.Column(scale=1):
788
+ age_input = gr.Slider(
789
+ minimum=18, maximum=70, value=30, step=1,
790
+ label="Age"
791
+ )
792
+ exp_input = gr.Slider(
793
+ minimum=0, maximum=40, value=5, step=1,
794
+ label="Years of Experience"
795
+ )
796
+ edu_input = gr.Dropdown(
797
+ choices=["Bachelor's", "Master's", "PhD"],
798
+ value="Bachelor's",
799
+ label="Education Level"
800
+ )
801
+ job_input = gr.Textbox(
802
+ label="Job Title",
803
+ placeholder="e.g. Data Scientist, Software Engineer, Manager...",
804
+ value="Data Analyst"
805
+ )
806
+ gender_input = gr.Radio(
807
+ choices=["Male", "Female"],
808
+ value="Male",
809
+ label="Gender"
810
+ )
811
+ predict_btn = gr.Button("Predict Salary", variant="primary")
812
+
813
+ with gr.Column(scale=1):
814
+ result_html = gr.HTML(
815
+ value='<div style="background:#f9fafb;border-radius:12px;padding:40px;'
816
+ 'text-align:center;color:#9ca3af;font-size:14px;">'
817
+ 'Fill in your profile and click Predict Salary</div>'
818
+ )
819
+ explanation_box = gr.Markdown("")
820
+
821
+ predict_btn.click(
822
+ run_prediction,
823
+ inputs=[age_input, exp_input, edu_input, job_input, gender_input],
824
+ outputs=[result_html, explanation_box],
825
+ )
826
+
827
+ gr.Markdown("---")
828
+ gr.Markdown("#### Try these example profiles")
829
+ gr.Examples(
830
+ examples=[
831
+ [28, 3, "Bachelor's", "Junior Data Analyst", "Female"],
832
+ [35, 10, "Master's", "Senior Data Scientist", "Male"],
833
+ [45, 20, "PhD", "Director of Analytics", "Female"],
834
+ [22, 1, "Bachelor's", "Intern", "Male"],
835
+ [50, 25, "Master's", "VP of Engineering", "Male"],
836
+ ],
837
+ inputs=[age_input, exp_input, edu_input, job_input, gender_input],
838
+ label="Example Profiles",
839
+ )
840
+
841
+ # ===========================================================
842
+ # TAB 2 -- Pipeline Runner
843
  # ===========================================================
844
  with gr.Tab("Pipeline Runner"):
845
+ gr.Markdown("### Run the data pipeline")
846
 
847
  with gr.Row():
848
  with gr.Column(scale=1):
849
+ btn_nb1 = gr.Button("Step 1: Data Creation & Synthetic Enrichment", variant="secondary")
850
  with gr.Column(scale=1):
851
+ btn_nb2 = gr.Button("Step 2: Salary Analysis & Predictions", variant="secondary")
852
 
853
  with gr.Row():
854
  btn_all = gr.Button("Run Full Pipeline (Both Steps)", variant="primary")
 
865
  btn_all.click(run_full_pipeline, outputs=[run_log])
866
 
867
  # ===========================================================
868
+ # TAB 3 -- Dashboard
869
  # ===========================================================
870
  with gr.Tab("Dashboard"):
871
  kpi_html = gr.HTML(value=render_kpi_cards)
872
 
873
  refresh_btn = gr.Button("Refresh Dashboard", variant="primary")
874
 
875
+ gr.Markdown("#### Salary Analytics")
876
+ with gr.Row():
877
+ chart_tier = gr.Plot(label="Salary by Career Tier")
878
+ chart_dist = gr.Plot(label="Career Tier Distribution")
879
 
880
+ gr.Markdown("#### Temporal & Sentiment Analysis")
881
+ with gr.Row():
882
+ chart_prog = gr.Plot(label="Salary Progression 2020–2024")
883
+ chart_sent = gr.Plot(label="Feedback Sentiment by Tier")
 
 
 
884
 
885
  gr.Markdown("#### Data Tables")
886
  table_dropdown = gr.Dropdown(
 
888
  choices=[],
889
  interactive=True,
890
  )
891
+ table_display = gr.Dataframe(label="Table Preview", interactive=False)
 
 
 
892
 
893
  def _on_refresh():
894
+ kpi, c1, c2, c3, c4 = refresh_dashboard()
895
  figs, dd, df = refresh_gallery()
896
+ return kpi, c1, c2, c3, c4, figs, dd, df
897
+
898
+ gallery = gr.Gallery(
899
+ label="Generated Figures from Notebooks",
900
+ columns=2, height=480, object_fit="contain",
901
+ )
902
 
903
  refresh_btn.click(
904
  _on_refresh,
905
+ outputs=[kpi_html, chart_tier, chart_prog, chart_sent, chart_dist,
906
  gallery, table_dropdown, table_display],
907
  )
908
  table_dropdown.change(
 
912
  )
913
 
914
  # ===========================================================
915
+ # TAB 4 -- AI Dashboard
916
  # ===========================================================
917
  with gr.Tab('"AI" Dashboard'):
918
  _ai_status = (
919
  "Connected to your **n8n workflow**." if N8N_WEBHOOK_URL
920
  else "**LLM active.**" if LLM_ENABLED
921
+ else "Using **keyword matching**. Set `N8N_WEBHOOK_URL` in Space secrets to connect your n8n automation."
 
 
922
  )
923
  gr.Markdown(
924
  "### Ask questions, get interactive visualisations\n\n"
925
+ f"Type a question and the system will pick the right chart or table. {_ai_status}"
926
  )
927
 
928
  with gr.Row(equal_height=True):
929
  with gr.Column(scale=1):
930
+ chatbot = gr.Chatbot(label="Conversation", height=380)
 
 
 
931
  user_input = gr.Textbox(
932
  label="Ask about your data",
933
+ placeholder="e.g. Show me salary by career tier / What is the sentiment analysis?",
934
  lines=1,
935
  )
936
  gr.Examples(
937
  examples=[
938
+ "Show me salary by career tier",
939
+ "What does the sentiment analysis show?",
940
+ "Show me salary progression over time",
941
+ "What is the career tier distribution?",
 
942
  "Give me a dashboard overview",
943
+ "Show me the employee data table",
944
  ],
945
  inputs=user_input,
946
  )
947
 
948
  with gr.Column(scale=1):
949
+ ai_figure = gr.Plot(label="Interactive Chart")
950
+ ai_table = gr.Dataframe(label="Data Table", interactive=False)
 
 
 
 
 
951
 
952
  user_input.submit(
953
  ai_chat,
 
956
  )
957
 
958
 
959
+ demo.launch(css=load_css(), allowed_paths=[str(BASE_DIR)])