File size: 5,458 Bytes
51c4f14
 
 
 
 
 
 
 
5deec10
 
 
 
 
 
 
 
 
 
 
 
 
51c4f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import shutil
import zipfile
import pathlib
import pandas as pd
import gradio as gr
import huggingface_hub
import autogluon.tabular as agt
import shutil

# --- Model Loading ---
def _prepare_predictor_dir() -> str:
    """Downloads and extracts the model files from Hugging Face."""
    # Add a check to clear the cache before download
    if CACHE_DIR.exists():
        print(f"Clearing cache directory: {CACHE_DIR}")
        shutil.rmtree(CACHE_DIR)
        CACHE_DIR.mkdir(parents=True, exist_ok=True)




# --- Settings and Metadata ---
MODEL_REPO_ID = "bcueva/2024-24679-tabular-autolguon-predictor"
ZIP_FILENAME = "autogluon_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"

# Features for the model
FEATURE_COLS = [
    "Capacity_ml",
    "Height_cm",
    "Diameter_cm",
    "Weight_g",
    "Material",
]

# The target variable to be predicted
TARGET_COL = "Use_Type"

# Define the possible values for the 'Material' feature
MATERIAL_LABELS = [
    "Ceramic",
    "Glass",
    "Plastic",
    "Stainless Steel",

]

# Mapping the integer labels back to human-readable labels for the 'Use_Type' prediction
# These labels are based on the model's training data.
OUTCOME_LABELS = {
    0: "Hot",
    1: "Cold",
}

# --- Model Loading ---
def _prepare_predictor_dir() -> str:
    """Downloads and extracts the model files from Hugging Face."""
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    local_zip = huggingface_hub.hf_hub_download(
        repo_id=MODEL_REPO_ID,
        filename=ZIP_FILENAME,
        repo_type="model",
        local_dir=str(CACHE_DIR),
        local_dir_use_symlinks=False,
    )
    if EXTRACT_DIR.exists():
        shutil.rmtree(EXTRACT_DIR)
    EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(local_zip, "r") as zf:
        zf.extractall(str(EXTRACT_DIR))
    contents = list(EXTRACT_DIR.iterdir())
    predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
    return str(predictor_root)

PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = agt.TabularPredictor.load(PREDICTOR_DIR, require_py_version_match=False)

# --- Prediction Function ---
def _human_label(c):
    """Maps the model's numerical output to a human-readable label."""
    try:
        ci = int(c)
        if ci in OUTCOME_LABELS:
            return OUTCOME_LABELS[ci]
    except (ValueError, TypeError):
        pass
    if c in OUTCOME_LABELS:
        return OUTCOME_LABELS[c]
    return str(c)

def do_predict(capacity_ml, height_cm, diameter_cm, weight_g, material):
    """
    Takes user input, formats it for the model, and returns a prediction.
    'Cup_ID' is excluded as per the user request.
    """
    # Create a DataFrame for a single prediction
    # 'Cup_ID' is included as a placeholder with a dummy value (e.g., 0)
    # because the model's training data included it.
    row = {
        "Cup_ID": 0,  # Dummy value
        "Capacity_ml": capacity_ml,
        "Height_cm": height_cm,
        "Diameter_cm": diameter_cm,
        "Weight_g": weight_g,
        "Material": material,
    }
    X = pd.DataFrame([row])

    # Get the raw prediction and its label
    pred_series = PREDICTOR.predict(X)
    raw_pred = pred_series.iloc[0]
    pred_label = _human_label(raw_pred)

    # Get prediction probabilities
    try:
        proba = PREDICTOR.predict_proba(X)
        if isinstance(proba, pd.Series):
            proba = proba.to_frame().T

        # Format probabilities into a dictionary for Gradio
        proba_dict = {
            _human_label(cls): float(val) for cls, val in proba.iloc[0].items()
        }
        proba_dict = dict(sorted(proba_dict.items(), key=lambda kv: kv[1], reverse=True))

    except Exception as e:
        print(f"Could not get probabilities: {e}")
        proba_dict = None

    return proba_dict

# --- Gradio UI ---
with gr.Blocks(fill_height=True) as demo:
    gr.Markdown("# Cup Use Predictor ☕️")
    gr.Markdown("""
    Enter the physical properties of a cup to predict its intended use type (Hot or Cold).
    This app uses a pre-trained **AutoGluon** model to classify the cup's purpose.
    """)

    with gr.Column():
        material = gr.Radio(choices=MATERIAL_LABELS, value="Ceramic", label="Material")

    with gr.Row():
      with gr.Column():
        capacity_ml = gr.Number(value=350, label="Capacity (ml)")
        height_cm = gr.Number(value=10.0, label="Height (cm)")

      with gr.Column():
        diameter_cm = gr.Number(value=8.0, label="Diameter (cm)")
        weight_g = gr.Number(value=250, label="Weight (g)")

    predict_btn = gr.Button("Predict Use Type")
    output_label = gr.Label(num_top_classes=2, label="Prediction")

    inputs = [capacity_ml, height_cm, diameter_cm, weight_g, material]

    predict_btn.click(fn=do_predict, inputs=inputs, outputs=output_label)

    gr.Examples(
        examples=[
            [478, 8, 7.7, 315, "Ceramic"], # Example for a coffee mug (likely 'Hot')
            [442, 13.8, 6.4, 155, "Glass"], # Example for a tall drinking glass (likely 'Cold')
            [392, 18, 5.7, 61, "Plastic"], # Example for a small tea cup (likely 'Hot')
            [302, 17.5, 5.5, 783, "Stainless Steel"], # Example for a disposable soda cup (likely 'Cold')
        ],
        inputs=inputs,
        label="Representative Examples",
        examples_per_page=5,
    )

if __name__ == "__main__":
    demo.launch()