ccm commited on
Commit
1994b85
·
verified ·
1 Parent(s): 1f8f859

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -52
app.py CHANGED
@@ -4,21 +4,21 @@ import zipfile # For extracting model archives
4
  import pathlib # For path manipulations
5
  import tempfile # For creating temporary files/directories
6
 
7
- import gradio as gr # For interactive UI
8
- import pandas as pd # For tabular data handling
9
- from PIL import Image # For image I/O
10
 
11
  import huggingface_hub as hf # For downloading model assets
12
  from autogluon.multimodal import MultiModalPredictor # For loading AutoGluon image classifier
13
 
14
  # Hardcoded Hub model (native zip)
15
- MODEL_REPO_ID = "ccm/2025-24679-image-autogluon-predictor" # For pointing to your HF repo
16
- ZIP_FILENAME = "autogluon_image_predictor_dir.zip" # For specifying the zipped predictor dir
17
- HF_TOKEN = os.getenv("HF_TOKEN", None) # For private repos (optional)
18
 
19
  # Local cache/extract dirs
20
- CACHE_DIR = pathlib.Path("hf_assets") # For caching downloaded assets
21
- EXTRACT_DIR = CACHE_DIR / "predictor_native" # For extracted predictor directory
22
 
23
  # Download & load the native predictor
24
  def _prepare_predictor_dir() -> str: # For ensuring predictor directory is ready
@@ -40,32 +40,32 @@ def _prepare_predictor_dir() -> str: # For ensuring predictor directory is read
40
  predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
41
  return str(predictor_root)
42
 
43
- PREDICTOR_DIR = _prepare_predictor_dir() # For path to predictor root
44
- PREDICTOR = MultiModalPredictor.load(PREDICTOR_DIR) # For loading the image classifier
45
 
46
  # Explicit class labels (edit copy as desired)
47
- CLASS_LABELS = {0: "♻️ Recycling", 1: "🗑️ Trash"} # For mapping numeric predictions to human labels
48
 
49
  # Helper to map model class -> human label
50
- def _human_label(c): # For robust label mapping
51
  try:
52
  ci = int(c)
53
  return CLASS_LABELS.get(ci, str(c))
54
  except Exception:
55
  return CLASS_LABELS.get(c, str(c))
56
 
57
- # Inference
58
- def do_predict(pil_img: Image.Image): # For running a single-image prediction
 
59
  if pil_img is None:
60
- return "No image provided.", {}, pd.DataFrame(columns=["Predicted label", "Confidence (%)"])
61
- tmpdir = pathlib.Path(tempfile.mkdtemp()) # For a temporary working directory
62
- img_path = tmpdir / "input.png" # For a temporary image path
63
- pil_img.save(img_path)
64
 
65
- df = pd.DataFrame({"image": [str(img_path)]}) # For AutoGluon expected input format
 
 
 
66
 
67
- y_pred = PREDICTOR.predict(df) # For predicted class ids
68
- pred_label = _human_label(y_pred.iloc[0]) # For human-readable predicted label
69
 
70
  try:
71
  proba_df = PREDICTOR.predict_proba(df) # For class probabilities
@@ -78,49 +78,31 @@ def do_predict(pil_img: Image.Image): # For running a single-image prediction
78
  "🗑️ Trash": float(row.get("🗑️ Trash (1)", 0.0)),
79
  }
80
  pretty_dict = dict(sorted(pretty_dict.items(), key=lambda kv: kv[1], reverse=True))
81
- confidence_pct = round(pretty_dict.get(pred_label.replace(" (0)", "").replace(" (1)", ""), 0.0) * 100, 2)
82
  except Exception:
83
- proba_df = None
84
  pretty_dict = {}
85
- confidence_pct = 100.0 # For default when probabilities are unavailable
86
 
87
- md = f"**Prediction:** {pred_label}" # For concise summary line
88
- if pretty_dict:
89
- md += f" \n**Confidence:** {confidence_pct}%"
90
-
91
- compact = pd.DataFrame([{"Predicted label": pred_label, "Confidence (%)": confidence_pct}]) # For compact table
92
- return md, pretty_dict, (proba_df if proba_df is not None else pd.DataFrame())
93
 
94
  # Representative example images (replace with your own)
95
  EXAMPLES = [
96
- ["https://c8.alamy.com/comp/2AEA4K9/a-garbage-and-recycling-can-on-the-campus-of-carnegie-mellon-university-pittsburgh-pennsylvania-usa-2AEA4K9.jpg"], # For campus bins example
97
- ["https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSvid9M7DynMcoUsX0KBMxooLvrKQJwREiw6g&s"], # For generic bin example
98
  ]
99
 
100
  # Gradio UI
101
- with gr.Blocks(
102
- title="AutoGluon Image Classification — Recycling vs Trash", # For browser tab title
103
- css=".gradio-container {max-width: 900px !important;}" # For a slightly wider layout
104
- ) as demo:
105
- gr.Markdown( # For top-of-app instructions
106
- "## AutoGluon Image Classification — Recycling vs Trash\n"
107
- f"**Model:** `{MODEL_REPO_ID}/{ZIP_FILENAME}` \n"
108
- "Drop an image and the prediction updates automatically.\n\n"
109
- "- **Class 0 → ♻️ Recycling**\n"
110
- "- **Class 1 → 🗑️ Trash**"
111
- )
112
 
