Danielescp commited on
Commit
893e047
Β·
verified Β·
1 Parent(s): 7a28183

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -735
app.py CHANGED
@@ -1,758 +1,257 @@
1
- import os
2
- import re
3
- import json
4
- import time
5
- import traceback
6
- from pathlib import Path
7
- from typing import Dict, Any, List, Tuple
 
 
8
 
9
- import pandas as pd
10
  import gradio as gr
11
- import papermill as pm
12
- import plotly.graph_objects as go
13
-
14
- # Optional LLM (HuggingFace Inference API)
15
- try:
16
- from huggingface_hub import InferenceClient
17
- except Exception:
18
- InferenceClient = None
19
-
20
- # =========================================================
21
- # CONFIG
22
- # =========================================================
23
-
24
- BASE_DIR = Path(__file__).resolve().parent
25
-
26
- NB1 = os.environ.get("NB1", "datacreation.ipynb").strip()
27
- NB2 = os.environ.get("NB2", "pythonanalysis.ipynb").strip()
28
-
29
- RUNS_DIR = BASE_DIR / "runs"
30
- ART_DIR = BASE_DIR / "artifacts"
31
- PY_FIG_DIR = ART_DIR / "py" / "figures"
32
- PY_TAB_DIR = ART_DIR / "py" / "tables"
33
-
34
- PAPERMILL_TIMEOUT = int(os.environ.get("PAPERMILL_TIMEOUT", "1800"))
35
- MAX_PREVIEW_ROWS = int(os.environ.get("MAX_FILE_PREVIEW_ROWS", "50"))
36
- MAX_LOG_CHARS = int(os.environ.get("MAX_LOG_CHARS", "8000"))
37
-
38
- HF_API_KEY = os.environ.get("HF_API_KEY", "").strip()
39
- MODEL_NAME = os.environ.get("MODEL_NAME", "deepseek-ai/DeepSeek-R1").strip()
40
- HF_PROVIDER = os.environ.get("HF_PROVIDER", "novita").strip()
41
- N8N_WEBHOOK_URL = os.environ.get("N8N_WEBHOOK_URL", "").strip()
42
-
43
- LLM_ENABLED = bool(HF_API_KEY) and InferenceClient is not None
44
- llm_client = (
45
- InferenceClient(provider=HF_PROVIDER, api_key=HF_API_KEY)
46
- if LLM_ENABLED
47
- else None
48
- )
49
-
50
- # =========================================================
51
- # HELPERS
52
- # =========================================================
53
-
54
- def ensure_dirs():
55
- for p in [RUNS_DIR, ART_DIR, PY_FIG_DIR, PY_TAB_DIR]:
56
- p.mkdir(parents=True, exist_ok=True)
57
-
58
- def stamp():
59
- return time.strftime("%Y%m%d-%H%M%S")
60
-
61
- def tail(text: str, n: int = MAX_LOG_CHARS) -> str:
62
- return (text or "")[-n:]
63
-
64
- def _ls(dir_path: Path, exts: Tuple[str, ...]) -> List[str]:
65
- if not dir_path.is_dir():
66
- return []
67
- return sorted(p.name for p in dir_path.iterdir() if p.is_file() and p.suffix.lower() in exts)
68
-
69
- def _read_csv(path: Path) -> pd.DataFrame:
70
- return pd.read_csv(path, nrows=MAX_PREVIEW_ROWS)
71
-
72
- def _read_json(path: Path):
73
- with path.open(encoding="utf-8") as f:
74
- return json.load(f)
75
-
76
- def artifacts_index() -> Dict[str, Any]:
77
- return {
78
- "python": {
79
- "figures": _ls(PY_FIG_DIR, (".png", ".jpg", ".jpeg")),
80
- "tables": _ls(PY_TAB_DIR, (".csv", ".json")),
81
- },
82
- }
83
-
84
- # =========================================================
85
- # PIPELINE RUNNERS
86
- # =========================================================
87
-
88
- def run_notebook(nb_name: str) -> str:
89
- ensure_dirs()
90
- nb_in = BASE_DIR / nb_name
91
- if not nb_in.exists():
92
- return f"ERROR: {nb_name} not found."
93
- nb_out = RUNS_DIR / f"run_{stamp()}_{nb_name}"
94
- pm.execute_notebook(
95
- input_path=str(nb_in),
96
- output_path=str(nb_out),
97
- cwd=str(BASE_DIR),
98
- log_output=True,
99
- progress_bar=False,
100
- request_save_on_cell_execute=True,
101
- execution_timeout=PAPERMILL_TIMEOUT,
102
- )
103
- return f"Executed {nb_name}"
104
-
105
-
106
- def run_datacreation() -> str:
107
- try:
108
- log = run_notebook(NB1)
109
- csvs = [f.name for f in BASE_DIR.glob("*.csv")]
110
- return f"OK {log}\n\nCSVs now in /app:\n" + "\n".join(f" - {c}" for c in sorted(csvs))
111
- except Exception as e:
112
- return f"FAILED {e}\n\n{traceback.format_exc()[-2000:]}"
113
-
114
-
115
- def run_pythonanalysis() -> str:
116
- try:
117
- log = run_notebook(NB2)
118
- idx = artifacts_index()
119
- figs = idx["python"]["figures"]
120
- tabs = idx["python"]["tables"]
121
- return (
122
- f"OK {log}\n\n"
123
- f"Figures: {', '.join(figs) or '(none)'}\n"
124
- f"Tables: {', '.join(tabs) or '(none)'}"
125
- )
126
- except Exception as e:
127
- return f"FAILED {e}\n\n{traceback.format_exc()[-2000:]}"
128
-
129
-
130
- def run_full_pipeline() -> str:
131
- logs = []
132
- logs.append("=" * 50)
133
- logs.append("STEP 1/2: Data Creation (web scraping + synthetic data)")
134
- logs.append("=" * 50)
135
- logs.append(run_datacreation())
136
- logs.append("")
137
- logs.append("=" * 50)
138
- logs.append("STEP 2/2: Python Analysis (sentiment, ARIMA, dashboard)")
139
- logs.append("=" * 50)
140
- logs.append(run_pythonanalysis())
141
- return "\n".join(logs)
142
-
143
-
144
- # =========================================================
145
- # GALLERY LOADERS
146
- # =========================================================
147
-
148
- def _load_all_figures() -> List[Tuple[str, str]]:
149
- """Return list of (filepath, caption) for Gallery."""
150
- items = []
151
- for p in sorted(PY_FIG_DIR.glob("*.png")):
152
- items.append((str(p), p.stem.replace('_', ' ').title()))
153
- return items
154
-
155
-
156
- def _load_table_safe(path: Path) -> pd.DataFrame:
157
- try:
158
- if path.suffix == ".json":
159
- obj = _read_json(path)
160
- if isinstance(obj, dict):
161
- return pd.DataFrame([obj])
162
- return pd.DataFrame(obj)
163
- return _read_csv(path)
164
- except Exception as e:
165
- return pd.DataFrame([{"error": str(e)}])
166
-
167
-
168
- def refresh_gallery():
169
- """Called when user clicks Refresh on Gallery tab."""
170
- figures = _load_all_figures()
171
- idx = artifacts_index()
172
-
173
- table_choices = list(idx["python"]["tables"])
174
-
175
- default_df = pd.DataFrame()
176
- if table_choices:
177
- default_df = _load_table_safe(PY_TAB_DIR / table_choices[0])
178
-
179
- return (
180
- figures if figures else [],
181
- gr.update(choices=table_choices, value=table_choices[0] if table_choices else None),
182
- default_df,
183
  )
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- def on_table_select(choice: str):
187
- if not choice:
188
- return pd.DataFrame([{"hint": "Select a table above."}])
189
- path = PY_TAB_DIR / choice
190
- if not path.exists():
191
- return pd.DataFrame([{"error": f"File not found: {choice}"}])
192
- return _load_table_safe(path)
193
-
194
-
195
- # =========================================================
196
- # KPI LOADER
197
- # =========================================================
198
-
199
- def load_kpis() -> Dict[str, Any]:
200
- for candidate in [PY_TAB_DIR / "kpis.json", PY_FIG_DIR / "kpis.json"]:
201
- if candidate.exists():
202
- try:
203
- return _read_json(candidate)
204
- except Exception:
205
- pass
206
- return {}
207
-
208
-
209
- # =========================================================
210
- # AI DASHBOARD -- LLM picks what to display
211
- # =========================================================
212
-
213
- DASHBOARD_SYSTEM = """You are an AI dashboard assistant for a book-sales analytics app.
214
- The user asks questions or requests about their data. You have access to pre-computed
215
- artifacts from a Python analysis pipeline.
216
 
217
- AVAILABLE ARTIFACTS (only reference ones that exist):
218
- {artifacts_json}
219
 
220
- KPI SUMMARY: {kpis_json}
221
 
222
- YOUR JOB:
223
- 1. Answer the user's question conversationally using the KPIs and your knowledge of the artifacts.
224
- 2. At the END of your response, output a JSON block (fenced with ```json ... ```) that tells
225
- the dashboard which artifact to display. The JSON must have this shape:
226
- {{"show": "figure"|"table"|"none", "scope": "python", "filename": "..."}}
227
 
228
- - Use "show": "figure" to display a chart image.
229
- - Use "show": "table" to display a CSV/JSON table.
230
- - Use "show": "none" if no artifact is relevant.
231
 
232
- RULES:
233
- - If the user asks about sales trends or forecasting by title, show sales_trends or arima figures.
234
- - If the user asks about sentiment, show sentiment figure or sentiment_counts table.
235
- - If the user asks about forecast accuracy or ARIMA, show arima figures.
236
- - If the user asks about top sellers, show top_titles_by_units_sold.csv.
237
- - If the user asks a general data question, pick the most relevant artifact.
238
- - Keep your answer concise (2-4 sentences), then the JSON block.
239
  """
240
 
241
- JSON_BLOCK_RE = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL)
242
- FALLBACK_JSON_RE = re.compile(r"\{[^{}]*\"show\"[^{}]*\}", re.DOTALL)
243
-
244
-
245
- def _parse_display_directive(text: str) -> Dict[str, str]:
246
- m = JSON_BLOCK_RE.search(text)
247
- if m:
248
- try:
249
- return json.loads(m.group(1))
250
- except json.JSONDecodeError:
251
- pass
252
- m = FALLBACK_JSON_RE.search(text)
253
- if m:
254
- try:
255
- return json.loads(m.group(0))
256
- except json.JSONDecodeError:
257
- pass
258
- return {"show": "none"}
259
-
260
-
261
- def _clean_response(text: str) -> str:
262
- """Strip the JSON directive block from the displayed response."""
263
- return JSON_BLOCK_RE.sub("", text).strip()
264
-
265
-
266
- def _n8n_call(msg: str) -> Tuple[str, Dict]:
267
- """Call the student's n8n webhook and return (reply, directive)."""
268
- import requests as req
269
- try:
270
- resp = req.post(N8N_WEBHOOK_URL, json={"question": msg}, timeout=20)
271
- data = resp.json()
272
- answer = data.get("answer", "No response from n8n workflow.")
273
- chart = data.get("chart", "none")
274
- if chart and chart != "none":
275
- return answer, {"show": "figure", "chart": chart}
276
- return answer, {"show": "none"}
277
- except Exception as e:
278
- return f"n8n error: {e}. Falling back to keyword matching.", None
279
-
280
-
281
- def ai_chat(user_msg: str, history: list):
282
- """Chat function for the AI Dashboard tab."""
283
- if not user_msg or not user_msg.strip():
284
- return history, "", None, None
285
-
286
- idx = artifacts_index()
287
- kpis = load_kpis()
288
-
289
- # Priority: n8n webhook > HF LLM > keyword fallback
290
- if N8N_WEBHOOK_URL:
291
- reply, directive = _n8n_call(user_msg)
292
- if directive is None:
293
- reply_fb, directive = _keyword_fallback(user_msg, idx, kpis)
294
- reply += "\n\n" + reply_fb
295
- elif not LLM_ENABLED:
296
- reply, directive = _keyword_fallback(user_msg, idx, kpis)
297
- else:
298
- system = DASHBOARD_SYSTEM.format(
299
- artifacts_json=json.dumps(idx, indent=2),
300
- kpis_json=json.dumps(kpis, indent=2) if kpis else "(no KPIs yet, run the pipeline first)",
301
- )
302
- msgs = [{"role": "system", "content": system}]
303
- for entry in (history or [])[-6:]:
304
- msgs.append(entry)
305
- msgs.append({"role": "user", "content": user_msg})
306
-
307
- try:
308
- r = llm_client.chat_completion(
309
- model=MODEL_NAME,
310
- messages=msgs,
311
- temperature=0.3,
312
- max_tokens=600,
313
- stream=False,
314
  )
