jerpelhan commited on
Commit
0e137ec
·
1 Parent(s): dcddf2d

Updated demo, added AMP for faster inference, added examples

Browse files
demo_gradio.py CHANGED
@@ -1,8 +1,8 @@
1
  import spaces
2
  import torch
 
3
  import gradio as gr
4
  from gradio_image_annotation import image_annotator
5
- from torch.nn import DataParallel
6
  from models.counter_infer import build_model
7
  from utils.arg_parser import get_argparser
8
  from utils.data import resize_and_pad
@@ -13,55 +13,11 @@ from huggingface_hub import hf_hub_download
13
  import numpy as np
14
  import colorsys
15
 
16
-
17
- # -----------------------------
18
- # Minimal UI + force "Create" mode (press C a few times)
19
  # -----------------------------
20
- JS_FORCE_CREATE_MODE = r"""
21
- function () {
22
- const pressC = () => {
23
- const ev = new KeyboardEvent("keydown", {
24
- key: "c",
25
- code: "KeyC",
26
- bubbles: true
27
- });
28
- document.dispatchEvent(ev);
29
- };
30
-
31
- let tries = 0;
32
- const t = setInterval(() => {
33
- tries++;
34
- pressC();
35
- if (tries > 20) clearInterval(t);
36
- }, 200);
37
- }
38
- """
39
-
40
- CSS_MINIMAL_UI = """
41
- /* Hide labels, instructions, help text */
42
- .gradio-container label,
43
- .gradio-container .block-label,
44
- .gradio-container .markdown,
45
- .gradio-container p {
46
- display: none !important;
47
- }
48
-
49
- /* Reduce rounding of UI containers */
50
- .gradio-container [class*="rounded"] {
51
- border-radius: 4px !important;
52
- }
53
-
54
- /* Reduce padding */
55
- .gradio-container [class*="p-4"] {
56
- padding: 0.25rem !important;
57
- }
58
- """
59
-
60
-
61
  _MODEL = None
62
  _ARGS = None
63
  _WEIGHTS_PATH = None
64
-
65
 
66
  def _get_args():
67
  global _ARGS
@@ -83,6 +39,36 @@ def _get_weights_path():
83
  return _WEIGHTS_PATH
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def get_model_on_device(device: torch.device):
87
  """
88
  Lazily build and load model, then move to the requested device.
@@ -95,18 +81,19 @@ def get_model_on_device(device: torch.device):
95
 
96
  # Build on CPU first to avoid CUDA init in the wrong process
97
  model = build_model(args)
98
- model = DataParallel(model) # wrap before loading; matches your original
99
 
100
  weights_path = _get_weights_path()
101
- ckpt = torch.load(weights_path, map_location="cpu", weights_only=True)
102
- state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
103
- model.load_state_dict(state, strict=False)
104
 
 
105
  model.eval()
106
  _MODEL = model
107
 
108
- # Ensure correct device for this invocation
109
  _MODEL = _MODEL.to(device)
 
 
110
  return _MODEL
111
 
112
 
@@ -114,11 +101,6 @@ def get_model_on_device(device: torch.device):
114
  # Rotation helper (in case annotator reports orientation)
115
  # -----------------------------
116
  def _rotate_image_and_boxes(image_np: np.ndarray, boxes: list[dict], angle: int):
117
- """
118
- angle is in 90-degree steps. The gradio_image_annotation README demonstrates:
119
- np.rot90(image, k=-angle)
120
- so angle=1 => rotate clockwise 90 deg.
121
- """
122
  if angle is None:
123
  return image_np, boxes
124
 
@@ -136,7 +118,6 @@ def _rotate_image_and_boxes(image_np: np.ndarray, boxes: list[dict], angle: int)
136
  xmax = max(0, min(newW, xmax))
137
  ymin = max(0, min(newH, ymin))
138
  ymax = max(0, min(newH, ymax))
139
- # ensure ordering
140
  if xmax < xmin:
141
  xmin, xmax = xmax, xmin
142
  if ymax < ymin:
@@ -212,27 +193,25 @@ def process_image_once(inputs, enable_mask):
212
  image = inputs["image"]
213
  boxes = inputs.get("boxes", []) or []
214
 
215
- # Ensure numpy image
216
  if isinstance(image, Image.Image):
217
- image = np.array(image)
218
  elif isinstance(image, str):
219
- # If you ever allow URL/path returns, you’d need to load it here.
220
- # For now, enforce image_type="numpy" in the UI so this does not occur.
221
- raise ValueError("Annotator returned image as str. Set image_type='numpy' on image_annotator.")
 
 
222
 
223
- # Handle orientation if provided (rare but supported by component)
224
  angle = inputs.get("orientation", None)
225
  if angle is not None:
226
  image, boxes = _rotate_image_and_boxes(image, boxes, angle)
227
 
228
- # Convert boxes dicts to your legacy list format so downstream code stays unchanged:
229
- # drawn_boxes elements must support [0],[1],[3],[4] usage in your code.
230
- # We'll encode as: [x1, y1, 0, x2, y2]
231
  drawn_boxes = []
232
  for b in boxes:
233
  drawn_boxes.append([float(b["xmin"]), float(b["ymin"]), 0.0, float(b["xmax"]), float(b["ymax"])])
234
 
235
- # If no boxes, keep consistent behavior (model call would likely fail)
236
  if len(drawn_boxes) == 0:
237
  return image, [{"pred_boxes": torch.empty(0, 4), "box_v": torch.empty(0)}], [None], torch.empty(1), 1.0, []
238
 
@@ -249,19 +228,17 @@ def process_image_once(inputs, enable_mask):
249
  img = img.unsqueeze(0).to(device)
250
  bboxes = bboxes.unsqueeze(0).to(device)
251
 
252
- with torch.no_grad():
253
- model.module.return_masks = enable_mask
 
 
254
  outputs, _, _, _, masks = model(img, bboxes)
255
 
256
  # Return ONLY CPU-native objects to main process.
257
  out0 = outputs[0]
258
  pred_boxes_cpu = out0["pred_boxes"].detach().float().cpu()
259
  box_v_cpu = out0["box_v"].detach().float().cpu()
260
-
261
- outputs_cpu = [{
262
- "pred_boxes": pred_boxes_cpu,
263
- "box_v": box_v_cpu,
264
- }]
265
 
266
  if enable_mask and masks is not None and masks[0] is not None:
267
  masks_cpu = [masks[0].detach().float().cpu()]
@@ -369,13 +346,25 @@ def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, th
369
  return image.convert("RGB"), len(pred_boxes)
370
 
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  # -----------------------------
373
  # Gradio UI
374
  # -----------------------------
375
  iface = gr.Blocks(
376
  title="GeCo2 Gradio Demo",
377
- # js=JS_FORCE_CREATE_MODE,
378
- # css=CSS_MINIMAL_UI,
379
  )
380
 
381
  with iface:
@@ -383,7 +372,8 @@ with iface:
383
  """
384
  # GeCo2: Generalized-Scale Object Counting with Gradual Query Aggregation
385
  GeCo2 is a few-shot, category-agnostic detection counter. With only a small number of exemplars, GeCo2 can detect and count all instances of the target object in an image without any retraining.
386
- 1) Upload an image.
 
387
  2) Draw bounding boxes on the target object (preferably ~3 instances).
388
  3) Click **Count**.
389
  4) If needed, adjust the threshold.
@@ -399,7 +389,6 @@ GeCo2 is a few-shot, category-agnostic detection counter. With only a small numb
399
  drawn_boxes_state = gr.State()
400
 
401
  with gr.Row():
402
- # New annotator component
403
  annotator = image_annotator(
404
  value=None,
405
  image_type="numpy", # ensures inputs["image"] is a numpy array
@@ -408,7 +397,7 @@ GeCo2 is a few-shot, category-agnostic detection counter. With only a small numb
408
  use_default_label=True,
409
  enable_keyboard_shortcuts=True,
410
  interactive=True,
411
- show_label=False, # hide label text on boxes
412
  )
413
  image_output = gr.Image(type="pil")
414
 
@@ -419,12 +408,52 @@ GeCo2 is a few-shot, category-agnostic detection counter. With only a small numb
419
 
420
  count_button = gr.Button("Count")
421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  def initial_process(inputs, enable_mask, threshold):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
424
  if image is None:
425
  return None, 0, None, None, None, None, None, None
 
 
426
  return (
427
- *post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold),
 
428
  image,
429
  outputs,
430
  masks,
 
1
  import spaces
2
  import torch
3
+ import torch.nn.functional as F
4
  import gradio as gr
5
  from gradio_image_annotation import image_annotator
 
6
  from models.counter_infer import build_model
7
  from utils.arg_parser import get_argparser
8
  from utils.data import resize_and_pad
 
13
  import numpy as np
14
  import colorsys
15
 
 
 
 
16
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  _MODEL = None
18
  _ARGS = None
19
  _WEIGHTS_PATH = None
20
+ # -----------------------------
21
 
22
  def _get_args():
23
  global _ARGS
 
39
  return _WEIGHTS_PATH
40
 
41
 
42
+ def _strip_module_prefix(state_dict: dict) -> dict:
43
+ """
44
+ If weights were saved from torch.nn.DataParallel, keys are often prefixed with 'module.'.
45
+ When loading into a non-DataParallel model, strip that prefix.
46
+ """
47
+ if not isinstance(state_dict, dict) or len(state_dict) == 0:
48
+ return state_dict
49
+
50
+ # Only strip if it looks like DP
51
+ has_module = any(k.startswith("module.") for k in state_dict.keys())
52
+ if not has_module:
53
+ return state_dict
54
+
55
+ return {k[len("module.") :]: v for k, v in state_dict.items()}
56
+
57
+
58
+ def _extract_state_dict(ckpt) -> dict:
59
+ """
60
+ Robustly extract a state_dict from typical checkpoint formats.
61
+ """
62
+ if isinstance(ckpt, dict):
63
+ # Common keys
64
+ if "model" in ckpt and isinstance(ckpt["model"], dict):
65
+ return ckpt["model"]
66
+ if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
67
+ return ckpt["state_dict"]
68
+ # Fallback: checkpoint itself is the state_dict
69
+ return ckpt
70
+
71
+
72
  def get_model_on_device(device: torch.device):
73
  """
74
  Lazily build and load model, then move to the requested device.
 
81
 
82
  # Build on CPU first to avoid CUDA init in the wrong process
83
  model = build_model(args)
 
84
 
85
  weights_path = _get_weights_path()
86
+ ckpt = torch.load(weights_path, map_location="cpu") # keep compatibility across torch versions
87
+ state = _extract_state_dict(ckpt)
88
+ state = _strip_module_prefix(state)
89
 
90
+ model.load_state_dict(state, strict=False)
91
  model.eval()
92
  _MODEL = model
93
 
 
94
  _MODEL = _MODEL.to(device)
95
+ if device.type == "cuda":
96
+ torch.backends.cudnn.benchmark = True
97
  return _MODEL
98
 
99
 
 
101
  # Rotation helper (in case annotator reports orientation)
102
  # -----------------------------
103
  def _rotate_image_and_boxes(image_np: np.ndarray, boxes: list[dict], angle: int):
 
 
 
 
 
104
  if angle is None:
105
  return image_np, boxes
106
 
 
118
  xmax = max(0, min(newW, xmax))
119
  ymin = max(0, min(newH, ymin))
120
  ymax = max(0, min(newH, ymax))
 
121
  if xmax < xmin:
122
  xmin, xmax = xmax, xmin
123
  if ymax < ymin:
 
193
  image = inputs["image"]
194
  boxes = inputs.get("boxes", []) or []
195
 
196
+ # Ensure numpy image (support numpy, PIL, OR local path string)
197
  if isinstance(image, Image.Image):
198
+ image = np.array(image.convert("RGB"))
199
  elif isinstance(image, str):
200
+ image = np.array(Image.open(image).convert("RGB"))
201
+ elif isinstance(image, np.ndarray):
202
+ pass
203
+ else:
204
+ raise ValueError(f"Unsupported image type from annotator: {type(image)}")
205
 
 
206
  angle = inputs.get("orientation", None)
207
  if angle is not None:
208
  image, boxes = _rotate_image_and_boxes(image, boxes, angle)
209
 
 
 
 
210
  drawn_boxes = []
211
  for b in boxes:
212
  drawn_boxes.append([float(b["xmin"]), float(b["ymin"]), 0.0, float(b["xmax"]), float(b["ymax"])])
213
 
214
+ # If no boxes, do not call model (caller will handle warning)
215
  if len(drawn_boxes) == 0:
216
  return image, [{"pred_boxes": torch.empty(0, 4), "box_v": torch.empty(0)}], [None], torch.empty(1), 1.0, []
217
 
 
228
  img = img.unsqueeze(0).to(device)
229
  bboxes = bboxes.unsqueeze(0).to(device)
230
 
231
+ # Faster inference mode
232
+ use_amp = (device.type == "cuda")
233
+ with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16, enabled=use_amp):
234
+ model.return_masks = enable_mask
235
  outputs, _, _, _, masks = model(img, bboxes)
236
 
237
  # Return ONLY CPU-native objects to main process.
238
  out0 = outputs[0]
239
  pred_boxes_cpu = out0["pred_boxes"].detach().float().cpu()
240
  box_v_cpu = out0["box_v"].detach().float().cpu()
241
+ outputs_cpu = [{"pred_boxes": pred_boxes_cpu, "box_v": box_v_cpu}]
 
 
 
 
242
 
243
  if enable_mask and masks is not None and masks[0] is not None:
244
  masks_cpu = [masks[0].detach().float().cpu()]
 
346
  return image.convert("RGB"), len(pred_boxes)
347
 
348
 
349
+ # -----------------------------
350
+ # Examples: gallery click -> set annotator value
351
+ # -----------------------------
352
+ EXAMPLE_PATHS = ["material/1.jpg", "material/2.jpg", "material/3.jpg", "material/4.jpg", "material/5.jpg"]
353
+
354
+ def load_example_from_gallery(evt: gr.SelectData):
355
+ """
356
+ When user clicks a thumbnail in the gallery, load that image into the annotator.
357
+ """
358
+ idx = int(evt.index)
359
+ path = EXAMPLE_PATHS[idx]
360
+ return {"image": path, "boxes": []}
361
+
362
+
363
  # -----------------------------
364
  # Gradio UI
365
  # -----------------------------
366
  iface = gr.Blocks(
367
  title="GeCo2 Gradio Demo",
 
 
368
  )
369
 
370
  with iface:
 
372
  """
373
  # GeCo2: Generalized-Scale Object Counting with Gradual Query Aggregation
374
  GeCo2 is a few-shot, category-agnostic detection counter. With only a small number of exemplars, GeCo2 can detect and count all instances of the target object in an image without any retraining.
375
+
376
+ 1) Upload an image or click an example below.
377
  2) Draw bounding boxes on the target object (preferably ~3 instances).
378
  3) Click **Count**.
379
  4) If needed, adjust the threshold.
 
389
  drawn_boxes_state = gr.State()
390
 
391
  with gr.Row():
 
392
  annotator = image_annotator(
393
  value=None,
394
  image_type="numpy", # ensures inputs["image"] is a numpy array
 
397
  use_default_label=True,
398
  enable_keyboard_shortcuts=True,
399
  interactive=True,
400
+ show_label=False,
401
  )
402
  image_output = gr.Image(type="pil")
403
 
 
408
 
409
  count_button = gr.Button("Count")
410
 
411
+ gallery = gr.Gallery(
412
+ value=EXAMPLE_PATHS,
413
+ columns=5,
414
+ height=450,
415
+ label="Examples (click an image to load it into the annotator)",
416
+ show_label=True,
417
+ allow_preview=False,
418
+ )
419
+
420
+ gallery.select(
421
+ fn=load_example_from_gallery,
422
+ inputs=None,
423
+ outputs=annotator,
424
+ )
425
+
426
  def initial_process(inputs, enable_mask, threshold):
