jerpelhan commited on
Commit
1bf734e
·
1 Parent(s): ef5932e

Updating sdk version and resolving compability issues -- image_prompter is removed, gradio_image_annotation added

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. demo_gradio.py +213 -44
  3. requirements.txt +2 -2
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: GeCo2 Gradio Demo
3
  sdk: gradio
4
- sdk_version: "4.44.1"
5
  python_version: "3.10.13"
6
  app_file: demo_gradio.py
7
  ---
 
1
  ---
2
  title: GeCo2 Gradio Demo
3
  sdk: gradio
4
+ sdk_version: "5.50.0"
5
  python_version: "3.10.13"
6
  app_file: demo_gradio.py
7
  ---
demo_gradio.py CHANGED
@@ -1,7 +1,7 @@
1
  import spaces
2
  import torch
3
  import gradio as gr
4
- from gradio_image_prompter import ImagePrompter
5
  from torch.nn import DataParallel
6
  from models.counter_infer import build_model
7
  from utils.arg_parser import get_argparser
@@ -14,10 +14,55 @@ import numpy as np
14
  import colorsys
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  _MODEL = None
18
  _ARGS = None
19
  _WEIGHTS_PATH = None
20
 
 
21
  def _get_args():
22
  global _ARGS
23
  if _ARGS is None:
@@ -26,6 +71,7 @@ def _get_args():
26
  _ARGS = args
27
  return _ARGS
28
 
 
29
  def _get_weights_path():
30
  global _WEIGHTS_PATH
31
  if _WEIGHTS_PATH is None:
@@ -36,6 +82,7 @@ def _get_weights_path():
36
  )
37
  return _WEIGHTS_PATH
38
 
 
39
  def get_model_on_device(device: torch.device):
40
  """
41
  Lazily build and load model, then move to the requested device.
@@ -63,22 +110,140 @@ def get_model_on_device(device: torch.device):
63
  return _MODEL
64
 
65
 
66
- # **Function to Process Image Once**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  @spaces.GPU
68
  def process_image_once(inputs, enable_mask):
69
-
 
 
 
 
 
 
 
70
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
  model = get_model_on_device(device)
72
 
 
 
 
 
73
  image = inputs["image"]
74
- drawn_boxes = inputs["points"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  image_tensor = torch.tensor(image).to(device)
76
  image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0
77
  image_tensor = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image_tensor)
78
 
79
- bboxes_tensor = torch.tensor([[box[0], box[1], box[3], box[4]] for box in drawn_boxes], dtype=torch.float32).to(
80
- device
81
- )
 
82
 
83
  img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0)
84
  img = img.unsqueeze(0).to(device)
@@ -88,13 +253,8 @@ def process_image_once(inputs, enable_mask):
88
  model.module.return_masks = enable_mask
89
  outputs, _, _, _, masks = model(img, bboxes)
90
 
91
- # ------------------------------------------------------------------
92
- # ZeroGPU requirement: return ONLY CPU-native objects to main process.
93
- # Do NOT return CUDA tensors, and avoid returning output dicts that may
94
- # contain additional CUDA tensors beyond pred_boxes/box_v.
95
- # ------------------------------------------------------------------
96
  out0 = outputs[0]
97
-
98
  pred_boxes_cpu = out0["pred_boxes"].detach().float().cpu()
99
  box_v_cpu = out0["box_v"].detach().float().cpu()
100
 
@@ -108,7 +268,6 @@ def process_image_once(inputs, enable_mask):
108
  else:
109
  masks_cpu = [None]
110
 
111
- # img is only used for shape in post_process, so return a CPU tensor
112
  img_cpu = img.detach().cpu()
113
 
114
  return image, outputs_cpu, masks_cpu, img_cpu, float(scale), drawn_boxes
@@ -123,22 +282,13 @@ def _hsv_to_rgb255(h, s, v):
123
 
124
 
125
  def instance_colors(i: int):
126
- """
127
- Pastel palette per instance.
128
- - Mask: pastel fill
129
- - Box: same hue, slightly more saturated (but still pastel-ish)
130
- Deterministic hue stepping (golden ratio) for stable and distinct colors.
131
- """
132
  h = (i * 0.618033988749895) % 1.0
133
- mask_rgb = _hsv_to_rgb255(h, s=0.28, v=1.00) # soft pastel
134
- box_rgb = _hsv_to_rgb255(h, s=0.42, v=0.95) # slightly stronger pastel
135
  return mask_rgb, box_rgb
136
 
137
 
138
  def overlay_single_mask(base_rgba: Image.Image, mask_bool: np.ndarray, rgb, alpha=0.45):
139
- """
140
- Alpha-composite a single instance mask (boolean HxW) in given rgb onto base_rgba.
141
- """
142
  if mask_bool.dtype != np.bool_:
143
  mask_bool = mask_bool.astype(bool)
144
 
@@ -153,12 +303,19 @@ def overlay_single_mask(base_rgba: Image.Image, mask_bool: np.ndarray, rgb, alph
153
  return Image.alpha_composite(base_rgba, overlay_img)
154
 
155
 
156
- # **Post-process and Update Output**
 
 
157
  def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold):
158
  idx = 0
159
  threshold = 1 / threshold
160
 
161
  score = outputs[idx]["box_v"]
 
 
 
 
 
162
  score_mask = score > score.max() / threshold
163
 
164
  keep = ops.nms(
@@ -171,20 +328,17 @@ def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, th
171
  pred_boxes = torch.clamp(pred_boxes, 0, 1)
172
  pred_boxes = (pred_boxes / scale * img.shape[-1]).tolist()
173
 
174
- # Base image as RGBA for compositing
175
  image = Image.fromarray((image).astype(np.uint8)).convert("RGBA")
176
 
177
- # --- Masks: per-instance, pastel, matching box hue ---
178
  if enable_mask and masks is not None and masks[idx] is not None:
179
  masks_sel = masks[idx][score_mask[0]] if score_mask.ndim > 1 else masks[idx][score_mask]
180
- masks_sel = masks_sel[keep] # align with pred_boxes
181
 
182
  target_h = int(img.shape[2] / scale)
183
  target_w = int(img.shape[3] / scale)
184
  resize_nearest = T.Resize((target_h, target_w), interpolation=T.InterpolationMode.NEAREST)
185
 
186
  W, H = image.size
187
-
188
  for i in range(masks_sel.shape[0]):
189
  mask_i = masks_sel[i]
190
  if mask_i.ndim == 3:
@@ -197,37 +351,38 @@ def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, th
197
  mask_rgb, _ = instance_colors(i)
198
  image = overlay_single_mask(image, mask_bool, mask_rgb, alpha=0.45)
199
 
200
- # --- Boxes: thin, pastel, no labels/text ---
201
  draw = ImageDraw.Draw(image)
202
- box_width = 2 # thin and clean
203
 
204
  for i, box in enumerate(pred_boxes):
205
  _, box_rgb = instance_colors(i)
206
  x1, y1, x2, y2 = map(float, box)
207
  draw.rectangle([x1, y1, x2, y2], outline=box_rgb, width=box_width)
208
 
209
- # --- Exemplar boxes (user-drawn): keep clear but unobtrusive, no text ---
210
- exemplar_outline = (255, 255, 255, 255) # white
211
- exemplar_inner = (0, 0, 0, 255) # black
212
  for box in drawn_boxes:
213
  x1, y1, x2, y2 = box[0], box[1], box[3], box[4]
214
  draw.rectangle([x1, y1, x2, y2], outline=exemplar_outline, width=2)
215
  draw.rectangle([x1 + 1, y1 + 1, x2 - 1, y2 - 1], outline=exemplar_inner, width=1)
216
 
217
- # Return without any text/labels on the image
218
  return image.convert("RGB"), len(pred_boxes)
219
 
220
 
221
- iface = gr.Blocks(title="GeCo2 Gradio Demo")
 
 
 
 
 
 
 
222
 
223
  with iface:
224
  gr.Markdown(
225
  """
