import os import shutil import zipfile import pathlib import pandas as pd import gradio as gr import huggingface_hub import autogluon.tabular as agt import shutil # --- Model Loading --- def _prepare_predictor_dir() -> str: """Downloads and extracts the model files from Hugging Face.""" # Add a check to clear the cache before download if CACHE_DIR.exists(): print(f"Clearing cache directory: {CACHE_DIR}") shutil.rmtree(CACHE_DIR) CACHE_DIR.mkdir(parents=True, exist_ok=True) # --- Settings and Metadata --- MODEL_REPO_ID = "bcueva/2024-24679-tabular-autolguon-predictor" ZIP_FILENAME = "autogluon_predictor_dir.zip" CACHE_DIR = pathlib.Path("hf_assets") EXTRACT_DIR = CACHE_DIR / "predictor_native" # Features for the model FEATURE_COLS = [ "Capacity_ml", "Height_cm", "Diameter_cm", "Weight_g", "Material", ] # The target variable to be predicted TARGET_COL = "Use_Type" # Define the possible values for the 'Material' feature MATERIAL_LABELS = [ "Ceramic", "Glass", "Plastic", "Stainless Steel", ] # Mapping the integer labels back to human-readable labels for the 'Use_Type' prediction # These labels are based on the model's training data. OUTCOME_LABELS = { 0: "Hot", 1: "Cold", } # --- Model Loading --- def _prepare_predictor_dir() -> str: """Downloads and extracts the model files from Hugging Face.""" CACHE_DIR.mkdir(parents=True, exist_ok=True) local_zip = huggingface_hub.hf_hub_download( repo_id=MODEL_REPO_ID, filename=ZIP_FILENAME, repo_type="model", local_dir=str(CACHE_DIR), local_dir_use_symlinks=False, ) if EXTRACT_DIR.exists(): shutil.rmtree(EXTRACT_DIR) EXTRACT_DIR.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(local_zip, "r") as zf: zf.extractall(str(EXTRACT_DIR)) contents = list(EXTRACT_DIR.iterdir()) predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR return str(predictor_root) PREDICTOR_DIR = _prepare_predictor_dir() PREDICTOR = agt.TabularPredictor.load(PREDICTOR_DIR, require_py_version_match=False) # --- Prediction Function --- def _human_label(c): """Maps the model's numerical output to a human-readable label.""" try: ci = int(c) if ci in OUTCOME_LABELS: return OUTCOME_LABELS[ci] except (ValueError, TypeError): pass if c in OUTCOME_LABELS: return OUTCOME_LABELS[c] return str(c) def do_predict(capacity_ml, height_cm, diameter_cm, weight_g, material): """ Takes user input, formats it for the model, and returns a prediction. 'Cup_ID' is excluded as per the user request. """ # Create a DataFrame for a single prediction # 'Cup_ID' is included as a placeholder with a dummy value (e.g., 0) # because the model's training data included it. row = { "Cup_ID": 0, # Dummy value "Capacity_ml": capacity_ml, "Height_cm": height_cm, "Diameter_cm": diameter_cm, "Weight_g": weight_g, "Material": material, } X = pd.DataFrame([row]) # Get the raw prediction and its label pred_series = PREDICTOR.predict(X) raw_pred = pred_series.iloc[0] pred_label = _human_label(raw_pred) # Get prediction probabilities try: proba = PREDICTOR.predict_proba(X) if isinstance(proba, pd.Series): proba = proba.to_frame().T # Format probabilities into a dictionary for Gradio proba_dict = { _human_label(cls): float(val) for cls, val in proba.iloc[0].items() } proba_dict = dict(sorted(proba_dict.items(), key=lambda kv: kv[1], reverse=True)) except Exception as e: print(f"Could not get probabilities: {e}") proba_dict = None return proba_dict # --- Gradio UI --- with gr.Blocks(fill_height=True) as demo: gr.Markdown("# Cup Use Predictor ☕️") gr.Markdown(""" Enter the physical properties of a cup to predict its intended use type (Hot or Cold). This app uses a pre-trained **AutoGluon** model to classify the cup's purpose. """) with gr.Column(): material = gr.Radio(choices=MATERIAL_LABELS, value="Ceramic", label="Material") with gr.Row(): with gr.Column(): capacity_ml = gr.Number(value=350, label="Capacity (ml)") height_cm = gr.Number(value=10.0, label="Height (cm)") with gr.Column(): diameter_cm = gr.Number(value=8.0, label="Diameter (cm)") weight_g = gr.Number(value=250, label="Weight (g)") predict_btn = gr.Button("Predict Use Type") output_label = gr.Label(num_top_classes=2, label="Prediction") inputs = [capacity_ml, height_cm, diameter_cm, weight_g, material] predict_btn.click(fn=do_predict, inputs=inputs, outputs=output_label) gr.Examples( examples=[ [478, 8, 7.7, 315, "Ceramic"], # Example for a coffee mug (likely 'Hot') [442, 13.8, 6.4, 155, "Glass"], # Example for a tall drinking glass (likely 'Cold') [392, 18, 5.7, 61, "Plastic"], # Example for a small tea cup (likely 'Hot') [302, 17.5, 5.5, 783, "Stainless Steel"], # Example for a disposable soda cup (likely 'Cold') ], inputs=inputs, label="Representative Examples", examples_per_page=5, ) if __name__ == "__main__": demo.launch()