maryzhang's picture
Create app.py
2a4b179 verified
# -*- coding: utf-8 -*-
"""sign_identifier_gradio_final.ipynb
Gradio interface for image classification using a classmate’s model.
Model: cassieli226/sign-identification-automl
"""
# !pip install autogluon.multimodal gradio huggingface_hub pillow pandas --quiet
import os, pathlib, shutil, zipfile, tempfile, io
import pandas as pd
from PIL import Image
import gradio as gr
import huggingface_hub
from autogluon.multimodal import MultiModalPredictor
# -----------------------------
# Config
# -----------------------------
MODEL_REPO_ID = "cassieli226/sign-identification-automl"
ZIP_FILENAME = "autogluon_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
MAX_SIZE_MB = 5
# -----------------------------
# Model loading
# -----------------------------
def prepare_predictor_dir() -> str:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_zip = huggingface_hub.hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=ZIP_FILENAME,
repo_type="model",
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
)
if EXTRACT_DIR.exists():
shutil.rmtree(EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(EXTRACT_DIR))
contents = list(EXTRACT_DIR.iterdir())
return str(contents[0]) if (len(contents) == 1 and contents[0].is_dir()) else str(EXTRACT_DIR)
print("Loading predictor...")
PREDICTOR_DIR = prepare_predictor_dir()
PREDICTOR = MultiModalPredictor.load(PREDICTOR_DIR)
print("✅ Model loaded!")
# Try to extract readable class names
try:
if hasattr(PREDICTOR, "label_generator") and hasattr(PREDICTOR.label_generator, "category_map"):
CLASS_MAP = {str(k): str(v) for k, v in PREDICTOR.label_generator.category_map.items()}
else:
CLASS_MAP = {str(i): str(lbl) for i, lbl in enumerate(PREDICTOR.class_labels)}
except Exception:
CLASS_MAP = {}
print("Class map:", CLASS_MAP)
# -----------------------------
# Helpers
# -----------------------------
def _pil_to_tmp(img: Image.Image, resize_size=224) -> str:
img = img.convert("RGB").resize((resize_size, resize_size))
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
img.save(tmp.name, format="PNG")
return tmp.name
def _size_mb_of_png(img: Image.Image) -> float:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.tell() / (1024 * 1024)
# -----------------------------
# Inference
# -----------------------------
def predict_image(img, resize_size=224, top_k=3, prob_threshold=0.05):
if img is None:
return None, "<div style='color:#b91c1c'>⚠️ Please upload an image.</div>"
# Validate size
size_mb = _size_mb_of_png(img)
if size_mb > MAX_SIZE_MB:
return None, f"<div style='color:#b91c1c'>⚠️ File too large: {size_mb:.2f} MB (limit {MAX_SIZE_MB} MB).</div>"
# Preprocess
img_path = _pil_to_tmp(img, resize_size)
df = pd.DataFrame({"image": [img_path]})
# Predict probabilities
proba_df = PREDICTOR.predict_proba(df)
probs = proba_df.iloc[0].sort_values(ascending=False)
# Map numeric indices to actual category names
probs.index = [CLASS_MAP.get(str(i), str(i)) for i in probs.index]
# Apply threshold + top-k
filtered = probs[probs > prob_threshold]
top = filtered.head(top_k) if not filtered.empty else probs.head(top_k)
# Top-1
top_label = top.index[0]
top_conf = float(top.iloc[0]) * 100
# HTML result
html = f"""
<div style="padding:20px;background:#f0f9ff;border-radius:12px;border-left:5px solid #3b82f6;">
<h2 style="color:#1e40af;margin:0 0 12px;">🔎 Prediction Results</h2>
<div style="background:#3b82f6;color:white;padding:15px;border-radius:10px;margin-bottom:15px;text-align:center;">
<div style="font-size:18px;">Predicted Sign</div>
<div style="font-size:36px;font-weight:800;letter-spacing:.3px;">{top_label}</div>
<div style="font-size:16px;opacity:.95;">Confidence: {top_conf:.1f}%</div>
</div>
<h4 style="color:#1e40af;margin:10px 0;">Top {len(top)} Predictions</h4>
<ul style="margin:0 0 10px 18px;color:#111827;">
"""
for cls, prob in top.items():
html += f"<li><b>{cls}</b>: {prob*100:.1f}%</li>"
html += "</ul></div>"
return img, html
# -----------------------------
# Gradio UI
# -----------------------------
with gr.Blocks(css="""
.gradio-container { font-family: 'Segoe UI', system-ui, -apple-system, Arial, sans-serif; }
""") as demo:
gr.HTML(
"<h1 style='text-align:center;color:#1e40af;'>🚦 Traffic Sign Identifier</h1>"
"<p style='text-align:center;color:#334155;'>Upload a traffic sign image to see predictions.</p>"
)
with gr.Row():
with gr.Column():
img_in = gr.Image(type="pil", image_mode="RGB", label="Upload Image", sources=["upload","webcam"])
resize_size = gr.Slider(64, 512, value=224, step=32, label="Resize Size (px)")
top_k = gr.Slider(1, 10, value=3, step=1, label="Top-k Predictions")
prob_threshold = gr.Slider(0.0, 0.9, value=0.05, step=0.01, label="Probability Threshold")
btn = gr.Button("🔍 Predict", variant="primary")
with gr.Column():
orig_out = gr.Image(label="Original Image", image_mode="RGB")
res_out = gr.HTML(label="Results")
btn.click(
fn=predict_image,
inputs=[img_in, resize_size, top_k, prob_threshold],
outputs=[orig_out, res_out],
)
if __name__ == "__main__":
demo.launch(share=True)