226
  # GeCo2: Generalized-Scale Object Counting with Gradual Query Aggregation
227
-
228
- GeCo2 is a few-shot, category-agnostic detection counter. With only a small number of exemplars, GeCo2 can detect and count all instances of the target object in an image wihtout any retraining.
229
-
230
-
231
  1) Upload an image.
232
  2) Draw bounding boxes on the target object (preferably ~3 instances).
233
  3) Click **Count**.
@@ -244,7 +399,17 @@ GeCo2 is a few-shot, category-agnostic detection counter. With only a small numb
244
  drawn_boxes_state = gr.State()
245
 
246
  with gr.Row():
247
- image_prompter = ImagePrompter()
 
 
 
 
 
 
 
 
 
 
248
  image_output = gr.Image(type="pil")
249
 
250
  with gr.Row():
@@ -256,6 +421,8 @@ GeCo2 is a few-shot, category-agnostic detection counter. With only a small numb
256
 
257
  def initial_process(inputs, enable_mask, threshold):
258
  image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
 
 
259
  return (
260
  *post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold),
261
  image,
@@ -267,11 +434,13 @@ GeCo2 is a few-shot, category-agnostic detection counter. With only a small numb
267
  )
268
 
269
  def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask):
 
 
270
  return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
271
 
272
  count_button.click(
273
  initial_process,
274
- [image_prompter, enable_mask, threshold],
275
  [image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state],
276
  )
277
 
@@ -288,4 +457,4 @@ GeCo2 is a few-shot, category-agnostic detection counter. With only a small numb
288
  )
289
 
290
  if __name__ == "__main__":
291
- iface.launch()
 
1
  import spaces
2
  import torch
3
  import gradio as gr
4
+ from gradio_image_annotation import image_annotator
5
  from torch.nn import DataParallel
6
  from models.counter_infer import build_model
7
  from utils.arg_parser import get_argparser
 
14
  import colorsys
15
 
16
 
17
+ # -----------------------------
18
+ # Minimal UI + force "Create" mode (press C a few times)
19
+ # -----------------------------
20
+ JS_FORCE_CREATE_MODE = r"""
21
+ function () {
22
+ const pressC = () => {
23
+ const ev = new KeyboardEvent("keydown", {
24
+ key: "c",
25
+ code: "KeyC",
26
+ bubbles: true
27
+ });
28
+ document.dispatchEvent(ev);
29
+ };
30
+
31
+ let tries = 0;
32
+ const t = setInterval(() => {
33
+ tries++;
34
+ pressC();
35
+ if (tries > 20) clearInterval(t);
36
+ }, 200);
37
+ }
38
+ """
39
+
40
+ CSS_MINIMAL_UI = """
41
+ /* Hide labels, instructions, help text */
42
+ .gradio-container label,
43
+ .gradio-container .block-label,
44
+ .gradio-container .markdown,
45
+ .gradio-container p {
46
+ display: none !important;
47
+ }
48
+
49
+ /* Reduce rounding of UI containers */
50
+ .gradio-container [class*="rounded"] {
51
+ border-radius: 4px !important;
52
+ }
53
+
54
+ /* Reduce padding */
55
+ .gradio-container [class*="p-4"] {
56
+ padding: 0.25rem !important;
57
+ }
58
+ """
59
+
60
+
61
  _MODEL = None
62
  _ARGS = None
63
  _WEIGHTS_PATH = None
64
 
65
+
66
  def _get_args():
67
  global _ARGS
68
  if _ARGS is None:
 
71
  _ARGS = args
72
  return _ARGS
73
 
74
+
75
  def _get_weights_path():
76
  global _WEIGHTS_PATH
77
  if _WEIGHTS_PATH is None:
 
82
  )
83
  return _WEIGHTS_PATH
84
 
85
+
86
  def get_model_on_device(device: torch.device):
87
  """
88
  Lazily build and load model, then move to the requested device.
 
110
  return _MODEL
111
 
112
 
113
+ # -----------------------------
114
+ # Rotation helper (in case annotator reports orientation)
115
+ # -----------------------------
116
+ def _rotate_image_and_boxes(image_np: np.ndarray, boxes: list[dict], angle: int):
117
+ """
118
+ angle is in 90-degree steps. The gradio_image_annotation README demonstrates:
119
+ np.rot90(image, k=-angle)
120
+ so angle=1 => rotate clockwise 90 deg.
121
+ """
122
+ if angle is None:
123
+ return image_np, boxes
124
+
125
+ a = int(angle) % 4
126
+ if a == 0:
127
+ return image_np, boxes
128
+
129
+ H, W = image_np.shape[:2]
130
+
131
+ # rotate image using the same convention as the component docs
132
+ image_rot = np.rot90(image_np, k=-a)
133
+
134
+ def clamp_box(xmin, ymin, xmax, ymax, newW, newH):
135
+ xmin = max(0, min(newW, xmin))
136
+ xmax = max(0, min(newW, xmax))
137
+ ymin = max(0, min(newH, ymin))
138
+ ymax = max(0, min(newH, ymax))
139
+ # ensure ordering
140
+ if xmax < xmin:
141
+ xmin, xmax = xmax, xmin
142
+ if ymax < ymin:
143
+ ymin, ymax = ymax, ymin
144
+ return xmin, ymin, xmax, ymax
145
+
146
+ boxes_rot = []
147
+ if a == 1:
148
+ # 90 deg clockwise: (x,y) -> (H - 1 - y, x)
149
+ newH, newW = W, H
150
+ for b in boxes:
151
+ xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"]
152
+ nxmin = H - ymax
153
+ nxmax = H - ymin
154
+ nymin = xmin
155
+ nymax = xmax
156
+ nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH)
157
+ bb = dict(b)
158
+ bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax})
159
+ boxes_rot.append(bb)
160
+
161
+ elif a == 2:
162
+ # 180 deg: (x,y) -> (W - 1 - x, H - 1 - y)
163
+ newH, newW = H, W
164
+ for b in boxes:
165
+ xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"]
166
+ nxmin = W - xmax
167
+ nxmax = W - xmin
168
+ nymin = H - ymax
169
+ nymax = H - ymin
170
+ nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH)
171
+ bb = dict(b)
172
+ bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax})
173
+ boxes_rot.append(bb)
174
+
175
+ else: # a == 3
176
+ # 90 deg counter-clockwise: (x,y) -> (y, W - 1 - x)
177
+ newH, newW = W, H
178
+ for b in boxes:
179
+ xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"]
180
+ nxmin = ymin
181
+ nxmax = ymax
182
+ nymin = W - xmax
183
+ nymax = W - xmin
184
+ nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH)
185
+ bb = dict(b)
186
+ bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax})
187
+ boxes_rot.append(bb)
188
+
189
+ return image_rot, boxes_rot
190
+
191
+
192
+ # -----------------------------
193
+ # Function to Process Image Once (GPU)
194
+ # -----------------------------
195
  @spaces.GPU
196
  def process_image_once(inputs, enable_mask):
197
+ """
198
+ inputs is AnnotatedImageValue-like dict from gradio_image_annotation:
199
+ {
200
+ "image": np.ndarray | PIL | str,
201
+ "boxes": [ {xmin,ymin,xmax,ymax,label?,color?}, ... ],
202
+ "orientation": int?
203
+ }
204
+ """
205
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
206
  model = get_model_on_device(device)
207
 
208
+ if inputs is None or inputs.get("image", None) is None:
209
+ # keep behavior simple: return empty outputs
210
+ return None, [{"pred_boxes": torch.empty(0, 4), "box_v": torch.empty(0)}], [None], torch.empty(1), 1.0, []
211
+
212
  image = inputs["image"]
213
+ boxes = inputs.get("boxes", []) or []
214
+
215
+ # Ensure numpy image
216
+ if isinstance(image, Image.Image):
217
+ image = np.array(image)
218
+ elif isinstance(image, str):
219
+ # If you ever allow URL/path returns, you’d need to load it here.
220
+ # For now, enforce image_type="numpy" in the UI so this does not occur.
221
+ raise ValueError("Annotator returned image as str. Set image_type='numpy' on image_annotator.")
222
+
223
+ # Handle orientation if provided (rare but supported by component)
224
+ angle = inputs.get("orientation", None)
225
+ if angle is not None:
226
+ image, boxes = _rotate_image_and_boxes(image, boxes, angle)
227
+
228
+ # Convert boxes dicts to your legacy list format so downstream code stays unchanged:
229
+ # drawn_boxes elements must support [0],[1],[3],[4] usage in your code.
230
+ # We'll encode as: [x1, y1, 0, x2, y2]
231
+ drawn_boxes = []
232
+ for b in boxes:
233
+ drawn_boxes.append([float(b["xmin"]), float(b["ymin"]), 0.0, float(b["xmax"]), float(b["ymax"])])
234
+
235
+ # If no boxes, keep consistent behavior (model call would likely fail)
236
+ if len(drawn_boxes) == 0:
237
+ return image, [{"pred_boxes": torch.empty(0, 4), "box_v": torch.empty(0)}], [None], torch.empty(1), 1.0, []
238
+
239
  image_tensor = torch.tensor(image).to(device)
240
  image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0
241
  image_tensor = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image_tensor)
242
 
243
+ bboxes_tensor = torch.tensor(
244
+ [[box[0], box[1], box[3], box[4]] for box in drawn_boxes],
245
+ dtype=torch.float32,
246
+ ).to(device)
247
 
248
  img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0)
249
  img = img.unsqueeze(0).to(device)
 
253
  model.module.return_masks = enable_mask
254
  outputs, _, _, _, masks = model(img, bboxes)
255
 
256
+ # Return ONLY CPU-native objects to main process.
 
 
 
 
257
  out0 = outputs[0]
 
258
  pred_boxes_cpu = out0["pred_boxes"].detach().float().cpu()
259
  box_v_cpu = out0["box_v"].detach().float().cpu()
260
 
 
268
  else:
269
  masks_cpu = [None]
270
 
 
271
  img_cpu = img.detach().cpu()
272
 
273
  return image, outputs_cpu, masks_cpu, img_cpu, float(scale), drawn_boxes
 
282
 
283
 
284
  def instance_colors(i: int):
 
 
 
 
 
 
285
  h = (i * 0.618033988749895) % 1.0
286
+ mask_rgb = _hsv_to_rgb255(h, s=0.28, v=1.00)
287
+ box_rgb = _hsv_to_rgb255(h, s=0.42, v=0.95)
288
  return mask_rgb, box_rgb
289
 
290
 
291
  def overlay_single_mask(base_rgba: Image.Image, mask_bool: np.ndarray, rgb, alpha=0.45):
 
 
 
292
  if mask_bool.dtype != np.bool_:
293
  mask_bool = mask_bool.astype(bool)
294
 
 
303
  return Image.alpha_composite(base_rgba, overlay_img)
304
 
305
 
306
+ # -----------------------------
307
+ # Post-process and Update Output
308
+ # -----------------------------
309
  def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold):
310
  idx = 0
311
  threshold = 1 / threshold
312
 
313
  score = outputs[idx]["box_v"]
314
+ if score.numel() == 0:
315
+ # no predictions
316
+ image_pil = Image.fromarray((image).astype(np.uint8)).convert("RGB")
317
+ return image_pil, 0
318
+
319
  score_mask = score > score.max() / threshold
320
 
321
  keep = ops.nms(
 
328
  pred_boxes = torch.clamp(pred_boxes, 0, 1)
329
  pred_boxes = (pred_boxes / scale * img.shape[-1]).tolist()
330
 
 
331
  image = Image.fromarray((image).astype(np.uint8)).convert("RGBA")
332
 
 
333
  if enable_mask and masks is not None and masks[idx] is not None:
334
  masks_sel = masks[idx][score_mask[0]] if score_mask.ndim > 1 else masks[idx][score_mask]
335
+ masks_sel = masks_sel[keep]
336
 
337
  target_h = int(img.shape[2] / scale)
338
  target_w = int(img.shape[3] / scale)
339
  resize_nearest = T.Resize((target_h, target_w), interpolation=T.InterpolationMode.NEAREST)
340
 
341
  W, H = image.size
 
342
  for i in range(masks_sel.shape[0]):
343
  mask_i = masks_sel[i]
344
  if mask_i.ndim == 3:
 
351
  mask_rgb, _ = instance_colors(i)
352
  image = overlay_single_mask(image, mask_bool, mask_rgb, alpha=0.45)
353
 
 
354
  draw = ImageDraw.Draw(image)
355
+ box_width = 2
356
 
357
  for i, box in enumerate(pred_boxes):
358
  _, box_rgb = instance_colors(i)
359
  x1, y1, x2, y2 = map(float, box)
360
  draw.rectangle([x1, y1, x2, y2], outline=box_rgb, width=box_width)
361
 
362
+ exemplar_outline = (255, 255, 255, 255)
363
+ exemplar_inner = (0, 0, 0, 255)
 
364
  for box in drawn_boxes:
365
  x1, y1, x2, y2 = box[0], box[1], box[3], box[4]
366
  draw.rectangle([x1, y1, x2, y2], outline=exemplar_outline, width=2)
367
  draw.rectangle([x1 + 1, y1 + 1, x2 - 1, y2 - 1], outline=exemplar_inner, width=1)
368
 
 
369
  return image.convert("RGB"), len(pred_boxes)
370
 
371
 
372
+ # -----------------------------
373
+ # Gradio UI
374
+ # -----------------------------
375
+ iface = gr.Blocks(
376
+ title="GeCo2 Gradio Demo",
377
+ js=JS_FORCE_CREATE_MODE,
378
+ css=CSS_MINIMAL_UI,
379
+ )
380
 
381
  with iface:
382
  gr.Markdown(
383
  """
384
  # GeCo2: Generalized-Scale Object Counting with Gradual Query Aggregation
385
+ GeCo2 is a few-shot, category-agnostic detection counter. With only a small number of exemplars, GeCo2 can detect and count all instances of the target object in an image without any retraining.
 
 
 
386
  1) Upload an image.
387
  2) Draw bounding boxes on the target object (preferably ~3 instances).
388
  3) Click **Count**.
 
399
  drawn_boxes_state = gr.State()
400
 
401
  with gr.Row():
402
+ # New annotator component
403
+ annotator = image_annotator(
404
+ value=None,
405
+ image_type="numpy", # ensures inputs["image"] is a numpy array
406
+ label_list=["Object"],
407
+ label_colors=[(0, 255, 0)],
408
+ use_default_label=True,
409
+ enable_keyboard_shortcuts=True,
410
+ interactive=True,
411
+ show_label=False, # hide label text on boxes
412
+ )
413
  image_output = gr.Image(type="pil")
414
 
415
  with gr.Row():
 
421
 
422
  def initial_process(inputs, enable_mask, threshold):
423
  image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
424
+ if image is None:
425
+ return None, 0, None, None, None, None, None, None
426
  return (
427
  *post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold),
428
  image,
 
434
  )
435
 
436
  def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask):
437
+ if image is None or outputs is None or img is None:
438
+ return None, 0
439
  return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
440
 
441
  count_button.click(
442
  initial_process,
443
+ [annotator, enable_mask, threshold],
444
  [image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state],
445
  )
446
 
 
457
  )
458
 
459
  if __name__ == "__main__":
460
+ iface.queue().launch()
requirements.txt CHANGED
@@ -110,5 +110,5 @@ websockets==12.0
110
  zipp==3.21.0
111
  spaces
112
  gradio_client
113
- gradio>=4.0.0,<5
114
- gradio_image_prompter @ https://huggingface.co/datasets/jerpelhan/geco2-assets/resolve/main/wheels/gradio_image_prompter-0.1.0-py3-none-any.whl
 
110
  zipp==3.21.0
111
  spaces
112
  gradio_client
113
+ gradio==5.50.0
114
+ gradio_image_annotation