315
- raw = (
316
- r["choices"][0]["message"]["content"]
317
- if isinstance(r, dict)
318
- else r.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
319
  )
320
- directive = _parse_display_directive(raw)
321
- reply = _clean_response(raw)
322
- except Exception as e:
323
- reply = f"LLM error: {e}. Falling back to keyword matching."
324
- reply_fb, directive = _keyword_fallback(user_msg, idx, kpis)
325
- reply += "\n\n" + reply_fb
326
-
327
- # Resolve artifacts β€” build interactive Plotly charts when possible
328
- chart_out = None
329
- tab_out = None
330
- show = directive.get("show", "none")
331
- fname = directive.get("filename", "")
332
- chart_name = directive.get("chart", "")
333
-
334
- # Interactive chart builders keyed by name
335
- chart_builders = {
336
- "sales": build_sales_chart,
337
- "sentiment": build_sentiment_chart,
338
- "top_sellers": build_top_sellers_chart,
339
- }
340
-
341
- if chart_name and chart_name in chart_builders:
342
- chart_out = chart_builders[chart_name]()
343
- elif show == "figure" and fname:
344
- # Fallback: try to match filename to a chart builder
345
- if "sales_trend" in fname:
346
- chart_out = build_sales_chart()
347
- elif "sentiment" in fname:
348
- chart_out = build_sentiment_chart()
349
- elif "arima" in fname or "forecast" in fname:
350
- chart_out = build_sales_chart() # closest interactive equivalent
351
- else:
352
- chart_out = _empty_chart(f"No interactive chart for {fname}")
353
-
354
- if show == "table" and fname:
355
- fp = PY_TAB_DIR / fname
356
- if fp.exists():
357
- tab_out = _load_table_safe(fp)
358
- else:
359
- reply += f"\n\n*(Could not find table: {fname})*"
360
-
361
- new_history = (history or []) + [
362
- {"role": "user", "content": user_msg},
363
- {"role": "assistant", "content": reply},
364
- ]
365
-
366
- return new_history, "", chart_out, tab_out
367
-
368
-
369
- def _keyword_fallback(msg: str, idx: Dict, kpis: Dict) -> Tuple[str, Dict]:
370
- """Simple keyword matcher when LLM is unavailable."""
371
- msg_lower = msg.lower()
372
-
373
- if not idx["python"]["figures"] and not idx["python"]["tables"]:
374
- return (
375
- "No artifacts found yet. Please run the pipeline first (Tab 1), "
376
- "then come back here to explore the results.",
377
- {"show": "none"},
378
- )
379
-
380
- kpi_text = ""
381
- if kpis:
382
- total = kpis.get("total_units_sold", 0)
383
- kpi_text = (
384
- f"Quick summary: **{kpis.get('n_titles', '?')}** book titles across "
385
- f"**{kpis.get('n_months', '?')}** months, with **{total:,.0f}** total units sold."
386
- )
387
-
388
- if any(w in msg_lower for w in ["trend", "sales trend", "monthly sale"]):
389
- return (
390
- f"Here are the sales trends. {kpi_text}",
391
- {"show": "figure", "chart": "sales"},
392
- )
393
-
394
- if any(w in msg_lower for w in ["sentiment", "review", "positive", "negative"]):
395
- return (
396
- f"Here is the sentiment distribution across sampled book titles. {kpi_text}",
397
- {"show": "figure", "chart": "sentiment"},
398
- )
399
-
400
- if any(w in msg_lower for w in ["arima", "forecast", "predict"]):
401
- return (
402
- f"Here are the sales trends and forecasts. {kpi_text}",
403
- {"show": "figure", "chart": "sales"},
404
- )
405
-
406
- if any(w in msg_lower for w in ["top", "best sell", "popular", "rank"]):
407
- return (
408
- f"Here are the top-selling titles by units sold. {kpi_text}",
409
- {"show": "table", "scope": "python", "filename": "top_titles_by_units_sold.csv"},
410
- )
411
-
412
- if any(w in msg_lower for w in ["price", "pricing", "decision"]):
413
- return (
414
- f"Here are the pricing decisions. {kpi_text}",
415
- {"show": "table", "scope": "python", "filename": "pricing_decisions.csv"},
416
- )
417
 
