maryzhang commited on
Commit
2a4b179
·
verified ·
1 Parent(s): a87c5f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """sign_identifier_gradio_final.ipynb
3
+
4
+ Gradio interface for image classification using a classmate’s model.
5
+ Model: cassieli226/sign-identification-automl
6
+ """
7
+
8
+ # !pip install autogluon.multimodal gradio huggingface_hub pillow pandas --quiet
9
+
10
+ import os, pathlib, shutil, zipfile, tempfile, io
11
+ import pandas as pd
12
+ from PIL import Image
13
+ import gradio as gr
14
+ import huggingface_hub
15
+ from autogluon.multimodal import MultiModalPredictor
16
+
17
+ # -----------------------------
18
+ # Config
19
+ # -----------------------------
20
+ MODEL_REPO_ID = "cassieli226/sign-identification-automl"
21
+ ZIP_FILENAME = "autogluon_predictor_dir.zip"
22
+ CACHE_DIR = pathlib.Path("hf_assets")
23
+ EXTRACT_DIR = CACHE_DIR / "predictor_native"
24
+ MAX_SIZE_MB = 5
25
+
26
+ # -----------------------------
27
+ # Model loading
28
+ # -----------------------------
29
+ def prepare_predictor_dir() -> str:
30
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
31
+ local_zip = huggingface_hub.hf_hub_download(
32
+ repo_id=MODEL_REPO_ID,
33
+ filename=ZIP_FILENAME,
34
+ repo_type="model",
35
+ local_dir=str(CACHE_DIR),
36
+ local_dir_use_symlinks=False,
37
+ )
38
+ if EXTRACT_DIR.exists():
39
+ shutil.rmtree(EXTRACT_DIR)
40
+ EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
41
+ with zipfile.ZipFile(local_zip, "r") as zf:
42
+ zf.extractall(str(EXTRACT_DIR))
43
+ contents = list(EXTRACT_DIR.iterdir())
44
+ return str(contents[0]) if (len(contents) == 1 and contents[0].is_dir()) else str(EXTRACT_DIR)
45
+
46
+ print("Loading predictor...")
47
+ PREDICTOR_DIR = prepare_predictor_dir()
48
+ PREDICTOR = MultiModalPredictor.load(PREDICTOR_DIR)
49
+ print("✅ Model loaded!")
50
+
51
+ # Try to extract readable class names
52
+ try:
53
+ if hasattr(PREDICTOR, "label_generator") and hasattr(PREDICTOR.label_generator, "category_map"):
54
+ CLASS_MAP = {str(k): str(v) for k, v in PREDICTOR.label_generator.category_map.items()}
55
+ else:
56
+ CLASS_MAP = {str(i): str(lbl) for i, lbl in enumerate(PREDICTOR.class_labels)}
57
+ except Exception:
58
+ CLASS_MAP = {}
59
+ print("Class map:", CLASS_MAP)
60
+
61
+ # -----------------------------
62
+ # Helpers
63
+ # -----------------------------
64
+ def _pil_to_tmp(img: Image.Image, resize_size=224) -> str:
65
+ img = img.convert("RGB").resize((resize_size, resize_size))
66
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
67
+ img.save(tmp.name, format="PNG")
68
+ return tmp.name
69
+
70
+ def _size_mb_of_png(img: Image.Image) -> float:
71
+ buf = io.BytesIO()
72
+ img.save(buf, format="PNG")
73
+ return buf.tell() / (1024 * 1024)
74
+
75
+ # -----------------------------
76
+ # Inference
77
+ # -----------------------------
78
+ def predict_image(img, resize_size=224, top_k=3, prob_threshold=0.05):
79
+ if img is None:
80
+ return None, "<div style='color:#b91c1c'>⚠️ Please upload an image.</div>"
81
+
82
+ # Validate size
83
+ size_mb = _size_mb_of_png(img)
84
+ if size_mb > MAX_SIZE_MB:
85
+ return None, f"<div style='color:#b91c1c'>⚠️ File too large: {size_mb:.2f} MB (limit {MAX_SIZE_MB} MB).</div>"
86
+
87
+ # Preprocess
88
+ img_path = _pil_to_tmp(img, resize_size)
89
+ df = pd.DataFrame({"image": [img_path]})
90
+
91
+ # Predict probabilities
92
+ proba_df = PREDICTOR.predict_proba(df)
93
+ probs = proba_df.iloc[0].sort_values(ascending=False)
94
+
95
+ # Map numeric indices to actual category names
96
+ probs.index = [CLASS_MAP.get(str(i), str(i)) for i in probs.index]
97
+
98
+ # Apply threshold + top-k
99
+ filtered = probs[probs > prob_threshold]
100
+ top = filtered.head(top_k) if not filtered.empty else probs.head(top_k)
101
+
102
+ # Top-1
103
+ top_label = top.index[0]
104
+ top_conf = float(top.iloc[0]) * 100
105
+
106
+ # HTML result
107
+ html = f"""
108
+ <div style="padding:20px;background:#f0f9ff;border-radius:12px;border-left:5px solid #3b82f6;">
109
+ <h2 style="color:#1e40af;margin:0 0 12px;">🔎 Prediction Results</h2>
110
+ <div style="background:#3b82f6;color:white;padding:15px;border-radius:10px;margin-bottom:15px;text-align:center;">
111
+ <div style="font-size:18px;">Predicted Sign</div>
112
+ <div style="font-size:36px;font-weight:800;letter-spacing:.3px;">{top_label}</div>
113
+ <div style="font-size:16px;opacity:.95;">Confidence: {top_conf:.1f}%</div>
114
+ </div>
115
+ <h4 style="color:#1e40af;margin:10px 0;">Top {len(top)} Predictions</h4>
116
+ <ul style="margin:0 0 10px 18px;color:#111827;">
117
+ """
118
+ for cls, prob in top.items():
119
+ html += f"<li><b>{cls}</b>: {prob*100:.1f}%</li>"
120
+ html += "</ul></div>"
121
+
122
+ return img, html
123
+
124
+ # -----------------------------
125
+ # Gradio UI
126
+ # -----------------------------
127
+ with gr.Blocks(css="""
128
+ .gradio-container { font-family: 'Segoe UI', system-ui, -apple-system, Arial, sans-serif; }
129
+ """) as demo:
130
+ gr.HTML(
131
+ "<h1 style='text-align:center;color:#1e40af;'>🚦 Traffic Sign Identifier</h1>"
132
+ "<p style='text-align:center;color:#334155;'>Upload a traffic sign image to see predictions.</p>"
133
+ )
134
+
135
+ with gr.Row():
136
+ with gr.Column():
137
+ img_in = gr.Image(type="pil", image_mode="RGB", label="Upload Image", sources=["upload","webcam"])
138
+ resize_size = gr.Slider(64, 512, value=224, step=32, label="Resize Size (px)")
139
+ top_k = gr.Slider(1, 10, value=3, step=1, label="Top-k Predictions")
140
+ prob_threshold = gr.Slider(0.0, 0.9, value=0.05, step=0.01, label="Probability Threshold")
141
+ btn = gr.Button("🔍 Predict", variant="primary")
142
+ with gr.Column():
143
+ orig_out = gr.Image(label="Original Image", image_mode="RGB")
144
+ res_out = gr.HTML(label="Results")
145
+
146
+ btn.click(
147
+ fn=predict_image,
148
+ inputs=[img_in, resize_size, top_k, prob_threshold],
149
+ outputs=[orig_out, res_out],
150
+ )
151
+
152
+ if __name__ == "__main__":
153
+ demo.launch(share=True)