File size: 3,535 Bytes
9f13279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import pathlib, shutil, zipfile, tempfile
import pandas
import gradio
import huggingface_hub
import autogluon.multimodal
import PIL.Image

MODEL_REPO_ID = "george2cool36/hw2_image_automl_autogluon"
ZIP_FILENAME  = "ag_image_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"

CLASS_LABELS = {0: "πŸ›‘ Has Stop Sign", 1: "βœ… No Stop Sign"}

def _prepare_predictor_dir() -> str:
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    local_zip = huggingface_hub.hf_hub_download(
        repo_id=MODEL_REPO_ID,
        filename=ZIP_FILENAME,
        repo_type="model",
        local_dir=str(CACHE_DIR),
        local_dir_use_symlinks=False,
    )
    if EXTRACT_DIR.exists():
        shutil.rmtree(EXTRACT_DIR)
    EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(local_zip, "r") as zf:
        zf.extractall(str(EXTRACT_DIR))
    contents = list(EXTRACT_DIR.iterdir())
    predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
    return str(predictor_root)

PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = autogluon.multimodal.MultiModalPredictor.load(PREDICTOR_DIR)

def _human_label(c):
    try:
        ci = int(c)
        return CLASS_LABELS.get(ci, str(c))
    except Exception:
        return CLASS_LABELS.get(c, str(c))

def do_predict(pil_img: PIL.Image.Image):
    if pil_img is None:
        return {}, "No image provided."

    tmpdir = pathlib.Path(tempfile.mkdtemp())
    img_path = tmpdir / "input.png"
    pil_img.save(img_path)

    df = pandas.DataFrame({"image": [str(img_path)]})

    proba_df = PREDICTOR.predict_proba(df)
    proba_df = proba_df.rename(columns={0: "πŸ›‘ Has Stop Sign (0)", 1: "βœ… No Stop Sign (1)"})
    row = proba_df.iloc[0]

    pretty_dict = {
        "πŸ›‘ Has Stop Sign": float(row.get("πŸ›‘ Has Stop Sign (0)", 0.0)),
        "βœ… No Stop Sign": float(row.get("βœ… No Stop Sign (1)", 0.0)),
    }

    predicted_class = PREDICTOR.predict(df).iloc[0]
    pred_label = _human_label(predicted_class)

    md = f"**Prediction:** {pred_label}"
    if pretty_dict:
        md += f"  \n**Confidence:** {round(pretty_dict.get(pred_label, 0.0) * 100, 2)}%"


    return pretty_dict, md

EXAMPLES = [
    ["https://www.kingsrivercasting.com/images/stories/virtuemart/product/STOP%20SIGN%20(5).jpg"],
    ["https://www.trafficsafetywarehouse.com/Resources/images/traffic-sign-shapes.jpeg"],
    ["https://di-uploads-pod16.dealerinspire.com/toyotaofnorthcharlotte/uploads/2020/08/yield-road-sign.jpg"]
]


with gradio.Blocks() as demo:
    gradio.Markdown("# Has Stop Sign or Not?")
    gradio.Markdown(
        "This is a simple app that demonstrates how to use an autogluon multimodal"
        "predictor in a gradio space to predict whether an image contains a stop sign. To use,"
        "just upload a photo. The result should be generated automatically."
    )

    image_in = gradio.Image(type="pil", label="Input image", sources=["upload", "webcam"])

    proba_pretty = gradio.Label(num_top_classes=2, label="Class probabilities")
    prediction_output = gradio.Markdown()


    inputs = [image_in]
    outputs = [proba_pretty, prediction_output]
    for comp in inputs:
        comp.change(fn=do_predict, inputs=inputs, outputs=outputs)

    gradio.Examples(
        examples=EXAMPLES,
        inputs=inputs,
        label="Representative examples",
        examples_per_page=8,
        cache_examples=False,
    )

if __name__ == "__main__":
    demo.launch(debug=False)