418
- if any(w in msg_lower for w in ["dashboard", "overview", "summary", "kpi"]):
419
- return (
420
- f"Dashboard overview: {kpi_text}\n\nAsk me about sales trends, sentiment, forecasts, "
421
- "pricing, or top sellers to see specific visualizations.",
422
- {"show": "table", "scope": "python", "filename": "df_dashboard.csv"},
423
- )
424
 
425
- # Default
426
- return (
427
- f"I can show you various analyses. {kpi_text}\n\n"
428
- "Try asking about: **sales trends**, **sentiment**, **ARIMA forecasts**, "
429
- "**pricing decisions**, **top sellers**, or **dashboard overview**.",
430
- {"show": "none"},
431
  )
432
 
433
-
434
- # =========================================================
435
- # KPI CARDS (BubbleBusters style)
436
- # =========================================================
437
-
438
- def render_kpi_cards() -> str:
439
- kpis = load_kpis()
440
- if not kpis:
441
- return (
442
- '<div style="background:rgba(255,255,255,.65);backdrop-filter:blur(16px);'
443
- 'border-radius:20px;padding:28px;text-align:center;'
444
- 'border:1.5px solid rgba(255,255,255,.7);'
445
- 'box-shadow:0 8px 32px rgba(124,92,191,.08);">'
446
- '<div style="font-size:36px;margin-bottom:10px;">πŸ“Š</div>'
447
- '<div style="color:#a48de8;font-size:14px;'
448
- 'font-weight:800;margin-bottom:6px;">No data yet</div>'
449
- '<div style="color:#9d8fc4;font-size:12px;">'
450
- 'Run the pipeline to populate these cards.</div>'
451
- '</div>'
452
- )
453
-
454
- def card(icon, label, value, colour):
455
- return f"""
456
- <div style="background:rgba(255,255,255,.72);backdrop-filter:blur(16px);
457
- border-radius:20px;padding:18px 14px 16px;text-align:center;
458
- border:1.5px solid rgba(255,255,255,.8);
459
- box-shadow:0 4px 16px rgba(124,92,191,.08);
460
- border-top:3px solid {colour};">
461
- <div style="font-size:26px;margin-bottom:7px;line-height:1;">{icon}</div>
462
- <div style="color:#9d8fc4;font-size:9.5px;text-transform:uppercase;
463
- letter-spacing:1.8px;margin-bottom:7px;font-weight:800;">{label}</div>
464
- <div style="color:#2d1f4e;font-size:16px;font-weight:800;">{value}</div>
465
- </div>"""
466
-
467
- kpi_config = [
468
- ("n_titles", "πŸ“š", "Book Titles", "#a48de8"),
469
- ("n_months", "πŸ“…", "Time Periods", "#7aa6f8"),
470
- ("total_units_sold", "πŸ“¦", "Units Sold", "#6ee7c7"),
471
- ("total_revenue", "πŸ’°", "Revenue", "#3dcba8"),
472
- ]
473
-
474
- html = (
475
- '<div style="display:grid;grid-template-columns:repeat(auto-fit,minmax(140px,1fr));'
476
- 'gap:12px;margin-bottom:24px;">'
477
- )
478
- for key, icon, label, colour in kpi_config:
479
- val = kpis.get(key)
480
- if val is None:
481
- continue
482
- if isinstance(val, (int, float)) and val > 100:
483
- val = f"{val:,.0f}"
484
- html += card(icon, label, str(val), colour)
485
- # Extra KPIs not in config
486
- known = {k for k, *_ in kpi_config}
487
- for key, val in kpis.items():
488
- if key not in known:
489
- label = key.replace("_", " ").title()
490
- if isinstance(val, (int, float)) and val > 100:
491
- val = f"{val:,.0f}"
492
- html += card("πŸ“ˆ", label, str(val), "#8fa8f8")
493
- html += "</div>"
494
- return html
495
-
496
-
497
- # =========================================================
498
- # INTERACTIVE PLOTLY CHARTS (BubbleBusters style)
499
- # =========================================================
500
-
501
- CHART_PALETTE = ["#7c5cbf", "#2ec4a0", "#e8537a", "#e8a230", "#5e8fef",
502
- "#c45ea8", "#3dbacc", "#a0522d", "#6aaa3a", "#d46060"]
503
-
504
- def _styled_layout(**kwargs) -> dict:
505
- defaults = dict(
506
- template="plotly_white",
507
- paper_bgcolor="rgba(255,255,255,0.95)",
508
- plot_bgcolor="rgba(255,255,255,0.98)",
509
- font=dict(family="system-ui, sans-serif", color="#2d1f4e", size=12),
510
- margin=dict(l=60, r=20, t=70, b=70),
511
- legend=dict(
512
- orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1,
513
- bgcolor="rgba(255,255,255,0.92)",
514
- bordercolor="rgba(124,92,191,0.35)", borderwidth=1,
515
- ),
516
- title=dict(font=dict(size=15, color="#4b2d8a")),
517
  )
