Spaces:
Sleeping
Sleeping
Upload app.py with huggingface_hub
Browse files
app.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, pathlib, zipfile, tempfile
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 6 |
+
import autogluon.multimodal as ag
|
| 7 |
+
|
| 8 |
+
MODEL_REPO_ID = "samder03/2025-24679-image-autogluon-predictor"
|
| 9 |
+
|
| 10 |
+
CLASS_LABELS = {
|
| 11 |
+
0: "No Stop Sign",
|
| 12 |
+
1: "Stop Sign",
|
| 13 |
+
"class_0": "No Stop Sign",
|
| 14 |
+
"class_1": "Stop Sign",
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
def _human_label(c):
|
| 18 |
+
try:
|
| 19 |
+
ci = int(c)
|
| 20 |
+
return CLASS_LABELS.get(ci, str(c))
|
| 21 |
+
except Exception:
|
| 22 |
+
return CLASS_LABELS.get(str(c), str(c))
|
| 23 |
+
|
| 24 |
+
def _locate_predictor_dir_from_repo_folder(repo_dir: str) -> str:
|
| 25 |
+
rd = pathlib.Path(repo_dir)
|
| 26 |
+
for p in rd.rglob("predictor.pkl"):
|
| 27 |
+
return str(p.parent)
|
| 28 |
+
return ""
|
| 29 |
+
|
| 30 |
+
def _prepare_predictor_dir() -> str:
|
| 31 |
+
repo_dir = snapshot_download(repo_id=MODEL_REPO_ID, repo_type="model")
|
| 32 |
+
pred_dir = _locate_predictor_dir_from_repo_folder(repo_dir)
|
| 33 |
+
if pred_dir:
|
| 34 |
+
return pred_dir
|
| 35 |
+
|
| 36 |
+
# Fallback: try to find a zip and extract
|
| 37 |
+
zips = list(pathlib.Path(repo_dir).rglob("*.zip"))
|
| 38 |
+
if not zips:
|
| 39 |
+
raise FileNotFoundError("Could not find a predictor directory or .zip in the model repo.")
|
| 40 |
+
zip_path = str(zips[0])
|
| 41 |
+
|
| 42 |
+
workdir = tempfile.mkdtemp(prefix="ag_img_predictor_")
|
| 43 |
+
with zipfile.ZipFile(zip_path, "r") as zf:
|
| 44 |
+
zf.extractall(workdir)
|
| 45 |
+
entries = list(pathlib.Path(workdir).iterdir())
|
| 46 |
+
if len(entries) == 1 and entries[0].is_dir():
|
| 47 |
+
return str(entries[0])
|
| 48 |
+
return workdir
|
| 49 |
+
|
| 50 |
+
PREDICTOR_DIR = _prepare_predictor_dir()
|
| 51 |
+
PREDICTOR = ag.MultiModalPredictor.load(PREDICTOR_DIR)
|
| 52 |
+
|
| 53 |
+
def _ensure_rgb(img: Image.Image) -> Image.Image:
|
| 54 |
+
return img.convert("RGB") if img.mode != "RGB" else img
|
| 55 |
+
|
| 56 |
+
def _resize_shorter(img: Image.Image, shorter: int) -> Image.Image:
|
| 57 |
+
w, h = img.size
|
| 58 |
+
if min(w, h) == shorter:
|
| 59 |
+
return img
|
| 60 |
+
if w < h:
|
| 61 |
+
new_w = shorter
|
| 62 |
+
new_h = int(h * (shorter / w))
|
| 63 |
+
else:
|
| 64 |
+
new_h = shorter
|
| 65 |
+
new_w = int(w * (shorter / h))
|
| 66 |
+
return img.resize((new_w, new_h), Image.BICUBIC)
|
| 67 |
+
|
| 68 |
+
def _center_crop(img: Image.Image, size: int) -> Image.Image:
|
| 69 |
+
w, h = img.size
|
| 70 |
+
side = min(w, h, size)
|
| 71 |
+
left = (w - side) // 2
|
| 72 |
+
top = (h - side) // 2
|
| 73 |
+
return img.crop((left, top, left + side, top + side)).resize((size, size), Image.BICUBIC)
|
| 74 |
+
|
| 75 |
+
def _validate_image(pil_img: Image.Image, max_pixels: int = 8_000_000):
|
| 76 |
+
if pil_img is None:
|
| 77 |
+
return False, "No image provided."
|
| 78 |
+
if pil_img.width * pil_img.height > max_pixels:
|
| 79 |
+
return False, f"Image too large (>{max_pixels:,} pixels). Please upload a smaller image."
|
| 80 |
+
return True, ""
|
| 81 |
+
|
| 82 |
+
def preprocess(pil_img: Image.Image, resize_shorter: int, do_center_crop: bool, crop_size: int) -> Image.Image:
|
| 83 |
+
img = _ensure_rgb(pil_img)
|
| 84 |
+
img = _resize_shorter(img, resize_shorter)
|
| 85 |
+
if do_center_crop:
|
| 86 |
+
img = _center_crop(img, crop_size)
|
| 87 |
+
return img
|
| 88 |
+
|
| 89 |
+
def do_predict(pil_img: Image.Image, resize_shorter: int, do_center_crop: bool, crop_size: int, top_k: int):
|
| 90 |
+
ok, msg = _validate_image(pil_img)
|
| 91 |
+
if not ok:
|
| 92 |
+
return None, None, {"Error": 1.0}
|
| 93 |
+
|
| 94 |
+
pre_img = preprocess(pil_img, resize_shorter, do_center_crop, crop_size)
|
| 95 |
+
|
| 96 |
+
tmpdir = pathlib.Path(tempfile.mkdtemp(prefix="ag_img_run_"))
|
| 97 |
+
orig_path = tmpdir / "original.png"
|
| 98 |
+
pre_path = tmpdir / "preprocessed.png"
|
| 99 |
+
pil_img.save(orig_path)
|
| 100 |
+
pre_img.save(pre_path)
|
| 101 |
+
|
| 102 |
+
df = pd.DataFrame({"image": [str(pre_path)]})
|
| 103 |
+
proba_df = PREDICTOR.predict_proba(df)
|
| 104 |
+
proba_df = proba_df.rename(columns={c: _human_label(c) for c in proba_df.columns})
|
| 105 |
+
row = proba_df.iloc[0]
|
| 106 |
+
|
| 107 |
+
items = sorted(row.items(), key=lambda kv: float(kv[1]), reverse=True)[:max(1, int(top_k))]
|
| 108 |
+
pretty = {k: float(v) for k, v in items}
|
| 109 |
+
|
| 110 |
+
return Image.open(orig_path), Image.open(pre_path), pretty
|
| 111 |
+
|
| 112 |
+
with gr.Blocks(title="Stop Sign Classifier") as demo:
|
| 113 |
+
gr.Markdown("# Stop Sign? — Image Classifier (Classmate Model)")
|
| 114 |
+
gr.Markdown(
|
| 115 |
+
"Upload a PNG/JPG (or use webcam). You’ll see the **original** and the **preprocessed** image, "
|
| 116 |
+
"plus ranked class probabilities."
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
with gr.Row():
|
| 120 |
+
image_in = gr.Image(type="pil", label="Input image (PNG/JPG)", sources=["upload", "webcam"])
|
| 121 |
+
with gr.Column():
|
| 122 |
+
with gr.Accordion("Inference Parameters", open=False):
|
| 123 |
+
resize_shorter = gr.Slider(64, 1024, value=384, step=16, label="Resize (shorter side)")
|
| 124 |
+
do_center_crop = gr.Checkbox(value=True, label="Center-crop to square")
|
| 125 |
+
crop_size = gr.Slider(64, 1024, value=384, step=16, label="Crop size (if center-crop)")
|
| 126 |
+
top_k = gr.Slider(1, 2, value=2, step=1, label="Top-K classes to display")
|
| 127 |
+
|
| 128 |
+
with gr.Row():
|
| 129 |
+
img_orig = gr.Image(label="Original")
|
| 130 |
+
img_proc = gr.Image(label="Preprocessed")
|
| 131 |
+
proba_pretty = gr.Label(num_top_classes=2, label="Class probabilities (Top-K)")
|
| 132 |
+
|
| 133 |
+
image_in.change(
|
| 134 |
+
fn=do_predict,
|
| 135 |
+
inputs=[image_in, resize_shorter, do_center_crop, crop_size, top_k],
|
| 136 |
+
outputs=[img_orig, img_proc, proba_pretty]
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
gr.Examples(
|
| 140 |
+
examples=[
|
| 141 |
+
["https://upload.wikimedia.org/wikipedia/commons/thumb/2/23/Stop_sign_light_red.svg/640px-Stop_sign_light_red.svg.png"],
|
| 142 |
+
["https://upload.wikimedia.org/wikipedia/commons/thumb/1/12/No_parking_sign.png/480px-No_parking_sign.png"],
|
| 143 |
+
["https://upload.wikimedia.org/wikipedia/commons/thumb/4/4f/Stop_sign_in_the_United_States.jpg/640px-Stop_sign_in_the_United_States.jpg"],
|
| 144 |
+
],
|
| 145 |
+
inputs=[image_in],
|
| 146 |
+
label="Example images",
|
| 147 |
+
examples_per_page=6,
|
| 148 |
+
cache_examples=False,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
demo.launch()
|