File size: 5,687 Bytes
2a4b179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# -*- coding: utf-8 -*-
"""sign_identifier_gradio_final.ipynb

Gradio interface for image classification using a classmate’s model.
Model: cassieli226/sign-identification-automl
"""

# !pip install autogluon.multimodal gradio huggingface_hub pillow pandas --quiet

import os, pathlib, shutil, zipfile, tempfile, io
import pandas as pd
from PIL import Image
import gradio as gr
import huggingface_hub
from autogluon.multimodal import MultiModalPredictor

# -----------------------------
# Config
# -----------------------------
MODEL_REPO_ID = "cassieli226/sign-identification-automl"
ZIP_FILENAME = "autogluon_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
MAX_SIZE_MB = 5

# -----------------------------
# Model loading
# -----------------------------
def prepare_predictor_dir() -> str:
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    local_zip = huggingface_hub.hf_hub_download(
        repo_id=MODEL_REPO_ID,
        filename=ZIP_FILENAME,
        repo_type="model",
        local_dir=str(CACHE_DIR),
        local_dir_use_symlinks=False,
    )
    if EXTRACT_DIR.exists():
        shutil.rmtree(EXTRACT_DIR)
    EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(local_zip, "r") as zf:
        zf.extractall(str(EXTRACT_DIR))
    contents = list(EXTRACT_DIR.iterdir())
    return str(contents[0]) if (len(contents) == 1 and contents[0].is_dir()) else str(EXTRACT_DIR)

print("Loading predictor...")
PREDICTOR_DIR = prepare_predictor_dir()
PREDICTOR = MultiModalPredictor.load(PREDICTOR_DIR)
print("✅ Model loaded!")

# Try to extract readable class names
try:
    if hasattr(PREDICTOR, "label_generator") and hasattr(PREDICTOR.label_generator, "category_map"):
        CLASS_MAP = {str(k): str(v) for k, v in PREDICTOR.label_generator.category_map.items()}
    else:
        CLASS_MAP = {str(i): str(lbl) for i, lbl in enumerate(PREDICTOR.class_labels)}
except Exception:
    CLASS_MAP = {}
print("Class map:", CLASS_MAP)

# -----------------------------
# Helpers
# -----------------------------
def _pil_to_tmp(img: Image.Image, resize_size=224) -> str:
    img = img.convert("RGB").resize((resize_size, resize_size))
    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    img.save(tmp.name, format="PNG")
    return tmp.name

def _size_mb_of_png(img: Image.Image) -> float:
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return buf.tell() / (1024 * 1024)

# -----------------------------
# Inference
# -----------------------------
def predict_image(img, resize_size=224, top_k=3, prob_threshold=0.05):
    if img is None:
        return None, "<div style='color:#b91c1c'>⚠️ Please upload an image.</div>"

    # Validate size
    size_mb = _size_mb_of_png(img)
    if size_mb > MAX_SIZE_MB:
        return None, f"<div style='color:#b91c1c'>⚠️ File too large: {size_mb:.2f} MB (limit {MAX_SIZE_MB} MB).</div>"

    # Preprocess
    img_path = _pil_to_tmp(img, resize_size)
    df = pd.DataFrame({"image": [img_path]})

    # Predict probabilities
    proba_df = PREDICTOR.predict_proba(df)
    probs = proba_df.iloc[0].sort_values(ascending=False)

    # Map numeric indices to actual category names
    probs.index = [CLASS_MAP.get(str(i), str(i)) for i in probs.index]

    # Apply threshold + top-k
    filtered = probs[probs > prob_threshold]
    top = filtered.head(top_k) if not filtered.empty else probs.head(top_k)

    # Top-1
    top_label = top.index[0]
    top_conf = float(top.iloc[0]) * 100

    # HTML result
    html = f"""
    <div style="padding:20px;background:#f0f9ff;border-radius:12px;border-left:5px solid #3b82f6;">
        <h2 style="color:#1e40af;margin:0 0 12px;">🔎 Prediction Results</h2>
        <div style="background:#3b82f6;color:white;padding:15px;border-radius:10px;margin-bottom:15px;text-align:center;">
            <div style="font-size:18px;">Predicted Sign</div>
            <div style="font-size:36px;font-weight:800;letter-spacing:.3px;">{top_label}</div>
            <div style="font-size:16px;opacity:.95;">Confidence: {top_conf:.1f}%</div>
        </div>
        <h4 style="color:#1e40af;margin:10px 0;">Top {len(top)} Predictions</h4>
        <ul style="margin:0 0 10px 18px;color:#111827;">
    """
    for cls, prob in top.items():
        html += f"<li><b>{cls}</b>: {prob*100:.1f}%</li>"
    html += "</ul></div>"

    return img, html

# -----------------------------
# Gradio UI
# -----------------------------
with gr.Blocks(css="""
    .gradio-container { font-family: 'Segoe UI', system-ui, -apple-system, Arial, sans-serif; }
""") as demo:
    gr.HTML(
        "<h1 style='text-align:center;color:#1e40af;'>🚦 Traffic Sign Identifier</h1>"
        "<p style='text-align:center;color:#334155;'>Upload a traffic sign image to see predictions.</p>"
    )

    with gr.Row():
        with gr.Column():
            img_in = gr.Image(type="pil", image_mode="RGB", label="Upload Image", sources=["upload","webcam"])
            resize_size = gr.Slider(64, 512, value=224, step=32, label="Resize Size (px)")
            top_k = gr.Slider(1, 10, value=3, step=1, label="Top-k Predictions")
            prob_threshold = gr.Slider(0.0, 0.9, value=0.05, step=0.01, label="Probability Threshold")
            btn = gr.Button("🔍 Predict", variant="primary")
        with gr.Column():
            orig_out = gr.Image(label="Original Image", image_mode="RGB")
            res_out = gr.HTML(label="Results")

    btn.click(
        fn=predict_image,
        inputs=[img_in, resize_size, top_k, prob_threshold],
        outputs=[orig_out, res_out],
    )

if __name__ == "__main__":
    demo.launch(share=True)