Spaces:
Sleeping
Sleeping
File size: 5,687 Bytes
2a4b179 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# -*- coding: utf-8 -*-
"""sign_identifier_gradio_final.ipynb
Gradio interface for image classification using a classmate’s model.
Model: cassieli226/sign-identification-automl
"""
# !pip install autogluon.multimodal gradio huggingface_hub pillow pandas --quiet
import os, pathlib, shutil, zipfile, tempfile, io
import pandas as pd
from PIL import Image
import gradio as gr
import huggingface_hub
from autogluon.multimodal import MultiModalPredictor
# -----------------------------
# Config
# -----------------------------
MODEL_REPO_ID = "cassieli226/sign-identification-automl"
ZIP_FILENAME = "autogluon_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
MAX_SIZE_MB = 5
# -----------------------------
# Model loading
# -----------------------------
def prepare_predictor_dir() -> str:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_zip = huggingface_hub.hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=ZIP_FILENAME,
repo_type="model",
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
)
if EXTRACT_DIR.exists():
shutil.rmtree(EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(EXTRACT_DIR))
contents = list(EXTRACT_DIR.iterdir())
return str(contents[0]) if (len(contents) == 1 and contents[0].is_dir()) else str(EXTRACT_DIR)
print("Loading predictor...")
PREDICTOR_DIR = prepare_predictor_dir()
PREDICTOR = MultiModalPredictor.load(PREDICTOR_DIR)
print("✅ Model loaded!")
# Try to extract readable class names
try:
if hasattr(PREDICTOR, "label_generator") and hasattr(PREDICTOR.label_generator, "category_map"):
CLASS_MAP = {str(k): str(v) for k, v in PREDICTOR.label_generator.category_map.items()}
else:
CLASS_MAP = {str(i): str(lbl) for i, lbl in enumerate(PREDICTOR.class_labels)}
except Exception:
CLASS_MAP = {}
print("Class map:", CLASS_MAP)
# -----------------------------
# Helpers
# -----------------------------
def _pil_to_tmp(img: Image.Image, resize_size=224) -> str:
img = img.convert("RGB").resize((resize_size, resize_size))
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
img.save(tmp.name, format="PNG")
return tmp.name
def _size_mb_of_png(img: Image.Image) -> float:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.tell() / (1024 * 1024)
# -----------------------------
# Inference
# -----------------------------
def predict_image(img, resize_size=224, top_k=3, prob_threshold=0.05):
if img is None:
return None, "<div style='color:#b91c1c'>⚠️ Please upload an image.</div>"
# Validate size
size_mb = _size_mb_of_png(img)
if size_mb > MAX_SIZE_MB:
return None, f"<div style='color:#b91c1c'>⚠️ File too large: {size_mb:.2f} MB (limit {MAX_SIZE_MB} MB).</div>"
# Preprocess
img_path = _pil_to_tmp(img, resize_size)
df = pd.DataFrame({"image": [img_path]})
# Predict probabilities
proba_df = PREDICTOR.predict_proba(df)
probs = proba_df.iloc[0].sort_values(ascending=False)
# Map numeric indices to actual category names
probs.index = [CLASS_MAP.get(str(i), str(i)) for i in probs.index]
# Apply threshold + top-k
filtered = probs[probs > prob_threshold]
top = filtered.head(top_k) if not filtered.empty else probs.head(top_k)
# Top-1
top_label = top.index[0]
top_conf = float(top.iloc[0]) * 100
# HTML result
html = f"""
<div style="padding:20px;background:#f0f9ff;border-radius:12px;border-left:5px solid #3b82f6;">
<h2 style="color:#1e40af;margin:0 0 12px;">🔎 Prediction Results</h2>
<div style="background:#3b82f6;color:white;padding:15px;border-radius:10px;margin-bottom:15px;text-align:center;">
<div style="font-size:18px;">Predicted Sign</div>
<div style="font-size:36px;font-weight:800;letter-spacing:.3px;">{top_label}</div>
<div style="font-size:16px;opacity:.95;">Confidence: {top_conf:.1f}%</div>
</div>
<h4 style="color:#1e40af;margin:10px 0;">Top {len(top)} Predictions</h4>
<ul style="margin:0 0 10px 18px;color:#111827;">
"""
for cls, prob in top.items():
html += f"<li><b>{cls}</b>: {prob*100:.1f}%</li>"
html += "</ul></div>"
return img, html
# -----------------------------
# Gradio UI
# -----------------------------
with gr.Blocks(css="""
.gradio-container { font-family: 'Segoe UI', system-ui, -apple-system, Arial, sans-serif; }
""") as demo:
gr.HTML(
"<h1 style='text-align:center;color:#1e40af;'>🚦 Traffic Sign Identifier</h1>"
"<p style='text-align:center;color:#334155;'>Upload a traffic sign image to see predictions.</p>"
)
with gr.Row():
with gr.Column():
img_in = gr.Image(type="pil", image_mode="RGB", label="Upload Image", sources=["upload","webcam"])
resize_size = gr.Slider(64, 512, value=224, step=32, label="Resize Size (px)")
top_k = gr.Slider(1, 10, value=3, step=1, label="Top-k Predictions")
prob_threshold = gr.Slider(0.0, 0.9, value=0.05, step=0.01, label="Probability Threshold")
btn = gr.Button("🔍 Predict", variant="primary")
with gr.Column():
orig_out = gr.Image(label="Original Image", image_mode="RGB")
res_out = gr.HTML(label="Results")
btn.click(
fn=predict_image,
inputs=[img_in, resize_size, top_k, prob_threshold],
outputs=[orig_out, res_out],
)
if __name__ == "__main__":
demo.launch(share=True)
|