Miroslav Purkrabek commited on
Commit
322535b
·
1 Parent(s): e5057aa

add BMPv2 code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. BMPv2_README.md +311 -0
  2. SAM3D_INTEGRATION.md +302 -0
  3. app.py +188 -172
  4. bboxmaskpose/__init__.py +10 -0
  5. bboxmaskpose/api.py +515 -0
  6. {configs → bboxmaskpose/configs}/README.md +0 -0
  7. {configs → bboxmaskpose/configs}/bmp_D3.yaml +9 -2
  8. {configs → bboxmaskpose/configs}/bmp_J1.yaml +5 -0
  9. bboxmaskpose/configs/bmp_v2.yaml +34 -0
  10. {demo → bboxmaskpose}/demo_utils.py +30 -110
  11. {demo → bboxmaskpose}/posevis_lite.py +12 -12
  12. {sam2 → bboxmaskpose/sam2}/__init__.py +1 -1
  13. {sam2 → bboxmaskpose/sam2}/automatic_mask_generator.py +20 -50
  14. {sam2 → bboxmaskpose/sam2}/benchmark.py +3 -9
  15. {sam2 → bboxmaskpose/sam2}/build_sam.py +34 -9
  16. {sam2 → bboxmaskpose/sam2}/colorblind.py +8 -16
  17. bboxmaskpose/sam2/configs/sam-pose2seg/sam-pose2seg_hiera_b+.yaml +118 -0
  18. {sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_b+.yaml +14 -14
  19. {sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_l.yaml +14 -14
  20. {sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_s.yaml +14 -14
  21. {sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_t.yaml +14 -14
  22. bboxmaskpose/sam2/configs/sam2.1_training/sam2.1_hiera_b+_COCO+CIHP_finetune_sam-pose2seg.yaml +343 -0
  23. {sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_COCO_1024_prompt.yaml +15 -23
  24. {sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_COCO_finetune.yaml +15 -24
  25. {sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_COCO_finetune_prompt+decoder.yaml +15 -24
  26. {sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +15 -21
  27. {sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_b+.yaml +14 -14
  28. {sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_l.yaml +14 -14
  29. {sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_s.yaml +14 -14
  30. {sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_t.yaml +14 -14
  31. {sam2 → bboxmaskpose/sam2}/csrc/connected_components.cu +0 -0
  32. {sam2 → bboxmaskpose/sam2}/distinctipy.py +7 -14
  33. {sam2 → bboxmaskpose/sam2}/modeling/__init__.py +0 -0
  34. {sam2 → bboxmaskpose/sam2}/modeling/backbones/__init__.py +0 -0
  35. {sam2 → bboxmaskpose/sam2}/modeling/backbones/hieradet.py +10 -31
  36. {sam2 → bboxmaskpose/sam2}/modeling/backbones/image_encoder.py +1 -3
  37. {sam2 → bboxmaskpose/sam2}/modeling/backbones/utils.py +2 -6
  38. {sam2 → bboxmaskpose/sam2}/modeling/memory_attention.py +4 -7
  39. {sam2 → bboxmaskpose/sam2}/modeling/memory_encoder.py +3 -9
  40. {sam2 → bboxmaskpose/sam2}/modeling/position_encoding.py +8 -31
  41. {sam2 → bboxmaskpose/sam2}/modeling/sam/__init__.py +0 -0
  42. {sam2 → bboxmaskpose/sam2}/modeling/sam/mask_decoder.py +11 -32
  43. {sam2 → bboxmaskpose/sam2}/modeling/sam/pose_encoder.py +7 -19
  44. {sam2 → bboxmaskpose/sam2}/modeling/sam/prompt_encoder.py +21 -26
  45. {sam2 → bboxmaskpose/sam2}/modeling/sam/transformer.py +12 -30
  46. {sam2 → bboxmaskpose/sam2}/modeling/sam2_base.py +72 -246
  47. {sam2 → bboxmaskpose/sam2}/modeling/sam2_base_pose.py +45 -87
  48. {sam2 → bboxmaskpose/sam2}/modeling/sam2_utils.py +5 -13
  49. {sam2 → bboxmaskpose/sam2}/sam2_image_predictor.py +32 -81
  50. {sam2 → bboxmaskpose/sam2}/sam2_video_predictor.py +35 -298
BMPv2_README.md ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ </h1><div id="toc">
2
+ <ul align="center" style="list-style: none; padding: 0; margin: 0;">
3
+ <summary>
4
+ <h1 style="margin-bottom: 0.0em;">
5
+ BBoxMaskPose v2
6
+ </h1>
7
+ </summary>
8
+ </ul>
9
+ </div>
10
+ </h1><div id="toc">
11
+ <ul align="center" style="list-style: none; padding: 0; margin: 0;">
12
+ <summary>
13
+ <h2 style="margin-bottom: 0.2em;">
14
+ CVPR 2025 + ICCV 2025
15
+ </h2>
16
+ </summary>
17
+ </ul>
18
+ </div>
19
+
20
+ <div align="center">
21
+ <img src="data/assets/BMP_043+076+174.gif" alt="BBoxMaskPose v2 loop" height="500px">
22
+
23
+ [![Website](https://img.shields.io/badge/Website-BBoxMaskPose-green)](https://mirapurkrabek.github.io/BBox-Mask-Pose/) &nbsp;&nbsp;&nbsp;
24
+ [![License](https://img.shields.io/badge/License-GPL%203.0-orange.svg)](LICENSE) &nbsp;&nbsp;&nbsp;
25
+ [![Video](https://img.shields.io/badge/Video-YouTube-red?logo=youtube)](https://youtu.be/U05yUP4b2LQ)
26
+
27
+ [![Paper](https://img.shields.io/badge/ProbPose-CVPR%202025-blue)](https://arxiv.org/abs/2412.02254) &nbsp;&nbsp;&nbsp;
28
+ [![Paper](https://img.shields.io/badge/BMPv1-ICCV%202025-blue)](https://arxiv.org/abs/2412.01562) &nbsp;&nbsp;&nbsp;
29
+ [![Paper](https://img.shields.io/badge/SAMpose2seg-CVWW%202026-blue)](https://arxiv.org/abs/2601.08982) &nbsp;&nbsp;&nbsp;
30
+ [![Paper](https://img.shields.io/badge/BMPv2-arXiv-blue)](https://arxiv.org/abs/2601.15200) &nbsp;&nbsp;&nbsp;
31
+
32
+
33
+
34
+ <!-- Papers with code:
35
+ [![2D Pose AP on OCHuman: 42.5](https://img.shields.io/badge/OCHuman-2D_Pose:_49.2_AP-blue)](https://paperswithcode.com/sota/2d-human-pose-estimation-on-ochuman?p=detection-pose-estimation-and-segmentation-1) &nbsp;&nbsp;
36
+ [![Human Instance Segmentation AP on OCHuman: 34.0](https://img.shields.io/badge/OCHuman-Human_Instance_Segmentation:_34.0_AP-blue)](https://paperswithcode.com/sota/human-instance-segmentation-on-ochuman?p=detection-pose-estimation-and-segmentation-1) -->
37
+
38
+ </div>
39
+
40
+ > [!CAUTION]
41
+ > This branch is a **work in progress**!
42
+ >
43
+ > Until merged with <code>main</code>, use on your own discretion. For stable version, please refer to <code>main</code> branch with BMPv1.
44
+
45
+ ## 📢 News
46
+
47
+ - **Feb 2026**: Version 2.0 with improved (1) pose and (2) SAM and (3) wiring to 3D prediction released.
48
+ - **Feb 2026**: SAM-pose2seg won a Best Paper Award on CVWW 2026 🎉
49
+ - **Jan 2026**: [BMPv2 paper](https://arxiv.org/abs/2601.15200) is available on arXiv
50
+ - **Aug 2025**: [HuggingFace Image Demo](https://huggingface.co/spaces/purkrmir/BBoxMaskPose-demo) is out! 🎮
51
+ - **Jul 2025**: Version 1.1 with easy-to-run image demo released
52
+ - **Jun 2025**: BMPv1 paper accepted to ICCV 2025! 🎉
53
+ - **Dec 2024**: BMPv1 code is available
54
+ - **Nov 2024**: The [project website](https://MiraPurkrabek.github.io/BBox-Mask-Pose) is on
55
+
56
+ ## 📑 Table of Contents
57
+
58
+ - [Installation](#-installation)
59
+ - [Demo](#-demo)
60
+ - [API Examples](#api-examples)
61
+ - [Pre-trained Models](#-pre-trained-models)
62
+ - [Acknowledgments](#-acknowledgments)
63
+ - [Citation](#-citation)
64
+
65
+
66
+ ## 📋 Project Overview
67
+
68
+ Bounding boxes, masks, and poses capture complementary aspects of the human body. BBoxMaskPose links detection, segmentation, and pose estimation iteratively, where each prediction refines the others. PMPose combines probabilistic modeling with mask conditioning for robust pose estimation in crowds. Together, these components achieve state-of-the-art results on COCO and OCHuman, being the first method to exceed 50 AP on OCHuman.
69
+
70
+
71
+ ### Repository Structure
72
+
73
+ The repository is organized into two main packages with stable public APIs:
74
+
75
+ ```
76
+ BBoxMaskPose/
77
+ ├── pmpose/ # PMPose package (pose estimation)
78
+ │ └── pmpose/
79
+ │ ├── api.py # PUBLIC API: PMPose class
80
+ │ ├── mm_utils.py # Internal utilities
81
+ │ └── posevis_lite.py # Visualization
82
+ ├── mmpose/ # MMPose fork with our edits
83
+ ├── bboxmaskpose/ # BBoxMaskPose package (full pipeline)
84
+ │ └── bboxmaskpose/
85
+ │ ├── api.py # PUBLIC API: BBoxMaskPose class
86
+ │ ├── sam2/ # SAM2 implementation
87
+ │ ├── configs/ # BMP configurations
88
+ │ └── *_utils.py # Internal utilities
89
+ ├── demos/ # Public API demos
90
+ │ ├── PMPose_demo.py # PMPose usage example
91
+ │ ├── BMP_demo.py # BBoxMaskPose usage example
92
+ │ └── quickstart.ipynb # Interactive notebook
93
+ └── demo/ # Legacy demo (still functional)
94
+ ```
95
+
96
+ Key contributions:
97
+ 1. **MaskPose**: a pose estimation model conditioned by segmentation masks instead of bounding boxes, boosting performance in dense scenes without adding parameters
98
+ - Download pre-trained weights below
99
+ 2. **BBox-MaskPose (BMP)**: method linking bounding boxes, segmentation masks, and poses to simultaneously address multi-body detection, segmentation and pose estimation
100
+ - Try the demo!
101
+ 3. Fine-tuned RTMDet adapted for itterative detection (ignoring 'holes')
102
+ - Download pre-trained weights below
103
+ 4. Support for multi-dataset training of ViTPose, previously implemented in the official ViTPose repository but absent in MMPose.
104
+
105
+ For more details, please visit our [project website](https://mirapurkrabek.github.io/BBox-Mask-Pose/).
106
+
107
+
108
+
109
+ ## 🚀 Installation
110
+
111
+ ### Docker Installation (Recommended)
112
+
113
+ The fastest way to get started with GPU support:
114
+
115
+ ```bash
116
+ # Clone and build
117
+ git clone https://github.com/mirapurkrabek/BBoxMaskPose.git
118
+ cd BBoxMaskPose
119
+ docker-compose build
120
+
121
+ # Run the demo
122
+ docker-compose up
123
+ ```
124
+
125
+ Requires: Docker Engine 19.03+, [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), NVIDIA GPU with CUDA 12.1 support.
126
+
127
+ ### Manual Installation
128
+
129
+ This project is built on top of [MMPose](https://github.com/open-mmlab/mmpose) and [SAM 2.1](https://github.com/facebookresearch/sam2).
130
+ Please refer to the [MMPose installation guide](https://mmpose.readthedocs.io/en/latest/installation.html) or [SAM installation guide](https://github.com/facebookresearch/sam2/blob/main/INSTALL.md) for detailed setup instructions.
131
+
132
+ Basic installation steps:
133
+ ```bash
134
+ # Clone the repository
135
+ git clone https://github.com/mirapurkrabek/BBoxMaskPose.git BBoxMaskPose/
136
+ cd BBoxMaskPose
137
+
138
+ # Install your version of torch, torchvision, OpenCV and NumPy
139
+ pip install torch==2.1.2+cu121 torchvision==0.16.2+cu121 --extra-index-url https://download.pytorch.org/whl/cu121
140
+ pip install numpy==1.25.1 opencv-python==4.9.0.80
141
+
142
+ # Install MMLibrary
143
+ pip install -U openmim
144
+ mim install mmengine "mmcv==2.1.0" "mmdet==3.3.0" "mmpretrain==1.2.0"
145
+
146
+ # Install dependencies
147
+ pip install -r requirements.txt
148
+ pip install -e .
149
+ ```
150
+
151
+ ## 🎮 Demo
152
+
153
+ #### PMPose Demo (Pose Estimation Only)
154
+ ```bash
155
+ python demos/PMPose_demo.py --image data/004806.jpg --device cuda
156
+ ```
157
+
158
+ #### BBoxMaskPose Demo (Full Pipeline)
159
+ ```bash
160
+ python demos/BMP_demo.py --image data/004806.jpg --device cuda
161
+ ```
162
+
163
+ After running the demo, outputs are in `outputs/004806/`. The expected output should look like this:
164
+ <div align="center">
165
+ <a href="data/assets/004806_mask.jpg" target="_blank">
166
+ <img src="data/assets/004806_mask.jpg" alt="Detection results" width="200" />
167
+ </a>
168
+ &nbsp&nbsp&nbsp&nbsp
169
+ <a href="data/assets/004806_pose.jpg" target="_blank">
170
+ <img src="data/assets/004806_pose.jpg" alt="Pose results" width="200" style="margin-right:10px;" />
171
+ </a>
172
+ </div>
173
+
174
+ #### BBoxMaskPose v2 Demo (Full Pipeline + 3D Mesh Recovery)
175
+ This demo extends BMP with [SAM-3D-Body](https://github.com/facebookresearch/sam-3d-body) for 3D human mesh recovery:
176
+ ```bash
177
+ # Basic usage (auto-downloads checkpoint from HuggingFace)
178
+ python demos/BMPv2_demo.py --image data/004806.jpg --device cuda
179
+
180
+ # With local checkpoint
181
+ python demos/BMPv2_demo.py --image data/004806.jpg --device cuda \
182
+ --sam3d_checkpoint checkpoints/sam-3d-body-dinov3/model.ckpt \
183
+ --mhr_path checkpoints/sam-3d-body-dinov3/assets/mhr_model.pt
184
+ ```
185
+
186
+ **SAM-3D-Body Installation (Optional):**
187
+ BMPv2 requires SAM-3D-Body for 3D mesh recovery. Install it separately:
188
+ ```bash
189
+ # 1. Install dependencies
190
+ pip install -r requirements/sam3d.txt
191
+
192
+ # 2. Install detectron2
193
+ pip install 'git+https://github.com/facebookresearch/detectron2.git@a1ce2f9' --no-build-isolation --no-deps
194
+
195
+ # 3. Install MoGe (optional, for FOV estimation)
196
+ pip install git+https://github.com/microsoft/MoGe.git
197
+
198
+ # 4. Install adapted SAM-3D-Body repository
199
+ pip install git+https://github.com/MiraPurkrabek/sam-3d-body.git
200
+
201
+ # 5. Request access to checkpoints at https://huggingface.co/facebook/sam-3d-body-dinov3
202
+ ```
203
+
204
+ For more details, see [SAM-3D-Body installation guide](https://github.com/facebookresearch/sam-3d-body/blob/main/INSTALL.md).
205
+
206
+ #### Jupyter Notebook
207
+ Interactive demo with both PMPose and BBoxMaskPose:
208
+ ```bash
209
+ jupyter notebook demos/quickstart.ipynb
210
+ ```
211
+
212
+ ## API Examples
213
+
214
+ **PMPose API** - Pose estimation with bounding boxes:
215
+ ```python
216
+ from pmpose import PMPose
217
+
218
+ # Initialize model
219
+ pose_model = PMPose(device="cuda", from_pretrained=True)
220
+
221
+ # Run inference
222
+ keypoints, presence, visibility, heatmaps = pose_model.predict(
223
+ image="demo/data/004806.jpg",
224
+ bboxes=[[100, 100, 300, 400]], # [x1, y1, x2, y2]
225
+ )
226
+
227
+ # Visualize
228
+ vis_img = pose_model.visualize(image="demo/data/004806.jpg", keypoints=keypoints)
229
+ ```
230
+
231
+ **BBoxMaskPose API** - Full detection + pose + segmentation:
232
+
233
+ ```python
234
+ from pmpose import PMPose
235
+ from bboxmaskpose import BBoxMaskPose
236
+
237
+ # Create pose model
238
+ pose_model = PMPose(device="cuda", from_pretrained=True)
239
+
240
+ # Inject into BMP
241
+ bmp_model = BBoxMaskPose(config="BMP_D3", device="cuda", pose_model=pose_model)
242
+ result = bmp_model.predict(image="demo/data/004806.jpg")
243
+
244
+ # Visualize
245
+ vis_img = bmp_model.visualize(image="demo/data/004806.jpg", result=result)
246
+ ```
247
+
248
+
249
+ ## 📦 Pre-trained Models
250
+
251
+ Pre-trained models are available on [VRG Hugging Face 🤗](https://huggingface.co/vrg-prague/BBoxMaskPose/).
252
+ To run the demo, you only need do download SAM weights with [enclosed script](models/SAM/download_ckpts.sh).
253
+ Our detector and pose estimator will be downloaded during the runtime.
254
+
255
+ If you want to download our weights yourself, here are the links to our HuggingFace:
256
+ - ViTPose-b trained on COCO+MPII+AIC -- [download weights](https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/ViTPose-b-multi_mmpose20.pth)
257
+ - MaskPose-b -- [download weights](https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/MaskPose-b.pth)
258
+ - Fine-tuned RTMDet-L -- [download weights](https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/rtmdet-ins-l-mask.pth)
259
+
260
+ ## 🙏 Acknowledgments
261
+
262
+ The code combines [MMDetection](https://github.com/open-mmlab/mmdetection), [MMPose 2.0](https://github.com/open-mmlab/mmpose), [ViTPose](https://github.com/ViTAE-Transformer/ViTPose), [SAM 2.1](https://github.com/facebookresearch/sam2) and [SAM-3D-Body](https://github.com/facebookresearch/sam-3d-body).
263
+
264
+ Our visualizations integrate [Distinctipy](https://github.com/alan-turing-institute/distinctipy) for automatic color selection.
265
+
266
+ This repository combines our work on BBoxMaskPose project with our previous work on [probabilistic 2D human pose estimation modelling](https://mirapurkrabek.github.io/ProbPose/).
267
+
268
+ ## 📝 Citation
269
+
270
+ The code was implemented by [Miroslav Purkrábek](https://mirapurkrabek.github.io/) and Constantin Kolomiiets.
271
+ If you use this work, kindly cite it using the references provided below.
272
+
273
+ For questions, please use the Issues of Discussion.
274
+
275
+ ```
276
+ @InProceedings{Purkrabek2025BMPv1,
277
+ author = {Purkrabek, Miroslav and Matas, Jiri},
278
+ title = {Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle},
279
+ booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
280
+ month = {October},
281
+ year = {2025},
282
+ pages = {9004-9013}
283
+ }
284
+ ```
285
+
286
+ ```
287
+ @InProceedings{Purkrabek2026BMPv2,
288
+ author = {Purkrabek, Miroslav and Kolomiiets, Constantin and Matas, Jiri},
289
+ title = {BBoxMaskPose v2: Expanding Mutual Conditioning to 3D},
290
+ booktitle = {arXiv preprint arXiv:2601.15200},
291
+ year = {2026}
292
+ }
293
+ ```
294
+
295
+ ```
296
+ @article{yang2025sam3dbody,
297
+ title={SAM 3D Body: Robust Full-Body Human Mesh Recovery},
298
+ author={Yang, Xitong and Kukreja, Devansh and Pinkus, Don and Sagar, Anushka and Fan, Taosha and Park, Jinhyung and Shin, Soyong and Cao, Jinkun and Liu, Jiawei and Ugrinovic, Nicolas and Feiszli, Matt and Malik, Jitendra and Dollar, Piotr and Kitani, Kris},
299
+ journal={arXiv preprint; identifier to be added},
300
+ year={2025}
301
+ }
302
+ ```
303
+
304
+ ```
305
+ @InProceedings{Kolomiiets2026CVWW,
306
+ author = {Kolomiiets, Constantin and Purkrabek, Miroslav and Matas, Jiri},
307
+ title = {SAM-pose2seg: Pose-Guided Human Instance Segmentation in Crowds},
308
+ booktitle = {Computer Vision Winter Workshop (CVWW)},
309
+ year = {2026}
310
+ }
311
+ ```
SAM3D_INTEGRATION.md ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAM-3D-Body Integration Guide
2
+
3
+ This guide explains how to integrate and use SAM-3D-Body for 3D human mesh recovery within the BBoxMaskPose pipeline.
4
+
5
+ ## Overview
6
+
7
+ BBoxMaskPose v2 extends the original BMP pipeline with [SAM-3D-Body](https://github.com/facebookresearch/sam-3d-body) from Meta AI, enabling full 3D human mesh recovery from single images. The integration leverages BMP's high-quality 2D pose estimates and segmentation masks as prompts to SAM-3D-Body, resulting in accurate 3D reconstructions even in crowded scenes.
8
+
9
+ **Pipeline Flow:**
10
+ ```
11
+ Input Image
12
+
13
+ BBoxMaskPose (Detection + 2D Pose + Segmentation)
14
+
15
+ 2D Bboxes + Masks + Poses
16
+
17
+ SAM-3D-Body (3D Mesh Recovery)
18
+
19
+ 3D Human Meshes (vertices, joints, faces)
20
+ ```
21
+
22
+ ## Installation
23
+
24
+ ### Prerequisites
25
+
26
+ - BBoxMaskPose must be already installed and working
27
+ - CUDA-capable GPU recommended (CPU inference is very slow)
28
+ - Python 3.8+ (Python 3.11 recommended for SAM-3D-Body)
29
+
30
+ ### Step 1: Install SAM-3D-Body Dependencies
31
+
32
+ ```bash
33
+ # Navigate to BBoxMaskPose root directory
34
+ cd /path/to/BBoxMaskPose
35
+
36
+ # Install SAM-3D-Body dependencies
37
+ pip install -r requirements/sam3d.txt
38
+ ```
39
+
40
+ ### Step 2: Install Detectron2
41
+
42
+ SAM-3D-Body requires a specific version of Detectron2:
43
+
44
+ ```bash
45
+ pip install 'git+https://github.com/facebookresearch/detectron2.git@a1ce2f9' \
46
+ --no-build-isolation --no-deps
47
+ ```
48
+
49
+ ### Step 3: Install MoGe (Optional but Recommended)
50
+
51
+ MoGe provides FOV (field-of-view) estimation for better camera calibration:
52
+
53
+ ```bash
54
+ pip install git+https://github.com/microsoft/MoGe.git
55
+ ```
56
+
57
+ ### Step 4: Install SAM-3D-Body
58
+
59
+ ```bash
60
+ # Install adapted SAM-3D-Body repository
61
+ pip install git+https://github.com/MiraPurkrabek/sam-3d-body.git
62
+ ```
63
+
64
+ ### Step 5: Get Model Checkpoints
65
+
66
+ SAM-3D-Body checkpoints are hosted on HuggingFace. You need to:
67
+
68
+ 1. **Request access** at [facebook/sam-3d-body-dinov3](https://huggingface.co/facebook/sam-3d-body-dinov3)
69
+ 2. **Wait for approval** (usually within 24 hours)
70
+ 3. **Authenticate** with HuggingFace:
71
+ ```bash
72
+ pip install huggingface_hub
73
+ huggingface-cli login
74
+ ```
75
+
76
+ The BMPv2 demo will auto-download the checkpoint on first use, or you can download manually to the default location for auto-detection:
77
+
78
+ ```bash
79
+ # Download checkpoint manually to default location (will be auto-detected)
80
+ mkdir -p checkpoints
81
+ huggingface-cli download facebook/sam-3d-body-dinov3 \
82
+ --local-dir checkpoints/sam-3d-body-dinov3
83
+ ```
84
+
85
+ ## Usage
86
+
87
+ ### Basic Usage
88
+
89
+ Run the BMPv2 demo with automatic checkpoint handling:
90
+
91
+ ```bash
92
+ python demos/BMPv2_demo.py --image data/004806.jpg --device cuda
93
+ ```
94
+
95
+ **The demo will:**
96
+ 1. Auto-detect checkpoint in `checkpoints/sam-3d-body-dinov3/` OR download from HuggingFace (~3.5 GB)
97
+ 2. Run BMP pipeline to get 2D detections, poses, and masks
98
+ 3. Run SAM-3D-Body to recover 3D meshes
99
+ 4. Save visualizations to `demos/outputs/bboxmaskpose_v2/`
100
+
101
+ ### Advanced Usage
102
+
103
+ #### Use Local Checkpoint (Auto-Detection)
104
+
105
+ Download checkpoint to the default location for automatic detection:
106
+
107
+ ```bash
108
+ # The demo automatically detects checkpoints in this location
109
+ huggingface-cli download facebook/sam-3d-body-dinov3 \
110
+ --local-dir checkpoints/sam-3d-body-dinov3
111
+
112
+ # Then just run the demo - no checkpoint arguments needed!
113
+ python demos/BMPv2_demo.py --image data/004806.jpg --device cuda
114
+ ```
115
+
116
+ #### Use Custom Checkpoint Path
117
+
118
+ If your checkpoint is in a different location:
119
+
120
+ ```bash
121
+ python demos/BMPv2_demo.py \
122
+ --image data/004806.jpg \
123
+ --device cuda \
124
+ --sam3d_checkpoint /path/to/model.ckpt \
125
+ --mhr_path /path/to/mhr_model.pt
126
+ ```
127
+
128
+ #### Speed vs Quality Trade-offs
129
+
130
+ ```bash
131
+ # Fastest: body-only inference without mask conditioning
132
+ python demos/BMPv2_demo.py --image data/004806.jpg \
133
+ --inference_type body --no_mask_conditioning
134
+
135
+ # Balanced: body-only with mask conditioning
136
+ python demos/BMPv2_demo.py --image data/004806.jpg \
137
+ --inference_type body
138
+
139
+ # Best quality: full inference with mask conditioning (default)
140
+ python demos/BMPv2_demo.py --image data/004806.jpg
141
+ ```
142
+
143
+ #### Disable Mask Conditioning
144
+
145
+ Faster but less accurate (doesn't use segmentation masks as prompts):
146
+
147
+ ```bash
148
+ python demos/BMPv2_demo.py \
149
+ --image data/004806.jpg \
150
+ --no_mask_conditioning
151
+ ```
152
+
153
+ #### Skip 3D Recovery
154
+
155
+ Run only BMP pipeline (useful for testing BMP without SAM-3D-Body):
156
+
157
+ ```bash
158
+ python demos/BMPv2_demo.py \
159
+ --image data/004806.jpg \
160
+ --skip_3d
161
+ ```
162
+
163
+ ### Output Files
164
+
165
+ The demo saves the following visualizations:
166
+
167
+ - `{image_name}_bmp_pose.jpg` - 2D pose estimation results
168
+ - `{image_name}_bmp_mask.jpg` - Segmentation mask results
169
+ - `{image_name}_3d_mesh.jpg` - 3D mesh overlay on image
170
+ - `{image_name}_combined.jpg` - Side-by-side comparison of all results
171
+
172
+ ## Programmatic API
173
+
174
+ You can also use SAM-3D-Body programmatically:
175
+
176
+ ```python
177
+ from bboxmaskpose import BBoxMaskPose
178
+ from bboxmaskpose.sam3d_utils import SAM3DBodyWrapper, visualize_3d_meshes
179
+
180
+ # Step 1: Run BMP pipeline
181
+ bmp = BBoxMaskPose(config="bmp_D3", device="cuda")
182
+ result = bmp.predict(image="path/to/image.jpg")
183
+
184
+ # Step 2: Initialize SAM-3D-Body
185
+ sam3d = SAM3DBodyWrapper(device="cuda")
186
+
187
+ # Step 3: Predict 3D meshes from BMP outputs
188
+ outputs_3d = sam3d.predict(
189
+ image="path/to/image.jpg",
190
+ bboxes=result['bboxes'],
191
+ masks=result['masks'],
192
+ use_mask=True,
193
+ inference_type="full", # Options: "full", "body", "hand"
194
+ )
195
+
196
+ # Step 4: Visualize results
197
+ import cv2
198
+ img = cv2.imread("path/to/image.jpg")
199
+ vis = visualize_3d_meshes(img, outputs_3d, sam3d.faces)
200
+ cv2.imwrite("output_3d.jpg", vis)
201
+ ```
202
+
203
+ ### Access 3D Mesh Data
204
+
205
+ Each element in `outputs_3d` is a dictionary containing:
206
+
207
+ ```python
208
+ output_3d[0].keys()
209
+ # dict_keys(['vertices', 'joints', 'bbox', 'mask', ...])
210
+
211
+ # 3D mesh vertices in camera coordinates (V, 3)
212
+ vertices = outputs_3d[0]['vertices']
213
+
214
+ # 3D joint locations (J, 3)
215
+ joints_3d = outputs_3d[0]['joints']
216
+
217
+ # Mesh faces (shared across all people)
218
+ faces = sam3d.faces # (F, 3)
219
+ ```
220
+
221
+ ## Integration Architecture
222
+
223
+ ### Wrapper Design
224
+
225
+ The integration follows BBoxMaskPose's modular design pattern:
226
+
227
+ ```
228
+ bboxmaskpose/
229
+ ├── sam3d_utils.py # SAM-3D-Body wrapper (new)
230
+ │ ├── SAM3DBodyWrapper # Main wrapper class
231
+ │ ├── visualize_3d_meshes # Visualization helper
232
+ │ └── check_sam3d_available
233
+
234
+ demos/
235
+ ├── BMP_demo.py # Original BMP demo
236
+ └── BMPv2_demo.py # New demo with 3D (new)
237
+ ```
238
+
239
+ ### Why a Wrapper?
240
+
241
+ The `SAM3DBodyWrapper` class:
242
+ - **Simplifies** SAM-3D-Body's complex initialization
243
+ - **Adapts** BMP outputs (bboxes, masks) to SAM-3D-Body inputs
244
+ - **Handles** optional dependencies gracefully (no hard requirement)
245
+ - **Follows** BMP's design patterns (similar to PMPose wrapper)
246
+
247
+ ### Key Design Decisions
248
+
249
+ 1. **Optional Dependency**: SAM-3D-Body is not required for core BMP functionality
250
+ 2. **No Code Duplication**: Reuses SAM-3D-Body's existing code via wrapper
251
+ 3. **Mask Conditioning**: Leverages BMP's high-quality masks as prompts
252
+ 4. **No Internal Detector**: Disables SAM-3D-Body's detector (BMP already detects)
253
+
254
+ ## Troubleshooting
255
+
256
+ ### Import Error: `sam_3d_body` not found
257
+
258
+ **Solution**: Install SAM-3D-Body following Step 4 above.
259
+
260
+ ### HuggingFace Authentication Error
261
+
262
+ **Solution**:
263
+ 1. Request access at https://huggingface.co/facebook/sam-3d-body-dinov3
264
+ 2. Login: `huggingface-cli login`
265
+
266
+ ### MoGe Import Error (FOV Estimator)
267
+
268
+ **Solution**: Either:
269
+ - Install MoGe: `pip install git+https://github.com/microsoft/MoGe.git`
270
+ - Or disable FOV estimation (uses default FOV instead)
271
+
272
+ ### Detectron2 Build Errors
273
+
274
+ **Solution**: Make sure you have:
275
+ - CUDA toolkit installed and matching PyTorch CUDA version
276
+ - GCC/G++ compiler available
277
+ - Use the exact commit hash: `@a1ce2f9`
278
+
279
+ ## References
280
+
281
+ - **SAM-3D-Body**: [GitHub](https://github.com/facebookresearch/sam-3d-body) | [Paper](https://ai.meta.com/research/publications/sam-3d-body-robust-full-body-human-mesh-recovery/)
282
+ - **BBoxMaskPose**: [GitHub](https://github.com/MiraPurkrabek/BBoxMaskPose) | [Paper](https://arxiv.org/abs/2601.15200)
283
+
284
+ ## Citation
285
+
286
+ If you use this integration, please cite both works:
287
+
288
+ ```bibtex
289
+ @article{yang2025sam3dbody,
290
+ title={SAM 3D Body: Robust Full-Body Human Mesh Recovery},
291
+ author={Yang, Xitong and Kukreja, Devansh and Pinkus, Don and Sagar, Anushka and Fan, Taosha and Park, Jinhyung and Shin, Soyong and Cao, Jinkun and Liu, Jiawei and Ugrinovic, Nicolas and Feiszli, Matt and Malik, Jitendra and Dollar, Piotr and Kitani, Kris},
292
+ journal={arXiv preprint; identifier to be added},
293
+ year={2025}
294
+ }
295
+
296
+ @InProceedings{Purkrabek2026BMPv2,
297
+ author = {Purkrabek, Miroslav and Kolomiiets, Constantin and Matas, Jiri},
298
+ title = {BBoxMaskPose v2: Expanding Mutual Conditioning to 3D},
299
+ booktitle = {arXiv preprint arXiv:2601.15200},
300
+ year = {2026}
301
+ }
302
+ ```
app.py CHANGED
@@ -1,188 +1,204 @@
1
- import gradio as gr
2
- import spaces
3
-
4
- from pathlib import Path
5
 
 
 
6
  import numpy as np
7
- import yaml
8
- from demo.demo_utils import DotDict, concat_instances, filter_instances, pose_nms, visualize_demo
9
- from demo.mm_utils import run_MMDetector, run_MMPose
10
- from mmdet.apis import init_detector
11
- from demo.sam2_utils import prepare_model as prepare_sam2_model
12
- from demo.sam2_utils import process_image_with_SAM
13
-
14
- from mmpose.apis import init_model as init_pose_estimator
15
- from mmpose.utils import adapt_mmdet_pipeline
16
-
17
- # Default thresholds
18
- DEFAULT_CAT_ID: int = 0
19
-
20
- DEFAULT_BBOX_THR: float = 0.3
21
- DEFAULT_NMS_THR: float = 0.3
22
- DEFAULT_KPT_THR: float = 0.3
23
-
24
- # Global models variable
25
- det_model = None
26
- pose_model = None
27
- sam2_model = None
28
-
29
- def _parse_yaml_config(yaml_path: Path) -> DotDict:
30
- """
31
- Load BMP configuration from a YAML file.
32
-
33
- Args:
34
- yaml_path (Path): Path to YAML config.
35
- Returns:
36
- DotDict: Nested config dictionary.
37
- """
38
- with open(yaml_path, "r") as f:
39
- cfg = yaml.safe_load(f)
40
- return DotDict(cfg)
41
-
42
- def load_models(bmp_config):
43
- device = 'cuda:0'
44
-
45
- global det_model, pose_model, sam2_model
46
-
47
- # build detectors
48
- det_model = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device='cpu') # Detect with CPU because of installation issues on HF
49
- det_model.cfg = adapt_mmdet_pipeline(det_model.cfg)
50
-
51
-
52
- # build pose estimator
53
- pose_model = init_pose_estimator(
54
- bmp_config.pose_estimator.pose_config,
55
- bmp_config.pose_estimator.pose_checkpoint,
56
- device=device,
57
- cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))),
58
- )
59
 
60
- sam2_model = prepare_sam2_model(
61
- model_cfg=bmp_config.sam2.sam2_config,
62
- model_checkpoint=bmp_config.sam2.sam2_checkpoint,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
-
65
- return det_model, pose_model, sam2_model
 
 
 
 
 
 
 
 
 
66
 
67
  @spaces.GPU(duration=60)
68
  def process_image_with_BMP(
69
  img: np.ndarray
70
  ) -> tuple[np.ndarray, np.ndarray]:
71
  """
72
- Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization.
73
-
74
- Args:
75
- args (Namespace): Parsed CLI arguments.
76
- bmp_config (DotDict): Configuration parameters.
77
- img_path (Path): Path to the input image.
78
- detector: Primary MMDetection model.
79
- detector_prime: Secondary MMDetection model for iterations.
80
- pose_estimator: MMPose model for keypoint estimation.
81
- sam2_model: SAM model for mask refinement.
82
- Returns:
83
- InstanceData: Final merged detections and refined masks.
84
  """
85
- bmp_config = _parse_yaml_config(Path("configs/bmp_D3.yaml"))
86
- load_models(bmp_config)
87
-
88
- # img: RGB -> BGR
89
- img = img[..., ::-1]
90
-
91
- img_for_detection = img.copy()
92
- rtmdet_result = None
93
- all_detections = None
94
- for iteration in range(bmp_config.num_bmp_iters):
95
-
96
- # Step 1: Detection
97
- det_instances = run_MMDetector(
98
- det_model,
99
- img_for_detection,
100
- det_cat_id=DEFAULT_CAT_ID,
101
- bbox_thr=DEFAULT_BBOX_THR,
102
- nms_thr=DEFAULT_NMS_THR,
103
- )
104
- if len(det_instances.bboxes) == 0:
105
- continue
106
-
107
- # Step 2: Pose estimation
108
- pose_instances = run_MMPose(
109
- pose_model,
110
- img.copy(),
111
- detections=det_instances,
112
- kpt_thr=DEFAULT_KPT_THR,
113
- )
114
-
115
- # Restrict to first 17 COCO keypoints
116
- pose_instances.keypoints = pose_instances.keypoints[:, :17, :]
117
- pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17]
118
- pose_instances.keypoints = np.concatenate(
119
- [pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1
120
- )
121
-
122
- # Step 3: Pose-NMS and SAM refinement
123
- all_keypoints = (
124
- pose_instances.keypoints
125
- if all_detections is None
126
- else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0)
127
- )
128
- all_bboxes = (
129
- pose_instances.bboxes
130
- if all_detections is None
131
- else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
132
- )
133
- num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1)
134
- keep_indices = pose_nms(
135
- DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}),
136
- image_kpts=all_keypoints,
137
- image_bboxes=all_bboxes,
138
- num_valid_kpts=num_valid_kpts,
139
- )
140
- keep_indices = sorted(keep_indices) # Sort by original index
141
- num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
142
- keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
143
- keep_old_indices = [i for i in keep_indices if i < num_old_detections]
144
- if len(keep_new_indices) == 0:
145
- continue
146
- # filter new detections and compute scores
147
- new_dets = filter_instances(pose_instances, keep_new_indices)
148
- new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
149
- old_dets = None
150
- if len(keep_old_indices) > 0:
151
- old_dets = filter_instances(all_detections, keep_old_indices)
152
-
153
- new_detections = process_image_with_SAM(
154
- DotDict(bmp_config.sam2.prompting),
155
- img.copy(),
156
- sam2_model,
157
- new_dets,
158
- old_dets if old_dets is not None else None,
159
- )
160
-
161
- # Merge detections
162
- if all_detections is None:
163
- all_detections = new_detections
164
- else:
165
- all_detections = concat_instances(all_detections, new_dets)
166
-
167
- # Step 4: Visualization
168
- img_for_detection, rtmdet_r, _ = visualize_demo(
169
- img.copy(),
170
- all_detections,
171
- )
172
-
173
- if iteration == 0:
174
- rtmdet_result = rtmdet_r
175
-
176
- _, _, bmp_result = visualize_demo(
177
- img.copy(),
178
- all_detections,
179
  )
 
 
 
 
 
 
 
 
 
180
 
181
- # img: BGR -> RGB
182
- rtmdet_result = rtmdet_result[..., ::-1]
183
- bmp_result = bmp_result[..., ::-1]
184
 
185
- return rtmdet_result, bmp_result
 
186
 
187
 
188
  with gr.Blocks() as app:
@@ -281,4 +297,4 @@ with gr.Blocks() as app:
281
  )
282
 
283
  # Launch the demo
284
- app.launch()
 
1
+ from typing import Any
 
 
 
2
 
3
+ import cv2
4
+ import gradio as gr
5
  import numpy as np
6
+ import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ from bboxmaskpose import BBoxMaskPose
9
+
10
+ # Global BMP model singleton
11
+ bmp_model = None
12
+ bmp_model_config = "bmp_v2"
13
+ bmp_model_device = "cuda:0"
14
+
15
+
16
+ def _to_numpy(value: Any):
17
+ """Convert model outputs to numpy arrays when needed."""
18
+ if value is None:
19
+ return None
20
+ if isinstance(value, np.ndarray):
21
+ return value
22
+ if hasattr(value, "detach"):
23
+ return value.detach().cpu().numpy()
24
+ if hasattr(value, "cpu") and hasattr(value, "numpy"):
25
+ return value.cpu().numpy()
26
+ return np.asarray(value)
27
+
28
+
29
+ def _empty_result(height: int, width: int) -> dict[str, np.ndarray]:
30
+ """Create an empty result dictionary compatible with BBoxMaskPose.visualize()."""
31
+ return {
32
+ "bboxes": np.zeros((0, 4), dtype=np.float32),
33
+ "masks": np.zeros((0, height, width), dtype=np.uint8),
34
+ "keypoints": np.zeros((0, 17, 3), dtype=np.float32),
35
+ "presence": np.zeros((0, 17), dtype=np.float32),
36
+ "visibility": np.zeros((0, 17), dtype=np.float32),
37
+ }
38
+
39
+
40
+ def _normalize_result(result: dict, height: int, width: int) -> dict[str, np.ndarray]:
41
+ """Normalize prediction dictionary into a robust shape for visualization."""
42
+ bboxes = _to_numpy(result.get("bboxes"))
43
+ keypoints = _to_numpy(result.get("keypoints"))
44
+ masks = _to_numpy(result.get("masks"))
45
+ presence = _to_numpy(result.get("presence"))
46
+ visibility = _to_numpy(result.get("visibility"))
47
+
48
+ if bboxes is None:
49
+ bboxes = np.zeros((0, 4), dtype=np.float32)
50
+ bboxes = np.asarray(bboxes, dtype=np.float32).reshape(-1, 4)
51
+ num_instances = bboxes.shape[0]
52
+
53
+ if keypoints is None:
54
+ keypoints = np.zeros((num_instances, 17, 3), dtype=np.float32)
55
+ keypoints = np.asarray(keypoints, dtype=np.float32)
56
+ if keypoints.ndim == 2:
57
+ keypoints = keypoints[None, ...]
58
+ if keypoints.shape[0] != num_instances:
59
+ keypoints = np.zeros((num_instances, 17, 3), dtype=np.float32)
60
+ if keypoints.shape[1] > 17:
61
+ keypoints = keypoints[:, :17, :]
62
+ if keypoints.shape[1] < 17:
63
+ padded = np.zeros((num_instances, 17, 3), dtype=np.float32)
64
+ padded[:, : keypoints.shape[1], : min(keypoints.shape[2], 3)] = keypoints[:, :, :3]
65
+ keypoints = padded
66
+ if keypoints.shape[2] == 2:
67
+ scores = np.ones((num_instances, 17, 1), dtype=np.float32)
68
+ keypoints = np.concatenate([keypoints, scores], axis=-1)
69
+ elif keypoints.shape[2] > 3:
70
+ keypoints = keypoints[:, :, :3]
71
+
72
+ if masks is None:
73
+ masks = np.zeros((num_instances, height, width), dtype=np.uint8)
74
+ masks = np.asarray(masks)
75
+ if masks.ndim == 2:
76
+ masks = masks[None, ...]
77
+ if masks.ndim == 4 and masks.shape[-1] == 1:
78
+ masks = masks.squeeze(-1)
79
+ if masks.shape[0] != num_instances:
80
+ masks = np.zeros((num_instances, height, width), dtype=np.uint8)
81
+ masks = masks.astype(np.uint8)
82
+
83
+ if presence is None:
84
+ presence = keypoints[:, :, 2]
85
+ presence = np.asarray(presence, dtype=np.float32).reshape(num_instances, -1)
86
+ if presence.shape[1] > 17:
87
+ presence = presence[:, :17]
88
+ if presence.shape[1] < 17:
89
+ padded_presence = np.zeros((num_instances, 17), dtype=np.float32)
90
+ padded_presence[:, : presence.shape[1]] = presence
91
+ presence = padded_presence
92
+
93
+ if visibility is None:
94
+ visibility = keypoints[:, :, 2]
95
+ visibility = np.asarray(visibility, dtype=np.float32).reshape(num_instances, -1)
96
+ if visibility.shape[1] > 17:
97
+ visibility = visibility[:, :17]
98
+ if visibility.shape[1] < 17:
99
+ padded_visibility = np.zeros((num_instances, 17), dtype=np.float32)
100
+ padded_visibility[:, : visibility.shape[1]] = visibility
101
+ visibility = padded_visibility
102
+
103
+ return {
104
+ "bboxes": bboxes,
105
+ "masks": masks,
106
+ "keypoints": keypoints,
107
+ "presence": presence,
108
+ "visibility": visibility,
109
+ }
110
+
111
+
112
+ def _extract_baseline_result(intermediates: Any, fallback_result: dict, height: int, width: int) -> dict[str, np.ndarray]:
113
+ """Build baseline result from first intermediate pose output."""
114
+ if not intermediates:
115
+ return _normalize_result(fallback_result, height, width)
116
+
117
+ first_intermediate = intermediates[0] if len(intermediates) > 0 else None
118
+ if first_intermediate is None:
119
+ return _normalize_result(fallback_result, height, width)
120
+
121
+ pose_instances = first_intermediate.get("poses")
122
+ if pose_instances is None:
123
+ return _normalize_result(fallback_result, height, width)
124
+
125
+ result = {
126
+ "bboxes": _to_numpy(getattr(pose_instances, "bboxes", None)),
127
+ "keypoints": _to_numpy(getattr(pose_instances, "keypoints", None)),
128
+ "masks": _to_numpy(getattr(pose_instances, "masks", None)),
129
+ "presence": _to_numpy(getattr(pose_instances, "keypoint_prob", None)),
130
+ "visibility": _to_numpy(getattr(pose_instances, "keypoint_vis", None)),
131
+ }
132
+ return _normalize_result(result, height, width)
133
+
134
+
135
+ def _blend_pose_and_mask(model: BBoxMaskPose, image_bgr: np.ndarray, result: dict[str, np.ndarray]) -> np.ndarray:
136
+ """Render pose and mask overlays, then blend them into one image."""
137
+ pose_vis = model.visualize(image=image_bgr, result=result, vis_type="pose")
138
+ mask_vis = model.visualize(image=image_bgr, result=result, vis_type="mask")
139
+ return cv2.addWeighted(pose_vis, 0.5, mask_vis, 0.5, 0.0)
140
+
141
+
142
+ def _get_bmp_model(config_name: str = "bmp_v2", device: str = "cuda:0") -> BBoxMaskPose:
143
+ """Lazily initialize and reuse BBoxMaskPose model."""
144
+ global bmp_model, bmp_model_config, bmp_model_device
145
+
146
+ should_rebuild = (
147
+ bmp_model is None
148
+ or bmp_model_config != config_name
149
+ or bmp_model_device != device
150
  )
151
+ if should_rebuild:
152
+ try:
153
+ bmp_model = BBoxMaskPose(config=config_name, device=device)
154
+ bmp_model_config = config_name
155
+ bmp_model_device = device
156
+ except Exception as exc:
157
+ raise RuntimeError(
158
+ f"Failed to initialize BBoxMaskPose with config='{config_name}', device='{device}'."
159
+ ) from exc
160
+
161
+ return bmp_model
162
 
163
  @spaces.GPU(duration=60)
164
  def process_image_with_BMP(
165
  img: np.ndarray
166
  ) -> tuple[np.ndarray, np.ndarray]:
167
  """
168
+ Run BMP inference using the public BBoxMaskPose API.
169
+
170
+ The function keeps the original Gradio interface:
171
+ - output 1: baseline-style result from first intermediate pass
172
+ - output 2: final BMP-refined result
 
 
 
 
 
 
 
173
  """
174
+ if img is None:
175
+ raise ValueError("Input image is None.")
176
+
177
+ # Gradio image is RGB; BMP API expects BGR.
178
+ img_bgr = img[..., ::-1].copy()
179
+ height, width = img_bgr.shape[:2]
180
+
181
+ model = _get_bmp_model(config_name=bmp_model_config, device=bmp_model_device)
182
+ final_result = model.predict(
183
+ image=img_bgr,
184
+ bboxes=None,
185
+ return_intermediates=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  )
187
+ normalized_final = _normalize_result(final_result, height, width)
188
+
189
+ # No-detection robustness: return original image for both outputs.
190
+ if normalized_final["bboxes"].shape[0] == 0:
191
+ original_rgb = img_bgr[..., ::-1]
192
+ return original_rgb, original_rgb
193
+
194
+ intermediates = final_result.get("intermediates", [])
195
+ baseline_result = _extract_baseline_result(intermediates, normalized_final, height, width)
196
 
197
+ baseline_vis = _blend_pose_and_mask(model, img_bgr, baseline_result)
198
+ bmp_vis = _blend_pose_and_mask(model, img_bgr, normalized_final)
 
199
 
200
+ # BGR -> RGB for Gradio
201
+ return baseline_vis[..., ::-1], bmp_vis[..., ::-1]
202
 
203
 
204
  with gr.Blocks() as app:
 
297
  )
298
 
299
  # Launch the demo
300
+ app.launch()
bboxmaskpose/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) authors of BBoxMaskPose (BMPv2). All rights reserved.
2
+ """
3
+ BBoxMaskPose package - Public API for detection, pose estimation, and segmentation.
4
+
5
+ This package provides a stable wrapper for the full BBoxMaskPose pipeline.
6
+ """
7
+
8
+ from .api import BBoxMaskPose
9
+
10
+ __all__ = ["BBoxMaskPose"]
bboxmaskpose/api.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) authors of BBoxMaskPose (BMPv2). All rights reserved.
2
+
3
+ """
4
+ Public API for BBoxMaskPose wrapper.
5
+
6
+ This module provides a stable, user-friendly interface for the full
7
+ BBoxMaskPose pipeline: detection, pose estimation, and mask refinement.
8
+ """
9
+
10
+ import glob
11
+ import os
12
+ from pathlib import Path
13
+ from typing import Dict, List, Optional, Union
14
+
15
+ import cv2
16
+ import mmengine
17
+ import numpy as np
18
+ import yaml
19
+ from mmdet.apis import inference_detector, init_detector
20
+ from mmengine.structures import InstanceData
21
+
22
+ from .demo_utils import DotDict, _visualize_predictions, concat_instances, filter_instances, pose_nms
23
+ from .posevis_lite import pose_visualization
24
+
25
+ # Import from BBoxMaskPose package
26
+ from .sam2_utils import prepare_model as prepare_sam2_model, process_image_with_SAM
27
+
28
+ BMP_ROOT = Path(__file__).parent.parent
29
+
30
+ # Note: PMPose will be imported when needed to avoid circular imports
31
+ # from pmpose import PMPose
32
+
33
+ # Default detector and pose config
34
+ DEFAULT_DET_CAT_ID: int = 0
35
+ DEFAULT_BBOX_THR: float = 0.3
36
+ DEFAULT_NMS_THR: float = 0.3
37
+ DEFAULT_KPT_THR: float = 0.3
38
+
39
+ # Pretrained config URLs (for future use)
40
+ PRETRAINED_CONFIGS = {
41
+ "bmp-d3": "BMP_D3",
42
+ "bmp-j1": "BMP_J1",
43
+ }
44
+
45
+
46
+ class BBoxMaskPose:
47
+ """
48
+ Public wrapper API for BBoxMaskPose pipeline.
49
+
50
+ This class provides a complete pipeline for detection, pose estimation,
51
+ and mask refinement using SAM2.
52
+
53
+ Example:
54
+ >>> bmp_model = BBoxMaskPose(config="BMP_D3", device="cuda")
55
+ >>> result = bmp_model.predict(
56
+ ... image="path/to/image.jpg",
57
+ ... return_intermediates=True
58
+ ... )
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ config: str = "BMP_D3",
64
+ device: str = "cuda",
65
+ config_path: Optional[str] = None,
66
+ pose_model=None, # Type hint removed to avoid import at module level
67
+ pretrained_id: Optional[str] = None,
68
+ n_kpts_to_work_with: Optional[int] = 17,
69
+ ):
70
+ """
71
+ Initialize BBoxMaskPose model.
72
+
73
+ Args:
74
+ config (str): Config alias ('BMP_D3', 'BMP_J1'). Defaults to 'BMP_D3'.
75
+ device (str): Device for inference. Defaults to 'cuda'.
76
+ config_path (str, optional): Path to custom YAML config file.
77
+ pose_model (PMPose, optional): Pre-initialized PMPose instance.
78
+ If None, will create internal pose model.
79
+ pretrained_id (str, optional): Alias for pretrained config.
80
+ n_kpts_to_work_with (int, optional): Number of keypoints to work with.
81
+ Defaults to 17 (COCO keypoints).
82
+ """
83
+ self.device = device
84
+ self.config_name = config
85
+
86
+ self.n_kpts_to_work_with = 17 # Hard-code 17 as no experiments were done with other values. Keep the argument for future flexibility, but ignore it for now.
87
+ if n_kpts_to_work_with != 17:
88
+ print(
89
+ f"Warning: n_kpts_to_work_with is set to {n_kpts_to_work_with}, but currently only 17 keypoints are supported. Ignoring this argument for now."
90
+ )
91
+
92
+ # Determine config path
93
+ if config_path is not None:
94
+ self.config_path = config_path
95
+ else:
96
+ bmp_configs_root = os.path.join(BMP_ROOT, "bboxmaskpose", "configs")
97
+ config_file = f"{config}.yaml"
98
+ self.config_path = os.path.join(bmp_configs_root, config_file)
99
+
100
+ if not os.path.exists(self.config_path):
101
+ available_configs = glob.glob(os.path.join(bmp_configs_root, "*.yaml"))
102
+ available_configs = [os.path.basename(f).replace(".yaml", "") for f in available_configs]
103
+ raise FileNotFoundError(f"Config file not found: {self.config_path}. " f"Available configs: {', '.join(available_configs)}")
104
+
105
+ # Load config
106
+ self.config = self._load_config(self.config_path)
107
+
108
+ # Initialize or use provided pose model
109
+ if pose_model is not None:
110
+ self.pose_model = pose_model
111
+ self._owns_pose_model = False
112
+ else:
113
+ # Create internal PMPose instance
114
+ self.pose_model = self._create_pose_model()
115
+ self._owns_pose_model = True
116
+
117
+ # Initialize detector and SAM2
118
+ self.detector = None
119
+ self.detector_prime = None
120
+ self.sam2_model = None
121
+ self._initialize_models()
122
+
123
+ def _load_config(self, config_path: str) -> DotDict:
124
+ """Load BMP configuration from YAML file."""
125
+ with open(config_path, "r") as f:
126
+ cfg = yaml.safe_load(f)
127
+ return DotDict(cfg)
128
+
129
+ def _create_pose_model(self):
130
+ """Create internal PMPose model from config."""
131
+ # Import PMPose here to avoid circular imports
132
+ from pmpose import PMPose
133
+
134
+ # Extract pose config from BMP config
135
+ pose_config = self.config.pose_estimator.pose_config
136
+ pose_checkpoint = self.config.pose_estimator.pose_checkpoint
137
+
138
+ # Create PMPose instance with custom config
139
+ full_pose_config = str(BMP_ROOT / pose_config)
140
+
141
+ pose_model = PMPose(
142
+ device=self.device,
143
+ config_path=full_pose_config,
144
+ from_pretrained=True,
145
+ )
146
+
147
+ # Load checkpoint if it's a local path
148
+ if not pose_checkpoint.startswith("http"):
149
+ pose_model.load_from_file(pose_checkpoint)
150
+
151
+ return pose_model
152
+
153
+ def _initialize_models(self):
154
+ """Initialize detector and SAM2 models."""
155
+ # Initialize detector
156
+ self.detector = init_detector(self.config.detector.det_config, self.config.detector.det_checkpoint, device=self.device)
157
+
158
+ # Adapt detector pipeline
159
+ from mmpose.utils import adapt_mmdet_pipeline
160
+
161
+ self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg)
162
+
163
+ # Initialize detector prime (may be same as detector)
164
+ if (
165
+ self.config.detector.det_config == self.config.detector.det_prime_config
166
+ and self.config.detector.det_checkpoint == self.config.detector.det_prime_checkpoint
167
+ ) or (self.config.detector.det_prime_config is None or self.config.detector.det_prime_checkpoint is None):
168
+ self.detector_prime = self.detector
169
+ else:
170
+ self.detector_prime = init_detector(
171
+ self.config.detector.det_prime_config, self.config.detector.det_prime_checkpoint, device=self.device
172
+ )
173
+ self.detector_prime.cfg = adapt_mmdet_pipeline(self.detector_prime.cfg)
174
+
175
+ # Initialize SAM2
176
+ sam2_config_path = os.path.join(BMP_ROOT, "bboxmaskpose", "sam2", self.config.sam2.sam2_config)
177
+ self.sam2_model = prepare_sam2_model(
178
+ model_cfg=sam2_config_path,
179
+ model_checkpoint=self.config.sam2.sam2_checkpoint,
180
+ )
181
+
182
+ def predict(
183
+ self,
184
+ image: Union[str, np.ndarray],
185
+ bboxes: Optional[np.ndarray] = None,
186
+ return_intermediates: bool = False,
187
+ return_probmaps: bool = False,
188
+ ) -> Dict:
189
+ """
190
+ Run full BBoxMaskPose pipeline on image.
191
+
192
+ Args:
193
+ image: Image path (str) or BGR numpy array.
194
+ bboxes: Optional (N, 4) bboxes in [x1, y1, x2, y2] format.
195
+ If None, run detector.
196
+ return_intermediates: If True, return intermediate outputs.
197
+ return_probmaps: If True, request heatmaps from pose model.
198
+
199
+ Returns:
200
+ Dict with keys:
201
+ - 'bboxes': (N, 4) final bounding boxes
202
+ - 'masks': (N, H, W) refined binary masks
203
+ - 'keypoints': (N, K, 3) keypoints with scores
204
+ - 'presence': (N, K) presence probabilities
205
+ - 'visibility': (N, K) visibility flags
206
+ - 'detector': (optional) raw detector outputs
207
+ - 'sam2': (optional) intermediate SAM outputs
208
+ """
209
+ # Load image
210
+ if isinstance(image, str):
211
+ img = cv2.imread(image)
212
+ if img is None:
213
+ raise ValueError(f"Failed to load image from {image}")
214
+ else:
215
+ img = image.copy()
216
+
217
+ # Run BMP iterations
218
+ all_detections = None
219
+ intermediate_results = [] if return_intermediates else None
220
+
221
+ for iteration in range(self.config.num_bmp_iters):
222
+ # Step 1: Detection
223
+ if iteration == 0 and bboxes is not None:
224
+ # Use provided bboxes for first iteration
225
+ det_instances = InstanceData(bboxes=bboxes, bbox_scores=np.ones(len(bboxes)), masks=None)
226
+ else:
227
+ # Run detector
228
+ det_instances = self._run_detector(
229
+ self.detector if iteration == 0 else self.detector_prime,
230
+ img if all_detections is None else self._mask_out_image(img, all_detections),
231
+ )
232
+
233
+ if len(det_instances.bboxes) == 0:
234
+ continue
235
+
236
+ # Step 2: Pose estimation using PMPose wrapper
237
+ pose_results = self._run_pose_estimation(img, det_instances, return_probmaps=return_probmaps)
238
+
239
+ # Step 3: Pose NMS and SAM refinement
240
+ new_detections, old_detections = self._refine_with_sam(
241
+ img,
242
+ pose_results,
243
+ all_detections,
244
+ )
245
+
246
+ # Merge detections
247
+ if all_detections is None:
248
+ all_detections = new_detections
249
+ else:
250
+ all_detections = concat_instances(old_detections, new_detections)
251
+
252
+ # Store intermediates if requested
253
+ if return_intermediates:
254
+ intermediate_results.append(
255
+ {
256
+ "iteration": iteration,
257
+ "detections": det_instances,
258
+ "poses": pose_results,
259
+ "refined": new_detections,
260
+ }
261
+ )
262
+
263
+ # Prepare final result
264
+ result = self._format_result(all_detections, img.shape[:2])
265
+
266
+ if return_intermediates:
267
+ result["intermediates"] = intermediate_results
268
+
269
+ return result
270
+
271
+ def _run_detector(
272
+ self,
273
+ detector,
274
+ img: np.ndarray,
275
+ ) -> InstanceData:
276
+ """Run MMDetection detector."""
277
+ from mmpose.evaluation.functional import nms
278
+
279
+ # Run detection
280
+ det_result = inference_detector(detector, img)
281
+ pred_instances = det_result.pred_instances.cpu().numpy()
282
+
283
+ # Aggregate bboxes and scores
284
+ bboxes_all = np.concatenate((pred_instances.bboxes, pred_instances.scores[:, None]), axis=1)
285
+
286
+ # Filter by category and score
287
+ keep_mask = np.logical_and(pred_instances.labels == DEFAULT_DET_CAT_ID, pred_instances.scores > DEFAULT_BBOX_THR)
288
+
289
+ if not np.any(keep_mask):
290
+ return InstanceData(bboxes=np.zeros((0, 4)), bbox_scores=np.zeros((0,)), masks=np.zeros((0, 1, 1)))
291
+
292
+ bboxes = bboxes_all[keep_mask]
293
+ masks = getattr(pred_instances, "masks", None)
294
+ if masks is not None:
295
+ masks = masks[keep_mask]
296
+
297
+ # Sort by score
298
+ order = np.argsort(bboxes[:, 4])[::-1]
299
+ bboxes = bboxes[order]
300
+ if masks is not None:
301
+ masks = masks[order]
302
+
303
+ # Apply NMS
304
+ keep_indices = nms(bboxes, DEFAULT_NMS_THR)
305
+ bboxes = bboxes[keep_indices]
306
+ if masks is not None:
307
+ masks = masks[keep_indices]
308
+
309
+ return InstanceData(bboxes=bboxes[:, :4], bbox_scores=bboxes[:, 4], masks=masks)
310
+
311
+ def _run_pose_estimation(
312
+ self,
313
+ img: np.ndarray,
314
+ det_instances: InstanceData,
315
+ return_probmaps: bool = False,
316
+ ) -> InstanceData:
317
+ """Run pose estimation using PMPose wrapper."""
318
+ bboxes = det_instances.bboxes
319
+ masks = det_instances.masks
320
+
321
+ if len(bboxes) == 0:
322
+ return InstanceData(
323
+ keypoints=np.zeros((0, self.n_kpts_to_work_with, 3)),
324
+ keypoint_scores=np.zeros((0, self.n_kpts_to_work_with)),
325
+ bboxes=bboxes,
326
+ bbox_scores=det_instances.bbox_scores,
327
+ masks=masks,
328
+ )
329
+
330
+ # Call PMPose public API
331
+ keypoints, probabilities, visibilities, heatmaps = self.pose_model.predict(
332
+ img,
333
+ bboxes,
334
+ masks=masks,
335
+ return_probmaps=return_probmaps,
336
+ )
337
+
338
+ # Restrict to first 17 COCO keypoints
339
+ keypoints = keypoints[:, : self.n_kpts_to_work_with, :]
340
+ probabilities = probabilities[:, : self.n_kpts_to_work_with]
341
+ visibilities = visibilities[:, : self.n_kpts_to_work_with]
342
+
343
+ if heatmaps is not None:
344
+ heatmaps = heatmaps[:, : self.n_kpts_to_work_with, :, :]
345
+
346
+ # Create InstanceData with results
347
+ result = InstanceData(
348
+ keypoints=keypoints,
349
+ keypoint_scores=keypoints[:, :, 2],
350
+ bboxes=bboxes,
351
+ bbox_scores=det_instances.bbox_scores,
352
+ masks=masks,
353
+ keypoint_vis=visibilities,
354
+ keypoint_prob=probabilities,
355
+ )
356
+
357
+ if return_probmaps and heatmaps is not None:
358
+ result.heatmaps = heatmaps
359
+
360
+ return result
361
+
362
+ def _refine_with_sam(
363
+ self,
364
+ img: np.ndarray,
365
+ pose_instances: InstanceData,
366
+ all_detections: Optional[InstanceData],
367
+ ) -> tuple:
368
+ """Perform Pose-NMS and SAM refinement."""
369
+ # Combine keypoints with scores
370
+ keypoints_with_scores = pose_instances.keypoints
371
+
372
+ # Perform Pose-NMS
373
+ all_keypoints = (
374
+ keypoints_with_scores if all_detections is None else np.concatenate([all_detections.keypoints, keypoints_with_scores], axis=0)
375
+ )
376
+ all_bboxes = (
377
+ pose_instances.bboxes if all_detections is None else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0)
378
+ )
379
+
380
+ num_valid_kpts = np.sum(all_keypoints[:, :, 2] > self.config.sam2.prompting.confidence_thr, axis=1)
381
+
382
+ keep_indices = pose_nms(
383
+ DotDict({"confidence_thr": self.config.sam2.prompting.confidence_thr, "oks_thr": self.config.oks_nms_thr}),
384
+ image_kpts=all_keypoints,
385
+ image_bboxes=all_bboxes,
386
+ num_valid_kpts=num_valid_kpts,
387
+ )
388
+
389
+ keep_indices = sorted(keep_indices)
390
+ num_old_detections = 0 if all_detections is None else len(all_detections.bboxes)
391
+ keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections]
392
+ keep_old_indices = [i for i in keep_indices if i < num_old_detections]
393
+
394
+ if len(keep_new_indices) == 0:
395
+ return None, all_detections
396
+
397
+ # Filter new detections
398
+ new_dets = filter_instances(pose_instances, keep_new_indices)
399
+ new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1)
400
+
401
+ old_dets = None
402
+ if len(keep_old_indices) > 0:
403
+ old_dets = filter_instances(all_detections, keep_old_indices)
404
+
405
+ # Run SAM refinement
406
+ new_detections = process_image_with_SAM(
407
+ DotDict(self.config.sam2.prompting),
408
+ img.copy(),
409
+ self.sam2_model,
410
+ new_dets,
411
+ old_dets if old_dets is not None else None,
412
+ )
413
+
414
+ return new_detections, old_dets
415
+
416
+ def _mask_out_image(
417
+ self,
418
+ img: np.ndarray,
419
+ detections: InstanceData,
420
+ ) -> np.ndarray:
421
+ """Mask out detected instances from image for next iteration."""
422
+ masked_img = img.copy()
423
+ if hasattr(detections, "refined_masks") and detections.refined_masks is not None:
424
+ for mask in detections.refined_masks:
425
+ if mask is not None:
426
+ masked_img[mask.astype(bool)] = 0
427
+ return masked_img
428
+
429
+ def _format_result(
430
+ self,
431
+ detections: Optional[InstanceData],
432
+ img_shape: tuple,
433
+ ) -> Dict:
434
+ """Format detection results into standard output dict."""
435
+ if detections is None or len(detections.bboxes) == 0:
436
+ return {
437
+ "bboxes": np.zeros((0, 4)),
438
+ "masks": np.zeros((0, img_shape[0], img_shape[1])),
439
+ "keypoints": np.zeros((0, 17, 3)),
440
+ "presence": np.zeros((0, 17)),
441
+ "visibility": np.zeros((0, 17)),
442
+ }
443
+
444
+ # Extract refined masks if available
445
+ if hasattr(detections, "refined_masks") and detections.refined_masks is not None:
446
+ masks = detections.refined_masks
447
+ elif hasattr(detections, "pred_masks") and detections.pred_masks is not None:
448
+ masks = detections.pred_masks
449
+ elif hasattr(detections, "masks") and detections.masks is not None:
450
+ masks = detections.masks
451
+ else:
452
+ masks = np.zeros((len(detections.bboxes), img_shape[0], img_shape[1]))
453
+
454
+ return {
455
+ "bboxes": detections.bboxes,
456
+ "masks": masks,
457
+ "keypoints": detections.keypoints,
458
+ "presence": detections.keypoint_prob,
459
+ "visibility": detections.keypoint_vis,
460
+ }
461
+
462
+ def visualize(
463
+ self,
464
+ image: Union[str, np.ndarray],
465
+ result: Dict,
466
+ save_path: Optional[str] = None,
467
+ vis_type: str = "pose",
468
+ ) -> np.ndarray:
469
+ """
470
+ Visualize BBoxMaskPose results on image.
471
+
472
+ Args:
473
+ image: Image path (str) or BGR numpy array.
474
+ result: Result dict from predict().
475
+ save_path: Optional path to save visualization.
476
+ vis_type: Type of visualization ("pose" or "mask").
477
+ Returns:
478
+ np.ndarray: Visualization image (BGR).
479
+ """
480
+ # Load image
481
+ if isinstance(image, str):
482
+ img = cv2.imread(image)
483
+ if img is None:
484
+ raise ValueError(f"Failed to load image from {image}")
485
+ else:
486
+ img = image.copy()
487
+
488
+ if vis_type == "mask":
489
+ vis_img, _ = _visualize_predictions(
490
+ img,
491
+ bboxes=result["bboxes"],
492
+ scores=np.ones(len(result["bboxes"])),
493
+ masks=result["masks"],
494
+ poses=result["keypoints"],
495
+ vis_type="mask",
496
+ mask_is_binary=True,
497
+ )
498
+ img = vis_img
499
+ else:
500
+ # Visualize using posevis_lite
501
+ keypoints = result["keypoints"]
502
+ keypoints = keypoints[:, :17, :] # Use first 17 COCO keypoints
503
+ img = pose_visualization(
504
+ img,
505
+ keypoints,
506
+ width_multiplier=8,
507
+ differ_individuals=True,
508
+ keep_image_size=True,
509
+ )
510
+
511
+ # Save if requested
512
+ if save_path is not None:
513
+ cv2.imwrite(save_path, img)
514
+
515
+ return img
{configs → bboxmaskpose/configs}/README.md RENAMED
File without changes
{configs → bboxmaskpose/configs}/bmp_D3.yaml RENAMED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  # BBoxMaskPose Hyperparameters from Experiment D3.
2
  # For details, see the paper: https://arxiv.org/abs/2412.01562, Tab 8. in the supplementary.
3
 
@@ -11,8 +16,10 @@ detector:
11
  det_prime_checkpoint: null
12
 
13
  pose_estimator:
14
- pose_config: 'mmpose/configs/MaskPose/ViTb-multi_mask.py'
15
- pose_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/MaskPose-b.pth'
 
 
16
 
17
  sam2:
18
  sam2_config: 'configs/samurai/sam2.1_hiera_b+.yaml' # Use SAMURAI as it has img_size 1024 (SAM-2.1 has 512)
 
1
+ ######################################################################################
2
+ ### THIS CONFIG IS DEPRACATED AND KEPT ONLY FOR REPRODUCTION OF BMPv1 EXPERIMENTS. ###
3
+ ### FOR BMPv2 EXPERIMENTS, PLEASE USE THE bmp_v2.yaml CONFIG. ###
4
+ ######################################################################################
5
+
6
  # BBoxMaskPose Hyperparameters from Experiment D3.
7
  # For details, see the paper: https://arxiv.org/abs/2412.01562, Tab 8. in the supplementary.
8
 
 
16
  det_prime_checkpoint: null
17
 
18
  pose_estimator:
19
+ pose_config: 'mmpose/configs/ProbMaskPose/PMPose-b-1.0.0.py'
20
+ pose_checkpoint: 'models/pose_estimators/PMPose-b-1.0.0.pth'
21
+ # pose_config: 'mmpose/configs/MaskPose/ViTb-multi_mask.py'
22
+ # pose_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/MaskPose-b.pth'
23
 
24
  sam2:
25
  sam2_config: 'configs/samurai/sam2.1_hiera_b+.yaml' # Use SAMURAI as it has img_size 1024 (SAM-2.1 has 512)
{configs → bboxmaskpose/configs}/bmp_J1.yaml RENAMED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  # BBoxMaskPose Hyperparameters from Experiment J1.
2
  # For details, see the paper: https://arxiv.org/abs/2412.01562, Tab 8. in the supplementary.
3
 
 
1
+ ######################################################################################
2
+ ### THIS CONFIG IS DEPRACATED AND KEPT ONLY FOR REPRODUCTION OF BMPv1 EXPERIMENTS. ###
3
+ ### FOR BMPv2 EXPERIMENTS, PLEASE USE THE bmp_v2.yaml CONFIG. ###
4
+ ######################################################################################
5
+
6
  # BBoxMaskPose Hyperparameters from Experiment J1.
7
  # For details, see the paper: https://arxiv.org/abs/2412.01562, Tab 8. in the supplementary.
8
 
bboxmaskpose/configs/bmp_v2.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This configuration is good for the BMP loop as was used for most of the experiments.
2
+ detector:
3
+ det_config: 'mmpose/configs/mmdet/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py'
4
+ det_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/rtmdet-ins-l-mask.pth'
5
+
6
+ # Detectors D and D' could be different.
7
+ det_prime_config: null
8
+ det_prime_checkpoint: null
9
+
10
+ pose_estimator:
11
+ pose_config: 'mmpose/configs/ProbMaskPose/PMPose-b-1.0.0.py'
12
+ pose_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/PMPose/PMPose-b-1.0.0.pth'
13
+
14
+ sam2:
15
+ sam2_config: 'configs/sam-pose2seg/sam-pose2seg_hiera_b+.yaml'
16
+ sam2_checkpoint: 'https://huggingface.co/vrg-prague/BBoxMaskPose/resolve/main/SAM-pose2seg_hiera_b%2B.pt'
17
+ prompting:
18
+ batch: False
19
+ use_bbox: False
20
+ num_pos_keypoints: 3
21
+ num_pos_keypoints_if_crowd: 3
22
+ num_neg_keypoints: 0
23
+ confidence_thr: 0.5 # not used
24
+ visibility_thr: 0.5 # not used
25
+ selection_method: 'k_most_visible'
26
+ extend_bbox: False
27
+ pose_mask_consistency: False
28
+ crowd_by_max_iou: False # Determine if the instance is in the multi-body scenario. If yes, use different amount of keypoints and NO BBOX. If no, use bbox according to 'use_bbox' argument.
29
+ crop: False
30
+ exclusive_masks: True
31
+ ignore_small_bboxes: False
32
+
33
+ num_bmp_iters: 2
34
+ oks_nms_thr: 0.8
{demo → bboxmaskpose}/demo_utils.py RENAMED
@@ -1,3 +1,4 @@
 
1
  """
2
  Utilities for the BMP demo:
3
  - Visualization of detections, masks, and poses
@@ -18,9 +19,10 @@ import numpy as np
18
  from mmengine.logging import print_log
19
  from mmengine.structures import InstanceData
20
  from pycocotools import mask as Mask
21
- from sam2.distinctipy import get_colors
22
  from tqdm import tqdm
23
 
 
 
24
  ### Visualization hyperparameters
25
  MIN_CONTOUR_AREA: int = 50
26
  BBOX_WEIGHT: float = 0.9
@@ -38,6 +40,21 @@ except ImportError:
38
  from .posevis_lite import pose_visualization
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  class DotDict(dict):
42
  """Dictionary with attribute access and nested dict wrapping."""
43
 
@@ -68,17 +85,7 @@ def filter_instances(instances: InstanceData, indices):
68
  return None
69
  data = {}
70
  # Attributes to filter
71
- for attr in [
72
- "bboxes",
73
- "bbox_scores",
74
- "keypoints",
75
- "keypoint_scores",
76
- "scores",
77
- "pred_masks",
78
- "refined_masks",
79
- "sam_scores",
80
- "sam_kpts",
81
- ]:
82
  if hasattr(instances, attr):
83
  arr = getattr(instances, attr)
84
  data[attr] = arr[indices] if arr is not None else None
@@ -95,17 +102,7 @@ def concat_instances(instances1: InstanceData, instances2: InstanceData):
95
  if instances2 is None:
96
  return instances1
97
  data = {}
98
- for attr in [
99
- "bboxes",
100
- "bbox_scores",
101
- "keypoints",
102
- "keypoint_scores",
103
- "scores",
104
- "pred_masks",
105
- "refined_masks",
106
- "sam_scores",
107
- "sam_kpts",
108
- ]:
109
  arr1 = getattr(instances1, attr, None)
110
  arr2 = getattr(instances2, attr, None)
111
  if arr1 is None and arr2 is None:
@@ -145,43 +142,20 @@ def _visualize_predictions(
145
  """
146
  vis_types = vis_type.split("+")
147
 
148
- # # Filter-out small detections to make the visualization more clear
149
- # new_bboxes = []
150
- # new_scores = []
151
- # new_masks = []
152
- # new_poses = []
153
- # size_thr = img.shape[0] * img.shape[1] * 0.01
154
- # for bbox, score, mask, pose in zip(bboxes, scores, masks, poses):
155
- # area = mask.sum() # Assume binary mask. OK for demo purposes
156
- # if area > size_thr:
157
- # new_bboxes.append(bbox)
158
- # new_scores.append(score)
159
- # new_masks.append(mask)
160
- # new_poses.append(pose)
161
- # bboxes = np.array(new_bboxes)
162
- # scores = np.array(new_scores)
163
- # masks = new_masks
164
- # poses = new_poses
165
-
166
  if mask_is_binary:
167
  poly_masks: List[Optional[List[np.ndarray]]] = []
168
  for binary_mask in masks:
169
  if binary_mask is not None:
170
- contours, _ = cv2.findContours(
171
- (binary_mask * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
172
- )
173
  polys = [cnt.flatten() for cnt in contours if cv2.contourArea(cnt) >= MIN_CONTOUR_AREA]
174
  else:
175
  polys = None
176
  poly_masks.append(polys)
177
  masks = poly_masks # type: ignore
178
 
179
- # Exclude white, black, and green colors from the palette as they are not distinctive
180
- colors = (np.array(get_colors(len(bboxes), exclude_colors=[(0, 1, 0), (.5, .5, .5), (0, 0, 0), (1, 1, 1)], rng=0)) * 255).astype(
181
- int
182
- )
183
-
184
-
185
  if "inv-mask" in vis_types:
186
  stencil = np.zeros_like(img)
187
 
@@ -272,9 +246,7 @@ def visualize_itteration(
272
  label = "BMP {:d}x: {}".format(iteration_idx + 1, vis_def["label"])
273
  cv2.putText(vis_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
274
  cv2.putText(vis_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
275
- out_path = os.path.join(
276
- output_root, "{}_iter{}_{}.jpg".format(img_name, iteration_idx + 1, vis_def["label"].replace(" ", "_"))
277
- )
278
  cv2.imwrite(str(out_path), vis_img)
279
 
280
  # Show prompting keypoints
@@ -311,43 +283,6 @@ def visualize_itteration(
311
  return masked_out
312
 
313
 
314
- def visualize_demo(
315
- img: np.ndarray, detections: Any,
316
- ) -> Optional[np.ndarray]:
317
- """
318
- Generate and save visualization images for each BMP iteration.
319
-
320
- Args:
321
- img (np.ndarray): Original input image.
322
- detections: InstanceData containing bboxes, scores, masks, keypoints.
323
- iteration_idx (int): Current iteration index (0-based).
324
- output_root (Path): Directory to save output images.
325
- img_name (str): Base name of the image without extension.
326
- with_text (bool): Whether to overlay text labels.
327
-
328
- Returns:
329
- Optional[np.ndarray]: The masked-out image if generated, else None.
330
- """
331
- bboxes = detections.bboxes
332
- scores = detections.scores
333
- pred_masks = detections.pred_masks
334
- refined_masks = detections.refined_masks
335
- keypoints = detections.keypoints
336
-
337
- returns = []
338
- for vis_def in [
339
- {"type": "mask-out", "masks": refined_masks, "label": ""},
340
- {"type": "mask+pose", "masks": pred_masks, "label": "RTMDet-L"},
341
- {"type": "mask+pose", "masks": refined_masks, "label": "BMP"},
342
- ]:
343
- vis_img, colors = _visualize_predictions(
344
- img.copy(), bboxes, scores, vis_def["masks"], keypoints, vis_type=vis_def["type"], mask_is_binary=True
345
- )
346
- returns.append(vis_img)
347
-
348
- return returns
349
-
350
-
351
  def create_GIF(
352
  img_path: Path,
353
  output_root: Path,
@@ -419,7 +354,6 @@ def create_GIF(
419
  # Add 'before' and 'after' images
420
  after1_img = os.path.join(dirname, "{}_iter{}_Final_Poses.jpg".format(img_name_wo_ext, bmp_x))
421
  after2_img = os.path.join(dirname, "{}_iter{}_SAM_Masks.jpg".format(img_name_wo_ext, bmp_x))
422
- # gif_images.append(os.path.join(dirname, "black_image.jpg")) # Add black image at the end
423
  gif_images.append(after1_img)
424
  gif_images.append(after2_img)
425
  gif_images.append(os.path.join(dirname, "black_image.jpg")) # Add black image at the end
@@ -457,10 +391,7 @@ def create_GIF(
457
  right = "[{}:v]".format(i)
458
  out = "[v{}]".format(i)
459
  offset = (i - 1) * (display_dur + fade_dur) + display_dur
460
- parts.append(
461
- "{}{}xfade=transition=fade:".format(left, right)
462
- + "duration={}:offset={:.3f}{}".format(fade_dur, offset, out)
463
- )
464
  filter_complex = ";".join(parts)
465
 
466
  # 3. make MP4 slideshow
@@ -544,9 +475,7 @@ def create_GIF(
544
  print_log(f"GIF saved as '{gif_output_path}'", logger="current")
545
 
546
 
547
- def _update_bbox_by_mask(
548
- bbox: List[int], mask_poly: Optional[List[List[int]]], image_shape: Tuple[int, int, int]
549
- ) -> List[int]:
550
  """
551
  Adjust bounding box to tightly fit mask polygon.
552
 
@@ -591,11 +520,6 @@ def pose_nms(config: Any, image_kpts: np.ndarray, image_bboxes: np.ndarray, num_
591
  Returns:
592
  np.ndarray: Indices of kept instances.
593
  """
594
- # Sort image kpts by average score - lowest first
595
- # scores = image_kpts[:, :, 2].mean(axis=1)
596
- # sort_idx = np.argsort(scores)
597
- # image_kpts = image_kpts[sort_idx, :, :]
598
-
599
  # Compute OKS between all pairs of poses
600
  oks_matrix = np.zeros((image_kpts.shape[0], image_kpts.shape[0]))
601
  for i in range(image_kpts.shape[0]):
@@ -611,8 +535,7 @@ def pose_nms(config: Any, image_kpts: np.ndarray, image_bboxes: np.ndarray, num_
611
  dt = {"keypoints": image_kpts[j].copy(), "bbox": gt_bbox_xyxy}
612
  gt["keypoints"][:, 2] = (gt["keypoints"][:, 2] > config.confidence_thr) * 2
613
  oks = compute_oks(gt, dt)
614
- if oks > 1:
615
- breakpoint()
616
  oks_matrix[i, j] = oks
617
 
618
  np.fill_diagonal(oks_matrix, -1)
@@ -653,13 +576,10 @@ def compute_oks(gt: Dict[str, Any], dt: Dict[str, Any], use_area: bool = True, p
653
  Returns:
654
  float: OKS score or mean OKS.
655
  """
656
- sigmas = (
657
- np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
658
- / 10.0
659
- )
660
  vars = (sigmas * 2) ** 2
661
  k = len(sigmas)
662
- visibility_condition = lambda x: x > 0
663
  g = np.array(gt["keypoints"]).reshape(k, 3)
664
  xg = g[:, 0]
665
  yg = g[:, 1]
 
1
+ # Copyright (c) authors of BBoxMaskPose (BMPv2). All rights reserved.
2
  """
3
  Utilities for the BMP demo:
4
  - Visualization of detections, masks, and poses
 
19
  from mmengine.logging import print_log
20
  from mmengine.structures import InstanceData
21
  from pycocotools import mask as Mask
 
22
  from tqdm import tqdm
23
 
24
+ from bboxmaskpose.sam2.distinctipy import get_colors
25
+
26
  ### Visualization hyperparameters
27
  MIN_CONTOUR_AREA: int = 50
28
  BBOX_WEIGHT: float = 0.9
 
40
  from .posevis_lite import pose_visualization
41
 
42
 
43
+ WHITELIST_ATTRIBUTES = [
44
+ "bboxes",
45
+ "bbox_scores",
46
+ "keypoints",
47
+ "keypoint_scores",
48
+ "scores",
49
+ "pred_masks",
50
+ "refined_masks",
51
+ "sam_scores",
52
+ "sam_kpts",
53
+ "keypoint_vis",
54
+ "keypoint_prob",
55
+ ]
56
+
57
+
58
  class DotDict(dict):
59
  """Dictionary with attribute access and nested dict wrapping."""
60
 
 
85
  return None
86
  data = {}
87
  # Attributes to filter
88
+ for attr in WHITELIST_ATTRIBUTES:
 
 
 
 
 
 
 
 
 
 
89
  if hasattr(instances, attr):
90
  arr = getattr(instances, attr)
91
  data[attr] = arr[indices] if arr is not None else None
 
102
  if instances2 is None:
103
  return instances1
104
  data = {}
105
+ for attr in WHITELIST_ATTRIBUTES:
 
 
 
 
 
 
 
 
 
 
106
  arr1 = getattr(instances1, attr, None)
107
  arr2 = getattr(instances2, attr, None)
108
  if arr1 is None and arr2 is None:
 
142
  """
143
  vis_types = vis_type.split("+")
144
 
145
+ # Exclude white, black, and green colors from the palette as they are not distinctive
146
+ colors = (np.array(get_colors(len(bboxes), exclude_colors=[(0, 1, 0), (0, 0, 0), (1, 1, 1)], rng=0)) * 255).astype(int)
147
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  if mask_is_binary:
149
  poly_masks: List[Optional[List[np.ndarray]]] = []
150
  for binary_mask in masks:
151
  if binary_mask is not None:
152
+ contours, _ = cv2.findContours((binary_mask * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
 
153
  polys = [cnt.flatten() for cnt in contours if cv2.contourArea(cnt) >= MIN_CONTOUR_AREA]
154
  else:
155
  polys = None
156
  poly_masks.append(polys)
157
  masks = poly_masks # type: ignore
158
 
 
 
 
 
 
 
159
  if "inv-mask" in vis_types:
160
  stencil = np.zeros_like(img)
161
 
 
246
  label = "BMP {:d}x: {}".format(iteration_idx + 1, vis_def["label"])
247
  cv2.putText(vis_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
248
  cv2.putText(vis_img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
249
+ out_path = os.path.join(output_root, "{}_iter{}_{}.jpg".format(img_name, iteration_idx + 1, vis_def["label"].replace(" ", "_")))
 
 
250
  cv2.imwrite(str(out_path), vis_img)
251
 
252
  # Show prompting keypoints
 
283
  return masked_out
284
 
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  def create_GIF(
287
  img_path: Path,
288
  output_root: Path,
 
354
  # Add 'before' and 'after' images
355
  after1_img = os.path.join(dirname, "{}_iter{}_Final_Poses.jpg".format(img_name_wo_ext, bmp_x))
356
  after2_img = os.path.join(dirname, "{}_iter{}_SAM_Masks.jpg".format(img_name_wo_ext, bmp_x))
 
357
  gif_images.append(after1_img)
358
  gif_images.append(after2_img)
359
  gif_images.append(os.path.join(dirname, "black_image.jpg")) # Add black image at the end
 
391
  right = "[{}:v]".format(i)
392
  out = "[v{}]".format(i)
393
  offset = (i - 1) * (display_dur + fade_dur) + display_dur
394
+ parts.append("{}{}xfade=transition=fade:".format(left, right) + "duration={}:offset={:.3f}{}".format(fade_dur, offset, out))
 
 
 
395
  filter_complex = ";".join(parts)
396
 
397
  # 3. make MP4 slideshow
 
475
  print_log(f"GIF saved as '{gif_output_path}'", logger="current")
476
 
477
 
478
+ def _update_bbox_by_mask(bbox: List[int], mask_poly: Optional[List[List[int]]], image_shape: Tuple[int, int, int]) -> List[int]:
 
 
479
  """
480
  Adjust bounding box to tightly fit mask polygon.
481
 
 
520
  Returns:
521
  np.ndarray: Indices of kept instances.
522
  """
 
 
 
 
 
523
  # Compute OKS between all pairs of poses
524
  oks_matrix = np.zeros((image_kpts.shape[0], image_kpts.shape[0]))
525
  for i in range(image_kpts.shape[0]):
 
535
  dt = {"keypoints": image_kpts[j].copy(), "bbox": gt_bbox_xyxy}
536
  gt["keypoints"][:, 2] = (gt["keypoints"][:, 2] > config.confidence_thr) * 2
537
  oks = compute_oks(gt, dt)
538
+ assert oks <= 1.0, f"OKS value {oks} exceeds 1.0, which indicates a bug in compute_oks"
 
539
  oks_matrix[i, j] = oks
540
 
541
  np.fill_diagonal(oks_matrix, -1)
 
576
  Returns:
577
  float: OKS score or mean OKS.
578
  """
579
+ sigmas = np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) / 10.0
 
 
 
580
  vars = (sigmas * 2) ** 2
581
  k = len(sigmas)
582
+ visibility_condition = lambda x: x > 0.3
583
  g = np.array(gt["keypoints"]).reshape(k, 3)
584
  xg = g[:, 0]
585
  yg = g[:, 1]
{demo → bboxmaskpose}/posevis_lite.py RENAMED
@@ -1,9 +1,13 @@
 
 
1
  import os
2
  from typing import Any, Dict, List, Optional, Tuple, Union
3
 
4
  import cv2
5
  import numpy as np
6
 
 
 
7
  NEUTRAL_COLOR = (52, 235, 107)
8
 
9
  LEFT_ARM_COLOR = (216, 235, 52)
@@ -85,14 +89,6 @@ def _draw_line(
85
  start = np.array(start)[:2]
86
  stop = np.array(stop)[:2]
87
  if line_type.lower() == "solid":
88
- img = cv2.line(
89
- img,
90
- (int(start[0]), int(start[1])),
91
- (int(stop[0]), int(stop[1])),
92
- color=(0, 0, 0),
93
- thickness=thickness+1,
94
- lineType=cv2.LINE_AA,
95
- )
96
  img = cv2.line(
97
  img,
98
  (int(start[0]), int(start[1])),
@@ -193,7 +189,14 @@ def pose_visualization(
193
  if not isinstance(color, (list, tuple)):
194
  color = [color for keypoint in keypoints]
195
  else:
196
- color = [None for keypoint in keypoints]
 
 
 
 
 
 
 
197
 
198
  max_padding = [0, 0, 0, 0]
199
  for keypoint, clr in zip(keypoints, color):
@@ -243,12 +246,9 @@ def pose_visualization(
243
  # If conf >= confidence_thr: conf = 2
244
  vis_is_float = np.any(np.logical_and(keypoints[:, -1] > 0, keypoints[:, -1] < 1))
245
  if keypoints.shape[1] == 3 and vis_is_float:
246
- # print("before", keypoints[:, -1])
247
  lower_idx = keypoints[:, -1] < confidence_thr
248
  keypoints[lower_idx, -1] = 1
249
  keypoints[~lower_idx, -1] = 2
250
- # print("after", keypoints[:, -1])
251
- # print("-"*20)
252
 
253
  # All visibility values should be ints
254
  keypoints[:, -1] = keypoints[:, -1].astype(int)
 
1
+ # Copyright (c) authors of BBoxMaskPose (BMPv2). All rights reserved.
2
+
3
  import os
4
  from typing import Any, Dict, List, Optional, Tuple, Union
5
 
6
  import cv2
7
  import numpy as np
8
 
9
+ from bboxmaskpose.sam2.distinctipy import get_colors
10
+
11
  NEUTRAL_COLOR = (52, 235, 107)
12
 
13
  LEFT_ARM_COLOR = (216, 235, 52)
 
89
  start = np.array(start)[:2]
90
  stop = np.array(stop)[:2]
91
  if line_type.lower() == "solid":
 
 
 
 
 
 
 
 
92
  img = cv2.line(
93
  img,
94
  (int(start[0]), int(start[1])),
 
189
  if not isinstance(color, (list, tuple)):
190
  color = [color for keypoint in keypoints]
191
  else:
192
+ if differ_individuals:
193
+ color = (
194
+ (np.array(get_colors(len(keypoints), exclude_colors=[(0, 1, 0), (0, 0, 0), (1, 1, 1)], rng=0)) * 255)
195
+ .astype(int)
196
+ .tolist()
197
+ )
198
+ else:
199
+ color = [None for keypoint in keypoints]
200
 
201
  max_padding = [0, 0, 0, 0]
202
  for keypoint, clr in zip(keypoints, color):
 
246
  # If conf >= confidence_thr: conf = 2
247
  vis_is_float = np.any(np.logical_and(keypoints[:, -1] > 0, keypoints[:, -1] < 1))
248
  if keypoints.shape[1] == 3 and vis_is_float:
 
249
  lower_idx = keypoints[:, -1] < confidence_thr
250
  keypoints[lower_idx, -1] = 1
251
  keypoints[~lower_idx, -1] = 2
 
 
252
 
253
  # All visibility values should be ints
254
  keypoints[:, -1] = keypoints[:, -1].astype(int)
{sam2 → bboxmaskpose/sam2}/__init__.py RENAMED
@@ -8,4 +8,4 @@ from hydra import initialize_config_module
8
  from hydra.core.global_hydra import GlobalHydra
9
 
10
  if not GlobalHydra.instance().is_initialized():
11
- initialize_config_module("sam2", version_base="1.2")
 
8
  from hydra.core.global_hydra import GlobalHydra
9
 
10
  if not GlobalHydra.instance().is_initialized():
11
+ initialize_config_module("bboxmaskpose.sam2", version_base="1.2")
{sam2 → bboxmaskpose/sam2}/automatic_mask_generator.py RENAMED
@@ -11,9 +11,10 @@ import numpy as np
11
  import torch
12
  from torchvision.ops.boxes import batched_nms, box_area # type: ignore
13
 
14
- from sam2.modeling.sam2_base import SAM2Base
15
- from sam2.sam2_image_predictor import SAM2ImagePredictor
16
- from sam2.utils.amg import (
 
17
  area_from_rle,
18
  batch_iterator,
19
  batched_mask_to_box,
@@ -24,7 +25,6 @@ from sam2.utils.amg import (
24
  generate_crop_boxes,
25
  is_box_near_crop_edge,
26
  mask_to_rle_pytorch,
27
- MaskData,
28
  remove_small_regions,
29
  rle_to_mask,
30
  uncrop_boxes_xyxy,
@@ -103,9 +103,7 @@ class SAM2AutomaticMaskGenerator:
103
  multimask_output (bool): Whether to output multimask at each point of the grid.
104
  """
105
 
106
- assert (points_per_side is None) != (
107
- point_grids is None
108
- ), "Exactly one of points_per_side or point_grid must be provided."
109
  if points_per_side is not None:
110
  self.point_grids = build_all_layer_point_grids(
111
  points_per_side,
@@ -161,7 +159,7 @@ class SAM2AutomaticMaskGenerator:
161
  Returns:
162
  (SAM2AutomaticMaskGenerator): The loaded model.
163
  """
164
- from sam2.build_sam import build_sam2_hf
165
 
166
  sam_model = build_sam2_hf(model_id, **kwargs)
167
  return cls(sam_model, **kwargs)
@@ -197,9 +195,7 @@ class SAM2AutomaticMaskGenerator:
197
 
198
  # Encode masks
199
  if self.output_mode == "coco_rle":
200
- mask_data["segmentations"] = [
201
- coco_encode_rle(rle) for rle in mask_data["rles"]
202
- ]
203
  elif self.output_mode == "binary_mask":
204
  mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
205
  else:
@@ -223,9 +219,7 @@ class SAM2AutomaticMaskGenerator:
223
 
224
  def _generate_masks(self, image: np.ndarray) -> MaskData:
225
  orig_size = image.shape[:2]
226
- crop_boxes, layer_idxs = generate_crop_boxes(
227
- orig_size, self.crop_n_layers, self.crop_overlap_ratio
228
- )
229
 
230
  # Iterate over image crops
231
  data = MaskData()
@@ -268,9 +262,7 @@ class SAM2AutomaticMaskGenerator:
268
  # Generate masks for this crop in batches
269
  data = MaskData()
270
  for (points,) in batch_iterator(self.points_per_batch, points_for_image):
271
- batch_data = self._process_batch(
272
- points, cropped_im_size, crop_box, orig_size, normalize=True
273
- )
274
  data.cat(batch_data)
275
  del batch_data
276
  self.predictor.reset_predictor()
@@ -302,15 +294,9 @@ class SAM2AutomaticMaskGenerator:
302
  orig_h, orig_w = orig_size
303
 
304
  # Run model on this batch
305
- points = torch.as_tensor(
306
- points, dtype=torch.float32, device=self.predictor.device
307
- )
308
- in_points = self.predictor._transforms.transform_coords(
309
- points, normalize=normalize, orig_hw=im_size
310
- )
311
- in_labels = torch.ones(
312
- in_points.shape[0], dtype=torch.int, device=in_points.device
313
- )
314
  masks, iou_preds, low_res_masks = self.predictor._predict(
315
  in_points[:, None, :],
316
  in_labels[:, None],
@@ -334,23 +320,15 @@ class SAM2AutomaticMaskGenerator:
334
  data.filter(keep_mask)
335
 
336
  # Calculate and filter by stability score
337
- data["stability_score"] = calculate_stability_score(
338
- data["masks"], self.mask_threshold, self.stability_score_offset
339
- )
340
  if self.stability_score_thresh > 0.0:
341
  keep_mask = data["stability_score"] >= self.stability_score_thresh
342
  data.filter(keep_mask)
343
  else:
344
  # One step refinement using previous mask predictions
345
- in_points = self.predictor._transforms.transform_coords(
346
- data["points"], normalize=normalize, orig_hw=im_size
347
- )
348
- labels = torch.ones(
349
- in_points.shape[0], dtype=torch.int, device=in_points.device
350
- )
351
- masks, ious = self.refine_with_m2m(
352
- in_points, labels, data["low_res_masks"], self.points_per_batch
353
- )
354
  data["masks"] = masks.squeeze(1)
355
  data["iou_preds"] = ious.squeeze(1)
356
 
@@ -358,9 +336,7 @@ class SAM2AutomaticMaskGenerator:
358
  keep_mask = data["iou_preds"] > self.pred_iou_thresh
359
  data.filter(keep_mask)
360
 
361
- data["stability_score"] = calculate_stability_score(
362
- data["masks"], self.mask_threshold, self.stability_score_offset
363
- )
364
  if self.stability_score_thresh > 0.0:
365
  keep_mask = data["stability_score"] >= self.stability_score_thresh
366
  data.filter(keep_mask)
@@ -370,9 +346,7 @@ class SAM2AutomaticMaskGenerator:
370
  data["boxes"] = batched_mask_to_box(data["masks"])
371
 
372
  # Filter boxes that touch crop boundaries
373
- keep_mask = ~is_box_near_crop_edge(
374
- data["boxes"], crop_box, [0, 0, orig_w, orig_h]
375
- )
376
  if not torch.all(keep_mask):
377
  data.filter(keep_mask)
378
 
@@ -384,9 +358,7 @@ class SAM2AutomaticMaskGenerator:
384
  return data
385
 
386
  @staticmethod
387
- def postprocess_small_regions(
388
- mask_data: MaskData, min_area: int, nms_thresh: float
389
- ) -> MaskData:
390
  """
391
  Removes small disconnected regions and holes in masks, then reruns
392
  box NMS to remove any new duplicates.
@@ -438,9 +410,7 @@ class SAM2AutomaticMaskGenerator:
438
  new_masks = []
439
  new_iou_preds = []
440
 
441
- for cur_points, cur_point_labels, low_res_mask in batch_iterator(
442
- points_per_batch, points, point_labels, low_res_masks
443
- ):
444
  best_masks, best_iou_preds, _ = self.predictor._predict(
445
  cur_points[:, None, :],
446
  cur_point_labels[:, None],
 
11
  import torch
12
  from torchvision.ops.boxes import batched_nms, box_area # type: ignore
13
 
14
+ from bboxmaskpose.sam2.modeling.sam2_base import SAM2Base
15
+ from bboxmaskpose.sam2.sam2_image_predictor import SAM2ImagePredictor
16
+ from bboxmaskpose.sam2.utils.amg import (
17
+ MaskData,
18
  area_from_rle,
19
  batch_iterator,
20
  batched_mask_to_box,
 
25
  generate_crop_boxes,
26
  is_box_near_crop_edge,
27
  mask_to_rle_pytorch,
 
28
  remove_small_regions,
29
  rle_to_mask,
30
  uncrop_boxes_xyxy,
 
103
  multimask_output (bool): Whether to output multimask at each point of the grid.
104
  """
105
 
106
+ assert (points_per_side is None) != (point_grids is None), "Exactly one of points_per_side or point_grid must be provided."
 
 
107
  if points_per_side is not None:
108
  self.point_grids = build_all_layer_point_grids(
109
  points_per_side,
 
159
  Returns:
160
  (SAM2AutomaticMaskGenerator): The loaded model.
161
  """
162
+ from bboxmaskpose.sam2.build_sam import build_sam2_hf
163
 
164
  sam_model = build_sam2_hf(model_id, **kwargs)
165
  return cls(sam_model, **kwargs)
 
195
 
196
  # Encode masks
197
  if self.output_mode == "coco_rle":
198
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
 
 
199
  elif self.output_mode == "binary_mask":
200
  mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
201
  else:
 
219
 
220
  def _generate_masks(self, image: np.ndarray) -> MaskData:
221
  orig_size = image.shape[:2]
222
+ crop_boxes, layer_idxs = generate_crop_boxes(orig_size, self.crop_n_layers, self.crop_overlap_ratio)
 
 
223
 
224
  # Iterate over image crops
225
  data = MaskData()
 
262
  # Generate masks for this crop in batches
263
  data = MaskData()
264
  for (points,) in batch_iterator(self.points_per_batch, points_for_image):
265
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, normalize=True)
 
 
266
  data.cat(batch_data)
267
  del batch_data
268
  self.predictor.reset_predictor()
 
294
  orig_h, orig_w = orig_size
295
 
296
  # Run model on this batch
297
+ points = torch.as_tensor(points, dtype=torch.float32, device=self.predictor.device)
298
+ in_points = self.predictor._transforms.transform_coords(points, normalize=normalize, orig_hw=im_size)
299
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
 
 
 
 
 
 
300
  masks, iou_preds, low_res_masks = self.predictor._predict(
301
  in_points[:, None, :],
302
  in_labels[:, None],
 
320
  data.filter(keep_mask)
321
 
322
  # Calculate and filter by stability score
323
+ data["stability_score"] = calculate_stability_score(data["masks"], self.mask_threshold, self.stability_score_offset)
 
 
324
  if self.stability_score_thresh > 0.0:
325
  keep_mask = data["stability_score"] >= self.stability_score_thresh
326
  data.filter(keep_mask)
327
  else:
328
  # One step refinement using previous mask predictions
329
+ in_points = self.predictor._transforms.transform_coords(data["points"], normalize=normalize, orig_hw=im_size)
330
+ labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
331
+ masks, ious = self.refine_with_m2m(in_points, labels, data["low_res_masks"], self.points_per_batch)
 
 
 
 
 
 
332
  data["masks"] = masks.squeeze(1)
333
  data["iou_preds"] = ious.squeeze(1)
334
 
 
336
  keep_mask = data["iou_preds"] > self.pred_iou_thresh
337
  data.filter(keep_mask)
338
 
339
+ data["stability_score"] = calculate_stability_score(data["masks"], self.mask_threshold, self.stability_score_offset)
 
 
340
  if self.stability_score_thresh > 0.0:
341
  keep_mask = data["stability_score"] >= self.stability_score_thresh
342
  data.filter(keep_mask)
 
346
  data["boxes"] = batched_mask_to_box(data["masks"])
347
 
348
  # Filter boxes that touch crop boundaries
349
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
 
 
350
  if not torch.all(keep_mask):
351
  data.filter(keep_mask)
352
 
 
358
  return data
359
 
360
  @staticmethod
361
+ def postprocess_small_regions(mask_data: MaskData, min_area: int, nms_thresh: float) -> MaskData:
 
 
362
  """
363
  Removes small disconnected regions and holes in masks, then reruns
364
  box NMS to remove any new duplicates.
 
410
  new_masks = []
411
  new_iou_preds = []
412
 
413
+ for cur_points, cur_point_labels, low_res_mask in batch_iterator(points_per_batch, points, point_labels, low_res_masks):
 
 
414
  best_masks, best_iou_preds, _ = self.predictor._predict(
415
  cur_points[:, None, :],
416
  cur_point_labels[:, None],
{sam2 → bboxmaskpose/sam2}/benchmark.py RENAMED
@@ -11,7 +11,7 @@ import numpy as np
11
  import torch
12
  from tqdm import tqdm
13
 
14
- from sam2.build_sam import build_sam2_video_predictor
15
 
16
  # Only cuda supported
17
  assert torch.cuda.is_available()
@@ -28,19 +28,13 @@ sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
28
  model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
29
 
30
  # Build video predictor with vos_optimized=True setting
31
- predictor = build_sam2_video_predictor(
32
- model_cfg, sam2_checkpoint, device=device, vos_optimized=True
33
- )
34
 
35
 
36
  # Initialize with video
37
  video_dir = "notebooks/videos/bedroom"
38
  # scan all the JPEG frame names in this directory
39
- frame_names = [
40
- p
41
- for p in os.listdir(video_dir)
42
- if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
43
- ]
44
  frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
45
  inference_state = predictor.init_state(video_path=video_dir)
46
 
 
11
  import torch
12
  from tqdm import tqdm
13
 
14
+ from bboxmaskpose.sam2.build_sam import build_sam2_video_predictor
15
 
16
  # Only cuda supported
17
  assert torch.cuda.is_available()
 
28
  model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
29
 
30
  # Build video predictor with vos_optimized=True setting
31
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device, vos_optimized=True)
 
 
32
 
33
 
34
  # Initialize with video
35
  video_dir = "notebooks/videos/bedroom"
36
  # scan all the JPEG frame names in this directory
37
+ frame_names = [p for p in os.listdir(video_dir) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]]
 
 
 
 
38
  frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
39
  inference_state = predictor.init_state(video_path=video_dir)
40
 
{sam2 → bboxmaskpose/sam2}/build_sam.py RENAMED
@@ -6,14 +6,16 @@
6
 
7
  import logging
8
  import os
 
9
 
10
  import torch
11
- from hydra import compose
 
 
 
12
  from hydra.utils import instantiate
13
  from omegaconf import OmegaConf
14
 
15
- import sam2
16
-
17
  # Check if the user is running Python from the parent directory of the sam2 repo
18
  # (i.e. the directory where this repo is cloned into) -- this is not supported since
19
  # it could shadow the sam2 package and cause issues.
@@ -86,13 +88,26 @@ def build_sam2(
86
  "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
87
  "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
88
  ]
 
 
 
 
 
 
 
 
 
 
 
89
  # Read config and init model
90
  try:
91
- cfg = compose(config_name=config_file)
 
 
92
  except Exception as e:
93
  logging.error(f"Error loading config: {e}")
94
  cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
95
-
96
  OmegaConf.resolve(cfg)
97
  model = instantiate(cfg.model, _recursive_=True)
98
  _load_checkpoint(model, ckpt_path)
@@ -161,14 +176,23 @@ def build_sam2_hf(model_id, **kwargs):
161
 
162
  def build_sam2_video_predictor_hf(model_id, **kwargs):
163
  config_name, ckpt_path = _hf_download(model_id)
164
- return build_sam2_video_predictor(
165
- config_file=config_name, ckpt_path=ckpt_path, **kwargs
166
- )
 
 
167
 
168
 
169
  def _load_checkpoint(model, ckpt_path):
170
  if ckpt_path is not None:
171
- sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
 
 
 
 
 
 
 
172
  missing_keys, unexpected_keys = model.load_state_dict(sd)
173
  if missing_keys:
174
  logging.error(missing_keys)
@@ -176,4 +200,5 @@ def _load_checkpoint(model, ckpt_path):
176
  if unexpected_keys:
177
  logging.error(unexpected_keys)
178
  raise RuntimeError()
 
179
  logging.info("Loaded checkpoint sucessfully")
 
6
 
7
  import logging
8
  import os
9
+ import urllib.parse as urlparse
10
 
11
  import torch
12
+
13
+ import bboxmaskpose.sam2 as sam2
14
+ from hydra import compose, initialize_config_dir
15
+ from hydra.core.global_hydra import GlobalHydra
16
  from hydra.utils import instantiate
17
  from omegaconf import OmegaConf
18
 
 
 
19
  # Check if the user is running Python from the parent directory of the sam2 repo
20
  # (i.e. the directory where this repo is cloned into) -- this is not supported since
21
  # it could shadow the sam2 package and cause issues.
 
88
  "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
89
  "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
90
  ]
91
+
92
+ # IMPORTANT: compose() requires Hydra to be initialized with a config source.
93
+ # Also important if build_sam2() can be called multiple times in one process.
94
+ GlobalHydra.instance().clear()
95
+
96
+ # Point Hydra at the directory that contains the SAM2 yaml configs
97
+ config_dir = os.path.dirname(config_file)
98
+
99
+ # Hydra expects config_name WITHOUT .yaml
100
+ config_name = os.path.basename(config_file).replace(".yaml", "")
101
+
102
  # Read config and init model
103
  try:
104
+ with initialize_config_dir(version_base=None, config_dir=str(config_dir)):
105
+ cfg = compose(config_name=config_name, overrides=hydra_overrides_extra)
106
+ # cfg = compose(config_name=config_file)
107
  except Exception as e:
108
  logging.error(f"Error loading config: {e}")
109
  cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
110
+
111
  OmegaConf.resolve(cfg)
112
  model = instantiate(cfg.model, _recursive_=True)
113
  _load_checkpoint(model, ckpt_path)
 
176
 
177
  def build_sam2_video_predictor_hf(model_id, **kwargs):
178
  config_name, ckpt_path = _hf_download(model_id)
179
+ return build_sam2_video_predictor(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
180
+
181
+
182
+ def _is_url(path: str) -> bool:
183
+ return urlparse.urlparse(path).scheme != ""
184
 
185
 
186
  def _load_checkpoint(model, ckpt_path):
187
  if ckpt_path is not None:
188
+
189
+ if _is_url(ckpt_path):
190
+ sd = torch.hub.load_state_dict_from_url(ckpt_path, map_location="cpu", weights_only=True)["model"]
191
+ elif os.path.exists(ckpt_path):
192
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
193
+ else:
194
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
195
+
196
  missing_keys, unexpected_keys = model.load_state_dict(sd)
197
  if missing_keys:
198
  logging.error(missing_keys)
 
200
  if unexpected_keys:
201
  logging.error(unexpected_keys)
202
  raise RuntimeError()
203
+
204
  logging.info("Loaded checkpoint sucessfully")
{sam2 → bboxmaskpose/sam2}/colorblind.py RENAMED
@@ -1,7 +1,10 @@
 
 
1
  """
2
  Adapted from "The Color Blind Simulation function" by Matthew Wickline
3
  and the Human - Computer Interaction Resource Network (http://hcirn.com/), 2000 - 2001.
4
  """
 
5
  import numpy as np
6
 
7
  rBlind = {
@@ -261,16 +264,13 @@ def simulate_colors(colors, colorblind_type="Deuteranomaly", one_row=None, show=
261
  :return:
262
  """
263
  import matplotlib.pyplot as plt
264
-
265
  from distinctipy import distinctipy
266
 
267
  filtered_colors = [colorblind_filter(color, colorblind_type) for color in colors]
268
 
269
  fig, axes = plt.subplots(1, 2, figsize=(8, 4))
270
 
271
- distinctipy.color_swatch(
272
- colors, ax=axes[0], one_row=one_row, title="Viewed with Normal Sight"
273
- )
274
 
275
  distinctipy.color_swatch(
276
  filtered_colors,
@@ -324,30 +324,22 @@ def simulate_clusters(
324
  """
325
  import matplotlib.pyplot as plt
326
  import pandas as pd
327
-
328
  from distinctipy import distinctipy
329
 
330
  if dataset not in ("s1", "s2", "s3", "s4", "a1", "a2", "a3", "b1"):
331
  raise ValueError("dataset must be s1, s2, s3, s4, a1, a2, a3 or b1")
332
 
333
- URL = (
334
- "https://raw.githubusercontent.com/alan-turing-institute/distinctipy/"
335
- "main/distinctipy/datasets/"
336
- )
337
  df = pd.read_csv(URL + dataset + ".csv")
338
 
339
  if colorblind_distinct:
340
- orig_colors = distinctipy.get_colors(
341
- df["cluster"].nunique(), colorblind_type=colorblind_type
342
- )
343
  else:
344
  orig_colors = distinctipy.get_colors(df["cluster"].nunique())
345
 
346
  orig_cmap = distinctipy.get_colormap(orig_colors)
347
 
348
- filtered_colors = [
349
- colorblind_filter(color, colorblind_type) for color in orig_colors
350
- ]
351
  filtered_cmap = distinctipy.get_colormap(filtered_colors)
352
 
353
  fig, axes = plt.subplots(1, 2, figsize=(10, 5))
@@ -376,4 +368,4 @@ def _main():
376
 
377
 
378
  if __name__ == "__main__":
379
- _main()
 
1
+ # Adapted from the distinctipy repository (https://github.com/alan-turing-institute/distinctipy).
2
+ # Original authors: distinctipy contributors. Included with minor modifications.
3
  """
4
  Adapted from "The Color Blind Simulation function" by Matthew Wickline
5
  and the Human - Computer Interaction Resource Network (http://hcirn.com/), 2000 - 2001.
6
  """
7
+
8
  import numpy as np
9
 
10
  rBlind = {
 
264
  :return:
265
  """
266
  import matplotlib.pyplot as plt
 
267
  from distinctipy import distinctipy
268
 
269
  filtered_colors = [colorblind_filter(color, colorblind_type) for color in colors]
270
 
271
  fig, axes = plt.subplots(1, 2, figsize=(8, 4))
272
 
273
+ distinctipy.color_swatch(colors, ax=axes[0], one_row=one_row, title="Viewed with Normal Sight")
 
 
274
 
275
  distinctipy.color_swatch(
276
  filtered_colors,
 
324
  """
325
  import matplotlib.pyplot as plt
326
  import pandas as pd
 
327
  from distinctipy import distinctipy
328
 
329
  if dataset not in ("s1", "s2", "s3", "s4", "a1", "a2", "a3", "b1"):
330
  raise ValueError("dataset must be s1, s2, s3, s4, a1, a2, a3 or b1")
331
 
332
+ URL = "https://raw.githubusercontent.com/alan-turing-institute/distinctipy/" "main/distinctipy/datasets/"
 
 
 
333
  df = pd.read_csv(URL + dataset + ".csv")
334
 
335
  if colorblind_distinct:
336
+ orig_colors = distinctipy.get_colors(df["cluster"].nunique(), colorblind_type=colorblind_type)
 
 
337
  else:
338
  orig_colors = distinctipy.get_colors(df["cluster"].nunique())
339
 
340
  orig_cmap = distinctipy.get_colormap(orig_colors)
341
 
342
+ filtered_colors = [colorblind_filter(color, colorblind_type) for color in orig_colors]
 
 
343
  filtered_cmap = distinctipy.get_colormap(filtered_colors)
344
 
345
  fig, axes = plt.subplots(1, 2, figsize=(10, 5))
 
368
 
369
 
370
  if __name__ == "__main__":
371
+ _main()
bboxmaskpose/sam2/configs/sam-pose2seg/sam-pose2seg_hiera_b+.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [64, 64]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [64, 64]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ no_obj_embed_spatial: true
93
+ # use high-resolution feature map in the SAM mask decoder
94
+ use_high_res_features_in_sam: true
95
+ # output 3 masks on the first click on initial conditioning frames
96
+ multimask_output_in_sam: true
97
+ # SAM heads
98
+ iou_prediction_use_sigmoid: True
99
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
100
+ use_obj_ptrs_in_encoder: true
101
+ add_tpos_enc_to_obj_ptrs: true
102
+ proj_tpos_enc_in_obj_ptrs: true
103
+ use_signed_tpos_enc_to_obj_ptrs: true
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+
116
+ n_kpts_encoder: 8
117
+ # Compilation flag
118
+ # compile_image_encoder: False
{sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_b+.yaml RENAMED
@@ -2,18 +2,18 @@
2
 
3
  # Model
4
  model:
5
- _target_: sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
- _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
- _target_: sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 112
12
  num_heads: 2
13
  neck:
14
- _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
  position_encoding:
16
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
  num_pos_feats: 256
18
  normalize: true
19
  scale: null
@@ -24,17 +24,17 @@ model:
24
  fpn_interp_model: nearest
25
 
26
  memory_attention:
27
- _target_: sam2.modeling.memory_attention.MemoryAttention
28
  d_model: 256
29
  pos_enc_at_input: true
30
  layer:
31
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
  activation: relu
33
  dim_feedforward: 2048
34
  dropout: 0.1
35
  pos_enc_at_attn: false
36
  self_attention:
37
- _target_: sam2.modeling.sam.transformer.RoPEAttention
38
  rope_theta: 10000.0
39
  feat_sizes: [64, 64]
40
  embedding_dim: 256
@@ -45,7 +45,7 @@ model:
45
  pos_enc_at_cross_attn_keys: true
46
  pos_enc_at_cross_attn_queries: false
47
  cross_attention:
48
- _target_: sam2.modeling.sam.transformer.RoPEAttention
49
  rope_theta: 10000.0
50
  feat_sizes: [64, 64]
51
  rope_k_repeat: True
@@ -57,23 +57,23 @@ model:
57
  num_layers: 4
58
 
59
  memory_encoder:
60
- _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
  out_dim: 64
62
  position_encoding:
63
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
  num_pos_feats: 64
65
  normalize: true
66
  scale: null
67
  temperature: 10000
68
  mask_downsampler:
69
- _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
  kernel_size: 3
71
  stride: 2
72
  padding: 1
73
  fuser:
74
- _target_: sam2.modeling.memory_encoder.Fuser
75
  layer:
76
- _target_: sam2.modeling.memory_encoder.CXBlock
77
  dim: 256
78
  kernel_size: 7
79
  padding: 3
 
2
 
3
  # Model
4
  model:
5
+ _target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 112
12
  num_heads: 2
13
  neck:
14
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
15
  position_encoding:
16
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
17
  num_pos_feats: 256
18
  normalize: true
19
  scale: null
 
24
  fpn_interp_model: nearest
25
 
26
  memory_attention:
27
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
28
  d_model: 256
29
  pos_enc_at_input: true
30
  layer:
31
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
32
  activation: relu
33
  dim_feedforward: 2048
34
  dropout: 0.1
35
  pos_enc_at_attn: false
36
  self_attention:
37
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
38
  rope_theta: 10000.0
39
  feat_sizes: [64, 64]
40
  embedding_dim: 256
 
45
  pos_enc_at_cross_attn_keys: true
46
  pos_enc_at_cross_attn_queries: false
47
  cross_attention:
48
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
49
  rope_theta: 10000.0
50
  feat_sizes: [64, 64]
51
  rope_k_repeat: True
 
57
  num_layers: 4
58
 
59
  memory_encoder:
60
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
61
  out_dim: 64
62
  position_encoding:
63
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
64
  num_pos_feats: 64
65
  normalize: true
66
  scale: null
67
  temperature: 10000
68
  mask_downsampler:
69
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
70
  kernel_size: 3
71
  stride: 2
72
  padding: 1
73
  fuser:
74
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
75
  layer:
76
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
77
  dim: 256
78
  kernel_size: 7
79
  padding: 3
{sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_l.yaml RENAMED
@@ -2,12 +2,12 @@
2
 
3
  # Model
4
  model:
5
- _target_: sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
- _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
- _target_: sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 144
12
  num_heads: 2
13
  stages: [2, 6, 36, 4]
@@ -15,9 +15,9 @@ model:
15
  window_pos_embed_bkg_spatial_size: [7, 7]
16
  window_spec: [8, 4, 16, 8]
17
  neck:
18
- _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
  position_encoding:
20
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
  num_pos_feats: 256
22
  normalize: true
23
  scale: null
@@ -28,17 +28,17 @@ model:
28
  fpn_interp_model: nearest
29
 
30
  memory_attention:
31
- _target_: sam2.modeling.memory_attention.MemoryAttention
32
  d_model: 256
33
  pos_enc_at_input: true
34
  layer:
35
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
  activation: relu
37
  dim_feedforward: 2048
38
  dropout: 0.1
39
  pos_enc_at_attn: false
40
  self_attention:
41
- _target_: sam2.modeling.sam.transformer.RoPEAttention
42
  rope_theta: 10000.0
43
  feat_sizes: [64, 64]
44
  embedding_dim: 256
@@ -49,7 +49,7 @@ model:
49
  pos_enc_at_cross_attn_keys: true
50
  pos_enc_at_cross_attn_queries: false
51
  cross_attention:
52
- _target_: sam2.modeling.sam.transformer.RoPEAttention
53
  rope_theta: 10000.0
54
  feat_sizes: [64, 64]
55
  rope_k_repeat: True
@@ -61,23 +61,23 @@ model:
61
  num_layers: 4
62
 
63
  memory_encoder:
64
- _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
  out_dim: 64
66
  position_encoding:
67
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
  num_pos_feats: 64
69
  normalize: true
70
  scale: null
71
  temperature: 10000
72
  mask_downsampler:
73
- _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
  kernel_size: 3
75
  stride: 2
76
  padding: 1
77
  fuser:
78
- _target_: sam2.modeling.memory_encoder.Fuser
79
  layer:
80
- _target_: sam2.modeling.memory_encoder.CXBlock
81
  dim: 256
82
  kernel_size: 7
83
  padding: 3
 
2
 
3
  # Model
4
  model:
5
+ _target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 144
12
  num_heads: 2
13
  stages: [2, 6, 36, 4]
 
15
  window_pos_embed_bkg_spatial_size: [7, 7]
16
  window_spec: [8, 4, 16, 8]
17
  neck:
18
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
19
  position_encoding:
20
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
21
  num_pos_feats: 256
22
  normalize: true
23
  scale: null
 
28
  fpn_interp_model: nearest
29
 
30
  memory_attention:
31
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
32
  d_model: 256
33
  pos_enc_at_input: true
34
  layer:
35
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
36
  activation: relu
37
  dim_feedforward: 2048
38
  dropout: 0.1
39
  pos_enc_at_attn: false
40
  self_attention:
41
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
42
  rope_theta: 10000.0
43
  feat_sizes: [64, 64]
44
  embedding_dim: 256
 
49
  pos_enc_at_cross_attn_keys: true
50
  pos_enc_at_cross_attn_queries: false
51
  cross_attention:
52
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
53
  rope_theta: 10000.0
54
  feat_sizes: [64, 64]
55
  rope_k_repeat: True
 
61
  num_layers: 4
62
 
63
  memory_encoder:
64
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
65
  out_dim: 64
66
  position_encoding:
67
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
68
  num_pos_feats: 64
69
  normalize: true
70
  scale: null
71
  temperature: 10000
72
  mask_downsampler:
73
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
74
  kernel_size: 3
75
  stride: 2
76
  padding: 1
77
  fuser:
78
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
79
  layer:
80
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
81
  dim: 256
82
  kernel_size: 7
83
  padding: 3
{sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_s.yaml RENAMED
@@ -2,21 +2,21 @@
2
 
3
  # Model
4
  model:
5
- _target_: sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
- _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
- _target_: sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 96
12
  num_heads: 1
13
  stages: [1, 2, 11, 2]
14
  global_att_blocks: [7, 10, 13]
15
  window_pos_embed_bkg_spatial_size: [7, 7]
16
  neck:
17
- _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
  position_encoding:
19
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
  num_pos_feats: 256
21
  normalize: true
22
  scale: null
@@ -27,17 +27,17 @@ model:
27
  fpn_interp_model: nearest
28
 
29
  memory_attention:
30
- _target_: sam2.modeling.memory_attention.MemoryAttention
31
  d_model: 256
32
  pos_enc_at_input: true
33
  layer:
34
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
  activation: relu
36
  dim_feedforward: 2048
37
  dropout: 0.1
38
  pos_enc_at_attn: false
39
  self_attention:
40
- _target_: sam2.modeling.sam.transformer.RoPEAttention
41
  rope_theta: 10000.0
42
  feat_sizes: [64, 64]
43
  embedding_dim: 256
@@ -48,7 +48,7 @@ model:
48
  pos_enc_at_cross_attn_keys: true
49
  pos_enc_at_cross_attn_queries: false
50
  cross_attention:
51
- _target_: sam2.modeling.sam.transformer.RoPEAttention
52
  rope_theta: 10000.0
53
  feat_sizes: [64, 64]
54
  rope_k_repeat: True
@@ -60,23 +60,23 @@ model:
60
  num_layers: 4
61
 
62
  memory_encoder:
63
- _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
  out_dim: 64
65
  position_encoding:
66
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
  num_pos_feats: 64
68
  normalize: true
69
  scale: null
70
  temperature: 10000
71
  mask_downsampler:
72
- _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
  kernel_size: 3
74
  stride: 2
75
  padding: 1
76
  fuser:
77
- _target_: sam2.modeling.memory_encoder.Fuser
78
  layer:
79
- _target_: sam2.modeling.memory_encoder.CXBlock
80
  dim: 256
81
  kernel_size: 7
82
  padding: 3
 
2
 
3
  # Model
4
  model:
5
+ _target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 96
12
  num_heads: 1
13
  stages: [1, 2, 11, 2]
14
  global_att_blocks: [7, 10, 13]
15
  window_pos_embed_bkg_spatial_size: [7, 7]
16
  neck:
17
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
18
  position_encoding:
19
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
20
  num_pos_feats: 256
21
  normalize: true
22
  scale: null
 
27
  fpn_interp_model: nearest
28
 
29
  memory_attention:
30
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
31
  d_model: 256
32
  pos_enc_at_input: true
33
  layer:
34
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
35
  activation: relu
36
  dim_feedforward: 2048
37
  dropout: 0.1
38
  pos_enc_at_attn: false
39
  self_attention:
40
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
41
  rope_theta: 10000.0
42
  feat_sizes: [64, 64]
43
  embedding_dim: 256
 
48
  pos_enc_at_cross_attn_keys: true
49
  pos_enc_at_cross_attn_queries: false
50
  cross_attention:
51
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
52
  rope_theta: 10000.0
53
  feat_sizes: [64, 64]
54
  rope_k_repeat: True
 
60
  num_layers: 4
61
 
62
  memory_encoder:
63
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
64
  out_dim: 64
65
  position_encoding:
66
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
67
  num_pos_feats: 64
68
  normalize: true
69
  scale: null
70
  temperature: 10000
71
  mask_downsampler:
72
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
73
  kernel_size: 3
74
  stride: 2
75
  padding: 1
76
  fuser:
77
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
78
  layer:
79
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
80
  dim: 256
81
  kernel_size: 7
82
  padding: 3
{sam2 → bboxmaskpose/sam2}/configs/sam2.1/sam2.1_hiera_t.yaml RENAMED
@@ -2,21 +2,21 @@
2
 
3
  # Model
4
  model:
5
- _target_: sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
- _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
- _target_: sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 96
12
  num_heads: 1
13
  stages: [1, 2, 7, 2]
14
  global_att_blocks: [5, 7, 9]
15
  window_pos_embed_bkg_spatial_size: [7, 7]
16
  neck:
17
- _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
  position_encoding:
19
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
  num_pos_feats: 256
21
  normalize: true
22
  scale: null
@@ -27,17 +27,17 @@ model:
27
  fpn_interp_model: nearest
28
 
29
  memory_attention:
30
- _target_: sam2.modeling.memory_attention.MemoryAttention
31
  d_model: 256
32
  pos_enc_at_input: true
33
  layer:
34
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
  activation: relu
36
  dim_feedforward: 2048
37
  dropout: 0.1
38
  pos_enc_at_attn: false
39
  self_attention:
40
- _target_: sam2.modeling.sam.transformer.RoPEAttention
41
  rope_theta: 10000.0
42
  feat_sizes: [64, 64]
43
  embedding_dim: 256
@@ -48,7 +48,7 @@ model:
48
  pos_enc_at_cross_attn_keys: true
49
  pos_enc_at_cross_attn_queries: false
50
  cross_attention:
51
- _target_: sam2.modeling.sam.transformer.RoPEAttention
52
  rope_theta: 10000.0
53
  feat_sizes: [64, 64]
54
  rope_k_repeat: True
@@ -60,23 +60,23 @@ model:
60
  num_layers: 4
61
 
62
  memory_encoder:
63
- _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
  out_dim: 64
65
  position_encoding:
66
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
  num_pos_feats: 64
68
  normalize: true
69
  scale: null
70
  temperature: 10000
71
  mask_downsampler:
72
- _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
  kernel_size: 3
74
  stride: 2
75
  padding: 1
76
  fuser:
77
- _target_: sam2.modeling.memory_encoder.Fuser
78
  layer:
79
- _target_: sam2.modeling.memory_encoder.CXBlock
80
  dim: 256
81
  kernel_size: 7
82
  padding: 3
 
2
 
3
  # Model
4
  model:
5
+ _target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 96
12
  num_heads: 1
13
  stages: [1, 2, 7, 2]
14
  global_att_blocks: [5, 7, 9]
15
  window_pos_embed_bkg_spatial_size: [7, 7]
16
  neck:
17
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
18
  position_encoding:
19
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
20
  num_pos_feats: 256
21
  normalize: true
22
  scale: null
 
27
  fpn_interp_model: nearest
28
 
29
  memory_attention:
30
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
31
  d_model: 256
32
  pos_enc_at_input: true
33
  layer:
34
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
35
  activation: relu
36
  dim_feedforward: 2048
37
  dropout: 0.1
38
  pos_enc_at_attn: false
39
  self_attention:
40
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
41
  rope_theta: 10000.0
42
  feat_sizes: [64, 64]
43
  embedding_dim: 256
 
48
  pos_enc_at_cross_attn_keys: true
49
  pos_enc_at_cross_attn_queries: false
50
  cross_attention:
51
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
52
  rope_theta: 10000.0
53
  feat_sizes: [64, 64]
54
  rope_k_repeat: True
 
60
  num_layers: 4
61
 
62
  memory_encoder:
63
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
64
  out_dim: 64
65
  position_encoding:
66
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
67
  num_pos_feats: 64
68
  normalize: true
69
  scale: null
70
  temperature: 10000
71
  mask_downsampler:
72
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
73
  kernel_size: 3
74
  stride: 2
75
  padding: 1
76
  fuser:
77
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
78
  layer:
79
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
80
  dim: 256
81
  kernel_size: 7
82
  padding: 3
bboxmaskpose/sam2/configs/sam2.1_training/sam2.1_hiera_b+_COCO+CIHP_finetune_sam-pose2seg.yaml ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ scratch:
4
+ resolution: 1024
5
+ train_batch_size: 1
6
+ num_train_workers: 10
7
+ num_frames: 1
8
+ max_num_objects: 1
9
+ base_lr: 5.0e-6
10
+ vision_lr: 3.0e-06
11
+ phases_per_epoch: 1
12
+ num_epochs: 15
13
+
14
+ dataset:
15
+ # PATHS to Dataset
16
+ img_folder: path/to/datasett
17
+ gt_folder: path/to/dataset
18
+ multiplier: 2
19
+
20
+ # Video transforms
21
+ vos:
22
+ train_transforms:
23
+ - _target_: training.dataset.transforms.ComposeAPI
24
+ transforms:
25
+ - _target_: training.dataset.transforms.RandomHorizontalFlip
26
+ consistent_transform: True
27
+ - _target_: training.dataset.transforms.RandomAffine
28
+ degrees: 25
29
+ shear: 20
30
+ image_interpolation: bilinear
31
+ consistent_transform: True
32
+ - _target_: training.dataset.transforms.RandomResizeAPI
33
+ sizes: ${scratch.resolution}
34
+ square: true
35
+ consistent_transform: True
36
+ - _target_: training.dataset.transforms.ColorJitter
37
+ consistent_transform: True
38
+ brightness: 0.1
39
+ contrast: 0.03
40
+ saturation: 0.03
41
+ hue: null
42
+ - _target_: training.dataset.transforms.RandomGrayscale
43
+ p: 0.05
44
+ consistent_transform: True
45
+ - _target_: training.dataset.transforms.ColorJitter
46
+ consistent_transform: False
47
+ brightness: 0.1
48
+ contrast: 0.05
49
+ saturation: 0.05
50
+ hue: null
51
+ - _target_: training.dataset.transforms.ToTensorAPI
52
+ - _target_: training.dataset.transforms.NormalizeAPI
53
+ mean: [0.485, 0.456, 0.406]
54
+ std: [0.229, 0.224, 0.225]
55
+
56
+ trainer:
57
+ _target_: training.trainer.Trainer
58
+ mode: train_only # change to train ? (a.k.a. train + val)
59
+ max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
60
+ accelerator: cuda
61
+ seed_value: 123
62
+
63
+ model:
64
+ _target_: training.model.sam2.SAM2Train
65
+ image_encoder:
66
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
67
+ scalp: 1
68
+ trunk:
69
+ _target_: sam2.modeling.backbones.hieradet.Hiera
70
+ embed_dim: 112
71
+ num_heads: 2
72
+ drop_path_rate: 0.1
73
+ neck:
74
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
75
+ position_encoding:
76
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
77
+ num_pos_feats: 256
78
+ normalize: true
79
+ scale: null
80
+ temperature: 10000
81
+ d_model: 256
82
+ backbone_channel_list: [896, 448, 224, 112]
83
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
84
+ fpn_interp_model: nearest
85
+
86
+ memory_attention:
87
+ _target_: sam2.modeling.memory_attention.MemoryAttention
88
+ d_model: 256
89
+ pos_enc_at_input: true
90
+ layer:
91
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
92
+ activation: relu
93
+ dim_feedforward: 2048
94
+ dropout: 0.1
95
+ pos_enc_at_attn: false
96
+ self_attention:
97
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
98
+ rope_theta: 10000.0
99
+ feat_sizes: [64, 64]
100
+ embedding_dim: 256
101
+ num_heads: 1
102
+ downsample_rate: 1
103
+ dropout: 0.1
104
+ d_model: 256
105
+ pos_enc_at_cross_attn_keys: true
106
+ pos_enc_at_cross_attn_queries: false
107
+ cross_attention:
108
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
109
+ rope_theta: 10000.0
110
+ feat_sizes: [64, 64]
111
+ rope_k_repeat: True
112
+ embedding_dim: 256
113
+ num_heads: 1
114
+ downsample_rate: 1
115
+ dropout: 0.1
116
+ kv_in_dim: 64
117
+ num_layers: 4
118
+
119
+ memory_encoder:
120
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
121
+ out_dim: 64
122
+ position_encoding:
123
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
124
+ num_pos_feats: 64
125
+ normalize: true
126
+ scale: null
127
+ temperature: 10000
128
+ mask_downsampler:
129
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
130
+ kernel_size: 3
131
+ stride: 2
132
+ padding: 1
133
+ fuser:
134
+ _target_: sam2.modeling.memory_encoder.Fuser
135
+ layer:
136
+ _target_: sam2.modeling.memory_encoder.CXBlock
137
+ dim: 256
138
+ kernel_size: 7
139
+ padding: 3
140
+ layer_scale_init_value: 1e-6
141
+ use_dwconv: True # depth-wise convs
142
+ num_layers: 2
143
+
144
+ num_maskmem: 7
145
+ image_size: ${scratch.resolution}
146
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
147
+ sigmoid_scale_for_mem_enc: 20.0
148
+ sigmoid_bias_for_mem_enc: -10.0
149
+ use_mask_input_as_output_without_sam: true
150
+ # Memory
151
+ directly_add_no_mem_embed: true
152
+ no_obj_embed_spatial: true
153
+ # use high-resolution feature map in the SAM mask decoder
154
+ use_high_res_features_in_sam: true
155
+ # output 3 masks on the first click on initial conditioning frames
156
+ multimask_output_in_sam: true
157
+ # SAM heads
158
+ iou_prediction_use_sigmoid: True
159
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
160
+ use_obj_ptrs_in_encoder: true
161
+ add_tpos_enc_to_obj_ptrs: true
162
+ proj_tpos_enc_in_obj_ptrs: true
163
+ use_signed_tpos_enc_to_obj_ptrs: true
164
+ only_obj_ptrs_in_the_past_for_eval: true
165
+ # object occlusion prediction
166
+ pred_obj_scores: true
167
+ pred_obj_scores_mlp: true
168
+ fixed_no_obj_ptr: true
169
+ # multimask tracking settings
170
+ multimask_output_for_tracking: true
171
+ use_multimask_token_for_obj_ptr: false
172
+ multimask_min_pt_num: 0
173
+ multimask_max_pt_num: 1
174
+ use_mlp_for_obj_ptr_proj: true
175
+
176
+ n_kpts_encoder: 8
177
+ # Compilation flag
178
+ # compile_image_encoder: False
179
+
180
+ ####### Training specific params #######
181
+ # box/point input and corrections
182
+ prob_to_use_pt_input_for_train: 1.0
183
+ prob_to_use_pt_input_for_eval: 0.0
184
+ prob_to_use_box_input_for_train: 0.0 # 0.5*0.5 = 0.25 prob to use box instead of points
185
+ prob_to_use_box_input_for_eval: 0.0
186
+ prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
187
+ num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
188
+ num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
189
+ rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
190
+ add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
191
+ # maximum 2 initial conditioning frames
192
+ num_init_cond_frames_for_train: 2
193
+ rand_init_cond_frames_for_train: True # random 1~2
194
+ num_correction_pt_per_frame: 7 ## CHANGED
195
+ use_act_ckpt_iterative_pt_sampling: false
196
+
197
+
198
+
199
+ num_init_cond_frames_for_eval: 1 # only mask on the first frame
200
+ forward_backbone_per_frame_for_eval: True
201
+
202
+
203
+ data:
204
+ train:
205
+ _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
206
+ phases_per_epoch: ${scratch.phases_per_epoch}
207
+ batch_sizes:
208
+ - ${scratch.train_batch_size}
209
+
210
+ datasets:
211
+ - _target_: training.dataset.utils.RepeatFactorWrapper
212
+ dataset:
213
+ _target_: training.dataset.utils.ConcatDataset
214
+ datasets:
215
+ - _target_: training.dataset.vos_dataset.VOSDataset
216
+ transforms: ${vos.train_transforms}
217
+ training: true
218
+ video_dataset:
219
+ _target_: training.dataset.vos_raw_dataset.SA1BRawDataset
220
+ img_folder: ${dataset.img_folder}
221
+ gt_folder: ${dataset.gt_folder}
222
+ # file_list_txt: ${dataset.file_list_txt}
223
+ sampler:
224
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
225
+ num_frames: ${scratch.num_frames}
226
+ max_num_objects: ${scratch.max_num_objects}
227
+ multiplier: ${dataset.multiplier}
228
+ shuffle: True
229
+ num_workers: ${scratch.num_train_workers}
230
+ pin_memory: True
231
+ drop_last: True
232
+ collate_fn:
233
+ _target_: training.utils.data_utils.collate_fn
234
+ _partial_: true
235
+ dict_key: all
236
+
237
+ # val:
238
+
239
+
240
+ optim:
241
+ amp:
242
+ enabled: True
243
+ amp_dtype: bfloat16
244
+
245
+ optimizer:
246
+ _target_: torch.optim.AdamW
247
+
248
+ gradient_clip:
249
+ _target_: training.optimizer.GradientClipper
250
+ max_norm: 0.1
251
+ norm_type: 2
252
+
253
+ param_group_modifiers:
254
+ - _target_: training.optimizer.layer_decay_param_modifier
255
+ _partial_: True
256
+ layer_decay_value: 0.9
257
+ apply_to: 'image_encoder.trunk'
258
+ overrides:
259
+ - pattern: '*pos_embed*'
260
+ value: 1.0
261
+
262
+ options:
263
+ lr:
264
+ - scheduler:
265
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
266
+ start_value: ${scratch.base_lr}
267
+ end_value: ${divide:${scratch.base_lr},10}
268
+ - scheduler:
269
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
270
+ start_value: ${scratch.vision_lr}
271
+ end_value: ${divide:${scratch.vision_lr},10}
272
+ param_names:
273
+ - 'image_encoder.*'
274
+ weight_decay:
275
+ - scheduler:
276
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
277
+ value: 0.1
278
+ - scheduler:
279
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
280
+ value: 0.0
281
+ param_names:
282
+ - '*bias*'
283
+ module_cls_names: ['torch.nn.LayerNorm']
284
+
285
+ loss:
286
+ all:
287
+ _target_: training.loss_fns.MultiStepMultiMasksAndIous
288
+ weight_dict:
289
+ loss_mask: 20
290
+ loss_dice: 1
291
+ loss_iou: 1
292
+ loss_class: 1
293
+ supervise_all_iou: true
294
+ iou_use_l1_loss: true
295
+ pred_obj_scores: true
296
+ focal_gamma_obj_score: 0.0
297
+ focal_alpha_obj_score: -1.0
298
+
299
+ distributed:
300
+ backend: nccl
301
+ find_unused_parameters: True
302
+
303
+ logging:
304
+ tensorboard_writer:
305
+ _target_: training.utils.logger.make_tensorboard_logger
306
+ log_dir: ${launcher.experiment_log_dir}/tensorboard
307
+ flush_secs: 120
308
+ should_log: True
309
+ log_dir: ${launcher.experiment_log_dir}/logs
310
+ log_freq: 10
311
+
312
+ # initialize from a SAM 2 checkpoint
313
+ checkpoint:
314
+ save_dir: ${launcher.experiment_log_dir}/checkpoints
315
+ save_freq: 0 # 0 only last checkpoint is saved.
316
+ model_weight_initializer:
317
+ _partial_: True
318
+ _target_: training.utils.checkpoint_utils.load_state_dict_into_model
319
+ strict: True
320
+ ignore_unexpected_keys: null
321
+ ignore_missing_keys: null
322
+
323
+ state_dict:
324
+ _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
325
+ checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt ## CHANGED - PATH to SAM 2.1 checkpoint
326
+ ckpt_state_dict_keys: ['model']
327
+
328
+ launcher:
329
+ num_nodes: 1
330
+ gpus_per_node: 8
331
+ experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
332
+
333
+ # SLURM args if running on a cluster
334
+ submitit:
335
+ partition: null
336
+ account: null
337
+ qos: null
338
+ cpus_per_task: 10
339
+ use_cluster: false
340
+ timeout_hour: 24
341
+ name: null
342
+ port_range: [10000, 65000]
343
+
{sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_COCO_1024_prompt.yaml RENAMED
@@ -11,14 +11,6 @@ scratch:
11
  phases_per_epoch: 1
12
  num_epochs: 40
13
 
14
- dataset:
15
- # PATHS to Dataset
16
- img_folder: /mnt/personal/purkrmir/data/COCO/original/train2017/ # PATH to MOSE JPEGImages folder
17
- gt_folder: /mnt/personal/purkrmir/data/COCO/original/annotations/ # PATH to MOSE Annotations folder
18
- # img_folder: /datagrid/personal/purkrmir/data/COCO/original/val2017/ # PATH to MOSE JPEGImages folder
19
- # gt_folder: /datagrid/personal/purkrmir/data/COCO/original/annotations/ # PATH to MOSE Annotations folder
20
- file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training
21
- multiplier: 2
22
 
23
  # Video transforms
24
  vos:
@@ -69,19 +61,19 @@ trainer:
69
  unfreeze_decoder: False
70
 
71
  model:
72
- _target_: training.model.sam2.SAM2Train
73
  image_encoder:
74
- _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
75
  scalp: 1
76
  trunk:
77
- _target_: sam2.modeling.backbones.hieradet.Hiera
78
  embed_dim: 112
79
  num_heads: 2
80
  drop_path_rate: 0.1
81
  neck:
82
- _target_: sam2.modeling.backbones.image_encoder.FpnNeck
83
  position_encoding:
84
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
85
  num_pos_feats: 256
86
  normalize: true
87
  scale: null
@@ -92,17 +84,17 @@ trainer:
92
  fpn_interp_model: nearest
93
 
94
  memory_attention:
95
- _target_: sam2.modeling.memory_attention.MemoryAttention
96
  d_model: 256
97
  pos_enc_at_input: true
98
  layer:
99
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
100
  activation: relu
101
  dim_feedforward: 2048
102
  dropout: 0.1
103
  pos_enc_at_attn: false
104
  self_attention:
105
- _target_: sam2.modeling.sam.transformer.RoPEAttention
106
  rope_theta: 10000.0
107
  feat_sizes: [64, 64]
108
  embedding_dim: 256
@@ -113,7 +105,7 @@ trainer:
113
  pos_enc_at_cross_attn_keys: true
114
  pos_enc_at_cross_attn_queries: false
115
  cross_attention:
116
- _target_: sam2.modeling.sam.transformer.RoPEAttention
117
  rope_theta: 10000.0
118
  feat_sizes: [64, 64]
119
  rope_k_repeat: True
@@ -125,23 +117,23 @@ trainer:
125
  num_layers: 4
126
 
127
  memory_encoder:
128
- _target_: sam2.modeling.memory_encoder.MemoryEncoder
129
  out_dim: 64
130
  position_encoding:
131
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
132
  num_pos_feats: 64
133
  normalize: true
134
  scale: null
135
  temperature: 10000
136
  mask_downsampler:
137
- _target_: sam2.modeling.memory_encoder.MaskDownSampler
138
  kernel_size: 3
139
  stride: 2
140
  padding: 1
141
  fuser:
142
- _target_: sam2.modeling.memory_encoder.Fuser
143
  layer:
144
- _target_: sam2.modeling.memory_encoder.CXBlock
145
  dim: 256
146
  kernel_size: 7
147
  padding: 3
@@ -325,7 +317,7 @@ trainer:
325
 
326
  state_dict:
327
  _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
328
- checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
329
  ckpt_state_dict_keys: ['model']
330
 
331
  launcher:
 
11
  phases_per_epoch: 1
12
  num_epochs: 40
13
 
 
 
 
 
 
 
 
 
14
 
15
  # Video transforms
16
  vos:
 
61
  unfreeze_decoder: False
62
 
63
  model:
64
+ _target_: training.model.bboxmaskpose.sam2.sam2Train
65
  image_encoder:
66
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
67
  scalp: 1
68
  trunk:
69
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
70
  embed_dim: 112
71
  num_heads: 2
72
  drop_path_rate: 0.1
73
  neck:
74
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
75
  position_encoding:
76
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
77
  num_pos_feats: 256
78
  normalize: true
79
  scale: null
 
84
  fpn_interp_model: nearest
85
 
86
  memory_attention:
87
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
88
  d_model: 256
89
  pos_enc_at_input: true
90
  layer:
91
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
92
  activation: relu
93
  dim_feedforward: 2048
94
  dropout: 0.1
95
  pos_enc_at_attn: false
96
  self_attention:
97
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
98
  rope_theta: 10000.0
99
  feat_sizes: [64, 64]
100
  embedding_dim: 256
 
105
  pos_enc_at_cross_attn_keys: true
106
  pos_enc_at_cross_attn_queries: false
107
  cross_attention:
108
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
109
  rope_theta: 10000.0
110
  feat_sizes: [64, 64]
111
  rope_k_repeat: True
 
117
  num_layers: 4
118
 
119
  memory_encoder:
120
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
121
  out_dim: 64
122
  position_encoding:
123
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
124
  num_pos_feats: 64
125
  normalize: true
126
  scale: null
127
  temperature: 10000
128
  mask_downsampler:
129
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
130
  kernel_size: 3
131
  stride: 2
132
  padding: 1
133
  fuser:
134
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
135
  layer:
136
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
137
  dim: 256
138
  kernel_size: 7
139
  padding: 3
 
317
 
318
  state_dict:
319
  _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
320
+ checkpoint_path: ./checkpoints/bboxmaskpose.sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
321
  ckpt_state_dict_keys: ['model']
322
 
323
  launcher:
{sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_COCO_finetune.yaml RENAMED
@@ -11,15 +11,6 @@ scratch:
11
  phases_per_epoch: 1
12
  num_epochs: 40
13
 
14
- dataset:
15
- # PATHS to Dataset
16
- img_folder: /mnt/personal/purkrmir/data/COCO/original/train2017/ # PATH to MOSE JPEGImages folder
17
- gt_folder: /mnt/personal/purkrmir/data/COCO/original/annotations/ # PATH to MOSE Annotations folder
18
- # img_folder: /datagrid/personal/purkrmir/data/COCO/original/val2017/ # PATH to MOSE JPEGImages folder
19
- # gt_folder: /datagrid/personal/purkrmir/data/COCO/original/annotations/ # PATH to MOSE Annotations folder
20
- file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training
21
- multiplier: 2
22
-
23
  # Video transforms
24
  vos:
25
  train_transforms:
@@ -69,19 +60,19 @@ trainer:
69
  unfreeze_decoder: False
70
 
71
  model:
72
- _target_: training.model.sam2.SAM2Train
73
  image_encoder:
74
- _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
75
  scalp: 1
76
  trunk:
77
- _target_: sam2.modeling.backbones.hieradet.Hiera
78
  embed_dim: 112
79
  num_heads: 2
80
  drop_path_rate: 0.1
81
  neck:
82
- _target_: sam2.modeling.backbones.image_encoder.FpnNeck
83
  position_encoding:
84
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
85
  num_pos_feats: 256
86
  normalize: true
87
  scale: null
@@ -92,17 +83,17 @@ trainer:
92
  fpn_interp_model: nearest
93
 
94
  memory_attention:
95
- _target_: sam2.modeling.memory_attention.MemoryAttention
96
  d_model: 256
97
  pos_enc_at_input: true
98
  layer:
99
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
100
  activation: relu
101
  dim_feedforward: 2048
102
  dropout: 0.1
103
  pos_enc_at_attn: false
104
  self_attention:
105
- _target_: sam2.modeling.sam.transformer.RoPEAttention
106
  rope_theta: 10000.0
107
  feat_sizes: [64, 64]
108
  embedding_dim: 256
@@ -113,7 +104,7 @@ trainer:
113
  pos_enc_at_cross_attn_keys: true
114
  pos_enc_at_cross_attn_queries: false
115
  cross_attention:
116
- _target_: sam2.modeling.sam.transformer.RoPEAttention
117
  rope_theta: 10000.0
118
  feat_sizes: [64, 64]
119
  rope_k_repeat: True
@@ -125,23 +116,23 @@ trainer:
125
  num_layers: 4
126
 
127
  memory_encoder:
128
- _target_: sam2.modeling.memory_encoder.MemoryEncoder
129
  out_dim: 64
130
  position_encoding:
131
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
132
  num_pos_feats: 64
133
  normalize: true
134
  scale: null
135
  temperature: 10000
136
  mask_downsampler:
137
- _target_: sam2.modeling.memory_encoder.MaskDownSampler
138
  kernel_size: 3
139
  stride: 2
140
  padding: 1
141
  fuser:
142
- _target_: sam2.modeling.memory_encoder.Fuser
143
  layer:
144
- _target_: sam2.modeling.memory_encoder.CXBlock
145
  dim: 256
146
  kernel_size: 7
147
  padding: 3
@@ -325,7 +316,7 @@ trainer:
325
 
326
  state_dict:
327
  _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
328
- checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
329
  ckpt_state_dict_keys: ['model']
330
 
331
  launcher:
 
11
  phases_per_epoch: 1
12
  num_epochs: 40
13
 
 
 
 
 
 
 
 
 
 
14
  # Video transforms
15
  vos:
16
  train_transforms:
 
60
  unfreeze_decoder: False
61
 
62
  model:
63
+ _target_: training.model.bboxmaskpose.sam2.sam2Train
64
  image_encoder:
65
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
66
  scalp: 1
67
  trunk:
68
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
69
  embed_dim: 112
70
  num_heads: 2
71
  drop_path_rate: 0.1
72
  neck:
73
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
74
  position_encoding:
75
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
76
  num_pos_feats: 256
77
  normalize: true
78
  scale: null
 
83
  fpn_interp_model: nearest
84
 
85
  memory_attention:
86
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
87
  d_model: 256
88
  pos_enc_at_input: true
89
  layer:
90
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
91
  activation: relu
92
  dim_feedforward: 2048
93
  dropout: 0.1
94
  pos_enc_at_attn: false
95
  self_attention:
96
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
97
  rope_theta: 10000.0
98
  feat_sizes: [64, 64]
99
  embedding_dim: 256
 
104
  pos_enc_at_cross_attn_keys: true
105
  pos_enc_at_cross_attn_queries: false
106
  cross_attention:
107
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
108
  rope_theta: 10000.0
109
  feat_sizes: [64, 64]
110
  rope_k_repeat: True
 
116
  num_layers: 4
117
 
118
  memory_encoder:
119
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
120
  out_dim: 64
121
  position_encoding:
122
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
123
  num_pos_feats: 64
124
  normalize: true
125
  scale: null
126
  temperature: 10000
127
  mask_downsampler:
128
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
129
  kernel_size: 3
130
  stride: 2
131
  padding: 1
132
  fuser:
133
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
134
  layer:
135
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
136
  dim: 256
137
  kernel_size: 7
138
  padding: 3
 
316
 
317
  state_dict:
318
  _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
319
+ checkpoint_path: ./checkpoints/bboxmaskpose.sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
320
  ckpt_state_dict_keys: ['model']
321
 
322
  launcher:
{sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_COCO_finetune_prompt+decoder.yaml RENAMED
@@ -11,15 +11,6 @@ scratch:
11
  phases_per_epoch: 1
12
  num_epochs: 40
13
 
14
- dataset:
15
- # PATHS to Dataset
16
- img_folder: /mnt/personal/purkrmir/data/COCO/original/train2017/ # PATH to MOSE JPEGImages folder
17
- gt_folder: /mnt/personal/purkrmir/data/COCO/original/annotations/ # PATH to MOSE Annotations folder
18
- # img_folder: /datagrid/personal/purkrmir/data/COCO/original/train2017/ # PATH to MOSE JPEGImages folder
19
- # gt_folder: /datagrid/personal/purkrmir/data/COCO/original/annotations/ # PATH to MOSE Annotations folder
20
- file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training
21
- multiplier: 2
22
-
23
  # Video transforms
24
  vos:
25
  train_transforms:
@@ -69,19 +60,19 @@ trainer:
69
  unfreeze_decoder: True
70
 
71
  model:
72
- _target_: training.model.sam2.SAM2Train
73
  image_encoder:
74
- _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
75
  scalp: 1
76
  trunk:
77
- _target_: sam2.modeling.backbones.hieradet.Hiera
78
  embed_dim: 112
79
  num_heads: 2
80
  drop_path_rate: 0.1
81
  neck:
82
- _target_: sam2.modeling.backbones.image_encoder.FpnNeck
83
  position_encoding:
84
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
85
  num_pos_feats: 256
86
  normalize: true
87
  scale: null
@@ -92,17 +83,17 @@ trainer:
92
  fpn_interp_model: nearest
93
 
94
  memory_attention:
95
- _target_: sam2.modeling.memory_attention.MemoryAttention
96
  d_model: 256
97
  pos_enc_at_input: true
98
  layer:
99
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
100
  activation: relu
101
  dim_feedforward: 2048
102
  dropout: 0.1
103
  pos_enc_at_attn: false
104
  self_attention:
105
- _target_: sam2.modeling.sam.transformer.RoPEAttention
106
  rope_theta: 10000.0
107
  feat_sizes: [64, 64]
108
  embedding_dim: 256
@@ -113,7 +104,7 @@ trainer:
113
  pos_enc_at_cross_attn_keys: true
114
  pos_enc_at_cross_attn_queries: false
115
  cross_attention:
116
- _target_: sam2.modeling.sam.transformer.RoPEAttention
117
  rope_theta: 10000.0
118
  feat_sizes: [64, 64]
119
  rope_k_repeat: True
@@ -125,23 +116,23 @@ trainer:
125
  num_layers: 4
126
 
127
  memory_encoder:
128
- _target_: sam2.modeling.memory_encoder.MemoryEncoder
129
  out_dim: 64
130
  position_encoding:
131
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
132
  num_pos_feats: 64
133
  normalize: true
134
  scale: null
135
  temperature: 10000
136
  mask_downsampler:
137
- _target_: sam2.modeling.memory_encoder.MaskDownSampler
138
  kernel_size: 3
139
  stride: 2
140
  padding: 1
141
  fuser:
142
- _target_: sam2.modeling.memory_encoder.Fuser
143
  layer:
144
- _target_: sam2.modeling.memory_encoder.CXBlock
145
  dim: 256
146
  kernel_size: 7
147
  padding: 3
@@ -325,7 +316,7 @@ trainer:
325
 
326
  state_dict:
327
  _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
328
- checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
329
  ckpt_state_dict_keys: ['model']
330
 
331
  launcher:
 
11
  phases_per_epoch: 1
12
  num_epochs: 40
13
 
 
 
 
 
 
 
 
 
 
14
  # Video transforms
15
  vos:
16
  train_transforms:
 
60
  unfreeze_decoder: True
61
 
62
  model:
63
+ _target_: training.model.bboxmaskpose.sam2.sam2Train
64
  image_encoder:
65
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
66
  scalp: 1
67
  trunk:
68
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
69
  embed_dim: 112
70
  num_heads: 2
71
  drop_path_rate: 0.1
72
  neck:
73
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
74
  position_encoding:
75
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
76
  num_pos_feats: 256
77
  normalize: true
78
  scale: null
 
83
  fpn_interp_model: nearest
84
 
85
  memory_attention:
86
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
87
  d_model: 256
88
  pos_enc_at_input: true
89
  layer:
90
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
91
  activation: relu
92
  dim_feedforward: 2048
93
  dropout: 0.1
94
  pos_enc_at_attn: false
95
  self_attention:
96
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
97
  rope_theta: 10000.0
98
  feat_sizes: [64, 64]
99
  embedding_dim: 256
 
104
  pos_enc_at_cross_attn_keys: true
105
  pos_enc_at_cross_attn_queries: false
106
  cross_attention:
107
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
108
  rope_theta: 10000.0
109
  feat_sizes: [64, 64]
110
  rope_k_repeat: True
 
116
  num_layers: 4
117
 
118
  memory_encoder:
119
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
120
  out_dim: 64
121
  position_encoding:
122
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
123
  num_pos_feats: 64
124
  normalize: true
125
  scale: null
126
  temperature: 10000
127
  mask_downsampler:
128
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
129
  kernel_size: 3
130
  stride: 2
131
  padding: 1
132
  fuser:
133
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
134
  layer:
135
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
136
  dim: 256
137
  kernel_size: 7
138
  padding: 3
 
316
 
317
  state_dict:
318
  _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
319
+ checkpoint_path: ./checkpoints/bboxmaskpose.sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
320
  ckpt_state_dict_keys: ['model']
321
 
322
  launcher:
{sam2 → bboxmaskpose/sam2}/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml RENAMED
@@ -11,12 +11,6 @@ scratch:
11
  phases_per_epoch: 1
12
  num_epochs: 40
13
 
14
- dataset:
15
- # PATHS to Dataset
16
- img_folder: /datagrid/personal/purkrmir/data/MOSE/train/JPEGImages/ # PATH to MOSE JPEGImages folder
17
- gt_folder: /datagrid/personal/purkrmir/data/MOSE/train/Annotations/ # PATH to MOSE Annotations folder
18
- file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
19
- multiplier: 2
20
 
21
  # Video transforms
22
  vos:
@@ -62,19 +56,19 @@ trainer:
62
  seed_value: 123
63
 
64
  model:
65
- _target_: training.model.sam2.SAM2Train
66
  image_encoder:
67
- _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
68
  scalp: 1
69
  trunk:
70
- _target_: sam2.modeling.backbones.hieradet.Hiera
71
  embed_dim: 112
72
  num_heads: 2
73
  drop_path_rate: 0.1
74
  neck:
75
- _target_: sam2.modeling.backbones.image_encoder.FpnNeck
76
  position_encoding:
77
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
78
  num_pos_feats: 256
79
  normalize: true
80
  scale: null
@@ -85,17 +79,17 @@ trainer:
85
  fpn_interp_model: nearest
86
 
87
  memory_attention:
88
- _target_: sam2.modeling.memory_attention.MemoryAttention
89
  d_model: 256
90
  pos_enc_at_input: true
91
  layer:
92
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
93
  activation: relu
94
  dim_feedforward: 2048
95
  dropout: 0.1
96
  pos_enc_at_attn: false
97
  self_attention:
98
- _target_: sam2.modeling.sam.transformer.RoPEAttention
99
  rope_theta: 10000.0
100
  feat_sizes: [64, 64]
101
  embedding_dim: 256
@@ -106,7 +100,7 @@ trainer:
106
  pos_enc_at_cross_attn_keys: true
107
  pos_enc_at_cross_attn_queries: false
108
  cross_attention:
109
- _target_: sam2.modeling.sam.transformer.RoPEAttention
110
  rope_theta: 10000.0
111
  feat_sizes: [64, 64]
112
  rope_k_repeat: True
@@ -118,23 +112,23 @@ trainer:
118
  num_layers: 4
119
 
120
  memory_encoder:
121
- _target_: sam2.modeling.memory_encoder.MemoryEncoder
122
  out_dim: 64
123
  position_encoding:
124
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
125
  num_pos_feats: 64
126
  normalize: true
127
  scale: null
128
  temperature: 10000
129
  mask_downsampler:
130
- _target_: sam2.modeling.memory_encoder.MaskDownSampler
131
  kernel_size: 3
132
  stride: 2
133
  padding: 1
134
  fuser:
135
- _target_: sam2.modeling.memory_encoder.Fuser
136
  layer:
137
- _target_: sam2.modeling.memory_encoder.CXBlock
138
  dim: 256
139
  kernel_size: 7
140
  padding: 3
@@ -318,7 +312,7 @@ trainer:
318
 
319
  state_dict:
320
  _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
321
- checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
322
  ckpt_state_dict_keys: ['model']
323
 
324
  launcher:
 
11
  phases_per_epoch: 1
12
  num_epochs: 40
13
 
 
 
 
 
 
 
14
 
15
  # Video transforms
16
  vos:
 
56
  seed_value: 123
57
 
58
  model:
59
+ _target_: training.model.bboxmaskpose.sam2.sam2Train
60
  image_encoder:
61
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
62
  scalp: 1
63
  trunk:
64
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
65
  embed_dim: 112
66
  num_heads: 2
67
  drop_path_rate: 0.1
68
  neck:
69
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
70
  position_encoding:
71
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
72
  num_pos_feats: 256
73
  normalize: true
74
  scale: null
 
79
  fpn_interp_model: nearest
80
 
81
  memory_attention:
82
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
83
  d_model: 256
84
  pos_enc_at_input: true
85
  layer:
86
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
87
  activation: relu
88
  dim_feedforward: 2048
89
  dropout: 0.1
90
  pos_enc_at_attn: false
91
  self_attention:
92
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
93
  rope_theta: 10000.0
94
  feat_sizes: [64, 64]
95
  embedding_dim: 256
 
100
  pos_enc_at_cross_attn_keys: true
101
  pos_enc_at_cross_attn_queries: false
102
  cross_attention:
103
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
104
  rope_theta: 10000.0
105
  feat_sizes: [64, 64]
106
  rope_k_repeat: True
 
112
  num_layers: 4
113
 
114
  memory_encoder:
115
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
116
  out_dim: 64
117
  position_encoding:
118
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
119
  num_pos_feats: 64
120
  normalize: true
121
  scale: null
122
  temperature: 10000
123
  mask_downsampler:
124
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
125
  kernel_size: 3
126
  stride: 2
127
  padding: 1
128
  fuser:
129
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
130
  layer:
131
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
132
  dim: 256
133
  kernel_size: 7
134
  padding: 3
 
312
 
313
  state_dict:
314
  _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
315
+ checkpoint_path: ./checkpoints/bboxmaskpose.sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
316
  ckpt_state_dict_keys: ['model']
317
 
318
  launcher:
{sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_b+.yaml RENAMED
@@ -2,18 +2,18 @@
2
 
3
  # Model
4
  model:
5
- _target_: sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
- _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
- _target_: sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 112
12
  num_heads: 2
13
  neck:
14
- _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
  position_encoding:
16
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
  num_pos_feats: 256
18
  normalize: true
19
  scale: null
@@ -24,17 +24,17 @@ model:
24
  fpn_interp_model: nearest
25
 
26
  memory_attention:
27
- _target_: sam2.modeling.memory_attention.MemoryAttention
28
  d_model: 256
29
  pos_enc_at_input: true
30
  layer:
31
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
  activation: relu
33
  dim_feedforward: 2048
34
  dropout: 0.1
35
  pos_enc_at_attn: false
36
  self_attention:
37
- _target_: sam2.modeling.sam.transformer.RoPEAttention
38
  rope_theta: 10000.0
39
  feat_sizes: [32, 32]
40
  embedding_dim: 256
@@ -45,7 +45,7 @@ model:
45
  pos_enc_at_cross_attn_keys: true
46
  pos_enc_at_cross_attn_queries: false
47
  cross_attention:
48
- _target_: sam2.modeling.sam.transformer.RoPEAttention
49
  rope_theta: 10000.0
50
  feat_sizes: [32, 32]
51
  rope_k_repeat: True
@@ -57,23 +57,23 @@ model:
57
  num_layers: 4
58
 
59
  memory_encoder:
60
- _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
  out_dim: 64
62
  position_encoding:
63
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
  num_pos_feats: 64
65
  normalize: true
66
  scale: null
67
  temperature: 10000
68
  mask_downsampler:
69
- _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
  kernel_size: 3
71
  stride: 2
72
  padding: 1
73
  fuser:
74
- _target_: sam2.modeling.memory_encoder.Fuser
75
  layer:
76
- _target_: sam2.modeling.memory_encoder.CXBlock
77
  dim: 256
78
  kernel_size: 7
79
  padding: 3
 
2
 
3
  # Model
4
  model:
5
+ _target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 112
12
  num_heads: 2
13
  neck:
14
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
15
  position_encoding:
16
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
17
  num_pos_feats: 256
18
  normalize: true
19
  scale: null
 
24
  fpn_interp_model: nearest
25
 
26
  memory_attention:
27
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
28
  d_model: 256
29
  pos_enc_at_input: true
30
  layer:
31
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
32
  activation: relu
33
  dim_feedforward: 2048
34
  dropout: 0.1
35
  pos_enc_at_attn: false
36
  self_attention:
37
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
38
  rope_theta: 10000.0
39
  feat_sizes: [32, 32]
40
  embedding_dim: 256
 
45
  pos_enc_at_cross_attn_keys: true
46
  pos_enc_at_cross_attn_queries: false
47
  cross_attention:
48
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
49
  rope_theta: 10000.0
50
  feat_sizes: [32, 32]
51
  rope_k_repeat: True
 
57
  num_layers: 4
58
 
59
  memory_encoder:
60
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
61
  out_dim: 64
62
  position_encoding:
63
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
64
  num_pos_feats: 64
65
  normalize: true
66
  scale: null
67
  temperature: 10000
68
  mask_downsampler:
69
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
70
  kernel_size: 3
71
  stride: 2
72
  padding: 1
73
  fuser:
74
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
75
  layer:
76
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
77
  dim: 256
78
  kernel_size: 7
79
  padding: 3
{sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_l.yaml RENAMED
@@ -2,12 +2,12 @@
2
 
3
  # Model
4
  model:
5
- _target_: sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
- _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
- _target_: sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 144
12
  num_heads: 2
13
  stages: [2, 6, 36, 4]
@@ -15,9 +15,9 @@ model:
15
  window_pos_embed_bkg_spatial_size: [7, 7]
16
  window_spec: [8, 4, 16, 8]
17
  neck:
18
- _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
  position_encoding:
20
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
  num_pos_feats: 256
22
  normalize: true
23
  scale: null
@@ -28,17 +28,17 @@ model:
28
  fpn_interp_model: nearest
29
 
30
  memory_attention:
31
- _target_: sam2.modeling.memory_attention.MemoryAttention
32
  d_model: 256
33
  pos_enc_at_input: true
34
  layer:
35
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
  activation: relu
37
  dim_feedforward: 2048
38
  dropout: 0.1
39
  pos_enc_at_attn: false
40
  self_attention:
41
- _target_: sam2.modeling.sam.transformer.RoPEAttention
42
  rope_theta: 10000.0
43
  feat_sizes: [32, 32]
44
  embedding_dim: 256
@@ -49,7 +49,7 @@ model:
49
  pos_enc_at_cross_attn_keys: true
50
  pos_enc_at_cross_attn_queries: false
51
  cross_attention:
52
- _target_: sam2.modeling.sam.transformer.RoPEAttention
53
  rope_theta: 10000.0
54
  feat_sizes: [32, 32]
55
  rope_k_repeat: True
@@ -61,23 +61,23 @@ model:
61
  num_layers: 4
62
 
63
  memory_encoder:
64
- _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
  out_dim: 64
66
  position_encoding:
67
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
  num_pos_feats: 64
69
  normalize: true
70
  scale: null
71
  temperature: 10000
72
  mask_downsampler:
73
- _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
  kernel_size: 3
75
  stride: 2
76
  padding: 1
77
  fuser:
78
- _target_: sam2.modeling.memory_encoder.Fuser
79
  layer:
80
- _target_: sam2.modeling.memory_encoder.CXBlock
81
  dim: 256
82
  kernel_size: 7
83
  padding: 3
 
2
 
3
  # Model
4
  model:
5
+ _target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 144
12
  num_heads: 2
13
  stages: [2, 6, 36, 4]
 
15
  window_pos_embed_bkg_spatial_size: [7, 7]
16
  window_spec: [8, 4, 16, 8]
17
  neck:
18
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
19
  position_encoding:
20
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
21
  num_pos_feats: 256
22
  normalize: true
23
  scale: null
 
28
  fpn_interp_model: nearest
29
 
30
  memory_attention:
31
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
32
  d_model: 256
33
  pos_enc_at_input: true
34
  layer:
35
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
36
  activation: relu
37
  dim_feedforward: 2048
38
  dropout: 0.1
39
  pos_enc_at_attn: false
40
  self_attention:
41
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
42
  rope_theta: 10000.0
43
  feat_sizes: [32, 32]
44
  embedding_dim: 256
 
49
  pos_enc_at_cross_attn_keys: true
50
  pos_enc_at_cross_attn_queries: false
51
  cross_attention:
52
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
53
  rope_theta: 10000.0
54
  feat_sizes: [32, 32]
55
  rope_k_repeat: True
 
61
  num_layers: 4
62
 
63
  memory_encoder:
64
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
65
  out_dim: 64
66
  position_encoding:
67
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
68
  num_pos_feats: 64
69
  normalize: true
70
  scale: null
71
  temperature: 10000
72
  mask_downsampler:
73
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
74
  kernel_size: 3
75
  stride: 2
76
  padding: 1
77
  fuser:
78
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
79
  layer:
80
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
81
  dim: 256
82
  kernel_size: 7
83
  padding: 3
{sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_s.yaml RENAMED
@@ -2,21 +2,21 @@
2
 
3
  # Model
4
  model:
5
- _target_: sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
- _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
- _target_: sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 96
12
  num_heads: 1
13
  stages: [1, 2, 11, 2]
14
  global_att_blocks: [7, 10, 13]
15
  window_pos_embed_bkg_spatial_size: [7, 7]
16
  neck:
17
- _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
  position_encoding:
19
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
  num_pos_feats: 256
21
  normalize: true
22
  scale: null
@@ -27,17 +27,17 @@ model:
27
  fpn_interp_model: nearest
28
 
29
  memory_attention:
30
- _target_: sam2.modeling.memory_attention.MemoryAttention
31
  d_model: 256
32
  pos_enc_at_input: true
33
  layer:
34
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
  activation: relu
36
  dim_feedforward: 2048
37
  dropout: 0.1
38
  pos_enc_at_attn: false
39
  self_attention:
40
- _target_: sam2.modeling.sam.transformer.RoPEAttention
41
  rope_theta: 10000.0
42
  feat_sizes: [32, 32]
43
  embedding_dim: 256
@@ -48,7 +48,7 @@ model:
48
  pos_enc_at_cross_attn_keys: true
49
  pos_enc_at_cross_attn_queries: false
50
  cross_attention:
51
- _target_: sam2.modeling.sam.transformer.RoPEAttention
52
  rope_theta: 10000.0
53
  feat_sizes: [32, 32]
54
  rope_k_repeat: True
@@ -60,23 +60,23 @@ model:
60
  num_layers: 4
61
 
62
  memory_encoder:
63
- _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
  out_dim: 64
65
  position_encoding:
66
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
  num_pos_feats: 64
68
  normalize: true
69
  scale: null
70
  temperature: 10000
71
  mask_downsampler:
72
- _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
  kernel_size: 3
74
  stride: 2
75
  padding: 1
76
  fuser:
77
- _target_: sam2.modeling.memory_encoder.Fuser
78
  layer:
79
- _target_: sam2.modeling.memory_encoder.CXBlock
80
  dim: 256
81
  kernel_size: 7
82
  padding: 3
 
2
 
3
  # Model
4
  model:
5
+ _target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 96
12
  num_heads: 1
13
  stages: [1, 2, 11, 2]
14
  global_att_blocks: [7, 10, 13]
15
  window_pos_embed_bkg_spatial_size: [7, 7]
16
  neck:
17
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
18
  position_encoding:
19
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
20
  num_pos_feats: 256
21
  normalize: true
22
  scale: null
 
27
  fpn_interp_model: nearest
28
 
29
  memory_attention:
30
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
31
  d_model: 256
32
  pos_enc_at_input: true
33
  layer:
34
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
35
  activation: relu
36
  dim_feedforward: 2048
37
  dropout: 0.1
38
  pos_enc_at_attn: false
39
  self_attention:
40
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
41
  rope_theta: 10000.0
42
  feat_sizes: [32, 32]
43
  embedding_dim: 256
 
48
  pos_enc_at_cross_attn_keys: true
49
  pos_enc_at_cross_attn_queries: false
50
  cross_attention:
51
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
52
  rope_theta: 10000.0
53
  feat_sizes: [32, 32]
54
  rope_k_repeat: True
 
60
  num_layers: 4
61
 
62
  memory_encoder:
63
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
64
  out_dim: 64
65
  position_encoding:
66
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
67
  num_pos_feats: 64
68
  normalize: true
69
  scale: null
70
  temperature: 10000
71
  mask_downsampler:
72
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
73
  kernel_size: 3
74
  stride: 2
75
  padding: 1
76
  fuser:
77
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
78
  layer:
79
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
80
  dim: 256
81
  kernel_size: 7
82
  padding: 3
{sam2 → bboxmaskpose/sam2}/configs/sam2/sam2_hiera_t.yaml RENAMED
@@ -2,21 +2,21 @@
2
 
3
  # Model
4
  model:
5
- _target_: sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
- _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
- _target_: sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 96
12
  num_heads: 1
13
  stages: [1, 2, 7, 2]
14
  global_att_blocks: [5, 7, 9]
15
  window_pos_embed_bkg_spatial_size: [7, 7]
16
  neck:
17
- _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
  position_encoding:
19
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
  num_pos_feats: 256
21
  normalize: true
22
  scale: null
@@ -27,17 +27,17 @@ model:
27
  fpn_interp_model: nearest
28
 
29
  memory_attention:
30
- _target_: sam2.modeling.memory_attention.MemoryAttention
31
  d_model: 256
32
  pos_enc_at_input: true
33
  layer:
34
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
  activation: relu
36
  dim_feedforward: 2048
37
  dropout: 0.1
38
  pos_enc_at_attn: false
39
  self_attention:
40
- _target_: sam2.modeling.sam.transformer.RoPEAttention
41
  rope_theta: 10000.0
42
  feat_sizes: [32, 32]
43
  embedding_dim: 256
@@ -48,7 +48,7 @@ model:
48
  pos_enc_at_cross_attn_keys: true
49
  pos_enc_at_cross_attn_queries: false
50
  cross_attention:
51
- _target_: sam2.modeling.sam.transformer.RoPEAttention
52
  rope_theta: 10000.0
53
  feat_sizes: [32, 32]
54
  rope_k_repeat: True
@@ -60,23 +60,23 @@ model:
60
  num_layers: 4
61
 
62
  memory_encoder:
63
- _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
  out_dim: 64
65
  position_encoding:
66
- _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
  num_pos_feats: 64
68
  normalize: true
69
  scale: null
70
  temperature: 10000
71
  mask_downsampler:
72
- _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
  kernel_size: 3
74
  stride: 2
75
  padding: 1
76
  fuser:
77
- _target_: sam2.modeling.memory_encoder.Fuser
78
  layer:
79
- _target_: sam2.modeling.memory_encoder.CXBlock
80
  dim: 256
81
  kernel_size: 7
82
  padding: 3
 
2
 
3
  # Model
4
  model:
5
+ _target_: bboxmaskpose.sam2.modeling.sam2_base.SAM2Base
6
  image_encoder:
7
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.ImageEncoder
8
  scalp: 1
9
  trunk:
10
+ _target_: bboxmaskpose.sam2.modeling.backbones.hieradet.Hiera
11
  embed_dim: 96
12
  num_heads: 1
13
  stages: [1, 2, 7, 2]
14
  global_att_blocks: [5, 7, 9]
15
  window_pos_embed_bkg_spatial_size: [7, 7]
16
  neck:
17
+ _target_: bboxmaskpose.sam2.modeling.backbones.image_encoder.FpnNeck
18
  position_encoding:
19
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
20
  num_pos_feats: 256
21
  normalize: true
22
  scale: null
 
27
  fpn_interp_model: nearest
28
 
29
  memory_attention:
30
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttention
31
  d_model: 256
32
  pos_enc_at_input: true
33
  layer:
34
+ _target_: bboxmaskpose.sam2.modeling.memory_attention.MemoryAttentionLayer
35
  activation: relu
36
  dim_feedforward: 2048
37
  dropout: 0.1
38
  pos_enc_at_attn: false
39
  self_attention:
40
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
41
  rope_theta: 10000.0
42
  feat_sizes: [32, 32]
43
  embedding_dim: 256
 
48
  pos_enc_at_cross_attn_keys: true
49
  pos_enc_at_cross_attn_queries: false
50
  cross_attention:
51
+ _target_: bboxmaskpose.sam2.modeling.sam.transformer.RoPEAttention
52
  rope_theta: 10000.0
53
  feat_sizes: [32, 32]
54
  rope_k_repeat: True
 
60
  num_layers: 4
61
 
62
  memory_encoder:
63
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MemoryEncoder
64
  out_dim: 64
65
  position_encoding:
66
+ _target_: bboxmaskpose.sam2.modeling.position_encoding.PositionEmbeddingSine
67
  num_pos_feats: 64
68
  normalize: true
69
  scale: null
70
  temperature: 10000
71
  mask_downsampler:
72
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.MaskDownSampler
73
  kernel_size: 3
74
  stride: 2
75
  padding: 1
76
  fuser:
77
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.Fuser
78
  layer:
79
+ _target_: bboxmaskpose.sam2.modeling.memory_encoder.CXBlock
80
  dim: 256
81
  kernel_size: 7
82
  padding: 3
{sam2 → bboxmaskpose/sam2}/csrc/connected_components.cu RENAMED
File without changes
{sam2 → bboxmaskpose/sam2}/distinctipy.py RENAMED
@@ -1,3 +1,5 @@
 
 
1
  import math
2
  import random
3
 
@@ -125,9 +127,7 @@ def color_distance(c1, c2):
125
  return distance
126
 
127
 
128
- def distinct_color(
129
- exclude_colors, pastel_factor=0.0, n_attempts=1000, colorblind_type=None, rng=None
130
- ):
131
  """
132
  Generate a colour as distinct as possible from the colours defined in exclude_colors
133
  Inspired by: https://gist.github.com/adewes/5884820
@@ -164,10 +164,7 @@ def distinct_color(
164
  return get_random_color(pastel_factor=pastel_factor, rng=rng)
165
 
166
  if colorblind_type:
167
- exclude_colors = [
168
- colorblind.colorblind_filter(color, colorblind_type)
169
- for color in exclude_colors
170
- ]
171
 
172
  max_distance = None
173
  best_color = None
@@ -181,9 +178,7 @@ def distinct_color(
181
  else:
182
  compare_color = color
183
 
184
- distance_to_nearest = min(
185
- [color_distance(compare_color, c) for c in exclude_colors]
186
- )
187
 
188
  if (not max_distance) or (distance_to_nearest > max_distance):
189
  max_distance = distance_to_nearest
@@ -202,9 +197,7 @@ def distinct_color(
202
  else:
203
  compare_color = color
204
 
205
- distance_to_nearest = min(
206
- [color_distance(compare_color, c) for c in exclude_colors]
207
- )
208
 
209
  if (not max_distance) or (distance_to_nearest > max_distance):
210
  max_distance = distance_to_nearest
@@ -500,4 +493,4 @@ def get_colormap(list_of_colors, name="distinctipy"):
500
 
501
  cmap = matplotlib.colors.ListedColormap(list_of_colors, name=name)
502
 
503
- return cmap
 
1
+ # Adapted from the distinctipy repository (https://github.com/alan-turing-institute/distinctipy).
2
+ # Original authors: distinctipy contributors. Included with minor modifications.
3
  import math
4
  import random
5
 
 
127
  return distance
128
 
129
 
130
+ def distinct_color(exclude_colors, pastel_factor=0.0, n_attempts=1000, colorblind_type=None, rng=None):
 
 
131
  """
132
  Generate a colour as distinct as possible from the colours defined in exclude_colors
133
  Inspired by: https://gist.github.com/adewes/5884820
 
164
  return get_random_color(pastel_factor=pastel_factor, rng=rng)
165
 
166
  if colorblind_type:
167
+ exclude_colors = [colorblind.colorblind_filter(color, colorblind_type) for color in exclude_colors]
 
 
 
168
 
169
  max_distance = None
170
  best_color = None
 
178
  else:
179
  compare_color = color
180
 
181
+ distance_to_nearest = min([color_distance(compare_color, c) for c in exclude_colors])
 
 
182
 
183
  if (not max_distance) or (distance_to_nearest > max_distance):
184
  max_distance = distance_to_nearest
 
197
  else:
198
  compare_color = color
199
 
200
+ distance_to_nearest = min([color_distance(compare_color, c) for c in exclude_colors])
 
 
201
 
202
  if (not max_distance) or (distance_to_nearest > max_distance):
203
  max_distance = distance_to_nearest
 
493
 
494
  cmap = matplotlib.colors.ListedColormap(list_of_colors, name=name)
495
 
496
+ return cmap
{sam2 → bboxmaskpose/sam2}/modeling/__init__.py RENAMED
File without changes
{sam2 → bboxmaskpose/sam2}/modeling/backbones/__init__.py RENAMED
File without changes
{sam2 → bboxmaskpose/sam2}/modeling/backbones/hieradet.py RENAMED
@@ -11,15 +11,10 @@ from typing import List, Tuple, Union
11
  import torch
12
  import torch.nn as nn
13
  import torch.nn.functional as F
14
- from iopath.common.file_io import g_pathmgr
15
-
16
- from sam2.modeling.backbones.utils import (
17
- PatchEmbed,
18
- window_partition,
19
- window_unpartition,
20
- )
21
 
22
- from sam2.modeling.sam2_utils import DropPath, MLP
 
 
23
 
24
 
25
  def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
@@ -107,9 +102,7 @@ class MultiScaleBlock(nn.Module):
107
 
108
  self.pool, self.q_stride = None, q_stride
109
  if self.q_stride:
110
- self.pool = nn.MaxPool2d(
111
- kernel_size=q_stride, stride=q_stride, ceil_mode=False
112
- )
113
 
114
  self.attn = MultiScaleAttention(
115
  dim,
@@ -218,16 +211,10 @@ class Hiera(nn.Module):
218
 
219
  # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
220
  self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
221
- self.pos_embed = nn.Parameter(
222
- torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
223
- )
224
- self.pos_embed_window = nn.Parameter(
225
- torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
226
- )
227
 
228
- dpr = [
229
- x.item() for x in torch.linspace(0, drop_path_rate, depth)
230
- ] # stochastic depth decay rule
231
 
232
  cur_stage = 1
233
  self.blocks = nn.ModuleList()
@@ -259,11 +246,7 @@ class Hiera(nn.Module):
259
  embed_dim = dim_out
260
  self.blocks.append(block)
261
 
262
- self.channel_list = (
263
- [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
264
- if return_interm_layers
265
- else [self.blocks[-1].dim_out]
266
- )
267
 
268
  if weights_path is not None:
269
  with g_pathmgr.open(weights_path, "rb") as f:
@@ -274,9 +257,7 @@ class Hiera(nn.Module):
274
  h, w = hw
275
  window_embed = self.pos_embed_window
276
  pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
277
- pos_embed = pos_embed + window_embed.tile(
278
- [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
279
- )
280
  pos_embed = pos_embed.permute(0, 2, 3, 1)
281
  return pos_embed
282
 
@@ -290,9 +271,7 @@ class Hiera(nn.Module):
290
  outputs = []
291
  for i, blk in enumerate(self.blocks):
292
  x = blk(x)
293
- if (i == self.stage_ends[-1]) or (
294
- i in self.stage_ends and self.return_interm_layers
295
- ):
296
  feats = x.permute(0, 3, 1, 2)
297
  outputs.append(feats)
298
 
 
11
  import torch
12
  import torch.nn as nn
13
  import torch.nn.functional as F
 
 
 
 
 
 
 
14
 
15
+ from bboxmaskpose.sam2.modeling.backbones.utils import PatchEmbed, window_partition, window_unpartition
16
+ from bboxmaskpose.sam2.modeling.sam2_utils import MLP, DropPath
17
+ from iopath.common.file_io import g_pathmgr
18
 
19
 
20
  def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
 
102
 
103
  self.pool, self.q_stride = None, q_stride
104
  if self.q_stride:
105
+ self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)
 
 
106
 
107
  self.attn = MultiScaleAttention(
108
  dim,
 
211
 
212
  # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
213
  self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
214
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
215
+ self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
 
 
 
 
216
 
217
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
 
 
218
 
219
  cur_stage = 1
220
  self.blocks = nn.ModuleList()
 
246
  embed_dim = dim_out
247
  self.blocks.append(block)
248
 
249
+ self.channel_list = [self.blocks[i].dim_out for i in self.stage_ends[::-1]] if return_interm_layers else [self.blocks[-1].dim_out]
 
 
 
 
250
 
251
  if weights_path is not None:
252
  with g_pathmgr.open(weights_path, "rb") as f:
 
257
  h, w = hw
258
  window_embed = self.pos_embed_window
259
  pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
260
+ pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
 
 
261
  pos_embed = pos_embed.permute(0, 2, 3, 1)
262
  return pos_embed
263
 
 
271
  outputs = []
272
  for i, blk in enumerate(self.blocks):
273
  x = blk(x)
274
+ if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
 
 
275
  feats = x.permute(0, 3, 1, 2)
276
  outputs.append(feats)
277
 
{sam2 → bboxmaskpose/sam2}/modeling/backbones/image_encoder.py RENAMED
@@ -117,9 +117,7 @@ class FpnNeck(nn.Module):
117
  prev_features.to(dtype=torch.float32),
118
  scale_factor=2.0,
119
  mode=self.fpn_interp_model,
120
- align_corners=(
121
- None if self.fpn_interp_model == "nearest" else False
122
- ),
123
  antialias=False,
124
  )
125
  prev_features = lateral_features + top_down_features
 
117
  prev_features.to(dtype=torch.float32),
118
  scale_factor=2.0,
119
  mode=self.fpn_interp_model,
120
+ align_corners=(None if self.fpn_interp_model == "nearest" else False),
 
 
121
  antialias=False,
122
  )
123
  prev_features = lateral_features + top_down_features
{sam2 → bboxmaskpose/sam2}/modeling/backbones/utils.py RENAMED
@@ -50,9 +50,7 @@ def window_unpartition(windows, window_size, pad_hw, hw):
50
  Hp, Wp = pad_hw
51
  H, W = hw
52
  B = windows.shape[0] // (Hp * Wp // window_size // window_size)
53
- x = windows.reshape(
54
- B, Hp // window_size, Wp // window_size, window_size, window_size, -1
55
- )
56
  x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
57
 
58
  if Hp > H or Wp > W:
@@ -82,9 +80,7 @@ class PatchEmbed(nn.Module):
82
  embed_dim (int): embed_dim (int): Patch embedding dimension.
83
  """
84
  super().__init__()
85
- self.proj = nn.Conv2d(
86
- in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
87
- )
88
 
89
  def forward(self, x: torch.Tensor) -> torch.Tensor:
90
  x = self.proj(x)
 
50
  Hp, Wp = pad_hw
51
  H, W = hw
52
  B = windows.shape[0] // (Hp * Wp // window_size // window_size)
53
+ x = windows.reshape(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
 
 
54
  x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
55
 
56
  if Hp > H or Wp > W:
 
80
  embed_dim (int): embed_dim (int): Patch embedding dimension.
81
  """
82
  super().__init__()
83
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
 
 
84
 
85
  def forward(self, x: torch.Tensor) -> torch.Tensor:
86
  x = self.proj(x)
{sam2 → bboxmaskpose/sam2}/modeling/memory_attention.py RENAMED
@@ -7,11 +7,10 @@
7
  from typing import Optional
8
 
9
  import torch
10
- from torch import nn, Tensor
11
 
12
- from sam2.modeling.sam.transformer import RoPEAttention
13
-
14
- from sam2.modeling.sam2_utils import get_activation_fn, get_clones
15
 
16
 
17
  class MemoryAttentionLayer(nn.Module):
@@ -132,9 +131,7 @@ class MemoryAttention(nn.Module):
132
  curr_pos[0],
133
  )
134
 
135
- assert (
136
- curr.shape[1] == memory.shape[1]
137
- ), "Batch size must be the same for curr and memory"
138
 
139
  output = curr
140
  if self.pos_enc_at_input and curr_pos is not None:
 
7
  from typing import Optional
8
 
9
  import torch
10
+ from torch import Tensor, nn
11
 
12
+ from bboxmaskpose.sam2.modeling.sam2_utils import get_activation_fn, get_clones
13
+ from bboxmaskpose.sam2.modeling.sam.transformer import RoPEAttention
 
14
 
15
 
16
  class MemoryAttentionLayer(nn.Module):
 
131
  curr_pos[0],
132
  )
133
 
134
+ assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"
 
 
135
 
136
  output = curr
137
  if self.pos_enc_at_input and curr_pos is not None:
{sam2 → bboxmaskpose/sam2}/modeling/memory_encoder.py RENAMED
@@ -11,7 +11,7 @@ import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
 
14
- from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
15
 
16
 
17
  class MaskDownSampler(nn.Module):
@@ -89,16 +89,10 @@ class CXBlock(nn.Module):
89
  groups=dim if use_dwconv else 1,
90
  ) # depthwise conv
91
  self.norm = LayerNorm2d(dim, eps=1e-6)
92
- self.pwconv1 = nn.Linear(
93
- dim, 4 * dim
94
- ) # pointwise/1x1 convs, implemented with linear layers
95
  self.act = nn.GELU()
96
  self.pwconv2 = nn.Linear(4 * dim, dim)
97
- self.gamma = (
98
- nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99
- if layer_scale_init_value > 0
100
- else None
101
- )
102
  self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103
 
104
  def forward(self, x):
 
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
 
14
+ from bboxmaskpose.sam2.modeling.sam2_utils import DropPath, LayerNorm2d, get_clones
15
 
16
 
17
  class MaskDownSampler(nn.Module):
 
89
  groups=dim if use_dwconv else 1,
90
  ) # depthwise conv
91
  self.norm = LayerNorm2d(dim, eps=1e-6)
92
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
 
 
93
  self.act = nn.GELU()
94
  self.pwconv2 = nn.Linear(4 * dim, dim)
95
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 else None
 
 
 
 
96
  self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
97
 
98
  def forward(self, x):
{sam2 → bboxmaskpose/sam2}/modeling/position_encoding.py RENAMED
@@ -8,7 +8,6 @@ import math
8
  from typing import Any, Optional, Tuple
9
 
10
  import numpy as np
11
-
12
  import torch
13
  from torch import nn
14
 
@@ -61,12 +60,8 @@ class PositionEmbeddingSine(nn.Module):
61
 
62
  pos_x = x_embed[:, None] / dim_t
63
  pos_y = y_embed[:, None] / dim_t
64
- pos_x = torch.stack(
65
- (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
66
- ).flatten(1)
67
- pos_y = torch.stack(
68
- (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
69
- ).flatten(1)
70
  return pos_x, pos_y
71
 
72
  @torch.no_grad()
@@ -92,16 +87,8 @@ class PositionEmbeddingSine(nn.Module):
92
  if cache_key in self.cache:
93
  return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
94
 
95
- y_embed = (
96
- torch.arange(1, H + 1, dtype=torch.float32, device=device)
97
- .view(1, -1, 1)
98
- .repeat(B, 1, W)
99
- )
100
- x_embed = (
101
- torch.arange(1, W + 1, dtype=torch.float32, device=device)
102
- .view(1, 1, -1)
103
- .repeat(B, H, 1)
104
- )
105
 
106
  if self.normalize:
107
  eps = 1e-6
@@ -113,12 +100,8 @@ class PositionEmbeddingSine(nn.Module):
113
 
114
  pos_x = x_embed[:, :, :, None] / dim_t
115
  pos_y = y_embed[:, :, :, None] / dim_t
116
- pos_x = torch.stack(
117
- (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
118
- ).flatten(3)
119
- pos_y = torch.stack(
120
- (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
121
- ).flatten(3)
122
  pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
123
  self.cache[cache_key] = pos[0]
124
  return pos
@@ -166,9 +149,7 @@ class PositionEmbeddingRandom(nn.Module):
166
  pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
167
  return pe.permute(2, 0, 1) # C x H x W
168
 
169
- def forward_with_coords(
170
- self, coords_input: torch.Tensor, image_size: Tuple[int, int]
171
- ) -> torch.Tensor:
172
  """Positionally encode points that are not normalized to [0,1]."""
173
  coords = coords_input.clone()
174
  coords[:, :, 0] = coords[:, :, 0] / image_size[1]
@@ -216,11 +197,7 @@ def apply_rotary_enc(
216
  repeat_freqs_k: bool = False,
217
  ):
218
  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
219
- xk_ = (
220
- torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
221
- if xk.shape[-2] != 0
222
- else None
223
- )
224
  freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
225
  xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
226
  if xk_ is None:
 
8
  from typing import Any, Optional, Tuple
9
 
10
  import numpy as np
 
11
  import torch
12
  from torch import nn
13
 
 
60
 
61
  pos_x = x_embed[:, None] / dim_t
62
  pos_y = y_embed[:, None] / dim_t
63
+ pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
64
+ pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
 
 
 
 
65
  return pos_x, pos_y
66
 
67
  @torch.no_grad()
 
87
  if cache_key in self.cache:
88
  return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
89
 
90
+ y_embed = torch.arange(1, H + 1, dtype=torch.float32, device=device).view(1, -1, 1).repeat(B, 1, W)
91
+ x_embed = torch.arange(1, W + 1, dtype=torch.float32, device=device).view(1, 1, -1).repeat(B, H, 1)
 
 
 
 
 
 
 
 
92
 
93
  if self.normalize:
94
  eps = 1e-6
 
100
 
101
  pos_x = x_embed[:, :, :, None] / dim_t
102
  pos_y = y_embed[:, :, :, None] / dim_t
103
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
104
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
 
 
 
 
105
  pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
106
  self.cache[cache_key] = pos[0]
107
  return pos
 
149
  pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
150
  return pe.permute(2, 0, 1) # C x H x W
151
 
152
+ def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
 
 
153
  """Positionally encode points that are not normalized to [0,1]."""
154
  coords = coords_input.clone()
155
  coords[:, :, 0] = coords[:, :, 0] / image_size[1]
 
197
  repeat_freqs_k: bool = False,
198
  ):
199
  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
200
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
 
 
 
 
201
  freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
202
  xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
203
  if xk_ is None:
{sam2 → bboxmaskpose/sam2}/modeling/sam/__init__.py RENAMED
File without changes
{sam2 → bboxmaskpose/sam2}/modeling/sam/mask_decoder.py RENAMED
@@ -9,7 +9,7 @@ from typing import List, Optional, Tuple, Type
9
  import torch
10
  from torch import nn
11
 
12
- from sam2.modeling.sam2_utils import LayerNorm2d, MLP
13
 
14
 
15
  class MaskDecoder(nn.Module):
@@ -63,30 +63,19 @@ class MaskDecoder(nn.Module):
63
  self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
64
 
65
  self.output_upscaling = nn.Sequential(
66
- nn.ConvTranspose2d(
67
- transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
68
- ),
69
  LayerNorm2d(transformer_dim // 4),
70
  activation(),
71
- nn.ConvTranspose2d(
72
- transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
73
- ),
74
  activation(),
75
  )
76
  self.use_high_res_features = use_high_res_features
77
  if use_high_res_features:
78
- self.conv_s0 = nn.Conv2d(
79
- transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
80
- )
81
- self.conv_s1 = nn.Conv2d(
82
- transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
83
- )
84
 
85
  self.output_hypernetworks_mlps = nn.ModuleList(
86
- [
87
- MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
88
- for i in range(self.num_mask_tokens)
89
- ]
90
  )
91
 
92
  self.iou_prediction_head = MLP(
@@ -188,12 +177,8 @@ class MaskDecoder(nn.Module):
188
  )
189
  s = 1
190
  else:
191
- output_tokens = torch.cat(
192
- [self.iou_token.weight, self.mask_tokens.weight], dim=0
193
- )
194
- output_tokens = output_tokens.unsqueeze(0).expand(
195
- sparse_prompt_embeddings.size(0), -1, -1
196
- )
197
  tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
198
 
199
  # Expand per-image data in batch direction to be per-mask
@@ -203,9 +188,7 @@ class MaskDecoder(nn.Module):
203
  assert image_embeddings.shape[0] == tokens.shape[0]
204
  src = image_embeddings
205
  src = src + dense_prompt_embeddings
206
- assert (
207
- image_pe.size(0) == 1
208
- ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
209
  pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
210
  b, c, h, w = src.shape
211
 
@@ -226,9 +209,7 @@ class MaskDecoder(nn.Module):
226
 
227
  hyper_in_list: List[torch.Tensor] = []
228
  for i in range(self.num_mask_tokens):
229
- hyper_in_list.append(
230
- self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
231
- )
232
  hyper_in = torch.stack(hyper_in_list, dim=1)
233
  b, c, h, w = upscaled_embedding.shape
234
  masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
@@ -267,9 +248,7 @@ class MaskDecoder(nn.Module):
267
  multimask_logits = all_mask_logits[:, 1:, :, :]
268
  multimask_iou_scores = all_iou_scores[:, 1:]
269
  best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
270
- batch_inds = torch.arange(
271
- multimask_iou_scores.size(0), device=all_iou_scores.device
272
- )
273
  best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
274
  best_multimask_logits = best_multimask_logits.unsqueeze(1)
275
  best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
 
9
  import torch
10
  from torch import nn
11
 
12
+ from bboxmaskpose.sam2.modeling.sam2_utils import MLP, LayerNorm2d
13
 
14
 
15
  class MaskDecoder(nn.Module):
 
63
  self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
64
 
65
  self.output_upscaling = nn.Sequential(
66
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
 
 
67
  LayerNorm2d(transformer_dim // 4),
68
  activation(),
69
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
 
 
70
  activation(),
71
  )
72
  self.use_high_res_features = use_high_res_features
73
  if use_high_res_features:
74
+ self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
75
+ self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
 
 
 
 
76
 
77
  self.output_hypernetworks_mlps = nn.ModuleList(
78
+ [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens)]
 
 
 
79
  )
80
 
81
  self.iou_prediction_head = MLP(
 
177
  )
178
  s = 1
179
  else:
180
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
181
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
 
 
 
 
182
  tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
183
 
184
  # Expand per-image data in batch direction to be per-mask
 
188
  assert image_embeddings.shape[0] == tokens.shape[0]
189
  src = image_embeddings
190
  src = src + dense_prompt_embeddings
191
+ assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
 
 
192
  pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
193
  b, c, h, w = src.shape
194
 
 
209
 
210
  hyper_in_list: List[torch.Tensor] = []
211
  for i in range(self.num_mask_tokens):
212
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
 
 
213
  hyper_in = torch.stack(hyper_in_list, dim=1)
214
  b, c, h, w = upscaled_embedding.shape
215
  masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
 
248
  multimask_logits = all_mask_logits[:, 1:, :, :]
249
  multimask_iou_scores = all_iou_scores[:, 1:]
250
  best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
251
+ batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
 
 
252
  best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
253
  best_multimask_logits = best_multimask_logits.unsqueeze(1)
254
  best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
{sam2 → bboxmaskpose/sam2}/modeling/sam/pose_encoder.py RENAMED
@@ -9,9 +9,8 @@ from typing import Optional, Tuple, Type
9
  import torch
10
  from torch import nn
11
 
12
- from sam2.modeling.position_encoding import PositionEmbeddingRandom
13
-
14
- from sam2.modeling.sam2_utils import LayerNorm2d
15
 
16
 
17
  class PoseEncoder(nn.Module):
@@ -44,9 +43,7 @@ class PoseEncoder(nn.Module):
44
  self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45
 
46
  self.num_point_embeddings: int = 17 # 17 COCO keypoints
47
- point_embeddings = [
48
- nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
49
- ]
50
  self.point_embeddings = nn.ModuleList(point_embeddings)
51
  self.not_a_point_embed = nn.Embedding(1, embed_dim)
52
 
@@ -89,17 +86,12 @@ class PoseEncoder(nn.Module):
89
  padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
90
  points = torch.cat([points, padding_point], dim=1)
91
  labels = torch.cat([labels, padding_label], dim=1)
92
- point_embedding = self.pe_layer.forward_with_coords(
93
- points, self.input_image_size
94
- )
95
 
96
  kpt_embeddings = torch.cat([self.point_embeddings[i].weight for i in range(self.num_point_embeddings)], dim=0)
97
  negative_embedding = torch.zeros_like(point_embedding) + self.not_a_point_embed.weight
98
  positive_embedding = point_embedding + kpt_embeddings
99
- weighted_embedding = (
100
- positive_embedding * labels.unsqueeze(-1).float() +
101
- negative_embedding * (1 - labels.unsqueeze(-1).float())
102
- )
103
 
104
  point_embedding = torch.where(
105
  (labels == 0).unsqueeze(-1),
@@ -112,9 +104,7 @@ class PoseEncoder(nn.Module):
112
  """Embeds box prompts."""
113
  boxes = boxes + 0.5 # Shift to center of pixel
114
  coords = boxes.reshape(-1, 2, 2)
115
- corner_embedding = self.pe_layer.forward_with_coords(
116
- coords, self.input_image_size
117
- )
118
  corner_embedding[:, 0, :] += self.point_embeddings[2].weight
119
  corner_embedding[:, 1, :] += self.point_embeddings[3].weight
120
  return corner_embedding
@@ -170,9 +160,7 @@ class PoseEncoder(nn.Module):
170
  Bx(embed_dim)x(embed_H)x(embed_W)
171
  """
172
  bs = self._get_batch_size(points, boxes, masks)
173
- sparse_embeddings = torch.empty(
174
- (bs, 0, self.embed_dim), device=self._get_device()
175
- )
176
  if points is not None:
177
  coords, labels = points
178
  point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
 
9
  import torch
10
  from torch import nn
11
 
12
+ from bboxmaskpose.sam2.modeling.position_encoding import PositionEmbeddingRandom
13
+ from bboxmaskpose.sam2.modeling.sam2_utils import LayerNorm2d
 
14
 
15
 
16
  class PoseEncoder(nn.Module):
 
43
  self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44
 
45
  self.num_point_embeddings: int = 17 # 17 COCO keypoints
46
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
 
 
47
  self.point_embeddings = nn.ModuleList(point_embeddings)
48
  self.not_a_point_embed = nn.Embedding(1, embed_dim)
49
 
 
86
  padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
87
  points = torch.cat([points, padding_point], dim=1)
88
  labels = torch.cat([labels, padding_label], dim=1)
89
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
 
 
90
 
91
  kpt_embeddings = torch.cat([self.point_embeddings[i].weight for i in range(self.num_point_embeddings)], dim=0)
92
  negative_embedding = torch.zeros_like(point_embedding) + self.not_a_point_embed.weight
93
  positive_embedding = point_embedding + kpt_embeddings
94
+ weighted_embedding = positive_embedding * labels.unsqueeze(-1).float() + negative_embedding * (1 - labels.unsqueeze(-1).float())
 
 
 
95
 
96
  point_embedding = torch.where(
97
  (labels == 0).unsqueeze(-1),
 
104
  """Embeds box prompts."""
105
  boxes = boxes + 0.5 # Shift to center of pixel
106
  coords = boxes.reshape(-1, 2, 2)
107
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
 
 
108
  corner_embedding[:, 0, :] += self.point_embeddings[2].weight
109
  corner_embedding[:, 1, :] += self.point_embeddings[3].weight
110
  return corner_embedding
 
160
  Bx(embed_dim)x(embed_H)x(embed_W)
161
  """
162
  bs = self._get_batch_size(points, boxes, masks)
163
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
 
 
164
  if points is not None:
165
  coords, labels = points
166
  point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
{sam2 → bboxmaskpose/sam2}/modeling/sam/prompt_encoder.py RENAMED
@@ -6,12 +6,12 @@
6
 
7
  from typing import Optional, Tuple, Type
8
 
 
9
  import torch
10
  from torch import nn
11
 
12
- from sam2.modeling.position_encoding import PositionEmbeddingRandom
13
-
14
- from sam2.modeling.sam2_utils import LayerNorm2d
15
 
16
 
17
  class PromptEncoder(nn.Module):
@@ -22,6 +22,7 @@ class PromptEncoder(nn.Module):
22
  input_image_size: Tuple[int, int],
23
  mask_in_chans: int,
24
  activation: Type[nn.Module] = nn.GELU,
 
25
  ) -> None:
26
  """
27
  Encodes prompts for input to SAM's mask decoder.
@@ -44,9 +45,7 @@ class PromptEncoder(nn.Module):
44
  self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45
 
46
  self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
47
- point_embeddings = [
48
- nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
49
- ]
50
  self.point_embeddings = nn.ModuleList(point_embeddings)
51
  self.not_a_point_embed = nn.Embedding(1, embed_dim)
52
 
@@ -63,6 +62,7 @@ class PromptEncoder(nn.Module):
63
  activation(),
64
  nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
65
  )
 
66
  self.no_mask_embed = nn.Embedding(1, embed_dim)
67
 
68
  def get_dense_pe(self) -> torch.Tensor:
@@ -76,45 +76,41 @@ class PromptEncoder(nn.Module):
76
  """
77
  return self.pe_layer(self.image_embedding_size).unsqueeze(0)
78
 
79
- def _embed_points(
80
- self,
81
- points: torch.Tensor,
82
- labels: torch.Tensor,
83
- pad: bool,
84
  ) -> torch.Tensor:
85
  """Embeds point prompts."""
 
86
  points = points + 0.5 # Shift to center of pixel
87
  if pad:
88
  padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
89
  padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
90
  points = torch.cat([points, padding_point], dim=1)
91
  labels = torch.cat([labels, padding_label], dim=1)
92
- point_embedding = self.pe_layer.forward_with_coords(
93
- points, self.input_image_size
94
- )
95
 
 
96
  point_embedding = torch.where(
97
  (labels == -1).unsqueeze(-1),
98
  torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
99
  point_embedding,
100
  )
101
- point_embedding = torch.where(
102
  (labels == 0).unsqueeze(-1),
103
  point_embedding + self.point_embeddings[0].weight,
104
  point_embedding,
105
  )
106
  point_embedding = torch.where(
107
- (labels == 1).unsqueeze(-1),
108
  point_embedding + self.point_embeddings[1].weight,
109
  point_embedding,
110
  )
111
  point_embedding = torch.where(
112
- (labels == 2).unsqueeze(-1),
113
  point_embedding + self.point_embeddings[2].weight,
114
  point_embedding,
115
  )
116
  point_embedding = torch.where(
117
- (labels == 3).unsqueeze(-1),
118
  point_embedding + self.point_embeddings[3].weight,
119
  point_embedding,
120
  )
@@ -124,9 +120,7 @@ class PromptEncoder(nn.Module):
124
  """Embeds box prompts."""
125
  boxes = boxes + 0.5 # Shift to center of pixel
126
  coords = boxes.reshape(-1, 2, 2)
127
- corner_embedding = self.pe_layer.forward_with_coords(
128
- coords, self.input_image_size
129
- )
130
  corner_embedding[:, 0, :] += self.point_embeddings[2].weight
131
  corner_embedding[:, 1, :] += self.point_embeddings[3].weight
132
  return corner_embedding
@@ -160,9 +154,9 @@ class PromptEncoder(nn.Module):
160
  def forward(
161
  self,
162
  points: Optional[Tuple[torch.Tensor, torch.Tensor]],
163
- # skeletons: Optional[Tuple[torch.Tensor, torch.Tensor]],
164
  boxes: Optional[torch.Tensor],
165
  masks: Optional[torch.Tensor],
 
166
  ) -> Tuple[torch.Tensor, torch.Tensor]:
167
  """
168
  Embeds different types of prompts, returning both sparse and dense
@@ -182,12 +176,13 @@ class PromptEncoder(nn.Module):
182
  Bx(embed_dim)x(embed_H)x(embed_W)
183
  """
184
  bs = self._get_batch_size(points, boxes, masks)
185
- sparse_embeddings = torch.empty(
186
- (bs, 0, self.embed_dim), device=self._get_device()
187
- )
188
  if points is not None:
189
  coords, labels = points
190
- point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
 
 
 
191
  sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
192
  if boxes is not None:
193
  box_embeddings = self._embed_boxes(boxes)
 
6
 
7
  from typing import Optional, Tuple, Type
8
 
9
+ import numpy as np
10
  import torch
11
  from torch import nn
12
 
13
+ from bboxmaskpose.sam2.modeling.position_encoding import PositionEmbeddingRandom
14
+ from bboxmaskpose.sam2.modeling.sam2_utils import LayerNorm2d
 
15
 
16
 
17
  class PromptEncoder(nn.Module):
 
22
  input_image_size: Tuple[int, int],
23
  mask_in_chans: int,
24
  activation: Type[nn.Module] = nn.GELU,
25
+ n_kpts_encoder: int = -1,
26
  ) -> None:
27
  """
28
  Encodes prompts for input to SAM's mask decoder.
 
45
  self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
46
 
47
  self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
48
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
 
 
49
  self.point_embeddings = nn.ModuleList(point_embeddings)
50
  self.not_a_point_embed = nn.Embedding(1, embed_dim)
51
 
 
62
  activation(),
63
  nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
64
  )
65
+ self.n_kpts_encoder = n_kpts_encoder
66
  self.no_mask_embed = nn.Embedding(1, embed_dim)
67
 
68
  def get_dense_pe(self) -> torch.Tensor:
 
76
  """
77
  return self.pe_layer(self.image_embedding_size).unsqueeze(0)
78
 
79
+ def _embed_points( ## embeds the points into a high-dimensional space (e.g., 256-dim) using learned embeddings
80
+ self, points: torch.Tensor, labels: torch.Tensor, pad: bool, normalize: bool
 
 
 
81
  ) -> torch.Tensor:
82
  """Embeds point prompts."""
83
+ # print("EMBED points ", points) # KPTS OUTPUT
84
  points = points + 0.5 # Shift to center of pixel
85
  if pad:
86
  padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
87
  padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
88
  points = torch.cat([points, padding_point], dim=1)
89
  labels = torch.cat([labels, padding_label], dim=1)
 
 
 
90
 
91
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
92
  point_embedding = torch.where(
93
  (labels == -1).unsqueeze(-1),
94
  torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
95
  point_embedding,
96
  )
97
+ point_embedding = torch.where( ## negative pts
98
  (labels == 0).unsqueeze(-1),
99
  point_embedding + self.point_embeddings[0].weight,
100
  point_embedding,
101
  )
102
  point_embedding = torch.where(
103
+ (labels == 1).unsqueeze(-1), ## positive pts
104
  point_embedding + self.point_embeddings[1].weight,
105
  point_embedding,
106
  )
107
  point_embedding = torch.where(
108
+ (labels == 2).unsqueeze(-1), ## bbox top left
109
  point_embedding + self.point_embeddings[2].weight,
110
  point_embedding,
111
  )
112
  point_embedding = torch.where(
113
+ (labels == 3).unsqueeze(-1), ## bbox bottom right
114
  point_embedding + self.point_embeddings[3].weight,
115
  point_embedding,
116
  )
 
120
  """Embeds box prompts."""
121
  boxes = boxes + 0.5 # Shift to center of pixel
122
  coords = boxes.reshape(-1, 2, 2)
123
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
 
 
124
  corner_embedding[:, 0, :] += self.point_embeddings[2].weight
125
  corner_embedding[:, 1, :] += self.point_embeddings[3].weight
126
  return corner_embedding
 
154
  def forward(
155
  self,
156
  points: Optional[Tuple[torch.Tensor, torch.Tensor]],
 
157
  boxes: Optional[torch.Tensor],
158
  masks: Optional[torch.Tensor],
159
+ normalize: bool = True,
160
  ) -> Tuple[torch.Tensor, torch.Tensor]:
161
  """
162
  Embeds different types of prompts, returning both sparse and dense
 
176
  Bx(embed_dim)x(embed_H)x(embed_W)
177
  """
178
  bs = self._get_batch_size(points, boxes, masks)
179
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
 
 
180
  if points is not None:
181
  coords, labels = points
182
+ coords = coords.to(self._get_device())
183
+ labels = labels.to(self._get_device())
184
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None), normalize=normalize)
185
+ # point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
186
  sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
187
  if boxes is not None:
188
  box_embeddings = self._embed_boxes(boxes)
{sam2 → bboxmaskpose/sam2}/modeling/sam/transformer.py RENAMED
@@ -10,10 +10,10 @@ from typing import Tuple, Type
10
 
11
  import torch
12
  import torch.nn.functional as F
13
- from torch import nn, Tensor
14
 
15
- from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
16
- from sam2.modeling.sam2_utils import MLP
17
 
18
 
19
  class TwoWayTransformer(nn.Module):
@@ -57,9 +57,7 @@ class TwoWayTransformer(nn.Module):
57
  )
58
  )
59
 
60
- self.final_attn_token_to_image = Attention(
61
- embedding_dim, num_heads, downsample_rate=attention_downsample_rate
62
- )
63
  self.norm_final_attn = nn.LayerNorm(embedding_dim)
64
 
65
  def forward(
@@ -136,26 +134,18 @@ class TwoWayAttentionBlock(nn.Module):
136
  self.self_attn = Attention(embedding_dim, num_heads)
137
  self.norm1 = nn.LayerNorm(embedding_dim)
138
 
139
- self.cross_attn_token_to_image = Attention(
140
- embedding_dim, num_heads, downsample_rate=attention_downsample_rate
141
- )
142
  self.norm2 = nn.LayerNorm(embedding_dim)
143
 
144
- self.mlp = MLP(
145
- embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
146
- )
147
  self.norm3 = nn.LayerNorm(embedding_dim)
148
 
149
  self.norm4 = nn.LayerNorm(embedding_dim)
150
- self.cross_attn_image_to_token = Attention(
151
- embedding_dim, num_heads, downsample_rate=attention_downsample_rate
152
- )
153
 
154
  self.skip_first_layer_pe = skip_first_layer_pe
155
 
156
- def forward(
157
- self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
158
- ) -> Tuple[Tensor, Tensor]:
159
  # Self attention block
160
  if self.skip_first_layer_pe:
161
  queries = self.self_attn(q=queries, k=queries, v=queries)
@@ -206,9 +196,7 @@ class Attention(nn.Module):
206
  self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
207
  self.internal_dim = embedding_dim // downsample_rate
208
  self.num_heads = num_heads
209
- assert (
210
- self.internal_dim % num_heads == 0
211
- ), "num_heads must divide embedding_dim."
212
 
213
  self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
214
  self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
@@ -263,18 +251,12 @@ class RoPEAttention(Attention):
263
  ):
264
  super().__init__(*args, **kwargs)
265
 
266
- self.compute_cis = partial(
267
- compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
268
- )
269
  freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
270
- self.freqs_cis = (
271
- freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
272
- )
273
  self.rope_k_repeat = rope_k_repeat
274
 
275
- def forward(
276
- self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
277
- ) -> Tensor:
278
  # Input projections
279
  q = self.q_proj(q)
280
  k = self.k_proj(k)
 
10
 
11
  import torch
12
  import torch.nn.functional as F
13
+ from torch import Tensor, nn
14
 
15
+ from bboxmaskpose.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
16
+ from bboxmaskpose.sam2.modeling.sam2_utils import MLP
17
 
18
 
19
  class TwoWayTransformer(nn.Module):
 
57
  )
58
  )
59
 
60
+ self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
 
 
61
  self.norm_final_attn = nn.LayerNorm(embedding_dim)
62
 
63
  def forward(
 
134
  self.self_attn = Attention(embedding_dim, num_heads)
135
  self.norm1 = nn.LayerNorm(embedding_dim)
136
 
137
+ self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
 
 
138
  self.norm2 = nn.LayerNorm(embedding_dim)
139
 
140
+ self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation)
 
 
141
  self.norm3 = nn.LayerNorm(embedding_dim)
142
 
143
  self.norm4 = nn.LayerNorm(embedding_dim)
144
+ self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
 
 
145
 
146
  self.skip_first_layer_pe = skip_first_layer_pe
147
 
148
+ def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
 
 
149
  # Self attention block
150
  if self.skip_first_layer_pe:
151
  queries = self.self_attn(q=queries, k=queries, v=queries)
 
196
  self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
197
  self.internal_dim = embedding_dim // downsample_rate
198
  self.num_heads = num_heads
199
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
 
 
200
 
201
  self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
202
  self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
 
251
  ):
252
  super().__init__(*args, **kwargs)
253
 
254
+ self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
 
 
255
  freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
256
+ self.freqs_cis = freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
 
 
257
  self.rope_k_repeat = rope_k_repeat
258
 
259
+ def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
 
 
260
  # Input projections
261
  q = self.q_proj(q)
262
  k = self.k_proj(k)
{sam2 → bboxmaskpose/sam2}/modeling/sam2_base.py RENAMED
@@ -4,20 +4,15 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
- from loguru import logger
8
-
9
  import torch
10
  import torch.distributed
11
  import torch.nn.functional as F
12
-
13
  from torch.nn.init import trunc_normal_
14
 
15
- from sam2.modeling.sam.mask_decoder import MaskDecoder
16
- from sam2.modeling.sam.prompt_encoder import PromptEncoder
17
- from sam2.modeling.sam.transformer import TwoWayTransformer
18
- from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
19
-
20
- from sam2.utils.kalman_filter import KalmanFilter
21
 
22
  # a large negative value as a placeholder score for missing objects
23
  NO_OBJ_SCORE = -1024.0
@@ -97,19 +92,10 @@ class SAM2Base(torch.nn.Module):
97
  # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
98
  sam_mask_decoder_extra_args=None,
99
  compile_image_encoder: bool = False,
100
- # Whether to use SAMURAI or original SAM 2
101
- samurai_mode: bool = False,
102
- # Hyperparameters for SAMURAI
103
- stable_frames_threshold: int = 15,
104
- stable_ious_threshold: float = 0.3,
105
- min_obj_score_logits: float = -1,
106
- kf_score_weight: float = 0.15,
107
- memory_bank_iou_threshold: float = 0.5,
108
- memory_bank_obj_score_threshold: float = 0.0,
109
- memory_bank_kf_score_threshold: float = 0.0,
110
  ):
111
  super().__init__()
112
-
113
  # Part 1: the image backbone
114
  self.image_encoder = image_encoder
115
  # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
@@ -137,16 +123,12 @@ class SAM2Base(torch.nn.Module):
137
  # Part 3: memory encoder for the previous frame's outputs
138
  self.memory_encoder = memory_encoder
139
  self.mem_dim = self.hidden_dim
140
- if hasattr(self.memory_encoder, "out_proj") and hasattr(
141
- self.memory_encoder.out_proj, "weight"
142
- ):
143
  # if there is compression of memories along channel dim
144
  self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
145
  self.num_maskmem = num_maskmem # Number of memories accessible
146
  # Temporal encoding of the memories
147
- self.maskmem_tpos_enc = torch.nn.Parameter(
148
- torch.zeros(num_maskmem, 1, 1, self.mem_dim)
149
- )
150
  trunc_normal_(self.maskmem_tpos_enc, std=0.02)
151
  # a single token to indicate no memory embedding from previous frames
152
  self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
@@ -194,37 +176,10 @@ class SAM2Base(torch.nn.Module):
194
 
195
  self._build_sam_heads()
196
  self.max_cond_frames_in_attn = max_cond_frames_in_attn
197
-
198
- # Whether to use SAMURAI or original SAM 2
199
- self.samurai_mode = samurai_mode
200
-
201
- # Init Kalman Filter
202
- self.kf = KalmanFilter()
203
- self.kf_mean = None
204
- self.kf_covariance = None
205
- self.stable_frames = 0
206
-
207
- # Debug purpose
208
- self.history = {} # debug
209
- self.frame_cnt = 0 # debug
210
-
211
- # Hyperparameters for SAMURAI
212
- self.stable_frames_threshold = stable_frames_threshold
213
- self.stable_ious_threshold = stable_ious_threshold
214
- self.min_obj_score_logits = min_obj_score_logits
215
- self.kf_score_weight = kf_score_weight
216
- self.memory_bank_iou_threshold = memory_bank_iou_threshold
217
- self.memory_bank_obj_score_threshold = memory_bank_obj_score_threshold
218
- self.memory_bank_kf_score_threshold = memory_bank_kf_score_threshold
219
-
220
- print(f"\033[93mSAMURAI mode: {self.samurai_mode}\033[0m")
221
-
222
  # Model compilation
223
  if compile_image_encoder:
224
  # Compile the forward function (not the full module) to allow loading checkpoints.
225
- print(
226
- "Image encoder compilation is enabled. First forward pass will be slow."
227
- )
228
  self.image_encoder.forward = torch.compile(
229
  self.image_encoder.forward,
230
  mode="max-autotune",
@@ -232,6 +187,15 @@ class SAM2Base(torch.nn.Module):
232
  dynamic=False,
233
  )
234
 
 
 
 
 
 
 
 
 
 
235
  @property
236
  def device(self):
237
  return next(self.parameters()).device
@@ -257,7 +221,9 @@ class SAM2Base(torch.nn.Module):
257
  ),
258
  input_image_size=(self.image_size, self.image_size),
259
  mask_in_chans=16,
 
260
  )
 
261
  self.sam_mask_decoder = MaskDecoder(
262
  num_multimask_outputs=3,
263
  transformer=TwoWayTransformer(
@@ -276,13 +242,16 @@ class SAM2Base(torch.nn.Module):
276
  use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
277
  **(self.sam_mask_decoder_extra_args or {}),
278
  )
 
 
 
 
 
279
  if self.use_obj_ptrs_in_encoder:
280
  # a linear projection on SAM output tokens to turn them into object pointers
281
  self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
282
  if self.use_mlp_for_obj_ptr_proj:
283
- self.obj_ptr_proj = MLP(
284
- self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
285
- )
286
  else:
287
  self.obj_ptr_proj = torch.nn.Identity()
288
  if self.proj_tpos_enc_in_obj_ptrs:
@@ -395,7 +364,7 @@ class SAM2Base(torch.nn.Module):
395
  high_res_features=high_res_features,
396
  )
397
  if self.pred_obj_scores:
398
- is_obj_appearing = object_score_logits > self.min_obj_score_logits
399
 
400
  # Mask used for spatial memories is always a *hard* choice between obj and no obj,
401
  # consistent with the actual mask prediction
@@ -416,87 +385,7 @@ class SAM2Base(torch.nn.Module):
416
  )
417
 
418
  sam_output_token = sam_output_tokens[:, 0]
419
- kf_ious = None
420
- if multimask_output and self.samurai_mode:
421
- if self.kf_mean is None and self.kf_covariance is None or self.stable_frames == 0:
422
- best_iou_inds = torch.argmax(ious, dim=-1)
423
- batch_inds = torch.arange(B, device=device)
424
- low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
425
- high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
426
- non_zero_indices = torch.argwhere(high_res_masks[0][0] > 0.0)
427
- if len(non_zero_indices) == 0:
428
- high_res_bbox = [0, 0, 0, 0]
429
- else:
430
- y_min, x_min = non_zero_indices.min(dim=0).values
431
- y_max, x_max = non_zero_indices.max(dim=0).values
432
- high_res_bbox = [x_min.item(), y_min.item(), x_max.item(), y_max.item()]
433
- self.kf_mean, self.kf_covariance = self.kf.initiate(self.kf.xyxy_to_xyah(high_res_bbox))
434
- if sam_output_tokens.size(1) > 1:
435
- sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
436
- self.frame_cnt += 1
437
- self.stable_frames += 1
438
- elif self.stable_frames < self.stable_frames_threshold:
439
- self.kf_mean, self.kf_covariance = self.kf.predict(self.kf_mean, self.kf_covariance)
440
- best_iou_inds = torch.argmax(ious, dim=-1)
441
- batch_inds = torch.arange(B, device=device)
442
- low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
443
- high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
444
- non_zero_indices = torch.argwhere(high_res_masks[0][0] > 0.0)
445
- if len(non_zero_indices) == 0:
446
- high_res_bbox = [0, 0, 0, 0]
447
- else:
448
- y_min, x_min = non_zero_indices.min(dim=0).values
449
- y_max, x_max = non_zero_indices.max(dim=0).values
450
- high_res_bbox = [x_min.item(), y_min.item(), x_max.item(), y_max.item()]
451
- if ious[0][best_iou_inds] > self.stable_ious_threshold:
452
- self.kf_mean, self.kf_covariance = self.kf.update(self.kf_mean, self.kf_covariance, self.kf.xyxy_to_xyah(high_res_bbox))
453
- self.stable_frames += 1
454
- else:
455
- self.stable_frames = 0
456
- if sam_output_tokens.size(1) > 1:
457
- sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
458
- self.frame_cnt += 1
459
- else:
460
- self.kf_mean, self.kf_covariance = self.kf.predict(self.kf_mean, self.kf_covariance)
461
- high_res_multibboxes = []
462
- batch_inds = torch.arange(B, device=device)
463
- for i in range(ious.shape[1]):
464
- non_zero_indices = torch.argwhere(high_res_multimasks[batch_inds, i].unsqueeze(1)[0][0] > 0.0)
465
- if len(non_zero_indices) == 0:
466
- high_res_multibboxes.append([0, 0, 0, 0])
467
- else:
468
- y_min, x_min = non_zero_indices.min(dim=0).values
469
- y_max, x_max = non_zero_indices.max(dim=0).values
470
- high_res_multibboxes.append([x_min.item(), y_min.item(), x_max.item(), y_max.item()])
471
- # compute the IoU between the predicted bbox and the high_res_multibboxes
472
- kf_ious = torch.tensor(self.kf.compute_iou(self.kf_mean[:4], high_res_multibboxes), device=device)
473
- # weighted iou
474
- weighted_ious = self.kf_score_weight * kf_ious + (1 - self.kf_score_weight) * ious
475
- best_iou_inds = torch.argmax(weighted_ious, dim=-1)
476
- batch_inds = torch.arange(B, device=device)
477
- low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
478
- high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
479
- if sam_output_tokens.size(1) > 1:
480
- sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
481
-
482
- if False:
483
- # make all these on cpu
484
- self.history[self.frame_cnt] = {
485
- "kf_predicted_bbox": self.kf.xyah_to_xyxy(self.kf_mean[:4]),
486
- # "multi_masks": high_res_multimasks.cpu(),
487
- "ious": ious.cpu(),
488
- "multi_bboxes": high_res_multibboxes,
489
- "kf_ious": kf_ious,
490
- "weighted_ious": weighted_ious.cpu(),
491
- "final_selection": best_iou_inds.cpu(),
492
- }
493
- self.frame_cnt += 1
494
-
495
- if ious[0][best_iou_inds] < self.stable_ious_threshold:
496
- self.stable_frames = 0
497
- else:
498
- self.kf_mean, self.kf_covariance = self.kf.update(self.kf_mean, self.kf_covariance, self.kf.xyxy_to_xyah(high_res_multibboxes[best_iou_inds]))
499
- elif multimask_output and not self.samurai_mode:
500
  # take the best mask prediction (with the highest IoU estimation)
501
  best_iou_inds = torch.argmax(ious, dim=-1)
502
  batch_inds = torch.arange(B, device=device)
@@ -505,7 +394,6 @@ class SAM2Base(torch.nn.Module):
505
  if sam_output_tokens.size(1) > 1:
506
  sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
507
  else:
508
- best_iou_inds = 0
509
  low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
510
 
511
  # Extract object pointer from the SAM output token (with occlusion handling)
@@ -529,8 +417,6 @@ class SAM2Base(torch.nn.Module):
529
  high_res_masks,
530
  obj_ptr,
531
  object_score_logits,
532
- ious[0][best_iou_inds],
533
- kf_ious[best_iou_inds] if kf_ious is not None else None,
534
  )
535
 
536
  def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
@@ -553,12 +439,10 @@ class SAM2Base(torch.nn.Module):
553
  ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
554
  if not self.use_obj_ptrs_in_encoder:
555
  # all zeros as a dummy object pointer (of shape [B, C])
556
- obj_ptr = torch.zeros(
557
- mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
558
- )
559
  else:
560
  # produce an object pointer using the SAM decoder from the mask input
561
- _, _, _, _, _, obj_ptr, _, _, _ = self._forward_sam_heads(
562
  backbone_features=backbone_features,
563
  mask_inputs=self.mask_downsample(mask_inputs_float),
564
  high_res_features=high_res_features,
@@ -591,12 +475,8 @@ class SAM2Base(torch.nn.Module):
591
  if self.use_high_res_features_in_sam:
592
  # precompute projected level 0 and level 1 features in SAM decoder
593
  # to avoid running it again on every SAM click
594
- backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
595
- backbone_out["backbone_fpn"][0]
596
- )
597
- backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
598
- backbone_out["backbone_fpn"][1]
599
- )
600
  return backbone_out
601
 
602
  def _prepare_backbone_features(self, backbone_out):
@@ -657,63 +537,36 @@ class SAM2Base(torch.nn.Module):
657
  # We also allow taking the memory frame non-consecutively (with stride>1), in which case
658
  # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
659
  stride = 1 if self.training else self.memory_temporal_stride_for_eval
660
-
661
- if self.samurai_mode:
662
- valid_indices = []
663
- if frame_idx > 1: # Ensure we have previous frames to evaluate
664
- for i in range(frame_idx - 1, 1, -1): # Iterate backwards through previous frames
665
- iou_score = output_dict["non_cond_frame_outputs"][i]["best_iou_score"] # Get mask affinity score
666
- obj_score = output_dict["non_cond_frame_outputs"][i]["object_score_logits"] # Get object score
667
- kf_score = output_dict["non_cond_frame_outputs"][i]["kf_score"] if "kf_score" in output_dict["non_cond_frame_outputs"][i] else None # Get motion score if available
668
- # Check if the scores meet the criteria for being a valid index
669
- if iou_score.item() > self.memory_bank_iou_threshold and \
670
- obj_score.item() > self.memory_bank_obj_score_threshold and \
671
- (kf_score is None or kf_score.item() > self.memory_bank_kf_score_threshold):
672
- valid_indices.insert(0, i)
673
- # Check the number of valid indices
674
- if len(valid_indices) >= self.max_obj_ptrs_in_encoder - 1:
675
- break
676
- if frame_idx - 1 not in valid_indices:
677
- valid_indices.append(frame_idx - 1)
678
- for t_pos in range(1, self.num_maskmem): # Iterate over the number of mask memories
679
- idx = t_pos - self.num_maskmem # Calculate the index for valid indices
680
- if idx < -len(valid_indices): # Skip if index is out of bounds
681
- continue
682
- out = output_dict["non_cond_frame_outputs"].get(valid_indices[idx], None) # Get output for the valid index
683
- if out is None: # If not found, check unselected outputs
684
- out = unselected_cond_outputs.get(valid_indices[idx], None)
685
- t_pos_and_prevs.append((t_pos, out)) # Append the temporal position and output to the list
686
- else:
687
- for t_pos in range(1, self.num_maskmem):
688
- t_rel = self.num_maskmem - t_pos # how many frames before current frame
689
- if t_rel == 1:
690
- # for t_rel == 1, we take the last frame (regardless of r)
691
- if not track_in_reverse:
692
- # the frame immediately before this frame (i.e. frame_idx - 1)
693
- prev_frame_idx = frame_idx - t_rel
694
- else:
695
- # the frame immediately after this frame (i.e. frame_idx + 1)
696
- prev_frame_idx = frame_idx + t_rel
697
  else:
698
- # for t_rel >= 2, we take the memory frame from every r-th frames
699
- if not track_in_reverse:
700
- # first find the nearest frame among every r-th frames before this frame
701
- # for r=1, this would be (frame_idx - 2)
702
- prev_frame_idx = ((frame_idx - 2) // stride) * stride
703
- # then seek further among every r-th frames
704
- prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
705
- else:
706
- # first find the nearest frame among every r-th frames after this frame
707
- # for r=1, this would be (frame_idx + 2)
708
- prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
709
- # then seek further among every r-th frames
710
- prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
711
- out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
712
- if out is None:
713
- # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
714
- # frames, we still attend to it as if it's a non-conditioning frame.
715
- out = unselected_cond_outputs.get(prev_frame_idx, None)
716
- t_pos_and_prevs.append((t_pos, out))
 
 
 
717
 
718
  for t_pos, prev in t_pos_and_prevs:
719
  if prev is None:
@@ -726,9 +579,7 @@ class SAM2Base(torch.nn.Module):
726
  maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
727
  maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
728
  # Temporal positional encoding
729
- maskmem_enc = (
730
- maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
731
- )
732
  to_cat_memory_pos_embed.append(maskmem_enc)
733
 
734
  # Construct the list of past object pointers
@@ -738,20 +589,14 @@ class SAM2Base(torch.nn.Module):
738
  # (optionally, only include object pointers in the past during evaluation)
739
  if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
740
  ptr_cond_outputs = {
741
- t: out
742
- for t, out in selected_cond_outputs.items()
743
- if (t >= frame_idx if track_in_reverse else t <= frame_idx)
744
  }
745
  else:
746
  ptr_cond_outputs = selected_cond_outputs
747
  pos_and_ptrs = [
748
  # Temporal pos encoding contains how far away each pointer is from current frame
749
  (
750
- (
751
- (frame_idx - t) * tpos_sign_mul
752
- if self.use_signed_tpos_enc_to_obj_ptrs
753
- else abs(frame_idx - t)
754
- ),
755
  out["obj_ptr"],
756
  )
757
  for t, out in ptr_cond_outputs.items()
@@ -761,9 +606,7 @@ class SAM2Base(torch.nn.Module):
761
  t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
762
  if t < 0 or (num_frames is not None and t >= num_frames):
763
  break
764
- out = output_dict["non_cond_frame_outputs"].get(
765
- t, unselected_cond_outputs.get(t, None)
766
- )
767
  if out is not None:
768
  pos_and_ptrs.append((t_diff, out["obj_ptr"]))
769
  # If we have at least one object pointer, add them to the across attention
@@ -776,9 +619,7 @@ class SAM2Base(torch.nn.Module):
776
  if self.add_tpos_enc_to_obj_ptrs:
777
  t_diff_max = max_obj_ptrs_in_encoder - 1
778
  tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
779
- obj_pos = torch.tensor(pos_list).to(
780
- device=device, non_blocking=True
781
- )
782
  obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
783
  obj_pos = self.obj_ptr_tpos_proj(obj_pos)
784
  obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
@@ -786,9 +627,7 @@ class SAM2Base(torch.nn.Module):
786
  obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
787
  if self.mem_dim < C:
788
  # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
789
- obj_ptrs = obj_ptrs.reshape(
790
- -1, B, C // self.mem_dim, self.mem_dim
791
- )
792
  obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
793
  obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
794
  to_cat_memory.append(obj_ptrs)
@@ -841,9 +680,7 @@ class SAM2Base(torch.nn.Module):
841
  # optionally, apply non-overlapping constraints to the masks (it's applied
842
  # in the batch dimension and should only be used during eval, where all
843
  # the objects come from the same video under batch size 1).
844
- pred_masks_high_res = self._apply_non_overlapping_constraints(
845
- pred_masks_high_res
846
- )
847
  # scale the raw mask logits with a temperature before applying sigmoid
848
  binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
849
  if binarize and not self.training:
@@ -856,18 +693,14 @@ class SAM2Base(torch.nn.Module):
856
  mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
857
  if self.sigmoid_bias_for_mem_enc != 0.0:
858
  mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
859
- maskmem_out = self.memory_encoder(
860
- pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
861
- )
862
  maskmem_features = maskmem_out["vision_features"]
863
  maskmem_pos_enc = maskmem_out["vision_pos_enc"]
864
  # add a no-object embedding to the spatial memory to indicate that the frame
865
  # is predicted to be occluded (i.e. no object is appearing in the frame)
866
  if self.no_obj_embed_spatial is not None:
867
  is_obj_appearing = (object_score_logits > 0).float()
868
- maskmem_features += (
869
- 1 - is_obj_appearing[..., None, None]
870
- ) * self.no_obj_embed_spatial[..., None, None].expand(
871
  *maskmem_features.shape
872
  )
873
 
@@ -891,8 +724,7 @@ class SAM2Base(torch.nn.Module):
891
  # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
892
  if len(current_vision_feats) > 1:
893
  high_res_features = [
894
- x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
895
- for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
896
  ]
897
  else:
898
  high_res_features = None
@@ -901,9 +733,7 @@ class SAM2Base(torch.nn.Module):
901
  # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
902
  pix_feat = current_vision_feats[-1].permute(1, 2, 0)
903
  pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
904
- sam_outputs = self._use_mask_as_output(
905
- pix_feat, high_res_features, mask_inputs
906
- )
907
  else:
908
  # fused the visual feature with previous memory features in the memory bank
909
  pix_feat = self._prepare_memory_conditioned_features(
@@ -1002,15 +832,11 @@ class SAM2Base(torch.nn.Module):
1002
  high_res_masks,
1003
  obj_ptr,
1004
  object_score_logits,
1005
- best_iou_score,
1006
- kf_ious
1007
  ) = sam_outputs
1008
 
1009
  current_out["pred_masks"] = low_res_masks
1010
  current_out["pred_masks_high_res"] = high_res_masks
1011
  current_out["obj_ptr"] = obj_ptr
1012
- current_out["best_iou_score"] = best_iou_score
1013
- current_out["kf_ious"] = kf_ious
1014
  if not self.training:
1015
  # Only add this in inference (to avoid unused param in activation checkpointing;
1016
  # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
 
7
  import torch
8
  import torch.distributed
9
  import torch.nn.functional as F
 
10
  from torch.nn.init import trunc_normal_
11
 
12
+ from bboxmaskpose.sam2.modeling.sam2_utils import MLP, get_1d_sine_pe, select_closest_cond_frames
13
+ from bboxmaskpose.sam2.modeling.sam.mask_decoder import MaskDecoder
14
+ from bboxmaskpose.sam2.modeling.sam.prompt_encoder import PromptEncoder
15
+ from bboxmaskpose.sam2.modeling.sam.transformer import TwoWayTransformer
 
 
16
 
17
  # a large negative value as a placeholder score for missing objects
18
  NO_OBJ_SCORE = -1024.0
 
92
  # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
93
  sam_mask_decoder_extra_args=None,
94
  compile_image_encoder: bool = False,
95
+ n_kpts_encoder: int = -1,
 
 
 
 
 
 
 
 
 
96
  ):
97
  super().__init__()
98
+ self.n_kpts_encoder = n_kpts_encoder
99
  # Part 1: the image backbone
100
  self.image_encoder = image_encoder
101
  # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
 
123
  # Part 3: memory encoder for the previous frame's outputs
124
  self.memory_encoder = memory_encoder
125
  self.mem_dim = self.hidden_dim
126
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
 
 
127
  # if there is compression of memories along channel dim
128
  self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
129
  self.num_maskmem = num_maskmem # Number of memories accessible
130
  # Temporal encoding of the memories
131
+ self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
 
 
132
  trunc_normal_(self.maskmem_tpos_enc, std=0.02)
133
  # a single token to indicate no memory embedding from previous frames
134
  self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
 
176
 
177
  self._build_sam_heads()
178
  self.max_cond_frames_in_attn = max_cond_frames_in_attn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  # Model compilation
180
  if compile_image_encoder:
181
  # Compile the forward function (not the full module) to allow loading checkpoints.
182
+ print("Image encoder compilation is enabled. First forward pass will be slow.")
 
 
183
  self.image_encoder.forward = torch.compile(
184
  self.image_encoder.forward,
185
  mode="max-autotune",
 
187
  dynamic=False,
188
  )
189
 
190
+ freeze_prompt_encoder = False
191
+ freeze_mask_decoder = False
192
+ if freeze_prompt_encoder:
193
+ for p in self.sam_prompt_encoder.parameters():
194
+ p.requires_grad = False
195
+ if freeze_mask_decoder:
196
+ for p in self.sam_mask_decoder.parameters():
197
+ p.requires_grad = False
198
+
199
  @property
200
  def device(self):
201
  return next(self.parameters()).device
 
221
  ),
222
  input_image_size=(self.image_size, self.image_size),
223
  mask_in_chans=16,
224
+ n_kpts_encoder=self.n_kpts_encoder,
225
  )
226
+
227
  self.sam_mask_decoder = MaskDecoder(
228
  num_multimask_outputs=3,
229
  transformer=TwoWayTransformer(
 
242
  use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
243
  **(self.sam_mask_decoder_extra_args or {}),
244
  )
245
+ for p in self.sam_prompt_encoder.parameters():
246
+ p.requires_grad = True
247
+ for p in self.sam_mask_decoder.parameters():
248
+ p.requires_grad = True
249
+
250
  if self.use_obj_ptrs_in_encoder:
251
  # a linear projection on SAM output tokens to turn them into object pointers
252
  self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
253
  if self.use_mlp_for_obj_ptr_proj:
254
+ self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
 
 
255
  else:
256
  self.obj_ptr_proj = torch.nn.Identity()
257
  if self.proj_tpos_enc_in_obj_ptrs:
 
364
  high_res_features=high_res_features,
365
  )
366
  if self.pred_obj_scores:
367
+ is_obj_appearing = object_score_logits > 0
368
 
369
  # Mask used for spatial memories is always a *hard* choice between obj and no obj,
370
  # consistent with the actual mask prediction
 
385
  )
386
 
387
  sam_output_token = sam_output_tokens[:, 0]
388
+ if multimask_output:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  # take the best mask prediction (with the highest IoU estimation)
390
  best_iou_inds = torch.argmax(ious, dim=-1)
391
  batch_inds = torch.arange(B, device=device)
 
394
  if sam_output_tokens.size(1) > 1:
395
  sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
396
  else:
 
397
  low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
398
 
399
  # Extract object pointer from the SAM output token (with occlusion handling)
 
417
  high_res_masks,
418
  obj_ptr,
419
  object_score_logits,
 
 
420
  )
421
 
422
  def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
 
439
  ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
440
  if not self.use_obj_ptrs_in_encoder:
441
  # all zeros as a dummy object pointer (of shape [B, C])
442
+ obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
 
 
443
  else:
444
  # produce an object pointer using the SAM decoder from the mask input
445
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
446
  backbone_features=backbone_features,
447
  mask_inputs=self.mask_downsample(mask_inputs_float),
448
  high_res_features=high_res_features,
 
475
  if self.use_high_res_features_in_sam:
476
  # precompute projected level 0 and level 1 features in SAM decoder
477
  # to avoid running it again on every SAM click
478
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
479
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
 
 
 
 
480
  return backbone_out
481
 
482
  def _prepare_backbone_features(self, backbone_out):
 
537
  # We also allow taking the memory frame non-consecutively (with stride>1), in which case
538
  # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
539
  stride = 1 if self.training else self.memory_temporal_stride_for_eval
540
+ for t_pos in range(1, self.num_maskmem):
541
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
542
+ if t_rel == 1:
543
+ # for t_rel == 1, we take the last frame (regardless of r)
544
+ if not track_in_reverse:
545
+ # the frame immediately before this frame (i.e. frame_idx - 1)
546
+ prev_frame_idx = frame_idx - t_rel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  else:
548
+ # the frame immediately after this frame (i.e. frame_idx + 1)
549
+ prev_frame_idx = frame_idx + t_rel
550
+ else:
551
+ # for t_rel >= 2, we take the memory frame from every r-th frames
552
+ if not track_in_reverse:
553
+ # first find the nearest frame among every r-th frames before this frame
554
+ # for r=1, this would be (frame_idx - 2)
555
+ prev_frame_idx = ((frame_idx - 2) // stride) * stride
556
+ # then seek further among every r-th frames
557
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
558
+ else:
559
+ # first find the nearest frame among every r-th frames after this frame
560
+ # for r=1, this would be (frame_idx + 2)
561
+ prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
562
+ # then seek further among every r-th frames
563
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
564
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
565
+ if out is None:
566
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
567
+ # frames, we still attend to it as if it's a non-conditioning frame.
568
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
569
+ t_pos_and_prevs.append((t_pos, out))
570
 
571
  for t_pos, prev in t_pos_and_prevs:
572
  if prev is None:
 
579
  maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
580
  maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
581
  # Temporal positional encoding
582
+ maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
 
 
583
  to_cat_memory_pos_embed.append(maskmem_enc)
584
 
585
  # Construct the list of past object pointers
 
589
  # (optionally, only include object pointers in the past during evaluation)
590
  if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
591
  ptr_cond_outputs = {
592
+ t: out for t, out in selected_cond_outputs.items() if (t >= frame_idx if track_in_reverse else t <= frame_idx)
 
 
593
  }
594
  else:
595
  ptr_cond_outputs = selected_cond_outputs
596
  pos_and_ptrs = [
597
  # Temporal pos encoding contains how far away each pointer is from current frame
598
  (
599
+ ((frame_idx - t) * tpos_sign_mul if self.use_signed_tpos_enc_to_obj_ptrs else abs(frame_idx - t)),
 
 
 
 
600
  out["obj_ptr"],
601
  )
602
  for t, out in ptr_cond_outputs.items()
 
606
  t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
607
  if t < 0 or (num_frames is not None and t >= num_frames):
608
  break
609
+ out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
 
 
610
  if out is not None:
611
  pos_and_ptrs.append((t_diff, out["obj_ptr"]))
612
  # If we have at least one object pointer, add them to the across attention
 
619
  if self.add_tpos_enc_to_obj_ptrs:
620
  t_diff_max = max_obj_ptrs_in_encoder - 1
621
  tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
622
+ obj_pos = torch.tensor(pos_list).to(device=device, non_blocking=True)
 
 
623
  obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
624
  obj_pos = self.obj_ptr_tpos_proj(obj_pos)
625
  obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
 
627
  obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
628
  if self.mem_dim < C:
629
  # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
630
+ obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
 
 
631
  obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
632
  obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
633
  to_cat_memory.append(obj_ptrs)
 
680
  # optionally, apply non-overlapping constraints to the masks (it's applied
681
  # in the batch dimension and should only be used during eval, where all
682
  # the objects come from the same video under batch size 1).
683
+ pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
 
 
684
  # scale the raw mask logits with a temperature before applying sigmoid
685
  binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
686
  if binarize and not self.training:
 
693
  mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
694
  if self.sigmoid_bias_for_mem_enc != 0.0:
695
  mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
696
+ maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
 
 
697
  maskmem_features = maskmem_out["vision_features"]
698
  maskmem_pos_enc = maskmem_out["vision_pos_enc"]
699
  # add a no-object embedding to the spatial memory to indicate that the frame
700
  # is predicted to be occluded (i.e. no object is appearing in the frame)
701
  if self.no_obj_embed_spatial is not None:
702
  is_obj_appearing = (object_score_logits > 0).float()
703
+ maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[..., None, None].expand(
 
 
704
  *maskmem_features.shape
705
  )
706
 
 
724
  # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
725
  if len(current_vision_feats) > 1:
726
  high_res_features = [
727
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
 
728
  ]
729
  else:
730
  high_res_features = None
 
733
  # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
734
  pix_feat = current_vision_feats[-1].permute(1, 2, 0)
735
  pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
736
+ sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
 
 
737
  else:
738
  # fused the visual feature with previous memory features in the memory bank
739
  pix_feat = self._prepare_memory_conditioned_features(
 
832
  high_res_masks,
833
  obj_ptr,
834
  object_score_logits,
 
 
835
  ) = sam_outputs
836
 
837
  current_out["pred_masks"] = low_res_masks
838
  current_out["pred_masks_high_res"] = high_res_masks
839
  current_out["obj_ptr"] = obj_ptr
 
 
840
  if not self.training:
841
  # Only add this in inference (to avoid unused param in activation checkpointing;
842
  # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
{sam2 → bboxmaskpose/sam2}/modeling/sam2_base_pose.py RENAMED
@@ -4,20 +4,17 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
- from loguru import logger
8
-
9
  import torch
10
  import torch.distributed
11
  import torch.nn.functional as F
12
-
13
  from torch.nn.init import trunc_normal_
14
 
15
- from sam2.modeling.sam.mask_decoder import MaskDecoder
16
- from sam2.modeling.sam.pose_encoder import PoseEncoder
17
- from sam2.modeling.sam.transformer import TwoWayTransformer
18
- from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
19
-
20
- from sam2.utils.kalman_filter import KalmanFilter
21
 
22
  # a large negative value as a placeholder score for missing objects
23
  NO_OBJ_SCORE = -1024.0
@@ -137,16 +134,12 @@ class SAM2Base(torch.nn.Module):
137
  # Part 3: memory encoder for the previous frame's outputs
138
  self.memory_encoder = memory_encoder
139
  self.mem_dim = self.hidden_dim
140
- if hasattr(self.memory_encoder, "out_proj") and hasattr(
141
- self.memory_encoder.out_proj, "weight"
142
- ):
143
  # if there is compression of memories along channel dim
144
  self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
145
  self.num_maskmem = num_maskmem # Number of memories accessible
146
  # Temporal encoding of the memories
147
- self.maskmem_tpos_enc = torch.nn.Parameter(
148
- torch.zeros(num_maskmem, 1, 1, self.mem_dim)
149
- )
150
  trunc_normal_(self.maskmem_tpos_enc, std=0.02)
151
  # a single token to indicate no memory embedding from previous frames
152
  self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
@@ -205,8 +198,8 @@ class SAM2Base(torch.nn.Module):
205
  self.stable_frames = 0
206
 
207
  # Debug purpose
208
- self.history = {} # debug
209
- self.frame_cnt = 0 # debug
210
 
211
  # Hyperparameters for SAMURAI
212
  self.stable_frames_threshold = stable_frames_threshold
@@ -222,9 +215,7 @@ class SAM2Base(torch.nn.Module):
222
  # Model compilation
223
  if compile_image_encoder:
224
  # Compile the forward function (not the full module) to allow loading checkpoints.
225
- print(
226
- "Image encoder compilation is enabled. First forward pass will be slow."
227
- )
228
  self.image_encoder.forward = torch.compile(
229
  self.image_encoder.forward,
230
  mode="max-autotune",
@@ -280,9 +271,7 @@ class SAM2Base(torch.nn.Module):
280
  # a linear projection on SAM output tokens to turn them into object pointers
281
  self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
282
  if self.use_mlp_for_obj_ptr_proj:
283
- self.obj_ptr_proj = MLP(
284
- self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
285
- )
286
  else:
287
  self.obj_ptr_proj = torch.nn.Identity()
288
  if self.proj_tpos_enc_in_obj_ptrs:
@@ -480,7 +469,7 @@ class SAM2Base(torch.nn.Module):
480
  sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
481
 
482
  if False:
483
- # make all these on cpu
484
  self.history[self.frame_cnt] = {
485
  "kf_predicted_bbox": self.kf.xyah_to_xyxy(self.kf_mean[:4]),
486
  # "multi_masks": high_res_multimasks.cpu(),
@@ -495,7 +484,9 @@ class SAM2Base(torch.nn.Module):
495
  if ious[0][best_iou_inds] < self.stable_ious_threshold:
496
  self.stable_frames = 0
497
  else:
498
- self.kf_mean, self.kf_covariance = self.kf.update(self.kf_mean, self.kf_covariance, self.kf.xyxy_to_xyah(high_res_multibboxes[best_iou_inds]))
 
 
499
  elif multimask_output and not self.samurai_mode:
500
  # take the best mask prediction (with the highest IoU estimation)
501
  best_iou_inds = torch.argmax(ious, dim=-1)
@@ -553,9 +544,7 @@ class SAM2Base(torch.nn.Module):
553
  ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
554
  if not self.use_obj_ptrs_in_encoder:
555
  # all zeros as a dummy object pointer (of shape [B, C])
556
- obj_ptr = torch.zeros(
557
- mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
558
- )
559
  else:
560
  # produce an object pointer using the SAM decoder from the mask input
561
  _, _, _, _, _, obj_ptr, _, _, _ = self._forward_sam_heads(
@@ -591,12 +580,8 @@ class SAM2Base(torch.nn.Module):
591
  if self.use_high_res_features_in_sam:
592
  # precompute projected level 0 and level 1 features in SAM decoder
593
  # to avoid running it again on every SAM click
594
- backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
595
- backbone_out["backbone_fpn"][0]
596
- )
597
- backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
598
- backbone_out["backbone_fpn"][1]
599
- )
600
  return backbone_out
601
 
602
  def _prepare_backbone_features(self, backbone_out):
@@ -659,21 +644,27 @@ class SAM2Base(torch.nn.Module):
659
  stride = 1 if self.training else self.memory_temporal_stride_for_eval
660
 
661
  if self.samurai_mode:
662
- valid_indices = []
663
  if frame_idx > 1: # Ensure we have previous frames to evaluate
664
  for i in range(frame_idx - 1, 1, -1): # Iterate backwards through previous frames
665
  iou_score = output_dict["non_cond_frame_outputs"][i]["best_iou_score"] # Get mask affinity score
666
  obj_score = output_dict["non_cond_frame_outputs"][i]["object_score_logits"] # Get object score
667
- kf_score = output_dict["non_cond_frame_outputs"][i]["kf_score"] if "kf_score" in output_dict["non_cond_frame_outputs"][i] else None # Get motion score if available
 
 
 
 
668
  # Check if the scores meet the criteria for being a valid index
669
- if iou_score.item() > self.memory_bank_iou_threshold and \
670
- obj_score.item() > self.memory_bank_obj_score_threshold and \
671
- (kf_score is None or kf_score.item() > self.memory_bank_kf_score_threshold):
672
- valid_indices.insert(0, i)
 
 
673
  # Check the number of valid indices
674
- if len(valid_indices) >= self.max_obj_ptrs_in_encoder - 1:
675
  break
676
- if frame_idx - 1 not in valid_indices:
677
  valid_indices.append(frame_idx - 1)
678
  for t_pos in range(1, self.num_maskmem): # Iterate over the number of mask memories
679
  idx = t_pos - self.num_maskmem # Calculate the index for valid indices
@@ -726,9 +717,7 @@ class SAM2Base(torch.nn.Module):
726
  maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
727
  maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
728
  # Temporal positional encoding
729
- maskmem_enc = (
730
- maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
731
- )
732
  to_cat_memory_pos_embed.append(maskmem_enc)
733
 
734
  # Construct the list of past object pointers
@@ -738,20 +727,14 @@ class SAM2Base(torch.nn.Module):
738
  # (optionally, only include object pointers in the past during evaluation)
739
  if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
740
  ptr_cond_outputs = {
741
- t: out
742
- for t, out in selected_cond_outputs.items()
743
- if (t >= frame_idx if track_in_reverse else t <= frame_idx)
744
  }
745
  else:
746
  ptr_cond_outputs = selected_cond_outputs
747
  pos_and_ptrs = [
748
  # Temporal pos encoding contains how far away each pointer is from current frame
749
  (
750
- (
751
- (frame_idx - t) * tpos_sign_mul
752
- if self.use_signed_tpos_enc_to_obj_ptrs
753
- else abs(frame_idx - t)
754
- ),
755
  out["obj_ptr"],
756
  )
757
  for t, out in ptr_cond_outputs.items()
@@ -761,9 +744,7 @@ class SAM2Base(torch.nn.Module):
761
  t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
762
  if t < 0 or (num_frames is not None and t >= num_frames):
763
  break
764
- out = output_dict["non_cond_frame_outputs"].get(
765
- t, unselected_cond_outputs.get(t, None)
766
- )
767
  if out is not None:
768
  pos_and_ptrs.append((t_diff, out["obj_ptr"]))
769
  # If we have at least one object pointer, add them to the across attention
@@ -776,9 +757,7 @@ class SAM2Base(torch.nn.Module):
776
  if self.add_tpos_enc_to_obj_ptrs:
777
  t_diff_max = max_obj_ptrs_in_encoder - 1
778
  tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
779
- obj_pos = torch.tensor(pos_list).to(
780
- device=device, non_blocking=True
781
- )
782
  obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
783
  obj_pos = self.obj_ptr_tpos_proj(obj_pos)
784
  obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
@@ -786,9 +765,7 @@ class SAM2Base(torch.nn.Module):
786
  obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
787
  if self.mem_dim < C:
788
  # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
789
- obj_ptrs = obj_ptrs.reshape(
790
- -1, B, C // self.mem_dim, self.mem_dim
791
- )
792
  obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
793
  obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
794
  to_cat_memory.append(obj_ptrs)
@@ -841,9 +818,7 @@ class SAM2Base(torch.nn.Module):
841
  # optionally, apply non-overlapping constraints to the masks (it's applied
842
  # in the batch dimension and should only be used during eval, where all
843
  # the objects come from the same video under batch size 1).
844
- pred_masks_high_res = self._apply_non_overlapping_constraints(
845
- pred_masks_high_res
846
- )
847
  # scale the raw mask logits with a temperature before applying sigmoid
848
  binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
849
  if binarize and not self.training:
@@ -856,18 +831,14 @@ class SAM2Base(torch.nn.Module):
856
  mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
857
  if self.sigmoid_bias_for_mem_enc != 0.0:
858
  mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
859
- maskmem_out = self.memory_encoder(
860
- pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
861
- )
862
  maskmem_features = maskmem_out["vision_features"]
863
  maskmem_pos_enc = maskmem_out["vision_pos_enc"]
864
  # add a no-object embedding to the spatial memory to indicate that the frame
865
  # is predicted to be occluded (i.e. no object is appearing in the frame)
866
  if self.no_obj_embed_spatial is not None:
867
  is_obj_appearing = (object_score_logits > 0).float()
868
- maskmem_features += (
869
- 1 - is_obj_appearing[..., None, None]
870
- ) * self.no_obj_embed_spatial[..., None, None].expand(
871
  *maskmem_features.shape
872
  )
873
 
@@ -891,8 +862,7 @@ class SAM2Base(torch.nn.Module):
891
  # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
892
  if len(current_vision_feats) > 1:
893
  high_res_features = [
894
- x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
895
- for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
896
  ]
897
  else:
898
  high_res_features = None
@@ -901,9 +871,7 @@ class SAM2Base(torch.nn.Module):
901
  # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
902
  pix_feat = current_vision_feats[-1].permute(1, 2, 0)
903
  pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
904
- sam_outputs = self._use_mask_as_output(
905
- pix_feat, high_res_features, mask_inputs
906
- )
907
  else:
908
  # fused the visual feature with previous memory features in the memory bank
909
  pix_feat = self._prepare_memory_conditioned_features(
@@ -994,17 +962,7 @@ class SAM2Base(torch.nn.Module):
994
  prev_sam_mask_logits,
995
  )
996
 
997
- (
998
- _,
999
- _,
1000
- _,
1001
- low_res_masks,
1002
- high_res_masks,
1003
- obj_ptr,
1004
- object_score_logits,
1005
- best_iou_score,
1006
- kf_ious
1007
- ) = sam_outputs
1008
 
1009
  current_out["pred_masks"] = low_res_masks
1010
  current_out["pred_masks_high_res"] = high_res_masks
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
 
7
  import torch
8
  import torch.distributed
9
  import torch.nn.functional as F
 
10
  from torch.nn.init import trunc_normal_
11
 
12
+ from bboxmaskpose.sam2.modeling.sam2_utils import MLP, get_1d_sine_pe, select_closest_cond_frames
13
+ from bboxmaskpose.sam2.modeling.sam.mask_decoder import MaskDecoder
14
+ from bboxmaskpose.sam2.modeling.sam.pose_encoder import PoseEncoder
15
+ from bboxmaskpose.sam2.modeling.sam.transformer import TwoWayTransformer
16
+ from bboxmaskpose.sam2.utils.kalman_filter import KalmanFilter
17
+ from loguru import logger
18
 
19
  # a large negative value as a placeholder score for missing objects
20
  NO_OBJ_SCORE = -1024.0
 
134
  # Part 3: memory encoder for the previous frame's outputs
135
  self.memory_encoder = memory_encoder
136
  self.mem_dim = self.hidden_dim
137
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
 
 
138
  # if there is compression of memories along channel dim
139
  self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
140
  self.num_maskmem = num_maskmem # Number of memories accessible
141
  # Temporal encoding of the memories
142
+ self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
 
 
143
  trunc_normal_(self.maskmem_tpos_enc, std=0.02)
144
  # a single token to indicate no memory embedding from previous frames
145
  self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
 
198
  self.stable_frames = 0
199
 
200
  # Debug purpose
201
+ self.history = {} # debug
202
+ self.frame_cnt = 0 # debug
203
 
204
  # Hyperparameters for SAMURAI
205
  self.stable_frames_threshold = stable_frames_threshold
 
215
  # Model compilation
216
  if compile_image_encoder:
217
  # Compile the forward function (not the full module) to allow loading checkpoints.
218
+ print("Image encoder compilation is enabled. First forward pass will be slow.")
 
 
219
  self.image_encoder.forward = torch.compile(
220
  self.image_encoder.forward,
221
  mode="max-autotune",
 
271
  # a linear projection on SAM output tokens to turn them into object pointers
272
  self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
273
  if self.use_mlp_for_obj_ptr_proj:
274
+ self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
 
 
275
  else:
276
  self.obj_ptr_proj = torch.nn.Identity()
277
  if self.proj_tpos_enc_in_obj_ptrs:
 
469
  sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
470
 
471
  if False:
472
+ # make all these on cpu
473
  self.history[self.frame_cnt] = {
474
  "kf_predicted_bbox": self.kf.xyah_to_xyxy(self.kf_mean[:4]),
475
  # "multi_masks": high_res_multimasks.cpu(),
 
484
  if ious[0][best_iou_inds] < self.stable_ious_threshold:
485
  self.stable_frames = 0
486
  else:
487
+ self.kf_mean, self.kf_covariance = self.kf.update(
488
+ self.kf_mean, self.kf_covariance, self.kf.xyxy_to_xyah(high_res_multibboxes[best_iou_inds])
489
+ )
490
  elif multimask_output and not self.samurai_mode:
491
  # take the best mask prediction (with the highest IoU estimation)
492
  best_iou_inds = torch.argmax(ious, dim=-1)
 
544
  ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
545
  if not self.use_obj_ptrs_in_encoder:
546
  # all zeros as a dummy object pointer (of shape [B, C])
547
+ obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
 
 
548
  else:
549
  # produce an object pointer using the SAM decoder from the mask input
550
  _, _, _, _, _, obj_ptr, _, _, _ = self._forward_sam_heads(
 
580
  if self.use_high_res_features_in_sam:
581
  # precompute projected level 0 and level 1 features in SAM decoder
582
  # to avoid running it again on every SAM click
583
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
584
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
 
 
 
 
585
  return backbone_out
586
 
587
  def _prepare_backbone_features(self, backbone_out):
 
644
  stride = 1 if self.training else self.memory_temporal_stride_for_eval
645
 
646
  if self.samurai_mode:
647
+ valid_indices = []
648
  if frame_idx > 1: # Ensure we have previous frames to evaluate
649
  for i in range(frame_idx - 1, 1, -1): # Iterate backwards through previous frames
650
  iou_score = output_dict["non_cond_frame_outputs"][i]["best_iou_score"] # Get mask affinity score
651
  obj_score = output_dict["non_cond_frame_outputs"][i]["object_score_logits"] # Get object score
652
+ kf_score = (
653
+ output_dict["non_cond_frame_outputs"][i]["kf_score"]
654
+ if "kf_score" in output_dict["non_cond_frame_outputs"][i]
655
+ else None
656
+ ) # Get motion score if available
657
  # Check if the scores meet the criteria for being a valid index
658
+ if (
659
+ iou_score.item() > self.memory_bank_iou_threshold
660
+ and obj_score.item() > self.memory_bank_obj_score_threshold
661
+ and (kf_score is None or kf_score.item() > self.memory_bank_kf_score_threshold)
662
+ ):
663
+ valid_indices.insert(0, i)
664
  # Check the number of valid indices
665
+ if len(valid_indices) >= self.max_obj_ptrs_in_encoder - 1:
666
  break
667
+ if frame_idx - 1 not in valid_indices:
668
  valid_indices.append(frame_idx - 1)
669
  for t_pos in range(1, self.num_maskmem): # Iterate over the number of mask memories
670
  idx = t_pos - self.num_maskmem # Calculate the index for valid indices
 
717
  maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
718
  maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
719
  # Temporal positional encoding
720
+ maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
 
 
721
  to_cat_memory_pos_embed.append(maskmem_enc)
722
 
723
  # Construct the list of past object pointers
 
727
  # (optionally, only include object pointers in the past during evaluation)
728
  if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
729
  ptr_cond_outputs = {
730
+ t: out for t, out in selected_cond_outputs.items() if (t >= frame_idx if track_in_reverse else t <= frame_idx)
 
 
731
  }
732
  else:
733
  ptr_cond_outputs = selected_cond_outputs
734
  pos_and_ptrs = [
735
  # Temporal pos encoding contains how far away each pointer is from current frame
736
  (
737
+ ((frame_idx - t) * tpos_sign_mul if self.use_signed_tpos_enc_to_obj_ptrs else abs(frame_idx - t)),
 
 
 
 
738
  out["obj_ptr"],
739
  )
740
  for t, out in ptr_cond_outputs.items()
 
744
  t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
745
  if t < 0 or (num_frames is not None and t >= num_frames):
746
  break
747
+ out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
 
 
748
  if out is not None:
749
  pos_and_ptrs.append((t_diff, out["obj_ptr"]))
750
  # If we have at least one object pointer, add them to the across attention
 
757
  if self.add_tpos_enc_to_obj_ptrs:
758
  t_diff_max = max_obj_ptrs_in_encoder - 1
759
  tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
760
+ obj_pos = torch.tensor(pos_list).to(device=device, non_blocking=True)
 
 
761
  obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
762
  obj_pos = self.obj_ptr_tpos_proj(obj_pos)
763
  obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
 
765
  obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
766
  if self.mem_dim < C:
767
  # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
768
+ obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
 
 
769
  obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
770
  obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
771
  to_cat_memory.append(obj_ptrs)
 
818
  # optionally, apply non-overlapping constraints to the masks (it's applied
819
  # in the batch dimension and should only be used during eval, where all
820
  # the objects come from the same video under batch size 1).
821
+ pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
 
 
822
  # scale the raw mask logits with a temperature before applying sigmoid
823
  binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
824
  if binarize and not self.training:
 
831
  mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
832
  if self.sigmoid_bias_for_mem_enc != 0.0:
833
  mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
834
+ maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
 
 
835
  maskmem_features = maskmem_out["vision_features"]
836
  maskmem_pos_enc = maskmem_out["vision_pos_enc"]
837
  # add a no-object embedding to the spatial memory to indicate that the frame
838
  # is predicted to be occluded (i.e. no object is appearing in the frame)
839
  if self.no_obj_embed_spatial is not None:
840
  is_obj_appearing = (object_score_logits > 0).float()
841
+ maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[..., None, None].expand(
 
 
842
  *maskmem_features.shape
843
  )
844
 
 
862
  # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
863
  if len(current_vision_feats) > 1:
864
  high_res_features = [
865
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
 
866
  ]
867
  else:
868
  high_res_features = None
 
871
  # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
872
  pix_feat = current_vision_feats[-1].permute(1, 2, 0)
873
  pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
874
+ sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
 
 
875
  else:
876
  # fused the visual feature with previous memory features in the memory bank
877
  pix_feat = self._prepare_memory_conditioned_features(
 
962
  prev_sam_mask_logits,
963
  )
964
 
965
+ _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits, best_iou_score, kf_ious = sam_outputs
 
 
 
 
 
 
 
 
 
 
966
 
967
  current_out["pred_masks"] = low_res_masks
968
  current_out["pred_masks_high_res"] = high_res_masks
{sam2 → bboxmaskpose/sam2}/modeling/sam2_utils.py RENAMED
@@ -13,7 +13,7 @@ import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
 
16
- from sam2.utils.misc import mask_to_box
17
 
18
 
19
  def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
@@ -54,9 +54,7 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
54
  key=lambda x: abs(x - frame_idx),
55
  )[:num_remain]
56
  selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
57
- unselected_outputs = {
58
- t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
59
- }
60
 
61
  return selected_outputs, unselected_outputs
62
 
@@ -122,9 +120,7 @@ class MLP(nn.Module):
122
  super().__init__()
123
  self.num_layers = num_layers
124
  h = [hidden_dim] * (num_layers - 1)
125
- self.layers = nn.ModuleList(
126
- nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
127
- )
128
  self.sigmoid_output = sigmoid_output
129
  self.act = activation()
130
 
@@ -175,9 +171,7 @@ def sample_box_points(
175
  device = masks.device
176
  box_coords = mask_to_box(masks)
177
  B, _, H, W = masks.shape
178
- box_labels = torch.tensor(
179
- [top_left_label, bottom_right_label], dtype=torch.int, device=device
180
- ).repeat(B)
181
  if noise > 0.0:
182
  if not isinstance(noise_bound, torch.Tensor):
183
  noise_bound = torch.tensor(noise_bound, device=device)
@@ -189,9 +183,7 @@ def sample_box_points(
189
  box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
190
 
191
  box_coords = box_coords + box_noise
192
- img_bounds = (
193
- torch.tensor([W, H, W, H], device=device) - 1
194
- ) # uncentered pixel coords
195
  box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
196
 
197
  box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
 
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
 
16
+ from bboxmaskpose.sam2.utils.misc import mask_to_box
17
 
18
 
19
  def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
 
54
  key=lambda x: abs(x - frame_idx),
55
  )[:num_remain]
56
  selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
57
+ unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}
 
 
58
 
59
  return selected_outputs, unselected_outputs
60
 
 
120
  super().__init__()
121
  self.num_layers = num_layers
122
  h = [hidden_dim] * (num_layers - 1)
123
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
 
 
124
  self.sigmoid_output = sigmoid_output
125
  self.act = activation()
126
 
 
171
  device = masks.device
172
  box_coords = mask_to_box(masks)
173
  B, _, H, W = masks.shape
174
+ box_labels = torch.tensor([top_left_label, bottom_right_label], dtype=torch.int, device=device).repeat(B)
 
 
175
  if noise > 0.0:
176
  if not isinstance(noise_bound, torch.Tensor):
177
  noise_bound = torch.tensor(noise_bound, device=device)
 
183
  box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
184
 
185
  box_coords = box_coords + box_noise
186
+ img_bounds = torch.tensor([W, H, W, H], device=device) - 1 # uncentered pixel coords
 
 
187
  box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
188
 
189
  box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
{sam2 → bboxmaskpose/sam2}/sam2_image_predictor.py RENAMED
@@ -5,16 +5,14 @@
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  import logging
8
-
9
  from typing import List, Optional, Tuple, Union
10
 
11
  import numpy as np
12
  import torch
13
  from PIL.Image import Image
14
 
15
- from sam2.modeling.sam2_base import SAM2Base
16
-
17
- from sam2.utils.transforms import SAM2Transforms
18
 
19
 
20
  class SAM2ImagePredictor:
@@ -61,9 +59,9 @@ class SAM2ImagePredictor:
61
  # Spatial dim for backbone feature maps
62
  isize = self.model.image_size
63
  self._bb_feat_sizes = [
64
- (isize//4, isize//4),
65
- (isize//8, isize//8),
66
- (isize//16, isize//16),
67
  ]
68
 
69
  @classmethod
@@ -78,7 +76,7 @@ class SAM2ImagePredictor:
78
  Returns:
79
  (SAM2ImagePredictor): The loaded model.
80
  """
81
- from sam2.build_sam import build_sam2_hf
82
 
83
  sam_model = build_sam2_hf(model_id, **kwargs)
84
  return cls(sam_model, **kwargs)
@@ -111,9 +109,7 @@ class SAM2ImagePredictor:
111
  input_image = self._transforms(image)
112
  input_image = input_image[None, ...].to(self.device)
113
 
114
- assert (
115
- len(input_image.shape) == 4 and input_image.shape[1] == 3
116
- ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
117
  logging.info("Computing image embeddings for the provided image...")
118
  backbone_out = self.model.forward_image(input_image)
119
  _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
@@ -122,10 +118,9 @@ class SAM2ImagePredictor:
122
  vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
123
 
124
  # breakpoint()
125
- feats = [
126
- feat.permute(1, 2, 0).view(1, -1, *feat_size)
127
- for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
128
- ][::-1]
129
  self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
130
  self._is_image_set = True
131
  logging.info("Image embeddings computed.")
@@ -148,17 +143,13 @@ class SAM2ImagePredictor:
148
  assert isinstance(image_list, list)
149
  self._orig_hw = []
150
  for image in image_list:
151
- assert isinstance(
152
- image, np.ndarray
153
- ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
154
  self._orig_hw.append(image.shape[:2])
155
  # Transform the image to the form expected by the model
156
  img_batch = self._transforms.forward_batch(image_list)
157
  img_batch = img_batch.to(self.device)
158
  batch_size = img_batch.shape[0]
159
- assert (
160
- len(img_batch.shape) == 4 and img_batch.shape[1] == 3
161
- ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
162
  logging.info("Computing image embeddings for the provided images...")
163
  backbone_out = self.model.forward_image(img_batch)
164
  _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
@@ -167,8 +158,7 @@ class SAM2ImagePredictor:
167
  vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
168
 
169
  feats = [
170
- feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
171
- for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
172
  ][::-1]
173
  self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
174
  self._is_image_set = True
@@ -190,25 +180,17 @@ class SAM2ImagePredictor:
190
  """
191
  assert self._is_batch, "This function should only be used when in batched mode"
192
  if not self._is_image_set:
193
- raise RuntimeError(
194
- "An image must be set with .set_image_batch(...) before mask prediction."
195
- )
196
  num_images = len(self._features["image_embed"])
197
  all_masks = []
198
  all_ious = []
199
  all_low_res_masks = []
200
  for img_idx in range(num_images):
201
  # Transform input prompts
202
- point_coords = (
203
- point_coords_batch[img_idx] if point_coords_batch is not None else None
204
- )
205
- point_labels = (
206
- point_labels_batch[img_idx] if point_labels_batch is not None else None
207
- )
208
  box = box_batch[img_idx] if box_batch is not None else None
209
- mask_input = (
210
- mask_input_batch[img_idx] if mask_input_batch is not None else None
211
- )
212
  mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
213
  point_coords,
214
  point_labels,
@@ -227,9 +209,7 @@ class SAM2ImagePredictor:
227
  img_idx=img_idx,
228
  )
229
  masks_np = masks.squeeze(0).float().detach().cpu().numpy()
230
- iou_predictions_np = (
231
- iou_predictions.squeeze(0).float().detach().cpu().numpy()
232
- )
233
  low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
234
  all_masks.append(masks_np)
235
  all_ious.append(iou_predictions_np)
@@ -281,15 +261,11 @@ class SAM2ImagePredictor:
281
  a subsequent iteration as mask input.
282
  """
283
  if not self._is_image_set:
284
- raise RuntimeError(
285
- "An image must be set with .set_image(...) before mask prediction."
286
- )
287
 
288
  # Transform input prompts
289
 
290
- mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
291
- point_coords, point_labels, box, mask_input, normalize_coords
292
- )
293
 
294
  masks, iou_predictions, low_res_masks = self._predict(
295
  unnorm_coords,
@@ -305,33 +281,21 @@ class SAM2ImagePredictor:
305
  low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
306
  return masks_np, iou_predictions_np, low_res_masks_np
307
 
308
- def _prep_prompts(
309
- self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
310
- ):
311
 
312
  unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
313
  if point_coords is not None:
314
- assert (
315
- point_labels is not None
316
- ), "point_labels must be supplied if point_coords is supplied."
317
- point_coords = torch.as_tensor(
318
- point_coords, dtype=torch.float, device=self.device
319
- )
320
- unnorm_coords = self._transforms.transform_coords(
321
- point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
322
- )
323
  labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
324
  if len(unnorm_coords.shape) == 2:
325
  unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
326
  if box is not None:
327
  box = torch.as_tensor(box, dtype=torch.float, device=self.device)
328
- unnorm_box = self._transforms.transform_boxes(
329
- box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
330
- ) # Bx2x2
331
  if mask_logits is not None:
332
- mask_input = torch.as_tensor(
333
- mask_logits, dtype=torch.float, device=self.device
334
- )
335
  if len(mask_input.shape) == 3:
336
  mask_input = mask_input[None, :, :, :]
337
  return mask_input, unnorm_coords, labels, unnorm_box
@@ -383,9 +347,7 @@ class SAM2ImagePredictor:
383
  a subsequent iteration as mask input.
384
  """
385
  if not self._is_image_set:
386
- raise RuntimeError(
387
- "An image must be set with .set_image(...) before mask prediction."
388
- )
389
 
390
  if point_coords is not None:
391
  concat_points = (point_coords, point_labels)
@@ -413,13 +375,8 @@ class SAM2ImagePredictor:
413
  )
414
 
415
  # Predict masks
416
- batched_mode = (
417
- concat_points is not None and concat_points[0].shape[0] > 1
418
- ) # multi object prediction
419
- high_res_features = [
420
- feat_level[img_idx].unsqueeze(0)
421
- for feat_level in self._features["high_res_feats"]
422
- ]
423
  low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
424
  image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
425
  image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
@@ -431,9 +388,7 @@ class SAM2ImagePredictor:
431
  )
432
 
433
  # Upscale the masks to the original image resolution
434
- masks = self._transforms.postprocess_masks(
435
- low_res_masks, self._orig_hw[img_idx]
436
- )
437
  low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
438
  if not return_logits:
439
  masks = masks > self.mask_threshold
@@ -447,12 +402,8 @@ class SAM2ImagePredictor:
447
  the embedding spatial dimension of SAM (typically C=256, H=W=64).
448
  """
449
  if not self._is_image_set:
450
- raise RuntimeError(
451
- "An image must be set with .set_image(...) to generate an embedding."
452
- )
453
- assert (
454
- self._features is not None
455
- ), "Features must exist if an image has been set."
456
  return self._features["image_embed"]
457
 
458
  @property
 
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  import logging
 
8
  from typing import List, Optional, Tuple, Union
9
 
10
  import numpy as np
11
  import torch
12
  from PIL.Image import Image
13
 
14
+ from bboxmaskpose.sam2.modeling.sam2_base import SAM2Base
15
+ from bboxmaskpose.sam2.utils.transforms import SAM2Transforms
 
16
 
17
 
18
  class SAM2ImagePredictor:
 
59
  # Spatial dim for backbone feature maps
60
  isize = self.model.image_size
61
  self._bb_feat_sizes = [
62
+ (isize // 4, isize // 4),
63
+ (isize // 8, isize // 8),
64
+ (isize // 16, isize // 16),
65
  ]
66
 
67
  @classmethod
 
76
  Returns:
77
  (SAM2ImagePredictor): The loaded model.
78
  """
79
+ from bboxmaskpose.sam2.build_sam import build_sam2_hf
80
 
81
  sam_model = build_sam2_hf(model_id, **kwargs)
82
  return cls(sam_model, **kwargs)
 
109
  input_image = self._transforms(image)
110
  input_image = input_image[None, ...].to(self.device)
111
 
112
+ assert len(input_image.shape) == 4 and input_image.shape[1] == 3, f"input_image must be of size 1x3xHxW, got {input_image.shape}"
 
 
113
  logging.info("Computing image embeddings for the provided image...")
114
  backbone_out = self.model.forward_image(input_image)
115
  _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
 
118
  vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
119
 
120
  # breakpoint()
121
+ feats = [feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])][
122
+ ::-1
123
+ ]
 
124
  self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
125
  self._is_image_set = True
126
  logging.info("Image embeddings computed.")
 
143
  assert isinstance(image_list, list)
144
  self._orig_hw = []
145
  for image in image_list:
146
+ assert isinstance(image, np.ndarray), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
 
 
147
  self._orig_hw.append(image.shape[:2])
148
  # Transform the image to the form expected by the model
149
  img_batch = self._transforms.forward_batch(image_list)
150
  img_batch = img_batch.to(self.device)
151
  batch_size = img_batch.shape[0]
152
+ assert len(img_batch.shape) == 4 and img_batch.shape[1] == 3, f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
 
 
153
  logging.info("Computing image embeddings for the provided images...")
154
  backbone_out = self.model.forward_image(img_batch)
155
  _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
 
158
  vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
159
 
160
  feats = [
161
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
 
162
  ][::-1]
163
  self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
164
  self._is_image_set = True
 
180
  """
181
  assert self._is_batch, "This function should only be used when in batched mode"
182
  if not self._is_image_set:
183
+ raise RuntimeError("An image must be set with .set_image_batch(...) before mask prediction.")
 
 
184
  num_images = len(self._features["image_embed"])
185
  all_masks = []
186
  all_ious = []
187
  all_low_res_masks = []
188
  for img_idx in range(num_images):
189
  # Transform input prompts
190
+ point_coords = point_coords_batch[img_idx] if point_coords_batch is not None else None
191
+ point_labels = point_labels_batch[img_idx] if point_labels_batch is not None else None
 
 
 
 
192
  box = box_batch[img_idx] if box_batch is not None else None
193
+ mask_input = mask_input_batch[img_idx] if mask_input_batch is not None else None
 
 
194
  mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
195
  point_coords,
196
  point_labels,
 
209
  img_idx=img_idx,
210
  )
211
  masks_np = masks.squeeze(0).float().detach().cpu().numpy()
212
+ iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
 
 
213
  low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
214
  all_masks.append(masks_np)
215
  all_ious.append(iou_predictions_np)
 
261
  a subsequent iteration as mask input.
262
  """
263
  if not self._is_image_set:
264
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
 
 
265
 
266
  # Transform input prompts
267
 
268
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(point_coords, point_labels, box, mask_input, normalize_coords)
 
 
269
 
270
  masks, iou_predictions, low_res_masks = self._predict(
271
  unnorm_coords,
 
281
  low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
282
  return masks_np, iou_predictions_np, low_res_masks_np
283
 
284
+ def _prep_prompts(self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1):
 
 
285
 
286
  unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
287
  if point_coords is not None:
288
+ assert point_labels is not None, "point_labels must be supplied if point_coords is supplied."
289
+ point_coords = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
290
+ unnorm_coords = self._transforms.transform_coords(point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx])
 
 
 
 
 
 
291
  labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
292
  if len(unnorm_coords.shape) == 2:
293
  unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
294
  if box is not None:
295
  box = torch.as_tensor(box, dtype=torch.float, device=self.device)
296
+ unnorm_box = self._transforms.transform_boxes(box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]) # Bx2x2
 
 
297
  if mask_logits is not None:
298
+ mask_input = torch.as_tensor(mask_logits, dtype=torch.float, device=self.device)
 
 
299
  if len(mask_input.shape) == 3:
300
  mask_input = mask_input[None, :, :, :]
301
  return mask_input, unnorm_coords, labels, unnorm_box
 
347
  a subsequent iteration as mask input.
348
  """
349
  if not self._is_image_set:
350
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
 
 
351
 
352
  if point_coords is not None:
353
  concat_points = (point_coords, point_labels)
 
375
  )
376
 
377
  # Predict masks
378
+ batched_mode = concat_points is not None and concat_points[0].shape[0] > 1 # multi object prediction
379
+ high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in self._features["high_res_feats"]]
 
 
 
 
 
380
  low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
381
  image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
382
  image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
 
388
  )
389
 
390
  # Upscale the masks to the original image resolution
391
+ masks = self._transforms.postprocess_masks(low_res_masks, self._orig_hw[img_idx])
 
 
392
  low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
393
  if not return_logits:
394
  masks = masks > self.mask_threshold
 
402
  the embedding spatial dimension of SAM (typically C=256, H=W=64).
403
  """
404
  if not self._is_image_set:
405
+ raise RuntimeError("An image must be set with .set_image(...) to generate an embedding.")
406
+ assert self._features is not None, "Features must exist if an image has been set."
 
 
 
 
407
  return self._features["image_embed"]
408
 
409
  @property
{sam2 → bboxmaskpose/sam2}/sam2_video_predictor.py RENAMED
@@ -9,7 +9,6 @@ from collections import OrderedDict
9
 
10
  import torch
11
  import torch.nn.functional as F
12
-
13
  from tqdm import tqdm
14
 
15
  from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
@@ -27,11 +26,6 @@ class SAM2VideoPredictor(SAM2Base):
27
  # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
28
  # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
29
  clear_non_cond_mem_around_input=False,
30
- <<<<<<< HEAD
31
- # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
32
- clear_non_cond_mem_for_multi_obj=False,
33
- =======
34
- >>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
35
  # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
36
  # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
37
  add_all_frames_to_correct_as_cond=False,
@@ -41,10 +35,6 @@ class SAM2VideoPredictor(SAM2Base):
41
  self.fill_hole_area = fill_hole_area
42
  self.non_overlap_masks = non_overlap_masks
43
  self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
44
- <<<<<<< HEAD
45
- self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
46
- =======
47
- >>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
48
  self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
49
 
50
  @torch.inference_mode()
@@ -296,9 +286,7 @@ class SAM2VideoPredictor(SAM2Base):
296
  is_cond=is_cond,
297
  consolidate_at_video_res=True,
298
  )
299
- _, video_res_masks = self._get_orig_video_res_output(
300
- inference_state, consolidated_out["pred_masks_video_res"]
301
- )
302
  return frame_idx, obj_ids, video_res_masks
303
 
304
  def add_new_points(self, *args, **kwargs):
@@ -384,9 +372,7 @@ class SAM2VideoPredictor(SAM2Base):
384
  is_cond=is_cond,
385
  consolidate_at_video_res=True,
386
  )
387
- _, video_res_masks = self._get_orig_video_res_output(
388
- inference_state, consolidated_out["pred_masks_video_res"]
389
- )
390
  return frame_idx, obj_ids, video_res_masks
391
 
392
  def _get_orig_video_res_output(self, inference_state, any_res_masks):
@@ -450,23 +436,6 @@ class SAM2VideoPredictor(SAM2Base):
450
  dtype=torch.float32,
451
  device=inference_state["storage_device"],
452
  ),
453
- <<<<<<< HEAD
454
- "obj_ptr": torch.full(
455
- size=(batch_size, self.hidden_dim),
456
- fill_value=NO_OBJ_SCORE,
457
- dtype=torch.float32,
458
- device=inference_state["device"],
459
- ),
460
- "object_score_logits": torch.full(
461
- size=(batch_size, 1),
462
- # default to 10.0 for object_score_logits, i.e. assuming the object is
463
- # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
464
- fill_value=10.0,
465
- dtype=torch.float32,
466
- device=inference_state["device"],
467
- ),
468
- =======
469
- >>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
470
  }
471
  for obj_idx in range(batch_size):
472
  obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
@@ -499,36 +468,6 @@ class SAM2VideoPredictor(SAM2Base):
499
  align_corners=False,
500
  )
501
  consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
502
- <<<<<<< HEAD
503
- consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
504
- consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
505
- "object_score_logits"
506
- ]
507
-
508
- # Optionally, apply non-overlapping constraints on the consolidated scores
509
- # and rerun the memory encoder
510
- if run_mem_encoder:
511
- device = inference_state["device"]
512
- high_res_masks = torch.nn.functional.interpolate(
513
- consolidated_out["pred_masks"].to(device, non_blocking=True),
514
- size=(self.image_size, self.image_size),
515
- mode="bilinear",
516
- align_corners=False,
517
- )
518
- if self.non_overlap_masks_for_mem_enc:
519
- high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
520
- maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
521
- inference_state=inference_state,
522
- frame_idx=frame_idx,
523
- batch_size=batch_size,
524
- high_res_masks=high_res_masks,
525
- object_score_logits=consolidated_out["object_score_logits"],
526
- is_mask_from_pts=True, # these frames are what the user interacted with
527
- )
528
- consolidated_out["maskmem_features"] = maskmem_features
529
- consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
530
- =======
531
- >>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
532
 
533
  return consolidated_out
534
 
@@ -538,9 +477,7 @@ class SAM2VideoPredictor(SAM2Base):
538
  # Check and make sure that every object has received input points or masks.
539
  batch_size = self._get_obj_num(inference_state)
540
  if batch_size == 0:
541
- raise RuntimeError(
542
- "No input points or masks are provided for any object; please add inputs first."
543
- )
544
 
545
  # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
546
  # add them into "output_dict".
@@ -549,9 +486,7 @@ class SAM2VideoPredictor(SAM2Base):
549
  obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
550
  for is_cond in [False, True]:
551
  # Separately consolidate conditioning and non-conditioning temp outputs
552
- storage_key = (
553
- "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
554
- )
555
  # Find all the frames that contain temporary outputs for any objects
556
  # (these should be the frames that have just received clicks for mask inputs
557
  # via `add_new_points_or_box` or `add_new_mask`)
@@ -579,9 +514,7 @@ class SAM2VideoPredictor(SAM2Base):
579
  obj_output_dict[storage_key][frame_idx] = out
580
  if self.clear_non_cond_mem_around_input:
581
  # clear non-conditioning memory of the surrounding frames
582
- self._clear_obj_non_cond_mem_around_input(
583
- inference_state, frame_idx, obj_idx
584
- )
585
 
586
  # clear temporary outputs in `temp_output_dict_per_obj`
587
  obj_temp_output_dict[storage_key].clear()
@@ -590,9 +523,7 @@ class SAM2VideoPredictor(SAM2Base):
590
  obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
591
  if len(obj_output_dict["cond_frame_outputs"]) == 0:
592
  obj_id = self._obj_idx_to_id(inference_state, obj_idx)
593
- raise RuntimeError(
594
- f"No input points or masks are provided for object id {obj_id}; please add inputs first."
595
- )
596
  # edge case: if an output is added to "cond_frame_outputs", we remove any prior
597
  # output on the same frame in "non_cond_frame_outputs"
598
  for frame_idx in obj_output_dict["cond_frame_outputs"]:
@@ -617,9 +548,7 @@ class SAM2VideoPredictor(SAM2Base):
617
  if start_frame_idx is None:
618
  # default: start from the earliest frame with input points
619
  start_frame_idx = min(
620
- t
621
- for obj_output_dict in inference_state["output_dict_per_obj"].values()
622
- for t in obj_output_dict["cond_frame_outputs"]
623
  )
624
  if max_frame_num_to_track is None:
625
  # default: track all the frames in the video
@@ -631,9 +560,7 @@ class SAM2VideoPredictor(SAM2Base):
631
  else:
632
  processing_order = [] # skip reverse tracking if starting from frame 0
633
  else:
634
- end_frame_idx = min(
635
- start_frame_idx + max_frame_num_to_track, num_frames - 1
636
- )
637
  processing_order = range(start_frame_idx, end_frame_idx + 1)
638
 
639
  for frame_idx in tqdm(processing_order, desc="propagate in video"):
@@ -651,9 +578,7 @@ class SAM2VideoPredictor(SAM2Base):
651
  pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
652
  if self.clear_non_cond_mem_around_input:
653
  # clear non-conditioning memory of the surrounding frames
654
- self._clear_obj_non_cond_mem_around_input(
655
- inference_state, frame_idx, obj_idx
656
- )
657
  else:
658
  storage_key = "non_cond_frame_outputs"
659
  current_out, pred_masks = self._run_single_frame_inference(
@@ -669,9 +594,7 @@ class SAM2VideoPredictor(SAM2Base):
669
  )
670
  obj_output_dict[storage_key][frame_idx] = current_out
671
 
672
- inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
673
- "reverse": reverse
674
- }
675
  pred_masks_per_obj[obj_idx] = pred_masks
676
 
677
  # Resize the output mask to the original video resolution (we directly use
@@ -680,42 +603,11 @@ class SAM2VideoPredictor(SAM2Base):
680
  all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
681
  else:
682
  all_pred_masks = pred_masks_per_obj[0]
683
- _, video_res_masks = self._get_orig_video_res_output(
684
- inference_state, all_pred_masks
685
- )
686
  yield frame_idx, obj_ids, video_res_masks
687
 
688
  @torch.inference_mode()
689
- def clear_all_prompts_in_frame(
690
- self, inference_state, frame_idx, obj_id, need_output=True
691
- ):
692
- <<<<<<< HEAD
693
- """
694
- Split a multi-object output into per-object output slices and add them into
695
- `output_dict_per_obj`. The resulting slices share the same tensor storage.
696
- """
697
- maskmem_features = current_out["maskmem_features"]
698
- assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
699
-
700
- maskmem_pos_enc = current_out["maskmem_pos_enc"]
701
- assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
702
-
703
- output_dict_per_obj = inference_state["output_dict_per_obj"]
704
- for obj_idx, obj_output_dict in output_dict_per_obj.items():
705
- obj_slice = slice(obj_idx, obj_idx + 1)
706
- obj_out = {
707
- "maskmem_features": None,
708
- "maskmem_pos_enc": None,
709
- "pred_masks": current_out["pred_masks"][obj_slice],
710
- "obj_ptr": current_out["obj_ptr"][obj_slice],
711
- "object_score_logits": current_out["object_score_logits"][obj_slice],
712
- }
713
- if maskmem_features is not None:
714
- obj_out["maskmem_features"] = maskmem_features[obj_slice]
715
- if maskmem_pos_enc is not None:
716
- obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
717
- obj_output_dict[storage_key][frame_idx] = obj_out
718
- =======
719
  """Remove all input points or mask in a specific frame for a given object."""
720
  obj_idx = self._obj_id_to_idx(inference_state, obj_id)
721
 
@@ -740,91 +632,14 @@ class SAM2VideoPredictor(SAM2Base):
740
  return
741
  # Finally, output updated masks per object (after removing the inputs above)
742
  obj_ids = inference_state["obj_ids"]
743
- is_cond = any(
744
- frame_idx in obj_temp_output_dict["cond_frame_outputs"]
745
- for obj_temp_output_dict in temp_output_dict_per_obj.values()
746
- )
747
  consolidated_out = self._consolidate_temp_output_across_obj(
748
  inference_state,
749
  frame_idx,
750
  is_cond=is_cond,
751
  consolidate_at_video_res=True,
752
  )
753
- _, video_res_masks = self._get_orig_video_res_output(
754
- inference_state, consolidated_out["pred_masks_video_res"]
755
- )
756
- return frame_idx, obj_ids, video_res_masks
757
- >>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
758
-
759
- @torch.inference_mode()
760
- def clear_all_prompts_in_frame(
761
- self, inference_state, frame_idx, obj_id, need_output=True
762
- ):
763
- """Remove all input points or mask in a specific frame for a given object."""
764
- obj_idx = self._obj_id_to_idx(inference_state, obj_id)
765
-
766
- # Clear the conditioning information on the given frame
767
- inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
768
- inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
769
-
770
- temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
771
- temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
772
- temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
773
-
774
- # Check and see if there are still any inputs left on this frame
775
- batch_size = self._get_obj_num(inference_state)
776
- frame_has_input = False
777
- for obj_idx2 in range(batch_size):
778
- if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
779
- frame_has_input = True
780
- break
781
- if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
782
- frame_has_input = True
783
- break
784
-
785
- # If this frame has no remaining inputs for any objects, we further clear its
786
- # conditioning frame status
787
- if not frame_has_input:
788
- output_dict = inference_state["output_dict"]
789
- consolidated_frame_inds = inference_state["consolidated_frame_inds"]
790
- consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
791
- consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
792
- # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
793
- out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
794
- if out is not None:
795
- # The frame is not a conditioning frame anymore since it's not receiving inputs,
796
- # so we "downgrade" its output (if exists) to a non-conditioning frame output.
797
- output_dict["non_cond_frame_outputs"][frame_idx] = out
798
- inference_state["frames_already_tracked"].pop(frame_idx, None)
799
- # Similarly, do it for the sliced output on each object.
800
- for obj_idx2 in range(batch_size):
801
- obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
802
- obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
803
- if obj_out is not None:
804
- obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
805
-
806
- # If all the conditioning frames have been removed, we also clear the tracking outputs
807
- if len(output_dict["cond_frame_outputs"]) == 0:
808
- self._reset_tracking_results(inference_state)
809
-
810
- if not need_output:
811
- return
812
- # Finally, output updated masks per object (after removing the inputs above)
813
- obj_ids = inference_state["obj_ids"]
814
- is_cond = any(
815
- frame_idx in obj_temp_output_dict["cond_frame_outputs"]
816
- for obj_temp_output_dict in temp_output_dict_per_obj.values()
817
- )
818
- consolidated_out = self._consolidate_temp_output_across_obj(
819
- inference_state,
820
- frame_idx,
821
- is_cond=is_cond,
822
- run_mem_encoder=False,
823
- consolidate_at_video_res=True,
824
- )
825
- _, video_res_masks = self._get_orig_video_res_output(
826
- inference_state, consolidated_out["pred_masks_video_res"]
827
- )
828
  return frame_idx, obj_ids, video_res_masks
829
 
830
  @torch.inference_mode()
@@ -859,9 +674,7 @@ class SAM2VideoPredictor(SAM2Base):
859
  def _get_image_feature(self, inference_state, frame_idx, batch_size):
860
  """Compute the image features on a given frame."""
861
  # Look up in the cache first
862
- image, backbone_out = inference_state["cached_features"].get(
863
- frame_idx, (None, None)
864
- )
865
  if backbone_out is None:
866
  # Cache miss -- we will run inference on a single image
867
  device = inference_state["device"]
@@ -878,9 +691,7 @@ class SAM2VideoPredictor(SAM2Base):
878
  "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
879
  }
880
  for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
881
- expanded_backbone_out["backbone_fpn"][i] = feat.expand(
882
- batch_size, -1, -1, -1
883
- )
884
  for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
885
  pos = pos.expand(batch_size, -1, -1, -1)
886
  expanded_backbone_out["vision_pos_enc"][i] = pos
@@ -935,33 +746,23 @@ class SAM2VideoPredictor(SAM2Base):
935
  if maskmem_features is not None:
936
  maskmem_features = maskmem_features.to(torch.bfloat16)
937
  maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
938
- pred_masks_gpu = current_out["pred_masks"] # (B, 1, H, W)
939
  # potentially fill holes in the predicted masks
940
  if self.fill_hole_area > 0:
941
- pred_masks_gpu = fill_holes_in_mask_scores(
942
- pred_masks_gpu, self.fill_hole_area
943
- )
944
  pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
945
  # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
946
  maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
947
  # object pointer is a small tensor, so we always keep it on GPU memory for fast access
948
  obj_ptr = current_out["obj_ptr"]
949
  object_score_logits = current_out["object_score_logits"]
950
- <<<<<<< HEAD
951
- best_iou_score = current_out["best_iou_score"]
952
- =======
953
- >>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
954
  # make a compact version of this frame's output to reduce the state size
955
  compact_current_out = {
956
- "maskmem_features": maskmem_features, # (B, C, H, W)
957
- "maskmem_pos_enc": maskmem_pos_enc,
958
  "pred_masks": pred_masks,
959
  "obj_ptr": obj_ptr,
960
  "object_score_logits": object_score_logits,
961
- <<<<<<< HEAD
962
- "best_iou_score": best_iou_score,
963
- =======
964
- >>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
965
  }
966
  return compact_current_out, pred_masks_gpu
967
 
@@ -980,9 +781,7 @@ class SAM2VideoPredictor(SAM2Base):
980
  memory also need to be computed again with the memory encoder.
981
  """
982
  # Retrieve correct image features
983
- _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
984
- inference_state, frame_idx, batch_size
985
- )
986
  maskmem_features, maskmem_pos_enc = self._encode_new_memory(
987
  current_vision_feats=current_vision_feats,
988
  feat_sizes=feat_sizes,
@@ -996,9 +795,7 @@ class SAM2VideoPredictor(SAM2Base):
996
  maskmem_features = maskmem_features.to(torch.bfloat16)
997
  maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
998
  # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
999
- maskmem_pos_enc = self._get_maskmem_pos_enc(
1000
- inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
1001
- )
1002
  return maskmem_features, maskmem_pos_enc
1003
 
1004
  def _get_maskmem_pos_enc(self, inference_state, current_out):
@@ -1019,9 +816,7 @@ class SAM2VideoPredictor(SAM2Base):
1019
  maskmem_pos_enc = model_constants["maskmem_pos_enc"]
1020
  # expand the cached maskmem_pos_enc to the actual batch size
1021
  batch_size = out_maskmem_pos_enc[0].size(0)
1022
- expanded_maskmem_pos_enc = [
1023
- x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
1024
- ]
1025
  else:
1026
  expanded_maskmem_pos_enc = None
1027
  return expanded_maskmem_pos_enc
@@ -1039,8 +834,7 @@ class SAM2VideoPredictor(SAM2Base):
1039
  if not strict:
1040
  return inference_state["obj_ids"], updated_frames
1041
  raise RuntimeError(
1042
- f"Cannot remove object id {obj_id} as it doesn't exist. "
1043
- f"All existing object ids: {inference_state['obj_ids']}."
1044
  )
1045
 
1046
  # If this is the only remaining object id, we simply reset the state.
@@ -1054,16 +848,10 @@ class SAM2VideoPredictor(SAM2Base):
1054
  # (note that this step is required as it might downgrade conditioning frames to
1055
  # non-conditioning ones)
1056
  obj_input_frames_inds = set()
1057
- obj_input_frames_inds.update(
1058
- inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
1059
- )
1060
- obj_input_frames_inds.update(
1061
- inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
1062
- )
1063
  for frame_idx in obj_input_frames_inds:
1064
- self.clear_all_prompts_in_frame(
1065
- inference_state, frame_idx, obj_id, need_output=False
1066
- )
1067
 
1068
  # Step 1: Update the object id mapping (note that it must be done after Step 0,
1069
  # since Step 0 still requires the old object id mappings in inference_state)
@@ -1080,11 +868,6 @@ class SAM2VideoPredictor(SAM2Base):
1080
  inference_state["obj_ids"] = new_obj_ids
1081
 
1082
  # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
1083
- <<<<<<< HEAD
1084
- # (note that "consolidated_frame_inds" doesn't need to be updated in this step as
1085
- # it's already handled in Step 0)
1086
- =======
1087
- >>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
1088
  def _map_keys(container):
1089
  new_kvs = []
1090
  for k in old_obj_inds:
@@ -1097,57 +880,23 @@ class SAM2VideoPredictor(SAM2Base):
1097
  _map_keys(inference_state["mask_inputs_per_obj"])
1098
  _map_keys(inference_state["output_dict_per_obj"])
1099
  _map_keys(inference_state["temp_output_dict_per_obj"])
1100
- <<<<<<< HEAD
1101
-
1102
- # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
1103
- def _slice_state(output_dict, storage_key):
1104
- for frame_idx, out in output_dict[storage_key].items():
1105
- out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
1106
- out["maskmem_pos_enc"] = [
1107
- x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
1108
- ]
1109
- # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
1110
- out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
1111
- out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
1112
- out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
1113
- out["object_score_logits"] = out["object_score_logits"][
1114
- remain_old_obj_inds
1115
- ]
1116
- # also update the per-object slices
1117
- self._add_output_per_object(
1118
- inference_state, frame_idx, out, storage_key
1119
- )
1120
-
1121
- _slice_state(inference_state["output_dict"], "cond_frame_outputs")
1122
- _slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
1123
-
1124
- # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
1125
- =======
1126
  _map_keys(inference_state["frames_tracked_per_obj"])
1127
 
1128
  # Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
1129
- >>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
1130
  # could show an updated mask for objects previously occluded by the object being removed
1131
  if need_output:
1132
  temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
1133
  for frame_idx in obj_input_frames_inds:
1134
  is_cond = any(
1135
- frame_idx in obj_temp_output_dict["cond_frame_outputs"]
1136
- for obj_temp_output_dict in temp_output_dict_per_obj.values()
1137
  )
1138
  consolidated_out = self._consolidate_temp_output_across_obj(
1139
  inference_state,
1140
  frame_idx,
1141
  is_cond=is_cond,
1142
- <<<<<<< HEAD
1143
- run_mem_encoder=False,
1144
- =======
1145
- >>>>>>> 2b90b9f5ceec907a1c18123530e92e794ad901a4
1146
  consolidate_at_video_res=True,
1147
  )
1148
- _, video_res_masks = self._get_orig_video_res_output(
1149
- inference_state, consolidated_out["pred_masks_video_res"]
1150
- )
1151
  updated_frames.append((frame_idx, video_res_masks))
1152
 
1153
  return inference_state["obj_ids"], updated_frames
@@ -1218,18 +967,12 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
1218
  if self.use_high_res_features_in_sam:
1219
  # precompute projected level 0 and level 1 features in SAM decoder
1220
  # to avoid running it again on every SAM click
1221
- backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
1222
- backbone_out["backbone_fpn"][0]
1223
- )
1224
- backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
1225
- backbone_out["backbone_fpn"][1]
1226
- )
1227
  # Clone to help torch.compile
1228
  for i in range(len(backbone_out["backbone_fpn"])):
1229
  backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
1230
- backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
1231
- i
1232
- ].clone()
1233
  return backbone_out
1234
 
1235
  def _forward_sam_heads(
@@ -1388,9 +1131,7 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
1388
  # optionally, apply non-overlapping constraints to the masks (it's applied
1389
  # in the batch dimension and should only be used during eval, where all
1390
  # the objects come from the same video under batch size 1).
1391
- pred_masks_high_res = self._apply_non_overlapping_constraints(
1392
- pred_masks_high_res
1393
- )
1394
  # scale the raw mask logits with a temperature before applying sigmoid
1395
  binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
1396
  if binarize and not self.training:
@@ -1403,9 +1144,7 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
1403
  mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
1404
  if self.sigmoid_bias_for_mem_enc != 0.0:
1405
  mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
1406
- maskmem_out = self.memory_encoder(
1407
- pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
1408
- )
1409
  # Clone the feats and pos_enc to enable compilation
1410
  maskmem_features = maskmem_out["vision_features"].clone()
1411
  maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
@@ -1413,9 +1152,7 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
1413
  # is predicted to be occluded (i.e. no object is appearing in the frame)
1414
  if self.no_obj_embed_spatial is not None:
1415
  is_obj_appearing = (object_score_logits > 0).float()
1416
- maskmem_features += (
1417
- 1 - is_obj_appearing[..., None, None]
1418
- ) * self.no_obj_embed_spatial[..., None, None].expand(
1419
  *maskmem_features.shape
1420
  )
1421
 
 
9
 
10
  import torch
11
  import torch.nn.functional as F
 
12
  from tqdm import tqdm
13
 
14
  from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
 
26
  # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
27
  # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
28
  clear_non_cond_mem_around_input=False,
 
 
 
 
 
29
  # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
30
  # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
31
  add_all_frames_to_correct_as_cond=False,
 
35
  self.fill_hole_area = fill_hole_area
36
  self.non_overlap_masks = non_overlap_masks
37
  self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
 
 
 
 
38
  self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
39
 
40
  @torch.inference_mode()
 
286
  is_cond=is_cond,
287
  consolidate_at_video_res=True,
288
  )
289
+ _, video_res_masks = self._get_orig_video_res_output(inference_state, consolidated_out["pred_masks_video_res"])
 
 
290
  return frame_idx, obj_ids, video_res_masks
291
 
292
  def add_new_points(self, *args, **kwargs):
 
372
  is_cond=is_cond,
373
  consolidate_at_video_res=True,
374
  )
375
+ _, video_res_masks = self._get_orig_video_res_output(inference_state, consolidated_out["pred_masks_video_res"])
 
 
376
  return frame_idx, obj_ids, video_res_masks
377
 
378
  def _get_orig_video_res_output(self, inference_state, any_res_masks):
 
436
  dtype=torch.float32,
437
  device=inference_state["storage_device"],
438
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  }
440
  for obj_idx in range(batch_size):
441
  obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
 
468
  align_corners=False,
469
  )
470
  consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
  return consolidated_out
473
 
 
477
  # Check and make sure that every object has received input points or masks.
478
  batch_size = self._get_obj_num(inference_state)
479
  if batch_size == 0:
480
+ raise RuntimeError("No input points or masks are provided for any object; please add inputs first.")
 
 
481
 
482
  # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
483
  # add them into "output_dict".
 
486
  obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
487
  for is_cond in [False, True]:
488
  # Separately consolidate conditioning and non-conditioning temp outputs
489
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
 
 
490
  # Find all the frames that contain temporary outputs for any objects
491
  # (these should be the frames that have just received clicks for mask inputs
492
  # via `add_new_points_or_box` or `add_new_mask`)
 
514
  obj_output_dict[storage_key][frame_idx] = out
515
  if self.clear_non_cond_mem_around_input:
516
  # clear non-conditioning memory of the surrounding frames
517
+ self._clear_obj_non_cond_mem_around_input(inference_state, frame_idx, obj_idx)
 
 
518
 
519
  # clear temporary outputs in `temp_output_dict_per_obj`
520
  obj_temp_output_dict[storage_key].clear()
 
523
  obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
524
  if len(obj_output_dict["cond_frame_outputs"]) == 0:
525
  obj_id = self._obj_idx_to_id(inference_state, obj_idx)
526
+ raise RuntimeError(f"No input points or masks are provided for object id {obj_id}; please add inputs first.")
 
 
527
  # edge case: if an output is added to "cond_frame_outputs", we remove any prior
528
  # output on the same frame in "non_cond_frame_outputs"
529
  for frame_idx in obj_output_dict["cond_frame_outputs"]:
 
548
  if start_frame_idx is None:
549
  # default: start from the earliest frame with input points
550
  start_frame_idx = min(
551
+ t for obj_output_dict in inference_state["output_dict_per_obj"].values() for t in obj_output_dict["cond_frame_outputs"]
 
 
552
  )
553
  if max_frame_num_to_track is None:
554
  # default: track all the frames in the video
 
560
  else:
561
  processing_order = [] # skip reverse tracking if starting from frame 0
562
  else:
563
+ end_frame_idx = min(start_frame_idx + max_frame_num_to_track, num_frames - 1)
 
 
564
  processing_order = range(start_frame_idx, end_frame_idx + 1)
565
 
566
  for frame_idx in tqdm(processing_order, desc="propagate in video"):
 
578
  pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
579
  if self.clear_non_cond_mem_around_input:
580
  # clear non-conditioning memory of the surrounding frames
581
+ self._clear_obj_non_cond_mem_around_input(inference_state, frame_idx, obj_idx)
 
 
582
  else:
583
  storage_key = "non_cond_frame_outputs"
584
  current_out, pred_masks = self._run_single_frame_inference(
 
594
  )
595
  obj_output_dict[storage_key][frame_idx] = current_out
596
 
597
+ inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {"reverse": reverse}
 
 
598
  pred_masks_per_obj[obj_idx] = pred_masks
599
 
600
  # Resize the output mask to the original video resolution (we directly use
 
603
  all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
604
  else:
605
  all_pred_masks = pred_masks_per_obj[0]
606
+ _, video_res_masks = self._get_orig_video_res_output(inference_state, all_pred_masks)
 
 
607
  yield frame_idx, obj_ids, video_res_masks
608
 
609
  @torch.inference_mode()
610
+ def clear_all_prompts_in_frame(self, inference_state, frame_idx, obj_id, need_output=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
  """Remove all input points or mask in a specific frame for a given object."""
612
  obj_idx = self._obj_id_to_idx(inference_state, obj_id)
613
 
 
632
  return
633
  # Finally, output updated masks per object (after removing the inputs above)
634
  obj_ids = inference_state["obj_ids"]
635
+ is_cond = any(frame_idx in obj_temp_output_dict["cond_frame_outputs"] for obj_temp_output_dict in temp_output_dict_per_obj.values())
 
 
 
636
  consolidated_out = self._consolidate_temp_output_across_obj(
637
  inference_state,
638
  frame_idx,
639
  is_cond=is_cond,
640
  consolidate_at_video_res=True,
641
  )
642
+ _, video_res_masks = self._get_orig_video_res_output(inference_state, consolidated_out["pred_masks_video_res"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  return frame_idx, obj_ids, video_res_masks
644
 
645
  @torch.inference_mode()
 
674
  def _get_image_feature(self, inference_state, frame_idx, batch_size):
675
  """Compute the image features on a given frame."""
676
  # Look up in the cache first
677
+ image, backbone_out = inference_state["cached_features"].get(frame_idx, (None, None))
 
 
678
  if backbone_out is None:
679
  # Cache miss -- we will run inference on a single image
680
  device = inference_state["device"]
 
691
  "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
692
  }
693
  for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
694
+ expanded_backbone_out["backbone_fpn"][i] = feat.expand(batch_size, -1, -1, -1)
 
 
695
  for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
696
  pos = pos.expand(batch_size, -1, -1, -1)
697
  expanded_backbone_out["vision_pos_enc"][i] = pos
 
746
  if maskmem_features is not None:
747
  maskmem_features = maskmem_features.to(torch.bfloat16)
748
  maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
749
+ pred_masks_gpu = current_out["pred_masks"]
750
  # potentially fill holes in the predicted masks
751
  if self.fill_hole_area > 0:
752
+ pred_masks_gpu = fill_holes_in_mask_scores(pred_masks_gpu, self.fill_hole_area)
 
 
753
  pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
754
  # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
755
  maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
756
  # object pointer is a small tensor, so we always keep it on GPU memory for fast access
757
  obj_ptr = current_out["obj_ptr"]
758
  object_score_logits = current_out["object_score_logits"]
 
 
 
 
759
  # make a compact version of this frame's output to reduce the state size
760
  compact_current_out = {
761
+ "maskmem_features": maskmem_features,
762
+ "maskmem_pos_enc": maskmem_pos_enc,
763
  "pred_masks": pred_masks,
764
  "obj_ptr": obj_ptr,
765
  "object_score_logits": object_score_logits,
 
 
 
 
766
  }
767
  return compact_current_out, pred_masks_gpu
768
 
 
781
  memory also need to be computed again with the memory encoder.
782
  """
783
  # Retrieve correct image features
784
+ _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(inference_state, frame_idx, batch_size)
 
 
785
  maskmem_features, maskmem_pos_enc = self._encode_new_memory(
786
  current_vision_feats=current_vision_feats,
787
  feat_sizes=feat_sizes,
 
795
  maskmem_features = maskmem_features.to(torch.bfloat16)
796
  maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
797
  # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
798
+ maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, {"maskmem_pos_enc": maskmem_pos_enc})
 
 
799
  return maskmem_features, maskmem_pos_enc
800
 
801
  def _get_maskmem_pos_enc(self, inference_state, current_out):
 
816
  maskmem_pos_enc = model_constants["maskmem_pos_enc"]
817
  # expand the cached maskmem_pos_enc to the actual batch size
818
  batch_size = out_maskmem_pos_enc[0].size(0)
819
+ expanded_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]
 
 
820
  else:
821
  expanded_maskmem_pos_enc = None
822
  return expanded_maskmem_pos_enc
 
834
  if not strict:
835
  return inference_state["obj_ids"], updated_frames
836
  raise RuntimeError(
837
+ f"Cannot remove object id {obj_id} as it doesn't exist. " f"All existing object ids: {inference_state['obj_ids']}."
 
838
  )
839
 
840
  # If this is the only remaining object id, we simply reset the state.
 
848
  # (note that this step is required as it might downgrade conditioning frames to
849
  # non-conditioning ones)
850
  obj_input_frames_inds = set()
851
+ obj_input_frames_inds.update(inference_state["point_inputs_per_obj"][old_obj_idx_to_rm])
852
+ obj_input_frames_inds.update(inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm])
 
 
 
 
853
  for frame_idx in obj_input_frames_inds:
854
+ self.clear_all_prompts_in_frame(inference_state, frame_idx, obj_id, need_output=False)
 
 
855
 
856
  # Step 1: Update the object id mapping (note that it must be done after Step 0,
857
  # since Step 0 still requires the old object id mappings in inference_state)
 
868
  inference_state["obj_ids"] = new_obj_ids
869
 
870
  # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
 
 
 
 
 
871
  def _map_keys(container):
872
  new_kvs = []
873
  for k in old_obj_inds:
 
880
  _map_keys(inference_state["mask_inputs_per_obj"])
881
  _map_keys(inference_state["output_dict_per_obj"])
882
  _map_keys(inference_state["temp_output_dict_per_obj"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883
  _map_keys(inference_state["frames_tracked_per_obj"])
884
 
885
  # Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
 
886
  # could show an updated mask for objects previously occluded by the object being removed
887
  if need_output:
888
  temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
889
  for frame_idx in obj_input_frames_inds:
890
  is_cond = any(
891
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"] for obj_temp_output_dict in temp_output_dict_per_obj.values()
 
892
  )
893
  consolidated_out = self._consolidate_temp_output_across_obj(
894
  inference_state,
895
  frame_idx,
896
  is_cond=is_cond,
 
 
 
 
897
  consolidate_at_video_res=True,
898
  )
899
+ _, video_res_masks = self._get_orig_video_res_output(inference_state, consolidated_out["pred_masks_video_res"])
 
 
900
  updated_frames.append((frame_idx, video_res_masks))
901
 
902
  return inference_state["obj_ids"], updated_frames
 
967
  if self.use_high_res_features_in_sam:
968
  # precompute projected level 0 and level 1 features in SAM decoder
969
  # to avoid running it again on every SAM click
970
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
971
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
 
 
 
 
972
  # Clone to help torch.compile
973
  for i in range(len(backbone_out["backbone_fpn"])):
974
  backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
975
+ backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][i].clone()
 
 
976
  return backbone_out
977
 
978
  def _forward_sam_heads(
 
1131
  # optionally, apply non-overlapping constraints to the masks (it's applied
1132
  # in the batch dimension and should only be used during eval, where all
1133
  # the objects come from the same video under batch size 1).
1134
+ pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
 
 
1135
  # scale the raw mask logits with a temperature before applying sigmoid
1136
  binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
1137
  if binarize and not self.training:
 
1144
  mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
1145
  if self.sigmoid_bias_for_mem_enc != 0.0:
1146
  mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
1147
+ maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
 
 
1148
  # Clone the feats and pos_enc to enable compilation
1149
  maskmem_features = maskmem_out["vision_features"].clone()
1150
  maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
 
1152
  # is predicted to be occluded (i.e. no object is appearing in the frame)
1153
  if self.no_obj_embed_spatial is not None:
1154
  is_obj_appearing = (object_score_logits > 0).float()
1155
+ maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[..., None, None].expand(
 
 
1156
  *maskmem_features.shape
1157
  )
1158