File size: 3,693 Bytes
1454b2f
 
 
fb8ace6
1454b2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb8ace6
1454b2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# app.py
import os
import pandas as pd
import numpy as np
from huggingface_hub import hf_hub_download, HfApi
from autogluon.tabular import TabularPredictor
import gradio as gr
MODEL_REPO = "samder03/2025-24679-tabular-autolguon-predictor"
LOCAL_DIR = "/tmp/autogluon_predictor"
os.makedirs(LOCAL_DIR, exist_ok=True)

api = HfApi()
files = api.list_repo_files(repo_id=MODEL_REPO)
print("Files found:", files)

from autogluon.common.loaders import load_pkl

predictor = None
if "autogluon_predictor.pkl" in files:
    predictor_path = hf_hub_download(
        repo_id=MODEL_REPO,
        filename="autogluon_predictor.pkl",
        local_dir=LOCAL_DIR
    )
    predictor = load_pkl.load(path=predictor_path)

elif "autogluon_predictor_dir.zip" in files:
    zip_path = hf_hub_download(
        repo_id=MODEL_REPO,
        filename="autogluon_predictor_dir.zip",
        local_dir=LOCAL_DIR
    )
    import zipfile
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(LOCAL_DIR)
    predictor = TabularPredictor.load(LOCAL_DIR)

else:
    raise FileNotFoundError("Could not find supported predictor file in repo.")
# load dataset sample to gather columns for widgets (optional: you can hardcode the widgets)
from datasets import load_dataset
ds_orig = load_dataset("ecopus/pokemon_cards", split="original")
ds_aug = load_dataset("ecopus/pokemon_cards", split="augmented")
df_orig = pd.DataFrame(ds_orig)
df_aug = pd.DataFrame(ds_aug)
df = pd.concat([df_orig, df_aug])
label_col = predictor.label if hasattr(predictor, "label") else predictor._label
features = [c for c in df.columns if c != label_col]
# same simple mapping to widgets as in the notebook:
feature_specs = {}
for c in features:
    col = df[c]
    if pd.api.types.is_numeric_dtype(col):
        minv, maxv = float(col.min()), float(col.max())
        step = max((maxv - minv)/100.0, 0.01)
        feature_specs[c] = ("numeric", minv, maxv, step, float(col.median()))
    else:
        uniques = sorted(pd.Series(col.dropna().unique()).astype(str).tolist())
        if len(uniques) <= 20:
            feature_specs[c] = ("categorical", uniques)
        else:
            feature_specs[c] = ("text",)
# Build Gradio input widgets
inputs = []
input_names = []
for c in features:
    input_names.append(c)
    col = df[c]
    if pd.api.types.is_numeric_dtype(col):
        inputs.append(gr.Number(label=c))
    elif pd.api.types.is_bool_dtype(col):
        inputs.append(gr.Checkbox(label=c))
    else:
        inputs.append(gr.Textbox(value="", label=c))

# No submodels available in this predictor
prob_toggle = gr.Checkbox(value=True, label="Return probabilities (vs. hard label)")
def predict_record(*args):
    record = {name: val for name, val in zip(input_names, args[:-1])}
    return_prob = args[-1]

    df_in = pd.DataFrame([record])
    if return_prob:
        try:
            proba = predictor.predict_proba(df_in)
            proba_row = proba.iloc[0].to_dict()
            top_label = max(proba_row, key=proba_row.get)
            return {"label": top_label, "probabilities": proba_row}
        except Exception:
            label = predictor.predict(df_in).iloc[0]
            return {"label": label}
    else:
        label = predictor.predict(df_in).iloc[0]
        return {"label": label}

iface = gr.Interface(
    fn=predict_record,
    inputs=inputs + [prob_toggle],
    outputs=gr.Label(num_top_classes=3, label="Prediction"),
    title="Pokémon Card Collector's Item Predictor (AutoGluon)",
    description="Predicts whether a Pokémon card is a collector's item.")

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))