kevinkyi commited on
Commit
468b15b
·
verified ·
1 Parent(s): 937498b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pathlib, zipfile, tempfile
2
+ import pandas as pd
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from huggingface_hub import hf_hub_download, snapshot_download
6
+ import autogluon.multimodal as ag
7
+
8
+ MODEL_REPO_ID = "samder03/2025-24679-image-autogluon-predictor"
9
+
10
+ CLASS_LABELS = {
11
+ 0: "No Stop Sign",
12
+ 1: "Stop Sign",
13
+ "class_0": "No Stop Sign",
14
+ "class_1": "Stop Sign",
15
+ }
16
+
17
+ def _human_label(c):
18
+ try:
19
+ ci = int(c)
20
+ return CLASS_LABELS.get(ci, str(c))
21
+ except Exception:
22
+ return CLASS_LABELS.get(str(c), str(c))
23
+
24
+ def _locate_predictor_dir_from_repo_folder(repo_dir: str) -> str:
25
+ rd = pathlib.Path(repo_dir)
26
+ for p in rd.rglob("predictor.pkl"):
27
+ return str(p.parent)
28
+ return ""
29
+
30
+ def _prepare_predictor_dir() -> str:
31
+ repo_dir = snapshot_download(repo_id=MODEL_REPO_ID, repo_type="model")
32
+ pred_dir = _locate_predictor_dir_from_repo_folder(repo_dir)
33
+ if pred_dir:
34
+ return pred_dir
35
+
36
+ # Fallback: try to find a zip and extract
37
+ zips = list(pathlib.Path(repo_dir).rglob("*.zip"))
38
+ if not zips:
39
+ raise FileNotFoundError("Could not find a predictor directory or .zip in the model repo.")
40
+ zip_path = str(zips[0])
41
+
42
+ workdir = tempfile.mkdtemp(prefix="ag_img_predictor_")
43
+ with zipfile.ZipFile(zip_path, "r") as zf:
44
+ zf.extractall(workdir)
45
+ entries = list(pathlib.Path(workdir).iterdir())
46
+ if len(entries) == 1 and entries[0].is_dir():
47
+ return str(entries[0])
48
+ return workdir
49
+
50
+ PREDICTOR_DIR = _prepare_predictor_dir()
51
+ PREDICTOR = ag.MultiModalPredictor.load(PREDICTOR_DIR)
52
+
53
+ def _ensure_rgb(img: Image.Image) -> Image.Image:
54
+ return img.convert("RGB") if img.mode != "RGB" else img
55
+
56
+ def _resize_shorter(img: Image.Image, shorter: int) -> Image.Image:
57
+ w, h = img.size
58
+ if min(w, h) == shorter:
59
+ return img
60
+ if w < h:
61
+ new_w = shorter
62
+ new_h = int(h * (shorter / w))
63
+ else:
64
+ new_h = shorter
65
+ new_w = int(w * (shorter / h))
66
+ return img.resize((new_w, new_h), Image.BICUBIC)
67
+
68
+ def _center_crop(img: Image.Image, size: int) -> Image.Image:
69
+ w, h = img.size
70
+ side = min(w, h, size)
71
+ left = (w - side) // 2
72
+ top = (h - side) // 2
73
+ return img.crop((left, top, left + side, top + side)).resize((size, size), Image.BICUBIC)
74
+
75
+ def _validate_image(pil_img: Image.Image, max_pixels: int = 8_000_000):
76
+ if pil_img is None:
77
+ return False, "No image provided."
78
+ if pil_img.width * pil_img.height > max_pixels:
79
+ return False, f"Image too large (>{max_pixels:,} pixels). Please upload a smaller image."
80
+ return True, ""
81
+
82
+ def preprocess(pil_img: Image.Image, resize_shorter: int, do_center_crop: bool, crop_size: int) -> Image.Image:
83
+ img = _ensure_rgb(pil_img)
84
+ img = _resize_shorter(img, resize_shorter)
85
+ if do_center_crop:
86
+ img = _center_crop(img, crop_size)
87
+ return img
88
+
89
+ def do_predict(pil_img: Image.Image, resize_shorter: int, do_center_crop: bool, crop_size: int, top_k: int):
90
+ ok, msg = _validate_image(pil_img)
91
+ if not ok:
92
+ return None, None, {"Error": 1.0}
93
+
94
+ pre_img = preprocess(pil_img, resize_shorter, do_center_crop, crop_size)
95
+
96
+ tmpdir = pathlib.Path(tempfile.mkdtemp(prefix="ag_img_run_"))
97
+ orig_path = tmpdir / "original.png"
98
+ pre_path = tmpdir / "preprocessed.png"
99
+ pil_img.save(orig_path)
100
+ pre_img.save(pre_path)
101
+
102
+ df = pd.DataFrame({"image": [str(pre_path)]})
103
+ proba_df = PREDICTOR.predict_proba(df)
104
+ proba_df = proba_df.rename(columns={c: _human_label(c) for c in proba_df.columns})
105
+ row = proba_df.iloc[0]
106
+
107
+ items = sorted(row.items(), key=lambda kv: float(kv[1]), reverse=True)[:max(1, int(top_k))]
108
+ pretty = {k: float(v) for k, v in items}
109
+
110
+ return Image.open(orig_path), Image.open(pre_path), pretty
111
+
112
+ with gr.Blocks(title="Stop Sign Classifier") as demo:
113
+ gr.Markdown("# Stop Sign? — Image Classifier (Classmate Model)")
114
+ gr.Markdown(
115
+ "Upload a PNG/JPG (or use webcam). You’ll see the **original** and the **preprocessed** image, "
116
+ "plus ranked class probabilities."
117
+ )
118
+
119
+ with gr.Row():
120
+ image_in = gr.Image(type="pil", label="Input image (PNG/JPG)", sources=["upload", "webcam"])
121
+ with gr.Column():
122
+ with gr.Accordion("Inference Parameters", open=False):
123
+ resize_shorter = gr.Slider(64, 1024, value=384, step=16, label="Resize (shorter side)")
124
+ do_center_crop = gr.Checkbox(value=True, label="Center-crop to square")
125
+ crop_size = gr.Slider(64, 1024, value=384, step=16, label="Crop size (if center-crop)")
126
+ top_k = gr.Slider(1, 2, value=2, step=1, label="Top-K classes to display")
127
+
128
+ with gr.Row():
129
+ img_orig = gr.Image(label="Original")
130
+ img_proc = gr.Image(label="Preprocessed")
131
+ proba_pretty = gr.Label(num_top_classes=2, label="Class probabilities (Top-K)")
132
+
133
+ image_in.change(
134
+ fn=do_predict,
135
+ inputs=[image_in, resize_shorter, do_center_crop, crop_size, top_k],
136
+ outputs=[img_orig, img_proc, proba_pretty]
137
+ )
138
+
139
+ gr.Examples(
140
+ examples=[
141
+ ["https://upload.wikimedia.org/wikipedia/commons/thumb/2/23/Stop_sign_light_red.svg/640px-Stop_sign_light_red.svg.png"],
142
+ ["https://upload.wikimedia.org/wikipedia/commons/thumb/1/12/No_parking_sign.png/480px-No_parking_sign.png"],
143
+ ["https://upload.wikimedia.org/wikipedia/commons/thumb/4/4f/Stop_sign_in_the_United_States.jpg/640px-Stop_sign_in_the_United_States.jpg"],
144
+ ],
145
+ inputs=[image_in],
146
+ label="Example images",
147
+ examples_per_page=6,
148
+ cache_examples=False,
149
+ )
150
+
151
+ if __name__ == "__main__":
152
+ demo.launch()