Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| """ | |
| BMP Demo script: sequentially runs detection, pose estimation, SAM-based mask refinement, and visualization. | |
| Usage: | |
| python bmp_demo.py <config.yaml> <input_image> [--output-root <dir>] | |
| """ | |
| import os | |
| import shutil | |
| from argparse import ArgumentParser, Namespace | |
| from pathlib import Path | |
| import mmcv | |
| import mmengine | |
| import numpy as np | |
| import yaml | |
| from demo_utils import DotDict, concat_instances, create_GIF, filter_instances, pose_nms, visualize_itteration | |
| from mm_utils import run_MMDetector, run_MMPose | |
| from mmdet.apis import init_detector | |
| from mmengine.logging import print_log | |
| from mmengine.structures import InstanceData | |
| from sam2_utils import prepare_model as prepare_sam2_model | |
| from sam2_utils import process_image_with_SAM | |
| from mmpose.apis import init_model as init_pose_estimator | |
| from mmpose.utils import adapt_mmdet_pipeline | |
| # Default thresholds | |
| DEFAULT_DET_CAT_ID: int = 0 # "person" | |
| DEFAULT_BBOX_THR: float = 0.3 | |
| DEFAULT_NMS_THR: float = 0.3 | |
| DEFAULT_KPT_THR: float = 0.3 | |
| def parse_args() -> Namespace: | |
| """ | |
| Parse command-line arguments for BMP demo. | |
| Returns: | |
| Namespace: Contains bmp_config (Path), input (Path), output_root (Path), device (str). | |
| """ | |
| parser = ArgumentParser(description="BBoxMaskPose demo") | |
| parser.add_argument("bmp_config", type=Path, help="Path to BMP YAML config file") | |
| parser.add_argument("input", type=Path, help="Input image file") | |
| parser.add_argument("--output-root", type=Path, default=None, help="Directory to save outputs (default: ./outputs)") | |
| parser.add_argument("--device", type=str, default="cuda:0", help="Device for inference (e.g., cuda:0 or cpu)") | |
| parser.add_argument("--create-gif", action="store_true", default=False, help="Create GIF of all BMP iterations") | |
| args = parser.parse_args() | |
| if args.output_root is None: | |
| args.output_root = os.path.join(Path(__file__).parent, "outputs") | |
| return args | |
| def parse_yaml_config(yaml_path: Path) -> DotDict: | |
| """ | |
| Load BMP configuration from a YAML file. | |
| Args: | |
| yaml_path (Path): Path to YAML config. | |
| Returns: | |
| DotDict: Nested config dictionary. | |
| """ | |
| with open(yaml_path, "r") as f: | |
| cfg = yaml.safe_load(f) | |
| return DotDict(cfg) | |
| def process_one_image( | |
| args: Namespace, | |
| bmp_config: DotDict, | |
| img_path: Path, | |
| detector: object, | |
| detector_prime: object, | |
| pose_estimator: object, | |
| sam2_model: object, | |
| ) -> InstanceData: | |
| """ | |
| Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization. | |
| Args: | |
| args (Namespace): Parsed CLI arguments. | |
| bmp_config (DotDict): Configuration parameters. | |
| img_path (Path): Path to the input image. | |
| detector: Primary MMDetection model. | |
| detector_prime: Secondary MMDetection model for iterations. | |
| pose_estimator: MMPose model for keypoint estimation. | |
| sam2_model: SAM model for mask refinement. | |
| Returns: | |
| InstanceData: Final merged detections and refined masks. | |
| """ | |
| # Load image | |
| img = mmcv.imread(str(img_path), channel_order="bgr") | |
| if img is None: | |
| raise ValueError("Failed to read image from {}.".format(img_path)) | |
| # Prepare output directory | |
| output_dir = os.path.join(args.output_root, img_path.stem) | |
| shutil.rmtree(str(output_dir), ignore_errors=True) | |
| mmengine.mkdir_or_exist(str(output_dir)) | |
| img_for_detection = img.copy() | |
| all_detections = None | |
| for iteration in range(bmp_config.num_bmp_iters): | |
| print_log("BMP Iteration {}/{} started".format(iteration + 1, bmp_config.num_bmp_iters), logger="current") | |
| # Step 1: Detection | |
| det_instances = run_MMDetector( | |
| detector if iteration == 0 else detector_prime, | |
| img_for_detection, | |
| det_cat_id=DEFAULT_DET_CAT_ID, | |
| bbox_thr=DEFAULT_BBOX_THR, | |
| nms_thr=DEFAULT_NMS_THR, | |
| ) | |
| print_log("Detected {} instances".format(len(det_instances.bboxes)), logger="current") | |
| if len(det_instances.bboxes) == 0: | |
| print_log("No detections found, skipping.", logger="current") | |
| continue | |
| # Step 2: Pose estimation | |
| pose_instances = run_MMPose( | |
| pose_estimator, | |
| img.copy(), | |
| detections=det_instances, | |
| kpt_thr=DEFAULT_KPT_THR, | |
| ) | |
| # Restrict to first 17 COCO keypoints | |
| pose_instances.keypoints = pose_instances.keypoints[:, :17, :] | |
| pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17] | |
| pose_instances.keypoints = np.concatenate( | |
| [pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1 | |
| ) | |
| # Step 3: Pose-NMS and SAM refinement | |
| all_keypoints = ( | |
| pose_instances.keypoints | |
| if all_detections is None | |
| else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0) | |
| ) | |
| all_bboxes = ( | |
| pose_instances.bboxes | |
| if all_detections is None | |
| else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0) | |
| ) | |
| num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1) | |
| keep_indices = pose_nms( | |
| DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}), | |
| image_kpts=all_keypoints, | |
| image_bboxes=all_bboxes, | |
| num_valid_kpts=num_valid_kpts, | |
| ) | |
| keep_indices = sorted(keep_indices) # Sort by original index | |
| num_old_detections = 0 if all_detections is None else len(all_detections.bboxes) | |
| keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections] | |
| keep_old_indices = [i for i in keep_indices if i < num_old_detections] | |
| if len(keep_new_indices) == 0: | |
| print_log("No new instances passed pose NMS, skipping SAM refinement.", logger="current") | |
| continue | |
| # filter new detections and compute scores | |
| new_dets = filter_instances(pose_instances, keep_new_indices) | |
| new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1) | |
| old_dets = None | |
| if len(keep_old_indices) > 0: | |
| old_dets = filter_instances(all_detections, keep_old_indices) | |
| print_log( | |
| "Pose NMS reduced instances to {:d} ({:d}+{:d}) instances".format( | |
| len(new_dets.bboxes) + num_old_detections, num_old_detections, len(new_dets.bboxes) | |
| ), | |
| logger="current", | |
| ) | |
| new_detections = process_image_with_SAM( | |
| DotDict(bmp_config.sam2.prompting), | |
| img.copy(), | |
| sam2_model, | |
| new_dets, | |
| old_dets if old_dets is not None else None, | |
| ) | |
| # Merge detections | |
| if all_detections is None: | |
| all_detections = new_detections | |
| else: | |
| all_detections = concat_instances(all_detections, new_dets) | |
| # Step 4: Visualization | |
| img_for_detection = visualize_itteration( | |
| img.copy(), | |
| all_detections, | |
| iteration_idx=iteration, | |
| output_root=str(output_dir), | |
| img_name=img_path.stem, | |
| ) | |
| print_log("Iteration {} completed".format(iteration + 1), logger="current") | |
| # Create GIF of iterations if requested | |
| if args.create_gif: | |
| image_file = os.path.join(output_dir, "{:s}.jpg".format(img_path.stem)) | |
| create_GIF( | |
| img_path=str(image_file), | |
| output_root=str(output_dir), | |
| bmp_x=bmp_config.num_bmp_iters, | |
| ) | |
| return all_detections | |
| def main() -> None: | |
| """ | |
| Entry point for the BMP demo: loads models and processes one image. | |
| """ | |
| args = parse_args() | |
| bmp_config = parse_yaml_config(args.bmp_config) | |
| # Ensure output root exists | |
| mmengine.mkdir_or_exist(str(args.output_root)) | |
| # build detectors | |
| detector = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device=args.device) | |
| detector.cfg = adapt_mmdet_pipeline(detector.cfg) | |
| if ( | |
| bmp_config.detector.det_config == bmp_config.detector.det_prime_config | |
| and bmp_config.detector.det_checkpoint == bmp_config.detector.det_prime_checkpoint | |
| ) or (bmp_config.detector.det_prime_config is None or bmp_config.detector.det_prime_checkpoint is None): | |
| print_log("Using the same detector as D and D'", logger="current") | |
| detector_prime = detector | |
| else: | |
| detector_prime = init_detector( | |
| bmp_config.detector.det_prime_config, bmp_config.detector.det_prime_checkpoint, device=args.device | |
| ) | |
| detector_prime.cfg = adapt_mmdet_pipeline(detector_prime.cfg) | |
| print_log("Using a different detector for D'", logger="current") | |
| # build pose estimator | |
| pose_estimator = init_pose_estimator( | |
| bmp_config.pose_estimator.pose_config, | |
| bmp_config.pose_estimator.pose_checkpoint, | |
| device=args.device, | |
| cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))), | |
| ) | |
| sam2 = prepare_sam2_model( | |
| model_cfg=bmp_config.sam2.sam2_config, | |
| model_checkpoint=bmp_config.sam2.sam2_checkpoint, | |
| ) | |
| # Run inference on one image | |
| _ = process_one_image(args, bmp_config, args.input, detector, detector_prime, pose_estimator, sam2) | |
| if __name__ == "__main__": | |
| main() | |