pokemon / app.py
scottymcgee's picture
Update app.py
ddb44ab verified
# -*- coding: utf-8 -*-
"""
This application loads a trained AutoGluon TabularPredictor that was built on the ecopus/pokemon_cards dataset and exposes it through a Gradio interface. Users can enter details of a Pokémon card—including its name, release year, set, artwork style, condition, set-number equivalent, and market value—and the model will instantly predict whether the card is considered a collector’s item (“Yes” or “No”). The interface also displays the model’s class probabilities so users can see how confident the model is about each prediction.
Dataset reference:
https://huggingface.co/datasets/ecopus/pokemon_cards
"""
# ----------------------------
# Imports
# ----------------------------
import os
import shutil
import zipfile
import pathlib
from typing import Any, Dict, List, Optional
import pandas as pd
import gradio as gr
import huggingface_hub
import autogluon.tabular
# Optional: pull choices/ranges from the dataset (falls back if unavailable)
try:
from datasets import load_dataset
HAS_DATASETS = True
except Exception:
HAS_DATASETS = False
# ----------------------------
# Settings: point to your trained AutoGluon predictor on the Hub
# ----------------------------
MODEL_REPO_ID = "samder03/2025-24679-tabular-autolguon-predictor" # <- CHANGE ME
ZIP_FILENAME = "autogluon_predictor_dir.zip" # <- CHANGE if different
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
# Columns must match training-time names exactly:
FEATURE_COLS = [
"Card", # string
"Year", # int
"Card Set", # string
"Artwork Style", # string
"Condition", # string
"Set Number Eq", # float
"Market Value", # float
]
TARGET_COL = "Collector's Item" # binary: "Yes"/"No" in the dataset
# ----------------------------
# Load predictor (download zip from Hub, then autogluon load)
# ----------------------------
def _prepare_predictor_dir() -> str:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_zip = huggingface_hub.hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=ZIP_FILENAME,
repo_type="model",
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
)
if EXTRACT_DIR.exists():
shutil.rmtree(EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(EXTRACT_DIR))
contents = list(EXTRACT_DIR.iterdir())
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
return str(predictor_root)
# If loading locally instead of the Hub, comment these two lines and set:
# PREDICTOR_DIR = "/path/to/AutogluonModels/ag-<run>"
PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = autogluon.tabular.TabularPredictor.load(PREDICTOR_DIR, require_py_version_match=False)
# ----------------------------
# Helpers
# ----------------------------
OUTCOME_LABELS = {
"Yes": "Yes", "No": "No",
1: "Yes", 0: "No",
"1": "Yes", "0": "No",
True: "Yes", False: "No",
}
def _human_label(x: Any) -> str:
return OUTCOME_LABELS.get(x, str(x))
def _normalize_proba_keys(row_probs: Dict[Any, float]) -> Dict[str, float]:
normalized: Dict[str, float] = {}
for k, v in row_probs.items():
key = _human_label(k)
normalized[key] = float(v) + float(normalized.get(key, 0.0))
# sort high->low
return dict(sorted(normalized.items(), key=lambda kv: kv[1], reverse=True))
# ----------------------------
# Dataset-driven choices/ranges (with safe fallbacks if offline)
# ----------------------------
def get_dataset_metadata() -> dict:
"""
Try to pull unique choices and numeric ranges from ecopus/pokemon_cards.
Falls back to hard-coded sensible defaults if the dataset lib or network is unavailable.
"""
meta = {
"card_examples": ["Charizard", "Pikachu", "Mew", "Ivysaur"],
"card_sets": [
"Base Set", "Pokemon 151", "Evolutions", "Prismatic Evolutions",
"Journey Together", "Destined Rivals", "Stellar Crown", "BREAKpoint",
"EX Sandstorm", "Double Crisis", "McDonalds"
],
"art_styles": [
"Standard", "Holo", "Reverse Holo", "Full Art",
"Full Art Gold", "Full Art Rainbow", "Alternate Art", "Trainer Gallery", "Promo",
# include obvious typo seen in a sample row to avoid surprises:
"Standart"
],
"conditions": ["Mint", "Near Mint", "Lightly Played", "Heavily Played"],
"year_min": 1995,
"year_max": 2025,
"sne_min": 0.04,
"sne_max": 1.50,
"mv_min": 0.08,
"mv_max": 133.00,
"examples_rows": [], # list of example rows matching FEATURE_COLS order
}
if not HAS_DATASETS:
return meta
try:
ds = load_dataset("ecopus/pokemon_cards")
# Merge splits if present
split_names = [k for k in ds.keys()]
frames: List[pd.DataFrame] = []
for sn in split_names:
frames.append(pd.DataFrame(ds[sn]))
df_all = pd.concat(frames, ignore_index=True)
# Coerce types safely (in case commas exist in displayed values)
def _to_int(x):
try:
return int(str(x).replace(",", ""))
except Exception:
return None
def _to_float(x):
try:
return float(str(x).replace(",", ""))
except Exception:
return None
# Compute unique choices
if "Card Set" in df_all.columns:
sets = sorted({str(s) for s in df_all["Card Set"].dropna().unique().tolist()})
if sets:
meta["card_sets"] = sets
if "Artwork Style" in df_all.columns:
styles = sorted({str(s) for s in df_all["Artwork Style"].dropna().unique().tolist()})
if styles:
# include 'Standart' if present
meta["art_styles"] = styles
if "Condition" in df_all.columns:
conds = sorted({str(s) for s in df_all["Condition"].dropna().unique().tolist()})
if conds:
meta["conditions"] = conds
# Ranges
if "Year" in df_all.columns:
years = [y for y in df_all["Year"].map(_to_int).dropna().tolist()]
if years:
meta["year_min"] = min(years)
meta["year_max"] = max(years)
if "Set Number Eq" in df_all.columns:
sne = [s for s in df_all["Set Number Eq"].map(_to_float).dropna().tolist()]
if sne:
meta["sne_min"] = float(min(sne))
meta["sne_max"] = float(max(sne))
if "Market Value" in df_all.columns:
mv = [m for m in df_all["Market Value"].map(_to_float).dropna().tolist()]
if mv:
meta["mv_min"] = float(min(mv))
meta["mv_max"] = float(max(mv))
# Example rows (grab up to 5 reasonable examples)
cols_ok = all(c in df_all.columns for c in FEATURE_COLS)
if cols_ok:
sample = df_all[FEATURE_COLS].dropna().head(5)
meta["examples_rows"] = sample.values.tolist()
# Some card names to seed the textbox suggestions
if "Card" in df_all.columns:
meta["card_examples"] = df_all["Card"].dropna().astype(str).head(8).tolist()
except Exception:
pass
return meta
META = get_dataset_metadata()
# ----------------------------
# Prediction function
# ----------------------------
def do_predict(card_name: str,
year: float,
card_set: str,
artwork_style: str,
condition: str,
set_number_eq: float,
market_value: float):
# Build a single-row DataFrame exactly matching training columns
row = {
"Card": str(card_name).strip(),
"Year": int(year),
"Card Set": str(card_set).strip(),
"Artwork Style": str(artwork_style).strip(),
"Condition": str(condition).strip(),
"Set Number Eq": float(set_number_eq),
"Market Value": float(market_value),
}
X = pd.DataFrame([row], columns=FEATURE_COLS)
# Predict label
pred_series = PREDICTOR.predict(X)
raw_pred = pred_series.iloc[0]
pred_label = _human_label(raw_pred)
# Predict probabilities (if available)
try:
proba = PREDICTOR.predict_proba(X)
if isinstance(proba, pd.Series): # AutoGluon can return Series for binary
proba = proba.to_frame().T
except Exception:
proba = None
proba_dict = None
if proba is not None:
row0 = proba.iloc[0].to_dict()
proba_dict = _normalize_proba_keys(row0)
# If probabilities missing, fabricate 100% on predicted class for UX
if not proba_dict:
proba_dict = {pred_label: 1.0, ("No" if pred_label == "Yes" else "Yes"): 0.0}
return proba_dict
# ----------------------------
# Build Gradio UI
# ----------------------------
with gr.Blocks() as demo:
gr.Markdown("# Pokémon Card → Collector's Item Predictor (Yes/No)")
gr.Markdown(
"Enter a card's details to predict whether it's a **collector's item**. "
"This GUI mirrors the columns in the dataset "
"[ecopus/pokemon_cards](https://huggingface.co/datasets/ecopus/pokemon_cards)."
)
with gr.Row():
card_name = gr.Textbox(
label="Card",
value=(META["card_examples"][0] if META["card_examples"] else "Charizard"),
placeholder="e.g., Charizard"
)
card_set = gr.Dropdown(
choices=META["card_sets"],
value=(META["card_sets"][0] if META["card_sets"] else None),
label="Card Set",
allow_custom_value=True,
)
with gr.Row():
year = gr.Slider(
minimum=int(META["year_min"]),
maximum=int(META["year_max"]),
step=1,
value=min(2024, int(META["year_max"])),
label="Year"
)
artwork_style = gr.Dropdown(
choices=META["art_styles"],
value=(META["art_styles"][0] if META["art_styles"] else None),
label="Artwork Style",
allow_custom_value=True,
)
condition = gr.Dropdown(
choices=META["conditions"],
value=(META["conditions"][0] if META["conditions"] else None),
label="Condition",
allow_custom_value=True,
)
with gr.Row():
set_number_eq = gr.Slider(
minimum=float(META["sne_min"]),
maximum=float(META["sne_max"]),
step=0.001,
value=0.536,
label="Set Number Eq"
)
market_value = gr.Number(
value=round(min(100.00, float(META["mv_max"])), 2),
precision=2,
label="Market Value (USD)"
)
proba_pretty = gr.Label(num_top_classes=2, label="Class probabilities (Yes/No)")
inputs = [card_name, year, card_set, artwork_style, condition, set_number_eq, market_value]
for comp in inputs:
comp.change(fn=do_predict, inputs=inputs, outputs=[proba_pretty])
# Representative examples from the dataset if available, else a few hand-crafted ones
examples = META["examples_rows"] if META["examples_rows"] else [
["Charizard", 1999, "Base Set", "Holo", "Near Mint", 0.85, 450.00],
["Pikachu", 2024, "Pokemon 151", "Full Art", "Near Mint", 1.05, 47.45],
["Ivysaur", 2025, "Pokemon 151", "Full Art", "Near Mint", 1.106, 30.77],
["Mew", 2024, "Pokemon 151", "Full Art Gold", "Mint", 1.242, 16.51],
["Spheal", 2014, "Evolutions", "Reverse Holo", "Lightly Played", 0.226, 0.12],
]
gr.Examples(
examples=examples,
inputs=inputs,
label="Representative examples (from the dataset or sensible defaults)",
examples_per_page=min(5, len(examples)),
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()