Miroslav Purkrabek commited on
Commit
7ebd068
·
1 Parent(s): e0c4840

first code with BMP demo

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +181 -9
  3. demo/demo_utils.py +37 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: ProbPose Demo
3
  emoji: 🐠
4
  colorFrom: gray
5
  colorTo: yellow
 
1
  ---
2
+ title: BBoxMaskPose Demo
3
  emoji: 🐠
4
  colorFrom: gray
5
  colorTo: yellow
app.py CHANGED
@@ -1,22 +1,194 @@
1
  import gradio as gr
2
  import spaces
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  @spaces.GPU(duration=60)
5
  def process_image_with_BMP(
6
- image,
7
- ):
8
  """
9
- Performs object detection using SAHI with a specified YOLOv11 model.
 
10
  Args:
11
- image (PIL.Image.Image): The input image for detection.
 
 
 
 
 
 
12
  Returns:
13
- tuple: A tuple containing two PIL.Image.Image objects:
14
- - The image with standard YOLO inference results.
15
- - The image with SAHI sliced YOLO inference results.
16
  """
17
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- return image, image
20
 
21
 
22
  with gr.Blocks() as app:
 
1
  import gradio as gr
2
  import spaces
3
 
4
+
5
+
6
+ # Copyright (c) OpenMMLab. All rights reserved.
7
+ """
8
+ BMP Demo script: sequentially runs detection, pose estimation, SAM-based mask refinement, and visualization.
9
+ Usage:
10
+ python bmp_demo.py <config.yaml> <input_image> [--output-root <dir>]
11
+ """
12
+
13
+ import os
14
+ import shutil
15
+ from argparse import ArgumentParser, Namespace
16
+ from pathlib import Path
17
+
18
+ import mmcv
19
+ import mmengine
20
+ import numpy as np
21
+ import yaml
22
+ from demo.demo_utils import DotDict, concat_instances, create_GIF, filter_instances, pose_nms, visualize_demo
23
+ from demo.mm_utils import run_MMDetector, run_MMPose
24
+ from mmdet.apis import init_detector
25
+ from mmengine.logging import print_log
26
+ from mmengine.structures import InstanceData
27
+ from demo.sam2_utils import prepare_model as prepare_sam2_model
28
+ from demo.sam2_utils import process_image_with_SAM
29
+
30
+ from mmpose.apis import init_model as init_pose_estimator
31
+ from mmpose.utils import adapt_mmdet_pipeline
32
+
33
+ # Default thresholds
34
+ DEFAULT_CAT_ID: int = 0
35
+
36
+ DEFAULT_BBOX_THR: float = 0.3
37
+ DEFAULT_NMS_THR: float = 0.3
38
+ DEFAULT_KPT_THR: float = 0.3
39
+
40
+ # Global models variable
41
+ det_model = None
42
+ pose_model = None
43
+ sam2_model = None
44
+
45
+ def _parse_yaml_config(yaml_path: Path) -> DotDict:
46
+ """
47
+ Load BMP configuration from a YAML file.
48
+
49
+ Args:
50
+ yaml_path (Path): Path to YAML config.
51
+ Returns:
52
+ DotDict: Nested config dictionary.
53
+ """
54
+ with open(yaml_path, "r") as f:
55
+ cfg = yaml.safe_load(f)
56
+ return DotDict(cfg)
57
+
58
+ def load_models(bmp_config):
59
+ device = 'gpu'
60
+
61
+ global det_model, pose_model, sam2_model
62
+
63
+ # build detectors
64
+ det_model = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device='gpu')
65
+ det_model.cfg = adapt_mmdet_pipeline(det_model.cfg)
66
+
67
+
68
+ # build pose estimator
69
+ pose_model = init_pose_estimator(
70
+ bmp_config.pose_estimator.pose_config,
71
+ bmp_config.pose_estimator.pose_checkpoint,
72
+ device=device,
73
+ cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))),
74
+ )
75
+
76
+ sam2_model = prepare_sam2_model(
77
+ model_cfg=bmp_config.sam2.sam2_config,
78
+ model_checkpoint=bmp_config.sam2.sam2_checkpoint,
79
+ )
80
+
81
+ return det_model, pose_model, sam2_model
82
+
83
  @spaces.GPU(duration=60)
84
  def process_image_with_BMP(
85
+ img: np.ndarray
86
+ ) -> tuple[np.ndarray, np.ndarray]:
87
  """
88
+ Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization.
89
+
90
  Args:
91
+ args (Namespace): Parsed CLI arguments.
92
+ bmp_config (DotDict): Configuration parameters.
93
+ img_path (Path): Path to the input image.
94
+ detector: Primary MMDetection model.
95
+ detector_prime: Secondary MMDetection model for iterations.
96
+ pose_estimator: MMPose model for keypoint estimation.
97
+ sam2_model: SAM model for mask refinement.
98
  Returns:
99
+ InstanceData: Final merged detections and refined masks.
 
 
100
  """