427
+ # Validate: must have at least one box
428
+ if inputs is None or inputs.get("image", None) is None:
429
+ gr.Warning("please delineate at least one target category object")
430
+ return None, 0, None, None, None, None, None, None
431
+
432
+ img_val = inputs.get("image", None)
433
+ boxes = inputs.get("boxes", []) or []
434
+
435
+ if len(boxes) == 0:
436
+ # Try to show current image in the output even if no boxes
437
+ if isinstance(img_val, str):
438
+ preview = Image.open(img_val).convert("RGB")
439
+ elif isinstance(img_val, Image.Image):
440
+ preview = img_val.convert("RGB")
441
+ elif isinstance(img_val, np.ndarray):
442
+ preview = Image.fromarray(img_val.astype(np.uint8)).convert("RGB")
443
+ else:
444
+ preview = None
445
+
446
+ gr.Warning("please delineate at least one target category object")
447
+ return preview, 0, None, None, None, None, None, None
448
+
449
  image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
450
  if image is None:
451
  return None, 0, None, None, None, None, None, None
452
+
453
+ out_img, cnt = post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
454
  return (
455
+ out_img,
456
+ cnt,
457
  image,
458
  outputs,
459
  masks,
material/1.jpg ADDED
material/2.jpg ADDED
material/3.jpg ADDED
material/4.jpg ADDED
material/5.jpg ADDED
models/counter_infer.py CHANGED
@@ -8,7 +8,7 @@ from torch import nn
8
  from torch.nn import functional as F
9
  from torchvision.ops import roi_align
10
  from torchvision.transforms import Resize
11
-
12
  from utils.box_ops import boxes_with_scores
13
  from .query_generator import C_base
14
  from .sam_mask import MaskProcessor
@@ -128,15 +128,23 @@ class CNT(nn.Module):
128
  prototype_embeddings_l2 = torch.cat([exemplars_l2, shape], dim=1)
129
  hq_prototype_embeddings = [prototype_embeddings_l1, prototype_embeddings_l2]
130
 
131
- # adapt image feature with prototypes
132
- adapted_f, adapted_f_aux = self.adapt_features(
133
- image_embeddings=src,
134
- image_pe=self.sam_prompt_encoder.get_dense_pe(),
135
- prototype_embeddings=prototype_embeddings,
136
- hq_features=feats['backbone_fpn'],
137
- hq_prototypes=hq_prototype_embeddings,
138
- hq_pos=feats['vision_pos_enc'],
139
- )
 
 
 
 
 
 
 
 
140
  # Predict class [fg, bg] and l,r,t,b
141
  bs, c, w, h = adapted_f.shape
142
  adapted_f = adapted_f.view(bs, self.emb_dim, -1).permute(0, 2, 1)
 
8
  from torch.nn import functional as F
9
  from torchvision.ops import roi_align
10
  from torchvision.transforms import Resize
11
+ from torch.cuda.amp import autocast
12
  from utils.box_ops import boxes_with_scores
13
  from .query_generator import C_base
14
  from .sam_mask import MaskProcessor
 
128
  prototype_embeddings_l2 = torch.cat([exemplars_l2, shape], dim=1)
129
  hq_prototype_embeddings = [prototype_embeddings_l1, prototype_embeddings_l2]
130
 
131
+ with autocast(enabled=False):
132
+ if src.type != torch.float32:
133
+ src = src.float()
134
+ prototype_embeddings = prototype_embeddings.float()
135
+ hq_prototype_embeddings = [hq.float() for hq in hq_prototype_embeddings]
136
+ feats['backbone_fpn'] = [f.float() for f in feats['backbone_fpn']]
137
+ feats['vision_pos_enc'] = [f.float() for f in feats['vision_pos_enc']]
138
+
139
+ # adapt image feature with prototypes
140
+ adapted_f, adapted_f_aux = self.adapt_features(
141
+ image_embeddings=src,
142
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
143
+ prototype_embeddings=prototype_embeddings,
144
+ hq_features=feats['backbone_fpn'],
145
+ hq_prototypes=hq_prototype_embeddings,
146
+ hq_pos=feats['vision_pos_enc'],
147
+ )
148
  # Predict class [fg, bg] and l,r,t,b
149
  bs, c, w, h = adapted_f.shape
150
  adapted_f = adapted_f.view(bs, self.emb_dim, -1).permute(0, 2, 1)