Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """sign_identifier_gradio_final.ipynb | |
| Gradio interface for image classification using a classmate’s model. | |
| Model: cassieli226/sign-identification-automl | |
| """ | |
| # !pip install autogluon.multimodal gradio huggingface_hub pillow pandas --quiet | |
| import os, pathlib, shutil, zipfile, tempfile, io | |
| import pandas as pd | |
| from PIL import Image | |
| import gradio as gr | |
| import huggingface_hub | |
| from autogluon.multimodal import MultiModalPredictor | |
| # ----------------------------- | |
| # Config | |
| # ----------------------------- | |
| MODEL_REPO_ID = "cassieli226/sign-identification-automl" | |
| ZIP_FILENAME = "autogluon_predictor_dir.zip" | |
| CACHE_DIR = pathlib.Path("hf_assets") | |
| EXTRACT_DIR = CACHE_DIR / "predictor_native" | |
| MAX_SIZE_MB = 5 | |
| # ----------------------------- | |
| # Model loading | |
| # ----------------------------- | |
| def prepare_predictor_dir() -> str: | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| local_zip = huggingface_hub.hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=ZIP_FILENAME, | |
| repo_type="model", | |
| local_dir=str(CACHE_DIR), | |
| local_dir_use_symlinks=False, | |
| ) | |
| if EXTRACT_DIR.exists(): | |
| shutil.rmtree(EXTRACT_DIR) | |
| EXTRACT_DIR.mkdir(parents=True, exist_ok=True) | |
| with zipfile.ZipFile(local_zip, "r") as zf: | |
| zf.extractall(str(EXTRACT_DIR)) | |
| contents = list(EXTRACT_DIR.iterdir()) | |
| return str(contents[0]) if (len(contents) == 1 and contents[0].is_dir()) else str(EXTRACT_DIR) | |
| print("Loading predictor...") | |
| PREDICTOR_DIR = prepare_predictor_dir() | |
| PREDICTOR = MultiModalPredictor.load(PREDICTOR_DIR) | |
| print("✅ Model loaded!") | |
| # Try to extract readable class names | |
| try: | |
| if hasattr(PREDICTOR, "label_generator") and hasattr(PREDICTOR.label_generator, "category_map"): | |
| CLASS_MAP = {str(k): str(v) for k, v in PREDICTOR.label_generator.category_map.items()} | |
| else: | |
| CLASS_MAP = {str(i): str(lbl) for i, lbl in enumerate(PREDICTOR.class_labels)} | |
| except Exception: | |
| CLASS_MAP = {} | |
| print("Class map:", CLASS_MAP) | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def _pil_to_tmp(img: Image.Image, resize_size=224) -> str: | |
| img = img.convert("RGB").resize((resize_size, resize_size)) | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| img.save(tmp.name, format="PNG") | |
| return tmp.name | |
| def _size_mb_of_png(img: Image.Image) -> float: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return buf.tell() / (1024 * 1024) | |
| # ----------------------------- | |
| # Inference | |
| # ----------------------------- | |
| def predict_image(img, resize_size=224, top_k=3, prob_threshold=0.05): | |
| if img is None: | |
| return None, "<div style='color:#b91c1c'>⚠️ Please upload an image.</div>" | |
| # Validate size | |
| size_mb = _size_mb_of_png(img) | |
| if size_mb > MAX_SIZE_MB: | |
| return None, f"<div style='color:#b91c1c'>⚠️ File too large: {size_mb:.2f} MB (limit {MAX_SIZE_MB} MB).</div>" | |
| # Preprocess | |
| img_path = _pil_to_tmp(img, resize_size) | |
| df = pd.DataFrame({"image": [img_path]}) | |
| # Predict probabilities | |
| proba_df = PREDICTOR.predict_proba(df) | |
| probs = proba_df.iloc[0].sort_values(ascending=False) | |
| # Map numeric indices to actual category names | |
| probs.index = [CLASS_MAP.get(str(i), str(i)) for i in probs.index] | |
| # Apply threshold + top-k | |
| filtered = probs[probs > prob_threshold] | |
| top = filtered.head(top_k) if not filtered.empty else probs.head(top_k) | |
| # Top-1 | |
| top_label = top.index[0] | |
| top_conf = float(top.iloc[0]) * 100 | |
| # HTML result | |
| html = f""" | |
| <div style="padding:20px;background:#f0f9ff;border-radius:12px;border-left:5px solid #3b82f6;"> | |
| <h2 style="color:#1e40af;margin:0 0 12px;">🔎 Prediction Results</h2> | |
| <div style="background:#3b82f6;color:white;padding:15px;border-radius:10px;margin-bottom:15px;text-align:center;"> | |
| <div style="font-size:18px;">Predicted Sign</div> | |
| <div style="font-size:36px;font-weight:800;letter-spacing:.3px;">{top_label}</div> | |
| <div style="font-size:16px;opacity:.95;">Confidence: {top_conf:.1f}%</div> | |
| </div> | |
| <h4 style="color:#1e40af;margin:10px 0;">Top {len(top)} Predictions</h4> | |
| <ul style="margin:0 0 10px 18px;color:#111827;"> | |
| """ | |
| for cls, prob in top.items(): | |
| html += f"<li><b>{cls}</b>: {prob*100:.1f}%</li>" | |
| html += "</ul></div>" | |
| return img, html | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| with gr.Blocks(css=""" | |
| .gradio-container { font-family: 'Segoe UI', system-ui, -apple-system, Arial, sans-serif; } | |
| """) as demo: | |
| gr.HTML( | |
| "<h1 style='text-align:center;color:#1e40af;'>🚦 Traffic Sign Identifier</h1>" | |
| "<p style='text-align:center;color:#334155;'>Upload a traffic sign image to see predictions.</p>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_in = gr.Image(type="pil", image_mode="RGB", label="Upload Image", sources=["upload","webcam"]) | |
| resize_size = gr.Slider(64, 512, value=224, step=32, label="Resize Size (px)") | |
| top_k = gr.Slider(1, 10, value=3, step=1, label="Top-k Predictions") | |
| prob_threshold = gr.Slider(0.0, 0.9, value=0.05, step=0.01, label="Probability Threshold") | |
| btn = gr.Button("🔍 Predict", variant="primary") | |
| with gr.Column(): | |
| orig_out = gr.Image(label="Original Image", image_mode="RGB") | |
| res_out = gr.HTML(label="Results") | |
| btn.click( | |
| fn=predict_image, | |
| inputs=[img_in, resize_size, top_k, prob_threshold], | |
| outputs=[orig_out, res_out], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |