emkessle commited on
Commit
9576f0a
·
verified ·
1 Parent(s): fd6e2a4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os # For reading environment variables
2
+ import shutil # For directory cleanup
3
+ import zipfile # For extracting model archives
4
+ import pathlib # For path manipulations
5
+ import tempfile # For creating temporary files/directories
6
+
7
+ import gradio # For interactive UI
8
+ import pandas # For tabular data handling
9
+ import PIL.Image # For image I/O
10
+
11
+ import huggingface_hub # For downloading model assets
12
+ import autogluon.multimodal # For loading AutoGluon image classifier
13
+
14
+ # Hardcoded Hub model (native zip)
15
+ MODEL_REPO_ID = "apsora/autoML_images_data"
16
+ ZIP_FILENAME = "autogluon_image_predictor_dir.zip"
17
+
18
+ # Local cache/extract dirs
19
+ CACHE_DIR = pathlib.Path("hf_assets")
20
+ EXTRACT_DIR = CACHE_DIR / "predictor_native"
21
+
22
+ # Download & load the native predictor
23
+ def _prepare_predictor_dir() -> str:
24
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
25
+ local_zip = huggingface_hub.hf_hub_download(
26
+ repo_id=MODEL_REPO_ID,
27
+ filename=ZIP_FILENAME,
28
+ repo_type="model",
29
+ local_dir=str(CACHE_DIR),
30
+ local_dir_use_symlinks=False,
31
+ )
32
+ if EXTRACT_DIR.exists():
33
+ shutil.rmtree(EXTRACT_DIR)
34
+ EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
35
+ with zipfile.ZipFile(local_zip, "r") as zf:
36
+ zf.extractall(str(EXTRACT_DIR))
37
+ contents = list(EXTRACT_DIR.iterdir())
38
+ predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
39
+ return str(predictor_root)
40
+
41
+ PREDICTOR_DIR = _prepare_predictor_dir()
42
+ PREDICTOR = autogluon.multimodal.MultiModalPredictor.load(PREDICTOR_DIR)
43
+
44
+ # Explicit class labels (edit copy as desired)
45
+ CLASS_LABELS = {0: "no_tomato", 1: "tomato"}
46
+
47
+ # Helper to map model class -> human label
48
+ def _human_label(c):
49
+ try:
50
+ ci = int(c)
51
+ return CLASS_LABELS.get(ci, str(c))
52
+ except Exception:
53
+ return CLASS_LABELS.get(c, str(c))
54
+
55
+ # Do the prediction!
56
+ def do_predict(pil_img: PIL.Image.Image):
57
+ # Make sure there's actually an image to work with
58
+ if pil_img is None:
59
+ return "No image provided.", {}, pandas.DataFrame(columns=["Predicted label", "Confidence (%)"])
60
+
61
+ # IF we have something to work with, save it and prepare the input
62
+ tmpdir = pathlib.Path(tempfile.mkdtemp())
63
+ img_path = tmpdir / "input.png"
64
+ pil_img.save(img_path)
65
+
66
+ df = pandas.DataFrame({"image": [str(img_path)]}) # For AutoGluon expected input format
67
+
68
+ # For class probabilities
69
+ proba_df = PREDICTOR.predict_proba(df)
70
+
71
+ # For user-friendly column names
72
+ proba_df = proba_df.rename(columns={0: "no_tomato", 1: "tomato"})
73
+ row = proba_df.iloc[0]
74
+
75
+ # For pretty ranked dict expected by gr.Label
76
+ pretty_dict = {
77
+ "no_tomato": float(row.get("♻️ Recycling (0)", 0.0)),
78
+ "tomato": float(row.get("🗑️ Trash (1)", 0.0)),
79
+ }
80
+
81
+ return pretty_dict
82
+
83
+ # Representative example images! These can be local or links.
84
+ EXAMPLES = [
85
+ ["https://datasets-server.huggingface.co/assets/Iris314/Food_tomatoes_dataset/--/bcca6b7b9e4be8378efc59e98a95f1c44a5055ae/--/default/original/1/image/image.jpg?Expires=1758728666&Signature=lGRInfYT2bhH2zr~T-oKF6vTjts~uJz7rQFTF4eWoP2Yz-XzKQz2hpB-83pGTQm7k-TLKbgTwbKUHxoliG0V140W0YPMhlW-HQ6LBLAszEeSNwzwzaTNzQIso4bKfCy9KUBBMTFqRfbPA8QNw3Z~M4BiHS~7otx1SavJa-L53FbREbA4VZKanGNwyd0KSt6e~FfRZDgsKwbtEVoFuFQjID0q3GwBMLUXMQCk3gXCXftCiw5x1Et1EldmNEJsK38MA7ohKLcbr130NKyH0lo960QBcuVIkvhi7obweDKucGmRMUidmIPALl9I5ak8rAAy~jRXEr-nPR-c28foVUsMcw__&Key-Pair-Id=K3EI6M078Z3AC3"],
86
+ ["https://datasets-server.huggingface.co/assets/Iris314/Food_tomatoes_dataset/--/bcca6b7b9e4be8378efc59e98a95f1c44a5055ae/--/default/original/10/image/image.jpg?Expires=1758728666&Signature=bUO6jZnrnaGh2mflzrFtRwO~SXRr5G-Sy3Tdo1~W1oPS8yYBPJAMKGrOdii7~Y71g8XihQVtFecpyawyt0RAh2UKrdczFockU3vgBsA6YujYCqZkSYC0KFSM--wO5pGtGn1AHJcBw48MZpdfSeUbxJoZy6z5YecOYrtD6Rz7eUhaghrT1S-lIwqW~gOqyRy8Ue-3OC5DtkBtuKq5LCkXXtssfDBGB8qFOjG1Vba68Gi9XfLXyBashf7tFTgNXjl3iXFjDtEkuo17siZvh1dK~ImfNbKVcBVv1vfqUC2lQWl7Jc4jaK4h2lYMFuevCfaurwELOUZVt2I8KIb0FWdjMQ__&Key-Pair-Id=K3EI6M078Z3AC3"],
87
+ ["https://datasets-server.huggingface.co/assets/Iris314/Food_tomatoes_dataset/--/bcca6b7b9e4be8378efc59e98a95f1c44a5055ae/--/default/original/43/image/image.jpg?Expires=1758728666&Signature=gMODl-rcaEShOGiZGyASzkT6idSpnsw0J5kJC5wKVthRIYP~RuBogRgeJ3XXWzWsflcaq75Guo1vQ20M5TzIdtPoVbMBcN65pTZJhfKZN9VLkA04ujNMCqOabKUWpVvR~UqqUzz8NbCpBE0Mwiu36vb8bFjVVpzaxE6BW-m0iEAHasUgttnXn6jjB1-OlK6z0SnCsLToNYcl4X7OjaJ0q0NSU1I6LusjYlPqmOTIvRl8ba2YXzE22X6DGlhEr3WNxYUzjT8~6nbkKUO2kWHi~2~i8aK1QNh4prNMjgUzoHSZcuBm2gDCdalpuXQWIqeLiaHAmvuybQclG4zRLpQhHw__&Key-Pair-Id=K3EI6M078Z3AC3"]
88
+ ]
89
+
90
+ # Gradio UI
91
+ with gradio.Blocks() as demo:
92
+
93
+ # Provide an introduction
94
+ gradio.Markdown("# Tomato or Not?")
95
+ gradio.Markdown("""
96
+ This is a simple app that uses the model at apsora/autoML_images_data to classify whether an image has a tomato or not,
97
+ utilizing data found at Iris314/Food_tomatoes_dataset. To use the interface, upload an image in the area shown below.
98
+ """)
99
+
100
+ # Interface for the incoming image
101
+ image_in = gradio.Image(type="pil", label="Input image", sources=["upload", "webcam"])
102
+
103
+ # Interface elements to show htte result and probabilities
104
+ proba_pretty = gradio.Label(num_top_classes=2, label="Class probabilities")
105
+
106
+ # Whenever a new image is uploaded, update the result
107
+ image_in.change(fn=do_predict, inputs=[image_in], outputs=[proba_pretty])
108
+
109
+ # For clickable example images
110
+ gradio.Examples(
111
+ examples=EXAMPLES,
112
+ inputs=[image_in],
113
+ label="Representative examples",
114
+ examples_per_page=8,
115
+ cache_examples=False,
116
+ )
117
+
118
+ if __name__ == "__main__":
119
+ demo.launch()