Spaces:
Sleeping
Sleeping
| """ | |
| Restaurant Health Grade Predictor | |
| ---------------------------------- | |
| A Gradio app that predicts health inspection grades (A/B/C) | |
| using a placeholder Random Forest model trained on synthetic data. | |
| Requirements: | |
| pip install gradio scikit-learn matplotlib numpy pandas | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.preprocessing import LabelEncoder | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. Build a placeholder Random Forest model on synthetic data | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CUISINE_TYPES = [ | |
| "American", "Chinese", "Italian", "Mexican", "Japanese", | |
| "Indian", "Thai", "Mediterranean", "French", "Korean", | |
| ] | |
| VIOLATION_CODES = [ | |
| "No Violation", | |
| "02A - No food safety certificate", | |
| "04L - Evidence of mice or rats", | |
| "06C - Food not protected", | |
| "08A - Facility not sanitized", | |
| "10B - Plumbing not properly installed", | |
| "15L - Workers not using proper hygiene", | |
| ] | |
| GRADE_LABELS = ["A", "B", "C"] | |
| # Encode categorical features | |
| cuisine_enc = LabelEncoder().fit(CUISINE_TYPES) | |
| violation_enc = LabelEncoder().fit(VIOLATION_CODES) | |
| def encode_inputs(cuisine: str, violation: str, score: float) -> np.ndarray: | |
| c = cuisine_enc.transform([cuisine])[0] | |
| v = violation_enc.transform([violation])[0] | |
| return np.array([[c, v, score]]) | |
| def generate_synthetic_data(n: int = 2000, seed: int = 42) -> tuple: | |
| rng = np.random.default_rng(seed) | |
| cuisines = rng.integers(0, len(CUISINE_TYPES), n) | |
| violations = rng.integers(0, len(VIOLATION_CODES), n) | |
| scores = rng.uniform(0, 100, n) | |
| # Grade logic: score drives grade; violations add noise | |
| grades = [] | |
| for i in range(n): | |
| base = scores[i] | |
| penalty = violations[i] * 3 # higher code β worse grade | |
| effective = base - penalty | |
| if effective >= 60: | |
| grades.append(0) # A | |
| elif effective >= 40: | |
| grades.append(1) # B | |
| else: | |
| grades.append(2) # C | |
| X = np.column_stack([cuisines, violations, scores]) | |
| y = np.array(grades) | |
| return X, y | |
| print("Training placeholder Random Forest model β¦") | |
| X_train, y_train = generate_synthetic_data() | |
| model = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1) | |
| model.fit(X_train, y_train) | |
| print("Model ready β") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. Prediction + chart function | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| GRADE_COLORS = { | |
| "A": "#2ECC71", # green | |
| "B": "#F39C12", # amber | |
| "C": "#E74C3C", # red | |
| } | |
| def predict_grade(cuisine: str, violation: str, score: float): | |
| """Run inference and return a grade label and a probability bar chart.""" | |
| X = encode_inputs(cuisine, violation, score) | |
| proba = model.predict_proba(X)[0] # shape (3,) | |
| pred_idx = int(np.argmax(proba)) | |
| grade = GRADE_LABELS[pred_idx] | |
| confidence = proba[pred_idx] * 100 | |
| # ββ build the bar chart ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| fig, ax = plt.subplots(figsize=(6, 3.5)) | |
| fig.patch.set_facecolor("#1A1A2E") | |
| ax.set_facecolor("#16213E") | |
| bar_colors = [GRADE_COLORS[g] for g in GRADE_LABELS] | |
| bars = ax.bar( | |
| GRADE_LABELS, | |
| proba * 100, | |
| color=bar_colors, | |
| width=0.5, | |
| edgecolor="none", | |
| zorder=3, | |
| ) | |
| # highlight the predicted grade with a glow border | |
| pred_bar = bars[pred_idx] | |
| pred_bar.set_linewidth(2.5) | |
| pred_bar.set_edgecolor("white") | |
| # value labels on bars | |
| for bar, p in zip(bars, proba * 100): | |
| ax.text( | |
| bar.get_x() + bar.get_width() / 2, | |
| bar.get_height() + 1.5, | |
| f"{p:.1f}%", | |
| ha="center", va="bottom", | |
| color="white", fontsize=11, fontweight="bold", | |
| ) | |
| ax.set_ylim(0, 110) | |
| ax.set_xlabel("Predicted Grade", color="#AAAACC", fontsize=11, labelpad=8) | |
| ax.set_ylabel("Probability (%)", color="#AAAACC", fontsize=11, labelpad=8) | |
| ax.set_title( | |
| f"Model Confidence β Predicted Grade: {grade} ({confidence:.1f}%)", | |
| color="white", fontsize=13, fontweight="bold", pad=12, | |
| ) | |
| ax.tick_params(colors="white", labelsize=12) | |
| for spine in ax.spines.values(): | |
| spine.set_visible(False) | |
| ax.yaxis.grid(True, color="#2A2A4A", linewidth=0.8, zorder=0) | |
| ax.set_axisbelow(True) | |
| plt.tight_layout() | |
| # ββ compose the text output βββββββββββββββββββββββββββββββββββββββββββββββ | |
| emoji = {"A": "π’", "B": "π‘", "C": "π΄"}[grade] | |
| summary = ( | |
| f"{emoji} Predicted Health Grade: **{grade}**\n\n" | |
| f"Confidence: {confidence:.1f}%\n\n" | |
| f"---\n" | |
| f"| Input | Value |\n" | |
| f"|---|---|\n" | |
| f"| Cuisine | {cuisine} |\n" | |
| f"| Violation | {violation} |\n" | |
| f"| Inspection Score | {score:.1f} |\n\n" | |
| f"*Note: This uses a placeholder Random Forest model trained on " | |
| f"synthetic data. Replace `generate_synthetic_data()` and re-train " | |
| f"with real inspection records for production use.*" | |
| ) | |
| return summary, fig | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. Gradio UI | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DESCRIPTION = """ | |
| ## π½οΈ Restaurant Health Grade Predictor | |
| Enter inspection details below to get a predicted **A / B / C** health grade | |
| and a probability breakdown from the Random Forest model. | |
| """ | |
| with gr.Blocks( | |
| title="Health Grade Predictor", | |
| theme=gr.themes.Soft( | |
| primary_hue="violet", | |
| secondary_hue="slate", | |
| neutral_hue="slate", | |
| ), | |
| css=""" | |
| .predict-btn { font-size: 1.1rem !important; padding: 0.7rem !important; } | |
| #grade-output .prose { font-size: 1.05rem !important; } | |
| """, | |
| ) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| cuisine_input = gr.Dropdown( | |
| choices=CUISINE_TYPES, | |
| value="American", | |
| label="π Cuisine Type", | |
| ) | |
| violation_input = gr.Dropdown( | |
| choices=VIOLATION_CODES, | |
| value="No Violation", | |
| label="β οΈ Violation Code", | |
| ) | |
| score_input = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=85, | |
| step=0.5, | |
| label="π Inspection Score (0 = worst, 100 = best)", | |
| ) | |
| predict_btn = gr.Button( | |
| "π Predict Grade", | |
| variant="primary", | |
| elem_classes="predict-btn", | |
| ) | |
| with gr.Column(scale=2): | |
| grade_output = gr.Markdown( | |
| value="*Fill in the inputs and click **Predict Grade**.*", | |
| elem_id="grade-output", | |
| ) | |
| chart_output = gr.Plot(label="Grade Probability Distribution") | |
| predict_btn.click( | |
| fn=predict_grade, | |
| inputs=[cuisine_input, violation_input, score_input], | |
| outputs=[grade_output, chart_output], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Italian", "No Violation", 95], | |
| ["Chinese", "04L - Evidence of mice or rats", 55], | |
| ["Mexican", "08A - Facility not sanitized", 40], | |
| ["Japanese", "02A - No food safety certificate",72], | |
| ["Mediterranean","15L - Workers not using proper hygiene", 30], | |
| ], | |
| inputs=[cuisine_input, violation_input, score_input], | |
| outputs=[grade_output, chart_output], | |
| fn=predict_grade, | |
| cache_examples=True, | |
| label="π Quick Examples", | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **How grades work (synthetic rules used for training)** | |
| `Effective Score = Inspection Score β (Violation Code Index Γ 3)` | |
| β’ **A** β Effective β₯ 60 | **B** β 40β59 | **C** β < 40 | |
| Replace `generate_synthetic_data()` with a real labelled dataset to make this production-ready. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) |