518
- defaults.update(kwargs)
519
- return defaults
520
-
521
-
522
- def _empty_chart(title: str) -> go.Figure:
523
- fig = go.Figure()
524
- fig.update_layout(
525
- title=title, height=420, template="plotly_white",
526
- paper_bgcolor="rgba(255,255,255,0.95)",
527
- annotations=[dict(text="Run the pipeline to generate data",
528
- x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False,
529
- font=dict(size=14, color="rgba(124,92,191,0.5)"))],
530
- )
531
- return fig
532
-
533
-
534
- def build_sales_chart() -> go.Figure:
535
- path = PY_TAB_DIR / "df_dashboard.csv"
536
- if not path.exists():
537
- return _empty_chart("Sales Trends β€” run the pipeline first")
538
- df = pd.read_csv(path)
539
- date_col = next((c for c in df.columns if "month" in c.lower() or "date" in c.lower()), None)
540
- val_cols = [c for c in df.columns if c != date_col and df[c].dtype in ("float64", "int64")]
541
- if not date_col or not val_cols:
542
- return _empty_chart("Could not auto-detect columns in df_dashboard.csv")
543
- df[date_col] = pd.to_datetime(df[date_col], errors="coerce")
544
- fig = go.Figure()
545
- for i, col in enumerate(val_cols):
546
- fig.add_trace(go.Scatter(
547
- x=df[date_col], y=df[col], name=col.replace("_", " ").title(),
548
- mode="lines+markers", line=dict(color=CHART_PALETTE[i % len(CHART_PALETTE)], width=2),
549
- marker=dict(size=4),
550
- hovertemplate=f"<b>{col.replace('_',' ').title()}</b><br>%{{x|%b %Y}}: %{{y:,.0f}}<extra></extra>",
551
- ))
552
- fig.update_layout(**_styled_layout(height=450, hovermode="x unified",
553
- title=dict(text="Monthly Overview")))
554
- fig.update_xaxes(gridcolor="rgba(124,92,191,0.15)", showgrid=True)
555
- fig.update_yaxes(gridcolor="rgba(124,92,191,0.15)", showgrid=True)
556
- return fig
557
-
558
-
559
- def build_sentiment_chart() -> go.Figure:
560
- path = PY_TAB_DIR / "sentiment_counts_sampled.csv"
561
- if not path.exists():
562
- return _empty_chart("Sentiment Distribution β€” run the pipeline first")
563
- df = pd.read_csv(path)
564
- title_col = df.columns[0]
565
- sent_cols = [c for c in ["negative", "neutral", "positive"] if c in df.columns]
566
- if not sent_cols:
567
- return _empty_chart("No sentiment columns found in CSV")
568
- colors = {"negative": "#e8537a", "neutral": "#5e8fef", "positive": "#2ec4a0"}
569
- fig = go.Figure()
570
- for col in sent_cols:
571
- fig.add_trace(go.Bar(
572
- name=col.title(), y=df[title_col], x=df[col],
573
- orientation="h", marker_color=colors.get(col, "#888"),
574
- hovertemplate=f"<b>{col.title()}</b>: %{{x}}<extra></extra>",
575
- ))
576
- fig.update_layout(**_styled_layout(
577
- height=max(400, len(df) * 28), barmode="stack",
578
- title=dict(text="Sentiment Distribution by Book"),
579
- ))
580
- fig.update_xaxes(title="Number of Reviews")
581
- fig.update_yaxes(autorange="reversed")
582
- return fig
583
-
584
-
585
- def build_top_sellers_chart() -> go.Figure:
586
- path = PY_TAB_DIR / "top_titles_by_units_sold.csv"
587
- if not path.exists():
588
- return _empty_chart("Top Sellers β€” run the pipeline first")
589
- df = pd.read_csv(path).head(15)
590
- title_col = next((c for c in df.columns if "title" in c.lower()), df.columns[0])
591
- val_col = next((c for c in df.columns if "unit" in c.lower() or "sold" in c.lower()), df.columns[-1])
592
- fig = go.Figure(go.Bar(
593
- y=df[title_col], x=df[val_col], orientation="h",
594
- marker=dict(color=df[val_col], colorscale=[[0, "#c5b4f0"], [1, "#7c5cbf"]]),
595
- hovertemplate="<b>%{y}</b><br>Units: %{x:,.0f}<extra></extra>",
596
- ))
597
- fig.update_layout(**_styled_layout(
598
- height=max(400, len(df) * 30),
599
- title=dict(text="Top Selling Titles"), showlegend=False,
600
- ))
601
- fig.update_yaxes(autorange="reversed")
602
- fig.update_xaxes(title="Total Units Sold")
603
- return fig
604
-
605
-
606
- def refresh_dashboard():
607
- return render_kpi_cards(), build_sales_chart(), build_sentiment_chart(), build_top_sellers_chart()
608
-
609
-
610
- # =========================================================
611
- # UI
612
- # =========================================================
613
-
614
- ensure_dirs()
615
-
616
- def load_css() -> str:
617
- css_path = BASE_DIR / "style.css"
618
- return css_path.read_text(encoding="utf-8") if css_path.exists() else ""
619
-
620
-
621
- with gr.Blocks(title="AIBDM 2026 Workshop App") as demo:
622
 