113
- image_in = gr.Image(type="pil", label="Input image", sources=["upload", "clipboard", "webcam"]) # For image input
 
114
 
115
- pred_md = gr.Markdown() # For concise summary
116
- proba_pretty = gr.Label(num_top_classes=2, label="Class probabilities") # For pretty ranked probs
117
- proba_full = gr.Dataframe(label="Class probabilities (table)", interactive=False) # For full table view
118
- compact = gr.Dataframe(label="Prediction (compact)", interactive=False) # For 1-row summary table
119
 
120
- image_in.change(fn=do_predict, inputs=[image_in], outputs=[pred_md, proba_pretty, proba_full]) # For live inference
121
- image_in.change(fn=do_predict, inputs=[image_in], outputs=[pred_md, proba_pretty, compact]) # For compact table sync
122
 
123
- gr.Examples( # For clickable example images
 
124
  examples=EXAMPLES,
125
  inputs=[image_in],
126
  label="Representative examples",
@@ -128,5 +110,5 @@ with gr.Blocks(
128
  cache_examples=False,
129
  )
130
 
131
- if __name__ == "__main__": # For launching the app locally
132
  demo.launch(server_name="0.0.0.0", show_api=False)
 
4
  import pathlib # For path manipulations
5
  import tempfile # For creating temporary files/directories
6
 
7
+ import gradio # For interactive UI
8
+ import pandas # For tabular data handling
9
+ import PIL # For image I/O
10
 
11
  import huggingface_hub as hf # For downloading model assets
12
  from autogluon.multimodal import MultiModalPredictor # For loading AutoGluon image classifier
13
 
14
  # Hardcoded Hub model (native zip)
15
+ MODEL_REPO_ID = "ccm/2025-24679-image-autogluon-predictor"
16
+ ZIP_FILENAME = "autogluon_image_predictor_dir.zip"
17
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
18
 
19
  # Local cache/extract dirs
20
+ CACHE_DIR = pathlib.Path("hf_assets")
21
+ EXTRACT_DIR = CACHE_DIR / "predictor_native"
22
 
23
  # Download & load the native predictor
24
  def _prepare_predictor_dir() -> str: # For ensuring predictor directory is ready
 
40
  predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
41
  return str(predictor_root)
42
 
43
+ PREDICTOR_DIR = _prepare_predictor_dir()
44
+ PREDICTOR = MultiModalPredictor.load(PREDICTOR_DIR)
45
 
46
  # Explicit class labels (edit copy as desired)
47
+ CLASS_LABELS = {0: "♻️ Recycling", 1: "🗑️ Trash"}
48
 
49
  # Helper to map model class -> human label
50
+ def _human_label(c):
51
  try:
52
  ci = int(c)
53
  return CLASS_LABELS.get(ci, str(c))
54
  except Exception:
55
  return CLASS_LABELS.get(c, str(c))
56
 
57
+ # Do the prediction!
58
+ def do_predict(pil_img: PIL.Image.Image):
59
+ # Make sure there's actually an image to work with
60
  if pil_img is None:
61
+ return "No image provided.", {}, pandas.DataFrame(columns=["Predicted label", "Confidence (%)"])
 
 
 
62
 
63
+ # IF we have something to work with, save it and prepare the input
64
+ tmpdir = pathlib.Path(tempfile.mkdtemp())
65
+ img_path = tmpdir / "input.png"
66
+ pil_img.save(img_path)
67
 
68
+ df = pandas.DataFrame({"image": [str(img_path)]}) # For AutoGluon expected input format
 
69
 
70
  try:
71
  proba_df = PREDICTOR.predict_proba(df) # For class probabilities
 
78
  "🗑️ Trash": float(row.get("🗑️ Trash (1)", 0.0)),
79
  }
80
  pretty_dict = dict(sorted(pretty_dict.items(), key=lambda kv: kv[1], reverse=True))
 
81
  except Exception:
 
82
  pretty_dict = {}
 
83
 
84
+ return pretty_dict
 
 
 
 
 
85
 
86
  # Representative example images (replace with your own)
87
  EXAMPLES = [
88
+ ["https://c8.alamy.com/comp/2AEA4K9/a-garbage-and-recycling-can-on-the-campus-of-carnegie-mellon-university-pittsburgh-pennsylvania-usa-2AEA4K9.jpg"],
89
+ ["https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSvid9M7DynMcoUsX0KBMxooLvrKQJwREiw6g&s"],
90
  ]
91
 
92
  # Gradio UI
93
+ with gradio.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
94
 
95
+ # Interface for the incoming image
96
+ image_in = gradio.Image(type="pil", label="Input image", sources=["upload", "webcam"])
97
 
98
+ # Interface elements to show htte result and probabilities
99
+ proba_pretty = gradio.Label(num_top_classes=2, label="Class probabilities")
 
 
100
 
101
+ # Whenever a new image is uploaded, update the result
102
+ image_in.change(fn=do_predict, inputs=[image_in], outputs=[proba_pretty])
103
 
104
+ # For clickable example images
105
+ gradio.Examples(
106
  examples=EXAMPLES,
107
  inputs=[image_in],
108
  label="Representative examples",
 
110
  cache_examples=False,
111
  )
112
 
113
+ if __name__ == "__main__":
114
  demo.launch(server_name="0.0.0.0", show_api=False)