File size: 5,352 Bytes
468b15b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b357f64
 
 
 
4a909ca
b357f64
 
468b15b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os, pathlib, zipfile, tempfile
import pandas as pd
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download, snapshot_download
import autogluon.multimodal as ag

MODEL_REPO_ID = "samder03/2025-24679-image-autogluon-predictor"

CLASS_LABELS = {
    0: "No Stop Sign",
    1: "Stop Sign",
    "class_0": "No Stop Sign",
    "class_1": "Stop Sign",
}

def _human_label(c):
    try:
        ci = int(c)
        return CLASS_LABELS.get(ci, str(c))
    except Exception:
        return CLASS_LABELS.get(str(c), str(c))

def _locate_predictor_dir_from_repo_folder(repo_dir: str) -> str:
    rd = pathlib.Path(repo_dir)
    for p in rd.rglob("predictor.pkl"):
        return str(p.parent)
    return ""

def _prepare_predictor_dir() -> str:
    repo_dir = snapshot_download(repo_id=MODEL_REPO_ID, repo_type="model")
    pred_dir = _locate_predictor_dir_from_repo_folder(repo_dir)
    if pred_dir:
        return pred_dir

    # Fallback: try to find a zip and extract
    zips = list(pathlib.Path(repo_dir).rglob("*.zip"))
    if not zips:
        raise FileNotFoundError("Could not find a predictor directory or .zip in the model repo.")
    zip_path = str(zips[0])

    workdir = tempfile.mkdtemp(prefix="ag_img_predictor_")
    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall(workdir)
    entries = list(pathlib.Path(workdir).iterdir())
    if len(entries) == 1 and entries[0].is_dir():
        return str(entries[0])
    return workdir

PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = ag.MultiModalPredictor.load(PREDICTOR_DIR)

def _ensure_rgb(img: Image.Image) -> Image.Image:
    return img.convert("RGB") if img.mode != "RGB" else img

def _resize_shorter(img: Image.Image, shorter: int) -> Image.Image:
    w, h = img.size
    if min(w, h) == shorter:
        return img
    if w < h:
        new_w = shorter
        new_h = int(h * (shorter / w))
    else:
        new_h = shorter
        new_w = int(w * (shorter / h))
    return img.resize((new_w, new_h), Image.BICUBIC)

def _center_crop(img: Image.Image, size: int) -> Image.Image:
    w, h = img.size
    side = min(w, h, size)
    left = (w - side) // 2
    top  = (h - side) // 2
    return img.crop((left, top, left + side, top + side)).resize((size, size), Image.BICUBIC)

def _validate_image(pil_img: Image.Image, max_pixels: int = 8_000_000):
    if pil_img is None:
        return False, "No image provided."
    if pil_img.width * pil_img.height > max_pixels:
        return False, f"Image too large (>{max_pixels:,} pixels). Please upload a smaller image."
    return True, ""

def preprocess(pil_img: Image.Image, resize_shorter: int, do_center_crop: bool, crop_size: int) -> Image.Image:
    img = _ensure_rgb(pil_img)
    img = _resize_shorter(img, resize_shorter)
    if do_center_crop:
        img = _center_crop(img, crop_size)
    return img

def do_predict(pil_img: Image.Image, resize_shorter: int, do_center_crop: bool, crop_size: int, top_k: int):
    ok, msg = _validate_image(pil_img)
    if not ok:
        return None, None, {"Error": 1.0}

    pre_img = preprocess(pil_img, resize_shorter, do_center_crop, crop_size)

    tmpdir = pathlib.Path(tempfile.mkdtemp(prefix="ag_img_run_"))
    orig_path = tmpdir / "original.png"
    pre_path  = tmpdir / "preprocessed.png"
    pil_img.save(orig_path)
    pre_img.save(pre_path)

    df = pd.DataFrame({"image": [str(pre_path)]})
    proba_df = PREDICTOR.predict_proba(df)
    proba_df = proba_df.rename(columns={c: _human_label(c) for c in proba_df.columns})
    row = proba_df.iloc[0]

    items = sorted(row.items(), key=lambda kv: float(kv[1]), reverse=True)[:max(1, int(top_k))]
    pretty = {k: float(v) for k, v in items}

    return Image.open(orig_path), Image.open(pre_path), pretty

with gr.Blocks(title="Stop Sign Classifier") as demo:
    gr.Markdown("# Stop Sign? — Image Classifier (Classmate Model)")
    gr.Markdown(
        "Upload a PNG/JPG (or use webcam). You’ll see the **original** and the **preprocessed** image, "
        "plus ranked class probabilities."
    )

    with gr.Row():
        image_in = gr.Image(type="pil", label="Input image (PNG/JPG)", sources=["upload", "webcam"])
        with gr.Column():
            with gr.Accordion("Inference Parameters", open=False):
                resize_shorter = gr.Slider(64, 1024, value=384, step=16, label="Resize (shorter side)")
                do_center_crop = gr.Checkbox(value=True, label="Center-crop to square")
                crop_size      = gr.Slider(64, 1024, value=384, step=16, label="Crop size (if center-crop)")
                top_k          = gr.Slider(1, 2, value=2, step=1, label="Top-K classes to display")

    with gr.Row():
        img_orig = gr.Image(label="Original")
        img_proc = gr.Image(label="Preprocessed")
    proba_pretty = gr.Label(num_top_classes=2, label="Class probabilities (Top-K)")

    image_in.change(
        fn=do_predict,
        inputs=[image_in, resize_shorter, do_center_crop, crop_size, top_k],
        outputs=[img_orig, img_proc, proba_pretty]
    )

    gr.Examples(
      examples=[["examples/example_0.png"], ["examples/example_11.png"], ["examples/example_23.png"]],
      inputs=[image_in],
      label="Example images",
      examples_per_page=6,
      cache_examples=False,
  )


if __name__ == "__main__":
    demo.launch()