623
  gr.Markdown(
624
- "# SE21 App Template\n"
625
- "*This is an app template for SE21 students*",
626
- elem_id="escp_title",
 
 
 
 
 
627
  )
628
 
629
- # ===========================================================
630
- # TAB 1 -- Pipeline Runner
631
- # ===========================================================
632
- with gr.Tab("Pipeline Runner"):
633
- gr.Markdown()
634
-
635
- with gr.Row():
636
- with gr.Column(scale=1):
637
- btn_nb1 = gr.Button("Step 1: Data Creation", variant="secondary")
638
- with gr.Column(scale=1):
639
- btn_nb2 = gr.Button("Step 2: Python Analysis", variant="secondary")
640
-
641
- with gr.Row():
642
- btn_all = gr.Button("Run Full Pipeline (Both Steps)", variant="primary")
643
-
644
- run_log = gr.Textbox(
645
- label="Execution Log",
646
- lines=18,
647
- max_lines=30,
648
- interactive=False,
649
- )
650
-
651
- btn_nb1.click(run_datacreation, outputs=[run_log])
652
- btn_nb2.click(run_pythonanalysis, outputs=[run_log])
653
- btn_all.click(run_full_pipeline, outputs=[run_log])
654
-
655
- # ===========================================================
656
- # TAB 2 -- Dashboard (KPIs + Interactive Charts + Gallery)
657
- # ===========================================================
658
- with gr.Tab("Dashboard"):
659
- kpi_html = gr.HTML(value=render_kpi_cards)
660
-
661
- refresh_btn = gr.Button("Refresh Dashboard", variant="primary")
662
-
663
- gr.Markdown("#### Interactive Charts")
664
- chart_sales = gr.Plot(label="Monthly Overview")
665
- chart_sentiment = gr.Plot(label="Sentiment Distribution")
666
- chart_top = gr.Plot(label="Top Sellers")
667
-
668
- gr.Markdown("#### Static Figures (from notebooks)")
669
- gallery = gr.Gallery(
670
- label="Generated Figures",
671
- columns=2,
672
- height=480,
673
- object_fit="contain",
674
- )
675
-
676
- gr.Markdown("#### Data Tables")
677
- table_dropdown = gr.Dropdown(
678
- label="Select a table to view",
679
- choices=[],
680
- interactive=True,
681
- )
682
- table_display = gr.Dataframe(
683
- label="Table Preview",
684
- interactive=False,
685
- )
686
-
687
- def _on_refresh():
688
- kpi, c1, c2, c3 = refresh_dashboard()
689
- figs, dd, df = refresh_gallery()
690
- return kpi, c1, c2, c3, figs, dd, df
691
-
692
- refresh_btn.click(
693
- _on_refresh,
694
- outputs=[kpi_html, chart_sales, chart_sentiment, chart_top,
695
- gallery, table_dropdown, table_display],
696
- )
697
- table_dropdown.change(
698
- on_table_select,
699
- inputs=[table_dropdown],
700
- outputs=[table_display],
701
- )
702
-
703
- # ===========================================================
704
- # TAB 3 -- AI Dashboard
705
- # ===========================================================
706
- with gr.Tab('"AI" Dashboard'):
707
- _ai_status = (
708
- "Connected to your **n8n workflow**." if N8N_WEBHOOK_URL
709
- else "**LLM active.**" if LLM_ENABLED
710
- else "Using **keyword matching**. Upgrade options: "
711
- "set `N8N_WEBHOOK_URL` to connect your n8n workflow, "
712
- "or set `HF_API_KEY` for direct LLM access."
713
- )
714
- gr.Markdown(
715
- "### Ask questions, get interactive visualisations\n\n"
716
- f"Type a question and the system will pick the right interactive chart or table. {_ai_status}"
717
- )
718
-
719
- with gr.Row(equal_height=True):
720
- with gr.Column(scale=1):
721
- chatbot = gr.Chatbot(
722
- label="Conversation",
723
- height=380,
724
- )
725
- user_input = gr.Textbox(
726
- label="Ask about your data",
727
- placeholder="e.g. Show me sales trends / What are the top sellers? / Sentiment analysis",
728
- lines=1,
729
- )
730
- gr.Examples(
731
- examples=[
732
- "Show me the sales trends",
733
- "What does the sentiment look like?",
734
- "Which titles sell the most?",
735
- "Show the ARIMA forecasts",
736
- "What are the pricing decisions?",
737
- "Give me a dashboard overview",
738
- ],
739
- inputs=user_input,
740
- )
741
-
742
- with gr.Column(scale=1):
743
- ai_figure = gr.Plot(
744
- label="Interactive Chart",
745
- )
746
- ai_table = gr.Dataframe(
747
- label="Data Table",
748
- interactive=False,
749
- )
750
-
751
- user_input.submit(
752
- ai_chat,
753
- inputs=[user_input, chatbot],
754
- outputs=[chatbot, user_input, ai_figure, ai_table],
755
- )
756
-
757
-
758
- demo.launch(css=load_css(), allowed_paths=[str(BASE_DIR)])
 
1
+ """
2
+ Restaurant Health Grade Predictor
3
+ ----------------------------------
4
+ A Gradio app that predicts health inspection grades (A/B/C)
5
+ using a placeholder Random Forest model trained on synthetic data.
6
+
7
+ Requirements:
8
+ pip install gradio scikit-learn matplotlib numpy pandas
9
+ """
10
 
 
11
  import gradio as gr
12
+ import numpy as np
13
+ import pandas as pd
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib.patches as mpatches
16
+ from sklearn.ensemble import RandomForestClassifier
17
+ from sklearn.preprocessing import LabelEncoder
18
+ import warnings
19
+
20
+ warnings.filterwarnings("ignore")
21
+
22
+ # ──────────────────────────────────────────────────────────────────────────────
23
+ # 1. Build a placeholder Random Forest model on synthetic data
24
+ # ──────────────────────────────────────────────────────────────────────────────
25
+
26
+ CUISINE_TYPES = [
27
+ "American", "Chinese", "Italian", "Mexican", "Japanese",
28
+ "Indian", "Thai", "Mediterranean", "French", "Korean",
29
+ ]
30
+
31
+ VIOLATION_CODES = [
32
+ "No Violation",
33
+ "02A - No food safety certificate",
34
+ "04L - Evidence of mice or rats",
35
+ "06C - Food not protected",
36
+ "08A - Facility not sanitized",
37
+ "10B - Plumbing not properly installed",
38
+ "15L - Workers not using proper hygiene",
39
+ ]
40
+
41
+ GRADE_LABELS = ["A", "B", "C"]
42
+
43
+ # Encode categorical features
44
+ cuisine_enc = LabelEncoder().fit(CUISINE_TYPES)
45
+ violation_enc = LabelEncoder().fit(VIOLATION_CODES)
46
+
47
+ def encode_inputs(cuisine: str, violation: str, score: float) -> np.ndarray:
48
+ c = cuisine_enc.transform([cuisine])[0]
49
+ v = violation_enc.transform([violation])[0]
50
+ return np.array([[c, v, score]])
51
+
52
+
53
+ def generate_synthetic_data(n: int = 2000, seed: int = 42) -> tuple:
54
+ rng = np.random.default_rng(seed)
55
+ cuisines = rng.integers(0, len(CUISINE_TYPES), n)
56
+ violations = rng.integers(0, len(VIOLATION_CODES), n)
57
+ scores = rng.uniform(0, 100, n)
58
+
59
+ # Grade logic: score drives grade; violations add noise
60
+ grades = []
61
+ for i in range(n):
62
+ base = scores[i]
63
+ penalty = violations[i] * 3 # higher code β†’ worse grade
64
+ effective = base - penalty
65
+ if effective >= 60:
66
+ grades.append(0) # A
67
+ elif effective >= 40:
68
+ grades.append(1) # B
69
+ else:
70
+ grades.append(2) # C
71
+
72
+ X = np.column_stack([cuisines, violations, scores])
73
+ y = np.array(grades)
74
+ return X, y
75
+
76
+
77
+ print("Training placeholder Random Forest model …")
78
+ X_train, y_train = generate_synthetic_data()
79
+ model = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
80
+ model.fit(X_train, y_train)
81
+ print("Model ready βœ“")
82
+
83
+
84
+ # ──────────────────────────────────────────────────────────────────────────────
85
+ # 2. Prediction + chart function
86
+ # ──────────────────────────────────────────────────────────────────────────────
87
+
88
+ GRADE_COLORS = {
89
+ "A": "#2ECC71", # green
90
+ "B": "#F39C12", # amber
91
+ "C": "#E74C3C", # red
92
+ }
93
+
94
+ def predict_grade(cuisine: str, violation: str, score: float):
95
+ """Run inference and return a grade label and a probability bar chart."""
96
+ X = encode_inputs(cuisine, violation, score)
97
+ proba = model.predict_proba(X)[0] # shape (3,)
98
+ pred_idx = int(np.argmax(proba))
99
+ grade = GRADE_LABELS[pred_idx]
100
+ confidence = proba[pred_idx] * 100
101
+
102
+ # ── build the bar chart ──────────────────────────────────────────────────
103
+ fig, ax = plt.subplots(figsize=(6, 3.5))
104
+ fig.patch.set_facecolor("#1A1A2E")
105
+ ax.set_facecolor("#16213E")
106
+
107
+ bar_colors = [GRADE_COLORS[g] for g in GRADE_LABELS]
108
+ bars = ax.bar(
109
+ GRADE_LABELS,
110
+ proba * 100,
111
+ color=bar_colors,
112
+ width=0.5,
113
+ edgecolor="none",
114
+ zorder=3,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  )
116
 
117
+ # highlight the predicted grade with a glow border
118
+ pred_bar = bars[pred_idx]
119
+ pred_bar.set_linewidth(2.5)
120
+ pred_bar.set_edgecolor("white")
121
+
122
+ # value labels on bars
123
+ for bar, p in zip(bars, proba * 100):
124
+ ax.text(
125
+ bar.get_x() + bar.get_width() / 2,
126
+ bar.get_height() + 1.5,
127
+ f"{p:.1f}%",
128
+ ha="center", va="bottom",
129
+ color="white", fontsize=11, fontweight="bold",
130
+ )
131
 