101
+ bmp_config = _parse_yaml_config(Path("configs/bmp_D3.yaml"))
102
+ load_models(bmp_config)
103
+
104
+ img_for_detection = img.copy()
105
+ all_detections = None
106
+ for iteration in range(bmp_config.num_bmp_iters):
107
+
108
+ # Step 1: Detection
109
+ det_instances = run_MMDetector(
110
+ det_model,
111
+ img_for_detection,
112
+ det_cat_id=DEFAULT_CAT_ID,
113
+ bbox_thr=DEFAULT_BBOX_THR,
114
+ nms_thr=DEFAULT_NMS_THR,
115
+ )
116
+ if len(det_instances.bboxes) == 0:
117
+ continue
118
+
119
+ # Step 2: Pose estimation
120
+ pose_instances = run_MMPose(
121
+ pose_model,
122
+ img.copy(),
123
+ detections=det_instances,
124
+ kpt_thr=DEFAULT_KPT_THR,
125
+ )
126
+
127
+ # Restrict to first 17 COCO keypoints
128
+ pose_instances.keypoints = pose_instances.keypoints[:, :17, :]
129
+ pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17]
130
+ pose_instances.keypoints = np.concatenate(
131
+ [pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1
132
+ )
133
+
134
+ # Step 3: Pose-NMS and SAM refinement
135
+ all_keypoints = (
136
+ pose_instances.keypoints
137
+ if all_detections is None
138
+ else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0)
139
+ )
140
+ all_bboxes = (
141
+ pose_instances.bboxes
142
+ if all_detections is None
143
+ else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
144
+ )
145
+ num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1)
146
+ keep_indices = pose_nms(
147
+ DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}),
148
+ image_kpts=all_keypoints,
149
+ image_bboxes=all_bboxes,
150
+ num_valid_kpts=num_valid_kpts,
151
+ )
152
+ keep_indices = sorted(keep_indices) # Sort by original index
153
+ num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
154
+ keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
155
+ keep_old_indices = [i for i in keep_indices if i < num_old_detections]
156
+ if len(keep_new_indices) == 0:
157
+ print_log("No new instances passed pose NMS, skipping SAM refinement.", logger="current")
158
+ continue
159
+ # filter new detections and compute scores
160
+ new_dets = filter_instances(pose_instances, keep_new_indices)
161
+ new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
162
+ old_dets = None
163
+ if len(keep_old_indices) > 0:
164
+ old_dets = filter_instances(all_detections, keep_old_indices)
165
+
166
+ new_detections = process_image_with_SAM(
167
+ DotDict(bmp_config.sam2.prompting),
168
+ img.copy(),
169
+ sam2_model,
170
+ new_dets,
171
+ old_dets if old_dets is not None else None,
172
+ )
173
+
174
+ # Merge detections
175
+ if all_detections is None:
176
+ all_detections = new_detections
177
+ else:
178
+ all_detections = concat_instances(all_detections, new_dets)
179
+
180
+ # Step 4: Visualization
181
+ img_for_detection, _, _ = visualize_demo(
182
+ img.copy(),
183
+ all_detections,
184
+ )
185
+
186
+ _, rtmdet_result, bmp_result = visualize_demo(
187
+ img.copy(),
188
+ all_detections,
189
+ )
190
 
191
+ return rtmdet_result, bmp_result
192
 
193
 
194
  with gr.Blocks() as app:
demo/demo_utils.py CHANGED
@@ -292,6 +292,43 @@ def visualize_itteration(
292
  return masked_out
293
 
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  def create_GIF(
296
  img_path: Path,
297
  output_root: Path,
 
292
  return masked_out
293
 
294
 
295
+ def visualize_demo(
296
+ img: np.ndarray, detections: Any,
297
+ ) -> Optional[np.ndarray]:
298
+ """
299
+ Generate and save visualization images for each BMP iteration.
300
+
301
+ Args:
302
+ img (np.ndarray): Original input image.
303
+ detections: InstanceData containing bboxes, scores, masks, keypoints.
304
+ iteration_idx (int): Current iteration index (0-based).
305
+ output_root (Path): Directory to save output images.
306
+ img_name (str): Base name of the image without extension.
307
+ with_text (bool): Whether to overlay text labels.
308
+
309
+ Returns:
310
+ Optional[np.ndarray]: The masked-out image if generated, else None.
311
+ """
312
+ bboxes = detections.bboxes
313
+ scores = detections.scores
314
+ pred_masks = detections.pred_masks
315
+ refined_masks = detections.refined_masks
316
+ keypoints = detections.keypoints
317
+
318
+ returns = []
319
+ for vis_def in [
320
+ {"type": "mask-out", "masks": refined_masks, "label": ""},
321
+ {"type": "bbox+mask", "masks": pred_masks, "label": "RTMDet-L"},
322
+ {"type": "mask+pose", "masks": refined_masks, "label": "BMP"},
323
+ ]:
324
+ vis_img, colors = _visualize_predictions(
325
+ img.copy(), bboxes, scores, vis_def["masks"], keypoints, vis_type=vis_def["type"], mask_is_binary=True
326
+ )
327
+ returns.append(vis_img)
328
+
329
+ return returns
330
+
331
+
332
  def create_GIF(
333
  img_path: Path,
334
  output_root: Path,