Spaces:
Sleeping
Sleeping
| import os, pathlib, zipfile, tempfile | |
| import pandas as pd | |
| import gradio as gr | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| import autogluon.multimodal as ag | |
| MODEL_REPO_ID = "samder03/2025-24679-image-autogluon-predictor" | |
| CLASS_LABELS = { | |
| 0: "No Stop Sign", | |
| 1: "Stop Sign", | |
| "class_0": "No Stop Sign", | |
| "class_1": "Stop Sign", | |
| } | |
| def _human_label(c): | |
| try: | |
| ci = int(c) | |
| return CLASS_LABELS.get(ci, str(c)) | |
| except Exception: | |
| return CLASS_LABELS.get(str(c), str(c)) | |
| def _locate_predictor_dir_from_repo_folder(repo_dir: str) -> str: | |
| rd = pathlib.Path(repo_dir) | |
| for p in rd.rglob("predictor.pkl"): | |
| return str(p.parent) | |
| return "" | |
| def _prepare_predictor_dir() -> str: | |
| repo_dir = snapshot_download(repo_id=MODEL_REPO_ID, repo_type="model") | |
| pred_dir = _locate_predictor_dir_from_repo_folder(repo_dir) | |
| if pred_dir: | |
| return pred_dir | |
| # Fallback: try to find a zip and extract | |
| zips = list(pathlib.Path(repo_dir).rglob("*.zip")) | |
| if not zips: | |
| raise FileNotFoundError("Could not find a predictor directory or .zip in the model repo.") | |
| zip_path = str(zips[0]) | |
| workdir = tempfile.mkdtemp(prefix="ag_img_predictor_") | |
| with zipfile.ZipFile(zip_path, "r") as zf: | |
| zf.extractall(workdir) | |
| entries = list(pathlib.Path(workdir).iterdir()) | |
| if len(entries) == 1 and entries[0].is_dir(): | |
| return str(entries[0]) | |
| return workdir | |
| PREDICTOR_DIR = _prepare_predictor_dir() | |
| PREDICTOR = ag.MultiModalPredictor.load(PREDICTOR_DIR) | |
| def _ensure_rgb(img: Image.Image) -> Image.Image: | |
| return img.convert("RGB") if img.mode != "RGB" else img | |
| def _resize_shorter(img: Image.Image, shorter: int) -> Image.Image: | |
| w, h = img.size | |
| if min(w, h) == shorter: | |
| return img | |
| if w < h: | |
| new_w = shorter | |
| new_h = int(h * (shorter / w)) | |
| else: | |
| new_h = shorter | |
| new_w = int(w * (shorter / h)) | |
| return img.resize((new_w, new_h), Image.BICUBIC) | |
| def _center_crop(img: Image.Image, size: int) -> Image.Image: | |
| w, h = img.size | |
| side = min(w, h, size) | |
| left = (w - side) // 2 | |
| top = (h - side) // 2 | |
| return img.crop((left, top, left + side, top + side)).resize((size, size), Image.BICUBIC) | |
| def _validate_image(pil_img: Image.Image, max_pixels: int = 8_000_000): | |
| if pil_img is None: | |
| return False, "No image provided." | |
| if pil_img.width * pil_img.height > max_pixels: | |
| return False, f"Image too large (>{max_pixels:,} pixels). Please upload a smaller image." | |
| return True, "" | |
| def preprocess(pil_img: Image.Image, resize_shorter: int, do_center_crop: bool, crop_size: int) -> Image.Image: | |
| img = _ensure_rgb(pil_img) | |
| img = _resize_shorter(img, resize_shorter) | |
| if do_center_crop: | |
| img = _center_crop(img, crop_size) | |
| return img | |
| def do_predict(pil_img: Image.Image, resize_shorter: int, do_center_crop: bool, crop_size: int, top_k: int): | |
| ok, msg = _validate_image(pil_img) | |
| if not ok: | |
| return None, None, {"Error": 1.0} | |
| pre_img = preprocess(pil_img, resize_shorter, do_center_crop, crop_size) | |
| tmpdir = pathlib.Path(tempfile.mkdtemp(prefix="ag_img_run_")) | |
| orig_path = tmpdir / "original.png" | |
| pre_path = tmpdir / "preprocessed.png" | |
| pil_img.save(orig_path) | |
| pre_img.save(pre_path) | |
| df = pd.DataFrame({"image": [str(pre_path)]}) | |
| proba_df = PREDICTOR.predict_proba(df) | |
| proba_df = proba_df.rename(columns={c: _human_label(c) for c in proba_df.columns}) | |
| row = proba_df.iloc[0] | |
| items = sorted(row.items(), key=lambda kv: float(kv[1]), reverse=True)[:max(1, int(top_k))] | |
| pretty = {k: float(v) for k, v in items} | |
| return Image.open(orig_path), Image.open(pre_path), pretty | |
| with gr.Blocks(title="Stop Sign Classifier") as demo: | |
| gr.Markdown("# Stop Sign? — Image Classifier (Classmate Model)") | |
| gr.Markdown( | |
| "Upload a PNG/JPG (or use webcam). You’ll see the **original** and the **preprocessed** image, " | |
| "plus ranked class probabilities." | |
| ) | |
| with gr.Row(): | |
| image_in = gr.Image(type="pil", label="Input image (PNG/JPG)", sources=["upload", "webcam"]) | |
| with gr.Column(): | |
| with gr.Accordion("Inference Parameters", open=False): | |
| resize_shorter = gr.Slider(64, 1024, value=384, step=16, label="Resize (shorter side)") | |
| do_center_crop = gr.Checkbox(value=True, label="Center-crop to square") | |
| crop_size = gr.Slider(64, 1024, value=384, step=16, label="Crop size (if center-crop)") | |
| top_k = gr.Slider(1, 2, value=2, step=1, label="Top-K classes to display") | |
| with gr.Row(): | |
| img_orig = gr.Image(label="Original") | |
| img_proc = gr.Image(label="Preprocessed") | |
| proba_pretty = gr.Label(num_top_classes=2, label="Class probabilities (Top-K)") | |
| image_in.change( | |
| fn=do_predict, | |
| inputs=[image_in, resize_shorter, do_center_crop, crop_size, top_k], | |
| outputs=[img_orig, img_proc, proba_pretty] | |
| ) | |
| gr.Examples( | |
| examples=[["examples/example_0.png"], ["examples/example_11.png"], ["examples/example_23.png"]], | |
| inputs=[image_in], | |
| label="Example images", | |
| examples_per_page=6, | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |