Spaces:
Running on Zero
Running on Zero
Miroslav Purkrabek commited on
Commit ·
322535b
1
Parent(s): e5057aa
add BMPv2 code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- BMPv2_README.md +311 -0
- SAM3D_INTEGRATION.md +302 -0
- app.py +188 -172
- bboxmaskpose/__init__.py +10 -0
- bboxmaskpose/api.py +515 -0
- {configs → bboxmaskpose/configs}/README.md +0 -0
- {configs → bboxmaskpose/configs}/bmp_D3.yaml +9 -2
- {configs → bboxmaskpose/configs}/bmp_J1.yaml +5 -0
- bboxmaskpose/configs/bmp_v2.yaml +34 -0
- {demo → bboxmaskpose}/demo_utils.py +30 -110
- {demo → bboxmaskpose}/posevis_lite.py +12 -12
- {sam2 → bboxmaskpose/sam2}/__init__.py +1 -1
- {sam2 → bboxmaskpose/sam2}/automatic_mask_generator.py +20 -50
- {sam2 → bboxmaskpose/sam2}/benchmark.py +3 -9
- {sam2 → bboxmaskpose/sam2}/build_sam.py +34 -9
- {sam2 → bboxmaskpose/sam2}/colorblind.py +8 -16
- bboxmaskpose/sam2/configs/sam-pose2seg/sam-pose2seg_hiera_b+.yaml +118 -0
- {sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_b+.yaml +14 -14
- {sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_l.yaml +14 -14
- {sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_s.yaml +14 -14
- {sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_t.yaml +14 -14
- bboxmaskpose/sam2/configs/sam2.1_training/sam2.1_hiera_b+_COCO+CIHP_finetune_sam-pose2seg.yaml +343 -0
- {sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_COCO_1024_prompt.yaml +15 -23
- {sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_COCO_finetune.yaml +15 -24
- {sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_COCO_finetune_prompt+decoder.yaml +15 -24
- {sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +15 -21
- {sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_b+.yaml +14 -14
- {sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_l.yaml +14 -14
- {sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_s.yaml +14 -14
- {sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_t.yaml +14 -14
- {sam2 → bboxmaskpose/sam2}/csrc/connected_components.cu +0 -0
- {sam2 → bboxmaskpose/sam2}/distinctipy.py +7 -14
- {sam2 → bboxmaskpose/sam2}/modeling/__init__.py +0 -0
- {sam2 → bboxmaskpose/sam2}/modeling/backbones/__init__.py +0 -0
- {sam2 → bboxmaskpose/sam2}/modeling/backbones/hieradet.py +10 -31
- {sam2 → bboxmaskpose/sam2}/modeling/backbones/image_encoder.py +1 -3
- {sam2 → bboxmaskpose/sam2}/modeling/backbones/utils.py +2 -6
- {sam2 → bboxmaskpose/sam2}/modeling/memory_attention.py +4 -7
- {sam2 → bboxmaskpose/sam2}/modeling/memory_encoder.py +3 -9
- {sam2 → bboxmaskpose/sam2}/modeling/position_encoding.py +8 -31
- {sam2 → bboxmaskpose/sam2}/modeling/sam/__init__.py +0 -0
- {sam2 → bboxmaskpose/sam2}/modeling/sam/mask_decoder.py +11 -32
- {sam2 → bboxmaskpose/sam2}/modeling/sam/pose_encoder.py +7 -19
- {sam2 → bboxmaskpose/sam2}/modeling/sam/prompt_encoder.py +21 -26
- {sam2 → bboxmaskpose/sam2}/modeling/sam/transformer.py +12 -30
- {sam2 → bboxmaskpose/sam2}/modeling/sam2_base.py +72 -246
- {sam2 → bboxmaskpose/sam2}/modeling/sam2_base_pose.py +45 -87
- {sam2 → bboxmaskpose/sam2}/modeling/sam2_utils.py +5 -13
- {sam2 → bboxmaskpose/sam2}/sam2_image_predictor.py +32 -81
- {sam2 → bboxmaskpose/sam2}/sam2_video_predictor.py +35 -298
BMPv2_README.md
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
</h1><div id="toc">
|
| 2 |
+
<ul align="center" style="list-style: none; padding: 0; margin: 0;">
|
| 3 |
+
<summary>
|
| 4 |
+
<h1 style="margin-bottom: 0.0em;">
|
| 5 |
+
BBoxMaskPose v2
|
| 6 |
+
</h1>
|
| 7 |
+
</summary>
|
| 8 |
+
</ul>
|
| 9 |
+
</div>
|
| 10 |
+
</h1><div id="toc">
|
| 11 |
+
<ul align="center" style="list-style: none; padding: 0; margin: 0;">
|
| 12 |
+
<summary>
|
| 13 |
+
<h2 style="margin-bottom: 0.2em;">
|
| 14 |
+
CVPR 2025 + ICCV 2025
|
| 15 |
+
</h2>
|
| 16 |
+
</summary>
|
| 17 |
+
</ul>
|
| 18 |
+
</div>
|
| 19 |
+
|
| 20 |
+
<div align="center">
|
| 21 |
+
<img src="data/assets/BMP_043+076+174.gif" alt="BBoxMaskPose v2 loop" height="500px">
|
| 22 |
+
|
| 23 |
+
[](https://mirapurkrabek.github.io/BBox-Mask-Pose/)
|
| 24 |
+
[](LICENSE)
|
| 25 |
+
[](https://youtu.be/U05yUP4b2LQ)
|
| 26 |
+
|
| 27 |
+
[](https://arxiv.org/abs/2412.02254)
|
| 28 |
+
[](https://arxiv.org/abs/2412.01562)
|
| 29 |
+
[](https://arxiv.org/abs/2601.08982)
|
| 30 |
+
[](https://arxiv.org/abs/2601.15200)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
<!-- Papers with code:
|
| 35 |
+
[](https://paperswithcode.com/sota/2d-human-pose-estimation-on-ochuman?p=detection-pose-estimation-and-segmentation-1)
|
| 36 |
+
[](https://paperswithcode.com/sota/human-instance-segmentation-on-ochuman?p=detection-pose-estimation-and-segmentation-1) -->
|
| 37 |
+
|
| 38 |
+
</div>
|
| 39 |
+
|
| 40 |
+
> [!CAUTION]
|
| 41 |
+
> This branch is a **work in progress**!
|
| 42 |
+
>
|
| 43 |
+
> Until merged with <code>main</code>, use on your own discretion. For stable version, please refer to <code>main</code> branch with BMPv1.
|
| 44 |
+
|
| 45 |
+
## 📢 News
|
| 46 |
+
|
| 47 |
+
- **Feb 2026**: Version 2.0 with improved (1) pose and (2) SAM and (3) wiring to 3D prediction released.
|
| 48 |
+
- **Feb 2026**: SAM-pose2seg won a Best Paper Award on CVWW 2026 🎉
|
| 49 |
+
- **Jan 2026**: [BMPv2 paper](https://arxiv.org/abs/2601.15200) is available on arXiv
|
| 50 |
+
- **Aug 2025**: [HuggingFace Image Demo](https://huggingface.co/spaces/purkrmir/BBoxMaskPose-demo) is out! 🎮
|
| 51 |
+
- **Jul 2025**: Version 1.1 with easy-to-run image demo released
|
| 52 |
+
- **Jun 2025**: BMPv1 paper accepted to ICCV 2025! 🎉
|
| 53 |
+
- **Dec 2024**: BMPv1 code is available
|
| 54 |
+
- **Nov 2024**: The [project website](https://MiraPurkrabek.github.io/BBox-Mask-Pose) is on
|
| 55 |
+
|
| 56 |
+
## 📑 Table of Contents
|
| 57 |
+
|
| 58 |
+
- [Installation](#-installation)
|
| 59 |
+
- [Demo](#-demo)
|
| 60 |
+
- [API Examples](#api-examples)
|
| 61 |
+
- [Pre-trained Models](#-pre-trained-models)
|
| 62 |
+
- [Acknowledgments](#-acknowledgments)
|
| 63 |
+
- [Citation](#-citation)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
## 📋 Project Overview
|
| 67 |
+
|
| 68 |
+
Bounding boxes, masks, and poses capture complementary aspects of the human body. BBoxMaskPose links detection, segmentation, and pose estimation iteratively, where each prediction refines the others. PMPose combines probabilistic modeling with mask conditioning for robust pose estimation in crowds. Together, these components achieve state-of-the-art results on COCO and OCHuman, being the first method to exceed 50 AP on OCHuman.
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
### Repository Structure
|
| 72 |
+
|
| 73 |
+
The repository is organized into two main packages with stable public APIs:
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
BBoxMaskPose/
|
| 77 |
+
├── pmpose/ # PMPose package (pose estimation)
|
| 78 |
+
│ └── pmpose/
|
| 79 |
+
│ ├── api.py # PUBLIC API: PMPose class
|
| 80 |
+
│ ├── mm_utils.py # Internal utilities
|
| 81 |
+
│ └── posevis_lite.py # Visualization
|
| 82 |
+
├── mmpose/ # MMPose fork with our edits
|
| 83 |
+
├── bboxmaskpose/ # BBoxMaskPose package (full pipeline)
|
| 84 |
+
│ └── bboxmaskpose/
|
| 85 |
+
│ ├── api.py # PUBLIC API: BBoxMaskPose class
|
| 86 |
+
│ ├── sam2/ # SAM2 implementation
|
| 87 |
+
│ ├── configs/ # BMP configurations
|
| 88 |
+
│ └── *_utils.py # Internal utilities
|
| 89 |
+
├── demos/ # Public API demos
|
| 90 |
+
│ ├── PMPose_demo.py # PMPose usage example
|
| 91 |
+
│ ├── BMP_demo.py # BBoxMaskPose usage example
|
| 92 |
+
│ └── quickstart.ipynb # Interactive notebook
|
| 93 |
+
└── demo/ # Legacy demo (still functional)
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
Key contributions:
|
| 97 |
+
1. **MaskPose**: a pose estimation model conditioned by segmentation masks instead of bounding boxes, boosting performance in dense scenes without adding parameters
|
| 98 |
+
- Download pre-trained weights below
|
| 99 |
+
2. **BBox-MaskPose (BMP)**: method linking bounding boxes, segmentation masks, and poses to simultaneously address multi-body detection, segmentation and pose estimation
|
| 100 |
+
- Try the demo!
|
| 101 |
+
3. Fine-tuned RTMDet adapted for itterative detection (ignoring 'holes')
|
| 102 |
+
- Download pre-trained weights below
|
| 103 |
+
4. Support for multi-dataset training of ViTPose, previously implemented in the official ViTPose repository but absent in MMPose.
|
| 104 |
+
|
| 105 |
+
For more details, please visit our [project website](https://mirapurkrabek.github.io/BBox-Mask-Pose/).
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
## 🚀 Installation
|
| 110 |
+
|
| 111 |
+
### Docker Installation (Recommended)
|
| 112 |
+
|
| 113 |
+
The fastest way to get started with GPU support:
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
# Clone and build
|
| 117 |
+
git clone https://github.com/mirapurkrabek/BBoxMaskPose.git
|
| 118 |
+
cd BBoxMaskPose
|
| 119 |
+
docker-compose build
|
| 120 |
+
|
| 121 |
+
# Run the demo
|
| 122 |
+
docker-compose up
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
Requires: Docker Engine 19.03+, [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), NVIDIA GPU with CUDA 12.1 support.
|
| 126 |
+
|
| 127 |
+
### Manual Installation
|
| 128 |
+
|
| 129 |
+
This project is built on top of [MMPose](https://github.com/open-mmlab/mmpose) and [SAM 2.1](https://github.com/facebookresearch/sam2).
|
| 130 |
+
Please refer to the [MMPose installation guide](https://mmpose.readthedocs.io/en/latest/installation.html) or [SAM installation guide](https://github.com/facebookresearch/sam2/blob/main/INSTALL.md) for detailed setup instructions.
|
| 131 |
+
|
| 132 |
+
Basic installation steps:
|
| 133 |
+
```bash
|
| 134 |
+
# Clone the repository
|
| 135 |
+
git clone https://github.com/mirapurkrabek/BBoxMaskPose.git BBoxMaskPose/
|
| 136 |
+
cd BBoxMaskPose
|
| 137 |
+
|
| 138 |
+
# Install your version of torch, torchvision, OpenCV and NumPy
|
| 139 |
+
pip install torch==2.1.2+cu121 torchvision==0.16.2+cu121 --extra-index-url https://download.pytorch.org/whl/cu121
|
| 140 |
+
pip install numpy==1.25.1 opencv-python==4.9.0.80
|
| 141 |
+
|
| 142 |
+
# Install MMLibrary
|
| 143 |
+
pip install -U openmim
|
| 144 |
+
mim install mmengine "mmcv==2.1.0" "mmdet==3.3.0" "mmpretrain==1.2.0"
|
| 145 |
+
|
| 146 |
+
# Install dependencies
|
| 147 |
+
pip install -r requirements.txt
|
| 148 |
+
pip install -e .
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
## 🎮 Demo
|
| 152 |
+
|
| 153 |
+
#### PMPose Demo (Pose Estimation Only)
|
| 154 |
+
```bash
|
| 155 |
+
python demos/PMPose_demo.py --image data/004806.jpg --device cuda
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
#### BBoxMaskPose Demo (Full Pipeline)
|
| 159 |
+
```bash
|
| 160 |
+
python demos/BMP_demo.py --image data/004806.jpg --device cuda
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
After running the demo, outputs are in `outputs/004806/`. The expected output should look like this:
|
| 164 |
+
<div align="center">
|
| 165 |
+
<a href="data/assets/004806_mask.jpg" target="_blank">
|
| 166 |
+
<img src="data/assets/004806_mask.jpg" alt="Detection results" width="200" />
|
| 167 |
+
</a>
|
| 168 |
+
    
|
| 169 |
+
<a href="data/assets/004806_pose.jpg" target="_blank">
|
| 170 |
+
<img src="data/assets/004806_pose.jpg" alt="Pose results" width="200" style="margin-right:10px;" />
|
| 171 |
+
</a>
|
| 172 |
+
</div>
|
| 173 |
+
|
| 174 |
+
#### BBoxMaskPose v2 Demo (Full Pipeline + 3D Mesh Recovery)
|
| 175 |
+
This demo extends BMP with [SAM-3D-Body](https://github.com/facebookresearch/sam-3d-body) for 3D human mesh recovery:
|
| 176 |
+
```bash
|
| 177 |
+
# Basic usage (auto-downloads checkpoint from HuggingFace)
|
| 178 |
+
python demos/BMPv2_demo.py --image data/004806.jpg --device cuda
|
| 179 |
+
|
| 180 |
+
# With local checkpoint
|
| 181 |
+
python demos/BMPv2_demo.py --image data/004806.jpg --device cuda \
|
| 182 |
+
--sam3d_checkpoint checkpoints/sam-3d-body-dinov3/model.ckpt \
|
| 183 |
+
--mhr_path checkpoints/sam-3d-body-dinov3/assets/mhr_model.pt
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
**SAM-3D-Body Installation (Optional):**
|
| 187 |
+
BMPv2 requires SAM-3D-Body for 3D mesh recovery. Install it separately:
|
| 188 |
+
```bash
|
| 189 |
+
# 1. Install dependencies
|
| 190 |
+
pip install -r requirements/sam3d.txt
|
| 191 |
+
|
| 192 |
+
# 2. Install detectron2
|
| 193 |
+
pip install 'git+https://github.com/facebookresearch/detectron2.git@a1ce2f9' --no-build-isolation --no-deps
|
| 194 |
+
|
| 195 |
+
# 3. Install MoGe (optional, for FOV estimation)
|
| 196 |
+
pip install git+https://github.com/microsoft/MoGe.git
|
| 197 |
+
|
| 198 |
+
# 4. Install adapted SAM-3D-Body repository
|
| 199 |
+
pip install git+https://github.com/MiraPurkrabek/sam-3d-body.git
|
| 200 |
+
|
| 201 |
+
# 5. Request access to checkpoints at https://huggingface.co/facebook/sam-3d-body-dinov3
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
For more details, see [SAM-3D-Body installation guide](https://github.com/facebookresearch/sam-3d-body/blob/main/INSTALL.md).
|
| 205 |
+
|
| 206 |
+
#### Jupyter Notebook
|
| 207 |
+
Interactive demo with both PMPose and BBoxMaskPose:
|
| 208 |
+
```bash
|
| 209 |
+
jupyter notebook demos/quickstart.ipynb
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
## API Examples
|
| 213 |
+
|
| 214 |
+
**PMPose API** - Pose estimation with bounding boxes:
|
| 215 |
+
```python
|
| 216 |
+
from pmpose import PMPose
|
| 217 |
+
|
| 218 |
+
# Initialize model
|
| 219 |
+
pose_model = PMPose(device="cuda", from_pretrained=True)
|
| 220 |
+
|
| 221 |
+
# Run inference
|
| 222 |
+
keypoints, presence, visibility, heatmaps = pose_model.predict(
|
| 223 |
+
image="demo/data/004806.jpg",
|
| 224 |
+
bboxes=[[100, 100, 300, 400]], # [x1, y1, x2, y2]
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Visualize
|
| 228 |
+
vis_img = pose_model.visualize(image="demo/data/004806.jpg", keypoints=keypoints)
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
**BBoxMaskPose API** - Full detection + pose + segmentation:
|
| 232 |
+
|
| 233 |
+
```python
|
| 234 |
+
from pmpose import PMPose
|
| 235 |
+
from bboxmaskpose import BBoxMaskPose
|
| 236 |
+
|
| 237 |
+
# Create pose model
|
| 238 |
+
pose_model = PMPose(device="cuda", from_pretrained=True)
|
| 239 |
+
|
| 240 |
+
# Inject into BMP
|
| 241 |
+
bmp_model = BBoxMaskPose(config="BMP_D3", device="cuda", pose_model=pose_model)
|
| 242 |
+
result = bmp_model.predict(image="demo/data/004806.jpg")
|
| 243 |
+
|
| 244 |
+
# Visualize
|
| 245 |
+
vis_img = bmp_model.visualize(image="demo/data/004806.jpg", result=result)
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
## 📦 Pre-trained Models
|
| 250 |
+
|
| 251 |
+
Pre-trained models are available on [VRG Hugging Face 🤗](https://huggingface.co/vrg-prague/BBoxMaskPose/).
|
| 252 |
+
To run the demo, you only need do download SAM weights with [enclosed script](models/SAM/download_ckpts.sh).
|
| 253 |
+
Our detector and pose estimator will be downloaded during the runtime.
|
| 254 |
+
|
| 255 |
+
If you want to download our weights yourself, here are the links to our HuggingFace:
|
| 256 |
+
- ViTPose-b trained on COCO+MPII+AIC -- [download weights](https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/ViTPose-b-multi_mmpose20.pth)
|
| 257 |
+
- MaskPose-b -- [download weights](https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/MaskPose-b.pth)
|
| 258 |
+
- Fine-tuned RTMDet-L -- [download weights](https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/rtmdet-ins-l-mask.pth)
|
| 259 |
+
|
| 260 |
+
## 🙏 Acknowledgments
|
| 261 |
+
|
| 262 |
+
The code combines [MMDetection](https://github.com/open-mmlab/mmdetection), [MMPose 2.0](https://github.com/open-mmlab/mmpose), [ViTPose](https://github.com/ViTAE-Transformer/ViTPose), [SAM 2.1](https://github.com/facebookresearch/sam2) and [SAM-3D-Body](https://github.com/facebookresearch/sam-3d-body).
|
| 263 |
+
|
| 264 |
+
Our visualizations integrate [Distinctipy](https://github.com/alan-turing-institute/distinctipy) for automatic color selection.
|
| 265 |
+
|
| 266 |
+
This repository combines our work on BBoxMaskPose project with our previous work on [probabilistic 2D human pose estimation modelling](https://mirapurkrabek.github.io/ProbPose/).
|
| 267 |
+
|
| 268 |
+
## 📝 Citation
|
| 269 |
+
|
| 270 |
+
The code was implemented by [Miroslav Purkrábek](https://mirapurkrabek.github.io/) and Constantin Kolomiiets.
|
| 271 |
+
If you use this work, kindly cite it using the references provided below.
|
| 272 |
+
|
| 273 |
+
For questions, please use the Issues of Discussion.
|
| 274 |
+
|
| 275 |
+
```
|
| 276 |
+
@InProceedings{Purkrabek2025BMPv1,
|
| 277 |
+
author = {Purkrabek, Miroslav and Matas, Jiri},
|
| 278 |
+
title = {Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle},
|
| 279 |
+
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
|
| 280 |
+
month = {October},
|
| 281 |
+
year = {2025},
|
| 282 |
+
pages = {9004-9013}
|
| 283 |
+
}
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
```
|
| 287 |
+
@InProceedings{Purkrabek2026BMPv2,
|
| 288 |
+
author = {Purkrabek, Miroslav and Kolomiiets, Constantin and Matas, Jiri},
|
| 289 |
+
title = {BBoxMaskPose v2: Expanding Mutual Conditioning to 3D},
|
| 290 |
+
booktitle = {arXiv preprint arXiv:2601.15200},
|
| 291 |
+
year = {2026}
|
| 292 |
+
}
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
```
|
| 296 |
+
@article{yang2025sam3dbody,
|
| 297 |
+
title={SAM 3D Body: Robust Full-Body Human Mesh Recovery},
|
| 298 |
+
author={Yang, Xitong and Kukreja, Devansh and Pinkus, Don and Sagar, Anushka and Fan, Taosha and Park, Jinhyung and Shin, Soyong and Cao, Jinkun and Liu, Jiawei and Ugrinovic, Nicolas and Feiszli, Matt and Malik, Jitendra and Dollar, Piotr and Kitani, Kris},
|
| 299 |
+
journal={arXiv preprint; identifier to be added},
|
| 300 |
+
year={2025}
|
| 301 |
+
}
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
```
|
| 305 |
+
@InProceedings{Kolomiiets2026CVWW,
|
| 306 |
+
author = {Kolomiiets, Constantin and Purkrabek, Miroslav and Matas, Jiri},
|
| 307 |
+
title = {SAM-pose2seg: Pose-Guided Human Instance Segmentation in Crowds},
|
| 308 |
+
booktitle = {Computer Vision Winter Workshop (CVWW)},
|
| 309 |
+
year = {2026}
|
| 310 |
+
}
|
| 311 |
+
```
|
SAM3D_INTEGRATION.md
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SAM-3D-Body Integration Guide
|
| 2 |
+
|
| 3 |
+
This guide explains how to integrate and use SAM-3D-Body for 3D human mesh recovery within the BBoxMaskPose pipeline.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
BBoxMaskPose v2 extends the original BMP pipeline with [SAM-3D-Body](https://github.com/facebookresearch/sam-3d-body) from Meta AI, enabling full 3D human mesh recovery from single images. The integration leverages BMP's high-quality 2D pose estimates and segmentation masks as prompts to SAM-3D-Body, resulting in accurate 3D reconstructions even in crowded scenes.
|
| 8 |
+
|
| 9 |
+
**Pipeline Flow:**
|
| 10 |
+
```
|
| 11 |
+
Input Image
|
| 12 |
+
↓
|
| 13 |
+
BBoxMaskPose (Detection + 2D Pose + Segmentation)
|
| 14 |
+
↓
|
| 15 |
+
2D Bboxes + Masks + Poses
|
| 16 |
+
↓
|
| 17 |
+
SAM-3D-Body (3D Mesh Recovery)
|
| 18 |
+
↓
|
| 19 |
+
3D Human Meshes (vertices, joints, faces)
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
## Installation
|
| 23 |
+
|
| 24 |
+
### Prerequisites
|
| 25 |
+
|
| 26 |
+
- BBoxMaskPose must be already installed and working
|
| 27 |
+
- CUDA-capable GPU recommended (CPU inference is very slow)
|
| 28 |
+
- Python 3.8+ (Python 3.11 recommended for SAM-3D-Body)
|
| 29 |
+
|
| 30 |
+
### Step 1: Install SAM-3D-Body Dependencies
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
# Navigate to BBoxMaskPose root directory
|
| 34 |
+
cd /path/to/BBoxMaskPose
|
| 35 |
+
|
| 36 |
+
# Install SAM-3D-Body dependencies
|
| 37 |
+
pip install -r requirements/sam3d.txt
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
### Step 2: Install Detectron2
|
| 41 |
+
|
| 42 |
+
SAM-3D-Body requires a specific version of Detectron2:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
pip install 'git+https://github.com/facebookresearch/detectron2.git@a1ce2f9' \
|
| 46 |
+
--no-build-isolation --no-deps
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
### Step 3: Install MoGe (Optional but Recommended)
|
| 50 |
+
|
| 51 |
+
MoGe provides FOV (field-of-view) estimation for better camera calibration:
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
pip install git+https://github.com/microsoft/MoGe.git
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Step 4: Install SAM-3D-Body
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
# Install adapted SAM-3D-Body repository
|
| 61 |
+
pip install git+https://github.com/MiraPurkrabek/sam-3d-body.git
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### Step 5: Get Model Checkpoints
|
| 65 |
+
|
| 66 |
+
SAM-3D-Body checkpoints are hosted on HuggingFace. You need to:
|
| 67 |
+
|
| 68 |
+
1. **Request access** at [facebook/sam-3d-body-dinov3](https://huggingface.co/facebook/sam-3d-body-dinov3)
|
| 69 |
+
2. **Wait for approval** (usually within 24 hours)
|
| 70 |
+
3. **Authenticate** with HuggingFace:
|
| 71 |
+
```bash
|
| 72 |
+
pip install huggingface_hub
|
| 73 |
+
huggingface-cli login
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
The BMPv2 demo will auto-download the checkpoint on first use, or you can download manually to the default location for auto-detection:
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
# Download checkpoint manually to default location (will be auto-detected)
|
| 80 |
+
mkdir -p checkpoints
|
| 81 |
+
huggingface-cli download facebook/sam-3d-body-dinov3 \
|
| 82 |
+
--local-dir checkpoints/sam-3d-body-dinov3
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Usage
|
| 86 |
+
|
| 87 |
+
### Basic Usage
|
| 88 |
+
|
| 89 |
+
Run the BMPv2 demo with automatic checkpoint handling:
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
python demos/BMPv2_demo.py --image data/004806.jpg --device cuda
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
**The demo will:**
|
| 96 |
+
1. Auto-detect checkpoint in `checkpoints/sam-3d-body-dinov3/` OR download from HuggingFace (~3.5 GB)
|
| 97 |
+
2. Run BMP pipeline to get 2D detections, poses, and masks
|
| 98 |
+
3. Run SAM-3D-Body to recover 3D meshes
|
| 99 |
+
4. Save visualizations to `demos/outputs/bboxmaskpose_v2/`
|
| 100 |
+
|
| 101 |
+
### Advanced Usage
|
| 102 |
+
|
| 103 |
+
#### Use Local Checkpoint (Auto-Detection)
|
| 104 |
+
|
| 105 |
+
Download checkpoint to the default location for automatic detection:
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
# The demo automatically detects checkpoints in this location
|
| 109 |
+
huggingface-cli download facebook/sam-3d-body-dinov3 \
|
| 110 |
+
--local-dir checkpoints/sam-3d-body-dinov3
|
| 111 |
+
|
| 112 |
+
# Then just run the demo - no checkpoint arguments needed!
|
| 113 |
+
python demos/BMPv2_demo.py --image data/004806.jpg --device cuda
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
#### Use Custom Checkpoint Path
|
| 117 |
+
|
| 118 |
+
If your checkpoint is in a different location:
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
python demos/BMPv2_demo.py \
|
| 122 |
+
--image data/004806.jpg \
|
| 123 |
+
--device cuda \
|
| 124 |
+
--sam3d_checkpoint /path/to/model.ckpt \
|
| 125 |
+
--mhr_path /path/to/mhr_model.pt
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
#### Speed vs Quality Trade-offs
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
# Fastest: body-only inference without mask conditioning
|
| 132 |
+
python demos/BMPv2_demo.py --image data/004806.jpg \
|
| 133 |
+
--inference_type body --no_mask_conditioning
|
| 134 |
+
|
| 135 |
+
# Balanced: body-only with mask conditioning
|
| 136 |
+
python demos/BMPv2_demo.py --image data/004806.jpg \
|
| 137 |
+
--inference_type body
|
| 138 |
+
|
| 139 |
+
# Best quality: full inference with mask conditioning (default)
|
| 140 |
+
python demos/BMPv2_demo.py --image data/004806.jpg
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
#### Disable Mask Conditioning
|
| 144 |
+
|
| 145 |
+
Faster but less accurate (doesn't use segmentation masks as prompts):
|
| 146 |
+
|
| 147 |
+
```bash
|
| 148 |
+
python demos/BMPv2_demo.py \
|
| 149 |
+
--image data/004806.jpg \
|
| 150 |
+
--no_mask_conditioning
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
#### Skip 3D Recovery
|
| 154 |
+
|
| 155 |
+
Run only BMP pipeline (useful for testing BMP without SAM-3D-Body):
|
| 156 |
+
|
| 157 |
+
```bash
|
| 158 |
+
python demos/BMPv2_demo.py \
|
| 159 |
+
--image data/004806.jpg \
|
| 160 |
+
--skip_3d
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
### Output Files
|
| 164 |
+
|
| 165 |
+
The demo saves the following visualizations:
|
| 166 |
+
|
| 167 |
+
- `{image_name}_bmp_pose.jpg` - 2D pose estimation results
|
| 168 |
+
- `{image_name}_bmp_mask.jpg` - Segmentation mask results
|
| 169 |
+
- `{image_name}_3d_mesh.jpg` - 3D mesh overlay on image
|
| 170 |
+
- `{image_name}_combined.jpg` - Side-by-side comparison of all results
|
| 171 |
+
|
| 172 |
+
## Programmatic API
|
| 173 |
+
|
| 174 |
+
You can also use SAM-3D-Body programmatically:
|
| 175 |
+
|
| 176 |
+
```python
|
| 177 |
+
from bboxmaskpose import BBoxMaskPose
|
| 178 |
+
from bboxmaskpose.sam3d_utils import SAM3DBodyWrapper, visualize_3d_meshes
|
| 179 |
+
|
| 180 |
+
# Step 1: Run BMP pipeline
|
| 181 |
+
bmp = BBoxMaskPose(config="bmp_D3", device="cuda")
|
| 182 |
+
result = bmp.predict(image="path/to/image.jpg")
|
| 183 |
+
|
| 184 |
+
# Step 2: Initialize SAM-3D-Body
|
| 185 |
+
sam3d = SAM3DBodyWrapper(device="cuda")
|
| 186 |
+
|
| 187 |
+
# Step 3: Predict 3D meshes from BMP outputs
|
| 188 |
+
outputs_3d = sam3d.predict(
|
| 189 |
+
image="path/to/image.jpg",
|
| 190 |
+
bboxes=result['bboxes'],
|
| 191 |
+
masks=result['masks'],
|
| 192 |
+
use_mask=True,
|
| 193 |
+
inference_type="full", # Options: "full", "body", "hand"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Step 4: Visualize results
|
| 197 |
+
import cv2
|
| 198 |
+
img = cv2.imread("path/to/image.jpg")
|
| 199 |
+
vis = visualize_3d_meshes(img, outputs_3d, sam3d.faces)
|
| 200 |
+
cv2.imwrite("output_3d.jpg", vis)
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
### Access 3D Mesh Data
|
| 204 |
+
|
| 205 |
+
Each element in `outputs_3d` is a dictionary containing:
|
| 206 |
+
|
| 207 |
+
```python
|
| 208 |
+
output_3d[0].keys()
|
| 209 |
+
# dict_keys(['vertices', 'joints', 'bbox', 'mask', ...])
|
| 210 |
+
|
| 211 |
+
# 3D mesh vertices in camera coordinates (V, 3)
|
| 212 |
+
vertices = outputs_3d[0]['vertices']
|
| 213 |
+
|
| 214 |
+
# 3D joint locations (J, 3)
|
| 215 |
+
joints_3d = outputs_3d[0]['joints']
|
| 216 |
+
|
| 217 |
+
# Mesh faces (shared across all people)
|
| 218 |
+
faces = sam3d.faces # (F, 3)
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
## Integration Architecture
|
| 222 |
+
|
| 223 |
+
### Wrapper Design
|
| 224 |
+
|
| 225 |
+
The integration follows BBoxMaskPose's modular design pattern:
|
| 226 |
+
|
| 227 |
+
```
|
| 228 |
+
bboxmaskpose/
|
| 229 |
+
├── sam3d_utils.py # SAM-3D-Body wrapper (new)
|
| 230 |
+
│ ├── SAM3DBodyWrapper # Main wrapper class
|
| 231 |
+
│ ├── visualize_3d_meshes # Visualization helper
|
| 232 |
+
│ └── check_sam3d_available
|
| 233 |
+
│
|
| 234 |
+
demos/
|
| 235 |
+
├── BMP_demo.py # Original BMP demo
|
| 236 |
+
└── BMPv2_demo.py # New demo with 3D (new)
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
### Why a Wrapper?
|
| 240 |
+
|
| 241 |
+
The `SAM3DBodyWrapper` class:
|
| 242 |
+
- **Simplifies** SAM-3D-Body's complex initialization
|
| 243 |
+
- **Adapts** BMP outputs (bboxes, masks) to SAM-3D-Body inputs
|
| 244 |
+
- **Handles** optional dependencies gracefully (no hard requirement)
|
| 245 |
+
- **Follows** BMP's design patterns (similar to PMPose wrapper)
|
| 246 |
+
|
| 247 |
+
### Key Design Decisions
|
| 248 |
+
|
| 249 |
+
1. **Optional Dependency**: SAM-3D-Body is not required for core BMP functionality
|
| 250 |
+
2. **No Code Duplication**: Reuses SAM-3D-Body's existing code via wrapper
|
| 251 |
+
3. **Mask Conditioning**: Leverages BMP's high-quality masks as prompts
|
| 252 |
+
4. **No Internal Detector**: Disables SAM-3D-Body's detector (BMP already detects)
|
| 253 |
+
|
| 254 |
+
## Troubleshooting
|
| 255 |
+
|
| 256 |
+
### Import Error: `sam_3d_body` not found
|
| 257 |
+
|
| 258 |
+
**Solution**: Install SAM-3D-Body following Step 4 above.
|
| 259 |
+
|
| 260 |
+
### HuggingFace Authentication Error
|
| 261 |
+
|
| 262 |
+
**Solution**:
|
| 263 |
+
1. Request access at https://huggingface.co/facebook/sam-3d-body-dinov3
|
| 264 |
+
2. Login: `huggingface-cli login`
|
| 265 |
+
|
| 266 |
+
### MoGe Import Error (FOV Estimator)
|
| 267 |
+
|
| 268 |
+
**Solution**: Either:
|
| 269 |
+
- Install MoGe: `pip install git+https://github.com/microsoft/MoGe.git`
|
| 270 |
+
- Or disable FOV estimation (uses default FOV instead)
|
| 271 |
+
|
| 272 |
+
### Detectron2 Build Errors
|
| 273 |
+
|
| 274 |
+
**Solution**: Make sure you have:
|
| 275 |
+
- CUDA toolkit installed and matching PyTorch CUDA version
|
| 276 |
+
- GCC/G++ compiler available
|
| 277 |
+
- Use the exact commit hash: `@a1ce2f9`
|
| 278 |
+
|
| 279 |
+
## References
|
| 280 |
+
|
| 281 |
+
- **SAM-3D-Body**: [GitHub](https://github.com/facebookresearch/sam-3d-body) | [Paper](https://ai.meta.com/research/publications/sam-3d-body-robust-full-body-human-mesh-recovery/)
|
| 282 |
+
- **BBoxMaskPose**: [GitHub](https://github.com/MiraPurkrabek/BBoxMaskPose) | [Paper](https://arxiv.org/abs/2601.15200)
|
| 283 |
+
|
| 284 |
+
## Citation
|
| 285 |
+
|
| 286 |
+
If you use this integration, please cite both works:
|
| 287 |
+
|
| 288 |
+
```bibtex
|
| 289 |
+
@article{yang2025sam3dbody,
|
| 290 |
+
title={SAM 3D Body: Robust Full-Body Human Mesh Recovery},
|
| 291 |
+
author={Yang, Xitong and Kukreja, Devansh and Pinkus, Don and Sagar, Anushka and Fan, Taosha and Park, Jinhyung and Shin, Soyong and Cao, Jinkun and Liu, Jiawei and Ugrinovic, Nicolas and Feiszli, Matt and Malik, Jitendra and Dollar, Piotr and Kitani, Kris},
|
| 292 |
+
journal={arXiv preprint; identifier to be added},
|
| 293 |
+
year={2025}
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
@InProceedings{Purkrabek2026BMPv2,
|
| 297 |
+
author = {Purkrabek, Miroslav and Kolomiiets, Constantin and Matas, Jiri},
|
| 298 |
+
title = {BBoxMaskPose v2: Expanding Mutual Conditioning to 3D},
|
| 299 |
+
booktitle = {arXiv preprint arXiv:2601.15200},
|
| 300 |
+
year = {2026}
|
| 301 |
+
}
|
| 302 |
+
```
|
app.py
CHANGED
|
@@ -1,188 +1,204 @@
|
|
| 1 |
-
|
| 2 |
-
import spaces
|
| 3 |
-
|
| 4 |
-
from pathlib import Path
|
| 5 |
|
|
|
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
-
import
|
| 8 |
-
from demo.demo_utils import DotDict, concat_instances, filter_instances, pose_nms, visualize_demo
|
| 9 |
-
from demo.mm_utils import run_MMDetector, run_MMPose
|
| 10 |
-
from mmdet.apis import init_detector
|
| 11 |
-
from demo.sam2_utils import prepare_model as prepare_sam2_model
|
| 12 |
-
from demo.sam2_utils import process_image_with_SAM
|
| 13 |
-
|
| 14 |
-
from mmpose.apis import init_model as init_pose_estimator
|
| 15 |
-
from mmpose.utils import adapt_mmdet_pipeline
|
| 16 |
-
|
| 17 |
-
# Default thresholds
|
| 18 |
-
DEFAULT_CAT_ID: int = 0
|
| 19 |
-
|
| 20 |
-
DEFAULT_BBOX_THR: float = 0.3
|
| 21 |
-
DEFAULT_NMS_THR: float = 0.3
|
| 22 |
-
DEFAULT_KPT_THR: float = 0.3
|
| 23 |
-
|
| 24 |
-
# Global models variable
|
| 25 |
-
det_model = None
|
| 26 |
-
pose_model = None
|
| 27 |
-
sam2_model = None
|
| 28 |
-
|
| 29 |
-
def _parse_yaml_config(yaml_path: Path) -> DotDict:
|
| 30 |
-
"""
|
| 31 |
-
Load BMP configuration from a YAML file.
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
yaml_path (Path): Path to YAML config.
|
| 35 |
-
Returns:
|
| 36 |
-
DotDict: Nested config dictionary.
|
| 37 |
-
"""
|
| 38 |
-
with open(yaml_path, "r") as f:
|
| 39 |
-
cfg = yaml.safe_load(f)
|
| 40 |
-
return DotDict(cfg)
|
| 41 |
-
|
| 42 |
-
def load_models(bmp_config):
|
| 43 |
-
device = 'cuda:0'
|
| 44 |
-
|
| 45 |
-
global det_model, pose_model, sam2_model
|
| 46 |
-
|
| 47 |
-
# build detectors
|
| 48 |
-
det_model = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device='cpu') # Detect with CPU because of installation issues on HF
|
| 49 |
-
det_model.cfg = adapt_mmdet_pipeline(det_model.cfg)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
# build pose estimator
|
| 53 |
-
pose_model = init_pose_estimator(
|
| 54 |
-
bmp_config.pose_estimator.pose_config,
|
| 55 |
-
bmp_config.pose_estimator.pose_checkpoint,
|
| 56 |
-
device=device,
|
| 57 |
-
cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))),
|
| 58 |
-
)
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
)
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
@spaces.GPU(duration=60)
|
| 68 |
def process_image_with_BMP(
|
| 69 |
img: np.ndarray
|
| 70 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 71 |
"""
|
| 72 |
-
Run
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
img_path (Path): Path to the input image.
|
| 78 |
-
detector: Primary MMDetection model.
|
| 79 |
-
detector_prime: Secondary MMDetection model for iterations.
|
| 80 |
-
pose_estimator: MMPose model for keypoint estimation.
|
| 81 |
-
sam2_model: SAM model for mask refinement.
|
| 82 |
-
Returns:
|
| 83 |
-
InstanceData: Final merged detections and refined masks.
|
| 84 |
"""
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
#
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
det_instances = run_MMDetector(
|
| 98 |
-
det_model,
|
| 99 |
-
img_for_detection,
|
| 100 |
-
det_cat_id=DEFAULT_CAT_ID,
|
| 101 |
-
bbox_thr=DEFAULT_BBOX_THR,
|
| 102 |
-
nms_thr=DEFAULT_NMS_THR,
|
| 103 |
-
)
|
| 104 |
-
if len(det_instances.bboxes) == 0:
|
| 105 |
-
continue
|
| 106 |
-
|
| 107 |
-
# Step 2: Pose estimation
|
| 108 |
-
pose_instances = run_MMPose(
|
| 109 |
-
pose_model,
|
| 110 |
-
img.copy(),
|
| 111 |
-
detections=det_instances,
|
| 112 |
-
kpt_thr=DEFAULT_KPT_THR,
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
# Restrict to first 17 COCO keypoints
|
| 116 |
-
pose_instances.keypoints = pose_instances.keypoints[:, :17, :]
|
| 117 |
-
pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17]
|
| 118 |
-
pose_instances.keypoints = np.concatenate(
|
| 119 |
-
[pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
# Step 3: Pose-NMS and SAM refinement
|
| 123 |
-
all_keypoints = (
|
| 124 |
-
pose_instances.keypoints
|
| 125 |
-
if all_detections is None
|
| 126 |
-
else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0)
|
| 127 |
-
)
|
| 128 |
-
all_bboxes = (
|
| 129 |
-
pose_instances.bboxes
|
| 130 |
-
if all_detections is None
|
| 131 |
-
else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
|
| 132 |
-
)
|
| 133 |
-
num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1)
|
| 134 |
-
keep_indices = pose_nms(
|
| 135 |
-
DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}),
|
| 136 |
-
image_kpts=all_keypoints,
|
| 137 |
-
image_bboxes=all_bboxes,
|
| 138 |
-
num_valid_kpts=num_valid_kpts,
|
| 139 |
-
)
|
| 140 |
-
keep_indices = sorted(keep_indices) # Sort by original index
|
| 141 |
-
num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
|
| 142 |
-
keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
|
| 143 |
-
keep_old_indices = [i for i in keep_indices if i < num_old_detections]
|
| 144 |
-
if len(keep_new_indices) == 0:
|
| 145 |
-
continue
|
| 146 |
-
# filter new detections and compute scores
|
| 147 |
-
new_dets = filter_instances(pose_instances, keep_new_indices)
|
| 148 |
-
new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
|
| 149 |
-
old_dets = None
|
| 150 |
-
if len(keep_old_indices) > 0:
|
| 151 |
-
old_dets = filter_instances(all_detections, keep_old_indices)
|
| 152 |
-
|
| 153 |
-
new_detections = process_image_with_SAM(
|
| 154 |
-
DotDict(bmp_config.sam2.prompting),
|
| 155 |
-
img.copy(),
|
| 156 |
-
sam2_model,
|
| 157 |
-
new_dets,
|
| 158 |
-
old_dets if old_dets is not None else None,
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
-
# Merge detections
|
| 162 |
-
if all_detections is None:
|
| 163 |
-
all_detections = new_detections
|
| 164 |
-
else:
|
| 165 |
-
all_detections = concat_instances(all_detections, new_dets)
|
| 166 |
-
|
| 167 |
-
# Step 4: Visualization
|
| 168 |
-
img_for_detection, rtmdet_r, _ = visualize_demo(
|
| 169 |
-
img.copy(),
|
| 170 |
-
all_detections,
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
if iteration == 0:
|
| 174 |
-
rtmdet_result = rtmdet_r
|
| 175 |
-
|
| 176 |
-
_, _, bmp_result = visualize_demo(
|
| 177 |
-
img.copy(),
|
| 178 |
-
all_detections,
|
| 179 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
bmp_result = bmp_result[..., ::-1]
|
| 184 |
|
| 185 |
-
|
|
|
|
| 186 |
|
| 187 |
|
| 188 |
with gr.Blocks() as app:
|
|
@@ -281,4 +297,4 @@ with gr.Blocks() as app:
|
|
| 281 |
)
|
| 282 |
|
| 283 |
# Launch the demo
|
| 284 |
-
app.launch()
|
|
|
|
| 1 |
+
from typing import Any
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
import cv2
|
| 4 |
+
import gradio as gr
|
| 5 |
import numpy as np
|
| 6 |
+
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
from bboxmaskpose import BBoxMaskPose
|
| 9 |
+
|
| 10 |
+
# Global BMP model singleton
|
| 11 |
+
bmp_model = None
|
| 12 |
+
bmp_model_config = "bmp_v2"
|
| 13 |
+
bmp_model_device = "cuda:0"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _to_numpy(value: Any):
|
| 17 |
+
"""Convert model outputs to numpy arrays when needed."""
|
| 18 |
+
if value is None:
|
| 19 |
+
return None
|
| 20 |
+
if isinstance(value, np.ndarray):
|
| 21 |
+
return value
|
| 22 |
+
if hasattr(value, "detach"):
|
| 23 |
+
return value.detach().cpu().numpy()
|
| 24 |
+
if hasattr(value, "cpu") and hasattr(value, "numpy"):
|
| 25 |
+
return value.cpu().numpy()
|
| 26 |
+
return np.asarray(value)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _empty_result(height: int, width: int) -> dict[str, np.ndarray]:
|
| 30 |
+
"""Create an empty result dictionary compatible with BBoxMaskPose.visualize()."""
|
| 31 |
+
return {
|
| 32 |
+
"bboxes": np.zeros((0, 4), dtype=np.float32),
|
| 33 |
+
"masks": np.zeros((0, height, width), dtype=np.uint8),
|
| 34 |
+
"keypoints": np.zeros((0, 17, 3), dtype=np.float32),
|
| 35 |
+
"presence": np.zeros((0, 17), dtype=np.float32),
|
| 36 |
+
"visibility": np.zeros((0, 17), dtype=np.float32),
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _normalize_result(result: dict, height: int, width: int) -> dict[str, np.ndarray]:
|
| 41 |
+
"""Normalize prediction dictionary into a robust shape for visualization."""
|
| 42 |
+
bboxes = _to_numpy(result.get("bboxes"))
|
| 43 |
+
keypoints = _to_numpy(result.get("keypoints"))
|
| 44 |
+
masks = _to_numpy(result.get("masks"))
|
| 45 |
+
presence = _to_numpy(result.get("presence"))
|
| 46 |
+
visibility = _to_numpy(result.get("visibility"))
|
| 47 |
+
|
| 48 |
+
if bboxes is None:
|
| 49 |
+
bboxes = np.zeros((0, 4), dtype=np.float32)
|
| 50 |
+
bboxes = np.asarray(bboxes, dtype=np.float32).reshape(-1, 4)
|
| 51 |
+
num_instances = bboxes.shape[0]
|
| 52 |
+
|
| 53 |
+
if keypoints is None:
|
| 54 |
+
keypoints = np.zeros((num_instances, 17, 3), dtype=np.float32)
|
| 55 |
+
keypoints = np.asarray(keypoints, dtype=np.float32)
|
| 56 |
+
if keypoints.ndim == 2:
|
| 57 |
+
keypoints = keypoints[None, ...]
|
| 58 |
+
if keypoints.shape[0] != num_instances:
|
| 59 |
+
keypoints = np.zeros((num_instances, 17, 3), dtype=np.float32)
|
| 60 |
+
if keypoints.shape[1] > 17:
|
| 61 |
+
keypoints = keypoints[:, :17, :]
|
| 62 |
+
if keypoints.shape[1] < 17:
|
| 63 |
+
padded = np.zeros((num_instances, 17, 3), dtype=np.float32)
|
| 64 |
+
padded[:, : keypoints.shape[1], : min(keypoints.shape[2], 3)] = keypoints[:, :, :3]
|
| 65 |
+
keypoints = padded
|
| 66 |
+
if keypoints.shape[2] == 2:
|
| 67 |
+
scores = np.ones((num_instances, 17, 1), dtype=np.float32)
|
| 68 |
+
keypoints = np.concatenate([keypoints, scores], axis=-1)
|
| 69 |
+
elif keypoints.shape[2] > 3:
|
| 70 |
+
keypoints = keypoints[:, :, :3]
|
| 71 |
+
|
| 72 |
+
if masks is None:
|
| 73 |
+
masks = np.zeros((num_instances, height, width), dtype=np.uint8)
|
| 74 |
+
masks = np.asarray(masks)
|
| 75 |
+
if masks.ndim == 2:
|
| 76 |
+
masks = masks[None, ...]
|
| 77 |
+
if masks.ndim == 4 and masks.shape[-1] == 1:
|
| 78 |
+
masks = masks.squeeze(-1)
|
| 79 |
+
if masks.shape[0] != num_instances:
|
| 80 |
+
masks = np.zeros((num_instances, height, width), dtype=np.uint8)
|
| 81 |
+
masks = masks.astype(np.uint8)
|
| 82 |
+
|
| 83 |
+
if presence is None:
|
| 84 |
+
presence = keypoints[:, :, 2]
|
| 85 |
+
presence = np.asarray(presence, dtype=np.float32).reshape(num_instances, -1)
|
| 86 |
+
if presence.shape[1] > 17:
|
| 87 |
+
presence = presence[:, :17]
|
| 88 |
+
if presence.shape[1] < 17:
|
| 89 |
+
padded_presence = np.zeros((num_instances, 17), dtype=np.float32)
|
| 90 |
+
padded_presence[:, : presence.shape[1]] = presence
|
| 91 |
+
presence = padded_presence
|
| 92 |
+
|
| 93 |
+
if visibility is None:
|
| 94 |
+
visibility = keypoints[:, :, 2]
|
| 95 |
+
visibility = np.asarray(visibility, dtype=np.float32).reshape(num_instances, -1)
|
| 96 |
+
if visibility.shape[1] > 17:
|
| 97 |
+
visibility = visibility[:, :17]
|
| 98 |
+
if visibility.shape[1] < 17:
|
| 99 |
+
padded_visibility = np.zeros((num_instances, 17), dtype=np.float32)
|
| 100 |
+
padded_visibility[:, : visibility.shape[1]] = visibility
|
| 101 |
+
visibility = padded_visibility
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
"bboxes": bboxes,
|
| 105 |
+
"masks": masks,
|
| 106 |
+
"keypoints": keypoints,
|
| 107 |
+
"presence": presence,
|
| 108 |
+
"visibility": visibility,
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _extract_baseline_result(intermediates: Any, fallback_result: dict, height: int, width: int) -> dict[str, np.ndarray]:
|
| 113 |
+
"""Build baseline result from first intermediate pose output."""
|
| 114 |
+
if not intermediates:
|
| 115 |
+
return _normalize_result(fallback_result, height, width)
|
| 116 |
+
|
| 117 |
+
first_intermediate = intermediates[0] if len(intermediates) > 0 else None
|
| 118 |
+
if first_intermediate is None:
|
| 119 |
+
return _normalize_result(fallback_result, height, width)
|
| 120 |
+
|
| 121 |
+
pose_instances = first_intermediate.get("poses")
|
| 122 |
+
if pose_instances is None:
|
| 123 |
+
return _normalize_result(fallback_result, height, width)
|
| 124 |
+
|
| 125 |
+
result = {
|
| 126 |
+
"bboxes": _to_numpy(getattr(pose_instances, "bboxes", None)),
|
| 127 |
+
"keypoints": _to_numpy(getattr(pose_instances, "keypoints", None)),
|
| 128 |
+
"masks": _to_numpy(getattr(pose_instances, "masks", None)),
|
| 129 |
+
"presence": _to_numpy(getattr(pose_instances, "keypoint_prob", None)),
|
| 130 |
+
"visibility": _to_numpy(getattr(pose_instances, "keypoint_vis", None)),
|
| 131 |
+
}
|
| 132 |
+
return _normalize_result(result, height, width)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _blend_pose_and_mask(model: BBoxMaskPose, image_bgr: np.ndarray, result: dict[str, np.ndarray]) -> np.ndarray:
|
| 136 |
+
"""Render pose and mask overlays, then blend them into one image."""
|
| 137 |
+
pose_vis = model.visualize(image=image_bgr, result=result, vis_type="pose")
|
| 138 |
+
mask_vis = model.visualize(image=image_bgr, result=result, vis_type="mask")
|
| 139 |
+
return cv2.addWeighted(pose_vis, 0.5, mask_vis, 0.5, 0.0)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _get_bmp_model(config_name: str = "bmp_v2", device: str = "cuda:0") -> BBoxMaskPose:
|
| 143 |
+
"""Lazily initialize and reuse BBoxMaskPose model."""
|
| 144 |
+
global bmp_model, bmp_model_config, bmp_model_device
|
| 145 |
+
|
| 146 |
+
should_rebuild = (
|
| 147 |
+
bmp_model is None
|
| 148 |
+
or bmp_model_config != config_name
|
| 149 |
+
or bmp_model_device != device
|
| 150 |
)
|
| 151 |
+
if should_rebuild:
|
| 152 |
+
try:
|
| 153 |
+
bmp_model = BBoxMaskPose(config=config_name, device=device)
|
| 154 |
+
bmp_model_config = config_name
|
| 155 |
+
bmp_model_device = device
|
| 156 |
+
except Exception as exc:
|
| 157 |
+
raise RuntimeError(
|
| 158 |
+
f"Failed to initialize BBoxMaskPose with config='{config_name}', device='{device}'."
|
| 159 |
+
) from exc
|
| 160 |
+
|
| 161 |
+
return bmp_model
|
| 162 |
|
| 163 |
@spaces.GPU(duration=60)
|
| 164 |
def process_image_with_BMP(
|
| 165 |
img: np.ndarray
|
| 166 |
) -> tuple[np.ndarray, np.ndarray]:
|
| 167 |
"""
|
| 168 |
+
Run BMP inference using the public BBoxMaskPose API.
|
| 169 |
+
|
| 170 |
+
The function keeps the original Gradio interface:
|
| 171 |
+
- output 1: baseline-style result from first intermediate pass
|
| 172 |
+
- output 2: final BMP-refined result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
"""
|
| 174 |
+
if img is None:
|
| 175 |
+
raise ValueError("Input image is None.")
|
| 176 |
+
|
| 177 |
+
# Gradio image is RGB; BMP API expects BGR.
|
| 178 |
+
img_bgr = img[..., ::-1].copy()
|
| 179 |
+
height, width = img_bgr.shape[:2]
|
| 180 |
+
|
| 181 |
+
model = _get_bmp_model(config_name=bmp_model_config, device=bmp_model_device)
|
| 182 |
+
final_result = model.predict(
|
| 183 |
+
image=img_bgr,
|
| 184 |
+
bboxes=None,
|
| 185 |
+
return_intermediates=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
)
|
| 187 |
+
normalized_final = _normalize_result(final_result, height, width)
|
| 188 |
+
|
| 189 |
+
# No-detection robustness: return original image for both outputs.
|
| 190 |
+
if normalized_final["bboxes"].shape[0] == 0:
|
| 191 |
+
original_rgb = img_bgr[..., ::-1]
|
| 192 |
+
return original_rgb, original_rgb
|
| 193 |
+
|
| 194 |
+
intermediates = final_result.get("intermediates", [])
|
| 195 |
+
baseline_result = _extract_baseline_result(intermediates, normalized_final, height, width)
|
| 196 |
|
| 197 |
+
baseline_vis = _blend_pose_and_mask(model, img_bgr, baseline_result)
|
| 198 |
+
bmp_vis = _blend_pose_and_mask(model, img_bgr, normalized_final)
|
|
|
|
| 199 |
|
| 200 |
+
# BGR -> RGB for Gradio
|
| 201 |
+
return baseline_vis[..., ::-1], bmp_vis[..., ::-1]
|
| 202 |
|
| 203 |
|
| 204 |
with gr.Blocks() as app:
|
|
|
|
| 297 |
)
|
| 298 |
|
| 299 |
# Launch the demo
|
| 300 |
+
app.launch()
|
bboxmaskpose/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) authors of BBoxMaskPose (BMPv2). All rights reserved.
|
| 2 |
+
"""
|
| 3 |
+
BBoxMaskPose package - Public API for detection, pose estimation, and segmentation.
|
| 4 |
+
|
| 5 |
+
This package provides a stable wrapper for the full BBoxMaskPose pipeline.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .api import BBoxMaskPose
|
| 9 |
+
|
| 10 |
+
__all__ = ["BBoxMaskPose"]
|
bboxmaskpose/api.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) authors of BBoxMaskPose (BMPv2). All rights reserved.
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Public API for BBoxMaskPose wrapper.
|
| 5 |
+
|
| 6 |
+
This module provides a stable, user-friendly interface for the full
|
| 7 |
+
BBoxMaskPose pipeline: detection, pose estimation, and mask refinement.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import glob
|
| 11 |
+
import os
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, List, Optional, Union
|
| 14 |
+
|
| 15 |
+
import cv2
|
| 16 |
+
import mmengine
|
| 17 |
+
import numpy as np
|
| 18 |
+
import yaml
|
| 19 |
+
from mmdet.apis import inference_detector, init_detector
|
| 20 |
+
from mmengine.structures import InstanceData
|
| 21 |
+
|
| 22 |
+
from .demo_utils import DotDict, _visualize_predictions, concat_instances, filter_instances, pose_nms
|
| 23 |
+
from .posevis_lite import pose_visualization
|
| 24 |
+
|
| 25 |
+
# Import from BBoxMaskPose package
|
| 26 |
+
from .sam2_utils import prepare_model as prepare_sam2_model, process_image_with_SAM
|
| 27 |
+
|
| 28 |
+
BMP_ROOT = Path(__file__).parent.parent
|
| 29 |
+
|
| 30 |
+
# Note: PMPose will be imported when needed to avoid circular imports
|
| 31 |
+
# from pmpose import PMPose
|
| 32 |
+
|
| 33 |
+
# Default detector and pose config
|
| 34 |
+
DEFAULT_DET_CAT_ID: int = 0
|
| 35 |
+
DEFAULT_BBOX_THR: float = 0.3
|
| 36 |
+
DEFAULT_NMS_THR: float = 0.3
|
| 37 |
+
DEFAULT_KPT_THR: float = 0.3
|
| 38 |
+
|
| 39 |
+
# Pretrained config URLs (for future use)
|
| 40 |
+
PRETRAINED_CONFIGS = {
|
| 41 |
+
"bmp-d3": "BMP_D3",
|
| 42 |
+
"bmp-j1": "BMP_J1",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class BBoxMaskPose:
|
| 47 |
+
"""
|
| 48 |
+
Public wrapper API for BBoxMaskPose pipeline.
|
| 49 |
+
|
| 50 |
+
This class provides a complete pipeline for detection, pose estimation,
|
| 51 |
+
and mask refinement using SAM2.
|
| 52 |
+
|
| 53 |
+
Example:
|
| 54 |
+
>>> bmp_model = BBoxMaskPose(config="BMP_D3", device="cuda")
|
| 55 |
+
>>> result = bmp_model.predict(
|
| 56 |
+
... image="path/to/image.jpg",
|
| 57 |
+
... return_intermediates=True
|
| 58 |
+
... )
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
config: str = "BMP_D3",
|
| 64 |
+
device: str = "cuda",
|
| 65 |
+
config_path: Optional[str] = None,
|
| 66 |
+
pose_model=None, # Type hint removed to avoid import at module level
|
| 67 |
+
pretrained_id: Optional[str] = None,
|
| 68 |
+
n_kpts_to_work_with: Optional[int] = 17,
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Initialize BBoxMaskPose model.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
config (str): Config alias ('BMP_D3', 'BMP_J1'). Defaults to 'BMP_D3'.
|
| 75 |
+
device (str): Device for inference. Defaults to 'cuda'.
|
| 76 |
+
config_path (str, optional): Path to custom YAML config file.
|
| 77 |
+
pose_model (PMPose, optional): Pre-initialized PMPose instance.
|
| 78 |
+
If None, will create internal pose model.
|
| 79 |
+
pretrained_id (str, optional): Alias for pretrained config.
|
| 80 |
+
n_kpts_to_work_with (int, optional): Number of keypoints to work with.
|
| 81 |
+
Defaults to 17 (COCO keypoints).
|
| 82 |
+
"""
|
| 83 |
+
self.device = device
|
| 84 |
+
self.config_name = config
|
| 85 |
+
|
| 86 |
+
self.n_kpts_to_work_with = 17 # Hard-code 17 as no experiments were done with other values. Keep the argument for future flexibility, but ignore it for now.
|
| 87 |
+
if n_kpts_to_work_with != 17:
|
| 88 |
+
print(
|
| 89 |
+
f"Warning: n_kpts_to_work_with is set to {n_kpts_to_work_with}, but currently only 17 keypoints are supported. Ignoring this argument for now."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Determine config path
|
| 93 |
+
if config_path is not None:
|
| 94 |
+
self.config_path = config_path
|
| 95 |
+
else:
|
| 96 |
+
bmp_configs_root = os.path.join(BMP_ROOT, "bboxmaskpose", "configs")
|
| 97 |
+
config_file = f"{config}.yaml"
|
| 98 |
+
self.config_path = os.path.join(bmp_configs_root, config_file)
|
| 99 |
+
|
| 100 |
+
if not os.path.exists(self.config_path):
|
| 101 |
+
available_configs = glob.glob(os.path.join(bmp_configs_root, "*.yaml"))
|
| 102 |
+
available_configs = [os.path.basename(f).replace(".yaml", "") for f in available_configs]
|
| 103 |
+
raise FileNotFoundError(f"Config file not found: {self.config_path}. " f"Available configs: {', '.join(available_configs)}")
|
| 104 |
+
|
| 105 |
+
# Load config
|
| 106 |
+
self.config = self._load_config(self.config_path)
|
| 107 |
+
|
| 108 |
+
# Initialize or use provided pose model
|
| 109 |
+
if pose_model is not None:
|
| 110 |
+
self.pose_model = pose_model
|
| 111 |
+
self._owns_pose_model = False
|
| 112 |
+
else:
|
| 113 |
+
# Create internal PMPose instance
|
| 114 |
+
self.pose_model = self._create_pose_model()
|
| 115 |
+
self._owns_pose_model = True
|
| 116 |
+
|
| 117 |
+
# Initialize detector and SAM2
|
| 118 |
+
self.detector = None
|
| 119 |
+
self.detector_prime = None
|
| 120 |
+
self.sam2_model = None
|
| 121 |
+
self._initialize_models()
|
| 122 |
+
|
| 123 |
+
def _load_config(self, config_path: str) -> DotDict:
|
| 124 |
+
"""Load BMP configuration from YAML file."""
|
| 125 |
+
with open(config_path, "r") as f:
|
| 126 |
+
cfg = yaml.safe_load(f)
|
| 127 |
+
return DotDict(cfg)
|
| 128 |
+
|
| 129 |
+
def _create_pose_model(self):
|
| 130 |
+
"""Create internal PMPose model from config."""
|
| 131 |
+
# Import PMPose here to avoid circular imports
|
| 132 |
+
from pmpose import PMPose
|
| 133 |
+
|
| 134 |
+
# Extract pose config from BMP config
|
| 135 |
+
pose_config = self.config.pose_estimator.pose_config
|
| 136 |
+
pose_checkpoint = self.config.pose_estimator.pose_checkpoint
|
| 137 |
+
|
| 138 |
+
# Create PMPose instance with custom config
|
| 139 |
+
full_pose_config = str(BMP_ROOT / pose_config)
|
| 140 |
+
|
| 141 |
+
pose_model = PMPose(
|
| 142 |
+
device=self.device,
|
| 143 |
+
config_path=full_pose_config,
|
| 144 |
+
from_pretrained=True,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Load checkpoint if it's a local path
|
| 148 |
+
if not pose_checkpoint.startswith("http"):
|
| 149 |
+
pose_model.load_from_file(pose_checkpoint)
|
| 150 |
+
|
| 151 |
+
return pose_model
|
| 152 |
+
|
| 153 |
+
def _initialize_models(self):
|
| 154 |
+
"""Initialize detector and SAM2 models."""
|
| 155 |
+
# Initialize detector
|
| 156 |
+
self.detector = init_detector(self.config.detector.det_config, self.config.detector.det_checkpoint, device=self.device)
|
| 157 |
+
|
| 158 |
+
# Adapt detector pipeline
|
| 159 |
+
from mmpose.utils import adapt_mmdet_pipeline
|
| 160 |
+
|
| 161 |
+
self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg)
|
| 162 |
+
|
| 163 |
+
# Initialize detector prime (may be same as detector)
|
| 164 |
+
if (
|
| 165 |
+
self.config.detector.det_config == self.config.detector.det_prime_config
|
| 166 |
+
and self.config.detector.det_checkpoint == self.config.detector.det_prime_checkpoint
|
| 167 |
+
) or (self.config.detector.det_prime_config is None or self.config.detector.det_prime_checkpoint is None):
|
| 168 |
+
self.detector_prime = self.detector
|
| 169 |
+
else:
|
| 170 |
+
self.detector_prime = init_detector(
|
| 171 |
+
self.config.detector.det_prime_config, self.config.detector.det_prime_checkpoint, device=self.device
|
| 172 |
+
)
|
| 173 |
+
self.detector_prime.cfg = adapt_mmdet_pipeline(self.detector_prime.cfg)
|
| 174 |
+
|
| 175 |
+
# Initialize SAM2
|
| 176 |
+
sam2_config_path = os.path.join(BMP_ROOT, "bboxmaskpose", "sam2", self.config.sam2.sam2_config)
|
| 177 |
+
self.sam2_model = prepare_sam2_model(
|
| 178 |
+
model_cfg=sam2_config_path,
|
| 179 |
+
model_checkpoint=self.config.sam2.sam2_checkpoint,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def predict(
|
| 183 |
+
self,
|
| 184 |
+
image: Union[str, np.ndarray],
|
| 185 |
+
bboxes: Optional[np.ndarray] = None,
|
| 186 |
+
return_intermediates: bool = False,
|
| 187 |
+
return_probmaps: bool = False,
|
| 188 |
+
) -> Dict:
|
| 189 |
+
"""
|
| 190 |
+
Run full BBoxMaskPose pipeline on image.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
image: Image path (str) or BGR numpy array.
|
| 194 |
+
bboxes: Optional (N, 4) bboxes in [x1, y1, x2, y2] format.
|
| 195 |
+
If None, run detector.
|
| 196 |
+
return_intermediates: If True, return intermediate outputs.
|
| 197 |
+
return_probmaps: If True, request heatmaps from pose model.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Dict with keys:
|
| 201 |
+
- 'bboxes': (N, 4) final bounding boxes
|
| 202 |
+
- 'masks': (N, H, W) refined binary masks
|
| 203 |
+
- 'keypoints': (N, K, 3) keypoints with scores
|
| 204 |
+
- 'presence': (N, K) presence probabilities
|
| 205 |
+
- 'visibility': (N, K) visibility flags
|
| 206 |
+
- 'detector': (optional) raw detector outputs
|
| 207 |
+
- 'sam2': (optional) intermediate SAM outputs
|
| 208 |
+
"""
|
| 209 |
+
# Load image
|
| 210 |
+
if isinstance(image, str):
|
| 211 |
+
img = cv2.imread(image)
|
| 212 |
+
if img is None:
|
| 213 |
+
raise ValueError(f"Failed to load image from {image}")
|
| 214 |
+
else:
|
| 215 |
+
img = image.copy()
|
| 216 |
+
|
| 217 |
+
# Run BMP iterations
|
| 218 |
+
all_detections = None
|
| 219 |
+
intermediate_results = [] if return_intermediates else None
|
| 220 |
+
|
| 221 |
+
for iteration in range(self.config.num_bmp_iters):
|
| 222 |
+
# Step 1: Detection
|
| 223 |
+
if iteration == 0 and bboxes is not None:
|
| 224 |
+
# Use provided bboxes for first iteration
|
| 225 |
+
det_instances = InstanceData(bboxes=bboxes, bbox_scores=np.ones(len(bboxes)), masks=None)
|
| 226 |
+
else:
|
| 227 |
+
# Run detector
|
| 228 |
+
det_instances = self._run_detector(
|
| 229 |
+
self.detector if iteration == 0 else self.detector_prime,
|
| 230 |
+
img if all_detections is None else self._mask_out_image(img, all_detections),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
if len(det_instances.bboxes) == 0:
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
# Step 2: Pose estimation using PMPose wrapper
|
| 237 |
+
pose_results = self._run_pose_estimation(img, det_instances, return_probmaps=return_probmaps)
|
| 238 |
+
|
| 239 |
+
# Step 3: Pose NMS and SAM refinement
|
| 240 |
+
new_detections, old_detections = self._refine_with_sam(
|
| 241 |
+
img,
|
| 242 |
+
pose_results,
|
| 243 |
+
all_detections,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Merge detections
|
| 247 |
+
if all_detections is None:
|
| 248 |
+
all_detections = new_detections
|
| 249 |
+
else:
|
| 250 |
+
all_detections = concat_instances(old_detections, new_detections)
|
| 251 |
+
|
| 252 |
+
# Store intermediates if requested
|
| 253 |
+
if return_intermediates:
|
| 254 |
+
intermediate_results.append(
|
| 255 |
+
{
|
| 256 |
+
"iteration": iteration,
|
| 257 |
+
"detections": det_instances,
|
| 258 |
+
"poses": pose_results,
|
| 259 |
+
"refined": new_detections,
|
| 260 |
+
}
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Prepare final result
|
| 264 |
+
result = self._format_result(all_detections, img.shape[:2])
|
| 265 |
+
|
| 266 |
+
if return_intermediates:
|
| 267 |
+
result["intermediates"] = intermediate_results
|
| 268 |
+
|
| 269 |
+
return result
|
| 270 |
+
|
| 271 |
+
def _run_detector(
|
| 272 |
+
self,
|
| 273 |
+
detector,
|
| 274 |
+
img: np.ndarray,
|
| 275 |
+
) -> InstanceData:
|
| 276 |
+
"""Run MMDetection detector."""
|
| 277 |
+
from mmpose.evaluation.functional import nms
|
| 278 |
+
|
| 279 |
+
# Run detection
|
| 280 |
+
det_result = inference_detector(detector, img)
|
| 281 |
+
pred_instances = det_result.pred_instances.cpu().numpy()
|
| 282 |
+
|
| 283 |
+
# Aggregate bboxes and scores
|
| 284 |
+
bboxes_all = np.concatenate((pred_instances.bboxes, pred_instances.scores[:, None]), axis=1)
|
| 285 |
+
|
| 286 |
+
# Filter by category and score
|
| 287 |
+
keep_mask = np.logical_and(pred_instances.labels == DEFAULT_DET_CAT_ID, pred_instances.scores > DEFAULT_BBOX_THR)
|
| 288 |
+
|
| 289 |
+
if not np.any(keep_mask):
|
| 290 |
+
return InstanceData(bboxes=np.zeros((0, 4)), bbox_scores=np.zeros((0,)), masks=np.zeros((0, 1, 1)))
|
| 291 |
+
|
| 292 |
+
bboxes = bboxes_all[keep_mask]
|
| 293 |
+
masks = getattr(pred_instances, "masks", None)
|
| 294 |
+
if masks is not None:
|
| 295 |
+
masks = masks[keep_mask]
|
| 296 |
+
|
| 297 |
+
# Sort by score
|
| 298 |
+
order = np.argsort(bboxes[:, 4])[::-1]
|
| 299 |
+
bboxes = bboxes[order]
|
| 300 |
+
if masks is not None:
|
| 301 |
+
masks = masks[order]
|
| 302 |
+
|
| 303 |
+
# Apply NMS
|
| 304 |
+
keep_indices = nms(bboxes, DEFAULT_NMS_THR)
|
| 305 |
+
bboxes = bboxes[keep_indices]
|
| 306 |
+
if masks is not None:
|
| 307 |
+
masks = masks[keep_indices]
|
| 308 |
+
|
| 309 |
+
return InstanceData(bboxes=bboxes[:, :4], bbox_scores=bboxes[:, 4], masks=masks)
|
| 310 |
+
|
| 311 |
+
def _run_pose_estimation(
|
| 312 |
+
self,
|
| 313 |
+
img: np.ndarray,
|
| 314 |
+
det_instances: InstanceData,
|
| 315 |
+
return_probmaps: bool = False,
|
| 316 |
+
) -> InstanceData:
|
| 317 |
+
"""Run pose estimation using PMPose wrapper."""
|
| 318 |
+
bboxes = det_instances.bboxes
|
| 319 |
+
masks = det_instances.masks
|
| 320 |
+
|
| 321 |
+
if len(bboxes) == 0:
|
| 322 |
+
return InstanceData(
|
| 323 |
+
keypoints=np.zeros((0, self.n_kpts_to_work_with, 3)),
|
| 324 |
+
keypoint_scores=np.zeros((0, self.n_kpts_to_work_with)),
|
| 325 |
+
bboxes=bboxes,
|
| 326 |
+
bbox_scores=det_instances.bbox_scores,
|
| 327 |
+
masks=masks,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Call PMPose public API
|
| 331 |
+
keypoints, probabilities, visibilities, heatmaps = self.pose_model.predict(
|
| 332 |
+
img,
|
| 333 |
+
bboxes,
|
| 334 |
+
masks=masks,
|
| 335 |
+
return_probmaps=return_probmaps,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Restrict to first 17 COCO keypoints
|
| 339 |
+
keypoints = keypoints[:, : self.n_kpts_to_work_with, :]
|
| 340 |
+
probabilities = probabilities[:, : self.n_kpts_to_work_with]
|
| 341 |
+
visibilities = visibilities[:, : self.n_kpts_to_work_with]
|
| 342 |
+
|
| 343 |
+
if heatmaps is not None:
|
| 344 |
+
heatmaps = heatmaps[:, : self.n_kpts_to_work_with, :, :]
|
| 345 |
+
|
| 346 |
+
# Create InstanceData with results
|
| 347 |
+
result = InstanceData(
|
| 348 |
+
keypoints=keypoints,
|
| 349 |
+
keypoint_scores=keypoints[:, :, 2],
|
| 350 |
+
bboxes=bboxes,
|
| 351 |
+
bbox_scores=det_instances.bbox_scores,
|
| 352 |
+
masks=masks,
|
| 353 |
+
keypoint_vis=visibilities,
|
| 354 |
+
keypoint_prob=probabilities,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
if return_probmaps and heatmaps is not None:
|
| 358 |
+
result.heatmaps = heatmaps
|
| 359 |
+
|
| 360 |
+
return result
|
| 361 |
+
|
| 362 |
+
def _refine_with_sam(
|
| 363 |
+
self,
|
| 364 |
+
img: np.ndarray,
|
| 365 |
+
pose_instances: InstanceData,
|
| 366 |
+
all_detections: Optional[InstanceData],
|
| 367 |
+
) -> tuple:
|
| 368 |
+
"""Perform Pose-NMS and SAM refinement."""
|
| 369 |
+
# Combine keypoints with scores
|
| 370 |
+
keypoints_with_scores = pose_instances.keypoints
|
| 371 |
+
|
| 372 |
+
# Perform Pose-NMS
|
| 373 |
+
all_keypoints = (
|
| 374 |
+
keypoints_with_scores if all_detections is None else np.concatenate([all_detections.keypoints, keypoints_with_scores], axis=0)
|
| 375 |
+
)
|
| 376 |
+
all_bboxes = (
|
| 377 |
+
pose_instances.bboxes if all_detections is None else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
num_valid_kpts = np.sum(all_keypoints[:, :, 2] > self.config.sam2.prompting.confidence_thr, axis=1)
|
| 381 |
+
|
| 382 |
+
keep_indices = pose_nms(
|
| 383 |
+
DotDict({"confidence_thr": self.config.sam2.prompting.confidence_thr, "oks_thr": self.config.oks_nms_thr}),
|
| 384 |
+
image_kpts=all_keypoints,
|
| 385 |
+
image_bboxes=all_bboxes,
|
| 386 |
+
num_valid_kpts=num_valid_kpts,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
keep_indices = sorted(keep_indices)
|
| 390 |
+
num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
|
| 391 |
+
keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
|
| 392 |
+
keep_old_indices = [i for i in keep_indices if i < num_old_detections]
|
| 393 |
+
|
| 394 |
+
if len(keep_new_indices) == 0:
|
| 395 |
+
return None, all_detections
|
| 396 |
+
|
| 397 |
+
# Filter new detections
|
| 398 |
+
new_dets = filter_instances(pose_instances, keep_new_indices)
|
| 399 |
+
new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
|
| 400 |
+
|
| 401 |
+
old_dets = None
|
| 402 |
+
if len(keep_old_indices) > 0:
|
| 403 |
+
old_dets = filter_instances(all_detections, keep_old_indices)
|
| 404 |
+
|
| 405 |
+
# Run SAM refinement
|
| 406 |
+
new_detections = process_image_with_SAM(
|
| 407 |
+
DotDict(self.config.sam2.prompting),
|
| 408 |
+
img.copy(),
|
| 409 |
+
self.sam2_model,
|
| 410 |
+
new_dets,
|
| 411 |
+
old_dets if old_dets is not None else None,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
return new_detections, old_dets
|
| 415 |
+
|
| 416 |
+
def _mask_out_image(
|
| 417 |
+
self,
|
| 418 |
+
img: np.ndarray,
|
| 419 |
+
detections: InstanceData,
|
| 420 |
+
) -> np.ndarray:
|
| 421 |
+
"""Mask out detected instances from image for next iteration."""
|
| 422 |
+
masked_img = img.copy()
|
| 423 |
+
if hasattr(detections, "refined_masks") and detections.refined_masks is not None:
|
| 424 |
+
for mask in detections.refined_masks:
|
| 425 |
+
if mask is not None:
|
| 426 |
+
masked_img[mask.astype(bool)] = 0
|
| 427 |
+
return masked_img
|
| 428 |
+
|
| 429 |
+
def _format_result(
|
| 430 |
+
self,
|
| 431 |
+
detections: Optional[InstanceData],
|
| 432 |
+
img_shape: tuple,
|
| 433 |
+
) -> Dict:
|
| 434 |
+
"""Format detection results into standard output dict."""
|
| 435 |
+
if detections is None or len(detections.bboxes) == 0:
|
| 436 |
+
return {
|
| 437 |
+
"bboxes": np.zeros((0, 4)),
|
| 438 |
+
"masks": np.zeros((0, img_shape[0], img_shape[1])),
|
| 439 |
+
"keypoints": np.zeros((0, 17, 3)),
|
| 440 |
+
"presence": np.zeros((0, 17)),
|
| 441 |
+
"visibility": np.zeros((0, 17)),
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
# Extract refined masks if available
|
| 445 |
+
if hasattr(detections, "refined_masks") and detections.refined_masks is not None:
|
| 446 |
+
masks = detections.refined_masks
|
| 447 |
+
elif hasattr(detections, "pred_masks") and detections.pred_masks is not None:
|
| 448 |
+
masks = detections.pred_masks
|
| 449 |
+
elif hasattr(detections, "masks") and detections.masks is not None:
|
| 450 |
+
masks = detections.masks
|
| 451 |
+
else:
|
| 452 |
+
masks = np.zeros((len(detections.bboxes), img_shape[0], img_shape[1]))
|
| 453 |
+
|
| 454 |
+
return {
|
| 455 |
+
"bboxes": detections.bboxes,
|
| 456 |
+
"masks": masks,
|
| 457 |
+
"keypoints": detections.keypoints,
|
| 458 |
+
"presence": detections.keypoint_prob,
|
| 459 |
+
"visibility": detections.keypoint_vis,
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
def visualize(
|
| 463 |
+
self,
|
| 464 |
+
image: Union[str, np.ndarray],
|
| 465 |
+
result: Dict,
|
| 466 |
+
save_path: Optional[str] = None,
|
| 467 |
+
vis_type: str = "pose",
|
| 468 |
+
) -> np.ndarray:
|
| 469 |
+
"""
|
| 470 |
+
Visualize BBoxMaskPose results on image.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
image: Image path (str) or BGR numpy array.
|
| 474 |
+
result: Result dict from predict().
|
| 475 |
+
save_path: Optional path to save visualization.
|
| 476 |
+
vis_type: Type of visualization ("pose" or "mask").
|
| 477 |
+
Returns:
|
| 478 |
+
np.ndarray: Visualization image (BGR).
|
| 479 |
+
"""
|
| 480 |
+
# Load image
|
| 481 |
+
if isinstance(image, str):
|
| 482 |
+
img = cv2.imread(image)
|
| 483 |
+
if img is None:
|
| 484 |
+
raise ValueError(f"Failed to load image from {image}")
|
| 485 |
+
else:
|
| 486 |
+
img = image.copy()
|
| 487 |
+
|
| 488 |
+
if vis_type == "mask":
|
| 489 |
+
vis_img, _ = _visualize_predictions(
|
| 490 |
+
img,
|
| 491 |
+
bboxes=result["bboxes"],
|
| 492 |
+
scores=np.ones(len(result["bboxes"])),
|
| 493 |
+
masks=result["masks"],
|
| 494 |
+
poses=result["keypoints"],
|
| 495 |
+
vis_type="mask",
|
| 496 |
+
mask_is_binary=True,
|
| 497 |
+
)
|
| 498 |
+
img = vis_img
|
| 499 |
+
else:
|
| 500 |
+
# Visualize using posevis_lite
|
| 501 |
+
keypoints = result["keypoints"]
|
| 502 |
+
keypoints = keypoints[:, :17, :] # Use first 17 COCO keypoints
|
| 503 |
+
img = pose_visualization(
|
| 504 |
+
img,
|
| 505 |
+
keypoints,
|
| 506 |
+
width_multiplier=8,
|
| 507 |
+
differ_individuals=True,
|
| 508 |
+
keep_image_size=True,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# Save if requested
|
| 512 |
+
if save_path is not None:
|
| 513 |
+
cv2.imwrite(save_path, img)
|
| 514 |
+
|
| 515 |
+
return img
|
{configs → bboxmaskpose/configs}/README.md
RENAMED
|
File without changes
|
{configs → bboxmaskpose/configs}/bmp_D3.yaml
RENAMED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# BBoxMaskPose Hyperparameters from Experiment D3.
|
| 2 |
# For details, see the paper: https://arxiv.org/abs/2412.01562, Tab 8. in the supplementary.
|
| 3 |
|
|
@@ -11,8 +16,10 @@ detector:
|
|
| 11 |
det_prime_checkpoint: null
|
| 12 |
|
| 13 |
pose_estimator:
|
| 14 |
-
pose_config: 'mmpose/configs/
|
| 15 |
-
pose_checkpoint: '
|
|
|
|
|
|
|
| 16 |
|
| 17 |
sam2:
|
| 18 |
sam2_config: 'configs/samurai/sam2.1_hiera_b+.yaml' # Use SAMURAI as it has img_size 1024 (SAM-2.1 has 512)
|
|
|
|
| 1 |
+
######################################################################################
|
| 2 |
+
### THIS CONFIG IS DEPRACATED AND KEPT ONLY FOR REPRODUCTION OF BMPv1 EXPERIMENTS. ###
|
| 3 |
+
### FOR BMPv2 EXPERIMENTS, PLEASE USE THE bmp_v2.yaml CONFIG. ###
|
| 4 |
+
######################################################################################
|
| 5 |
+
|
| 6 |
# BBoxMaskPose Hyperparameters from Experiment D3.
|
| 7 |
# For details, see the paper: https://arxiv.org/abs/2412.01562, Tab 8. in the supplementary.
|
| 8 |
|
|
|
|
| 16 |
det_prime_checkpoint: null
|
| 17 |
|
| 18 |
pose_estimator:
|
| 19 |
+
pose_config: 'mmpose/configs/ProbMaskPose/PMPose-b-1.0.0.py'
|
| 20 |
+
pose_checkpoint: 'models/pose_estimators/PMPose-b-1.0.0.pth'
|
| 21 |
+
# pose_config: 'mmpose/configs/MaskPose/ViTb-multi_mask.py'
|
| 22 |
+
# pose_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/MaskPose-b.pth'
|
| 23 |
|
| 24 |
sam2:
|
| 25 |
sam2_config: 'configs/samurai/sam2.1_hiera_b+.yaml' # Use SAMURAI as it has img_size 1024 (SAM-2.1 has 512)
|
{configs → bboxmaskpose/configs}/bmp_J1.yaml
RENAMED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# BBoxMaskPose Hyperparameters from Experiment J1.
|
| 2 |
# For details, see the paper: https://arxiv.org/abs/2412.01562, Tab 8. in the supplementary.
|
| 3 |
|
|
|
|
| 1 |
+
######################################################################################
|
| 2 |
+
### THIS CONFIG IS DEPRACATED AND KEPT ONLY FOR REPRODUCTION OF BMPv1 EXPERIMENTS. ###
|
| 3 |
+
### FOR BMPv2 EXPERIMENTS, PLEASE USE THE bmp_v2.yaml CONFIG. ###
|
| 4 |
+
######################################################################################
|
| 5 |
+
|
| 6 |
# BBoxMaskPose Hyperparameters from Experiment J1.
|
| 7 |
# For details, see the paper: https://arxiv.org/abs/2412.01562, Tab 8. in the supplementary.
|
| 8 |
|
bboxmaskpose/configs/bmp_v2.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This configuration is good for the BMP loop as was used for most of the experiments.
|
| 2 |
+
detector:
|
| 3 |
+
det_config: 'mmpose/configs/mmdet/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py'
|
| 4 |
+
det_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/rtmdet-ins-l-mask.pth'
|
| 5 |
+
|
| 6 |
+
# Detectors D and D' could be different.
|
| 7 |
+
det_prime_config: null
|
| 8 |
+
det_prime_checkpoint: null
|
| 9 |
+
|
| 10 |
+
pose_estimator:
|
| 11 |
+
pose_config: 'mmpose/configs/ProbMaskPose/PMPose-b-1.0.0.py'
|
| 12 |
+
pose_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/PMPose/PMPose-b-1.0.0.pth'
|
| 13 |
+
|
| 14 |
+
sam2:
|
| 15 |
+
sam2_config: 'configs/sam-pose2seg/sam-pose2seg_hiera_b+.yaml'
|
| 16 |
+
sam2_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/SAM-pose2seg_hiera_b%2B.pt'
|
| 17 |
+
prompting:
|
| 18 |
+
batch: False
|
| 19 |
+
use_bbox: False
|
| 20 |
+
num_pos_keypoints: 3
|
| 21 |
+
num_pos_keypoints_if_crowd: 3
|
| 22 |
+
num_neg_keypoints: 0
|
| 23 |
+
confidence_thr: 0.5 # not used
|
| 24 |
+
visibility_thr: 0.5 # not used
|
| 25 |
+
selection_method: 'k_most_visible'
|
| 26 |
+
extend_bbox: False
|
| 27 |
+
pose_mask_consistency: False
|
| 28 |
+
crowd_by_max_iou: False # Determine if the instance is in the multi-body scenario. If yes, use different amount of keypoints and NO BBOX. If no, use bbox according to 'use_bbox' argument.
|
| 29 |
+
crop: False
|
| 30 |
+
exclusive_masks: True
|
| 31 |
+
ignore_small_bboxes: False
|
| 32 |
+
|
| 33 |
+
num_bmp_iters: 2
|
| 34 |
+
oks_nms_thr: 0.8
|
{demo → bboxmaskpose}/demo_utils.py
RENAMED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
"""
|
| 2 |
Utilities for the BMP demo:
|
| 3 |
- Visualization of detections, masks, and poses
|
|
@@ -18,9 +19,10 @@ import numpy as np
|
|
| 18 |
from mmengine.logging import print_log
|
| 19 |
from mmengine.structures import InstanceData
|
| 20 |
from pycocotools import mask as Mask
|
| 21 |
-
from sam2.distinctipy import get_colors
|
| 22 |
from tqdm import tqdm
|
| 23 |
|
|
|
|
|
|
|
| 24 |
### Visualization hyperparameters
|
| 25 |
MIN_CONTOUR_AREA: int = 50
|
| 26 |
BBOX_WEIGHT: float = 0.9
|
|
@@ -38,6 +40,21 @@ except ImportError:
|
|
| 38 |
from .posevis_lite import pose_visualization
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
class DotDict(dict):
|
| 42 |
"""Dictionary with attribute access and nested dict wrapping."""
|
| 43 |
|
|
@@ -68,17 +85,7 @@ def filter_instances(instances: InstanceData, indices):
|
|
| 68 |
return None
|
| 69 |
data = {}
|
| 70 |
# Attributes to filter
|
| 71 |
-
for attr in
|
| 72 |
-
"bboxes",
|
| 73 |
-
"bbox_scores",
|
| 74 |
-
"keypoints",
|
| 75 |
-
"keypoint_scores",
|
| 76 |
-
"scores",
|
| 77 |
-
"pred_masks",
|
| 78 |
-
"refined_masks",
|
| 79 |
-
"sam_scores",
|
| 80 |
-
"sam_kpts",
|
| 81 |
-
]:
|
| 82 |
if hasattr(instances, attr):
|
| 83 |
arr = getattr(instances, attr)
|
| 84 |
data[attr] = arr[indices] if arr is not None else None
|
|
@@ -95,17 +102,7 @@ def concat_instances(instances1: InstanceData, instances2: InstanceData):
|
|
| 95 |
if instances2 is None:
|
| 96 |
return instances1
|
| 97 |
data = {}
|
| 98 |
-
for attr in
|
| 99 |
-
"bboxes",
|
| 100 |
-
"bbox_scores",
|
| 101 |
-
"keypoints",
|
| 102 |
-
"keypoint_scores",
|
| 103 |
-
"scores",
|
| 104 |
-
"pred_masks",
|
| 105 |
-
"refined_masks",
|
| 106 |
-
"sam_scores",
|
| 107 |
-
"sam_kpts",
|
| 108 |
-
]:
|
| 109 |
arr1 = getattr(instances1, attr, None)
|
| 110 |
arr2 = getattr(instances2, attr, None)
|
| 111 |
if arr1 is None and arr2 is None:
|
|
@@ -145,43 +142,20 @@ def _visualize_predictions(
|
|
| 145 |
"""
|
| 146 |
vis_types = vis_type.split("+")
|
| 147 |
|
| 148 |
-
#
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
# new_masks = []
|
| 152 |
-
# new_poses = []
|
| 153 |
-
# size_thr = img.shape[0] * img.shape[1] * 0.01
|
| 154 |
-
# for bbox, score, mask, pose in zip(bboxes, scores, masks, poses):
|
| 155 |
-
# area = mask.sum() # Assume binary mask. OK for demo purposes
|
| 156 |
-
# if area > size_thr:
|
| 157 |
-
# new_bboxes.append(bbox)
|
| 158 |
-
# new_scores.append(score)
|
| 159 |
-
# new_masks.append(mask)
|
| 160 |
-
# new_poses.append(pose)
|
| 161 |
-
# bboxes = np.array(new_bboxes)
|
| 162 |
-
# scores = np.array(new_scores)
|
| 163 |
-
# masks = new_masks
|
| 164 |
-
# poses = new_poses
|
| 165 |
-
|
| 166 |
if mask_is_binary:
|
| 167 |
poly_masks: List[Optional[List[np.ndarray]]] = []
|
| 168 |
for binary_mask in masks:
|
| 169 |
if binary_mask is not None:
|
| 170 |
-
contours, _ = cv2.findContours(
|
| 171 |
-
(binary_mask * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
| 172 |
-
)
|
| 173 |
polys = [cnt.flatten() for cnt in contours if cv2.contourArea(cnt) >= MIN_CONTOUR_AREA]
|
| 174 |
else:
|
| 175 |
polys = None
|
| 176 |
poly_masks.append(polys)
|
| 177 |
masks = poly_masks # type: ignore
|
| 178 |
|
| 179 |
-
# Exclude white, black, and green colors from the palette as they are not distinctive
|
| 180 |
-
colors = (np.array(get_colors(len(bboxes), exclude_colors=[(0, 1, 0), (.5, .5, .5), (0, 0, 0), (1, 1, 1)], rng=0)) * 255).astype(
|
| 181 |
-
int
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
|
| 185 |
if "inv-mask" in vis_types:
|
| 186 |
stencil = np.zeros_like(img)
|
| 187 |
|
|
@@ -272,9 +246,7 @@ def visualize_itteration(
|
|
| 272 |
label = "BMP {:d}x: {}".format(iteration_idx + 1, vis_def["label"])
|
| 273 |
cv2.putText(vis_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
|
| 274 |
cv2.putText(vis_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
|
| 275 |
-
out_path = os.path.join(
|
| 276 |
-
output_root, "{}_iter{}_{}.jpg".format(img_name, iteration_idx + 1, vis_def["label"].replace(" ", "_"))
|
| 277 |
-
)
|
| 278 |
cv2.imwrite(str(out_path), vis_img)
|
| 279 |
|
| 280 |
# Show prompting keypoints
|
|
@@ -311,43 +283,6 @@ def visualize_itteration(
|
|
| 311 |
return masked_out
|
| 312 |
|
| 313 |
|
| 314 |
-
def visualize_demo(
|
| 315 |
-
img: np.ndarray, detections: Any,
|
| 316 |
-
) -> Optional[np.ndarray]:
|
| 317 |
-
"""
|
| 318 |
-
Generate and save visualization images for each BMP iteration.
|
| 319 |
-
|
| 320 |
-
Args:
|
| 321 |
-
img (np.ndarray): Original input image.
|
| 322 |
-
detections: InstanceData containing bboxes, scores, masks, keypoints.
|
| 323 |
-
iteration_idx (int): Current iteration index (0-based).
|
| 324 |
-
output_root (Path): Directory to save output images.
|
| 325 |
-
img_name (str): Base name of the image without extension.
|
| 326 |
-
with_text (bool): Whether to overlay text labels.
|
| 327 |
-
|
| 328 |
-
Returns:
|
| 329 |
-
Optional[np.ndarray]: The masked-out image if generated, else None.
|
| 330 |
-
"""
|
| 331 |
-
bboxes = detections.bboxes
|
| 332 |
-
scores = detections.scores
|
| 333 |
-
pred_masks = detections.pred_masks
|
| 334 |
-
refined_masks = detections.refined_masks
|
| 335 |
-
keypoints = detections.keypoints
|
| 336 |
-
|
| 337 |
-
returns = []
|
| 338 |
-
for vis_def in [
|
| 339 |
-
{"type": "mask-out", "masks": refined_masks, "label": ""},
|
| 340 |
-
{"type": "mask+pose", "masks": pred_masks, "label": "RTMDet-L"},
|
| 341 |
-
{"type": "mask+pose", "masks": refined_masks, "label": "BMP"},
|
| 342 |
-
]:
|
| 343 |
-
vis_img, colors = _visualize_predictions(
|
| 344 |
-
img.copy(), bboxes, scores, vis_def["masks"], keypoints, vis_type=vis_def["type"], mask_is_binary=True
|
| 345 |
-
)
|
| 346 |
-
returns.append(vis_img)
|
| 347 |
-
|
| 348 |
-
return returns
|
| 349 |
-
|
| 350 |
-
|
| 351 |
def create_GIF(
|
| 352 |
img_path: Path,
|
| 353 |
output_root: Path,
|
|
@@ -419,7 +354,6 @@ def create_GIF(
|
|
| 419 |
# Add 'before' and 'after' images
|
| 420 |
after1_img = os.path.join(dirname, "{}_iter{}_Final_Poses.jpg".format(img_name_wo_ext, bmp_x))
|
| 421 |
after2_img = os.path.join(dirname, "{}_iter{}_SAM_Masks.jpg".format(img_name_wo_ext, bmp_x))
|
| 422 |
-
# gif_images.append(os.path.join(dirname, "black_image.jpg")) # Add black image at the end
|
| 423 |
gif_images.append(after1_img)
|
| 424 |
gif_images.append(after2_img)
|
| 425 |
gif_images.append(os.path.join(dirname, "black_image.jpg")) # Add black image at the end
|
|
@@ -457,10 +391,7 @@ def create_GIF(
|
|
| 457 |
right = "[{}:v]".format(i)
|
| 458 |
out = "[v{}]".format(i)
|
| 459 |
offset = (i - 1) * (display_dur + fade_dur) + display_dur
|
| 460 |
-
parts.append(
|
| 461 |
-
"{}{}xfade=transition=fade:".format(left, right)
|
| 462 |
-
+ "duration={}:offset={:.3f}{}".format(fade_dur, offset, out)
|
| 463 |
-
)
|
| 464 |
filter_complex = ";".join(parts)
|
| 465 |
|
| 466 |
# 3. make MP4 slideshow
|
|
@@ -544,9 +475,7 @@ def create_GIF(
|
|
| 544 |
print_log(f"GIF saved as '{gif_output_path}'", logger="current")
|
| 545 |
|
| 546 |
|
| 547 |
-
def _update_bbox_by_mask(
|
| 548 |
-
bbox: List[int], mask_poly: Optional[List[List[int]]], image_shape: Tuple[int, int, int]
|
| 549 |
-
) -> List[int]:
|
| 550 |
"""
|
| 551 |
Adjust bounding box to tightly fit mask polygon.
|
| 552 |
|
|
@@ -591,11 +520,6 @@ def pose_nms(config: Any, image_kpts: np.ndarray, image_bboxes: np.ndarray, num_
|
|
| 591 |
Returns:
|
| 592 |
np.ndarray: Indices of kept instances.
|
| 593 |
"""
|
| 594 |
-
# Sort image kpts by average score - lowest first
|
| 595 |
-
# scores = image_kpts[:, :, 2].mean(axis=1)
|
| 596 |
-
# sort_idx = np.argsort(scores)
|
| 597 |
-
# image_kpts = image_kpts[sort_idx, :, :]
|
| 598 |
-
|
| 599 |
# Compute OKS between all pairs of poses
|
| 600 |
oks_matrix = np.zeros((image_kpts.shape[0], image_kpts.shape[0]))
|
| 601 |
for i in range(image_kpts.shape[0]):
|
|
@@ -611,8 +535,7 @@ def pose_nms(config: Any, image_kpts: np.ndarray, image_bboxes: np.ndarray, num_
|
|
| 611 |
dt = {"keypoints": image_kpts[j].copy(), "bbox": gt_bbox_xyxy}
|
| 612 |
gt["keypoints"][:, 2] = (gt["keypoints"][:, 2] > config.confidence_thr) * 2
|
| 613 |
oks = compute_oks(gt, dt)
|
| 614 |
-
|
| 615 |
-
breakpoint()
|
| 616 |
oks_matrix[i, j] = oks
|
| 617 |
|
| 618 |
np.fill_diagonal(oks_matrix, -1)
|
|
@@ -653,13 +576,10 @@ def compute_oks(gt: Dict[str, Any], dt: Dict[str, Any], use_area: bool = True, p
|
|
| 653 |
Returns:
|
| 654 |
float: OKS score or mean OKS.
|
| 655 |
"""
|
| 656 |
-
sigmas = (
|
| 657 |
-
np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
|
| 658 |
-
/ 10.0
|
| 659 |
-
)
|
| 660 |
vars = (sigmas * 2) ** 2
|
| 661 |
k = len(sigmas)
|
| 662 |
-
visibility_condition = lambda x: x > 0
|
| 663 |
g = np.array(gt["keypoints"]).reshape(k, 3)
|
| 664 |
xg = g[:, 0]
|
| 665 |
yg = g[:, 1]
|
|
|
|
| 1 |
+
# Copyright (c) authors of BBoxMaskPose (BMPv2). All rights reserved.
|
| 2 |
"""
|
| 3 |
Utilities for the BMP demo:
|
| 4 |
- Visualization of detections, masks, and poses
|
|
|
|
| 19 |
from mmengine.logging import print_log
|
| 20 |
from mmengine.structures import InstanceData
|
| 21 |
from pycocotools import mask as Mask
|
|
|
|
| 22 |
from tqdm import tqdm
|
| 23 |
|
| 24 |
+
from bboxmaskpose.sam2.distinctipy import get_colors
|
| 25 |
+
|
| 26 |
### Visualization hyperparameters
|
| 27 |
MIN_CONTOUR_AREA: int = 50
|
| 28 |
BBOX_WEIGHT: float = 0.9
|
|
|
|
| 40 |
from .posevis_lite import pose_visualization
|
| 41 |
|
| 42 |
|
| 43 |
+
WHITELIST_ATTRIBUTES = [
|
| 44 |
+
"bboxes",
|
| 45 |
+
"bbox_scores",
|
| 46 |
+
"keypoints",
|
| 47 |
+
"keypoint_scores",
|
| 48 |
+
"scores",
|
| 49 |
+
"pred_masks",
|
| 50 |
+
"refined_masks",
|
| 51 |
+
"sam_scores",
|
| 52 |
+
"sam_kpts",
|
| 53 |
+
"keypoint_vis",
|
| 54 |
+
"keypoint_prob",
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
class DotDict(dict):
|
| 59 |
"""Dictionary with attribute access and nested dict wrapping."""
|
| 60 |
|
|
|
|
| 85 |
return None
|
| 86 |
data = {}
|
| 87 |
# Attributes to filter
|
| 88 |
+
for attr in WHITELIST_ATTRIBUTES:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
if hasattr(instances, attr):
|
| 90 |
arr = getattr(instances, attr)
|
| 91 |
data[attr] = arr[indices] if arr is not None else None
|
|
|
|
| 102 |
if instances2 is None:
|
| 103 |
return instances1
|
| 104 |
data = {}
|
| 105 |
+
for attr in WHITELIST_ATTRIBUTES:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
arr1 = getattr(instances1, attr, None)
|
| 107 |
arr2 = getattr(instances2, attr, None)
|
| 108 |
if arr1 is None and arr2 is None:
|
|
|
|
| 142 |
"""
|
| 143 |
vis_types = vis_type.split("+")
|
| 144 |
|
| 145 |
+
# Exclude white, black, and green colors from the palette as they are not distinctive
|
| 146 |
+
colors = (np.array(get_colors(len(bboxes), exclude_colors=[(0, 1, 0), (0, 0, 0), (1, 1, 1)], rng=0)) * 255).astype(int)
|
| 147 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
if mask_is_binary:
|
| 149 |
poly_masks: List[Optional[List[np.ndarray]]] = []
|
| 150 |
for binary_mask in masks:
|
| 151 |
if binary_mask is not None:
|
| 152 |
+
contours, _ = cv2.findContours((binary_mask * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
|
|
|
| 153 |
polys = [cnt.flatten() for cnt in contours if cv2.contourArea(cnt) >= MIN_CONTOUR_AREA]
|
| 154 |
else:
|
| 155 |
polys = None
|
| 156 |
poly_masks.append(polys)
|
| 157 |
masks = poly_masks # type: ignore
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
if "inv-mask" in vis_types:
|
| 160 |
stencil = np.zeros_like(img)
|
| 161 |
|
|
|
|
| 246 |
label = "BMP {:d}x: {}".format(iteration_idx + 1, vis_def["label"])
|
| 247 |
cv2.putText(vis_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
|
| 248 |
cv2.putText(vis_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
|
| 249 |
+
out_path = os.path.join(output_root, "{}_iter{}_{}.jpg".format(img_name, iteration_idx + 1, vis_def["label"].replace(" ", "_")))
|
|
|
|
|
|
|
| 250 |
cv2.imwrite(str(out_path), vis_img)
|
| 251 |
|
| 252 |
# Show prompting keypoints
|
|
|
|
| 283 |
return masked_out
|
| 284 |
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
def create_GIF(
|
| 287 |
img_path: Path,
|
| 288 |
output_root: Path,
|
|
|
|
| 354 |
# Add 'before' and 'after' images
|
| 355 |
after1_img = os.path.join(dirname, "{}_iter{}_Final_Poses.jpg".format(img_name_wo_ext, bmp_x))
|
| 356 |
after2_img = os.path.join(dirname, "{}_iter{}_SAM_Masks.jpg".format(img_name_wo_ext, bmp_x))
|
|
|
|
| 357 |
gif_images.append(after1_img)
|
| 358 |
gif_images.append(after2_img)
|
| 359 |
gif_images.append(os.path.join(dirname, "black_image.jpg")) # Add black image at the end
|
|
|
|
| 391 |
right = "[{}:v]".format(i)
|
| 392 |
out = "[v{}]".format(i)
|
| 393 |
offset = (i - 1) * (display_dur + fade_dur) + display_dur
|
| 394 |
+
parts.append("{}{}xfade=transition=fade:".format(left, right) + "duration={}:offset={:.3f}{}".format(fade_dur, offset, out))
|
|
|
|
|
|
|
|
|
|
| 395 |
filter_complex = ";".join(parts)
|
| 396 |
|
| 397 |
# 3. make MP4 slideshow
|
|
|
|
| 475 |
print_log(f"GIF saved as '{gif_output_path}'", logger="current")
|
| 476 |
|
| 477 |
|
| 478 |
+
def _update_bbox_by_mask(bbox: List[int], mask_poly: Optional[List[List[int]]], image_shape: Tuple[int, int, int]) -> List[int]:
|
|
|
|
|
|
|
| 479 |
"""
|
| 480 |
Adjust bounding box to tightly fit mask polygon.
|
| 481 |
|
|
|
|
| 520 |
Returns:
|
| 521 |
np.ndarray: Indices of kept instances.
|
| 522 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
# Compute OKS between all pairs of poses
|
| 524 |
oks_matrix = np.zeros((image_kpts.shape[0], image_kpts.shape[0]))
|
| 525 |
for i in range(image_kpts.shape[0]):
|
|
|
|
| 535 |
dt = {"keypoints": image_kpts[j].copy(), "bbox": gt_bbox_xyxy}
|
| 536 |
gt["keypoints"][:, 2] = (gt["keypoints"][:, 2] > config.confidence_thr) * 2
|
| 537 |
oks = compute_oks(gt, dt)
|
| 538 |
+
assert oks <= 1.0, f"OKS value {oks} exceeds 1.0, which indicates a bug in compute_oks"
|
|
|
|
| 539 |
oks_matrix[i, j] = oks
|
| 540 |
|
| 541 |
np.fill_diagonal(oks_matrix, -1)
|
|
|
|
| 576 |
Returns:
|
| 577 |
float: OKS score or mean OKS.
|
| 578 |
"""
|
| 579 |
+
sigmas = np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) / 10.0
|
|
|
|
|
|
|
|
|
|
| 580 |
vars = (sigmas * 2) ** 2
|
| 581 |
k = len(sigmas)
|
| 582 |
+
visibility_condition = lambda x: x > 0.3
|
| 583 |
g = np.array(gt["keypoints"]).reshape(k, 3)
|
| 584 |
xg = g[:, 0]
|
| 585 |
yg = g[:, 1]
|
{demo → bboxmaskpose}/posevis_lite.py
RENAMED
|
@@ -1,9 +1,13 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 3 |
|
| 4 |
import cv2
|
| 5 |
import numpy as np
|
| 6 |
|
|
|
|
|
|
|
| 7 |
NEUTRAL_COLOR = (52, 235, 107)
|
| 8 |
|
| 9 |
LEFT_ARM_COLOR = (216, 235, 52)
|
|
@@ -85,14 +89,6 @@ def _draw_line(
|
|
| 85 |
start = np.array(start)[:2]
|
| 86 |
stop = np.array(stop)[:2]
|
| 87 |
if line_type.lower() == "solid":
|
| 88 |
-
img = cv2.line(
|
| 89 |
-
img,
|
| 90 |
-
(int(start[0]), int(start[1])),
|
| 91 |
-
(int(stop[0]), int(stop[1])),
|
| 92 |
-
color=(0, 0, 0),
|
| 93 |
-
thickness=thickness+1,
|
| 94 |
-
lineType=cv2.LINE_AA,
|
| 95 |
-
)
|
| 96 |
img = cv2.line(
|
| 97 |
img,
|
| 98 |
(int(start[0]), int(start[1])),
|
|
@@ -193,7 +189,14 @@ def pose_visualization(
|
|
| 193 |
if not isinstance(color, (list, tuple)):
|
| 194 |
color = [color for keypoint in keypoints]
|
| 195 |
else:
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
max_padding = [0, 0, 0, 0]
|
| 199 |
for keypoint, clr in zip(keypoints, color):
|
|
@@ -243,12 +246,9 @@ def pose_visualization(
|
|
| 243 |
# If conf >= confidence_thr: conf = 2
|
| 244 |
vis_is_float = np.any(np.logical_and(keypoints[:, -1] > 0, keypoints[:, -1] < 1))
|
| 245 |
if keypoints.shape[1] == 3 and vis_is_float:
|
| 246 |
-
# print("before", keypoints[:, -1])
|
| 247 |
lower_idx = keypoints[:, -1] < confidence_thr
|
| 248 |
keypoints[lower_idx, -1] = 1
|
| 249 |
keypoints[~lower_idx, -1] = 2
|
| 250 |
-
# print("after", keypoints[:, -1])
|
| 251 |
-
# print("-"*20)
|
| 252 |
|
| 253 |
# All visibility values should be ints
|
| 254 |
keypoints[:, -1] = keypoints[:, -1].astype(int)
|
|
|
|
| 1 |
+
# Copyright (c) authors of BBoxMaskPose (BMPv2). All rights reserved.
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 5 |
|
| 6 |
import cv2
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
+
from bboxmaskpose.sam2.distinctipy import get_colors
|
| 10 |
+
|
| 11 |
NEUTRAL_COLOR = (52, 235, 107)
|
| 12 |
|
| 13 |
LEFT_ARM_COLOR = (216, 235, 52)
|
|
|
|
| 89 |
start = np.array(start)[:2]
|
| 90 |
stop = np.array(stop)[:2]
|
| 91 |
if line_type.lower() == "solid":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
img = cv2.line(
|
| 93 |
img,
|
| 94 |
(int(start[0]), int(start[1])),
|
|
|
|
| 189 |
if not isinstance(color, (list, tuple)):
|
| 190 |
color = [color for keypoint in keypoints]
|
| 191 |
else:
|
| 192 |
+
if differ_individuals:
|
| 193 |
+
color = (
|
| 194 |
+
(np.array(get_colors(len(keypoints), exclude_colors=[(0, 1, 0), (0, 0, 0), (1, 1, 1)], rng=0)) * 255)
|
| 195 |
+
.astype(int)
|
| 196 |
+
.tolist()
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
color = [None for keypoint in keypoints]
|
| 200 |
|
| 201 |
max_padding = [0, 0, 0, 0]
|
| 202 |
for keypoint, clr in zip(keypoints, color):
|
|
|
|
| 246 |
# If conf >= confidence_thr: conf = 2
|
| 247 |
vis_is_float = np.any(np.logical_and(keypoints[:, -1] > 0, keypoints[:, -1] < 1))
|
| 248 |
if keypoints.shape[1] == 3 and vis_is_float:
|
|
|
|
| 249 |
lower_idx = keypoints[:, -1] < confidence_thr
|
| 250 |
keypoints[lower_idx, -1] = 1
|
| 251 |
keypoints[~lower_idx, -1] = 2
|
|
|
|
|
|
|
| 252 |
|
| 253 |
# All visibility values should be ints
|
| 254 |
keypoints[:, -1] = keypoints[:, -1].astype(int)
|
{sam2 → bboxmaskpose/sam2}/__init__.py
RENAMED
|
@@ -8,4 +8,4 @@ from hydra import initialize_config_module
|
|
| 8 |
from hydra.core.global_hydra import GlobalHydra
|
| 9 |
|
| 10 |
if not GlobalHydra.instance().is_initialized():
|
| 11 |
-
initialize_config_module("sam2", version_base="1.2")
|
|
|
|
| 8 |
from hydra.core.global_hydra import GlobalHydra
|
| 9 |
|
| 10 |
if not GlobalHydra.instance().is_initialized():
|
| 11 |
+
initialize_config_module("bboxmaskpose.sam2", version_base="1.2")
|
{sam2 → bboxmaskpose/sam2}/automatic_mask_generator.py
RENAMED
|
@@ -11,9 +11,10 @@ import numpy as np
|
|
| 11 |
import torch
|
| 12 |
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
| 13 |
|
| 14 |
-
from sam2.modeling.sam2_base import SAM2Base
|
| 15 |
-
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 16 |
-
from sam2.utils.amg import (
|
|
|
|
| 17 |
area_from_rle,
|
| 18 |
batch_iterator,
|
| 19 |
batched_mask_to_box,
|
|
@@ -24,7 +25,6 @@ from sam2.utils.amg import (
|
|
| 24 |
generate_crop_boxes,
|
| 25 |
is_box_near_crop_edge,
|
| 26 |
mask_to_rle_pytorch,
|
| 27 |
-
MaskData,
|
| 28 |
remove_small_regions,
|
| 29 |
rle_to_mask,
|
| 30 |
uncrop_boxes_xyxy,
|
|
@@ -103,9 +103,7 @@ class SAM2AutomaticMaskGenerator:
|
|
| 103 |
multimask_output (bool): Whether to output multimask at each point of the grid.
|
| 104 |
"""
|
| 105 |
|
| 106 |
-
assert (points_per_side is None) != (
|
| 107 |
-
point_grids is None
|
| 108 |
-
), "Exactly one of points_per_side or point_grid must be provided."
|
| 109 |
if points_per_side is not None:
|
| 110 |
self.point_grids = build_all_layer_point_grids(
|
| 111 |
points_per_side,
|
|
@@ -161,7 +159,7 @@ class SAM2AutomaticMaskGenerator:
|
|
| 161 |
Returns:
|
| 162 |
(SAM2AutomaticMaskGenerator): The loaded model.
|
| 163 |
"""
|
| 164 |
-
from sam2.build_sam import build_sam2_hf
|
| 165 |
|
| 166 |
sam_model = build_sam2_hf(model_id, **kwargs)
|
| 167 |
return cls(sam_model, **kwargs)
|
|
@@ -197,9 +195,7 @@ class SAM2AutomaticMaskGenerator:
|
|
| 197 |
|
| 198 |
# Encode masks
|
| 199 |
if self.output_mode == "coco_rle":
|
| 200 |
-
mask_data["segmentations"] = [
|
| 201 |
-
coco_encode_rle(rle) for rle in mask_data["rles"]
|
| 202 |
-
]
|
| 203 |
elif self.output_mode == "binary_mask":
|
| 204 |
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
| 205 |
else:
|
|
@@ -223,9 +219,7 @@ class SAM2AutomaticMaskGenerator:
|
|
| 223 |
|
| 224 |
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
| 225 |
orig_size = image.shape[:2]
|
| 226 |
-
crop_boxes, layer_idxs = generate_crop_boxes(
|
| 227 |
-
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
| 228 |
-
)
|
| 229 |
|
| 230 |
# Iterate over image crops
|
| 231 |
data = MaskData()
|
|
@@ -268,9 +262,7 @@ class SAM2AutomaticMaskGenerator:
|
|
| 268 |
# Generate masks for this crop in batches
|
| 269 |
data = MaskData()
|
| 270 |
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
| 271 |
-
batch_data = self._process_batch(
|
| 272 |
-
points, cropped_im_size, crop_box, orig_size, normalize=True
|
| 273 |
-
)
|
| 274 |
data.cat(batch_data)
|
| 275 |
del batch_data
|
| 276 |
self.predictor.reset_predictor()
|
|
@@ -302,15 +294,9 @@ class SAM2AutomaticMaskGenerator:
|
|
| 302 |
orig_h, orig_w = orig_size
|
| 303 |
|
| 304 |
# Run model on this batch
|
| 305 |
-
points = torch.as_tensor(
|
| 306 |
-
|
| 307 |
-
)
|
| 308 |
-
in_points = self.predictor._transforms.transform_coords(
|
| 309 |
-
points, normalize=normalize, orig_hw=im_size
|
| 310 |
-
)
|
| 311 |
-
in_labels = torch.ones(
|
| 312 |
-
in_points.shape[0], dtype=torch.int, device=in_points.device
|
| 313 |
-
)
|
| 314 |
masks, iou_preds, low_res_masks = self.predictor._predict(
|
| 315 |
in_points[:, None, :],
|
| 316 |
in_labels[:, None],
|
|
@@ -334,23 +320,15 @@ class SAM2AutomaticMaskGenerator:
|
|
| 334 |
data.filter(keep_mask)
|
| 335 |
|
| 336 |
# Calculate and filter by stability score
|
| 337 |
-
data["stability_score"] = calculate_stability_score(
|
| 338 |
-
data["masks"], self.mask_threshold, self.stability_score_offset
|
| 339 |
-
)
|
| 340 |
if self.stability_score_thresh > 0.0:
|
| 341 |
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 342 |
data.filter(keep_mask)
|
| 343 |
else:
|
| 344 |
# One step refinement using previous mask predictions
|
| 345 |
-
in_points = self.predictor._transforms.transform_coords(
|
| 346 |
-
|
| 347 |
-
)
|
| 348 |
-
labels = torch.ones(
|
| 349 |
-
in_points.shape[0], dtype=torch.int, device=in_points.device
|
| 350 |
-
)
|
| 351 |
-
masks, ious = self.refine_with_m2m(
|
| 352 |
-
in_points, labels, data["low_res_masks"], self.points_per_batch
|
| 353 |
-
)
|
| 354 |
data["masks"] = masks.squeeze(1)
|
| 355 |
data["iou_preds"] = ious.squeeze(1)
|
| 356 |
|
|
@@ -358,9 +336,7 @@ class SAM2AutomaticMaskGenerator:
|
|
| 358 |
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
| 359 |
data.filter(keep_mask)
|
| 360 |
|
| 361 |
-
data["stability_score"] = calculate_stability_score(
|
| 362 |
-
data["masks"], self.mask_threshold, self.stability_score_offset
|
| 363 |
-
)
|
| 364 |
if self.stability_score_thresh > 0.0:
|
| 365 |
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 366 |
data.filter(keep_mask)
|
|
@@ -370,9 +346,7 @@ class SAM2AutomaticMaskGenerator:
|
|
| 370 |
data["boxes"] = batched_mask_to_box(data["masks"])
|
| 371 |
|
| 372 |
# Filter boxes that touch crop boundaries
|
| 373 |
-
keep_mask = ~is_box_near_crop_edge(
|
| 374 |
-
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
|
| 375 |
-
)
|
| 376 |
if not torch.all(keep_mask):
|
| 377 |
data.filter(keep_mask)
|
| 378 |
|
|
@@ -384,9 +358,7 @@ class SAM2AutomaticMaskGenerator:
|
|
| 384 |
return data
|
| 385 |
|
| 386 |
@staticmethod
|
| 387 |
-
def postprocess_small_regions(
|
| 388 |
-
mask_data: MaskData, min_area: int, nms_thresh: float
|
| 389 |
-
) -> MaskData:
|
| 390 |
"""
|
| 391 |
Removes small disconnected regions and holes in masks, then reruns
|
| 392 |
box NMS to remove any new duplicates.
|
|
@@ -438,9 +410,7 @@ class SAM2AutomaticMaskGenerator:
|
|
| 438 |
new_masks = []
|
| 439 |
new_iou_preds = []
|
| 440 |
|
| 441 |
-
for cur_points, cur_point_labels, low_res_mask in batch_iterator(
|
| 442 |
-
points_per_batch, points, point_labels, low_res_masks
|
| 443 |
-
):
|
| 444 |
best_masks, best_iou_preds, _ = self.predictor._predict(
|
| 445 |
cur_points[:, None, :],
|
| 446 |
cur_point_labels[:, None],
|
|
|
|
| 11 |
import torch
|
| 12 |
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
| 13 |
|
| 14 |
+
from bboxmaskpose.sam2.modeling.sam2_base import SAM2Base
|
| 15 |
+
from bboxmaskpose.sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 16 |
+
from bboxmaskpose.sam2.utils.amg import (
|
| 17 |
+
MaskData,
|
| 18 |
area_from_rle,
|
| 19 |
batch_iterator,
|
| 20 |
batched_mask_to_box,
|
|
|
|
| 25 |
generate_crop_boxes,
|
| 26 |
is_box_near_crop_edge,
|
| 27 |
mask_to_rle_pytorch,
|
|
|
|
| 28 |
remove_small_regions,
|
| 29 |
rle_to_mask,
|
| 30 |
uncrop_boxes_xyxy,
|
|
|
|
| 103 |
multimask_output (bool): Whether to output multimask at each point of the grid.
|
| 104 |
"""
|
| 105 |
|
| 106 |
+
assert (points_per_side is None) != (point_grids is None), "Exactly one of points_per_side or point_grid must be provided."
|
|
|
|
|
|
|
| 107 |
if points_per_side is not None:
|
| 108 |
self.point_grids = build_all_layer_point_grids(
|
| 109 |
points_per_side,
|
|
|
|
| 159 |
Returns:
|
| 160 |
(SAM2AutomaticMaskGenerator): The loaded model.
|
| 161 |
"""
|
| 162 |
+
from bboxmaskpose.sam2.build_sam import build_sam2_hf
|
| 163 |
|
| 164 |
sam_model = build_sam2_hf(model_id, **kwargs)
|
| 165 |
return cls(sam_model, **kwargs)
|
|
|
|
| 195 |
|
| 196 |
# Encode masks
|
| 197 |
if self.output_mode == "coco_rle":
|
| 198 |
+
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
|
|
|
|
|
|
| 199 |
elif self.output_mode == "binary_mask":
|
| 200 |
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
| 201 |
else:
|
|
|
|
| 219 |
|
| 220 |
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
| 221 |
orig_size = image.shape[:2]
|
| 222 |
+
crop_boxes, layer_idxs = generate_crop_boxes(orig_size, self.crop_n_layers, self.crop_overlap_ratio)
|
|
|
|
|
|
|
| 223 |
|
| 224 |
# Iterate over image crops
|
| 225 |
data = MaskData()
|
|
|
|
| 262 |
# Generate masks for this crop in batches
|
| 263 |
data = MaskData()
|
| 264 |
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
| 265 |
+
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, normalize=True)
|
|
|
|
|
|
|
| 266 |
data.cat(batch_data)
|
| 267 |
del batch_data
|
| 268 |
self.predictor.reset_predictor()
|
|
|
|
| 294 |
orig_h, orig_w = orig_size
|
| 295 |
|
| 296 |
# Run model on this batch
|
| 297 |
+
points = torch.as_tensor(points, dtype=torch.float32, device=self.predictor.device)
|
| 298 |
+
in_points = self.predictor._transforms.transform_coords(points, normalize=normalize, orig_hw=im_size)
|
| 299 |
+
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
masks, iou_preds, low_res_masks = self.predictor._predict(
|
| 301 |
in_points[:, None, :],
|
| 302 |
in_labels[:, None],
|
|
|
|
| 320 |
data.filter(keep_mask)
|
| 321 |
|
| 322 |
# Calculate and filter by stability score
|
| 323 |
+
data["stability_score"] = calculate_stability_score(data["masks"], self.mask_threshold, self.stability_score_offset)
|
|
|
|
|
|
|
| 324 |
if self.stability_score_thresh > 0.0:
|
| 325 |
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 326 |
data.filter(keep_mask)
|
| 327 |
else:
|
| 328 |
# One step refinement using previous mask predictions
|
| 329 |
+
in_points = self.predictor._transforms.transform_coords(data["points"], normalize=normalize, orig_hw=im_size)
|
| 330 |
+
labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
| 331 |
+
masks, ious = self.refine_with_m2m(in_points, labels, data["low_res_masks"], self.points_per_batch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
data["masks"] = masks.squeeze(1)
|
| 333 |
data["iou_preds"] = ious.squeeze(1)
|
| 334 |
|
|
|
|
| 336 |
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
| 337 |
data.filter(keep_mask)
|
| 338 |
|
| 339 |
+
data["stability_score"] = calculate_stability_score(data["masks"], self.mask_threshold, self.stability_score_offset)
|
|
|
|
|
|
|
| 340 |
if self.stability_score_thresh > 0.0:
|
| 341 |
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 342 |
data.filter(keep_mask)
|
|
|
|
| 346 |
data["boxes"] = batched_mask_to_box(data["masks"])
|
| 347 |
|
| 348 |
# Filter boxes that touch crop boundaries
|
| 349 |
+
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
|
|
|
|
|
|
| 350 |
if not torch.all(keep_mask):
|
| 351 |
data.filter(keep_mask)
|
| 352 |
|
|
|
|
| 358 |
return data
|
| 359 |
|
| 360 |
@staticmethod
|
| 361 |
+
def postprocess_small_regions(mask_data: MaskData, min_area: int, nms_thresh: float) -> MaskData:
|
|
|
|
|
|
|
| 362 |
"""
|
| 363 |
Removes small disconnected regions and holes in masks, then reruns
|
| 364 |
box NMS to remove any new duplicates.
|
|
|
|
| 410 |
new_masks = []
|
| 411 |
new_iou_preds = []
|
| 412 |
|
| 413 |
+
for cur_points, cur_point_labels, low_res_mask in batch_iterator(points_per_batch, points, point_labels, low_res_masks):
|
|
|
|
|
|
|
| 414 |
best_masks, best_iou_preds, _ = self.predictor._predict(
|
| 415 |
cur_points[:, None, :],
|
| 416 |
cur_point_labels[:, None],
|
{sam2 → bboxmaskpose/sam2}/benchmark.py
RENAMED
|
@@ -11,7 +11,7 @@ import numpy as np
|
|
| 11 |
import torch
|
| 12 |
from tqdm import tqdm
|
| 13 |
|
| 14 |
-
from sam2.build_sam import build_sam2_video_predictor
|
| 15 |
|
| 16 |
# Only cuda supported
|
| 17 |
assert torch.cuda.is_available()
|
|
@@ -28,19 +28,13 @@ sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
|
|
| 28 |
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
|
| 29 |
|
| 30 |
# Build video predictor with vos_optimized=True setting
|
| 31 |
-
predictor = build_sam2_video_predictor(
|
| 32 |
-
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
|
| 33 |
-
)
|
| 34 |
|
| 35 |
|
| 36 |
# Initialize with video
|
| 37 |
video_dir = "notebooks/videos/bedroom"
|
| 38 |
# scan all the JPEG frame names in this directory
|
| 39 |
-
frame_names = [
|
| 40 |
-
p
|
| 41 |
-
for p in os.listdir(video_dir)
|
| 42 |
-
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
| 43 |
-
]
|
| 44 |
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
| 45 |
inference_state = predictor.init_state(video_path=video_dir)
|
| 46 |
|
|
|
|
| 11 |
import torch
|
| 12 |
from tqdm import tqdm
|
| 13 |
|
| 14 |
+
from bboxmaskpose.sam2.build_sam import build_sam2_video_predictor
|
| 15 |
|
| 16 |
# Only cuda supported
|
| 17 |
assert torch.cuda.is_available()
|
|
|
|
| 28 |
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
|
| 29 |
|
| 30 |
# Build video predictor with vos_optimized=True setting
|
| 31 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device, vos_optimized=True)
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
# Initialize with video
|
| 35 |
video_dir = "notebooks/videos/bedroom"
|
| 36 |
# scan all the JPEG frame names in this directory
|
| 37 |
+
frame_names = [p for p in os.listdir(video_dir) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
| 39 |
inference_state = predictor.init_state(video_path=video_dir)
|
| 40 |
|
{sam2 → bboxmaskpose/sam2}/build_sam.py
RENAMED
|
@@ -6,14 +6,16 @@
|
|
| 6 |
|
| 7 |
import logging
|
| 8 |
import os
|
|
|
|
| 9 |
|
| 10 |
import torch
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
from hydra.utils import instantiate
|
| 13 |
from omegaconf import OmegaConf
|
| 14 |
|
| 15 |
-
import sam2
|
| 16 |
-
|
| 17 |
# Check if the user is running Python from the parent directory of the sam2 repo
|
| 18 |
# (i.e. the directory where this repo is cloned into) -- this is not supported since
|
| 19 |
# it could shadow the sam2 package and cause issues.
|
|
@@ -86,13 +88,26 @@ def build_sam2(
|
|
| 86 |
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 87 |
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 88 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
# Read config and init model
|
| 90 |
try:
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
except Exception as e:
|
| 93 |
logging.error(f"Error loading config: {e}")
|
| 94 |
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
|
| 95 |
-
|
| 96 |
OmegaConf.resolve(cfg)
|
| 97 |
model = instantiate(cfg.model, _recursive_=True)
|
| 98 |
_load_checkpoint(model, ckpt_path)
|
|
@@ -161,14 +176,23 @@ def build_sam2_hf(model_id, **kwargs):
|
|
| 161 |
|
| 162 |
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
| 163 |
config_name, ckpt_path = _hf_download(model_id)
|
| 164 |
-
return build_sam2_video_predictor(
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
def _load_checkpoint(model, ckpt_path):
|
| 170 |
if ckpt_path is not None:
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
missing_keys, unexpected_keys = model.load_state_dict(sd)
|
| 173 |
if missing_keys:
|
| 174 |
logging.error(missing_keys)
|
|
@@ -176,4 +200,5 @@ def _load_checkpoint(model, ckpt_path):
|
|
| 176 |
if unexpected_keys:
|
| 177 |
logging.error(unexpected_keys)
|
| 178 |
raise RuntimeError()
|
|
|
|
| 179 |
logging.info("Loaded checkpoint sucessfully")
|
|
|
|
| 6 |
|
| 7 |
import logging
|
| 8 |
import os
|
| 9 |
+
import urllib.parse as urlparse
|
| 10 |
|
| 11 |
import torch
|
| 12 |
+
|
| 13 |
+
import bboxmaskpose.sam2 as sam2
|
| 14 |
+
from hydra import compose, initialize_config_dir
|
| 15 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 16 |
from hydra.utils import instantiate
|
| 17 |
from omegaconf import OmegaConf
|
| 18 |
|
|
|
|
|
|
|
| 19 |
# Check if the user is running Python from the parent directory of the sam2 repo
|
| 20 |
# (i.e. the directory where this repo is cloned into) -- this is not supported since
|
| 21 |
# it could shadow the sam2 package and cause issues.
|
|
|
|
| 88 |
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 89 |
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 90 |
]
|
| 91 |
+
|
| 92 |
+
# IMPORTANT: compose() requires Hydra to be initialized with a config source.
|
| 93 |
+
# Also important if build_sam2() can be called multiple times in one process.
|
| 94 |
+
GlobalHydra.instance().clear()
|
| 95 |
+
|
| 96 |
+
# Point Hydra at the directory that contains the SAM2 yaml configs
|
| 97 |
+
config_dir = os.path.dirname(config_file)
|
| 98 |
+
|
| 99 |
+
# Hydra expects config_name WITHOUT .yaml
|
| 100 |
+
config_name = os.path.basename(config_file).replace(".yaml", "")
|
| 101 |
+
|
| 102 |
# Read config and init model
|
| 103 |
try:
|
| 104 |
+
with initialize_config_dir(version_base=None, config_dir=str(config_dir)):
|
| 105 |
+
cfg = compose(config_name=config_name, overrides=hydra_overrides_extra)
|
| 106 |
+
# cfg = compose(config_name=config_file)
|
| 107 |
except Exception as e:
|
| 108 |
logging.error(f"Error loading config: {e}")
|
| 109 |
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
|
| 110 |
+
|
| 111 |
OmegaConf.resolve(cfg)
|
| 112 |
model = instantiate(cfg.model, _recursive_=True)
|
| 113 |
_load_checkpoint(model, ckpt_path)
|
|
|
|
| 176 |
|
| 177 |
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
| 178 |
config_name, ckpt_path = _hf_download(model_id)
|
| 179 |
+
return build_sam2_video_predictor(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _is_url(path: str) -> bool:
|
| 183 |
+
return urlparse.urlparse(path).scheme != ""
|
| 184 |
|
| 185 |
|
| 186 |
def _load_checkpoint(model, ckpt_path):
|
| 187 |
if ckpt_path is not None:
|
| 188 |
+
|
| 189 |
+
if _is_url(ckpt_path):
|
| 190 |
+
sd = torch.hub.load_state_dict_from_url(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
| 191 |
+
elif os.path.exists(ckpt_path):
|
| 192 |
+
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
| 193 |
+
else:
|
| 194 |
+
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
| 195 |
+
|
| 196 |
missing_keys, unexpected_keys = model.load_state_dict(sd)
|
| 197 |
if missing_keys:
|
| 198 |
logging.error(missing_keys)
|
|
|
|
| 200 |
if unexpected_keys:
|
| 201 |
logging.error(unexpected_keys)
|
| 202 |
raise RuntimeError()
|
| 203 |
+
|
| 204 |
logging.info("Loaded checkpoint sucessfully")
|
{sam2 → bboxmaskpose/sam2}/colorblind.py
RENAMED
|
@@ -1,7 +1,10 @@
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Adapted from "The Color Blind Simulation function" by Matthew Wickline
|
| 3 |
and the Human - Computer Interaction Resource Network (http://hcirn.com/), 2000 - 2001.
|
| 4 |
"""
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
|
| 7 |
rBlind = {
|
|
@@ -261,16 +264,13 @@ def simulate_colors(colors, colorblind_type="Deuteranomaly", one_row=None, show=
|
|
| 261 |
:return:
|
| 262 |
"""
|
| 263 |
import matplotlib.pyplot as plt
|
| 264 |
-
|
| 265 |
from distinctipy import distinctipy
|
| 266 |
|
| 267 |
filtered_colors = [colorblind_filter(color, colorblind_type) for color in colors]
|
| 268 |
|
| 269 |
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
|
| 270 |
|
| 271 |
-
distinctipy.color_swatch(
|
| 272 |
-
colors, ax=axes[0], one_row=one_row, title="Viewed with Normal Sight"
|
| 273 |
-
)
|
| 274 |
|
| 275 |
distinctipy.color_swatch(
|
| 276 |
filtered_colors,
|
|
@@ -324,30 +324,22 @@ def simulate_clusters(
|
|
| 324 |
"""
|
| 325 |
import matplotlib.pyplot as plt
|
| 326 |
import pandas as pd
|
| 327 |
-
|
| 328 |
from distinctipy import distinctipy
|
| 329 |
|
| 330 |
if dataset not in ("s1", "s2", "s3", "s4", "a1", "a2", "a3", "b1"):
|
| 331 |
raise ValueError("dataset must be s1, s2, s3, s4, a1, a2, a3 or b1")
|
| 332 |
|
| 333 |
-
URL =
|
| 334 |
-
"https://raw.githubusercontent.com/alan-turing-institute/distinctipy/"
|
| 335 |
-
"main/distinctipy/datasets/"
|
| 336 |
-
)
|
| 337 |
df = pd.read_csv(URL + dataset + ".csv")
|
| 338 |
|
| 339 |
if colorblind_distinct:
|
| 340 |
-
orig_colors = distinctipy.get_colors(
|
| 341 |
-
df["cluster"].nunique(), colorblind_type=colorblind_type
|
| 342 |
-
)
|
| 343 |
else:
|
| 344 |
orig_colors = distinctipy.get_colors(df["cluster"].nunique())
|
| 345 |
|
| 346 |
orig_cmap = distinctipy.get_colormap(orig_colors)
|
| 347 |
|
| 348 |
-
filtered_colors = [
|
| 349 |
-
colorblind_filter(color, colorblind_type) for color in orig_colors
|
| 350 |
-
]
|
| 351 |
filtered_cmap = distinctipy.get_colormap(filtered_colors)
|
| 352 |
|
| 353 |
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
|
@@ -376,4 +368,4 @@ def _main():
|
|
| 376 |
|
| 377 |
|
| 378 |
if __name__ == "__main__":
|
| 379 |
-
_main()
|
|
|
|
| 1 |
+
# Adapted from the distinctipy repository (https://github.com/alan-turing-institute/distinctipy).
|
| 2 |
+
# Original authors: distinctipy contributors. Included with minor modifications.
|
| 3 |
"""
|
| 4 |
Adapted from "The Color Blind Simulation function" by Matthew Wickline
|
| 5 |
and the Human - Computer Interaction Resource Network (http://hcirn.com/), 2000 - 2001.
|
| 6 |
"""
|
| 7 |
+
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
rBlind = {
|
|
|
|
| 264 |
:return:
|
| 265 |
"""
|
| 266 |
import matplotlib.pyplot as plt
|
|
|
|
| 267 |
from distinctipy import distinctipy
|
| 268 |
|
| 269 |
filtered_colors = [colorblind_filter(color, colorblind_type) for color in colors]
|
| 270 |
|
| 271 |
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
|
| 272 |
|
| 273 |
+
distinctipy.color_swatch(colors, ax=axes[0], one_row=one_row, title="Viewed with Normal Sight")
|
|
|
|
|
|
|
| 274 |
|
| 275 |
distinctipy.color_swatch(
|
| 276 |
filtered_colors,
|
|
|
|
| 324 |
"""
|
| 325 |
import matplotlib.pyplot as plt
|
| 326 |
import pandas as pd
|
|
|
|
| 327 |
from distinctipy import distinctipy
|
| 328 |
|
| 329 |
if dataset not in ("s1", "s2", "s3", "s4", "a1", "a2", "a3", "b1"):
|
| 330 |
raise ValueError("dataset must be s1, s2, s3, s4, a1, a2, a3 or b1")
|
| 331 |
|
| 332 |
+
URL = "https://raw.githubusercontent.com/alan-turing-institute/distinctipy/" "main/distinctipy/datasets/"
|
|
|
|
|
|
|
|
|
|
| 333 |
df = pd.read_csv(URL + dataset + ".csv")
|
| 334 |
|
| 335 |
if colorblind_distinct:
|
| 336 |
+
orig_colors = distinctipy.get_colors(df["cluster"].nunique(), colorblind_type=colorblind_type)
|
|
|
|
|
|
|
| 337 |
else:
|
| 338 |
orig_colors = distinctipy.get_colors(df["cluster"].nunique())
|
| 339 |
|
| 340 |
orig_cmap = distinctipy.get_colormap(orig_colors)
|
| 341 |
|
| 342 |
+
filtered_colors = [colorblind_filter(color, colorblind_type) for color in orig_colors]
|
|
|
|
|
|
|
| 343 |
filtered_cmap = distinctipy.get_colormap(filtered_colors)
|
| 344 |
|
| 345 |
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
|
|
|
| 368 |
|
| 369 |
|
| 370 |
if __name__ == "__main__":
|
| 371 |
+
_main()
|
bboxmaskpose/sam2/configs/sam-pose2seg/sam-pose2seg_hiera_b+.yaml
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 112
|
| 12 |
+
num_heads: 2
|
| 13 |
+
neck:
|
| 14 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 15 |
+
position_encoding:
|
| 16 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 17 |
+
num_pos_feats: 256
|
| 18 |
+
normalize: true
|
| 19 |
+
scale: null
|
| 20 |
+
temperature: 10000
|
| 21 |
+
d_model: 256
|
| 22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
| 23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 24 |
+
fpn_interp_model: nearest
|
| 25 |
+
|
| 26 |
+
memory_attention:
|
| 27 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 28 |
+
d_model: 256
|
| 29 |
+
pos_enc_at_input: true
|
| 30 |
+
layer:
|
| 31 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 32 |
+
activation: relu
|
| 33 |
+
dim_feedforward: 2048
|
| 34 |
+
dropout: 0.1
|
| 35 |
+
pos_enc_at_attn: false
|
| 36 |
+
self_attention:
|
| 37 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 38 |
+
rope_theta: 10000.0
|
| 39 |
+
feat_sizes: [64, 64]
|
| 40 |
+
embedding_dim: 256
|
| 41 |
+
num_heads: 1
|
| 42 |
+
downsample_rate: 1
|
| 43 |
+
dropout: 0.1
|
| 44 |
+
d_model: 256
|
| 45 |
+
pos_enc_at_cross_attn_keys: true
|
| 46 |
+
pos_enc_at_cross_attn_queries: false
|
| 47 |
+
cross_attention:
|
| 48 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 49 |
+
rope_theta: 10000.0
|
| 50 |
+
feat_sizes: [64, 64]
|
| 51 |
+
rope_k_repeat: True
|
| 52 |
+
embedding_dim: 256
|
| 53 |
+
num_heads: 1
|
| 54 |
+
downsample_rate: 1
|
| 55 |
+
dropout: 0.1
|
| 56 |
+
kv_in_dim: 64
|
| 57 |
+
num_layers: 4
|
| 58 |
+
|
| 59 |
+
memory_encoder:
|
| 60 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 61 |
+
out_dim: 64
|
| 62 |
+
position_encoding:
|
| 63 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 64 |
+
num_pos_feats: 64
|
| 65 |
+
normalize: true
|
| 66 |
+
scale: null
|
| 67 |
+
temperature: 10000
|
| 68 |
+
mask_downsampler:
|
| 69 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 70 |
+
kernel_size: 3
|
| 71 |
+
stride: 2
|
| 72 |
+
padding: 1
|
| 73 |
+
fuser:
|
| 74 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 75 |
+
layer:
|
| 76 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 77 |
+
dim: 256
|
| 78 |
+
kernel_size: 7
|
| 79 |
+
padding: 3
|
| 80 |
+
layer_scale_init_value: 1e-6
|
| 81 |
+
use_dwconv: True # depth-wise convs
|
| 82 |
+
num_layers: 2
|
| 83 |
+
|
| 84 |
+
num_maskmem: 7
|
| 85 |
+
image_size: 1024
|
| 86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 89 |
+
use_mask_input_as_output_without_sam: true
|
| 90 |
+
# Memory
|
| 91 |
+
directly_add_no_mem_embed: true
|
| 92 |
+
no_obj_embed_spatial: true
|
| 93 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 94 |
+
use_high_res_features_in_sam: true
|
| 95 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 96 |
+
multimask_output_in_sam: true
|
| 97 |
+
# SAM heads
|
| 98 |
+
iou_prediction_use_sigmoid: True
|
| 99 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 100 |
+
use_obj_ptrs_in_encoder: true
|
| 101 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 102 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 103 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 105 |
+
# object occlusion prediction
|
| 106 |
+
pred_obj_scores: true
|
| 107 |
+
pred_obj_scores_mlp: true
|
| 108 |
+
fixed_no_obj_ptr: true
|
| 109 |
+
# multimask tracking settings
|
| 110 |
+
multimask_output_for_tracking: true
|
| 111 |
+
use_multimask_token_for_obj_ptr: true
|
| 112 |
+
multimask_min_pt_num: 0
|
| 113 |
+
multimask_max_pt_num: 1
|
| 114 |
+
use_mlp_for_obj_ptr_proj: true
|
| 115 |
+
|
| 116 |
+
n_kpts_encoder: 8
|
| 117 |
+
# Compilation flag
|
| 118 |
+
# compile_image_encoder: False
|
{sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_b+.yaml
RENAMED
|
@@ -2,18 +2,18 @@
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
-
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 112
|
| 12 |
num_heads: 2
|
| 13 |
neck:
|
| 14 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 15 |
position_encoding:
|
| 16 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 17 |
num_pos_feats: 256
|
| 18 |
normalize: true
|
| 19 |
scale: null
|
|
@@ -24,17 +24,17 @@ model:
|
|
| 24 |
fpn_interp_model: nearest
|
| 25 |
|
| 26 |
memory_attention:
|
| 27 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 28 |
d_model: 256
|
| 29 |
pos_enc_at_input: true
|
| 30 |
layer:
|
| 31 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 32 |
activation: relu
|
| 33 |
dim_feedforward: 2048
|
| 34 |
dropout: 0.1
|
| 35 |
pos_enc_at_attn: false
|
| 36 |
self_attention:
|
| 37 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 38 |
rope_theta: 10000.0
|
| 39 |
feat_sizes: [64, 64]
|
| 40 |
embedding_dim: 256
|
|
@@ -45,7 +45,7 @@ model:
|
|
| 45 |
pos_enc_at_cross_attn_keys: true
|
| 46 |
pos_enc_at_cross_attn_queries: false
|
| 47 |
cross_attention:
|
| 48 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 49 |
rope_theta: 10000.0
|
| 50 |
feat_sizes: [64, 64]
|
| 51 |
rope_k_repeat: True
|
|
@@ -57,23 +57,23 @@ model:
|
|
| 57 |
num_layers: 4
|
| 58 |
|
| 59 |
memory_encoder:
|
| 60 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 61 |
out_dim: 64
|
| 62 |
position_encoding:
|
| 63 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 64 |
num_pos_feats: 64
|
| 65 |
normalize: true
|
| 66 |
scale: null
|
| 67 |
temperature: 10000
|
| 68 |
mask_downsampler:
|
| 69 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 70 |
kernel_size: 3
|
| 71 |
stride: 2
|
| 72 |
padding: 1
|
| 73 |
fuser:
|
| 74 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
| 75 |
layer:
|
| 76 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 77 |
dim: 256
|
| 78 |
kernel_size: 7
|
| 79 |
padding: 3
|
|
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
+
_target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 112
|
| 12 |
num_heads: 2
|
| 13 |
neck:
|
| 14 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 15 |
position_encoding:
|
| 16 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 17 |
num_pos_feats: 256
|
| 18 |
normalize: true
|
| 19 |
scale: null
|
|
|
|
| 24 |
fpn_interp_model: nearest
|
| 25 |
|
| 26 |
memory_attention:
|
| 27 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 28 |
d_model: 256
|
| 29 |
pos_enc_at_input: true
|
| 30 |
layer:
|
| 31 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 32 |
activation: relu
|
| 33 |
dim_feedforward: 2048
|
| 34 |
dropout: 0.1
|
| 35 |
pos_enc_at_attn: false
|
| 36 |
self_attention:
|
| 37 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 38 |
rope_theta: 10000.0
|
| 39 |
feat_sizes: [64, 64]
|
| 40 |
embedding_dim: 256
|
|
|
|
| 45 |
pos_enc_at_cross_attn_keys: true
|
| 46 |
pos_enc_at_cross_attn_queries: false
|
| 47 |
cross_attention:
|
| 48 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 49 |
rope_theta: 10000.0
|
| 50 |
feat_sizes: [64, 64]
|
| 51 |
rope_k_repeat: True
|
|
|
|
| 57 |
num_layers: 4
|
| 58 |
|
| 59 |
memory_encoder:
|
| 60 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 61 |
out_dim: 64
|
| 62 |
position_encoding:
|
| 63 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 64 |
num_pos_feats: 64
|
| 65 |
normalize: true
|
| 66 |
scale: null
|
| 67 |
temperature: 10000
|
| 68 |
mask_downsampler:
|
| 69 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 70 |
kernel_size: 3
|
| 71 |
stride: 2
|
| 72 |
padding: 1
|
| 73 |
fuser:
|
| 74 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 75 |
layer:
|
| 76 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 77 |
dim: 256
|
| 78 |
kernel_size: 7
|
| 79 |
padding: 3
|
{sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_l.yaml
RENAMED
|
@@ -2,12 +2,12 @@
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
-
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 144
|
| 12 |
num_heads: 2
|
| 13 |
stages: [2, 6, 36, 4]
|
|
@@ -15,9 +15,9 @@ model:
|
|
| 15 |
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
window_spec: [8, 4, 16, 8]
|
| 17 |
neck:
|
| 18 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 19 |
position_encoding:
|
| 20 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 21 |
num_pos_feats: 256
|
| 22 |
normalize: true
|
| 23 |
scale: null
|
|
@@ -28,17 +28,17 @@ model:
|
|
| 28 |
fpn_interp_model: nearest
|
| 29 |
|
| 30 |
memory_attention:
|
| 31 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 32 |
d_model: 256
|
| 33 |
pos_enc_at_input: true
|
| 34 |
layer:
|
| 35 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 36 |
activation: relu
|
| 37 |
dim_feedforward: 2048
|
| 38 |
dropout: 0.1
|
| 39 |
pos_enc_at_attn: false
|
| 40 |
self_attention:
|
| 41 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 42 |
rope_theta: 10000.0
|
| 43 |
feat_sizes: [64, 64]
|
| 44 |
embedding_dim: 256
|
|
@@ -49,7 +49,7 @@ model:
|
|
| 49 |
pos_enc_at_cross_attn_keys: true
|
| 50 |
pos_enc_at_cross_attn_queries: false
|
| 51 |
cross_attention:
|
| 52 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 53 |
rope_theta: 10000.0
|
| 54 |
feat_sizes: [64, 64]
|
| 55 |
rope_k_repeat: True
|
|
@@ -61,23 +61,23 @@ model:
|
|
| 61 |
num_layers: 4
|
| 62 |
|
| 63 |
memory_encoder:
|
| 64 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 65 |
out_dim: 64
|
| 66 |
position_encoding:
|
| 67 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 68 |
num_pos_feats: 64
|
| 69 |
normalize: true
|
| 70 |
scale: null
|
| 71 |
temperature: 10000
|
| 72 |
mask_downsampler:
|
| 73 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 74 |
kernel_size: 3
|
| 75 |
stride: 2
|
| 76 |
padding: 1
|
| 77 |
fuser:
|
| 78 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
| 79 |
layer:
|
| 80 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 81 |
dim: 256
|
| 82 |
kernel_size: 7
|
| 83 |
padding: 3
|
|
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
+
_target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 144
|
| 12 |
num_heads: 2
|
| 13 |
stages: [2, 6, 36, 4]
|
|
|
|
| 15 |
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
window_spec: [8, 4, 16, 8]
|
| 17 |
neck:
|
| 18 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 19 |
position_encoding:
|
| 20 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 21 |
num_pos_feats: 256
|
| 22 |
normalize: true
|
| 23 |
scale: null
|
|
|
|
| 28 |
fpn_interp_model: nearest
|
| 29 |
|
| 30 |
memory_attention:
|
| 31 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 32 |
d_model: 256
|
| 33 |
pos_enc_at_input: true
|
| 34 |
layer:
|
| 35 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 36 |
activation: relu
|
| 37 |
dim_feedforward: 2048
|
| 38 |
dropout: 0.1
|
| 39 |
pos_enc_at_attn: false
|
| 40 |
self_attention:
|
| 41 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 42 |
rope_theta: 10000.0
|
| 43 |
feat_sizes: [64, 64]
|
| 44 |
embedding_dim: 256
|
|
|
|
| 49 |
pos_enc_at_cross_attn_keys: true
|
| 50 |
pos_enc_at_cross_attn_queries: false
|
| 51 |
cross_attention:
|
| 52 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 53 |
rope_theta: 10000.0
|
| 54 |
feat_sizes: [64, 64]
|
| 55 |
rope_k_repeat: True
|
|
|
|
| 61 |
num_layers: 4
|
| 62 |
|
| 63 |
memory_encoder:
|
| 64 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 65 |
out_dim: 64
|
| 66 |
position_encoding:
|
| 67 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 68 |
num_pos_feats: 64
|
| 69 |
normalize: true
|
| 70 |
scale: null
|
| 71 |
temperature: 10000
|
| 72 |
mask_downsampler:
|
| 73 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 74 |
kernel_size: 3
|
| 75 |
stride: 2
|
| 76 |
padding: 1
|
| 77 |
fuser:
|
| 78 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 79 |
layer:
|
| 80 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 81 |
dim: 256
|
| 82 |
kernel_size: 7
|
| 83 |
padding: 3
|
{sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_s.yaml
RENAMED
|
@@ -2,21 +2,21 @@
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
-
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 96
|
| 12 |
num_heads: 1
|
| 13 |
stages: [1, 2, 11, 2]
|
| 14 |
global_att_blocks: [7, 10, 13]
|
| 15 |
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
neck:
|
| 17 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
position_encoding:
|
| 19 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
num_pos_feats: 256
|
| 21 |
normalize: true
|
| 22 |
scale: null
|
|
@@ -27,17 +27,17 @@ model:
|
|
| 27 |
fpn_interp_model: nearest
|
| 28 |
|
| 29 |
memory_attention:
|
| 30 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
d_model: 256
|
| 32 |
pos_enc_at_input: true
|
| 33 |
layer:
|
| 34 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
activation: relu
|
| 36 |
dim_feedforward: 2048
|
| 37 |
dropout: 0.1
|
| 38 |
pos_enc_at_attn: false
|
| 39 |
self_attention:
|
| 40 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
rope_theta: 10000.0
|
| 42 |
feat_sizes: [64, 64]
|
| 43 |
embedding_dim: 256
|
|
@@ -48,7 +48,7 @@ model:
|
|
| 48 |
pos_enc_at_cross_attn_keys: true
|
| 49 |
pos_enc_at_cross_attn_queries: false
|
| 50 |
cross_attention:
|
| 51 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
rope_theta: 10000.0
|
| 53 |
feat_sizes: [64, 64]
|
| 54 |
rope_k_repeat: True
|
|
@@ -60,23 +60,23 @@ model:
|
|
| 60 |
num_layers: 4
|
| 61 |
|
| 62 |
memory_encoder:
|
| 63 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
out_dim: 64
|
| 65 |
position_encoding:
|
| 66 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
num_pos_feats: 64
|
| 68 |
normalize: true
|
| 69 |
scale: null
|
| 70 |
temperature: 10000
|
| 71 |
mask_downsampler:
|
| 72 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
kernel_size: 3
|
| 74 |
stride: 2
|
| 75 |
padding: 1
|
| 76 |
fuser:
|
| 77 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
layer:
|
| 79 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
dim: 256
|
| 81 |
kernel_size: 7
|
| 82 |
padding: 3
|
|
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
+
_target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 96
|
| 12 |
num_heads: 1
|
| 13 |
stages: [1, 2, 11, 2]
|
| 14 |
global_att_blocks: [7, 10, 13]
|
| 15 |
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
neck:
|
| 17 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
position_encoding:
|
| 19 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
num_pos_feats: 256
|
| 21 |
normalize: true
|
| 22 |
scale: null
|
|
|
|
| 27 |
fpn_interp_model: nearest
|
| 28 |
|
| 29 |
memory_attention:
|
| 30 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
d_model: 256
|
| 32 |
pos_enc_at_input: true
|
| 33 |
layer:
|
| 34 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
activation: relu
|
| 36 |
dim_feedforward: 2048
|
| 37 |
dropout: 0.1
|
| 38 |
pos_enc_at_attn: false
|
| 39 |
self_attention:
|
| 40 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
rope_theta: 10000.0
|
| 42 |
feat_sizes: [64, 64]
|
| 43 |
embedding_dim: 256
|
|
|
|
| 48 |
pos_enc_at_cross_attn_keys: true
|
| 49 |
pos_enc_at_cross_attn_queries: false
|
| 50 |
cross_attention:
|
| 51 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
rope_theta: 10000.0
|
| 53 |
feat_sizes: [64, 64]
|
| 54 |
rope_k_repeat: True
|
|
|
|
| 60 |
num_layers: 4
|
| 61 |
|
| 62 |
memory_encoder:
|
| 63 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
out_dim: 64
|
| 65 |
position_encoding:
|
| 66 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
num_pos_feats: 64
|
| 68 |
normalize: true
|
| 69 |
scale: null
|
| 70 |
temperature: 10000
|
| 71 |
mask_downsampler:
|
| 72 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
kernel_size: 3
|
| 74 |
stride: 2
|
| 75 |
padding: 1
|
| 76 |
fuser:
|
| 77 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 78 |
layer:
|
| 79 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 80 |
dim: 256
|
| 81 |
kernel_size: 7
|
| 82 |
padding: 3
|
{sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_t.yaml
RENAMED
|
@@ -2,21 +2,21 @@
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
-
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 96
|
| 12 |
num_heads: 1
|
| 13 |
stages: [1, 2, 7, 2]
|
| 14 |
global_att_blocks: [5, 7, 9]
|
| 15 |
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
neck:
|
| 17 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
position_encoding:
|
| 19 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
num_pos_feats: 256
|
| 21 |
normalize: true
|
| 22 |
scale: null
|
|
@@ -27,17 +27,17 @@ model:
|
|
| 27 |
fpn_interp_model: nearest
|
| 28 |
|
| 29 |
memory_attention:
|
| 30 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
d_model: 256
|
| 32 |
pos_enc_at_input: true
|
| 33 |
layer:
|
| 34 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
activation: relu
|
| 36 |
dim_feedforward: 2048
|
| 37 |
dropout: 0.1
|
| 38 |
pos_enc_at_attn: false
|
| 39 |
self_attention:
|
| 40 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
rope_theta: 10000.0
|
| 42 |
feat_sizes: [64, 64]
|
| 43 |
embedding_dim: 256
|
|
@@ -48,7 +48,7 @@ model:
|
|
| 48 |
pos_enc_at_cross_attn_keys: true
|
| 49 |
pos_enc_at_cross_attn_queries: false
|
| 50 |
cross_attention:
|
| 51 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
rope_theta: 10000.0
|
| 53 |
feat_sizes: [64, 64]
|
| 54 |
rope_k_repeat: True
|
|
@@ -60,23 +60,23 @@ model:
|
|
| 60 |
num_layers: 4
|
| 61 |
|
| 62 |
memory_encoder:
|
| 63 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
out_dim: 64
|
| 65 |
position_encoding:
|
| 66 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
num_pos_feats: 64
|
| 68 |
normalize: true
|
| 69 |
scale: null
|
| 70 |
temperature: 10000
|
| 71 |
mask_downsampler:
|
| 72 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
kernel_size: 3
|
| 74 |
stride: 2
|
| 75 |
padding: 1
|
| 76 |
fuser:
|
| 77 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
layer:
|
| 79 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
dim: 256
|
| 81 |
kernel_size: 7
|
| 82 |
padding: 3
|
|
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
+
_target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 96
|
| 12 |
num_heads: 1
|
| 13 |
stages: [1, 2, 7, 2]
|
| 14 |
global_att_blocks: [5, 7, 9]
|
| 15 |
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
neck:
|
| 17 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
position_encoding:
|
| 19 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
num_pos_feats: 256
|
| 21 |
normalize: true
|
| 22 |
scale: null
|
|
|
|
| 27 |
fpn_interp_model: nearest
|
| 28 |
|
| 29 |
memory_attention:
|
| 30 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
d_model: 256
|
| 32 |
pos_enc_at_input: true
|
| 33 |
layer:
|
| 34 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
activation: relu
|
| 36 |
dim_feedforward: 2048
|
| 37 |
dropout: 0.1
|
| 38 |
pos_enc_at_attn: false
|
| 39 |
self_attention:
|
| 40 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
rope_theta: 10000.0
|
| 42 |
feat_sizes: [64, 64]
|
| 43 |
embedding_dim: 256
|
|
|
|
| 48 |
pos_enc_at_cross_attn_keys: true
|
| 49 |
pos_enc_at_cross_attn_queries: false
|
| 50 |
cross_attention:
|
| 51 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
rope_theta: 10000.0
|
| 53 |
feat_sizes: [64, 64]
|
| 54 |
rope_k_repeat: True
|
|
|
|
| 60 |
num_layers: 4
|
| 61 |
|
| 62 |
memory_encoder:
|
| 63 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
out_dim: 64
|
| 65 |
position_encoding:
|
| 66 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
num_pos_feats: 64
|
| 68 |
normalize: true
|
| 69 |
scale: null
|
| 70 |
temperature: 10000
|
| 71 |
mask_downsampler:
|
| 72 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
kernel_size: 3
|
| 74 |
stride: 2
|
| 75 |
padding: 1
|
| 76 |
fuser:
|
| 77 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 78 |
layer:
|
| 79 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 80 |
dim: 256
|
| 81 |
kernel_size: 7
|
| 82 |
padding: 3
|
bboxmaskpose/sam2/configs/sam2.1_training/sam2.1_hiera_b+_COCO+CIHP_finetune_sam-pose2seg.yaml
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
scratch:
|
| 4 |
+
resolution: 1024
|
| 5 |
+
train_batch_size: 1
|
| 6 |
+
num_train_workers: 10
|
| 7 |
+
num_frames: 1
|
| 8 |
+
max_num_objects: 1
|
| 9 |
+
base_lr: 5.0e-6
|
| 10 |
+
vision_lr: 3.0e-06
|
| 11 |
+
phases_per_epoch: 1
|
| 12 |
+
num_epochs: 15
|
| 13 |
+
|
| 14 |
+
dataset:
|
| 15 |
+
# PATHS to Dataset
|
| 16 |
+
img_folder: path/to/datasett
|
| 17 |
+
gt_folder: path/to/dataset
|
| 18 |
+
multiplier: 2
|
| 19 |
+
|
| 20 |
+
# Video transforms
|
| 21 |
+
vos:
|
| 22 |
+
train_transforms:
|
| 23 |
+
- _target_: training.dataset.transforms.ComposeAPI
|
| 24 |
+
transforms:
|
| 25 |
+
- _target_: training.dataset.transforms.RandomHorizontalFlip
|
| 26 |
+
consistent_transform: True
|
| 27 |
+
- _target_: training.dataset.transforms.RandomAffine
|
| 28 |
+
degrees: 25
|
| 29 |
+
shear: 20
|
| 30 |
+
image_interpolation: bilinear
|
| 31 |
+
consistent_transform: True
|
| 32 |
+
- _target_: training.dataset.transforms.RandomResizeAPI
|
| 33 |
+
sizes: ${scratch.resolution}
|
| 34 |
+
square: true
|
| 35 |
+
consistent_transform: True
|
| 36 |
+
- _target_: training.dataset.transforms.ColorJitter
|
| 37 |
+
consistent_transform: True
|
| 38 |
+
brightness: 0.1
|
| 39 |
+
contrast: 0.03
|
| 40 |
+
saturation: 0.03
|
| 41 |
+
hue: null
|
| 42 |
+
- _target_: training.dataset.transforms.RandomGrayscale
|
| 43 |
+
p: 0.05
|
| 44 |
+
consistent_transform: True
|
| 45 |
+
- _target_: training.dataset.transforms.ColorJitter
|
| 46 |
+
consistent_transform: False
|
| 47 |
+
brightness: 0.1
|
| 48 |
+
contrast: 0.05
|
| 49 |
+
saturation: 0.05
|
| 50 |
+
hue: null
|
| 51 |
+
- _target_: training.dataset.transforms.ToTensorAPI
|
| 52 |
+
- _target_: training.dataset.transforms.NormalizeAPI
|
| 53 |
+
mean: [0.485, 0.456, 0.406]
|
| 54 |
+
std: [0.229, 0.224, 0.225]
|
| 55 |
+
|
| 56 |
+
trainer:
|
| 57 |
+
_target_: training.trainer.Trainer
|
| 58 |
+
mode: train_only # change to train ? (a.k.a. train + val)
|
| 59 |
+
max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
|
| 60 |
+
accelerator: cuda
|
| 61 |
+
seed_value: 123
|
| 62 |
+
|
| 63 |
+
model:
|
| 64 |
+
_target_: training.model.sam2.SAM2Train
|
| 65 |
+
image_encoder:
|
| 66 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 67 |
+
scalp: 1
|
| 68 |
+
trunk:
|
| 69 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 70 |
+
embed_dim: 112
|
| 71 |
+
num_heads: 2
|
| 72 |
+
drop_path_rate: 0.1
|
| 73 |
+
neck:
|
| 74 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 75 |
+
position_encoding:
|
| 76 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 77 |
+
num_pos_feats: 256
|
| 78 |
+
normalize: true
|
| 79 |
+
scale: null
|
| 80 |
+
temperature: 10000
|
| 81 |
+
d_model: 256
|
| 82 |
+
backbone_channel_list: [896, 448, 224, 112]
|
| 83 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 84 |
+
fpn_interp_model: nearest
|
| 85 |
+
|
| 86 |
+
memory_attention:
|
| 87 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 88 |
+
d_model: 256
|
| 89 |
+
pos_enc_at_input: true
|
| 90 |
+
layer:
|
| 91 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 92 |
+
activation: relu
|
| 93 |
+
dim_feedforward: 2048
|
| 94 |
+
dropout: 0.1
|
| 95 |
+
pos_enc_at_attn: false
|
| 96 |
+
self_attention:
|
| 97 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 98 |
+
rope_theta: 10000.0
|
| 99 |
+
feat_sizes: [64, 64]
|
| 100 |
+
embedding_dim: 256
|
| 101 |
+
num_heads: 1
|
| 102 |
+
downsample_rate: 1
|
| 103 |
+
dropout: 0.1
|
| 104 |
+
d_model: 256
|
| 105 |
+
pos_enc_at_cross_attn_keys: true
|
| 106 |
+
pos_enc_at_cross_attn_queries: false
|
| 107 |
+
cross_attention:
|
| 108 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 109 |
+
rope_theta: 10000.0
|
| 110 |
+
feat_sizes: [64, 64]
|
| 111 |
+
rope_k_repeat: True
|
| 112 |
+
embedding_dim: 256
|
| 113 |
+
num_heads: 1
|
| 114 |
+
downsample_rate: 1
|
| 115 |
+
dropout: 0.1
|
| 116 |
+
kv_in_dim: 64
|
| 117 |
+
num_layers: 4
|
| 118 |
+
|
| 119 |
+
memory_encoder:
|
| 120 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 121 |
+
out_dim: 64
|
| 122 |
+
position_encoding:
|
| 123 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 124 |
+
num_pos_feats: 64
|
| 125 |
+
normalize: true
|
| 126 |
+
scale: null
|
| 127 |
+
temperature: 10000
|
| 128 |
+
mask_downsampler:
|
| 129 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 130 |
+
kernel_size: 3
|
| 131 |
+
stride: 2
|
| 132 |
+
padding: 1
|
| 133 |
+
fuser:
|
| 134 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 135 |
+
layer:
|
| 136 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 137 |
+
dim: 256
|
| 138 |
+
kernel_size: 7
|
| 139 |
+
padding: 3
|
| 140 |
+
layer_scale_init_value: 1e-6
|
| 141 |
+
use_dwconv: True # depth-wise convs
|
| 142 |
+
num_layers: 2
|
| 143 |
+
|
| 144 |
+
num_maskmem: 7
|
| 145 |
+
image_size: ${scratch.resolution}
|
| 146 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 147 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 148 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 149 |
+
use_mask_input_as_output_without_sam: true
|
| 150 |
+
# Memory
|
| 151 |
+
directly_add_no_mem_embed: true
|
| 152 |
+
no_obj_embed_spatial: true
|
| 153 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 154 |
+
use_high_res_features_in_sam: true
|
| 155 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 156 |
+
multimask_output_in_sam: true
|
| 157 |
+
# SAM heads
|
| 158 |
+
iou_prediction_use_sigmoid: True
|
| 159 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 160 |
+
use_obj_ptrs_in_encoder: true
|
| 161 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 162 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 163 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 164 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 165 |
+
# object occlusion prediction
|
| 166 |
+
pred_obj_scores: true
|
| 167 |
+
pred_obj_scores_mlp: true
|
| 168 |
+
fixed_no_obj_ptr: true
|
| 169 |
+
# multimask tracking settings
|
| 170 |
+
multimask_output_for_tracking: true
|
| 171 |
+
use_multimask_token_for_obj_ptr: false
|
| 172 |
+
multimask_min_pt_num: 0
|
| 173 |
+
multimask_max_pt_num: 1
|
| 174 |
+
use_mlp_for_obj_ptr_proj: true
|
| 175 |
+
|
| 176 |
+
n_kpts_encoder: 8
|
| 177 |
+
# Compilation flag
|
| 178 |
+
# compile_image_encoder: False
|
| 179 |
+
|
| 180 |
+
####### Training specific params #######
|
| 181 |
+
# box/point input and corrections
|
| 182 |
+
prob_to_use_pt_input_for_train: 1.0
|
| 183 |
+
prob_to_use_pt_input_for_eval: 0.0
|
| 184 |
+
prob_to_use_box_input_for_train: 0.0 # 0.5*0.5 = 0.25 prob to use box instead of points
|
| 185 |
+
prob_to_use_box_input_for_eval: 0.0
|
| 186 |
+
prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
|
| 187 |
+
num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
|
| 188 |
+
num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
|
| 189 |
+
rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
|
| 190 |
+
add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
|
| 191 |
+
# maximum 2 initial conditioning frames
|
| 192 |
+
num_init_cond_frames_for_train: 2
|
| 193 |
+
rand_init_cond_frames_for_train: True # random 1~2
|
| 194 |
+
num_correction_pt_per_frame: 7 ## CHANGED
|
| 195 |
+
use_act_ckpt_iterative_pt_sampling: false
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
num_init_cond_frames_for_eval: 1 # only mask on the first frame
|
| 200 |
+
forward_backbone_per_frame_for_eval: True
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
data:
|
| 204 |
+
train:
|
| 205 |
+
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
|
| 206 |
+
phases_per_epoch: ${scratch.phases_per_epoch}
|
| 207 |
+
batch_sizes:
|
| 208 |
+
- ${scratch.train_batch_size}
|
| 209 |
+
|
| 210 |
+
datasets:
|
| 211 |
+
- _target_: training.dataset.utils.RepeatFactorWrapper
|
| 212 |
+
dataset:
|
| 213 |
+
_target_: training.dataset.utils.ConcatDataset
|
| 214 |
+
datasets:
|
| 215 |
+
- _target_: training.dataset.vos_dataset.VOSDataset
|
| 216 |
+
transforms: ${vos.train_transforms}
|
| 217 |
+
training: true
|
| 218 |
+
video_dataset:
|
| 219 |
+
_target_: training.dataset.vos_raw_dataset.SA1BRawDataset
|
| 220 |
+
img_folder: ${dataset.img_folder}
|
| 221 |
+
gt_folder: ${dataset.gt_folder}
|
| 222 |
+
# file_list_txt: ${dataset.file_list_txt}
|
| 223 |
+
sampler:
|
| 224 |
+
_target_: training.dataset.vos_sampler.RandomUniformSampler
|
| 225 |
+
num_frames: ${scratch.num_frames}
|
| 226 |
+
max_num_objects: ${scratch.max_num_objects}
|
| 227 |
+
multiplier: ${dataset.multiplier}
|
| 228 |
+
shuffle: True
|
| 229 |
+
num_workers: ${scratch.num_train_workers}
|
| 230 |
+
pin_memory: True
|
| 231 |
+
drop_last: True
|
| 232 |
+
collate_fn:
|
| 233 |
+
_target_: training.utils.data_utils.collate_fn
|
| 234 |
+
_partial_: true
|
| 235 |
+
dict_key: all
|
| 236 |
+
|
| 237 |
+
# val:
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
optim:
|
| 241 |
+
amp:
|
| 242 |
+
enabled: True
|
| 243 |
+
amp_dtype: bfloat16
|
| 244 |
+
|
| 245 |
+
optimizer:
|
| 246 |
+
_target_: torch.optim.AdamW
|
| 247 |
+
|
| 248 |
+
gradient_clip:
|
| 249 |
+
_target_: training.optimizer.GradientClipper
|
| 250 |
+
max_norm: 0.1
|
| 251 |
+
norm_type: 2
|
| 252 |
+
|
| 253 |
+
param_group_modifiers:
|
| 254 |
+
- _target_: training.optimizer.layer_decay_param_modifier
|
| 255 |
+
_partial_: True
|
| 256 |
+
layer_decay_value: 0.9
|
| 257 |
+
apply_to: 'image_encoder.trunk'
|
| 258 |
+
overrides:
|
| 259 |
+
- pattern: '*pos_embed*'
|
| 260 |
+
value: 1.0
|
| 261 |
+
|
| 262 |
+
options:
|
| 263 |
+
lr:
|
| 264 |
+
- scheduler:
|
| 265 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
| 266 |
+
start_value: ${scratch.base_lr}
|
| 267 |
+
end_value: ${divide:${scratch.base_lr},10}
|
| 268 |
+
- scheduler:
|
| 269 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
| 270 |
+
start_value: ${scratch.vision_lr}
|
| 271 |
+
end_value: ${divide:${scratch.vision_lr},10}
|
| 272 |
+
param_names:
|
| 273 |
+
- 'image_encoder.*'
|
| 274 |
+
weight_decay:
|
| 275 |
+
- scheduler:
|
| 276 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
| 277 |
+
value: 0.1
|
| 278 |
+
- scheduler:
|
| 279 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
| 280 |
+
value: 0.0
|
| 281 |
+
param_names:
|
| 282 |
+
- '*bias*'
|
| 283 |
+
module_cls_names: ['torch.nn.LayerNorm']
|
| 284 |
+
|
| 285 |
+
loss:
|
| 286 |
+
all:
|
| 287 |
+
_target_: training.loss_fns.MultiStepMultiMasksAndIous
|
| 288 |
+
weight_dict:
|
| 289 |
+
loss_mask: 20
|
| 290 |
+
loss_dice: 1
|
| 291 |
+
loss_iou: 1
|
| 292 |
+
loss_class: 1
|
| 293 |
+
supervise_all_iou: true
|
| 294 |
+
iou_use_l1_loss: true
|
| 295 |
+
pred_obj_scores: true
|
| 296 |
+
focal_gamma_obj_score: 0.0
|
| 297 |
+
focal_alpha_obj_score: -1.0
|
| 298 |
+
|
| 299 |
+
distributed:
|
| 300 |
+
backend: nccl
|
| 301 |
+
find_unused_parameters: True
|
| 302 |
+
|
| 303 |
+
logging:
|
| 304 |
+
tensorboard_writer:
|
| 305 |
+
_target_: training.utils.logger.make_tensorboard_logger
|
| 306 |
+
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
| 307 |
+
flush_secs: 120
|
| 308 |
+
should_log: True
|
| 309 |
+
log_dir: ${launcher.experiment_log_dir}/logs
|
| 310 |
+
log_freq: 10
|
| 311 |
+
|
| 312 |
+
# initialize from a SAM 2 checkpoint
|
| 313 |
+
checkpoint:
|
| 314 |
+
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
| 315 |
+
save_freq: 0 # 0 only last checkpoint is saved.
|
| 316 |
+
model_weight_initializer:
|
| 317 |
+
_partial_: True
|
| 318 |
+
_target_: training.utils.checkpoint_utils.load_state_dict_into_model
|
| 319 |
+
strict: True
|
| 320 |
+
ignore_unexpected_keys: null
|
| 321 |
+
ignore_missing_keys: null
|
| 322 |
+
|
| 323 |
+
state_dict:
|
| 324 |
+
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
| 325 |
+
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt ## CHANGED - PATH to SAM 2.1 checkpoint
|
| 326 |
+
ckpt_state_dict_keys: ['model']
|
| 327 |
+
|
| 328 |
+
launcher:
|
| 329 |
+
num_nodes: 1
|
| 330 |
+
gpus_per_node: 8
|
| 331 |
+
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
|
| 332 |
+
|
| 333 |
+
# SLURM args if running on a cluster
|
| 334 |
+
submitit:
|
| 335 |
+
partition: null
|
| 336 |
+
account: null
|
| 337 |
+
qos: null
|
| 338 |
+
cpus_per_task: 10
|
| 339 |
+
use_cluster: false
|
| 340 |
+
timeout_hour: 24
|
| 341 |
+
name: null
|
| 342 |
+
port_range: [10000, 65000]
|
| 343 |
+
|
{sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_COCO_1024_prompt.yaml
RENAMED
|
@@ -11,14 +11,6 @@ scratch:
|
|
| 11 |
phases_per_epoch: 1
|
| 12 |
num_epochs: 40
|
| 13 |
|
| 14 |
-
dataset:
|
| 15 |
-
# PATHS to Dataset
|
| 16 |
-
img_folder: /mnt/personal/purkrmir/data/COCO/original/train2017/ # PATH to MOSE JPEGImages folder
|
| 17 |
-
gt_folder: /mnt/personal/purkrmir/data/COCO/original/annotations/ # PATH to MOSE Annotations folder
|
| 18 |
-
# img_folder: /datagrid/personal/purkrmir/data/COCO/original/val2017/ # PATH to MOSE JPEGImages folder
|
| 19 |
-
# gt_folder: /datagrid/personal/purkrmir/data/COCO/original/annotations/ # PATH to MOSE Annotations folder
|
| 20 |
-
file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training
|
| 21 |
-
multiplier: 2
|
| 22 |
|
| 23 |
# Video transforms
|
| 24 |
vos:
|
|
@@ -69,19 +61,19 @@ trainer:
|
|
| 69 |
unfreeze_decoder: False
|
| 70 |
|
| 71 |
model:
|
| 72 |
-
_target_: training.model.sam2.
|
| 73 |
image_encoder:
|
| 74 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 75 |
scalp: 1
|
| 76 |
trunk:
|
| 77 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 78 |
embed_dim: 112
|
| 79 |
num_heads: 2
|
| 80 |
drop_path_rate: 0.1
|
| 81 |
neck:
|
| 82 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 83 |
position_encoding:
|
| 84 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 85 |
num_pos_feats: 256
|
| 86 |
normalize: true
|
| 87 |
scale: null
|
|
@@ -92,17 +84,17 @@ trainer:
|
|
| 92 |
fpn_interp_model: nearest
|
| 93 |
|
| 94 |
memory_attention:
|
| 95 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 96 |
d_model: 256
|
| 97 |
pos_enc_at_input: true
|
| 98 |
layer:
|
| 99 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 100 |
activation: relu
|
| 101 |
dim_feedforward: 2048
|
| 102 |
dropout: 0.1
|
| 103 |
pos_enc_at_attn: false
|
| 104 |
self_attention:
|
| 105 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 106 |
rope_theta: 10000.0
|
| 107 |
feat_sizes: [64, 64]
|
| 108 |
embedding_dim: 256
|
|
@@ -113,7 +105,7 @@ trainer:
|
|
| 113 |
pos_enc_at_cross_attn_keys: true
|
| 114 |
pos_enc_at_cross_attn_queries: false
|
| 115 |
cross_attention:
|
| 116 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 117 |
rope_theta: 10000.0
|
| 118 |
feat_sizes: [64, 64]
|
| 119 |
rope_k_repeat: True
|
|
@@ -125,23 +117,23 @@ trainer:
|
|
| 125 |
num_layers: 4
|
| 126 |
|
| 127 |
memory_encoder:
|
| 128 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 129 |
out_dim: 64
|
| 130 |
position_encoding:
|
| 131 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 132 |
num_pos_feats: 64
|
| 133 |
normalize: true
|
| 134 |
scale: null
|
| 135 |
temperature: 10000
|
| 136 |
mask_downsampler:
|
| 137 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 138 |
kernel_size: 3
|
| 139 |
stride: 2
|
| 140 |
padding: 1
|
| 141 |
fuser:
|
| 142 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
| 143 |
layer:
|
| 144 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 145 |
dim: 256
|
| 146 |
kernel_size: 7
|
| 147 |
padding: 3
|
|
@@ -325,7 +317,7 @@ trainer:
|
|
| 325 |
|
| 326 |
state_dict:
|
| 327 |
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
| 328 |
-
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
| 329 |
ckpt_state_dict_keys: ['model']
|
| 330 |
|
| 331 |
launcher:
|
|
|
|
| 11 |
phases_per_epoch: 1
|
| 12 |
num_epochs: 40
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Video transforms
|
| 16 |
vos:
|
|
|
|
| 61 |
unfreeze_decoder: False
|
| 62 |
|
| 63 |
model:
|
| 64 |
+
_target_: training.model.bboxmaskpose.sam2.sam2Train
|
| 65 |
image_encoder:
|
| 66 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 67 |
scalp: 1
|
| 68 |
trunk:
|
| 69 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 70 |
embed_dim: 112
|
| 71 |
num_heads: 2
|
| 72 |
drop_path_rate: 0.1
|
| 73 |
neck:
|
| 74 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 75 |
position_encoding:
|
| 76 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 77 |
num_pos_feats: 256
|
| 78 |
normalize: true
|
| 79 |
scale: null
|
|
|
|
| 84 |
fpn_interp_model: nearest
|
| 85 |
|
| 86 |
memory_attention:
|
| 87 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 88 |
d_model: 256
|
| 89 |
pos_enc_at_input: true
|
| 90 |
layer:
|
| 91 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 92 |
activation: relu
|
| 93 |
dim_feedforward: 2048
|
| 94 |
dropout: 0.1
|
| 95 |
pos_enc_at_attn: false
|
| 96 |
self_attention:
|
| 97 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 98 |
rope_theta: 10000.0
|
| 99 |
feat_sizes: [64, 64]
|
| 100 |
embedding_dim: 256
|
|
|
|
| 105 |
pos_enc_at_cross_attn_keys: true
|
| 106 |
pos_enc_at_cross_attn_queries: false
|
| 107 |
cross_attention:
|
| 108 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 109 |
rope_theta: 10000.0
|
| 110 |
feat_sizes: [64, 64]
|
| 111 |
rope_k_repeat: True
|
|
|
|
| 117 |
num_layers: 4
|
| 118 |
|
| 119 |
memory_encoder:
|
| 120 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 121 |
out_dim: 64
|
| 122 |
position_encoding:
|
| 123 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 124 |
num_pos_feats: 64
|
| 125 |
normalize: true
|
| 126 |
scale: null
|
| 127 |
temperature: 10000
|
| 128 |
mask_downsampler:
|
| 129 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 130 |
kernel_size: 3
|
| 131 |
stride: 2
|
| 132 |
padding: 1
|
| 133 |
fuser:
|
| 134 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 135 |
layer:
|
| 136 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 137 |
dim: 256
|
| 138 |
kernel_size: 7
|
| 139 |
padding: 3
|
|
|
|
| 317 |
|
| 318 |
state_dict:
|
| 319 |
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
| 320 |
+
checkpoint_path: ./checkpoints/bboxmaskpose.sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
| 321 |
ckpt_state_dict_keys: ['model']
|
| 322 |
|
| 323 |
launcher:
|
{sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_COCO_finetune.yaml
RENAMED
|
@@ -11,15 +11,6 @@ scratch:
|
|
| 11 |
phases_per_epoch: 1
|
| 12 |
num_epochs: 40
|
| 13 |
|
| 14 |
-
dataset:
|
| 15 |
-
# PATHS to Dataset
|
| 16 |
-
img_folder: /mnt/personal/purkrmir/data/COCO/original/train2017/ # PATH to MOSE JPEGImages folder
|
| 17 |
-
gt_folder: /mnt/personal/purkrmir/data/COCO/original/annotations/ # PATH to MOSE Annotations folder
|
| 18 |
-
# img_folder: /datagrid/personal/purkrmir/data/COCO/original/val2017/ # PATH to MOSE JPEGImages folder
|
| 19 |
-
# gt_folder: /datagrid/personal/purkrmir/data/COCO/original/annotations/ # PATH to MOSE Annotations folder
|
| 20 |
-
file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training
|
| 21 |
-
multiplier: 2
|
| 22 |
-
|
| 23 |
# Video transforms
|
| 24 |
vos:
|
| 25 |
train_transforms:
|
|
@@ -69,19 +60,19 @@ trainer:
|
|
| 69 |
unfreeze_decoder: False
|
| 70 |
|
| 71 |
model:
|
| 72 |
-
_target_: training.model.sam2.
|
| 73 |
image_encoder:
|
| 74 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 75 |
scalp: 1
|
| 76 |
trunk:
|
| 77 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 78 |
embed_dim: 112
|
| 79 |
num_heads: 2
|
| 80 |
drop_path_rate: 0.1
|
| 81 |
neck:
|
| 82 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 83 |
position_encoding:
|
| 84 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 85 |
num_pos_feats: 256
|
| 86 |
normalize: true
|
| 87 |
scale: null
|
|
@@ -92,17 +83,17 @@ trainer:
|
|
| 92 |
fpn_interp_model: nearest
|
| 93 |
|
| 94 |
memory_attention:
|
| 95 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 96 |
d_model: 256
|
| 97 |
pos_enc_at_input: true
|
| 98 |
layer:
|
| 99 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 100 |
activation: relu
|
| 101 |
dim_feedforward: 2048
|
| 102 |
dropout: 0.1
|
| 103 |
pos_enc_at_attn: false
|
| 104 |
self_attention:
|
| 105 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 106 |
rope_theta: 10000.0
|
| 107 |
feat_sizes: [64, 64]
|
| 108 |
embedding_dim: 256
|
|
@@ -113,7 +104,7 @@ trainer:
|
|
| 113 |
pos_enc_at_cross_attn_keys: true
|
| 114 |
pos_enc_at_cross_attn_queries: false
|
| 115 |
cross_attention:
|
| 116 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 117 |
rope_theta: 10000.0
|
| 118 |
feat_sizes: [64, 64]
|
| 119 |
rope_k_repeat: True
|
|
@@ -125,23 +116,23 @@ trainer:
|
|
| 125 |
num_layers: 4
|
| 126 |
|
| 127 |
memory_encoder:
|
| 128 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 129 |
out_dim: 64
|
| 130 |
position_encoding:
|
| 131 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 132 |
num_pos_feats: 64
|
| 133 |
normalize: true
|
| 134 |
scale: null
|
| 135 |
temperature: 10000
|
| 136 |
mask_downsampler:
|
| 137 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 138 |
kernel_size: 3
|
| 139 |
stride: 2
|
| 140 |
padding: 1
|
| 141 |
fuser:
|
| 142 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
| 143 |
layer:
|
| 144 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 145 |
dim: 256
|
| 146 |
kernel_size: 7
|
| 147 |
padding: 3
|
|
@@ -325,7 +316,7 @@ trainer:
|
|
| 325 |
|
| 326 |
state_dict:
|
| 327 |
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
| 328 |
-
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
| 329 |
ckpt_state_dict_keys: ['model']
|
| 330 |
|
| 331 |
launcher:
|
|
|
|
| 11 |
phases_per_epoch: 1
|
| 12 |
num_epochs: 40
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# Video transforms
|
| 15 |
vos:
|
| 16 |
train_transforms:
|
|
|
|
| 60 |
unfreeze_decoder: False
|
| 61 |
|
| 62 |
model:
|
| 63 |
+
_target_: training.model.bboxmaskpose.sam2.sam2Train
|
| 64 |
image_encoder:
|
| 65 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 66 |
scalp: 1
|
| 67 |
trunk:
|
| 68 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 69 |
embed_dim: 112
|
| 70 |
num_heads: 2
|
| 71 |
drop_path_rate: 0.1
|
| 72 |
neck:
|
| 73 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 74 |
position_encoding:
|
| 75 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 76 |
num_pos_feats: 256
|
| 77 |
normalize: true
|
| 78 |
scale: null
|
|
|
|
| 83 |
fpn_interp_model: nearest
|
| 84 |
|
| 85 |
memory_attention:
|
| 86 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 87 |
d_model: 256
|
| 88 |
pos_enc_at_input: true
|
| 89 |
layer:
|
| 90 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 91 |
activation: relu
|
| 92 |
dim_feedforward: 2048
|
| 93 |
dropout: 0.1
|
| 94 |
pos_enc_at_attn: false
|
| 95 |
self_attention:
|
| 96 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 97 |
rope_theta: 10000.0
|
| 98 |
feat_sizes: [64, 64]
|
| 99 |
embedding_dim: 256
|
|
|
|
| 104 |
pos_enc_at_cross_attn_keys: true
|
| 105 |
pos_enc_at_cross_attn_queries: false
|
| 106 |
cross_attention:
|
| 107 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 108 |
rope_theta: 10000.0
|
| 109 |
feat_sizes: [64, 64]
|
| 110 |
rope_k_repeat: True
|
|
|
|
| 116 |
num_layers: 4
|
| 117 |
|
| 118 |
memory_encoder:
|
| 119 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 120 |
out_dim: 64
|
| 121 |
position_encoding:
|
| 122 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 123 |
num_pos_feats: 64
|
| 124 |
normalize: true
|
| 125 |
scale: null
|
| 126 |
temperature: 10000
|
| 127 |
mask_downsampler:
|
| 128 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 129 |
kernel_size: 3
|
| 130 |
stride: 2
|
| 131 |
padding: 1
|
| 132 |
fuser:
|
| 133 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 134 |
layer:
|
| 135 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 136 |
dim: 256
|
| 137 |
kernel_size: 7
|
| 138 |
padding: 3
|
|
|
|
| 316 |
|
| 317 |
state_dict:
|
| 318 |
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
| 319 |
+
checkpoint_path: ./checkpoints/bboxmaskpose.sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
| 320 |
ckpt_state_dict_keys: ['model']
|
| 321 |
|
| 322 |
launcher:
|
{sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_COCO_finetune_prompt+decoder.yaml
RENAMED
|
@@ -11,15 +11,6 @@ scratch:
|
|
| 11 |
phases_per_epoch: 1
|
| 12 |
num_epochs: 40
|
| 13 |
|
| 14 |
-
dataset:
|
| 15 |
-
# PATHS to Dataset
|
| 16 |
-
img_folder: /mnt/personal/purkrmir/data/COCO/original/train2017/ # PATH to MOSE JPEGImages folder
|
| 17 |
-
gt_folder: /mnt/personal/purkrmir/data/COCO/original/annotations/ # PATH to MOSE Annotations folder
|
| 18 |
-
# img_folder: /datagrid/personal/purkrmir/data/COCO/original/train2017/ # PATH to MOSE JPEGImages folder
|
| 19 |
-
# gt_folder: /datagrid/personal/purkrmir/data/COCO/original/annotations/ # PATH to MOSE Annotations folder
|
| 20 |
-
file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training
|
| 21 |
-
multiplier: 2
|
| 22 |
-
|
| 23 |
# Video transforms
|
| 24 |
vos:
|
| 25 |
train_transforms:
|
|
@@ -69,19 +60,19 @@ trainer:
|
|
| 69 |
unfreeze_decoder: True
|
| 70 |
|
| 71 |
model:
|
| 72 |
-
_target_: training.model.sam2.
|
| 73 |
image_encoder:
|
| 74 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 75 |
scalp: 1
|
| 76 |
trunk:
|
| 77 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 78 |
embed_dim: 112
|
| 79 |
num_heads: 2
|
| 80 |
drop_path_rate: 0.1
|
| 81 |
neck:
|
| 82 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 83 |
position_encoding:
|
| 84 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 85 |
num_pos_feats: 256
|
| 86 |
normalize: true
|
| 87 |
scale: null
|
|
@@ -92,17 +83,17 @@ trainer:
|
|
| 92 |
fpn_interp_model: nearest
|
| 93 |
|
| 94 |
memory_attention:
|
| 95 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 96 |
d_model: 256
|
| 97 |
pos_enc_at_input: true
|
| 98 |
layer:
|
| 99 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 100 |
activation: relu
|
| 101 |
dim_feedforward: 2048
|
| 102 |
dropout: 0.1
|
| 103 |
pos_enc_at_attn: false
|
| 104 |
self_attention:
|
| 105 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 106 |
rope_theta: 10000.0
|
| 107 |
feat_sizes: [64, 64]
|
| 108 |
embedding_dim: 256
|
|
@@ -113,7 +104,7 @@ trainer:
|
|
| 113 |
pos_enc_at_cross_attn_keys: true
|
| 114 |
pos_enc_at_cross_attn_queries: false
|
| 115 |
cross_attention:
|
| 116 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 117 |
rope_theta: 10000.0
|
| 118 |
feat_sizes: [64, 64]
|
| 119 |
rope_k_repeat: True
|
|
@@ -125,23 +116,23 @@ trainer:
|
|
| 125 |
num_layers: 4
|
| 126 |
|
| 127 |
memory_encoder:
|
| 128 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 129 |
out_dim: 64
|
| 130 |
position_encoding:
|
| 131 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 132 |
num_pos_feats: 64
|
| 133 |
normalize: true
|
| 134 |
scale: null
|
| 135 |
temperature: 10000
|
| 136 |
mask_downsampler:
|
| 137 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 138 |
kernel_size: 3
|
| 139 |
stride: 2
|
| 140 |
padding: 1
|
| 141 |
fuser:
|
| 142 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
| 143 |
layer:
|
| 144 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 145 |
dim: 256
|
| 146 |
kernel_size: 7
|
| 147 |
padding: 3
|
|
@@ -325,7 +316,7 @@ trainer:
|
|
| 325 |
|
| 326 |
state_dict:
|
| 327 |
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
| 328 |
-
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
| 329 |
ckpt_state_dict_keys: ['model']
|
| 330 |
|
| 331 |
launcher:
|
|
|
|
| 11 |
phases_per_epoch: 1
|
| 12 |
num_epochs: 40
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# Video transforms
|
| 15 |
vos:
|
| 16 |
train_transforms:
|
|
|
|
| 60 |
unfreeze_decoder: True
|
| 61 |
|
| 62 |
model:
|
| 63 |
+
_target_: training.model.bboxmaskpose.sam2.sam2Train
|
| 64 |
image_encoder:
|
| 65 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 66 |
scalp: 1
|
| 67 |
trunk:
|
| 68 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 69 |
embed_dim: 112
|
| 70 |
num_heads: 2
|
| 71 |
drop_path_rate: 0.1
|
| 72 |
neck:
|
| 73 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 74 |
position_encoding:
|
| 75 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 76 |
num_pos_feats: 256
|
| 77 |
normalize: true
|
| 78 |
scale: null
|
|
|
|
| 83 |
fpn_interp_model: nearest
|
| 84 |
|
| 85 |
memory_attention:
|
| 86 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 87 |
d_model: 256
|
| 88 |
pos_enc_at_input: true
|
| 89 |
layer:
|
| 90 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 91 |
activation: relu
|
| 92 |
dim_feedforward: 2048
|
| 93 |
dropout: 0.1
|
| 94 |
pos_enc_at_attn: false
|
| 95 |
self_attention:
|
| 96 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 97 |
rope_theta: 10000.0
|
| 98 |
feat_sizes: [64, 64]
|
| 99 |
embedding_dim: 256
|
|
|
|
| 104 |
pos_enc_at_cross_attn_keys: true
|
| 105 |
pos_enc_at_cross_attn_queries: false
|
| 106 |
cross_attention:
|
| 107 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 108 |
rope_theta: 10000.0
|
| 109 |
feat_sizes: [64, 64]
|
| 110 |
rope_k_repeat: True
|
|
|
|
| 116 |
num_layers: 4
|
| 117 |
|
| 118 |
memory_encoder:
|
| 119 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 120 |
out_dim: 64
|
| 121 |
position_encoding:
|
| 122 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 123 |
num_pos_feats: 64
|
| 124 |
normalize: true
|
| 125 |
scale: null
|
| 126 |
temperature: 10000
|
| 127 |
mask_downsampler:
|
| 128 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 129 |
kernel_size: 3
|
| 130 |
stride: 2
|
| 131 |
padding: 1
|
| 132 |
fuser:
|
| 133 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 134 |
layer:
|
| 135 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 136 |
dim: 256
|
| 137 |
kernel_size: 7
|
| 138 |
padding: 3
|
|
|
|
| 316 |
|
| 317 |
state_dict:
|
| 318 |
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
| 319 |
+
checkpoint_path: ./checkpoints/bboxmaskpose.sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
| 320 |
ckpt_state_dict_keys: ['model']
|
| 321 |
|
| 322 |
launcher:
|
{sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
RENAMED
|
@@ -11,12 +11,6 @@ scratch:
|
|
| 11 |
phases_per_epoch: 1
|
| 12 |
num_epochs: 40
|
| 13 |
|
| 14 |
-
dataset:
|
| 15 |
-
# PATHS to Dataset
|
| 16 |
-
img_folder: /datagrid/personal/purkrmir/data/MOSE/train/JPEGImages/ # PATH to MOSE JPEGImages folder
|
| 17 |
-
gt_folder: /datagrid/personal/purkrmir/data/MOSE/train/Annotations/ # PATH to MOSE Annotations folder
|
| 18 |
-
file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
|
| 19 |
-
multiplier: 2
|
| 20 |
|
| 21 |
# Video transforms
|
| 22 |
vos:
|
|
@@ -62,19 +56,19 @@ trainer:
|
|
| 62 |
seed_value: 123
|
| 63 |
|
| 64 |
model:
|
| 65 |
-
_target_: training.model.sam2.
|
| 66 |
image_encoder:
|
| 67 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 68 |
scalp: 1
|
| 69 |
trunk:
|
| 70 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 71 |
embed_dim: 112
|
| 72 |
num_heads: 2
|
| 73 |
drop_path_rate: 0.1
|
| 74 |
neck:
|
| 75 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 76 |
position_encoding:
|
| 77 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 78 |
num_pos_feats: 256
|
| 79 |
normalize: true
|
| 80 |
scale: null
|
|
@@ -85,17 +79,17 @@ trainer:
|
|
| 85 |
fpn_interp_model: nearest
|
| 86 |
|
| 87 |
memory_attention:
|
| 88 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 89 |
d_model: 256
|
| 90 |
pos_enc_at_input: true
|
| 91 |
layer:
|
| 92 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 93 |
activation: relu
|
| 94 |
dim_feedforward: 2048
|
| 95 |
dropout: 0.1
|
| 96 |
pos_enc_at_attn: false
|
| 97 |
self_attention:
|
| 98 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 99 |
rope_theta: 10000.0
|
| 100 |
feat_sizes: [64, 64]
|
| 101 |
embedding_dim: 256
|
|
@@ -106,7 +100,7 @@ trainer:
|
|
| 106 |
pos_enc_at_cross_attn_keys: true
|
| 107 |
pos_enc_at_cross_attn_queries: false
|
| 108 |
cross_attention:
|
| 109 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 110 |
rope_theta: 10000.0
|
| 111 |
feat_sizes: [64, 64]
|
| 112 |
rope_k_repeat: True
|
|
@@ -118,23 +112,23 @@ trainer:
|
|
| 118 |
num_layers: 4
|
| 119 |
|
| 120 |
memory_encoder:
|
| 121 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 122 |
out_dim: 64
|
| 123 |
position_encoding:
|
| 124 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 125 |
num_pos_feats: 64
|
| 126 |
normalize: true
|
| 127 |
scale: null
|
| 128 |
temperature: 10000
|
| 129 |
mask_downsampler:
|
| 130 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 131 |
kernel_size: 3
|
| 132 |
stride: 2
|
| 133 |
padding: 1
|
| 134 |
fuser:
|
| 135 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
| 136 |
layer:
|
| 137 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 138 |
dim: 256
|
| 139 |
kernel_size: 7
|
| 140 |
padding: 3
|
|
@@ -318,7 +312,7 @@ trainer:
|
|
| 318 |
|
| 319 |
state_dict:
|
| 320 |
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
| 321 |
-
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
| 322 |
ckpt_state_dict_keys: ['model']
|
| 323 |
|
| 324 |
launcher:
|
|
|
|
| 11 |
phases_per_epoch: 1
|
| 12 |
num_epochs: 40
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Video transforms
|
| 16 |
vos:
|
|
|
|
| 56 |
seed_value: 123
|
| 57 |
|
| 58 |
model:
|
| 59 |
+
_target_: training.model.bboxmaskpose.sam2.sam2Train
|
| 60 |
image_encoder:
|
| 61 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 62 |
scalp: 1
|
| 63 |
trunk:
|
| 64 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 65 |
embed_dim: 112
|
| 66 |
num_heads: 2
|
| 67 |
drop_path_rate: 0.1
|
| 68 |
neck:
|
| 69 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 70 |
position_encoding:
|
| 71 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 72 |
num_pos_feats: 256
|
| 73 |
normalize: true
|
| 74 |
scale: null
|
|
|
|
| 79 |
fpn_interp_model: nearest
|
| 80 |
|
| 81 |
memory_attention:
|
| 82 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 83 |
d_model: 256
|
| 84 |
pos_enc_at_input: true
|
| 85 |
layer:
|
| 86 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 87 |
activation: relu
|
| 88 |
dim_feedforward: 2048
|
| 89 |
dropout: 0.1
|
| 90 |
pos_enc_at_attn: false
|
| 91 |
self_attention:
|
| 92 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 93 |
rope_theta: 10000.0
|
| 94 |
feat_sizes: [64, 64]
|
| 95 |
embedding_dim: 256
|
|
|
|
| 100 |
pos_enc_at_cross_attn_keys: true
|
| 101 |
pos_enc_at_cross_attn_queries: false
|
| 102 |
cross_attention:
|
| 103 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 104 |
rope_theta: 10000.0
|
| 105 |
feat_sizes: [64, 64]
|
| 106 |
rope_k_repeat: True
|
|
|
|
| 112 |
num_layers: 4
|
| 113 |
|
| 114 |
memory_encoder:
|
| 115 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 116 |
out_dim: 64
|
| 117 |
position_encoding:
|
| 118 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 119 |
num_pos_feats: 64
|
| 120 |
normalize: true
|
| 121 |
scale: null
|
| 122 |
temperature: 10000
|
| 123 |
mask_downsampler:
|
| 124 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 125 |
kernel_size: 3
|
| 126 |
stride: 2
|
| 127 |
padding: 1
|
| 128 |
fuser:
|
| 129 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 130 |
layer:
|
| 131 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 132 |
dim: 256
|
| 133 |
kernel_size: 7
|
| 134 |
padding: 3
|
|
|
|
| 312 |
|
| 313 |
state_dict:
|
| 314 |
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
| 315 |
+
checkpoint_path: ./checkpoints/bboxmaskpose.sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
| 316 |
ckpt_state_dict_keys: ['model']
|
| 317 |
|
| 318 |
launcher:
|
{sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_b+.yaml
RENAMED
|
@@ -2,18 +2,18 @@
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
-
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 112
|
| 12 |
num_heads: 2
|
| 13 |
neck:
|
| 14 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 15 |
position_encoding:
|
| 16 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 17 |
num_pos_feats: 256
|
| 18 |
normalize: true
|
| 19 |
scale: null
|
|
@@ -24,17 +24,17 @@ model:
|
|
| 24 |
fpn_interp_model: nearest
|
| 25 |
|
| 26 |
memory_attention:
|
| 27 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 28 |
d_model: 256
|
| 29 |
pos_enc_at_input: true
|
| 30 |
layer:
|
| 31 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 32 |
activation: relu
|
| 33 |
dim_feedforward: 2048
|
| 34 |
dropout: 0.1
|
| 35 |
pos_enc_at_attn: false
|
| 36 |
self_attention:
|
| 37 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 38 |
rope_theta: 10000.0
|
| 39 |
feat_sizes: [32, 32]
|
| 40 |
embedding_dim: 256
|
|
@@ -45,7 +45,7 @@ model:
|
|
| 45 |
pos_enc_at_cross_attn_keys: true
|
| 46 |
pos_enc_at_cross_attn_queries: false
|
| 47 |
cross_attention:
|
| 48 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 49 |
rope_theta: 10000.0
|
| 50 |
feat_sizes: [32, 32]
|
| 51 |
rope_k_repeat: True
|
|
@@ -57,23 +57,23 @@ model:
|
|
| 57 |
num_layers: 4
|
| 58 |
|
| 59 |
memory_encoder:
|
| 60 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 61 |
out_dim: 64
|
| 62 |
position_encoding:
|
| 63 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 64 |
num_pos_feats: 64
|
| 65 |
normalize: true
|
| 66 |
scale: null
|
| 67 |
temperature: 10000
|
| 68 |
mask_downsampler:
|
| 69 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 70 |
kernel_size: 3
|
| 71 |
stride: 2
|
| 72 |
padding: 1
|
| 73 |
fuser:
|
| 74 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
| 75 |
layer:
|
| 76 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 77 |
dim: 256
|
| 78 |
kernel_size: 7
|
| 79 |
padding: 3
|
|
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
+
_target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 112
|
| 12 |
num_heads: 2
|
| 13 |
neck:
|
| 14 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 15 |
position_encoding:
|
| 16 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 17 |
num_pos_feats: 256
|
| 18 |
normalize: true
|
| 19 |
scale: null
|
|
|
|
| 24 |
fpn_interp_model: nearest
|
| 25 |
|
| 26 |
memory_attention:
|
| 27 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 28 |
d_model: 256
|
| 29 |
pos_enc_at_input: true
|
| 30 |
layer:
|
| 31 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 32 |
activation: relu
|
| 33 |
dim_feedforward: 2048
|
| 34 |
dropout: 0.1
|
| 35 |
pos_enc_at_attn: false
|
| 36 |
self_attention:
|
| 37 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 38 |
rope_theta: 10000.0
|
| 39 |
feat_sizes: [32, 32]
|
| 40 |
embedding_dim: 256
|
|
|
|
| 45 |
pos_enc_at_cross_attn_keys: true
|
| 46 |
pos_enc_at_cross_attn_queries: false
|
| 47 |
cross_attention:
|
| 48 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 49 |
rope_theta: 10000.0
|
| 50 |
feat_sizes: [32, 32]
|
| 51 |
rope_k_repeat: True
|
|
|
|
| 57 |
num_layers: 4
|
| 58 |
|
| 59 |
memory_encoder:
|
| 60 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 61 |
out_dim: 64
|
| 62 |
position_encoding:
|
| 63 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 64 |
num_pos_feats: 64
|
| 65 |
normalize: true
|
| 66 |
scale: null
|
| 67 |
temperature: 10000
|
| 68 |
mask_downsampler:
|
| 69 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 70 |
kernel_size: 3
|
| 71 |
stride: 2
|
| 72 |
padding: 1
|
| 73 |
fuser:
|
| 74 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 75 |
layer:
|
| 76 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 77 |
dim: 256
|
| 78 |
kernel_size: 7
|
| 79 |
padding: 3
|
{sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_l.yaml
RENAMED
|
@@ -2,12 +2,12 @@
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
-
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 144
|
| 12 |
num_heads: 2
|
| 13 |
stages: [2, 6, 36, 4]
|
|
@@ -15,9 +15,9 @@ model:
|
|
| 15 |
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
window_spec: [8, 4, 16, 8]
|
| 17 |
neck:
|
| 18 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 19 |
position_encoding:
|
| 20 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 21 |
num_pos_feats: 256
|
| 22 |
normalize: true
|
| 23 |
scale: null
|
|
@@ -28,17 +28,17 @@ model:
|
|
| 28 |
fpn_interp_model: nearest
|
| 29 |
|
| 30 |
memory_attention:
|
| 31 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 32 |
d_model: 256
|
| 33 |
pos_enc_at_input: true
|
| 34 |
layer:
|
| 35 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 36 |
activation: relu
|
| 37 |
dim_feedforward: 2048
|
| 38 |
dropout: 0.1
|
| 39 |
pos_enc_at_attn: false
|
| 40 |
self_attention:
|
| 41 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 42 |
rope_theta: 10000.0
|
| 43 |
feat_sizes: [32, 32]
|
| 44 |
embedding_dim: 256
|
|
@@ -49,7 +49,7 @@ model:
|
|
| 49 |
pos_enc_at_cross_attn_keys: true
|
| 50 |
pos_enc_at_cross_attn_queries: false
|
| 51 |
cross_attention:
|
| 52 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 53 |
rope_theta: 10000.0
|
| 54 |
feat_sizes: [32, 32]
|
| 55 |
rope_k_repeat: True
|
|
@@ -61,23 +61,23 @@ model:
|
|
| 61 |
num_layers: 4
|
| 62 |
|
| 63 |
memory_encoder:
|
| 64 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 65 |
out_dim: 64
|
| 66 |
position_encoding:
|
| 67 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 68 |
num_pos_feats: 64
|
| 69 |
normalize: true
|
| 70 |
scale: null
|
| 71 |
temperature: 10000
|
| 72 |
mask_downsampler:
|
| 73 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 74 |
kernel_size: 3
|
| 75 |
stride: 2
|
| 76 |
padding: 1
|
| 77 |
fuser:
|
| 78 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
| 79 |
layer:
|
| 80 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 81 |
dim: 256
|
| 82 |
kernel_size: 7
|
| 83 |
padding: 3
|
|
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
+
_target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 144
|
| 12 |
num_heads: 2
|
| 13 |
stages: [2, 6, 36, 4]
|
|
|
|
| 15 |
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
window_spec: [8, 4, 16, 8]
|
| 17 |
neck:
|
| 18 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 19 |
position_encoding:
|
| 20 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 21 |
num_pos_feats: 256
|
| 22 |
normalize: true
|
| 23 |
scale: null
|
|
|
|
| 28 |
fpn_interp_model: nearest
|
| 29 |
|
| 30 |
memory_attention:
|
| 31 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 32 |
d_model: 256
|
| 33 |
pos_enc_at_input: true
|
| 34 |
layer:
|
| 35 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 36 |
activation: relu
|
| 37 |
dim_feedforward: 2048
|
| 38 |
dropout: 0.1
|
| 39 |
pos_enc_at_attn: false
|
| 40 |
self_attention:
|
| 41 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 42 |
rope_theta: 10000.0
|
| 43 |
feat_sizes: [32, 32]
|
| 44 |
embedding_dim: 256
|
|
|
|
| 49 |
pos_enc_at_cross_attn_keys: true
|
| 50 |
pos_enc_at_cross_attn_queries: false
|
| 51 |
cross_attention:
|
| 52 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 53 |
rope_theta: 10000.0
|
| 54 |
feat_sizes: [32, 32]
|
| 55 |
rope_k_repeat: True
|
|
|
|
| 61 |
num_layers: 4
|
| 62 |
|
| 63 |
memory_encoder:
|
| 64 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 65 |
out_dim: 64
|
| 66 |
position_encoding:
|
| 67 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 68 |
num_pos_feats: 64
|
| 69 |
normalize: true
|
| 70 |
scale: null
|
| 71 |
temperature: 10000
|
| 72 |
mask_downsampler:
|
| 73 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 74 |
kernel_size: 3
|
| 75 |
stride: 2
|
| 76 |
padding: 1
|
| 77 |
fuser:
|
| 78 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 79 |
layer:
|
| 80 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 81 |
dim: 256
|
| 82 |
kernel_size: 7
|
| 83 |
padding: 3
|
{sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_s.yaml
RENAMED
|
@@ -2,21 +2,21 @@
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
-
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 96
|
| 12 |
num_heads: 1
|
| 13 |
stages: [1, 2, 11, 2]
|
| 14 |
global_att_blocks: [7, 10, 13]
|
| 15 |
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
neck:
|
| 17 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
position_encoding:
|
| 19 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
num_pos_feats: 256
|
| 21 |
normalize: true
|
| 22 |
scale: null
|
|
@@ -27,17 +27,17 @@ model:
|
|
| 27 |
fpn_interp_model: nearest
|
| 28 |
|
| 29 |
memory_attention:
|
| 30 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
d_model: 256
|
| 32 |
pos_enc_at_input: true
|
| 33 |
layer:
|
| 34 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
activation: relu
|
| 36 |
dim_feedforward: 2048
|
| 37 |
dropout: 0.1
|
| 38 |
pos_enc_at_attn: false
|
| 39 |
self_attention:
|
| 40 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
rope_theta: 10000.0
|
| 42 |
feat_sizes: [32, 32]
|
| 43 |
embedding_dim: 256
|
|
@@ -48,7 +48,7 @@ model:
|
|
| 48 |
pos_enc_at_cross_attn_keys: true
|
| 49 |
pos_enc_at_cross_attn_queries: false
|
| 50 |
cross_attention:
|
| 51 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
rope_theta: 10000.0
|
| 53 |
feat_sizes: [32, 32]
|
| 54 |
rope_k_repeat: True
|
|
@@ -60,23 +60,23 @@ model:
|
|
| 60 |
num_layers: 4
|
| 61 |
|
| 62 |
memory_encoder:
|
| 63 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
out_dim: 64
|
| 65 |
position_encoding:
|
| 66 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
num_pos_feats: 64
|
| 68 |
normalize: true
|
| 69 |
scale: null
|
| 70 |
temperature: 10000
|
| 71 |
mask_downsampler:
|
| 72 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
kernel_size: 3
|
| 74 |
stride: 2
|
| 75 |
padding: 1
|
| 76 |
fuser:
|
| 77 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
layer:
|
| 79 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
dim: 256
|
| 81 |
kernel_size: 7
|
| 82 |
padding: 3
|
|
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
+
_target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 96
|
| 12 |
num_heads: 1
|
| 13 |
stages: [1, 2, 11, 2]
|
| 14 |
global_att_blocks: [7, 10, 13]
|
| 15 |
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
neck:
|
| 17 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
position_encoding:
|
| 19 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
num_pos_feats: 256
|
| 21 |
normalize: true
|
| 22 |
scale: null
|
|
|
|
| 27 |
fpn_interp_model: nearest
|
| 28 |
|
| 29 |
memory_attention:
|
| 30 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
d_model: 256
|
| 32 |
pos_enc_at_input: true
|
| 33 |
layer:
|
| 34 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
activation: relu
|
| 36 |
dim_feedforward: 2048
|
| 37 |
dropout: 0.1
|
| 38 |
pos_enc_at_attn: false
|
| 39 |
self_attention:
|
| 40 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
rope_theta: 10000.0
|
| 42 |
feat_sizes: [32, 32]
|
| 43 |
embedding_dim: 256
|
|
|
|
| 48 |
pos_enc_at_cross_attn_keys: true
|
| 49 |
pos_enc_at_cross_attn_queries: false
|
| 50 |
cross_attention:
|
| 51 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
rope_theta: 10000.0
|
| 53 |
feat_sizes: [32, 32]
|
| 54 |
rope_k_repeat: True
|
|
|
|
| 60 |
num_layers: 4
|
| 61 |
|
| 62 |
memory_encoder:
|
| 63 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
out_dim: 64
|
| 65 |
position_encoding:
|
| 66 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
num_pos_feats: 64
|
| 68 |
normalize: true
|
| 69 |
scale: null
|
| 70 |
temperature: 10000
|
| 71 |
mask_downsampler:
|
| 72 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
kernel_size: 3
|
| 74 |
stride: 2
|
| 75 |
padding: 1
|
| 76 |
fuser:
|
| 77 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 78 |
layer:
|
| 79 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 80 |
dim: 256
|
| 81 |
kernel_size: 7
|
| 82 |
padding: 3
|
{sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_t.yaml
RENAMED
|
@@ -2,21 +2,21 @@
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
-
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 96
|
| 12 |
num_heads: 1
|
| 13 |
stages: [1, 2, 7, 2]
|
| 14 |
global_att_blocks: [5, 7, 9]
|
| 15 |
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
neck:
|
| 17 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
position_encoding:
|
| 19 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
num_pos_feats: 256
|
| 21 |
normalize: true
|
| 22 |
scale: null
|
|
@@ -27,17 +27,17 @@ model:
|
|
| 27 |
fpn_interp_model: nearest
|
| 28 |
|
| 29 |
memory_attention:
|
| 30 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
d_model: 256
|
| 32 |
pos_enc_at_input: true
|
| 33 |
layer:
|
| 34 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
activation: relu
|
| 36 |
dim_feedforward: 2048
|
| 37 |
dropout: 0.1
|
| 38 |
pos_enc_at_attn: false
|
| 39 |
self_attention:
|
| 40 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
rope_theta: 10000.0
|
| 42 |
feat_sizes: [32, 32]
|
| 43 |
embedding_dim: 256
|
|
@@ -48,7 +48,7 @@ model:
|
|
| 48 |
pos_enc_at_cross_attn_keys: true
|
| 49 |
pos_enc_at_cross_attn_queries: false
|
| 50 |
cross_attention:
|
| 51 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
rope_theta: 10000.0
|
| 53 |
feat_sizes: [32, 32]
|
| 54 |
rope_k_repeat: True
|
|
@@ -60,23 +60,23 @@ model:
|
|
| 60 |
num_layers: 4
|
| 61 |
|
| 62 |
memory_encoder:
|
| 63 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
out_dim: 64
|
| 65 |
position_encoding:
|
| 66 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
num_pos_feats: 64
|
| 68 |
normalize: true
|
| 69 |
scale: null
|
| 70 |
temperature: 10000
|
| 71 |
mask_downsampler:
|
| 72 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
kernel_size: 3
|
| 74 |
stride: 2
|
| 75 |
padding: 1
|
| 76 |
fuser:
|
| 77 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
layer:
|
| 79 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
dim: 256
|
| 81 |
kernel_size: 7
|
| 82 |
padding: 3
|
|
|
|
| 2 |
|
| 3 |
# Model
|
| 4 |
model:
|
| 5 |
+
_target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
image_encoder:
|
| 7 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
scalp: 1
|
| 9 |
trunk:
|
| 10 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
embed_dim: 96
|
| 12 |
num_heads: 1
|
| 13 |
stages: [1, 2, 7, 2]
|
| 14 |
global_att_blocks: [5, 7, 9]
|
| 15 |
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
neck:
|
| 17 |
+
_target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
position_encoding:
|
| 19 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
num_pos_feats: 256
|
| 21 |
normalize: true
|
| 22 |
scale: null
|
|
|
|
| 27 |
fpn_interp_model: nearest
|
| 28 |
|
| 29 |
memory_attention:
|
| 30 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
d_model: 256
|
| 32 |
pos_enc_at_input: true
|
| 33 |
layer:
|
| 34 |
+
_target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
activation: relu
|
| 36 |
dim_feedforward: 2048
|
| 37 |
dropout: 0.1
|
| 38 |
pos_enc_at_attn: false
|
| 39 |
self_attention:
|
| 40 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
rope_theta: 10000.0
|
| 42 |
feat_sizes: [32, 32]
|
| 43 |
embedding_dim: 256
|
|
|
|
| 48 |
pos_enc_at_cross_attn_keys: true
|
| 49 |
pos_enc_at_cross_attn_queries: false
|
| 50 |
cross_attention:
|
| 51 |
+
_target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
rope_theta: 10000.0
|
| 53 |
feat_sizes: [32, 32]
|
| 54 |
rope_k_repeat: True
|
|
|
|
| 60 |
num_layers: 4
|
| 61 |
|
| 62 |
memory_encoder:
|
| 63 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
out_dim: 64
|
| 65 |
position_encoding:
|
| 66 |
+
_target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
num_pos_feats: 64
|
| 68 |
normalize: true
|
| 69 |
scale: null
|
| 70 |
temperature: 10000
|
| 71 |
mask_downsampler:
|
| 72 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
kernel_size: 3
|
| 74 |
stride: 2
|
| 75 |
padding: 1
|
| 76 |
fuser:
|
| 77 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
|
| 78 |
layer:
|
| 79 |
+
_target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
|
| 80 |
dim: 256
|
| 81 |
kernel_size: 7
|
| 82 |
padding: 3
|
{sam2 → bboxmaskpose/sam2}/csrc/connected_components.cu
RENAMED
|
File without changes
|
{sam2 → bboxmaskpose/sam2}/distinctipy.py
RENAMED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
import random
|
| 3 |
|
|
@@ -125,9 +127,7 @@ def color_distance(c1, c2):
|
|
| 125 |
return distance
|
| 126 |
|
| 127 |
|
| 128 |
-
def distinct_color(
|
| 129 |
-
exclude_colors, pastel_factor=0.0, n_attempts=1000, colorblind_type=None, rng=None
|
| 130 |
-
):
|
| 131 |
"""
|
| 132 |
Generate a colour as distinct as possible from the colours defined in exclude_colors
|
| 133 |
Inspired by: https://gist.github.com/adewes/5884820
|
|
@@ -164,10 +164,7 @@ def distinct_color(
|
|
| 164 |
return get_random_color(pastel_factor=pastel_factor, rng=rng)
|
| 165 |
|
| 166 |
if colorblind_type:
|
| 167 |
-
exclude_colors = [
|
| 168 |
-
colorblind.colorblind_filter(color, colorblind_type)
|
| 169 |
-
for color in exclude_colors
|
| 170 |
-
]
|
| 171 |
|
| 172 |
max_distance = None
|
| 173 |
best_color = None
|
|
@@ -181,9 +178,7 @@ def distinct_color(
|
|
| 181 |
else:
|
| 182 |
compare_color = color
|
| 183 |
|
| 184 |
-
distance_to_nearest = min(
|
| 185 |
-
[color_distance(compare_color, c) for c in exclude_colors]
|
| 186 |
-
)
|
| 187 |
|
| 188 |
if (not max_distance) or (distance_to_nearest > max_distance):
|
| 189 |
max_distance = distance_to_nearest
|
|
@@ -202,9 +197,7 @@ def distinct_color(
|
|
| 202 |
else:
|
| 203 |
compare_color = color
|
| 204 |
|
| 205 |
-
distance_to_nearest = min(
|
| 206 |
-
[color_distance(compare_color, c) for c in exclude_colors]
|
| 207 |
-
)
|
| 208 |
|
| 209 |
if (not max_distance) or (distance_to_nearest > max_distance):
|
| 210 |
max_distance = distance_to_nearest
|
|
@@ -500,4 +493,4 @@ def get_colormap(list_of_colors, name="distinctipy"):
|
|
| 500 |
|
| 501 |
cmap = matplotlib.colors.ListedColormap(list_of_colors, name=name)
|
| 502 |
|
| 503 |
-
return cmap
|
|
|
|
| 1 |
+
# Adapted from the distinctipy repository (https://github.com/alan-turing-institute/distinctipy).
|
| 2 |
+
# Original authors: distinctipy contributors. Included with minor modifications.
|
| 3 |
import math
|
| 4 |
import random
|
| 5 |
|
|
|
|
| 127 |
return distance
|
| 128 |
|
| 129 |
|
| 130 |
+
def distinct_color(exclude_colors, pastel_factor=0.0, n_attempts=1000, colorblind_type=None, rng=None):
|
|
|
|
|
|
|
| 131 |
"""
|
| 132 |
Generate a colour as distinct as possible from the colours defined in exclude_colors
|
| 133 |
Inspired by: https://gist.github.com/adewes/5884820
|
|
|
|
| 164 |
return get_random_color(pastel_factor=pastel_factor, rng=rng)
|
| 165 |
|
| 166 |
if colorblind_type:
|
| 167 |
+
exclude_colors = [colorblind.colorblind_filter(color, colorblind_type) for color in exclude_colors]
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
max_distance = None
|
| 170 |
best_color = None
|
|
|
|
| 178 |
else:
|
| 179 |
compare_color = color
|
| 180 |
|
| 181 |
+
distance_to_nearest = min([color_distance(compare_color, c) for c in exclude_colors])
|
|
|
|
|
|
|
| 182 |
|
| 183 |
if (not max_distance) or (distance_to_nearest > max_distance):
|
| 184 |
max_distance = distance_to_nearest
|
|
|
|
| 197 |
else:
|
| 198 |
compare_color = color
|
| 199 |
|
| 200 |
+
distance_to_nearest = min([color_distance(compare_color, c) for c in exclude_colors])
|
|
|
|
|
|
|
| 201 |
|
| 202 |
if (not max_distance) or (distance_to_nearest > max_distance):
|
| 203 |
max_distance = distance_to_nearest
|
|
|
|
| 493 |
|
| 494 |
cmap = matplotlib.colors.ListedColormap(list_of_colors, name=name)
|
| 495 |
|
| 496 |
+
return cmap
|
{sam2 → bboxmaskpose/sam2}/modeling/__init__.py
RENAMED
|
File without changes
|
{sam2 → bboxmaskpose/sam2}/modeling/backbones/__init__.py
RENAMED
|
File without changes
|
{sam2 → bboxmaskpose/sam2}/modeling/backbones/hieradet.py
RENAMED
|
@@ -11,15 +11,10 @@ from typing import List, Tuple, Union
|
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
| 13 |
import torch.nn.functional as F
|
| 14 |
-
from iopath.common.file_io import g_pathmgr
|
| 15 |
-
|
| 16 |
-
from sam2.modeling.backbones.utils import (
|
| 17 |
-
PatchEmbed,
|
| 18 |
-
window_partition,
|
| 19 |
-
window_unpartition,
|
| 20 |
-
)
|
| 21 |
|
| 22 |
-
from sam2.modeling.
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
|
@@ -107,9 +102,7 @@ class MultiScaleBlock(nn.Module):
|
|
| 107 |
|
| 108 |
self.pool, self.q_stride = None, q_stride
|
| 109 |
if self.q_stride:
|
| 110 |
-
self.pool = nn.MaxPool2d(
|
| 111 |
-
kernel_size=q_stride, stride=q_stride, ceil_mode=False
|
| 112 |
-
)
|
| 113 |
|
| 114 |
self.attn = MultiScaleAttention(
|
| 115 |
dim,
|
|
@@ -218,16 +211,10 @@ class Hiera(nn.Module):
|
|
| 218 |
|
| 219 |
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
| 220 |
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
| 221 |
-
self.pos_embed = nn.Parameter(
|
| 222 |
-
|
| 223 |
-
)
|
| 224 |
-
self.pos_embed_window = nn.Parameter(
|
| 225 |
-
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
|
| 226 |
-
)
|
| 227 |
|
| 228 |
-
dpr = [
|
| 229 |
-
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 230 |
-
] # stochastic depth decay rule
|
| 231 |
|
| 232 |
cur_stage = 1
|
| 233 |
self.blocks = nn.ModuleList()
|
|
@@ -259,11 +246,7 @@ class Hiera(nn.Module):
|
|
| 259 |
embed_dim = dim_out
|
| 260 |
self.blocks.append(block)
|
| 261 |
|
| 262 |
-
self.channel_list =
|
| 263 |
-
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
| 264 |
-
if return_interm_layers
|
| 265 |
-
else [self.blocks[-1].dim_out]
|
| 266 |
-
)
|
| 267 |
|
| 268 |
if weights_path is not None:
|
| 269 |
with g_pathmgr.open(weights_path, "rb") as f:
|
|
@@ -274,9 +257,7 @@ class Hiera(nn.Module):
|
|
| 274 |
h, w = hw
|
| 275 |
window_embed = self.pos_embed_window
|
| 276 |
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
| 277 |
-
pos_embed = pos_embed + window_embed.tile(
|
| 278 |
-
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
|
| 279 |
-
)
|
| 280 |
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
| 281 |
return pos_embed
|
| 282 |
|
|
@@ -290,9 +271,7 @@ class Hiera(nn.Module):
|
|
| 290 |
outputs = []
|
| 291 |
for i, blk in enumerate(self.blocks):
|
| 292 |
x = blk(x)
|
| 293 |
-
if (i == self.stage_ends[-1]) or (
|
| 294 |
-
i in self.stage_ends and self.return_interm_layers
|
| 295 |
-
):
|
| 296 |
feats = x.permute(0, 3, 1, 2)
|
| 297 |
outputs.append(feats)
|
| 298 |
|
|
|
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
| 13 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
from bboxmaskpose.sam2.modeling.backbones.utils import PatchEmbed, window_partition, window_unpartition
|
| 16 |
+
from bboxmaskpose.sam2.modeling.sam2_utils import MLP, DropPath
|
| 17 |
+
from iopath.common.file_io import g_pathmgr
|
| 18 |
|
| 19 |
|
| 20 |
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
|
|
|
| 102 |
|
| 103 |
self.pool, self.q_stride = None, q_stride
|
| 104 |
if self.q_stride:
|
| 105 |
+
self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)
|
|
|
|
|
|
|
| 106 |
|
| 107 |
self.attn = MultiScaleAttention(
|
| 108 |
dim,
|
|
|
|
| 211 |
|
| 212 |
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
| 213 |
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
| 214 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
|
| 215 |
+
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
|
|
|
|
|
| 218 |
|
| 219 |
cur_stage = 1
|
| 220 |
self.blocks = nn.ModuleList()
|
|
|
|
| 246 |
embed_dim = dim_out
|
| 247 |
self.blocks.append(block)
|
| 248 |
|
| 249 |
+
self.channel_list = [self.blocks[i].dim_out for i in self.stage_ends[::-1]] if return_interm_layers else [self.blocks[-1].dim_out]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
if weights_path is not None:
|
| 252 |
with g_pathmgr.open(weights_path, "rb") as f:
|
|
|
|
| 257 |
h, w = hw
|
| 258 |
window_embed = self.pos_embed_window
|
| 259 |
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
| 260 |
+
pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
|
|
|
|
|
|
|
| 261 |
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
| 262 |
return pos_embed
|
| 263 |
|
|
|
|
| 271 |
outputs = []
|
| 272 |
for i, blk in enumerate(self.blocks):
|
| 273 |
x = blk(x)
|
| 274 |
+
if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
|
|
|
|
|
|
|
| 275 |
feats = x.permute(0, 3, 1, 2)
|
| 276 |
outputs.append(feats)
|
| 277 |
|
{sam2 → bboxmaskpose/sam2}/modeling/backbones/image_encoder.py
RENAMED
|
@@ -117,9 +117,7 @@ class FpnNeck(nn.Module):
|
|
| 117 |
prev_features.to(dtype=torch.float32),
|
| 118 |
scale_factor=2.0,
|
| 119 |
mode=self.fpn_interp_model,
|
| 120 |
-
align_corners=(
|
| 121 |
-
None if self.fpn_interp_model == "nearest" else False
|
| 122 |
-
),
|
| 123 |
antialias=False,
|
| 124 |
)
|
| 125 |
prev_features = lateral_features + top_down_features
|
|
|
|
| 117 |
prev_features.to(dtype=torch.float32),
|
| 118 |
scale_factor=2.0,
|
| 119 |
mode=self.fpn_interp_model,
|
| 120 |
+
align_corners=(None if self.fpn_interp_model == "nearest" else False),
|
|
|
|
|
|
|
| 121 |
antialias=False,
|
| 122 |
)
|
| 123 |
prev_features = lateral_features + top_down_features
|
{sam2 → bboxmaskpose/sam2}/modeling/backbones/utils.py
RENAMED
|
@@ -50,9 +50,7 @@ def window_unpartition(windows, window_size, pad_hw, hw):
|
|
| 50 |
Hp, Wp = pad_hw
|
| 51 |
H, W = hw
|
| 52 |
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
| 53 |
-
x = windows.reshape(
|
| 54 |
-
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
| 55 |
-
)
|
| 56 |
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
|
| 57 |
|
| 58 |
if Hp > H or Wp > W:
|
|
@@ -82,9 +80,7 @@ class PatchEmbed(nn.Module):
|
|
| 82 |
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
| 83 |
"""
|
| 84 |
super().__init__()
|
| 85 |
-
self.proj = nn.Conv2d(
|
| 86 |
-
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
| 87 |
-
)
|
| 88 |
|
| 89 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 90 |
x = self.proj(x)
|
|
|
|
| 50 |
Hp, Wp = pad_hw
|
| 51 |
H, W = hw
|
| 52 |
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
| 53 |
+
x = windows.reshape(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
|
|
|
|
|
|
| 54 |
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
|
| 55 |
|
| 56 |
if Hp > H or Wp > W:
|
|
|
|
| 80 |
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
| 81 |
"""
|
| 82 |
super().__init__()
|
| 83 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
|
|
|
|
|
| 84 |
|
| 85 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 86 |
x = self.proj(x)
|
{sam2 → bboxmaskpose/sam2}/modeling/memory_attention.py
RENAMED
|
@@ -7,11 +7,10 @@
|
|
| 7 |
from typing import Optional
|
| 8 |
|
| 9 |
import torch
|
| 10 |
-
from torch import
|
| 11 |
|
| 12 |
-
from sam2.modeling.
|
| 13 |
-
|
| 14 |
-
from sam2.modeling.sam2_utils import get_activation_fn, get_clones
|
| 15 |
|
| 16 |
|
| 17 |
class MemoryAttentionLayer(nn.Module):
|
|
@@ -132,9 +131,7 @@ class MemoryAttention(nn.Module):
|
|
| 132 |
curr_pos[0],
|
| 133 |
)
|
| 134 |
|
| 135 |
-
assert
|
| 136 |
-
curr.shape[1] == memory.shape[1]
|
| 137 |
-
), "Batch size must be the same for curr and memory"
|
| 138 |
|
| 139 |
output = curr
|
| 140 |
if self.pos_enc_at_input and curr_pos is not None:
|
|
|
|
| 7 |
from typing import Optional
|
| 8 |
|
| 9 |
import torch
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
|
| 12 |
+
from bboxmaskpose.sam2.modeling.sam2_utils import get_activation_fn, get_clones
|
| 13 |
+
from bboxmaskpose.sam2.modeling.sam.transformer import RoPEAttention
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class MemoryAttentionLayer(nn.Module):
|
|
|
|
| 131 |
curr_pos[0],
|
| 132 |
)
|
| 133 |
|
| 134 |
+
assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"
|
|
|
|
|
|
|
| 135 |
|
| 136 |
output = curr
|
| 137 |
if self.pos_enc_at_input and curr_pos is not None:
|
{sam2 → bboxmaskpose/sam2}/modeling/memory_encoder.py
RENAMED
|
@@ -11,7 +11,7 @@ import torch
|
|
| 11 |
import torch.nn as nn
|
| 12 |
import torch.nn.functional as F
|
| 13 |
|
| 14 |
-
from sam2.modeling.sam2_utils import DropPath,
|
| 15 |
|
| 16 |
|
| 17 |
class MaskDownSampler(nn.Module):
|
|
@@ -89,16 +89,10 @@ class CXBlock(nn.Module):
|
|
| 89 |
groups=dim if use_dwconv else 1,
|
| 90 |
) # depthwise conv
|
| 91 |
self.norm = LayerNorm2d(dim, eps=1e-6)
|
| 92 |
-
self.pwconv1 = nn.Linear(
|
| 93 |
-
dim, 4 * dim
|
| 94 |
-
) # pointwise/1x1 convs, implemented with linear layers
|
| 95 |
self.act = nn.GELU()
|
| 96 |
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 97 |
-
self.gamma = (
|
| 98 |
-
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 99 |
-
if layer_scale_init_value > 0
|
| 100 |
-
else None
|
| 101 |
-
)
|
| 102 |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 103 |
|
| 104 |
def forward(self, x):
|
|
|
|
| 11 |
import torch.nn as nn
|
| 12 |
import torch.nn.functional as F
|
| 13 |
|
| 14 |
+
from bboxmaskpose.sam2.modeling.sam2_utils import DropPath, LayerNorm2d, get_clones
|
| 15 |
|
| 16 |
|
| 17 |
class MaskDownSampler(nn.Module):
|
|
|
|
| 89 |
groups=dim if use_dwconv else 1,
|
| 90 |
) # depthwise conv
|
| 91 |
self.norm = LayerNorm2d(dim, eps=1e-6)
|
| 92 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
|
|
|
|
|
|
| 93 |
self.act = nn.GELU()
|
| 94 |
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 95 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 else None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 97 |
|
| 98 |
def forward(self, x):
|
{sam2 → bboxmaskpose/sam2}/modeling/position_encoding.py
RENAMED
|
@@ -8,7 +8,6 @@ import math
|
|
| 8 |
from typing import Any, Optional, Tuple
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
-
|
| 12 |
import torch
|
| 13 |
from torch import nn
|
| 14 |
|
|
@@ -61,12 +60,8 @@ class PositionEmbeddingSine(nn.Module):
|
|
| 61 |
|
| 62 |
pos_x = x_embed[:, None] / dim_t
|
| 63 |
pos_y = y_embed[:, None] / dim_t
|
| 64 |
-
pos_x = torch.stack(
|
| 65 |
-
|
| 66 |
-
).flatten(1)
|
| 67 |
-
pos_y = torch.stack(
|
| 68 |
-
(pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
|
| 69 |
-
).flatten(1)
|
| 70 |
return pos_x, pos_y
|
| 71 |
|
| 72 |
@torch.no_grad()
|
|
@@ -92,16 +87,8 @@ class PositionEmbeddingSine(nn.Module):
|
|
| 92 |
if cache_key in self.cache:
|
| 93 |
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
|
| 94 |
|
| 95 |
-
y_embed = (
|
| 96 |
-
|
| 97 |
-
.view(1, -1, 1)
|
| 98 |
-
.repeat(B, 1, W)
|
| 99 |
-
)
|
| 100 |
-
x_embed = (
|
| 101 |
-
torch.arange(1, W + 1, dtype=torch.float32, device=device)
|
| 102 |
-
.view(1, 1, -1)
|
| 103 |
-
.repeat(B, H, 1)
|
| 104 |
-
)
|
| 105 |
|
| 106 |
if self.normalize:
|
| 107 |
eps = 1e-6
|
|
@@ -113,12 +100,8 @@ class PositionEmbeddingSine(nn.Module):
|
|
| 113 |
|
| 114 |
pos_x = x_embed[:, :, :, None] / dim_t
|
| 115 |
pos_y = y_embed[:, :, :, None] / dim_t
|
| 116 |
-
pos_x = torch.stack(
|
| 117 |
-
|
| 118 |
-
).flatten(3)
|
| 119 |
-
pos_y = torch.stack(
|
| 120 |
-
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 121 |
-
).flatten(3)
|
| 122 |
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 123 |
self.cache[cache_key] = pos[0]
|
| 124 |
return pos
|
|
@@ -166,9 +149,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|
| 166 |
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
| 167 |
return pe.permute(2, 0, 1) # C x H x W
|
| 168 |
|
| 169 |
-
def forward_with_coords(
|
| 170 |
-
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
| 171 |
-
) -> torch.Tensor:
|
| 172 |
"""Positionally encode points that are not normalized to [0,1]."""
|
| 173 |
coords = coords_input.clone()
|
| 174 |
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
|
@@ -216,11 +197,7 @@ def apply_rotary_enc(
|
|
| 216 |
repeat_freqs_k: bool = False,
|
| 217 |
):
|
| 218 |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 219 |
-
xk_ = (
|
| 220 |
-
torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 221 |
-
if xk.shape[-2] != 0
|
| 222 |
-
else None
|
| 223 |
-
)
|
| 224 |
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 225 |
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 226 |
if xk_ is None:
|
|
|
|
| 8 |
from typing import Any, Optional, Tuple
|
| 9 |
|
| 10 |
import numpy as np
|
|
|
|
| 11 |
import torch
|
| 12 |
from torch import nn
|
| 13 |
|
|
|
|
| 60 |
|
| 61 |
pos_x = x_embed[:, None] / dim_t
|
| 62 |
pos_y = y_embed[:, None] / dim_t
|
| 63 |
+
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
|
| 64 |
+
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
return pos_x, pos_y
|
| 66 |
|
| 67 |
@torch.no_grad()
|
|
|
|
| 87 |
if cache_key in self.cache:
|
| 88 |
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
|
| 89 |
|
| 90 |
+
y_embed = torch.arange(1, H + 1, dtype=torch.float32, device=device).view(1, -1, 1).repeat(B, 1, W)
|
| 91 |
+
x_embed = torch.arange(1, W + 1, dtype=torch.float32, device=device).view(1, 1, -1).repeat(B, H, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
if self.normalize:
|
| 94 |
eps = 1e-6
|
|
|
|
| 100 |
|
| 101 |
pos_x = x_embed[:, :, :, None] / dim_t
|
| 102 |
pos_y = y_embed[:, :, :, None] / dim_t
|
| 103 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 104 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 106 |
self.cache[cache_key] = pos[0]
|
| 107 |
return pos
|
|
|
|
| 149 |
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
| 150 |
return pe.permute(2, 0, 1) # C x H x W
|
| 151 |
|
| 152 |
+
def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
|
|
|
|
|
|
|
| 153 |
"""Positionally encode points that are not normalized to [0,1]."""
|
| 154 |
coords = coords_input.clone()
|
| 155 |
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
|
|
|
| 197 |
repeat_freqs_k: bool = False,
|
| 198 |
):
|
| 199 |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 200 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 202 |
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 203 |
if xk_ is None:
|
{sam2 → bboxmaskpose/sam2}/modeling/sam/__init__.py
RENAMED
|
File without changes
|
{sam2 → bboxmaskpose/sam2}/modeling/sam/mask_decoder.py
RENAMED
|
@@ -9,7 +9,7 @@ from typing import List, Optional, Tuple, Type
|
|
| 9 |
import torch
|
| 10 |
from torch import nn
|
| 11 |
|
| 12 |
-
from sam2.modeling.sam2_utils import
|
| 13 |
|
| 14 |
|
| 15 |
class MaskDecoder(nn.Module):
|
|
@@ -63,30 +63,19 @@ class MaskDecoder(nn.Module):
|
|
| 63 |
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
| 64 |
|
| 65 |
self.output_upscaling = nn.Sequential(
|
| 66 |
-
nn.ConvTranspose2d(
|
| 67 |
-
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
|
| 68 |
-
),
|
| 69 |
LayerNorm2d(transformer_dim // 4),
|
| 70 |
activation(),
|
| 71 |
-
nn.ConvTranspose2d(
|
| 72 |
-
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
|
| 73 |
-
),
|
| 74 |
activation(),
|
| 75 |
)
|
| 76 |
self.use_high_res_features = use_high_res_features
|
| 77 |
if use_high_res_features:
|
| 78 |
-
self.conv_s0 = nn.Conv2d(
|
| 79 |
-
|
| 80 |
-
)
|
| 81 |
-
self.conv_s1 = nn.Conv2d(
|
| 82 |
-
transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
|
| 83 |
-
)
|
| 84 |
|
| 85 |
self.output_hypernetworks_mlps = nn.ModuleList(
|
| 86 |
-
[
|
| 87 |
-
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
| 88 |
-
for i in range(self.num_mask_tokens)
|
| 89 |
-
]
|
| 90 |
)
|
| 91 |
|
| 92 |
self.iou_prediction_head = MLP(
|
|
@@ -188,12 +177,8 @@ class MaskDecoder(nn.Module):
|
|
| 188 |
)
|
| 189 |
s = 1
|
| 190 |
else:
|
| 191 |
-
output_tokens = torch.cat(
|
| 192 |
-
|
| 193 |
-
)
|
| 194 |
-
output_tokens = output_tokens.unsqueeze(0).expand(
|
| 195 |
-
sparse_prompt_embeddings.size(0), -1, -1
|
| 196 |
-
)
|
| 197 |
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
| 198 |
|
| 199 |
# Expand per-image data in batch direction to be per-mask
|
|
@@ -203,9 +188,7 @@ class MaskDecoder(nn.Module):
|
|
| 203 |
assert image_embeddings.shape[0] == tokens.shape[0]
|
| 204 |
src = image_embeddings
|
| 205 |
src = src + dense_prompt_embeddings
|
| 206 |
-
assert (
|
| 207 |
-
image_pe.size(0) == 1
|
| 208 |
-
), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
| 209 |
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
| 210 |
b, c, h, w = src.shape
|
| 211 |
|
|
@@ -226,9 +209,7 @@ class MaskDecoder(nn.Module):
|
|
| 226 |
|
| 227 |
hyper_in_list: List[torch.Tensor] = []
|
| 228 |
for i in range(self.num_mask_tokens):
|
| 229 |
-
hyper_in_list.append(
|
| 230 |
-
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
| 231 |
-
)
|
| 232 |
hyper_in = torch.stack(hyper_in_list, dim=1)
|
| 233 |
b, c, h, w = upscaled_embedding.shape
|
| 234 |
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
|
@@ -267,9 +248,7 @@ class MaskDecoder(nn.Module):
|
|
| 267 |
multimask_logits = all_mask_logits[:, 1:, :, :]
|
| 268 |
multimask_iou_scores = all_iou_scores[:, 1:]
|
| 269 |
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
| 270 |
-
batch_inds = torch.arange(
|
| 271 |
-
multimask_iou_scores.size(0), device=all_iou_scores.device
|
| 272 |
-
)
|
| 273 |
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
|
| 274 |
best_multimask_logits = best_multimask_logits.unsqueeze(1)
|
| 275 |
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
|
|
|
|
| 9 |
import torch
|
| 10 |
from torch import nn
|
| 11 |
|
| 12 |
+
from bboxmaskpose.sam2.modeling.sam2_utils import MLP, LayerNorm2d
|
| 13 |
|
| 14 |
|
| 15 |
class MaskDecoder(nn.Module):
|
|
|
|
| 63 |
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
| 64 |
|
| 65 |
self.output_upscaling = nn.Sequential(
|
| 66 |
+
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
|
|
|
|
|
|
| 67 |
LayerNorm2d(transformer_dim // 4),
|
| 68 |
activation(),
|
| 69 |
+
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
|
|
|
|
|
|
| 70 |
activation(),
|
| 71 |
)
|
| 72 |
self.use_high_res_features = use_high_res_features
|
| 73 |
if use_high_res_features:
|
| 74 |
+
self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
|
| 75 |
+
self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
self.output_hypernetworks_mlps = nn.ModuleList(
|
| 78 |
+
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens)]
|
|
|
|
|
|
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
self.iou_prediction_head = MLP(
|
|
|
|
| 177 |
)
|
| 178 |
s = 1
|
| 179 |
else:
|
| 180 |
+
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
| 181 |
+
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
| 183 |
|
| 184 |
# Expand per-image data in batch direction to be per-mask
|
|
|
|
| 188 |
assert image_embeddings.shape[0] == tokens.shape[0]
|
| 189 |
src = image_embeddings
|
| 190 |
src = src + dense_prompt_embeddings
|
| 191 |
+
assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
|
|
|
|
|
|
| 192 |
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
| 193 |
b, c, h, w = src.shape
|
| 194 |
|
|
|
|
| 209 |
|
| 210 |
hyper_in_list: List[torch.Tensor] = []
|
| 211 |
for i in range(self.num_mask_tokens):
|
| 212 |
+
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
|
|
|
|
|
|
|
| 213 |
hyper_in = torch.stack(hyper_in_list, dim=1)
|
| 214 |
b, c, h, w = upscaled_embedding.shape
|
| 215 |
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
|
|
|
| 248 |
multimask_logits = all_mask_logits[:, 1:, :, :]
|
| 249 |
multimask_iou_scores = all_iou_scores[:, 1:]
|
| 250 |
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
| 251 |
+
batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
|
|
|
|
|
|
|
| 252 |
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
|
| 253 |
best_multimask_logits = best_multimask_logits.unsqueeze(1)
|
| 254 |
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
|
{sam2 → bboxmaskpose/sam2}/modeling/sam/pose_encoder.py
RENAMED
|
@@ -9,9 +9,8 @@ from typing import Optional, Tuple, Type
|
|
| 9 |
import torch
|
| 10 |
from torch import nn
|
| 11 |
|
| 12 |
-
from sam2.modeling.position_encoding import PositionEmbeddingRandom
|
| 13 |
-
|
| 14 |
-
from sam2.modeling.sam2_utils import LayerNorm2d
|
| 15 |
|
| 16 |
|
| 17 |
class PoseEncoder(nn.Module):
|
|
@@ -44,9 +43,7 @@ class PoseEncoder(nn.Module):
|
|
| 44 |
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
| 45 |
|
| 46 |
self.num_point_embeddings: int = 17 # 17 COCO keypoints
|
| 47 |
-
point_embeddings = [
|
| 48 |
-
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
|
| 49 |
-
]
|
| 50 |
self.point_embeddings = nn.ModuleList(point_embeddings)
|
| 51 |
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
| 52 |
|
|
@@ -89,17 +86,12 @@ class PoseEncoder(nn.Module):
|
|
| 89 |
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
| 90 |
points = torch.cat([points, padding_point], dim=1)
|
| 91 |
labels = torch.cat([labels, padding_label], dim=1)
|
| 92 |
-
point_embedding = self.pe_layer.forward_with_coords(
|
| 93 |
-
points, self.input_image_size
|
| 94 |
-
)
|
| 95 |
|
| 96 |
kpt_embeddings = torch.cat([self.point_embeddings[i].weight for i in range(self.num_point_embeddings)], dim=0)
|
| 97 |
negative_embedding = torch.zeros_like(point_embedding) + self.not_a_point_embed.weight
|
| 98 |
positive_embedding = point_embedding + kpt_embeddings
|
| 99 |
-
weighted_embedding = (
|
| 100 |
-
positive_embedding * labels.unsqueeze(-1).float() +
|
| 101 |
-
negative_embedding * (1 - labels.unsqueeze(-1).float())
|
| 102 |
-
)
|
| 103 |
|
| 104 |
point_embedding = torch.where(
|
| 105 |
(labels == 0).unsqueeze(-1),
|
|
@@ -112,9 +104,7 @@ class PoseEncoder(nn.Module):
|
|
| 112 |
"""Embeds box prompts."""
|
| 113 |
boxes = boxes + 0.5 # Shift to center of pixel
|
| 114 |
coords = boxes.reshape(-1, 2, 2)
|
| 115 |
-
corner_embedding = self.pe_layer.forward_with_coords(
|
| 116 |
-
coords, self.input_image_size
|
| 117 |
-
)
|
| 118 |
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
| 119 |
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
| 120 |
return corner_embedding
|
|
@@ -170,9 +160,7 @@ class PoseEncoder(nn.Module):
|
|
| 170 |
Bx(embed_dim)x(embed_H)x(embed_W)
|
| 171 |
"""
|
| 172 |
bs = self._get_batch_size(points, boxes, masks)
|
| 173 |
-
sparse_embeddings = torch.empty(
|
| 174 |
-
(bs, 0, self.embed_dim), device=self._get_device()
|
| 175 |
-
)
|
| 176 |
if points is not None:
|
| 177 |
coords, labels = points
|
| 178 |
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
|
|
|
| 9 |
import torch
|
| 10 |
from torch import nn
|
| 11 |
|
| 12 |
+
from bboxmaskpose.sam2.modeling.position_encoding import PositionEmbeddingRandom
|
| 13 |
+
from bboxmaskpose.sam2.modeling.sam2_utils import LayerNorm2d
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class PoseEncoder(nn.Module):
|
|
|
|
| 43 |
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
| 44 |
|
| 45 |
self.num_point_embeddings: int = 17 # 17 COCO keypoints
|
| 46 |
+
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
|
|
|
|
|
|
|
| 47 |
self.point_embeddings = nn.ModuleList(point_embeddings)
|
| 48 |
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
| 49 |
|
|
|
|
| 86 |
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
| 87 |
points = torch.cat([points, padding_point], dim=1)
|
| 88 |
labels = torch.cat([labels, padding_label], dim=1)
|
| 89 |
+
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
|
|
|
|
|
|
|
| 90 |
|
| 91 |
kpt_embeddings = torch.cat([self.point_embeddings[i].weight for i in range(self.num_point_embeddings)], dim=0)
|
| 92 |
negative_embedding = torch.zeros_like(point_embedding) + self.not_a_point_embed.weight
|
| 93 |
positive_embedding = point_embedding + kpt_embeddings
|
| 94 |
+
weighted_embedding = positive_embedding * labels.unsqueeze(-1).float() + negative_embedding * (1 - labels.unsqueeze(-1).float())
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
point_embedding = torch.where(
|
| 97 |
(labels == 0).unsqueeze(-1),
|
|
|
|
| 104 |
"""Embeds box prompts."""
|
| 105 |
boxes = boxes + 0.5 # Shift to center of pixel
|
| 106 |
coords = boxes.reshape(-1, 2, 2)
|
| 107 |
+
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
|
|
|
|
|
|
| 108 |
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
| 109 |
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
| 110 |
return corner_embedding
|
|
|
|
| 160 |
Bx(embed_dim)x(embed_H)x(embed_W)
|
| 161 |
"""
|
| 162 |
bs = self._get_batch_size(points, boxes, masks)
|
| 163 |
+
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
|
|
|
|
|
|
| 164 |
if points is not None:
|
| 165 |
coords, labels = points
|
| 166 |
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
{sam2 → bboxmaskpose/sam2}/modeling/sam/prompt_encoder.py
RENAMED
|
@@ -6,12 +6,12 @@
|
|
| 6 |
|
| 7 |
from typing import Optional, Tuple, Type
|
| 8 |
|
|
|
|
| 9 |
import torch
|
| 10 |
from torch import nn
|
| 11 |
|
| 12 |
-
from sam2.modeling.position_encoding import PositionEmbeddingRandom
|
| 13 |
-
|
| 14 |
-
from sam2.modeling.sam2_utils import LayerNorm2d
|
| 15 |
|
| 16 |
|
| 17 |
class PromptEncoder(nn.Module):
|
|
@@ -22,6 +22,7 @@ class PromptEncoder(nn.Module):
|
|
| 22 |
input_image_size: Tuple[int, int],
|
| 23 |
mask_in_chans: int,
|
| 24 |
activation: Type[nn.Module] = nn.GELU,
|
|
|
|
| 25 |
) -> None:
|
| 26 |
"""
|
| 27 |
Encodes prompts for input to SAM's mask decoder.
|
|
@@ -44,9 +45,7 @@ class PromptEncoder(nn.Module):
|
|
| 44 |
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
| 45 |
|
| 46 |
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
| 47 |
-
point_embeddings = [
|
| 48 |
-
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
|
| 49 |
-
]
|
| 50 |
self.point_embeddings = nn.ModuleList(point_embeddings)
|
| 51 |
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
| 52 |
|
|
@@ -63,6 +62,7 @@ class PromptEncoder(nn.Module):
|
|
| 63 |
activation(),
|
| 64 |
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
| 65 |
)
|
|
|
|
| 66 |
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
| 67 |
|
| 68 |
def get_dense_pe(self) -> torch.Tensor:
|
|
@@ -76,45 +76,41 @@ class PromptEncoder(nn.Module):
|
|
| 76 |
"""
|
| 77 |
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
| 78 |
|
| 79 |
-
def _embed_points(
|
| 80 |
-
self,
|
| 81 |
-
points: torch.Tensor,
|
| 82 |
-
labels: torch.Tensor,
|
| 83 |
-
pad: bool,
|
| 84 |
) -> torch.Tensor:
|
| 85 |
"""Embeds point prompts."""
|
|
|
|
| 86 |
points = points + 0.5 # Shift to center of pixel
|
| 87 |
if pad:
|
| 88 |
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
| 89 |
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
| 90 |
points = torch.cat([points, padding_point], dim=1)
|
| 91 |
labels = torch.cat([labels, padding_label], dim=1)
|
| 92 |
-
point_embedding = self.pe_layer.forward_with_coords(
|
| 93 |
-
points, self.input_image_size
|
| 94 |
-
)
|
| 95 |
|
|
|
|
| 96 |
point_embedding = torch.where(
|
| 97 |
(labels == -1).unsqueeze(-1),
|
| 98 |
torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
|
| 99 |
point_embedding,
|
| 100 |
)
|
| 101 |
-
point_embedding = torch.where(
|
| 102 |
(labels == 0).unsqueeze(-1),
|
| 103 |
point_embedding + self.point_embeddings[0].weight,
|
| 104 |
point_embedding,
|
| 105 |
)
|
| 106 |
point_embedding = torch.where(
|
| 107 |
-
(labels == 1).unsqueeze(-1),
|
| 108 |
point_embedding + self.point_embeddings[1].weight,
|
| 109 |
point_embedding,
|
| 110 |
)
|
| 111 |
point_embedding = torch.where(
|
| 112 |
-
(labels == 2).unsqueeze(-1),
|
| 113 |
point_embedding + self.point_embeddings[2].weight,
|
| 114 |
point_embedding,
|
| 115 |
)
|
| 116 |
point_embedding = torch.where(
|
| 117 |
-
(labels == 3).unsqueeze(-1),
|
| 118 |
point_embedding + self.point_embeddings[3].weight,
|
| 119 |
point_embedding,
|
| 120 |
)
|
|
@@ -124,9 +120,7 @@ class PromptEncoder(nn.Module):
|
|
| 124 |
"""Embeds box prompts."""
|
| 125 |
boxes = boxes + 0.5 # Shift to center of pixel
|
| 126 |
coords = boxes.reshape(-1, 2, 2)
|
| 127 |
-
corner_embedding = self.pe_layer.forward_with_coords(
|
| 128 |
-
coords, self.input_image_size
|
| 129 |
-
)
|
| 130 |
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
| 131 |
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
| 132 |
return corner_embedding
|
|
@@ -160,9 +154,9 @@ class PromptEncoder(nn.Module):
|
|
| 160 |
def forward(
|
| 161 |
self,
|
| 162 |
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 163 |
-
# skeletons: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 164 |
boxes: Optional[torch.Tensor],
|
| 165 |
masks: Optional[torch.Tensor],
|
|
|
|
| 166 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 167 |
"""
|
| 168 |
Embeds different types of prompts, returning both sparse and dense
|
|
@@ -182,12 +176,13 @@ class PromptEncoder(nn.Module):
|
|
| 182 |
Bx(embed_dim)x(embed_H)x(embed_W)
|
| 183 |
"""
|
| 184 |
bs = self._get_batch_size(points, boxes, masks)
|
| 185 |
-
sparse_embeddings = torch.empty(
|
| 186 |
-
(bs, 0, self.embed_dim), device=self._get_device()
|
| 187 |
-
)
|
| 188 |
if points is not None:
|
| 189 |
coords, labels = points
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
| 191 |
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
| 192 |
if boxes is not None:
|
| 193 |
box_embeddings = self._embed_boxes(boxes)
|
|
|
|
| 6 |
|
| 7 |
from typing import Optional, Tuple, Type
|
| 8 |
|
| 9 |
+
import numpy as np
|
| 10 |
import torch
|
| 11 |
from torch import nn
|
| 12 |
|
| 13 |
+
from bboxmaskpose.sam2.modeling.position_encoding import PositionEmbeddingRandom
|
| 14 |
+
from bboxmaskpose.sam2.modeling.sam2_utils import LayerNorm2d
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class PromptEncoder(nn.Module):
|
|
|
|
| 22 |
input_image_size: Tuple[int, int],
|
| 23 |
mask_in_chans: int,
|
| 24 |
activation: Type[nn.Module] = nn.GELU,
|
| 25 |
+
n_kpts_encoder: int = -1,
|
| 26 |
) -> None:
|
| 27 |
"""
|
| 28 |
Encodes prompts for input to SAM's mask decoder.
|
|
|
|
| 45 |
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
| 46 |
|
| 47 |
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
| 48 |
+
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
|
|
|
|
|
|
|
| 49 |
self.point_embeddings = nn.ModuleList(point_embeddings)
|
| 50 |
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
| 51 |
|
|
|
|
| 62 |
activation(),
|
| 63 |
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
| 64 |
)
|
| 65 |
+
self.n_kpts_encoder = n_kpts_encoder
|
| 66 |
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
| 67 |
|
| 68 |
def get_dense_pe(self) -> torch.Tensor:
|
|
|
|
| 76 |
"""
|
| 77 |
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
| 78 |
|
| 79 |
+
def _embed_points( ## embeds the points into a high-dimensional space (e.g., 256-dim) using learned embeddings
|
| 80 |
+
self, points: torch.Tensor, labels: torch.Tensor, pad: bool, normalize: bool
|
|
|
|
|
|
|
|
|
|
| 81 |
) -> torch.Tensor:
|
| 82 |
"""Embeds point prompts."""
|
| 83 |
+
# print("EMBED points ", points) # KPTS OUTPUT
|
| 84 |
points = points + 0.5 # Shift to center of pixel
|
| 85 |
if pad:
|
| 86 |
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
| 87 |
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
| 88 |
points = torch.cat([points, padding_point], dim=1)
|
| 89 |
labels = torch.cat([labels, padding_label], dim=1)
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
|
| 92 |
point_embedding = torch.where(
|
| 93 |
(labels == -1).unsqueeze(-1),
|
| 94 |
torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
|
| 95 |
point_embedding,
|
| 96 |
)
|
| 97 |
+
point_embedding = torch.where( ## negative pts
|
| 98 |
(labels == 0).unsqueeze(-1),
|
| 99 |
point_embedding + self.point_embeddings[0].weight,
|
| 100 |
point_embedding,
|
| 101 |
)
|
| 102 |
point_embedding = torch.where(
|
| 103 |
+
(labels == 1).unsqueeze(-1), ## positive pts
|
| 104 |
point_embedding + self.point_embeddings[1].weight,
|
| 105 |
point_embedding,
|
| 106 |
)
|
| 107 |
point_embedding = torch.where(
|
| 108 |
+
(labels == 2).unsqueeze(-1), ## bbox top left
|
| 109 |
point_embedding + self.point_embeddings[2].weight,
|
| 110 |
point_embedding,
|
| 111 |
)
|
| 112 |
point_embedding = torch.where(
|
| 113 |
+
(labels == 3).unsqueeze(-1), ## bbox bottom right
|
| 114 |
point_embedding + self.point_embeddings[3].weight,
|
| 115 |
point_embedding,
|
| 116 |
)
|
|
|
|
| 120 |
"""Embeds box prompts."""
|
| 121 |
boxes = boxes + 0.5 # Shift to center of pixel
|
| 122 |
coords = boxes.reshape(-1, 2, 2)
|
| 123 |
+
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
|
|
|
|
|
|
| 124 |
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
| 125 |
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
| 126 |
return corner_embedding
|
|
|
|
| 154 |
def forward(
|
| 155 |
self,
|
| 156 |
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
|
|
|
| 157 |
boxes: Optional[torch.Tensor],
|
| 158 |
masks: Optional[torch.Tensor],
|
| 159 |
+
normalize: bool = True,
|
| 160 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 161 |
"""
|
| 162 |
Embeds different types of prompts, returning both sparse and dense
|
|
|
|
| 176 |
Bx(embed_dim)x(embed_H)x(embed_W)
|
| 177 |
"""
|
| 178 |
bs = self._get_batch_size(points, boxes, masks)
|
| 179 |
+
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
|
|
|
|
|
|
| 180 |
if points is not None:
|
| 181 |
coords, labels = points
|
| 182 |
+
coords = coords.to(self._get_device())
|
| 183 |
+
labels = labels.to(self._get_device())
|
| 184 |
+
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None), normalize=normalize)
|
| 185 |
+
# point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
| 186 |
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
| 187 |
if boxes is not None:
|
| 188 |
box_embeddings = self._embed_boxes(boxes)
|
{sam2 → bboxmaskpose/sam2}/modeling/sam/transformer.py
RENAMED
|
@@ -10,10 +10,10 @@ from typing import Tuple, Type
|
|
| 10 |
|
| 11 |
import torch
|
| 12 |
import torch.nn.functional as F
|
| 13 |
-
from torch import
|
| 14 |
|
| 15 |
-
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
| 16 |
-
from sam2.modeling.sam2_utils import MLP
|
| 17 |
|
| 18 |
|
| 19 |
class TwoWayTransformer(nn.Module):
|
|
@@ -57,9 +57,7 @@ class TwoWayTransformer(nn.Module):
|
|
| 57 |
)
|
| 58 |
)
|
| 59 |
|
| 60 |
-
self.final_attn_token_to_image = Attention(
|
| 61 |
-
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 62 |
-
)
|
| 63 |
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
| 64 |
|
| 65 |
def forward(
|
|
@@ -136,26 +134,18 @@ class TwoWayAttentionBlock(nn.Module):
|
|
| 136 |
self.self_attn = Attention(embedding_dim, num_heads)
|
| 137 |
self.norm1 = nn.LayerNorm(embedding_dim)
|
| 138 |
|
| 139 |
-
self.cross_attn_token_to_image = Attention(
|
| 140 |
-
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 141 |
-
)
|
| 142 |
self.norm2 = nn.LayerNorm(embedding_dim)
|
| 143 |
|
| 144 |
-
self.mlp = MLP(
|
| 145 |
-
embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
|
| 146 |
-
)
|
| 147 |
self.norm3 = nn.LayerNorm(embedding_dim)
|
| 148 |
|
| 149 |
self.norm4 = nn.LayerNorm(embedding_dim)
|
| 150 |
-
self.cross_attn_image_to_token = Attention(
|
| 151 |
-
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 152 |
-
)
|
| 153 |
|
| 154 |
self.skip_first_layer_pe = skip_first_layer_pe
|
| 155 |
|
| 156 |
-
def forward(
|
| 157 |
-
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
| 158 |
-
) -> Tuple[Tensor, Tensor]:
|
| 159 |
# Self attention block
|
| 160 |
if self.skip_first_layer_pe:
|
| 161 |
queries = self.self_attn(q=queries, k=queries, v=queries)
|
|
@@ -206,9 +196,7 @@ class Attention(nn.Module):
|
|
| 206 |
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
| 207 |
self.internal_dim = embedding_dim // downsample_rate
|
| 208 |
self.num_heads = num_heads
|
| 209 |
-
assert
|
| 210 |
-
self.internal_dim % num_heads == 0
|
| 211 |
-
), "num_heads must divide embedding_dim."
|
| 212 |
|
| 213 |
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 214 |
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
|
@@ -263,18 +251,12 @@ class RoPEAttention(Attention):
|
|
| 263 |
):
|
| 264 |
super().__init__(*args, **kwargs)
|
| 265 |
|
| 266 |
-
self.compute_cis = partial(
|
| 267 |
-
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
| 268 |
-
)
|
| 269 |
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
| 270 |
-
self.freqs_cis = (
|
| 271 |
-
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
|
| 272 |
-
)
|
| 273 |
self.rope_k_repeat = rope_k_repeat
|
| 274 |
|
| 275 |
-
def forward(
|
| 276 |
-
self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
|
| 277 |
-
) -> Tensor:
|
| 278 |
# Input projections
|
| 279 |
q = self.q_proj(q)
|
| 280 |
k = self.k_proj(k)
|
|
|
|
| 10 |
|
| 11 |
import torch
|
| 12 |
import torch.nn.functional as F
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
|
| 15 |
+
from bboxmaskpose.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
| 16 |
+
from bboxmaskpose.sam2.modeling.sam2_utils import MLP
|
| 17 |
|
| 18 |
|
| 19 |
class TwoWayTransformer(nn.Module):
|
|
|
|
| 57 |
)
|
| 58 |
)
|
| 59 |
|
| 60 |
+
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
|
|
|
|
|
|
| 61 |
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
| 62 |
|
| 63 |
def forward(
|
|
|
|
| 134 |
self.self_attn = Attention(embedding_dim, num_heads)
|
| 135 |
self.norm1 = nn.LayerNorm(embedding_dim)
|
| 136 |
|
| 137 |
+
self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
|
|
|
|
|
|
| 138 |
self.norm2 = nn.LayerNorm(embedding_dim)
|
| 139 |
|
| 140 |
+
self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation)
|
|
|
|
|
|
|
| 141 |
self.norm3 = nn.LayerNorm(embedding_dim)
|
| 142 |
|
| 143 |
self.norm4 = nn.LayerNorm(embedding_dim)
|
| 144 |
+
self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
|
|
|
|
|
|
| 145 |
|
| 146 |
self.skip_first_layer_pe = skip_first_layer_pe
|
| 147 |
|
| 148 |
+
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
|
|
|
|
| 149 |
# Self attention block
|
| 150 |
if self.skip_first_layer_pe:
|
| 151 |
queries = self.self_attn(q=queries, k=queries, v=queries)
|
|
|
|
| 196 |
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
| 197 |
self.internal_dim = embedding_dim // downsample_rate
|
| 198 |
self.num_heads = num_heads
|
| 199 |
+
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
|
|
|
|
|
|
| 200 |
|
| 201 |
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 202 |
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
|
|
|
| 251 |
):
|
| 252 |
super().__init__(*args, **kwargs)
|
| 253 |
|
| 254 |
+
self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
|
|
|
|
|
|
|
| 255 |
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
| 256 |
+
self.freqs_cis = freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
|
|
|
|
|
|
|
| 257 |
self.rope_k_repeat = rope_k_repeat
|
| 258 |
|
| 259 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
|
|
|
|
|
|
|
| 260 |
# Input projections
|
| 261 |
q = self.q_proj(q)
|
| 262 |
k = self.k_proj(k)
|
{sam2 → bboxmaskpose/sam2}/modeling/sam2_base.py
RENAMED
|
@@ -4,20 +4,15 @@
|
|
| 4 |
# This source code is licensed under the license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
-
from loguru import logger
|
| 8 |
-
|
| 9 |
import torch
|
| 10 |
import torch.distributed
|
| 11 |
import torch.nn.functional as F
|
| 12 |
-
|
| 13 |
from torch.nn.init import trunc_normal_
|
| 14 |
|
| 15 |
-
from sam2.modeling.
|
| 16 |
-
from sam2.modeling.sam.
|
| 17 |
-
from sam2.modeling.sam.
|
| 18 |
-
from sam2.modeling.
|
| 19 |
-
|
| 20 |
-
from sam2.utils.kalman_filter import KalmanFilter
|
| 21 |
|
| 22 |
# a large negative value as a placeholder score for missing objects
|
| 23 |
NO_OBJ_SCORE = -1024.0
|
|
@@ -97,19 +92,10 @@ class SAM2Base(torch.nn.Module):
|
|
| 97 |
# extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
|
| 98 |
sam_mask_decoder_extra_args=None,
|
| 99 |
compile_image_encoder: bool = False,
|
| 100 |
-
|
| 101 |
-
samurai_mode: bool = False,
|
| 102 |
-
# Hyperparameters for SAMURAI
|
| 103 |
-
stable_frames_threshold: int = 15,
|
| 104 |
-
stable_ious_threshold: float = 0.3,
|
| 105 |
-
min_obj_score_logits: float = -1,
|
| 106 |
-
kf_score_weight: float = 0.15,
|
| 107 |
-
memory_bank_iou_threshold: float = 0.5,
|
| 108 |
-
memory_bank_obj_score_threshold: float = 0.0,
|
| 109 |
-
memory_bank_kf_score_threshold: float = 0.0,
|
| 110 |
):
|
| 111 |
super().__init__()
|
| 112 |
-
|
| 113 |
# Part 1: the image backbone
|
| 114 |
self.image_encoder = image_encoder
|
| 115 |
# Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
|
|
@@ -137,16 +123,12 @@ class SAM2Base(torch.nn.Module):
|
|
| 137 |
# Part 3: memory encoder for the previous frame's outputs
|
| 138 |
self.memory_encoder = memory_encoder
|
| 139 |
self.mem_dim = self.hidden_dim
|
| 140 |
-
if hasattr(self.memory_encoder, "out_proj") and hasattr(
|
| 141 |
-
self.memory_encoder.out_proj, "weight"
|
| 142 |
-
):
|
| 143 |
# if there is compression of memories along channel dim
|
| 144 |
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
|
| 145 |
self.num_maskmem = num_maskmem # Number of memories accessible
|
| 146 |
# Temporal encoding of the memories
|
| 147 |
-
self.maskmem_tpos_enc = torch.nn.Parameter(
|
| 148 |
-
torch.zeros(num_maskmem, 1, 1, self.mem_dim)
|
| 149 |
-
)
|
| 150 |
trunc_normal_(self.maskmem_tpos_enc, std=0.02)
|
| 151 |
# a single token to indicate no memory embedding from previous frames
|
| 152 |
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
|
@@ -194,37 +176,10 @@ class SAM2Base(torch.nn.Module):
|
|
| 194 |
|
| 195 |
self._build_sam_heads()
|
| 196 |
self.max_cond_frames_in_attn = max_cond_frames_in_attn
|
| 197 |
-
|
| 198 |
-
# Whether to use SAMURAI or original SAM 2
|
| 199 |
-
self.samurai_mode = samurai_mode
|
| 200 |
-
|
| 201 |
-
# Init Kalman Filter
|
| 202 |
-
self.kf = KalmanFilter()
|
| 203 |
-
self.kf_mean = None
|
| 204 |
-
self.kf_covariance = None
|
| 205 |
-
self.stable_frames = 0
|
| 206 |
-
|
| 207 |
-
# Debug purpose
|
| 208 |
-
self.history = {} # debug
|
| 209 |
-
self.frame_cnt = 0 # debug
|
| 210 |
-
|
| 211 |
-
# Hyperparameters for SAMURAI
|
| 212 |
-
self.stable_frames_threshold = stable_frames_threshold
|
| 213 |
-
self.stable_ious_threshold = stable_ious_threshold
|
| 214 |
-
self.min_obj_score_logits = min_obj_score_logits
|
| 215 |
-
self.kf_score_weight = kf_score_weight
|
| 216 |
-
self.memory_bank_iou_threshold = memory_bank_iou_threshold
|
| 217 |
-
self.memory_bank_obj_score_threshold = memory_bank_obj_score_threshold
|
| 218 |
-
self.memory_bank_kf_score_threshold = memory_bank_kf_score_threshold
|
| 219 |
-
|
| 220 |
-
print(f"\033[93mSAMURAI mode: {self.samurai_mode}\033[0m")
|
| 221 |
-
|
| 222 |
# Model compilation
|
| 223 |
if compile_image_encoder:
|
| 224 |
# Compile the forward function (not the full module) to allow loading checkpoints.
|
| 225 |
-
print(
|
| 226 |
-
"Image encoder compilation is enabled. First forward pass will be slow."
|
| 227 |
-
)
|
| 228 |
self.image_encoder.forward = torch.compile(
|
| 229 |
self.image_encoder.forward,
|
| 230 |
mode="max-autotune",
|
|
@@ -232,6 +187,15 @@ class SAM2Base(torch.nn.Module):
|
|
| 232 |
dynamic=False,
|
| 233 |
)
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
@property
|
| 236 |
def device(self):
|
| 237 |
return next(self.parameters()).device
|
|
@@ -257,7 +221,9 @@ class SAM2Base(torch.nn.Module):
|
|
| 257 |
),
|
| 258 |
input_image_size=(self.image_size, self.image_size),
|
| 259 |
mask_in_chans=16,
|
|
|
|
| 260 |
)
|
|
|
|
| 261 |
self.sam_mask_decoder = MaskDecoder(
|
| 262 |
num_multimask_outputs=3,
|
| 263 |
transformer=TwoWayTransformer(
|
|
@@ -276,13 +242,16 @@ class SAM2Base(torch.nn.Module):
|
|
| 276 |
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
|
| 277 |
**(self.sam_mask_decoder_extra_args or {}),
|
| 278 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
if self.use_obj_ptrs_in_encoder:
|
| 280 |
# a linear projection on SAM output tokens to turn them into object pointers
|
| 281 |
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 282 |
if self.use_mlp_for_obj_ptr_proj:
|
| 283 |
-
self.obj_ptr_proj = MLP(
|
| 284 |
-
self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
|
| 285 |
-
)
|
| 286 |
else:
|
| 287 |
self.obj_ptr_proj = torch.nn.Identity()
|
| 288 |
if self.proj_tpos_enc_in_obj_ptrs:
|
|
@@ -395,7 +364,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 395 |
high_res_features=high_res_features,
|
| 396 |
)
|
| 397 |
if self.pred_obj_scores:
|
| 398 |
-
is_obj_appearing = object_score_logits >
|
| 399 |
|
| 400 |
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
| 401 |
# consistent with the actual mask prediction
|
|
@@ -416,87 +385,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 416 |
)
|
| 417 |
|
| 418 |
sam_output_token = sam_output_tokens[:, 0]
|
| 419 |
-
|
| 420 |
-
if multimask_output and self.samurai_mode:
|
| 421 |
-
if self.kf_mean is None and self.kf_covariance is None or self.stable_frames == 0:
|
| 422 |
-
best_iou_inds = torch.argmax(ious, dim=-1)
|
| 423 |
-
batch_inds = torch.arange(B, device=device)
|
| 424 |
-
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| 425 |
-
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| 426 |
-
non_zero_indices = torch.argwhere(high_res_masks[0][0] > 0.0)
|
| 427 |
-
if len(non_zero_indices) == 0:
|
| 428 |
-
high_res_bbox = [0, 0, 0, 0]
|
| 429 |
-
else:
|
| 430 |
-
y_min, x_min = non_zero_indices.min(dim=0).values
|
| 431 |
-
y_max, x_max = non_zero_indices.max(dim=0).values
|
| 432 |
-
high_res_bbox = [x_min.item(), y_min.item(), x_max.item(), y_max.item()]
|
| 433 |
-
self.kf_mean, self.kf_covariance = self.kf.initiate(self.kf.xyxy_to_xyah(high_res_bbox))
|
| 434 |
-
if sam_output_tokens.size(1) > 1:
|
| 435 |
-
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
| 436 |
-
self.frame_cnt += 1
|
| 437 |
-
self.stable_frames += 1
|
| 438 |
-
elif self.stable_frames < self.stable_frames_threshold:
|
| 439 |
-
self.kf_mean, self.kf_covariance = self.kf.predict(self.kf_mean, self.kf_covariance)
|
| 440 |
-
best_iou_inds = torch.argmax(ious, dim=-1)
|
| 441 |
-
batch_inds = torch.arange(B, device=device)
|
| 442 |
-
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| 443 |
-
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| 444 |
-
non_zero_indices = torch.argwhere(high_res_masks[0][0] > 0.0)
|
| 445 |
-
if len(non_zero_indices) == 0:
|
| 446 |
-
high_res_bbox = [0, 0, 0, 0]
|
| 447 |
-
else:
|
| 448 |
-
y_min, x_min = non_zero_indices.min(dim=0).values
|
| 449 |
-
y_max, x_max = non_zero_indices.max(dim=0).values
|
| 450 |
-
high_res_bbox = [x_min.item(), y_min.item(), x_max.item(), y_max.item()]
|
| 451 |
-
if ious[0][best_iou_inds] > self.stable_ious_threshold:
|
| 452 |
-
self.kf_mean, self.kf_covariance = self.kf.update(self.kf_mean, self.kf_covariance, self.kf.xyxy_to_xyah(high_res_bbox))
|
| 453 |
-
self.stable_frames += 1
|
| 454 |
-
else:
|
| 455 |
-
self.stable_frames = 0
|
| 456 |
-
if sam_output_tokens.size(1) > 1:
|
| 457 |
-
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
| 458 |
-
self.frame_cnt += 1
|
| 459 |
-
else:
|
| 460 |
-
self.kf_mean, self.kf_covariance = self.kf.predict(self.kf_mean, self.kf_covariance)
|
| 461 |
-
high_res_multibboxes = []
|
| 462 |
-
batch_inds = torch.arange(B, device=device)
|
| 463 |
-
for i in range(ious.shape[1]):
|
| 464 |
-
non_zero_indices = torch.argwhere(high_res_multimasks[batch_inds, i].unsqueeze(1)[0][0] > 0.0)
|
| 465 |
-
if len(non_zero_indices) == 0:
|
| 466 |
-
high_res_multibboxes.append([0, 0, 0, 0])
|
| 467 |
-
else:
|
| 468 |
-
y_min, x_min = non_zero_indices.min(dim=0).values
|
| 469 |
-
y_max, x_max = non_zero_indices.max(dim=0).values
|
| 470 |
-
high_res_multibboxes.append([x_min.item(), y_min.item(), x_max.item(), y_max.item()])
|
| 471 |
-
# compute the IoU between the predicted bbox and the high_res_multibboxes
|
| 472 |
-
kf_ious = torch.tensor(self.kf.compute_iou(self.kf_mean[:4], high_res_multibboxes), device=device)
|
| 473 |
-
# weighted iou
|
| 474 |
-
weighted_ious = self.kf_score_weight * kf_ious + (1 - self.kf_score_weight) * ious
|
| 475 |
-
best_iou_inds = torch.argmax(weighted_ious, dim=-1)
|
| 476 |
-
batch_inds = torch.arange(B, device=device)
|
| 477 |
-
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| 478 |
-
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| 479 |
-
if sam_output_tokens.size(1) > 1:
|
| 480 |
-
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
| 481 |
-
|
| 482 |
-
if False:
|
| 483 |
-
# make all these on cpu
|
| 484 |
-
self.history[self.frame_cnt] = {
|
| 485 |
-
"kf_predicted_bbox": self.kf.xyah_to_xyxy(self.kf_mean[:4]),
|
| 486 |
-
# "multi_masks": high_res_multimasks.cpu(),
|
| 487 |
-
"ious": ious.cpu(),
|
| 488 |
-
"multi_bboxes": high_res_multibboxes,
|
| 489 |
-
"kf_ious": kf_ious,
|
| 490 |
-
"weighted_ious": weighted_ious.cpu(),
|
| 491 |
-
"final_selection": best_iou_inds.cpu(),
|
| 492 |
-
}
|
| 493 |
-
self.frame_cnt += 1
|
| 494 |
-
|
| 495 |
-
if ious[0][best_iou_inds] < self.stable_ious_threshold:
|
| 496 |
-
self.stable_frames = 0
|
| 497 |
-
else:
|
| 498 |
-
self.kf_mean, self.kf_covariance = self.kf.update(self.kf_mean, self.kf_covariance, self.kf.xyxy_to_xyah(high_res_multibboxes[best_iou_inds]))
|
| 499 |
-
elif multimask_output and not self.samurai_mode:
|
| 500 |
# take the best mask prediction (with the highest IoU estimation)
|
| 501 |
best_iou_inds = torch.argmax(ious, dim=-1)
|
| 502 |
batch_inds = torch.arange(B, device=device)
|
|
@@ -505,7 +394,6 @@ class SAM2Base(torch.nn.Module):
|
|
| 505 |
if sam_output_tokens.size(1) > 1:
|
| 506 |
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
| 507 |
else:
|
| 508 |
-
best_iou_inds = 0
|
| 509 |
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
| 510 |
|
| 511 |
# Extract object pointer from the SAM output token (with occlusion handling)
|
|
@@ -529,8 +417,6 @@ class SAM2Base(torch.nn.Module):
|
|
| 529 |
high_res_masks,
|
| 530 |
obj_ptr,
|
| 531 |
object_score_logits,
|
| 532 |
-
ious[0][best_iou_inds],
|
| 533 |
-
kf_ious[best_iou_inds] if kf_ious is not None else None,
|
| 534 |
)
|
| 535 |
|
| 536 |
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
|
|
@@ -553,12 +439,10 @@ class SAM2Base(torch.nn.Module):
|
|
| 553 |
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
|
| 554 |
if not self.use_obj_ptrs_in_encoder:
|
| 555 |
# all zeros as a dummy object pointer (of shape [B, C])
|
| 556 |
-
obj_ptr = torch.zeros(
|
| 557 |
-
mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
|
| 558 |
-
)
|
| 559 |
else:
|
| 560 |
# produce an object pointer using the SAM decoder from the mask input
|
| 561 |
-
_, _, _, _, _, obj_ptr, _
|
| 562 |
backbone_features=backbone_features,
|
| 563 |
mask_inputs=self.mask_downsample(mask_inputs_float),
|
| 564 |
high_res_features=high_res_features,
|
|
@@ -591,12 +475,8 @@ class SAM2Base(torch.nn.Module):
|
|
| 591 |
if self.use_high_res_features_in_sam:
|
| 592 |
# precompute projected level 0 and level 1 features in SAM decoder
|
| 593 |
# to avoid running it again on every SAM click
|
| 594 |
-
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
|
| 595 |
-
|
| 596 |
-
)
|
| 597 |
-
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
|
| 598 |
-
backbone_out["backbone_fpn"][1]
|
| 599 |
-
)
|
| 600 |
return backbone_out
|
| 601 |
|
| 602 |
def _prepare_backbone_features(self, backbone_out):
|
|
@@ -657,63 +537,36 @@ class SAM2Base(torch.nn.Module):
|
|
| 657 |
# We also allow taking the memory frame non-consecutively (with stride>1), in which case
|
| 658 |
# we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
|
| 659 |
stride = 1 if self.training else self.memory_temporal_stride_for_eval
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
kf_score = output_dict["non_cond_frame_outputs"][i]["kf_score"] if "kf_score" in output_dict["non_cond_frame_outputs"][i] else None # Get motion score if available
|
| 668 |
-
# Check if the scores meet the criteria for being a valid index
|
| 669 |
-
if iou_score.item() > self.memory_bank_iou_threshold and \
|
| 670 |
-
obj_score.item() > self.memory_bank_obj_score_threshold and \
|
| 671 |
-
(kf_score is None or kf_score.item() > self.memory_bank_kf_score_threshold):
|
| 672 |
-
valid_indices.insert(0, i)
|
| 673 |
-
# Check the number of valid indices
|
| 674 |
-
if len(valid_indices) >= self.max_obj_ptrs_in_encoder - 1:
|
| 675 |
-
break
|
| 676 |
-
if frame_idx - 1 not in valid_indices:
|
| 677 |
-
valid_indices.append(frame_idx - 1)
|
| 678 |
-
for t_pos in range(1, self.num_maskmem): # Iterate over the number of mask memories
|
| 679 |
-
idx = t_pos - self.num_maskmem # Calculate the index for valid indices
|
| 680 |
-
if idx < -len(valid_indices): # Skip if index is out of bounds
|
| 681 |
-
continue
|
| 682 |
-
out = output_dict["non_cond_frame_outputs"].get(valid_indices[idx], None) # Get output for the valid index
|
| 683 |
-
if out is None: # If not found, check unselected outputs
|
| 684 |
-
out = unselected_cond_outputs.get(valid_indices[idx], None)
|
| 685 |
-
t_pos_and_prevs.append((t_pos, out)) # Append the temporal position and output to the list
|
| 686 |
-
else:
|
| 687 |
-
for t_pos in range(1, self.num_maskmem):
|
| 688 |
-
t_rel = self.num_maskmem - t_pos # how many frames before current frame
|
| 689 |
-
if t_rel == 1:
|
| 690 |
-
# for t_rel == 1, we take the last frame (regardless of r)
|
| 691 |
-
if not track_in_reverse:
|
| 692 |
-
# the frame immediately before this frame (i.e. frame_idx - 1)
|
| 693 |
-
prev_frame_idx = frame_idx - t_rel
|
| 694 |
-
else:
|
| 695 |
-
# the frame immediately after this frame (i.e. frame_idx + 1)
|
| 696 |
-
prev_frame_idx = frame_idx + t_rel
|
| 697 |
else:
|
| 698 |
-
#
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
|
|
|
|
|
|
|
|
|
| 717 |
|
| 718 |
for t_pos, prev in t_pos_and_prevs:
|
| 719 |
if prev is None:
|
|
@@ -726,9 +579,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 726 |
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
| 727 |
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
| 728 |
# Temporal positional encoding
|
| 729 |
-
maskmem_enc =
|
| 730 |
-
maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
|
| 731 |
-
)
|
| 732 |
to_cat_memory_pos_embed.append(maskmem_enc)
|
| 733 |
|
| 734 |
# Construct the list of past object pointers
|
|
@@ -738,20 +589,14 @@ class SAM2Base(torch.nn.Module):
|
|
| 738 |
# (optionally, only include object pointers in the past during evaluation)
|
| 739 |
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
|
| 740 |
ptr_cond_outputs = {
|
| 741 |
-
t: out
|
| 742 |
-
for t, out in selected_cond_outputs.items()
|
| 743 |
-
if (t >= frame_idx if track_in_reverse else t <= frame_idx)
|
| 744 |
}
|
| 745 |
else:
|
| 746 |
ptr_cond_outputs = selected_cond_outputs
|
| 747 |
pos_and_ptrs = [
|
| 748 |
# Temporal pos encoding contains how far away each pointer is from current frame
|
| 749 |
(
|
| 750 |
-
(
|
| 751 |
-
(frame_idx - t) * tpos_sign_mul
|
| 752 |
-
if self.use_signed_tpos_enc_to_obj_ptrs
|
| 753 |
-
else abs(frame_idx - t)
|
| 754 |
-
),
|
| 755 |
out["obj_ptr"],
|
| 756 |
)
|
| 757 |
for t, out in ptr_cond_outputs.items()
|
|
@@ -761,9 +606,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 761 |
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
|
| 762 |
if t < 0 or (num_frames is not None and t >= num_frames):
|
| 763 |
break
|
| 764 |
-
out = output_dict["non_cond_frame_outputs"].get(
|
| 765 |
-
t, unselected_cond_outputs.get(t, None)
|
| 766 |
-
)
|
| 767 |
if out is not None:
|
| 768 |
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
| 769 |
# If we have at least one object pointer, add them to the across attention
|
|
@@ -776,9 +619,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 776 |
if self.add_tpos_enc_to_obj_ptrs:
|
| 777 |
t_diff_max = max_obj_ptrs_in_encoder - 1
|
| 778 |
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
| 779 |
-
obj_pos = torch.tensor(pos_list).to(
|
| 780 |
-
device=device, non_blocking=True
|
| 781 |
-
)
|
| 782 |
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
| 783 |
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
| 784 |
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
|
@@ -786,9 +627,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 786 |
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
|
| 787 |
if self.mem_dim < C:
|
| 788 |
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
|
| 789 |
-
obj_ptrs = obj_ptrs.reshape(
|
| 790 |
-
-1, B, C // self.mem_dim, self.mem_dim
|
| 791 |
-
)
|
| 792 |
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
|
| 793 |
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
|
| 794 |
to_cat_memory.append(obj_ptrs)
|
|
@@ -841,9 +680,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 841 |
# optionally, apply non-overlapping constraints to the masks (it's applied
|
| 842 |
# in the batch dimension and should only be used during eval, where all
|
| 843 |
# the objects come from the same video under batch size 1).
|
| 844 |
-
pred_masks_high_res = self._apply_non_overlapping_constraints(
|
| 845 |
-
pred_masks_high_res
|
| 846 |
-
)
|
| 847 |
# scale the raw mask logits with a temperature before applying sigmoid
|
| 848 |
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
| 849 |
if binarize and not self.training:
|
|
@@ -856,18 +693,14 @@ class SAM2Base(torch.nn.Module):
|
|
| 856 |
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
| 857 |
if self.sigmoid_bias_for_mem_enc != 0.0:
|
| 858 |
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
| 859 |
-
maskmem_out = self.memory_encoder(
|
| 860 |
-
pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
|
| 861 |
-
)
|
| 862 |
maskmem_features = maskmem_out["vision_features"]
|
| 863 |
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
| 864 |
# add a no-object embedding to the spatial memory to indicate that the frame
|
| 865 |
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
| 866 |
if self.no_obj_embed_spatial is not None:
|
| 867 |
is_obj_appearing = (object_score_logits > 0).float()
|
| 868 |
-
maskmem_features += (
|
| 869 |
-
1 - is_obj_appearing[..., None, None]
|
| 870 |
-
) * self.no_obj_embed_spatial[..., None, None].expand(
|
| 871 |
*maskmem_features.shape
|
| 872 |
)
|
| 873 |
|
|
@@ -891,8 +724,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 891 |
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
| 892 |
if len(current_vision_feats) > 1:
|
| 893 |
high_res_features = [
|
| 894 |
-
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
| 895 |
-
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
| 896 |
]
|
| 897 |
else:
|
| 898 |
high_res_features = None
|
|
@@ -901,9 +733,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 901 |
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
| 902 |
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
| 903 |
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
| 904 |
-
sam_outputs = self._use_mask_as_output(
|
| 905 |
-
pix_feat, high_res_features, mask_inputs
|
| 906 |
-
)
|
| 907 |
else:
|
| 908 |
# fused the visual feature with previous memory features in the memory bank
|
| 909 |
pix_feat = self._prepare_memory_conditioned_features(
|
|
@@ -1002,15 +832,11 @@ class SAM2Base(torch.nn.Module):
|
|
| 1002 |
high_res_masks,
|
| 1003 |
obj_ptr,
|
| 1004 |
object_score_logits,
|
| 1005 |
-
best_iou_score,
|
| 1006 |
-
kf_ious
|
| 1007 |
) = sam_outputs
|
| 1008 |
|
| 1009 |
current_out["pred_masks"] = low_res_masks
|
| 1010 |
current_out["pred_masks_high_res"] = high_res_masks
|
| 1011 |
current_out["obj_ptr"] = obj_ptr
|
| 1012 |
-
current_out["best_iou_score"] = best_iou_score
|
| 1013 |
-
current_out["kf_ious"] = kf_ious
|
| 1014 |
if not self.training:
|
| 1015 |
# Only add this in inference (to avoid unused param in activation checkpointing;
|
| 1016 |
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
|
|
|
|
| 4 |
# This source code is licensed under the license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
|
|
|
|
|
|
| 7 |
import torch
|
| 8 |
import torch.distributed
|
| 9 |
import torch.nn.functional as F
|
|
|
|
| 10 |
from torch.nn.init import trunc_normal_
|
| 11 |
|
| 12 |
+
from bboxmaskpose.sam2.modeling.sam2_utils import MLP, get_1d_sine_pe, select_closest_cond_frames
|
| 13 |
+
from bboxmaskpose.sam2.modeling.sam.mask_decoder import MaskDecoder
|
| 14 |
+
from bboxmaskpose.sam2.modeling.sam.prompt_encoder import PromptEncoder
|
| 15 |
+
from bboxmaskpose.sam2.modeling.sam.transformer import TwoWayTransformer
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# a large negative value as a placeholder score for missing objects
|
| 18 |
NO_OBJ_SCORE = -1024.0
|
|
|
|
| 92 |
# extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
|
| 93 |
sam_mask_decoder_extra_args=None,
|
| 94 |
compile_image_encoder: bool = False,
|
| 95 |
+
n_kpts_encoder: int = -1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
):
|
| 97 |
super().__init__()
|
| 98 |
+
self.n_kpts_encoder = n_kpts_encoder
|
| 99 |
# Part 1: the image backbone
|
| 100 |
self.image_encoder = image_encoder
|
| 101 |
# Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
|
|
|
|
| 123 |
# Part 3: memory encoder for the previous frame's outputs
|
| 124 |
self.memory_encoder = memory_encoder
|
| 125 |
self.mem_dim = self.hidden_dim
|
| 126 |
+
if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
|
|
|
|
|
|
|
| 127 |
# if there is compression of memories along channel dim
|
| 128 |
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
|
| 129 |
self.num_maskmem = num_maskmem # Number of memories accessible
|
| 130 |
# Temporal encoding of the memories
|
| 131 |
+
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
|
|
|
|
|
|
|
| 132 |
trunc_normal_(self.maskmem_tpos_enc, std=0.02)
|
| 133 |
# a single token to indicate no memory embedding from previous frames
|
| 134 |
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
|
|
|
| 176 |
|
| 177 |
self._build_sam_heads()
|
| 178 |
self.max_cond_frames_in_attn = max_cond_frames_in_attn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
# Model compilation
|
| 180 |
if compile_image_encoder:
|
| 181 |
# Compile the forward function (not the full module) to allow loading checkpoints.
|
| 182 |
+
print("Image encoder compilation is enabled. First forward pass will be slow.")
|
|
|
|
|
|
|
| 183 |
self.image_encoder.forward = torch.compile(
|
| 184 |
self.image_encoder.forward,
|
| 185 |
mode="max-autotune",
|
|
|
|
| 187 |
dynamic=False,
|
| 188 |
)
|
| 189 |
|
| 190 |
+
freeze_prompt_encoder = False
|
| 191 |
+
freeze_mask_decoder = False
|
| 192 |
+
if freeze_prompt_encoder:
|
| 193 |
+
for p in self.sam_prompt_encoder.parameters():
|
| 194 |
+
p.requires_grad = False
|
| 195 |
+
if freeze_mask_decoder:
|
| 196 |
+
for p in self.sam_mask_decoder.parameters():
|
| 197 |
+
p.requires_grad = False
|
| 198 |
+
|
| 199 |
@property
|
| 200 |
def device(self):
|
| 201 |
return next(self.parameters()).device
|
|
|
|
| 221 |
),
|
| 222 |
input_image_size=(self.image_size, self.image_size),
|
| 223 |
mask_in_chans=16,
|
| 224 |
+
n_kpts_encoder=self.n_kpts_encoder,
|
| 225 |
)
|
| 226 |
+
|
| 227 |
self.sam_mask_decoder = MaskDecoder(
|
| 228 |
num_multimask_outputs=3,
|
| 229 |
transformer=TwoWayTransformer(
|
|
|
|
| 242 |
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
|
| 243 |
**(self.sam_mask_decoder_extra_args or {}),
|
| 244 |
)
|
| 245 |
+
for p in self.sam_prompt_encoder.parameters():
|
| 246 |
+
p.requires_grad = True
|
| 247 |
+
for p in self.sam_mask_decoder.parameters():
|
| 248 |
+
p.requires_grad = True
|
| 249 |
+
|
| 250 |
if self.use_obj_ptrs_in_encoder:
|
| 251 |
# a linear projection on SAM output tokens to turn them into object pointers
|
| 252 |
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 253 |
if self.use_mlp_for_obj_ptr_proj:
|
| 254 |
+
self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
|
|
|
|
|
|
|
| 255 |
else:
|
| 256 |
self.obj_ptr_proj = torch.nn.Identity()
|
| 257 |
if self.proj_tpos_enc_in_obj_ptrs:
|
|
|
|
| 364 |
high_res_features=high_res_features,
|
| 365 |
)
|
| 366 |
if self.pred_obj_scores:
|
| 367 |
+
is_obj_appearing = object_score_logits > 0
|
| 368 |
|
| 369 |
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
| 370 |
# consistent with the actual mask prediction
|
|
|
|
| 385 |
)
|
| 386 |
|
| 387 |
sam_output_token = sam_output_tokens[:, 0]
|
| 388 |
+
if multimask_output:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
# take the best mask prediction (with the highest IoU estimation)
|
| 390 |
best_iou_inds = torch.argmax(ious, dim=-1)
|
| 391 |
batch_inds = torch.arange(B, device=device)
|
|
|
|
| 394 |
if sam_output_tokens.size(1) > 1:
|
| 395 |
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
| 396 |
else:
|
|
|
|
| 397 |
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
| 398 |
|
| 399 |
# Extract object pointer from the SAM output token (with occlusion handling)
|
|
|
|
| 417 |
high_res_masks,
|
| 418 |
obj_ptr,
|
| 419 |
object_score_logits,
|
|
|
|
|
|
|
| 420 |
)
|
| 421 |
|
| 422 |
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
|
|
|
|
| 439 |
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
|
| 440 |
if not self.use_obj_ptrs_in_encoder:
|
| 441 |
# all zeros as a dummy object pointer (of shape [B, C])
|
| 442 |
+
obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
|
|
|
|
|
|
|
| 443 |
else:
|
| 444 |
# produce an object pointer using the SAM decoder from the mask input
|
| 445 |
+
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
|
| 446 |
backbone_features=backbone_features,
|
| 447 |
mask_inputs=self.mask_downsample(mask_inputs_float),
|
| 448 |
high_res_features=high_res_features,
|
|
|
|
| 475 |
if self.use_high_res_features_in_sam:
|
| 476 |
# precompute projected level 0 and level 1 features in SAM decoder
|
| 477 |
# to avoid running it again on every SAM click
|
| 478 |
+
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
|
| 479 |
+
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
return backbone_out
|
| 481 |
|
| 482 |
def _prepare_backbone_features(self, backbone_out):
|
|
|
|
| 537 |
# We also allow taking the memory frame non-consecutively (with stride>1), in which case
|
| 538 |
# we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
|
| 539 |
stride = 1 if self.training else self.memory_temporal_stride_for_eval
|
| 540 |
+
for t_pos in range(1, self.num_maskmem):
|
| 541 |
+
t_rel = self.num_maskmem - t_pos # how many frames before current frame
|
| 542 |
+
if t_rel == 1:
|
| 543 |
+
# for t_rel == 1, we take the last frame (regardless of r)
|
| 544 |
+
if not track_in_reverse:
|
| 545 |
+
# the frame immediately before this frame (i.e. frame_idx - 1)
|
| 546 |
+
prev_frame_idx = frame_idx - t_rel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
else:
|
| 548 |
+
# the frame immediately after this frame (i.e. frame_idx + 1)
|
| 549 |
+
prev_frame_idx = frame_idx + t_rel
|
| 550 |
+
else:
|
| 551 |
+
# for t_rel >= 2, we take the memory frame from every r-th frames
|
| 552 |
+
if not track_in_reverse:
|
| 553 |
+
# first find the nearest frame among every r-th frames before this frame
|
| 554 |
+
# for r=1, this would be (frame_idx - 2)
|
| 555 |
+
prev_frame_idx = ((frame_idx - 2) // stride) * stride
|
| 556 |
+
# then seek further among every r-th frames
|
| 557 |
+
prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
|
| 558 |
+
else:
|
| 559 |
+
# first find the nearest frame among every r-th frames after this frame
|
| 560 |
+
# for r=1, this would be (frame_idx + 2)
|
| 561 |
+
prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
|
| 562 |
+
# then seek further among every r-th frames
|
| 563 |
+
prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
|
| 564 |
+
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
|
| 565 |
+
if out is None:
|
| 566 |
+
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
|
| 567 |
+
# frames, we still attend to it as if it's a non-conditioning frame.
|
| 568 |
+
out = unselected_cond_outputs.get(prev_frame_idx, None)
|
| 569 |
+
t_pos_and_prevs.append((t_pos, out))
|
| 570 |
|
| 571 |
for t_pos, prev in t_pos_and_prevs:
|
| 572 |
if prev is None:
|
|
|
|
| 579 |
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
| 580 |
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
| 581 |
# Temporal positional encoding
|
| 582 |
+
maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
|
|
|
|
|
|
|
| 583 |
to_cat_memory_pos_embed.append(maskmem_enc)
|
| 584 |
|
| 585 |
# Construct the list of past object pointers
|
|
|
|
| 589 |
# (optionally, only include object pointers in the past during evaluation)
|
| 590 |
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
|
| 591 |
ptr_cond_outputs = {
|
| 592 |
+
t: out for t, out in selected_cond_outputs.items() if (t >= frame_idx if track_in_reverse else t <= frame_idx)
|
|
|
|
|
|
|
| 593 |
}
|
| 594 |
else:
|
| 595 |
ptr_cond_outputs = selected_cond_outputs
|
| 596 |
pos_and_ptrs = [
|
| 597 |
# Temporal pos encoding contains how far away each pointer is from current frame
|
| 598 |
(
|
| 599 |
+
((frame_idx - t) * tpos_sign_mul if self.use_signed_tpos_enc_to_obj_ptrs else abs(frame_idx - t)),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
out["obj_ptr"],
|
| 601 |
)
|
| 602 |
for t, out in ptr_cond_outputs.items()
|
|
|
|
| 606 |
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
|
| 607 |
if t < 0 or (num_frames is not None and t >= num_frames):
|
| 608 |
break
|
| 609 |
+
out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
|
|
|
|
|
|
|
| 610 |
if out is not None:
|
| 611 |
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
| 612 |
# If we have at least one object pointer, add them to the across attention
|
|
|
|
| 619 |
if self.add_tpos_enc_to_obj_ptrs:
|
| 620 |
t_diff_max = max_obj_ptrs_in_encoder - 1
|
| 621 |
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
| 622 |
+
obj_pos = torch.tensor(pos_list).to(device=device, non_blocking=True)
|
|
|
|
|
|
|
| 623 |
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
| 624 |
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
| 625 |
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
|
|
|
| 627 |
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
|
| 628 |
if self.mem_dim < C:
|
| 629 |
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
|
| 630 |
+
obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
|
|
|
|
|
|
|
| 631 |
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
|
| 632 |
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
|
| 633 |
to_cat_memory.append(obj_ptrs)
|
|
|
|
| 680 |
# optionally, apply non-overlapping constraints to the masks (it's applied
|
| 681 |
# in the batch dimension and should only be used during eval, where all
|
| 682 |
# the objects come from the same video under batch size 1).
|
| 683 |
+
pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
|
|
|
|
|
|
|
| 684 |
# scale the raw mask logits with a temperature before applying sigmoid
|
| 685 |
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
| 686 |
if binarize and not self.training:
|
|
|
|
| 693 |
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
| 694 |
if self.sigmoid_bias_for_mem_enc != 0.0:
|
| 695 |
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
| 696 |
+
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
|
|
|
|
|
|
|
| 697 |
maskmem_features = maskmem_out["vision_features"]
|
| 698 |
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
| 699 |
# add a no-object embedding to the spatial memory to indicate that the frame
|
| 700 |
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
| 701 |
if self.no_obj_embed_spatial is not None:
|
| 702 |
is_obj_appearing = (object_score_logits > 0).float()
|
| 703 |
+
maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[..., None, None].expand(
|
|
|
|
|
|
|
| 704 |
*maskmem_features.shape
|
| 705 |
)
|
| 706 |
|
|
|
|
| 724 |
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
| 725 |
if len(current_vision_feats) > 1:
|
| 726 |
high_res_features = [
|
| 727 |
+
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
|
|
|
| 728 |
]
|
| 729 |
else:
|
| 730 |
high_res_features = None
|
|
|
|
| 733 |
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
| 734 |
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
| 735 |
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
| 736 |
+
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
|
|
|
|
|
|
|
| 737 |
else:
|
| 738 |
# fused the visual feature with previous memory features in the memory bank
|
| 739 |
pix_feat = self._prepare_memory_conditioned_features(
|
|
|
|
| 832 |
high_res_masks,
|
| 833 |
obj_ptr,
|
| 834 |
object_score_logits,
|
|
|
|
|
|
|
| 835 |
) = sam_outputs
|
| 836 |
|
| 837 |
current_out["pred_masks"] = low_res_masks
|
| 838 |
current_out["pred_masks_high_res"] = high_res_masks
|
| 839 |
current_out["obj_ptr"] = obj_ptr
|
|
|
|
|
|
|
| 840 |
if not self.training:
|
| 841 |
# Only add this in inference (to avoid unused param in activation checkpointing;
|
| 842 |
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
|
{sam2 → bboxmaskpose/sam2}/modeling/sam2_base_pose.py
RENAMED
|
@@ -4,20 +4,17 @@
|
|
| 4 |
# This source code is licensed under the license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
-
from loguru import logger
|
| 8 |
-
|
| 9 |
import torch
|
| 10 |
import torch.distributed
|
| 11 |
import torch.nn.functional as F
|
| 12 |
-
|
| 13 |
from torch.nn.init import trunc_normal_
|
| 14 |
|
| 15 |
-
from sam2.modeling.
|
| 16 |
-
from sam2.modeling.sam.
|
| 17 |
-
from sam2.modeling.sam.
|
| 18 |
-
from sam2.modeling.
|
| 19 |
-
|
| 20 |
-
from
|
| 21 |
|
| 22 |
# a large negative value as a placeholder score for missing objects
|
| 23 |
NO_OBJ_SCORE = -1024.0
|
|
@@ -137,16 +134,12 @@ class SAM2Base(torch.nn.Module):
|
|
| 137 |
# Part 3: memory encoder for the previous frame's outputs
|
| 138 |
self.memory_encoder = memory_encoder
|
| 139 |
self.mem_dim = self.hidden_dim
|
| 140 |
-
if hasattr(self.memory_encoder, "out_proj") and hasattr(
|
| 141 |
-
self.memory_encoder.out_proj, "weight"
|
| 142 |
-
):
|
| 143 |
# if there is compression of memories along channel dim
|
| 144 |
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
|
| 145 |
self.num_maskmem = num_maskmem # Number of memories accessible
|
| 146 |
# Temporal encoding of the memories
|
| 147 |
-
self.maskmem_tpos_enc = torch.nn.Parameter(
|
| 148 |
-
torch.zeros(num_maskmem, 1, 1, self.mem_dim)
|
| 149 |
-
)
|
| 150 |
trunc_normal_(self.maskmem_tpos_enc, std=0.02)
|
| 151 |
# a single token to indicate no memory embedding from previous frames
|
| 152 |
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
|
@@ -205,8 +198,8 @@ class SAM2Base(torch.nn.Module):
|
|
| 205 |
self.stable_frames = 0
|
| 206 |
|
| 207 |
# Debug purpose
|
| 208 |
-
self.history = {}
|
| 209 |
-
self.frame_cnt = 0
|
| 210 |
|
| 211 |
# Hyperparameters for SAMURAI
|
| 212 |
self.stable_frames_threshold = stable_frames_threshold
|
|
@@ -222,9 +215,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 222 |
# Model compilation
|
| 223 |
if compile_image_encoder:
|
| 224 |
# Compile the forward function (not the full module) to allow loading checkpoints.
|
| 225 |
-
print(
|
| 226 |
-
"Image encoder compilation is enabled. First forward pass will be slow."
|
| 227 |
-
)
|
| 228 |
self.image_encoder.forward = torch.compile(
|
| 229 |
self.image_encoder.forward,
|
| 230 |
mode="max-autotune",
|
|
@@ -280,9 +271,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 280 |
# a linear projection on SAM output tokens to turn them into object pointers
|
| 281 |
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 282 |
if self.use_mlp_for_obj_ptr_proj:
|
| 283 |
-
self.obj_ptr_proj = MLP(
|
| 284 |
-
self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
|
| 285 |
-
)
|
| 286 |
else:
|
| 287 |
self.obj_ptr_proj = torch.nn.Identity()
|
| 288 |
if self.proj_tpos_enc_in_obj_ptrs:
|
|
@@ -480,7 +469,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 480 |
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
| 481 |
|
| 482 |
if False:
|
| 483 |
-
# make all these on cpu
|
| 484 |
self.history[self.frame_cnt] = {
|
| 485 |
"kf_predicted_bbox": self.kf.xyah_to_xyxy(self.kf_mean[:4]),
|
| 486 |
# "multi_masks": high_res_multimasks.cpu(),
|
|
@@ -495,7 +484,9 @@ class SAM2Base(torch.nn.Module):
|
|
| 495 |
if ious[0][best_iou_inds] < self.stable_ious_threshold:
|
| 496 |
self.stable_frames = 0
|
| 497 |
else:
|
| 498 |
-
self.kf_mean, self.kf_covariance = self.kf.update(
|
|
|
|
|
|
|
| 499 |
elif multimask_output and not self.samurai_mode:
|
| 500 |
# take the best mask prediction (with the highest IoU estimation)
|
| 501 |
best_iou_inds = torch.argmax(ious, dim=-1)
|
|
@@ -553,9 +544,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 553 |
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
|
| 554 |
if not self.use_obj_ptrs_in_encoder:
|
| 555 |
# all zeros as a dummy object pointer (of shape [B, C])
|
| 556 |
-
obj_ptr = torch.zeros(
|
| 557 |
-
mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
|
| 558 |
-
)
|
| 559 |
else:
|
| 560 |
# produce an object pointer using the SAM decoder from the mask input
|
| 561 |
_, _, _, _, _, obj_ptr, _, _, _ = self._forward_sam_heads(
|
|
@@ -591,12 +580,8 @@ class SAM2Base(torch.nn.Module):
|
|
| 591 |
if self.use_high_res_features_in_sam:
|
| 592 |
# precompute projected level 0 and level 1 features in SAM decoder
|
| 593 |
# to avoid running it again on every SAM click
|
| 594 |
-
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
|
| 595 |
-
|
| 596 |
-
)
|
| 597 |
-
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
|
| 598 |
-
backbone_out["backbone_fpn"][1]
|
| 599 |
-
)
|
| 600 |
return backbone_out
|
| 601 |
|
| 602 |
def _prepare_backbone_features(self, backbone_out):
|
|
@@ -659,21 +644,27 @@ class SAM2Base(torch.nn.Module):
|
|
| 659 |
stride = 1 if self.training else self.memory_temporal_stride_for_eval
|
| 660 |
|
| 661 |
if self.samurai_mode:
|
| 662 |
-
valid_indices = []
|
| 663 |
if frame_idx > 1: # Ensure we have previous frames to evaluate
|
| 664 |
for i in range(frame_idx - 1, 1, -1): # Iterate backwards through previous frames
|
| 665 |
iou_score = output_dict["non_cond_frame_outputs"][i]["best_iou_score"] # Get mask affinity score
|
| 666 |
obj_score = output_dict["non_cond_frame_outputs"][i]["object_score_logits"] # Get object score
|
| 667 |
-
kf_score =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
# Check if the scores meet the criteria for being a valid index
|
| 669 |
-
if
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
|
|
|
|
|
|
| 673 |
# Check the number of valid indices
|
| 674 |
-
if len(valid_indices) >= self.max_obj_ptrs_in_encoder - 1:
|
| 675 |
break
|
| 676 |
-
if frame_idx - 1 not in valid_indices:
|
| 677 |
valid_indices.append(frame_idx - 1)
|
| 678 |
for t_pos in range(1, self.num_maskmem): # Iterate over the number of mask memories
|
| 679 |
idx = t_pos - self.num_maskmem # Calculate the index for valid indices
|
|
@@ -726,9 +717,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 726 |
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
| 727 |
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
| 728 |
# Temporal positional encoding
|
| 729 |
-
maskmem_enc =
|
| 730 |
-
maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
|
| 731 |
-
)
|
| 732 |
to_cat_memory_pos_embed.append(maskmem_enc)
|
| 733 |
|
| 734 |
# Construct the list of past object pointers
|
|
@@ -738,20 +727,14 @@ class SAM2Base(torch.nn.Module):
|
|
| 738 |
# (optionally, only include object pointers in the past during evaluation)
|
| 739 |
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
|
| 740 |
ptr_cond_outputs = {
|
| 741 |
-
t: out
|
| 742 |
-
for t, out in selected_cond_outputs.items()
|
| 743 |
-
if (t >= frame_idx if track_in_reverse else t <= frame_idx)
|
| 744 |
}
|
| 745 |
else:
|
| 746 |
ptr_cond_outputs = selected_cond_outputs
|
| 747 |
pos_and_ptrs = [
|
| 748 |
# Temporal pos encoding contains how far away each pointer is from current frame
|
| 749 |
(
|
| 750 |
-
(
|
| 751 |
-
(frame_idx - t) * tpos_sign_mul
|
| 752 |
-
if self.use_signed_tpos_enc_to_obj_ptrs
|
| 753 |
-
else abs(frame_idx - t)
|
| 754 |
-
),
|
| 755 |
out["obj_ptr"],
|
| 756 |
)
|
| 757 |
for t, out in ptr_cond_outputs.items()
|
|
@@ -761,9 +744,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 761 |
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
|
| 762 |
if t < 0 or (num_frames is not None and t >= num_frames):
|
| 763 |
break
|
| 764 |
-
out = output_dict["non_cond_frame_outputs"].get(
|
| 765 |
-
t, unselected_cond_outputs.get(t, None)
|
| 766 |
-
)
|
| 767 |
if out is not None:
|
| 768 |
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
| 769 |
# If we have at least one object pointer, add them to the across attention
|
|
@@ -776,9 +757,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 776 |
if self.add_tpos_enc_to_obj_ptrs:
|
| 777 |
t_diff_max = max_obj_ptrs_in_encoder - 1
|
| 778 |
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
| 779 |
-
obj_pos = torch.tensor(pos_list).to(
|
| 780 |
-
device=device, non_blocking=True
|
| 781 |
-
)
|
| 782 |
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
| 783 |
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
| 784 |
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
|
@@ -786,9 +765,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 786 |
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
|
| 787 |
if self.mem_dim < C:
|
| 788 |
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
|
| 789 |
-
obj_ptrs = obj_ptrs.reshape(
|
| 790 |
-
-1, B, C // self.mem_dim, self.mem_dim
|
| 791 |
-
)
|
| 792 |
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
|
| 793 |
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
|
| 794 |
to_cat_memory.append(obj_ptrs)
|
|
@@ -841,9 +818,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 841 |
# optionally, apply non-overlapping constraints to the masks (it's applied
|
| 842 |
# in the batch dimension and should only be used during eval, where all
|
| 843 |
# the objects come from the same video under batch size 1).
|
| 844 |
-
pred_masks_high_res = self._apply_non_overlapping_constraints(
|
| 845 |
-
pred_masks_high_res
|
| 846 |
-
)
|
| 847 |
# scale the raw mask logits with a temperature before applying sigmoid
|
| 848 |
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
| 849 |
if binarize and not self.training:
|
|
@@ -856,18 +831,14 @@ class SAM2Base(torch.nn.Module):
|
|
| 856 |
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
| 857 |
if self.sigmoid_bias_for_mem_enc != 0.0:
|
| 858 |
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
| 859 |
-
maskmem_out = self.memory_encoder(
|
| 860 |
-
pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
|
| 861 |
-
)
|
| 862 |
maskmem_features = maskmem_out["vision_features"]
|
| 863 |
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
| 864 |
# add a no-object embedding to the spatial memory to indicate that the frame
|
| 865 |
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
| 866 |
if self.no_obj_embed_spatial is not None:
|
| 867 |
is_obj_appearing = (object_score_logits > 0).float()
|
| 868 |
-
maskmem_features += (
|
| 869 |
-
1 - is_obj_appearing[..., None, None]
|
| 870 |
-
) * self.no_obj_embed_spatial[..., None, None].expand(
|
| 871 |
*maskmem_features.shape
|
| 872 |
)
|
| 873 |
|
|
@@ -891,8 +862,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 891 |
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
| 892 |
if len(current_vision_feats) > 1:
|
| 893 |
high_res_features = [
|
| 894 |
-
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
| 895 |
-
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
| 896 |
]
|
| 897 |
else:
|
| 898 |
high_res_features = None
|
|
@@ -901,9 +871,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 901 |
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
| 902 |
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
| 903 |
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
| 904 |
-
sam_outputs = self._use_mask_as_output(
|
| 905 |
-
pix_feat, high_res_features, mask_inputs
|
| 906 |
-
)
|
| 907 |
else:
|
| 908 |
# fused the visual feature with previous memory features in the memory bank
|
| 909 |
pix_feat = self._prepare_memory_conditioned_features(
|
|
@@ -994,17 +962,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 994 |
prev_sam_mask_logits,
|
| 995 |
)
|
| 996 |
|
| 997 |
-
|
| 998 |
-
_,
|
| 999 |
-
_,
|
| 1000 |
-
_,
|
| 1001 |
-
low_res_masks,
|
| 1002 |
-
high_res_masks,
|
| 1003 |
-
obj_ptr,
|
| 1004 |
-
object_score_logits,
|
| 1005 |
-
best_iou_score,
|
| 1006 |
-
kf_ious
|
| 1007 |
-
) = sam_outputs
|
| 1008 |
|
| 1009 |
current_out["pred_masks"] = low_res_masks
|
| 1010 |
current_out["pred_masks_high_res"] = high_res_masks
|
|
|
|
| 4 |
# This source code is licensed under the license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
|
|
|
|
|
|
| 7 |
import torch
|
| 8 |
import torch.distributed
|
| 9 |
import torch.nn.functional as F
|
|
|
|
| 10 |
from torch.nn.init import trunc_normal_
|
| 11 |
|
| 12 |
+
from bboxmaskpose.sam2.modeling.sam2_utils import MLP, get_1d_sine_pe, select_closest_cond_frames
|
| 13 |
+
from bboxmaskpose.sam2.modeling.sam.mask_decoder import MaskDecoder
|
| 14 |
+
from bboxmaskpose.sam2.modeling.sam.pose_encoder import PoseEncoder
|
| 15 |
+
from bboxmaskpose.sam2.modeling.sam.transformer import TwoWayTransformer
|
| 16 |
+
from bboxmaskpose.sam2.utils.kalman_filter import KalmanFilter
|
| 17 |
+
from loguru import logger
|
| 18 |
|
| 19 |
# a large negative value as a placeholder score for missing objects
|
| 20 |
NO_OBJ_SCORE = -1024.0
|
|
|
|
| 134 |
# Part 3: memory encoder for the previous frame's outputs
|
| 135 |
self.memory_encoder = memory_encoder
|
| 136 |
self.mem_dim = self.hidden_dim
|
| 137 |
+
if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
|
|
|
|
|
|
|
| 138 |
# if there is compression of memories along channel dim
|
| 139 |
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
|
| 140 |
self.num_maskmem = num_maskmem # Number of memories accessible
|
| 141 |
# Temporal encoding of the memories
|
| 142 |
+
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
|
|
|
|
|
|
|
| 143 |
trunc_normal_(self.maskmem_tpos_enc, std=0.02)
|
| 144 |
# a single token to indicate no memory embedding from previous frames
|
| 145 |
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
|
|
|
| 198 |
self.stable_frames = 0
|
| 199 |
|
| 200 |
# Debug purpose
|
| 201 |
+
self.history = {} # debug
|
| 202 |
+
self.frame_cnt = 0 # debug
|
| 203 |
|
| 204 |
# Hyperparameters for SAMURAI
|
| 205 |
self.stable_frames_threshold = stable_frames_threshold
|
|
|
|
| 215 |
# Model compilation
|
| 216 |
if compile_image_encoder:
|
| 217 |
# Compile the forward function (not the full module) to allow loading checkpoints.
|
| 218 |
+
print("Image encoder compilation is enabled. First forward pass will be slow.")
|
|
|
|
|
|
|
| 219 |
self.image_encoder.forward = torch.compile(
|
| 220 |
self.image_encoder.forward,
|
| 221 |
mode="max-autotune",
|
|
|
|
| 271 |
# a linear projection on SAM output tokens to turn them into object pointers
|
| 272 |
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 273 |
if self.use_mlp_for_obj_ptr_proj:
|
| 274 |
+
self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
|
|
|
|
|
|
|
| 275 |
else:
|
| 276 |
self.obj_ptr_proj = torch.nn.Identity()
|
| 277 |
if self.proj_tpos_enc_in_obj_ptrs:
|
|
|
|
| 469 |
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
| 470 |
|
| 471 |
if False:
|
| 472 |
+
# make all these on cpu
|
| 473 |
self.history[self.frame_cnt] = {
|
| 474 |
"kf_predicted_bbox": self.kf.xyah_to_xyxy(self.kf_mean[:4]),
|
| 475 |
# "multi_masks": high_res_multimasks.cpu(),
|
|
|
|
| 484 |
if ious[0][best_iou_inds] < self.stable_ious_threshold:
|
| 485 |
self.stable_frames = 0
|
| 486 |
else:
|
| 487 |
+
self.kf_mean, self.kf_covariance = self.kf.update(
|
| 488 |
+
self.kf_mean, self.kf_covariance, self.kf.xyxy_to_xyah(high_res_multibboxes[best_iou_inds])
|
| 489 |
+
)
|
| 490 |
elif multimask_output and not self.samurai_mode:
|
| 491 |
# take the best mask prediction (with the highest IoU estimation)
|
| 492 |
best_iou_inds = torch.argmax(ious, dim=-1)
|
|
|
|
| 544 |
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
|
| 545 |
if not self.use_obj_ptrs_in_encoder:
|
| 546 |
# all zeros as a dummy object pointer (of shape [B, C])
|
| 547 |
+
obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
|
|
|
|
|
|
|
| 548 |
else:
|
| 549 |
# produce an object pointer using the SAM decoder from the mask input
|
| 550 |
_, _, _, _, _, obj_ptr, _, _, _ = self._forward_sam_heads(
|
|
|
|
| 580 |
if self.use_high_res_features_in_sam:
|
| 581 |
# precompute projected level 0 and level 1 features in SAM decoder
|
| 582 |
# to avoid running it again on every SAM click
|
| 583 |
+
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
|
| 584 |
+
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
return backbone_out
|
| 586 |
|
| 587 |
def _prepare_backbone_features(self, backbone_out):
|
|
|
|
| 644 |
stride = 1 if self.training else self.memory_temporal_stride_for_eval
|
| 645 |
|
| 646 |
if self.samurai_mode:
|
| 647 |
+
valid_indices = []
|
| 648 |
if frame_idx > 1: # Ensure we have previous frames to evaluate
|
| 649 |
for i in range(frame_idx - 1, 1, -1): # Iterate backwards through previous frames
|
| 650 |
iou_score = output_dict["non_cond_frame_outputs"][i]["best_iou_score"] # Get mask affinity score
|
| 651 |
obj_score = output_dict["non_cond_frame_outputs"][i]["object_score_logits"] # Get object score
|
| 652 |
+
kf_score = (
|
| 653 |
+
output_dict["non_cond_frame_outputs"][i]["kf_score"]
|
| 654 |
+
if "kf_score" in output_dict["non_cond_frame_outputs"][i]
|
| 655 |
+
else None
|
| 656 |
+
) # Get motion score if available
|
| 657 |
# Check if the scores meet the criteria for being a valid index
|
| 658 |
+
if (
|
| 659 |
+
iou_score.item() > self.memory_bank_iou_threshold
|
| 660 |
+
and obj_score.item() > self.memory_bank_obj_score_threshold
|
| 661 |
+
and (kf_score is None or kf_score.item() > self.memory_bank_kf_score_threshold)
|
| 662 |
+
):
|
| 663 |
+
valid_indices.insert(0, i)
|
| 664 |
# Check the number of valid indices
|
| 665 |
+
if len(valid_indices) >= self.max_obj_ptrs_in_encoder - 1:
|
| 666 |
break
|
| 667 |
+
if frame_idx - 1 not in valid_indices:
|
| 668 |
valid_indices.append(frame_idx - 1)
|
| 669 |
for t_pos in range(1, self.num_maskmem): # Iterate over the number of mask memories
|
| 670 |
idx = t_pos - self.num_maskmem # Calculate the index for valid indices
|
|
|
|
| 717 |
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
| 718 |
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
| 719 |
# Temporal positional encoding
|
| 720 |
+
maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
|
|
|
|
|
|
|
| 721 |
to_cat_memory_pos_embed.append(maskmem_enc)
|
| 722 |
|
| 723 |
# Construct the list of past object pointers
|
|
|
|
| 727 |
# (optionally, only include object pointers in the past during evaluation)
|
| 728 |
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
|
| 729 |
ptr_cond_outputs = {
|
| 730 |
+
t: out for t, out in selected_cond_outputs.items() if (t >= frame_idx if track_in_reverse else t <= frame_idx)
|
|
|
|
|
|
|
| 731 |
}
|
| 732 |
else:
|
| 733 |
ptr_cond_outputs = selected_cond_outputs
|
| 734 |
pos_and_ptrs = [
|
| 735 |
# Temporal pos encoding contains how far away each pointer is from current frame
|
| 736 |
(
|
| 737 |
+
((frame_idx - t) * tpos_sign_mul if self.use_signed_tpos_enc_to_obj_ptrs else abs(frame_idx - t)),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
out["obj_ptr"],
|
| 739 |
)
|
| 740 |
for t, out in ptr_cond_outputs.items()
|
|
|
|
| 744 |
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
|
| 745 |
if t < 0 or (num_frames is not None and t >= num_frames):
|
| 746 |
break
|
| 747 |
+
out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
|
|
|
|
|
|
|
| 748 |
if out is not None:
|
| 749 |
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
| 750 |
# If we have at least one object pointer, add them to the across attention
|
|
|
|
| 757 |
if self.add_tpos_enc_to_obj_ptrs:
|
| 758 |
t_diff_max = max_obj_ptrs_in_encoder - 1
|
| 759 |
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
| 760 |
+
obj_pos = torch.tensor(pos_list).to(device=device, non_blocking=True)
|
|
|
|
|
|
|
| 761 |
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
| 762 |
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
| 763 |
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
|
|
|
| 765 |
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
|
| 766 |
if self.mem_dim < C:
|
| 767 |
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
|
| 768 |
+
obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
|
|
|
|
|
|
|
| 769 |
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
|
| 770 |
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
|
| 771 |
to_cat_memory.append(obj_ptrs)
|
|
|
|
| 818 |
# optionally, apply non-overlapping constraints to the masks (it's applied
|
| 819 |
# in the batch dimension and should only be used during eval, where all
|
| 820 |
# the objects come from the same video under batch size 1).
|
| 821 |
+
pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
|
|
|
|
|
|
|
| 822 |
# scale the raw mask logits with a temperature before applying sigmoid
|
| 823 |
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
| 824 |
if binarize and not self.training:
|
|
|
|
| 831 |
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
| 832 |
if self.sigmoid_bias_for_mem_enc != 0.0:
|
| 833 |
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
| 834 |
+
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
|
|
|
|
|
|
|
| 835 |
maskmem_features = maskmem_out["vision_features"]
|
| 836 |
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
| 837 |
# add a no-object embedding to the spatial memory to indicate that the frame
|
| 838 |
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
| 839 |
if self.no_obj_embed_spatial is not None:
|
| 840 |
is_obj_appearing = (object_score_logits > 0).float()
|
| 841 |
+
maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[..., None, None].expand(
|
|
|
|
|
|
|
| 842 |
*maskmem_features.shape
|
| 843 |
)
|
| 844 |
|
|
|
|
| 862 |
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
| 863 |
if len(current_vision_feats) > 1:
|
| 864 |
high_res_features = [
|
| 865 |
+
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
|
|
|
| 866 |
]
|
| 867 |
else:
|
| 868 |
high_res_features = None
|
|
|
|
| 871 |
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
| 872 |
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
| 873 |
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
| 874 |
+
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
|
|
|
|
|
|
|
| 875 |
else:
|
| 876 |
# fused the visual feature with previous memory features in the memory bank
|
| 877 |
pix_feat = self._prepare_memory_conditioned_features(
|
|
|
|
| 962 |
prev_sam_mask_logits,
|
| 963 |
)
|
| 964 |
|
| 965 |
+
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits, best_iou_score, kf_ious = sam_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 966 |
|
| 967 |
current_out["pred_masks"] = low_res_masks
|
| 968 |
current_out["pred_masks_high_res"] = high_res_masks
|
{sam2 → bboxmaskpose/sam2}/modeling/sam2_utils.py
RENAMED
|
@@ -13,7 +13,7 @@ import torch
|
|
| 13 |
import torch.nn as nn
|
| 14 |
import torch.nn.functional as F
|
| 15 |
|
| 16 |
-
from sam2.utils.misc import mask_to_box
|
| 17 |
|
| 18 |
|
| 19 |
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
|
|
@@ -54,9 +54,7 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
|
|
| 54 |
key=lambda x: abs(x - frame_idx),
|
| 55 |
)[:num_remain]
|
| 56 |
selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
|
| 57 |
-
unselected_outputs = {
|
| 58 |
-
t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
|
| 59 |
-
}
|
| 60 |
|
| 61 |
return selected_outputs, unselected_outputs
|
| 62 |
|
|
@@ -122,9 +120,7 @@ class MLP(nn.Module):
|
|
| 122 |
super().__init__()
|
| 123 |
self.num_layers = num_layers
|
| 124 |
h = [hidden_dim] * (num_layers - 1)
|
| 125 |
-
self.layers = nn.ModuleList(
|
| 126 |
-
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
| 127 |
-
)
|
| 128 |
self.sigmoid_output = sigmoid_output
|
| 129 |
self.act = activation()
|
| 130 |
|
|
@@ -175,9 +171,7 @@ def sample_box_points(
|
|
| 175 |
device = masks.device
|
| 176 |
box_coords = mask_to_box(masks)
|
| 177 |
B, _, H, W = masks.shape
|
| 178 |
-
box_labels = torch.tensor(
|
| 179 |
-
[top_left_label, bottom_right_label], dtype=torch.int, device=device
|
| 180 |
-
).repeat(B)
|
| 181 |
if noise > 0.0:
|
| 182 |
if not isinstance(noise_bound, torch.Tensor):
|
| 183 |
noise_bound = torch.tensor(noise_bound, device=device)
|
|
@@ -189,9 +183,7 @@ def sample_box_points(
|
|
| 189 |
box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
|
| 190 |
|
| 191 |
box_coords = box_coords + box_noise
|
| 192 |
-
img_bounds = (
|
| 193 |
-
torch.tensor([W, H, W, H], device=device) - 1
|
| 194 |
-
) # uncentered pixel coords
|
| 195 |
box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
|
| 196 |
|
| 197 |
box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
|
|
|
|
| 13 |
import torch.nn as nn
|
| 14 |
import torch.nn.functional as F
|
| 15 |
|
| 16 |
+
from bboxmaskpose.sam2.utils.misc import mask_to_box
|
| 17 |
|
| 18 |
|
| 19 |
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
|
|
|
|
| 54 |
key=lambda x: abs(x - frame_idx),
|
| 55 |
)[:num_remain]
|
| 56 |
selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
|
| 57 |
+
unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}
|
|
|
|
|
|
|
| 58 |
|
| 59 |
return selected_outputs, unselected_outputs
|
| 60 |
|
|
|
|
| 120 |
super().__init__()
|
| 121 |
self.num_layers = num_layers
|
| 122 |
h = [hidden_dim] * (num_layers - 1)
|
| 123 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
|
|
|
|
|
|
| 124 |
self.sigmoid_output = sigmoid_output
|
| 125 |
self.act = activation()
|
| 126 |
|
|
|
|
| 171 |
device = masks.device
|
| 172 |
box_coords = mask_to_box(masks)
|
| 173 |
B, _, H, W = masks.shape
|
| 174 |
+
box_labels = torch.tensor([top_left_label, bottom_right_label], dtype=torch.int, device=device).repeat(B)
|
|
|
|
|
|
|
| 175 |
if noise > 0.0:
|
| 176 |
if not isinstance(noise_bound, torch.Tensor):
|
| 177 |
noise_bound = torch.tensor(noise_bound, device=device)
|
|
|
|
| 183 |
box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
|
| 184 |
|
| 185 |
box_coords = box_coords + box_noise
|
| 186 |
+
img_bounds = torch.tensor([W, H, W, H], device=device) - 1 # uncentered pixel coords
|
|
|
|
|
|
|
| 187 |
box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
|
| 188 |
|
| 189 |
box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
|
{sam2 → bboxmaskpose/sam2}/sam2_image_predictor.py
RENAMED
|
@@ -5,16 +5,14 @@
|
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
import logging
|
| 8 |
-
|
| 9 |
from typing import List, Optional, Tuple, Union
|
| 10 |
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
| 13 |
from PIL.Image import Image
|
| 14 |
|
| 15 |
-
from sam2.modeling.sam2_base import SAM2Base
|
| 16 |
-
|
| 17 |
-
from sam2.utils.transforms import SAM2Transforms
|
| 18 |
|
| 19 |
|
| 20 |
class SAM2ImagePredictor:
|
|
@@ -61,9 +59,9 @@ class SAM2ImagePredictor:
|
|
| 61 |
# Spatial dim for backbone feature maps
|
| 62 |
isize = self.model.image_size
|
| 63 |
self._bb_feat_sizes = [
|
| 64 |
-
(isize//4, isize//4),
|
| 65 |
-
(isize//8, isize//8),
|
| 66 |
-
(isize//16, isize//16),
|
| 67 |
]
|
| 68 |
|
| 69 |
@classmethod
|
|
@@ -78,7 +76,7 @@ class SAM2ImagePredictor:
|
|
| 78 |
Returns:
|
| 79 |
(SAM2ImagePredictor): The loaded model.
|
| 80 |
"""
|
| 81 |
-
from sam2.build_sam import build_sam2_hf
|
| 82 |
|
| 83 |
sam_model = build_sam2_hf(model_id, **kwargs)
|
| 84 |
return cls(sam_model, **kwargs)
|
|
@@ -111,9 +109,7 @@ class SAM2ImagePredictor:
|
|
| 111 |
input_image = self._transforms(image)
|
| 112 |
input_image = input_image[None, ...].to(self.device)
|
| 113 |
|
| 114 |
-
assert (
|
| 115 |
-
len(input_image.shape) == 4 and input_image.shape[1] == 3
|
| 116 |
-
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
| 117 |
logging.info("Computing image embeddings for the provided image...")
|
| 118 |
backbone_out = self.model.forward_image(input_image)
|
| 119 |
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
|
@@ -122,10 +118,9 @@ class SAM2ImagePredictor:
|
|
| 122 |
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
| 123 |
|
| 124 |
# breakpoint()
|
| 125 |
-
feats = [
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
][::-1]
|
| 129 |
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
| 130 |
self._is_image_set = True
|
| 131 |
logging.info("Image embeddings computed.")
|
|
@@ -148,17 +143,13 @@ class SAM2ImagePredictor:
|
|
| 148 |
assert isinstance(image_list, list)
|
| 149 |
self._orig_hw = []
|
| 150 |
for image in image_list:
|
| 151 |
-
assert isinstance(
|
| 152 |
-
image, np.ndarray
|
| 153 |
-
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
| 154 |
self._orig_hw.append(image.shape[:2])
|
| 155 |
# Transform the image to the form expected by the model
|
| 156 |
img_batch = self._transforms.forward_batch(image_list)
|
| 157 |
img_batch = img_batch.to(self.device)
|
| 158 |
batch_size = img_batch.shape[0]
|
| 159 |
-
assert (
|
| 160 |
-
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
|
| 161 |
-
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
| 162 |
logging.info("Computing image embeddings for the provided images...")
|
| 163 |
backbone_out = self.model.forward_image(img_batch)
|
| 164 |
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
|
@@ -167,8 +158,7 @@ class SAM2ImagePredictor:
|
|
| 167 |
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
| 168 |
|
| 169 |
feats = [
|
| 170 |
-
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
| 171 |
-
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
| 172 |
][::-1]
|
| 173 |
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
| 174 |
self._is_image_set = True
|
|
@@ -190,25 +180,17 @@ class SAM2ImagePredictor:
|
|
| 190 |
"""
|
| 191 |
assert self._is_batch, "This function should only be used when in batched mode"
|
| 192 |
if not self._is_image_set:
|
| 193 |
-
raise RuntimeError(
|
| 194 |
-
"An image must be set with .set_image_batch(...) before mask prediction."
|
| 195 |
-
)
|
| 196 |
num_images = len(self._features["image_embed"])
|
| 197 |
all_masks = []
|
| 198 |
all_ious = []
|
| 199 |
all_low_res_masks = []
|
| 200 |
for img_idx in range(num_images):
|
| 201 |
# Transform input prompts
|
| 202 |
-
point_coords =
|
| 203 |
-
|
| 204 |
-
)
|
| 205 |
-
point_labels = (
|
| 206 |
-
point_labels_batch[img_idx] if point_labels_batch is not None else None
|
| 207 |
-
)
|
| 208 |
box = box_batch[img_idx] if box_batch is not None else None
|
| 209 |
-
mask_input =
|
| 210 |
-
mask_input_batch[img_idx] if mask_input_batch is not None else None
|
| 211 |
-
)
|
| 212 |
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
| 213 |
point_coords,
|
| 214 |
point_labels,
|
|
@@ -227,9 +209,7 @@ class SAM2ImagePredictor:
|
|
| 227 |
img_idx=img_idx,
|
| 228 |
)
|
| 229 |
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
|
| 230 |
-
iou_predictions_np = (
|
| 231 |
-
iou_predictions.squeeze(0).float().detach().cpu().numpy()
|
| 232 |
-
)
|
| 233 |
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
| 234 |
all_masks.append(masks_np)
|
| 235 |
all_ious.append(iou_predictions_np)
|
|
@@ -281,15 +261,11 @@ class SAM2ImagePredictor:
|
|
| 281 |
a subsequent iteration as mask input.
|
| 282 |
"""
|
| 283 |
if not self._is_image_set:
|
| 284 |
-
raise RuntimeError(
|
| 285 |
-
"An image must be set with .set_image(...) before mask prediction."
|
| 286 |
-
)
|
| 287 |
|
| 288 |
# Transform input prompts
|
| 289 |
|
| 290 |
-
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
| 291 |
-
point_coords, point_labels, box, mask_input, normalize_coords
|
| 292 |
-
)
|
| 293 |
|
| 294 |
masks, iou_predictions, low_res_masks = self._predict(
|
| 295 |
unnorm_coords,
|
|
@@ -305,33 +281,21 @@ class SAM2ImagePredictor:
|
|
| 305 |
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
| 306 |
return masks_np, iou_predictions_np, low_res_masks_np
|
| 307 |
|
| 308 |
-
def _prep_prompts(
|
| 309 |
-
self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
|
| 310 |
-
):
|
| 311 |
|
| 312 |
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
|
| 313 |
if point_coords is not None:
|
| 314 |
-
assert
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
point_coords = torch.as_tensor(
|
| 318 |
-
point_coords, dtype=torch.float, device=self.device
|
| 319 |
-
)
|
| 320 |
-
unnorm_coords = self._transforms.transform_coords(
|
| 321 |
-
point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
|
| 322 |
-
)
|
| 323 |
labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
| 324 |
if len(unnorm_coords.shape) == 2:
|
| 325 |
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
|
| 326 |
if box is not None:
|
| 327 |
box = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
| 328 |
-
unnorm_box = self._transforms.transform_boxes(
|
| 329 |
-
box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
|
| 330 |
-
) # Bx2x2
|
| 331 |
if mask_logits is not None:
|
| 332 |
-
mask_input = torch.as_tensor(
|
| 333 |
-
mask_logits, dtype=torch.float, device=self.device
|
| 334 |
-
)
|
| 335 |
if len(mask_input.shape) == 3:
|
| 336 |
mask_input = mask_input[None, :, :, :]
|
| 337 |
return mask_input, unnorm_coords, labels, unnorm_box
|
|
@@ -383,9 +347,7 @@ class SAM2ImagePredictor:
|
|
| 383 |
a subsequent iteration as mask input.
|
| 384 |
"""
|
| 385 |
if not self._is_image_set:
|
| 386 |
-
raise RuntimeError(
|
| 387 |
-
"An image must be set with .set_image(...) before mask prediction."
|
| 388 |
-
)
|
| 389 |
|
| 390 |
if point_coords is not None:
|
| 391 |
concat_points = (point_coords, point_labels)
|
|
@@ -413,13 +375,8 @@ class SAM2ImagePredictor:
|
|
| 413 |
)
|
| 414 |
|
| 415 |
# Predict masks
|
| 416 |
-
batched_mode =
|
| 417 |
-
|
| 418 |
-
) # multi object prediction
|
| 419 |
-
high_res_features = [
|
| 420 |
-
feat_level[img_idx].unsqueeze(0)
|
| 421 |
-
for feat_level in self._features["high_res_feats"]
|
| 422 |
-
]
|
| 423 |
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
|
| 424 |
image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
|
| 425 |
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
|
@@ -431,9 +388,7 @@ class SAM2ImagePredictor:
|
|
| 431 |
)
|
| 432 |
|
| 433 |
# Upscale the masks to the original image resolution
|
| 434 |
-
masks = self._transforms.postprocess_masks(
|
| 435 |
-
low_res_masks, self._orig_hw[img_idx]
|
| 436 |
-
)
|
| 437 |
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
|
| 438 |
if not return_logits:
|
| 439 |
masks = masks > self.mask_threshold
|
|
@@ -447,12 +402,8 @@ class SAM2ImagePredictor:
|
|
| 447 |
the embedding spatial dimension of SAM (typically C=256, H=W=64).
|
| 448 |
"""
|
| 449 |
if not self._is_image_set:
|
| 450 |
-
raise RuntimeError(
|
| 451 |
-
|
| 452 |
-
)
|
| 453 |
-
assert (
|
| 454 |
-
self._features is not None
|
| 455 |
-
), "Features must exist if an image has been set."
|
| 456 |
return self._features["image_embed"]
|
| 457 |
|
| 458 |
@property
|
|
|
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
import logging
|
|
|
|
| 8 |
from typing import List, Optional, Tuple, Union
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
from PIL.Image import Image
|
| 13 |
|
| 14 |
+
from bboxmaskpose.sam2.modeling.sam2_base import SAM2Base
|
| 15 |
+
from bboxmaskpose.sam2.utils.transforms import SAM2Transforms
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class SAM2ImagePredictor:
|
|
|
|
| 59 |
# Spatial dim for backbone feature maps
|
| 60 |
isize = self.model.image_size
|
| 61 |
self._bb_feat_sizes = [
|
| 62 |
+
(isize // 4, isize // 4),
|
| 63 |
+
(isize // 8, isize // 8),
|
| 64 |
+
(isize // 16, isize // 16),
|
| 65 |
]
|
| 66 |
|
| 67 |
@classmethod
|
|
|
|
| 76 |
Returns:
|
| 77 |
(SAM2ImagePredictor): The loaded model.
|
| 78 |
"""
|
| 79 |
+
from bboxmaskpose.sam2.build_sam import build_sam2_hf
|
| 80 |
|
| 81 |
sam_model = build_sam2_hf(model_id, **kwargs)
|
| 82 |
return cls(sam_model, **kwargs)
|
|
|
|
| 109 |
input_image = self._transforms(image)
|
| 110 |
input_image = input_image[None, ...].to(self.device)
|
| 111 |
|
| 112 |
+
assert len(input_image.shape) == 4 and input_image.shape[1] == 3, f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
|
|
|
|
|
|
| 113 |
logging.info("Computing image embeddings for the provided image...")
|
| 114 |
backbone_out = self.model.forward_image(input_image)
|
| 115 |
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
|
|
|
| 118 |
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
| 119 |
|
| 120 |
# breakpoint()
|
| 121 |
+
feats = [feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])][
|
| 122 |
+
::-1
|
| 123 |
+
]
|
|
|
|
| 124 |
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
| 125 |
self._is_image_set = True
|
| 126 |
logging.info("Image embeddings computed.")
|
|
|
|
| 143 |
assert isinstance(image_list, list)
|
| 144 |
self._orig_hw = []
|
| 145 |
for image in image_list:
|
| 146 |
+
assert isinstance(image, np.ndarray), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
|
|
|
|
|
|
| 147 |
self._orig_hw.append(image.shape[:2])
|
| 148 |
# Transform the image to the form expected by the model
|
| 149 |
img_batch = self._transforms.forward_batch(image_list)
|
| 150 |
img_batch = img_batch.to(self.device)
|
| 151 |
batch_size = img_batch.shape[0]
|
| 152 |
+
assert len(img_batch.shape) == 4 and img_batch.shape[1] == 3, f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
|
|
|
|
|
|
| 153 |
logging.info("Computing image embeddings for the provided images...")
|
| 154 |
backbone_out = self.model.forward_image(img_batch)
|
| 155 |
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
|
|
|
| 158 |
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
| 159 |
|
| 160 |
feats = [
|
| 161 |
+
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
|
|
|
| 162 |
][::-1]
|
| 163 |
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
| 164 |
self._is_image_set = True
|
|
|
|
| 180 |
"""
|
| 181 |
assert self._is_batch, "This function should only be used when in batched mode"
|
| 182 |
if not self._is_image_set:
|
| 183 |
+
raise RuntimeError("An image must be set with .set_image_batch(...) before mask prediction.")
|
|
|
|
|
|
|
| 184 |
num_images = len(self._features["image_embed"])
|
| 185 |
all_masks = []
|
| 186 |
all_ious = []
|
| 187 |
all_low_res_masks = []
|
| 188 |
for img_idx in range(num_images):
|
| 189 |
# Transform input prompts
|
| 190 |
+
point_coords = point_coords_batch[img_idx] if point_coords_batch is not None else None
|
| 191 |
+
point_labels = point_labels_batch[img_idx] if point_labels_batch is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
box = box_batch[img_idx] if box_batch is not None else None
|
| 193 |
+
mask_input = mask_input_batch[img_idx] if mask_input_batch is not None else None
|
|
|
|
|
|
|
| 194 |
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
| 195 |
point_coords,
|
| 196 |
point_labels,
|
|
|
|
| 209 |
img_idx=img_idx,
|
| 210 |
)
|
| 211 |
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
|
| 212 |
+
iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
|
|
|
|
|
|
|
| 213 |
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
| 214 |
all_masks.append(masks_np)
|
| 215 |
all_ious.append(iou_predictions_np)
|
|
|
|
| 261 |
a subsequent iteration as mask input.
|
| 262 |
"""
|
| 263 |
if not self._is_image_set:
|
| 264 |
+
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
|
|
|
|
|
|
| 265 |
|
| 266 |
# Transform input prompts
|
| 267 |
|
| 268 |
+
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(point_coords, point_labels, box, mask_input, normalize_coords)
|
|
|
|
|
|
|
| 269 |
|
| 270 |
masks, iou_predictions, low_res_masks = self._predict(
|
| 271 |
unnorm_coords,
|
|
|
|
| 281 |
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
| 282 |
return masks_np, iou_predictions_np, low_res_masks_np
|
| 283 |
|
| 284 |
+
def _prep_prompts(self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1):
|
|
|
|
|
|
|
| 285 |
|
| 286 |
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
|
| 287 |
if point_coords is not None:
|
| 288 |
+
assert point_labels is not None, "point_labels must be supplied if point_coords is supplied."
|
| 289 |
+
point_coords = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
|
| 290 |
+
unnorm_coords = self._transforms.transform_coords(point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
| 292 |
if len(unnorm_coords.shape) == 2:
|
| 293 |
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
|
| 294 |
if box is not None:
|
| 295 |
box = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
| 296 |
+
unnorm_box = self._transforms.transform_boxes(box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]) # Bx2x2
|
|
|
|
|
|
|
| 297 |
if mask_logits is not None:
|
| 298 |
+
mask_input = torch.as_tensor(mask_logits, dtype=torch.float, device=self.device)
|
|
|
|
|
|
|
| 299 |
if len(mask_input.shape) == 3:
|
| 300 |
mask_input = mask_input[None, :, :, :]
|
| 301 |
return mask_input, unnorm_coords, labels, unnorm_box
|
|
|
|
| 347 |
a subsequent iteration as mask input.
|
| 348 |
"""
|
| 349 |
if not self._is_image_set:
|
| 350 |
+
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
|
|
|
|
|
|
| 351 |
|
| 352 |
if point_coords is not None:
|
| 353 |
concat_points = (point_coords, point_labels)
|
|
|
|
| 375 |
)
|
| 376 |
|
| 377 |
# Predict masks
|
| 378 |
+
batched_mode = concat_points is not None and concat_points[0].shape[0] > 1 # multi object prediction
|
| 379 |
+
high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in self._features["high_res_feats"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
|
| 381 |
image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
|
| 382 |
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
|
|
|
| 388 |
)
|
| 389 |
|
| 390 |
# Upscale the masks to the original image resolution
|
| 391 |
+
masks = self._transforms.postprocess_masks(low_res_masks, self._orig_hw[img_idx])
|
|
|
|
|
|
|
| 392 |
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
|
| 393 |
if not return_logits:
|
| 394 |
masks = masks > self.mask_threshold
|
|
|
|
| 402 |
the embedding spatial dimension of SAM (typically C=256, H=W=64).
|
| 403 |
"""
|
| 404 |
if not self._is_image_set:
|
| 405 |
+
raise RuntimeError("An image must be set with .set_image(...) to generate an embedding.")
|
| 406 |
+
assert self._features is not None, "Features must exist if an image has been set."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
return self._features["image_embed"]
|
| 408 |
|
| 409 |
@property
|
{sam2 → bboxmaskpose/sam2}/sam2_video_predictor.py
RENAMED
|
@@ -9,7 +9,6 @@ from collections import OrderedDict
|
|
| 9 |
|
| 10 |
import torch
|
| 11 |
import torch.nn.functional as F
|
| 12 |
-
|
| 13 |
from tqdm import tqdm
|
| 14 |
|
| 15 |
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
|
|
@@ -27,11 +26,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 27 |
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
|
| 28 |
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
|
| 29 |
clear_non_cond_mem_around_input=False,
|
| 30 |
-
<<<<<<< HEAD
|
| 31 |
-
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
|
| 32 |
-
clear_non_cond_mem_for_multi_obj=False,
|
| 33 |
-
=======
|
| 34 |
-
>>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
|
| 35 |
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
| 36 |
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
| 37 |
add_all_frames_to_correct_as_cond=False,
|
|
@@ -41,10 +35,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 41 |
self.fill_hole_area = fill_hole_area
|
| 42 |
self.non_overlap_masks = non_overlap_masks
|
| 43 |
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
| 44 |
-
<<<<<<< HEAD
|
| 45 |
-
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
|
| 46 |
-
=======
|
| 47 |
-
>>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
|
| 48 |
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
| 49 |
|
| 50 |
@torch.inference_mode()
|
|
@@ -296,9 +286,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 296 |
is_cond=is_cond,
|
| 297 |
consolidate_at_video_res=True,
|
| 298 |
)
|
| 299 |
-
_, video_res_masks = self._get_orig_video_res_output(
|
| 300 |
-
inference_state, consolidated_out["pred_masks_video_res"]
|
| 301 |
-
)
|
| 302 |
return frame_idx, obj_ids, video_res_masks
|
| 303 |
|
| 304 |
def add_new_points(self, *args, **kwargs):
|
|
@@ -384,9 +372,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 384 |
is_cond=is_cond,
|
| 385 |
consolidate_at_video_res=True,
|
| 386 |
)
|
| 387 |
-
_, video_res_masks = self._get_orig_video_res_output(
|
| 388 |
-
inference_state, consolidated_out["pred_masks_video_res"]
|
| 389 |
-
)
|
| 390 |
return frame_idx, obj_ids, video_res_masks
|
| 391 |
|
| 392 |
def _get_orig_video_res_output(self, inference_state, any_res_masks):
|
|
@@ -450,23 +436,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 450 |
dtype=torch.float32,
|
| 451 |
device=inference_state["storage_device"],
|
| 452 |
),
|
| 453 |
-
<<<<<<< HEAD
|
| 454 |
-
"obj_ptr": torch.full(
|
| 455 |
-
size=(batch_size, self.hidden_dim),
|
| 456 |
-
fill_value=NO_OBJ_SCORE,
|
| 457 |
-
dtype=torch.float32,
|
| 458 |
-
device=inference_state["device"],
|
| 459 |
-
),
|
| 460 |
-
"object_score_logits": torch.full(
|
| 461 |
-
size=(batch_size, 1),
|
| 462 |
-
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
| 463 |
-
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
| 464 |
-
fill_value=10.0,
|
| 465 |
-
dtype=torch.float32,
|
| 466 |
-
device=inference_state["device"],
|
| 467 |
-
),
|
| 468 |
-
=======
|
| 469 |
-
>>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
|
| 470 |
}
|
| 471 |
for obj_idx in range(batch_size):
|
| 472 |
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
|
@@ -499,36 +468,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 499 |
align_corners=False,
|
| 500 |
)
|
| 501 |
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
| 502 |
-
<<<<<<< HEAD
|
| 503 |
-
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
| 504 |
-
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
|
| 505 |
-
"object_score_logits"
|
| 506 |
-
]
|
| 507 |
-
|
| 508 |
-
# Optionally, apply non-overlapping constraints on the consolidated scores
|
| 509 |
-
# and rerun the memory encoder
|
| 510 |
-
if run_mem_encoder:
|
| 511 |
-
device = inference_state["device"]
|
| 512 |
-
high_res_masks = torch.nn.functional.interpolate(
|
| 513 |
-
consolidated_out["pred_masks"].to(device, non_blocking=True),
|
| 514 |
-
size=(self.image_size, self.image_size),
|
| 515 |
-
mode="bilinear",
|
| 516 |
-
align_corners=False,
|
| 517 |
-
)
|
| 518 |
-
if self.non_overlap_masks_for_mem_enc:
|
| 519 |
-
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
|
| 520 |
-
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
| 521 |
-
inference_state=inference_state,
|
| 522 |
-
frame_idx=frame_idx,
|
| 523 |
-
batch_size=batch_size,
|
| 524 |
-
high_res_masks=high_res_masks,
|
| 525 |
-
object_score_logits=consolidated_out["object_score_logits"],
|
| 526 |
-
is_mask_from_pts=True, # these frames are what the user interacted with
|
| 527 |
-
)
|
| 528 |
-
consolidated_out["maskmem_features"] = maskmem_features
|
| 529 |
-
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
| 530 |
-
=======
|
| 531 |
-
>>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
|
| 532 |
|
| 533 |
return consolidated_out
|
| 534 |
|
|
@@ -538,9 +477,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 538 |
# Check and make sure that every object has received input points or masks.
|
| 539 |
batch_size = self._get_obj_num(inference_state)
|
| 540 |
if batch_size == 0:
|
| 541 |
-
raise RuntimeError(
|
| 542 |
-
"No input points or masks are provided for any object; please add inputs first."
|
| 543 |
-
)
|
| 544 |
|
| 545 |
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
| 546 |
# add them into "output_dict".
|
|
@@ -549,9 +486,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 549 |
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
| 550 |
for is_cond in [False, True]:
|
| 551 |
# Separately consolidate conditioning and non-conditioning temp outputs
|
| 552 |
-
storage_key =
|
| 553 |
-
"cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
| 554 |
-
)
|
| 555 |
# Find all the frames that contain temporary outputs for any objects
|
| 556 |
# (these should be the frames that have just received clicks for mask inputs
|
| 557 |
# via `add_new_points_or_box` or `add_new_mask`)
|
|
@@ -579,9 +514,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 579 |
obj_output_dict[storage_key][frame_idx] = out
|
| 580 |
if self.clear_non_cond_mem_around_input:
|
| 581 |
# clear non-conditioning memory of the surrounding frames
|
| 582 |
-
self._clear_obj_non_cond_mem_around_input(
|
| 583 |
-
inference_state, frame_idx, obj_idx
|
| 584 |
-
)
|
| 585 |
|
| 586 |
# clear temporary outputs in `temp_output_dict_per_obj`
|
| 587 |
obj_temp_output_dict[storage_key].clear()
|
|
@@ -590,9 +523,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 590 |
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| 591 |
if len(obj_output_dict["cond_frame_outputs"]) == 0:
|
| 592 |
obj_id = self._obj_idx_to_id(inference_state, obj_idx)
|
| 593 |
-
raise RuntimeError(
|
| 594 |
-
f"No input points or masks are provided for object id {obj_id}; please add inputs first."
|
| 595 |
-
)
|
| 596 |
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
| 597 |
# output on the same frame in "non_cond_frame_outputs"
|
| 598 |
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
|
@@ -617,9 +548,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 617 |
if start_frame_idx is None:
|
| 618 |
# default: start from the earliest frame with input points
|
| 619 |
start_frame_idx = min(
|
| 620 |
-
t
|
| 621 |
-
for obj_output_dict in inference_state["output_dict_per_obj"].values()
|
| 622 |
-
for t in obj_output_dict["cond_frame_outputs"]
|
| 623 |
)
|
| 624 |
if max_frame_num_to_track is None:
|
| 625 |
# default: track all the frames in the video
|
|
@@ -631,9 +560,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 631 |
else:
|
| 632 |
processing_order = [] # skip reverse tracking if starting from frame 0
|
| 633 |
else:
|
| 634 |
-
end_frame_idx = min(
|
| 635 |
-
start_frame_idx + max_frame_num_to_track, num_frames - 1
|
| 636 |
-
)
|
| 637 |
processing_order = range(start_frame_idx, end_frame_idx + 1)
|
| 638 |
|
| 639 |
for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
|
@@ -651,9 +578,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 651 |
pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
|
| 652 |
if self.clear_non_cond_mem_around_input:
|
| 653 |
# clear non-conditioning memory of the surrounding frames
|
| 654 |
-
self._clear_obj_non_cond_mem_around_input(
|
| 655 |
-
inference_state, frame_idx, obj_idx
|
| 656 |
-
)
|
| 657 |
else:
|
| 658 |
storage_key = "non_cond_frame_outputs"
|
| 659 |
current_out, pred_masks = self._run_single_frame_inference(
|
|
@@ -669,9 +594,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 669 |
)
|
| 670 |
obj_output_dict[storage_key][frame_idx] = current_out
|
| 671 |
|
| 672 |
-
inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
|
| 673 |
-
"reverse": reverse
|
| 674 |
-
}
|
| 675 |
pred_masks_per_obj[obj_idx] = pred_masks
|
| 676 |
|
| 677 |
# Resize the output mask to the original video resolution (we directly use
|
|
@@ -680,42 +603,11 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 680 |
all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
|
| 681 |
else:
|
| 682 |
all_pred_masks = pred_masks_per_obj[0]
|
| 683 |
-
_, video_res_masks = self._get_orig_video_res_output(
|
| 684 |
-
inference_state, all_pred_masks
|
| 685 |
-
)
|
| 686 |
yield frame_idx, obj_ids, video_res_masks
|
| 687 |
|
| 688 |
@torch.inference_mode()
|
| 689 |
-
def clear_all_prompts_in_frame(
|
| 690 |
-
self, inference_state, frame_idx, obj_id, need_output=True
|
| 691 |
-
):
|
| 692 |
-
<<<<<<< HEAD
|
| 693 |
-
"""
|
| 694 |
-
Split a multi-object output into per-object output slices and add them into
|
| 695 |
-
`output_dict_per_obj`. The resulting slices share the same tensor storage.
|
| 696 |
-
"""
|
| 697 |
-
maskmem_features = current_out["maskmem_features"]
|
| 698 |
-
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
| 699 |
-
|
| 700 |
-
maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
| 701 |
-
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
| 702 |
-
|
| 703 |
-
output_dict_per_obj = inference_state["output_dict_per_obj"]
|
| 704 |
-
for obj_idx, obj_output_dict in output_dict_per_obj.items():
|
| 705 |
-
obj_slice = slice(obj_idx, obj_idx + 1)
|
| 706 |
-
obj_out = {
|
| 707 |
-
"maskmem_features": None,
|
| 708 |
-
"maskmem_pos_enc": None,
|
| 709 |
-
"pred_masks": current_out["pred_masks"][obj_slice],
|
| 710 |
-
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
| 711 |
-
"object_score_logits": current_out["object_score_logits"][obj_slice],
|
| 712 |
-
}
|
| 713 |
-
if maskmem_features is not None:
|
| 714 |
-
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
| 715 |
-
if maskmem_pos_enc is not None:
|
| 716 |
-
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
| 717 |
-
obj_output_dict[storage_key][frame_idx] = obj_out
|
| 718 |
-
=======
|
| 719 |
"""Remove all input points or mask in a specific frame for a given object."""
|
| 720 |
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
| 721 |
|
|
@@ -740,91 +632,14 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 740 |
return
|
| 741 |
# Finally, output updated masks per object (after removing the inputs above)
|
| 742 |
obj_ids = inference_state["obj_ids"]
|
| 743 |
-
is_cond = any(
|
| 744 |
-
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
| 745 |
-
for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
| 746 |
-
)
|
| 747 |
consolidated_out = self._consolidate_temp_output_across_obj(
|
| 748 |
inference_state,
|
| 749 |
frame_idx,
|
| 750 |
is_cond=is_cond,
|
| 751 |
consolidate_at_video_res=True,
|
| 752 |
)
|
| 753 |
-
_, video_res_masks = self._get_orig_video_res_output(
|
| 754 |
-
inference_state, consolidated_out["pred_masks_video_res"]
|
| 755 |
-
)
|
| 756 |
-
return frame_idx, obj_ids, video_res_masks
|
| 757 |
-
>>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
|
| 758 |
-
|
| 759 |
-
@torch.inference_mode()
|
| 760 |
-
def clear_all_prompts_in_frame(
|
| 761 |
-
self, inference_state, frame_idx, obj_id, need_output=True
|
| 762 |
-
):
|
| 763 |
-
"""Remove all input points or mask in a specific frame for a given object."""
|
| 764 |
-
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
| 765 |
-
|
| 766 |
-
# Clear the conditioning information on the given frame
|
| 767 |
-
inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
| 768 |
-
inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
| 769 |
-
|
| 770 |
-
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
| 771 |
-
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
|
| 772 |
-
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
|
| 773 |
-
|
| 774 |
-
# Check and see if there are still any inputs left on this frame
|
| 775 |
-
batch_size = self._get_obj_num(inference_state)
|
| 776 |
-
frame_has_input = False
|
| 777 |
-
for obj_idx2 in range(batch_size):
|
| 778 |
-
if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
|
| 779 |
-
frame_has_input = True
|
| 780 |
-
break
|
| 781 |
-
if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
|
| 782 |
-
frame_has_input = True
|
| 783 |
-
break
|
| 784 |
-
|
| 785 |
-
# If this frame has no remaining inputs for any objects, we further clear its
|
| 786 |
-
# conditioning frame status
|
| 787 |
-
if not frame_has_input:
|
| 788 |
-
output_dict = inference_state["output_dict"]
|
| 789 |
-
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
| 790 |
-
consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
|
| 791 |
-
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
| 792 |
-
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
|
| 793 |
-
out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
| 794 |
-
if out is not None:
|
| 795 |
-
# The frame is not a conditioning frame anymore since it's not receiving inputs,
|
| 796 |
-
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
|
| 797 |
-
output_dict["non_cond_frame_outputs"][frame_idx] = out
|
| 798 |
-
inference_state["frames_already_tracked"].pop(frame_idx, None)
|
| 799 |
-
# Similarly, do it for the sliced output on each object.
|
| 800 |
-
for obj_idx2 in range(batch_size):
|
| 801 |
-
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
|
| 802 |
-
obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
| 803 |
-
if obj_out is not None:
|
| 804 |
-
obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
|
| 805 |
-
|
| 806 |
-
# If all the conditioning frames have been removed, we also clear the tracking outputs
|
| 807 |
-
if len(output_dict["cond_frame_outputs"]) == 0:
|
| 808 |
-
self._reset_tracking_results(inference_state)
|
| 809 |
-
|
| 810 |
-
if not need_output:
|
| 811 |
-
return
|
| 812 |
-
# Finally, output updated masks per object (after removing the inputs above)
|
| 813 |
-
obj_ids = inference_state["obj_ids"]
|
| 814 |
-
is_cond = any(
|
| 815 |
-
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
| 816 |
-
for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
| 817 |
-
)
|
| 818 |
-
consolidated_out = self._consolidate_temp_output_across_obj(
|
| 819 |
-
inference_state,
|
| 820 |
-
frame_idx,
|
| 821 |
-
is_cond=is_cond,
|
| 822 |
-
run_mem_encoder=False,
|
| 823 |
-
consolidate_at_video_res=True,
|
| 824 |
-
)
|
| 825 |
-
_, video_res_masks = self._get_orig_video_res_output(
|
| 826 |
-
inference_state, consolidated_out["pred_masks_video_res"]
|
| 827 |
-
)
|
| 828 |
return frame_idx, obj_ids, video_res_masks
|
| 829 |
|
| 830 |
@torch.inference_mode()
|
|
@@ -859,9 +674,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 859 |
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
| 860 |
"""Compute the image features on a given frame."""
|
| 861 |
# Look up in the cache first
|
| 862 |
-
image, backbone_out = inference_state["cached_features"].get(
|
| 863 |
-
frame_idx, (None, None)
|
| 864 |
-
)
|
| 865 |
if backbone_out is None:
|
| 866 |
# Cache miss -- we will run inference on a single image
|
| 867 |
device = inference_state["device"]
|
|
@@ -878,9 +691,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 878 |
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
|
| 879 |
}
|
| 880 |
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
|
| 881 |
-
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
|
| 882 |
-
batch_size, -1, -1, -1
|
| 883 |
-
)
|
| 884 |
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
|
| 885 |
pos = pos.expand(batch_size, -1, -1, -1)
|
| 886 |
expanded_backbone_out["vision_pos_enc"][i] = pos
|
|
@@ -935,33 +746,23 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 935 |
if maskmem_features is not None:
|
| 936 |
maskmem_features = maskmem_features.to(torch.bfloat16)
|
| 937 |
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
| 938 |
-
pred_masks_gpu = current_out["pred_masks"]
|
| 939 |
# potentially fill holes in the predicted masks
|
| 940 |
if self.fill_hole_area > 0:
|
| 941 |
-
pred_masks_gpu = fill_holes_in_mask_scores(
|
| 942 |
-
pred_masks_gpu, self.fill_hole_area
|
| 943 |
-
)
|
| 944 |
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
| 945 |
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 946 |
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
| 947 |
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
| 948 |
obj_ptr = current_out["obj_ptr"]
|
| 949 |
object_score_logits = current_out["object_score_logits"]
|
| 950 |
-
<<<<<<< HEAD
|
| 951 |
-
best_iou_score = current_out["best_iou_score"]
|
| 952 |
-
=======
|
| 953 |
-
>>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
|
| 954 |
# make a compact version of this frame's output to reduce the state size
|
| 955 |
compact_current_out = {
|
| 956 |
-
"maskmem_features": maskmem_features,
|
| 957 |
-
"maskmem_pos_enc": maskmem_pos_enc,
|
| 958 |
"pred_masks": pred_masks,
|
| 959 |
"obj_ptr": obj_ptr,
|
| 960 |
"object_score_logits": object_score_logits,
|
| 961 |
-
<<<<<<< HEAD
|
| 962 |
-
"best_iou_score": best_iou_score,
|
| 963 |
-
=======
|
| 964 |
-
>>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
|
| 965 |
}
|
| 966 |
return compact_current_out, pred_masks_gpu
|
| 967 |
|
|
@@ -980,9 +781,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 980 |
memory also need to be computed again with the memory encoder.
|
| 981 |
"""
|
| 982 |
# Retrieve correct image features
|
| 983 |
-
_, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
|
| 984 |
-
inference_state, frame_idx, batch_size
|
| 985 |
-
)
|
| 986 |
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
| 987 |
current_vision_feats=current_vision_feats,
|
| 988 |
feat_sizes=feat_sizes,
|
|
@@ -996,9 +795,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 996 |
maskmem_features = maskmem_features.to(torch.bfloat16)
|
| 997 |
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
| 998 |
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 999 |
-
maskmem_pos_enc = self._get_maskmem_pos_enc(
|
| 1000 |
-
inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
|
| 1001 |
-
)
|
| 1002 |
return maskmem_features, maskmem_pos_enc
|
| 1003 |
|
| 1004 |
def _get_maskmem_pos_enc(self, inference_state, current_out):
|
|
@@ -1019,9 +816,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 1019 |
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
| 1020 |
# expand the cached maskmem_pos_enc to the actual batch size
|
| 1021 |
batch_size = out_maskmem_pos_enc[0].size(0)
|
| 1022 |
-
expanded_maskmem_pos_enc = [
|
| 1023 |
-
x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
|
| 1024 |
-
]
|
| 1025 |
else:
|
| 1026 |
expanded_maskmem_pos_enc = None
|
| 1027 |
return expanded_maskmem_pos_enc
|
|
@@ -1039,8 +834,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 1039 |
if not strict:
|
| 1040 |
return inference_state["obj_ids"], updated_frames
|
| 1041 |
raise RuntimeError(
|
| 1042 |
-
f"Cannot remove object id {obj_id} as it doesn't exist. "
|
| 1043 |
-
f"All existing object ids: {inference_state['obj_ids']}."
|
| 1044 |
)
|
| 1045 |
|
| 1046 |
# If this is the only remaining object id, we simply reset the state.
|
|
@@ -1054,16 +848,10 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 1054 |
# (note that this step is required as it might downgrade conditioning frames to
|
| 1055 |
# non-conditioning ones)
|
| 1056 |
obj_input_frames_inds = set()
|
| 1057 |
-
obj_input_frames_inds.update(
|
| 1058 |
-
|
| 1059 |
-
)
|
| 1060 |
-
obj_input_frames_inds.update(
|
| 1061 |
-
inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
|
| 1062 |
-
)
|
| 1063 |
for frame_idx in obj_input_frames_inds:
|
| 1064 |
-
self.clear_all_prompts_in_frame(
|
| 1065 |
-
inference_state, frame_idx, obj_id, need_output=False
|
| 1066 |
-
)
|
| 1067 |
|
| 1068 |
# Step 1: Update the object id mapping (note that it must be done after Step 0,
|
| 1069 |
# since Step 0 still requires the old object id mappings in inference_state)
|
|
@@ -1080,11 +868,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 1080 |
inference_state["obj_ids"] = new_obj_ids
|
| 1081 |
|
| 1082 |
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
|
| 1083 |
-
<<<<<<< HEAD
|
| 1084 |
-
# (note that "consolidated_frame_inds" doesn't need to be updated in this step as
|
| 1085 |
-
# it's already handled in Step 0)
|
| 1086 |
-
=======
|
| 1087 |
-
>>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
|
| 1088 |
def _map_keys(container):
|
| 1089 |
new_kvs = []
|
| 1090 |
for k in old_obj_inds:
|
|
@@ -1097,57 +880,23 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 1097 |
_map_keys(inference_state["mask_inputs_per_obj"])
|
| 1098 |
_map_keys(inference_state["output_dict_per_obj"])
|
| 1099 |
_map_keys(inference_state["temp_output_dict_per_obj"])
|
| 1100 |
-
<<<<<<< HEAD
|
| 1101 |
-
|
| 1102 |
-
# Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
|
| 1103 |
-
def _slice_state(output_dict, storage_key):
|
| 1104 |
-
for frame_idx, out in output_dict[storage_key].items():
|
| 1105 |
-
out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
|
| 1106 |
-
out["maskmem_pos_enc"] = [
|
| 1107 |
-
x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
|
| 1108 |
-
]
|
| 1109 |
-
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 1110 |
-
out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
|
| 1111 |
-
out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
|
| 1112 |
-
out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
|
| 1113 |
-
out["object_score_logits"] = out["object_score_logits"][
|
| 1114 |
-
remain_old_obj_inds
|
| 1115 |
-
]
|
| 1116 |
-
# also update the per-object slices
|
| 1117 |
-
self._add_output_per_object(
|
| 1118 |
-
inference_state, frame_idx, out, storage_key
|
| 1119 |
-
)
|
| 1120 |
-
|
| 1121 |
-
_slice_state(inference_state["output_dict"], "cond_frame_outputs")
|
| 1122 |
-
_slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
|
| 1123 |
-
|
| 1124 |
-
# Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
| 1125 |
-
=======
|
| 1126 |
_map_keys(inference_state["frames_tracked_per_obj"])
|
| 1127 |
|
| 1128 |
# Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
| 1129 |
-
>>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
|
| 1130 |
# could show an updated mask for objects previously occluded by the object being removed
|
| 1131 |
if need_output:
|
| 1132 |
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
| 1133 |
for frame_idx in obj_input_frames_inds:
|
| 1134 |
is_cond = any(
|
| 1135 |
-
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
| 1136 |
-
for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
| 1137 |
)
|
| 1138 |
consolidated_out = self._consolidate_temp_output_across_obj(
|
| 1139 |
inference_state,
|
| 1140 |
frame_idx,
|
| 1141 |
is_cond=is_cond,
|
| 1142 |
-
<<<<<<< HEAD
|
| 1143 |
-
run_mem_encoder=False,
|
| 1144 |
-
=======
|
| 1145 |
-
>>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
|
| 1146 |
consolidate_at_video_res=True,
|
| 1147 |
)
|
| 1148 |
-
_, video_res_masks = self._get_orig_video_res_output(
|
| 1149 |
-
inference_state, consolidated_out["pred_masks_video_res"]
|
| 1150 |
-
)
|
| 1151 |
updated_frames.append((frame_idx, video_res_masks))
|
| 1152 |
|
| 1153 |
return inference_state["obj_ids"], updated_frames
|
|
@@ -1218,18 +967,12 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
|
|
| 1218 |
if self.use_high_res_features_in_sam:
|
| 1219 |
# precompute projected level 0 and level 1 features in SAM decoder
|
| 1220 |
# to avoid running it again on every SAM click
|
| 1221 |
-
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
|
| 1222 |
-
|
| 1223 |
-
)
|
| 1224 |
-
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
|
| 1225 |
-
backbone_out["backbone_fpn"][1]
|
| 1226 |
-
)
|
| 1227 |
# Clone to help torch.compile
|
| 1228 |
for i in range(len(backbone_out["backbone_fpn"])):
|
| 1229 |
backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
|
| 1230 |
-
backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
|
| 1231 |
-
i
|
| 1232 |
-
].clone()
|
| 1233 |
return backbone_out
|
| 1234 |
|
| 1235 |
def _forward_sam_heads(
|
|
@@ -1388,9 +1131,7 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
|
|
| 1388 |
# optionally, apply non-overlapping constraints to the masks (it's applied
|
| 1389 |
# in the batch dimension and should only be used during eval, where all
|
| 1390 |
# the objects come from the same video under batch size 1).
|
| 1391 |
-
pred_masks_high_res = self._apply_non_overlapping_constraints(
|
| 1392 |
-
pred_masks_high_res
|
| 1393 |
-
)
|
| 1394 |
# scale the raw mask logits with a temperature before applying sigmoid
|
| 1395 |
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
| 1396 |
if binarize and not self.training:
|
|
@@ -1403,9 +1144,7 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
|
|
| 1403 |
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
| 1404 |
if self.sigmoid_bias_for_mem_enc != 0.0:
|
| 1405 |
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
| 1406 |
-
maskmem_out = self.memory_encoder(
|
| 1407 |
-
pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
|
| 1408 |
-
)
|
| 1409 |
# Clone the feats and pos_enc to enable compilation
|
| 1410 |
maskmem_features = maskmem_out["vision_features"].clone()
|
| 1411 |
maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
|
|
@@ -1413,9 +1152,7 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
|
|
| 1413 |
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
| 1414 |
if self.no_obj_embed_spatial is not None:
|
| 1415 |
is_obj_appearing = (object_score_logits > 0).float()
|
| 1416 |
-
maskmem_features += (
|
| 1417 |
-
1 - is_obj_appearing[..., None, None]
|
| 1418 |
-
) * self.no_obj_embed_spatial[..., None, None].expand(
|
| 1419 |
*maskmem_features.shape
|
| 1420 |
)
|
| 1421 |
|
|
|
|
| 9 |
|
| 10 |
import torch
|
| 11 |
import torch.nn.functional as F
|
|
|
|
| 12 |
from tqdm import tqdm
|
| 13 |
|
| 14 |
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
|
|
|
|
| 26 |
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
|
| 27 |
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
|
| 28 |
clear_non_cond_mem_around_input=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
| 30 |
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
| 31 |
add_all_frames_to_correct_as_cond=False,
|
|
|
|
| 35 |
self.fill_hole_area = fill_hole_area
|
| 36 |
self.non_overlap_masks = non_overlap_masks
|
| 37 |
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
| 39 |
|
| 40 |
@torch.inference_mode()
|
|
|
|
| 286 |
is_cond=is_cond,
|
| 287 |
consolidate_at_video_res=True,
|
| 288 |
)
|
| 289 |
+
_, video_res_masks = self._get_orig_video_res_output(inference_state, consolidated_out["pred_masks_video_res"])
|
|
|
|
|
|
|
| 290 |
return frame_idx, obj_ids, video_res_masks
|
| 291 |
|
| 292 |
def add_new_points(self, *args, **kwargs):
|
|
|
|
| 372 |
is_cond=is_cond,
|
| 373 |
consolidate_at_video_res=True,
|
| 374 |
)
|
| 375 |
+
_, video_res_masks = self._get_orig_video_res_output(inference_state, consolidated_out["pred_masks_video_res"])
|
|
|
|
|
|
|
| 376 |
return frame_idx, obj_ids, video_res_masks
|
| 377 |
|
| 378 |
def _get_orig_video_res_output(self, inference_state, any_res_masks):
|
|
|
|
| 436 |
dtype=torch.float32,
|
| 437 |
device=inference_state["storage_device"],
|
| 438 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
}
|
| 440 |
for obj_idx in range(batch_size):
|
| 441 |
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
|
|
|
| 468 |
align_corners=False,
|
| 469 |
)
|
| 470 |
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
return consolidated_out
|
| 473 |
|
|
|
|
| 477 |
# Check and make sure that every object has received input points or masks.
|
| 478 |
batch_size = self._get_obj_num(inference_state)
|
| 479 |
if batch_size == 0:
|
| 480 |
+
raise RuntimeError("No input points or masks are provided for any object; please add inputs first.")
|
|
|
|
|
|
|
| 481 |
|
| 482 |
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
| 483 |
# add them into "output_dict".
|
|
|
|
| 486 |
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
| 487 |
for is_cond in [False, True]:
|
| 488 |
# Separately consolidate conditioning and non-conditioning temp outputs
|
| 489 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
|
|
|
|
| 490 |
# Find all the frames that contain temporary outputs for any objects
|
| 491 |
# (these should be the frames that have just received clicks for mask inputs
|
| 492 |
# via `add_new_points_or_box` or `add_new_mask`)
|
|
|
|
| 514 |
obj_output_dict[storage_key][frame_idx] = out
|
| 515 |
if self.clear_non_cond_mem_around_input:
|
| 516 |
# clear non-conditioning memory of the surrounding frames
|
| 517 |
+
self._clear_obj_non_cond_mem_around_input(inference_state, frame_idx, obj_idx)
|
|
|
|
|
|
|
| 518 |
|
| 519 |
# clear temporary outputs in `temp_output_dict_per_obj`
|
| 520 |
obj_temp_output_dict[storage_key].clear()
|
|
|
|
| 523 |
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| 524 |
if len(obj_output_dict["cond_frame_outputs"]) == 0:
|
| 525 |
obj_id = self._obj_idx_to_id(inference_state, obj_idx)
|
| 526 |
+
raise RuntimeError(f"No input points or masks are provided for object id {obj_id}; please add inputs first.")
|
|
|
|
|
|
|
| 527 |
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
| 528 |
# output on the same frame in "non_cond_frame_outputs"
|
| 529 |
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
|
|
|
| 548 |
if start_frame_idx is None:
|
| 549 |
# default: start from the earliest frame with input points
|
| 550 |
start_frame_idx = min(
|
| 551 |
+
t for obj_output_dict in inference_state["output_dict_per_obj"].values() for t in obj_output_dict["cond_frame_outputs"]
|
|
|
|
|
|
|
| 552 |
)
|
| 553 |
if max_frame_num_to_track is None:
|
| 554 |
# default: track all the frames in the video
|
|
|
|
| 560 |
else:
|
| 561 |
processing_order = [] # skip reverse tracking if starting from frame 0
|
| 562 |
else:
|
| 563 |
+
end_frame_idx = min(start_frame_idx + max_frame_num_to_track, num_frames - 1)
|
|
|
|
|
|
|
| 564 |
processing_order = range(start_frame_idx, end_frame_idx + 1)
|
| 565 |
|
| 566 |
for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
|
|
|
| 578 |
pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
|
| 579 |
if self.clear_non_cond_mem_around_input:
|
| 580 |
# clear non-conditioning memory of the surrounding frames
|
| 581 |
+
self._clear_obj_non_cond_mem_around_input(inference_state, frame_idx, obj_idx)
|
|
|
|
|
|
|
| 582 |
else:
|
| 583 |
storage_key = "non_cond_frame_outputs"
|
| 584 |
current_out, pred_masks = self._run_single_frame_inference(
|
|
|
|
| 594 |
)
|
| 595 |
obj_output_dict[storage_key][frame_idx] = current_out
|
| 596 |
|
| 597 |
+
inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {"reverse": reverse}
|
|
|
|
|
|
|
| 598 |
pred_masks_per_obj[obj_idx] = pred_masks
|
| 599 |
|
| 600 |
# Resize the output mask to the original video resolution (we directly use
|
|
|
|
| 603 |
all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
|
| 604 |
else:
|
| 605 |
all_pred_masks = pred_masks_per_obj[0]
|
| 606 |
+
_, video_res_masks = self._get_orig_video_res_output(inference_state, all_pred_masks)
|
|
|
|
|
|
|
| 607 |
yield frame_idx, obj_ids, video_res_masks
|
| 608 |
|
| 609 |
@torch.inference_mode()
|
| 610 |
+
def clear_all_prompts_in_frame(self, inference_state, frame_idx, obj_id, need_output=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
"""Remove all input points or mask in a specific frame for a given object."""
|
| 612 |
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
| 613 |
|
|
|
|
| 632 |
return
|
| 633 |
# Finally, output updated masks per object (after removing the inputs above)
|
| 634 |
obj_ids = inference_state["obj_ids"]
|
| 635 |
+
is_cond = any(frame_idx in obj_temp_output_dict["cond_frame_outputs"] for obj_temp_output_dict in temp_output_dict_per_obj.values())
|
|
|
|
|
|
|
|
|
|
| 636 |
consolidated_out = self._consolidate_temp_output_across_obj(
|
| 637 |
inference_state,
|
| 638 |
frame_idx,
|
| 639 |
is_cond=is_cond,
|
| 640 |
consolidate_at_video_res=True,
|
| 641 |
)
|
| 642 |
+
_, video_res_masks = self._get_orig_video_res_output(inference_state, consolidated_out["pred_masks_video_res"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
return frame_idx, obj_ids, video_res_masks
|
| 644 |
|
| 645 |
@torch.inference_mode()
|
|
|
|
| 674 |
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
| 675 |
"""Compute the image features on a given frame."""
|
| 676 |
# Look up in the cache first
|
| 677 |
+
image, backbone_out = inference_state["cached_features"].get(frame_idx, (None, None))
|
|
|
|
|
|
|
| 678 |
if backbone_out is None:
|
| 679 |
# Cache miss -- we will run inference on a single image
|
| 680 |
device = inference_state["device"]
|
|
|
|
| 691 |
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
|
| 692 |
}
|
| 693 |
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
|
| 694 |
+
expanded_backbone_out["backbone_fpn"][i] = feat.expand(batch_size, -1, -1, -1)
|
|
|
|
|
|
|
| 695 |
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
|
| 696 |
pos = pos.expand(batch_size, -1, -1, -1)
|
| 697 |
expanded_backbone_out["vision_pos_enc"][i] = pos
|
|
|
|
| 746 |
if maskmem_features is not None:
|
| 747 |
maskmem_features = maskmem_features.to(torch.bfloat16)
|
| 748 |
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
| 749 |
+
pred_masks_gpu = current_out["pred_masks"]
|
| 750 |
# potentially fill holes in the predicted masks
|
| 751 |
if self.fill_hole_area > 0:
|
| 752 |
+
pred_masks_gpu = fill_holes_in_mask_scores(pred_masks_gpu, self.fill_hole_area)
|
|
|
|
|
|
|
| 753 |
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
| 754 |
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 755 |
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
| 756 |
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
| 757 |
obj_ptr = current_out["obj_ptr"]
|
| 758 |
object_score_logits = current_out["object_score_logits"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 759 |
# make a compact version of this frame's output to reduce the state size
|
| 760 |
compact_current_out = {
|
| 761 |
+
"maskmem_features": maskmem_features,
|
| 762 |
+
"maskmem_pos_enc": maskmem_pos_enc,
|
| 763 |
"pred_masks": pred_masks,
|
| 764 |
"obj_ptr": obj_ptr,
|
| 765 |
"object_score_logits": object_score_logits,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 766 |
}
|
| 767 |
return compact_current_out, pred_masks_gpu
|
| 768 |
|
|
|
|
| 781 |
memory also need to be computed again with the memory encoder.
|
| 782 |
"""
|
| 783 |
# Retrieve correct image features
|
| 784 |
+
_, _, current_vision_feats, _, feat_sizes = self._get_image_feature(inference_state, frame_idx, batch_size)
|
|
|
|
|
|
|
| 785 |
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
| 786 |
current_vision_feats=current_vision_feats,
|
| 787 |
feat_sizes=feat_sizes,
|
|
|
|
| 795 |
maskmem_features = maskmem_features.to(torch.bfloat16)
|
| 796 |
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
| 797 |
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 798 |
+
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, {"maskmem_pos_enc": maskmem_pos_enc})
|
|
|
|
|
|
|
| 799 |
return maskmem_features, maskmem_pos_enc
|
| 800 |
|
| 801 |
def _get_maskmem_pos_enc(self, inference_state, current_out):
|
|
|
|
| 816 |
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
| 817 |
# expand the cached maskmem_pos_enc to the actual batch size
|
| 818 |
batch_size = out_maskmem_pos_enc[0].size(0)
|
| 819 |
+
expanded_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]
|
|
|
|
|
|
|
| 820 |
else:
|
| 821 |
expanded_maskmem_pos_enc = None
|
| 822 |
return expanded_maskmem_pos_enc
|
|
|
|
| 834 |
if not strict:
|
| 835 |
return inference_state["obj_ids"], updated_frames
|
| 836 |
raise RuntimeError(
|
| 837 |
+
f"Cannot remove object id {obj_id} as it doesn't exist. " f"All existing object ids: {inference_state['obj_ids']}."
|
|
|
|
| 838 |
)
|
| 839 |
|
| 840 |
# If this is the only remaining object id, we simply reset the state.
|
|
|
|
| 848 |
# (note that this step is required as it might downgrade conditioning frames to
|
| 849 |
# non-conditioning ones)
|
| 850 |
obj_input_frames_inds = set()
|
| 851 |
+
obj_input_frames_inds.update(inference_state["point_inputs_per_obj"][old_obj_idx_to_rm])
|
| 852 |
+
obj_input_frames_inds.update(inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
for frame_idx in obj_input_frames_inds:
|
| 854 |
+
self.clear_all_prompts_in_frame(inference_state, frame_idx, obj_id, need_output=False)
|
|
|
|
|
|
|
| 855 |
|
| 856 |
# Step 1: Update the object id mapping (note that it must be done after Step 0,
|
| 857 |
# since Step 0 still requires the old object id mappings in inference_state)
|
|
|
|
| 868 |
inference_state["obj_ids"] = new_obj_ids
|
| 869 |
|
| 870 |
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
def _map_keys(container):
|
| 872 |
new_kvs = []
|
| 873 |
for k in old_obj_inds:
|
|
|
|
| 880 |
_map_keys(inference_state["mask_inputs_per_obj"])
|
| 881 |
_map_keys(inference_state["output_dict_per_obj"])
|
| 882 |
_map_keys(inference_state["temp_output_dict_per_obj"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 883 |
_map_keys(inference_state["frames_tracked_per_obj"])
|
| 884 |
|
| 885 |
# Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
|
|
|
| 886 |
# could show an updated mask for objects previously occluded by the object being removed
|
| 887 |
if need_output:
|
| 888 |
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
| 889 |
for frame_idx in obj_input_frames_inds:
|
| 890 |
is_cond = any(
|
| 891 |
+
frame_idx in obj_temp_output_dict["cond_frame_outputs"] for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
|
|
|
| 892 |
)
|
| 893 |
consolidated_out = self._consolidate_temp_output_across_obj(
|
| 894 |
inference_state,
|
| 895 |
frame_idx,
|
| 896 |
is_cond=is_cond,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 897 |
consolidate_at_video_res=True,
|
| 898 |
)
|
| 899 |
+
_, video_res_masks = self._get_orig_video_res_output(inference_state, consolidated_out["pred_masks_video_res"])
|
|
|
|
|
|
|
| 900 |
updated_frames.append((frame_idx, video_res_masks))
|
| 901 |
|
| 902 |
return inference_state["obj_ids"], updated_frames
|
|
|
|
| 967 |
if self.use_high_res_features_in_sam:
|
| 968 |
# precompute projected level 0 and level 1 features in SAM decoder
|
| 969 |
# to avoid running it again on every SAM click
|
| 970 |
+
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
|
| 971 |
+
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 972 |
# Clone to help torch.compile
|
| 973 |
for i in range(len(backbone_out["backbone_fpn"])):
|
| 974 |
backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
|
| 975 |
+
backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][i].clone()
|
|
|
|
|
|
|
| 976 |
return backbone_out
|
| 977 |
|
| 978 |
def _forward_sam_heads(
|
|
|
|
| 1131 |
# optionally, apply non-overlapping constraints to the masks (it's applied
|
| 1132 |
# in the batch dimension and should only be used during eval, where all
|
| 1133 |
# the objects come from the same video under batch size 1).
|
| 1134 |
+
pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
|
|
|
|
|
|
|
| 1135 |
# scale the raw mask logits with a temperature before applying sigmoid
|
| 1136 |
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
| 1137 |
if binarize and not self.training:
|
|
|
|
| 1144 |
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
| 1145 |
if self.sigmoid_bias_for_mem_enc != 0.0:
|
| 1146 |
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
| 1147 |
+
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
|
|
|
|
|
|
|
| 1148 |
# Clone the feats and pos_enc to enable compilation
|
| 1149 |
maskmem_features = maskmem_out["vision_features"].clone()
|
| 1150 |
maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
|
|
|
|
| 1152 |
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
| 1153 |
if self.no_obj_embed_spatial is not None:
|
| 1154 |
is_obj_appearing = (object_score_logits > 0).float()
|
| 1155 |
+
maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[..., None, None].expand(
|
|
|
|
|
|
|
| 1156 |
*maskmem_features.shape
|
| 1157 |
)
|
| 1158 |
|