File size: 3,586 Bytes
a821b84
cab3625
 
 
 
6901f02
cab3625
 
 
a821b84
cab3625
 
a821b84
968e3d4
a821b84
968e3d4
 
a821b84
 
968e3d4
a821b84
968e3d4
6901f02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a821b84
 
 
 
968e3d4
a821b84
 
 
968e3d4
 
 
6901f02
968e3d4
a821b84
968e3d4
 
 
 
 
 
 
 
 
 
 
 
a821b84
968e3d4
a821b84
6901f02
 
 
a821b84
 
968e3d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a821b84
6901f02
a821b84
cab3625
968e3d4
6901f02
 
cab3625
 
6901f02
 
968e3d4
6901f02
a821b84
 
6901f02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#this was in part generated with gemini llm
import shutil
import zipfile
import pathlib
import tempfile

import gradio as gr
import pandas as pd
import PIL.Image

import huggingface_hub
import autogluon.multimodal

# --- Model and Path Configuration ---
MODEL_REPO_ID = "FaiyazAzam/24679-image-autolguon-predictor"
ZIP_FILENAME = "autogluon_image_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"

# --- Model Loading ---
def _prepare_predictor_dir() -> str:
    """Downloads, extracts, and prepares the AutoGluon model directory."""
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    local_zip = huggingface_hub.hf_hub_download(
        repo_id=MODEL_REPO_ID,
        filename=ZIP_FILENAME,
        repo_type="model",
        local_dir=str(CACHE_DIR),
        local_dir_use_symlinks=False,
    )
    if EXTRACT_DIR.exists():
        shutil.rmtree(EXTRACT_DIR)
    EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(local_zip, "r") as zf:
        zf.extractall(str(EXTRACT_DIR))
    contents = list(EXTRACT_DIR.iterdir())
    predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
    return str(predictor_root)

PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = autogluon.multimodal.MultiModalPredictor.load(PREDICTOR_DIR)

# --- Class Labels & Prediction Logic ---
CLASS_LABELS = {0: "Pen", 1: "Toy"}

def do_predict(pil_img: PIL.Image.Image):
    """
    Performs prediction on the input image and returns class probabilities.
    """
    if pil_img is None:
        return {}

    with tempfile.TemporaryDirectory() as tmpdir:
        img_path = pathlib.Path(tmpdir) / "input.png"
        pil_img.save(img_path)
        df = pd.DataFrame({"image": [str(img_path)]})
        proba_df = PREDICTOR.predict_proba(df)
        row_series = proba_df.iloc[0]
        pretty_dict = {
            CLASS_LABELS[class_idx]: prob
            for class_idx, prob in row_series.items()
            if class_idx in CLASS_LABELS
        }
        return pretty_dict

# --- Gradio User Interface ---
EXAMPLES = [
    ["https://www.penboutique.com/cdn/shop/articles/IMG_6759.jpg?v=1701974210&width=1920"],
    ["https://media.officedepot.com/images/f_auto,q_auto,e_sharpen,h_450/products/790761/790761_p_pilot_g_2_retractable_gel_pens/790761"],
    ["https://i5.walmartimages.com/seo/Disney-Pixar-Toy-Story-True-Talkers-Woody-Figure-with-15-Phrases_8c8c4a17-fb26-4f97-a284-1315c48c18ca.c35d5f2d8b932a490db9bb3f40977220.jpeg"]
]

with gr.Blocks() as demo:
    gr.Markdown("# Pen or Toy?")
    gr.Markdown(
        """
        This is a simple app that demonstrates how to use an AutoGluon Multimodal
        predictor in a Gradio Space to classify images. To use,
        just upload a photo or use one of the examples below. The result will be 
        generated automatically.
        """
    )
    with gr.Row():
        with gr.Column():
            image_in = gr.Image(type="pil", label="Input image", sources=["upload", "webcam"])
        with gr.Column():
            proba_pretty = gr.Label(num_top_classes=2, label="Class Probabilities")

    image_in.change(fn=do_predict, inputs=[image_in], outputs=[proba_pretty])

    # FIX: Added `fn` and `outputs` to allow example caching to work correctly.
    gr.Examples(
        examples=EXAMPLES,
        inputs=[image_in],
        outputs=[proba_pretty],
        fn=do_predict,
        label="Representative examples",
        examples_per_page=8,
        cache_examples=True,
    )

if __name__ == "__main__":
    demo.launch()