132
+ ax.set_ylim(0, 110)
133
+ ax.set_xlabel("Predicted Grade", color="#AAAACC", fontsize=11, labelpad=8)
134
+ ax.set_ylabel("Probability (%)", color="#AAAACC", fontsize=11, labelpad=8)
135
+ ax.set_title(
136
+ f"Model Confidence β€” Predicted Grade: {grade} ({confidence:.1f}%)",
137
+ color="white", fontsize=13, fontweight="bold", pad=12,
138
+ )
139
+ ax.tick_params(colors="white", labelsize=12)
140
+ for spine in ax.spines.values():
141
+ spine.set_visible(False)
142
+ ax.yaxis.grid(True, color="#2A2A4A", linewidth=0.8, zorder=0)
143
+ ax.set_axisbelow(True)
144
+
145
+ plt.tight_layout()
146
+
147
+ # ── compose the text output ───────────────────────────────────────────────
148
+ emoji = {"A": "🟒", "B": "🟑", "C": "πŸ”΄"}[grade]
149
+ summary = (
150
+ f"{emoji} Predicted Health Grade: **{grade}**\n\n"
151
+ f"Confidence: {confidence:.1f}%\n\n"
152
+ f"---\n"
153
+ f"| Input | Value |\n"
154
+ f"|---|---|\n"
155
+ f"| Cuisine | {cuisine} |\n"
156
+ f"| Violation | {violation} |\n"
157
+ f"| Inspection Score | {score:.1f} |\n\n"
158
+ f"*Note: This uses a placeholder Random Forest model trained on "
159
+ f"synthetic data. Replace `generate_synthetic_data()` and re-train "
160
+ f"with real inspection records for production use.*"
161
+ )
162
 
163
+ return summary, fig
 
164
 
 
165
 
166
+ # ──────────────────────────────────────────────────────────────────────────────
167
+ # 3. Gradio UI
168
+ # ──────────────────────────────────────────────────────────────────────────────
 
 
169
 
170
+ DESCRIPTION = """
171
+ ## 🍽️ Restaurant Health Grade Predictor
 
172
 
173
+ Enter inspection details below to get a predicted **A / B / C** health grade
174
+ and a probability breakdown from the Random Forest model.
 
 
 
 
 
175
  """
176
 
177
+ with gr.Blocks(
178
+ title="Health Grade Predictor",
179
+ theme=gr.themes.Soft(
180
+ primary_hue="violet",
181
+ secondary_hue="slate",
182
+ neutral_hue="slate",
183
+ ),
184
+ css="""
185
+ .predict-btn { font-size: 1.1rem !important; padding: 0.7rem !important; }
186
+ #grade-output .prose { font-size: 1.05rem !important; }
187
+ """,
188
+ ) as demo:
189
+
190
+ gr.Markdown(DESCRIPTION)
191
+
192
+ with gr.Row():
193
+ with gr.Column(scale=1):
194
+ cuisine_input = gr.Dropdown(
195
+ choices=CUISINE_TYPES,
196
+ value="American",
197
+ label="🍜 Cuisine Type",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
+ violation_input = gr.Dropdown(
200
+ choices=VIOLATION_CODES,
201
+ value="No Violation",
202
+ label="⚠️ Violation Code",
203
+ )
204
+ score_input = gr.Slider(
205
+ minimum=0,
206
+ maximum=100,
207
+ value=85,
208
+ step=0.5,
209
+ label="πŸ“Š Inspection Score (0 = worst, 100 = best)",
210
+ )
211
+ predict_btn = gr.Button(
212
+ "πŸ” Predict Grade",
213
+ variant="primary",
214
+ elem_classes="predict-btn",
215
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ with gr.Column(scale=2):
218
+ grade_output = gr.Markdown(
219
+ value="*Fill in the inputs and click **Predict Grade**.*",
220
+ elem_id="grade-output",
221
+ )
222
+ chart_output = gr.Plot(label="Grade Probability Distribution")
223
 
224
+ predict_btn.click(
225
+ fn=predict_grade,
226
+ inputs=[cuisine_input, violation_input, score_input],
227
+ outputs=[grade_output, chart_output],
 
 
228
  )
229
 
230
+ gr.Examples(
231
+ examples=[
232
+ ["Italian", "No Violation", 95],
233
+ ["Chinese", "04L - Evidence of mice or rats", 55],
234
+ ["Mexican", "08A - Facility not sanitized", 40],
235
+ ["Japanese", "02A - No food safety certificate",72],
236
+ ["Mediterranean","15L - Workers not using proper hygiene", 30],
237
+ ],
238
+ inputs=[cuisine_input, violation_input, score_input],
239
+ outputs=[grade_output, chart_output],
240
+ fn=predict_grade,
241
+ cache_examples=True,
242
+ label="πŸ“Œ Quick Examples",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
  gr.Markdown(
246
+ """
247
+ ---
248
+ **How grades work (synthetic rules used for training)**
249
+ `Effective Score = Inspection Score βˆ’ (Violation Code Index Γ— 3)`
250
+ β€’ **A** β†’ Effective β‰₯ 60 &nbsp;|&nbsp; **B** β†’ 40–59 &nbsp;|&nbsp; **C** β†’ < 40
251
+
252
+ Replace `generate_synthetic_data()` with a real labelled dataset to make this production-ready.
253
+ """
254
  )
255
 
256
+ if __name__ == "__main__":
257
+ demo.launch(share=False)