File size: 4,028 Bytes
9576f0a
 
 
 
 
 
 
 
 
 
 
 
 
 
01f458f
9576f0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8643d3e
 
9576f0a
 
 
 
 
 
c77807c
 
 
9576f0a
 
 
 
 
 
 
 
bcad485
9576f0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os  # For reading environment variables
import shutil  # For directory cleanup
import zipfile  # For extracting model archives
import pathlib  # For path manipulations
import tempfile  # For creating temporary files/directories

import gradio  # For interactive UI
import pandas  # For tabular data handling
import PIL.Image # For image I/O

import huggingface_hub # For downloading model assets
import autogluon.multimodal  # For loading AutoGluon image classifier

# Hardcoded Hub model (native zip)
MODEL_REPO_ID = "yl0628/autogluon-image-predictor"
ZIP_FILENAME  = "autogluon_image_predictor_dir.zip"

# Local cache/extract dirs
CACHE_DIR   = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"

# Download & load the native predictor
def _prepare_predictor_dir() -> str: 
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    local_zip = huggingface_hub.hf_hub_download(
        repo_id=MODEL_REPO_ID,
        filename=ZIP_FILENAME,
        repo_type="model",
        local_dir=str(CACHE_DIR),
        local_dir_use_symlinks=False,
    )
    if EXTRACT_DIR.exists():
        shutil.rmtree(EXTRACT_DIR)
    EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(local_zip, "r") as zf:
        zf.extractall(str(EXTRACT_DIR))
    contents = list(EXTRACT_DIR.iterdir())
    predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
    return str(predictor_root)

PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = autogluon.multimodal.MultiModalPredictor.load(PREDICTOR_DIR)  

# Explicit class labels (edit copy as desired)
CLASS_LABELS = {0: "no_tomato", 1: "tomato"}  

# Helper to map model class -> human label
def _human_label(c): 
    try:
        ci = int(c)
        return CLASS_LABELS.get(ci, str(c))
    except Exception:
        return CLASS_LABELS.get(c, str(c))

# Do the prediction!
def do_predict(pil_img: PIL.Image.Image):
    # Make sure there's actually an image to work with
    if pil_img is None:
        return "No image provided.", {}, pandas.DataFrame(columns=["Predicted label", "Confidence (%)"])

    # IF we have something to work with, save it and prepare the input
    tmpdir = pathlib.Path(tempfile.mkdtemp()) 
    img_path = tmpdir / "input.png" 
    pil_img.save(img_path)

    df = pandas.DataFrame({"image": [str(img_path)]})  # For AutoGluon expected input format

    # For class probabilities
    proba_df = PREDICTOR.predict_proba(df)  

    # For user-friendly column names
    proba_df = proba_df.rename(columns={0: "no_tomato", 1: "tomato"})
    row = proba_df.iloc[0]

    # For pretty ranked dict expected by gr.Label
    pretty_dict = {
        "no_tomato": float(row.get("no_tomato", 0.0)),
        "tomato": float(row.get("tomato", 0.0)),
    }

    return pretty_dict

# Representative example images! These can be local or links.
EXAMPLES = [
    ["Tomato1.jpg"], 
    ["Tomato2.jpg"],  
    ["Carrots.jpg"]
    ]

# Gradio UI
with gradio.Blocks() as demo:

    # Provide an introduction
    gradio.Markdown("# Tomato or Not?")
    gradio.Markdown("""
    This is a simple app that uses the model at kaitongg/best_tomato_model to classify whether an image has a tomato or not, 
    utilizing data found at Iris314/Food_tomatoes_dataset. To use the interface, upload an image in the area shown below.
    """)
    
    # Interface for the incoming image
    image_in = gradio.Image(type="pil", label="Input image", sources=["upload", "webcam"])  

    # Interface elements to show htte result and probabilities
    proba_pretty = gradio.Label(num_top_classes=2, label="Class probabilities") 

    # Whenever a new image is uploaded, update the result
    image_in.change(fn=do_predict, inputs=[image_in], outputs=[proba_pretty]) 

      # For clickable example images
    gradio.Examples(
        examples=EXAMPLES,
        inputs=[image_in],
        label="Representative examples",
        examples_per_page=8,
        cache_examples=False,
    )

if __name__ == "__main__": 
    demo.launch()