Spaces:
Running on Zero
Running on Zero
xiaoyuxi commited on
Commit ·
e43b66a
1
Parent(s): 4d35051
vggt_da
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +69 -0
- _viz/viz_template.html +1769 -0
- app.py +508 -645
- app_3rd/README.md +12 -0
- app_3rd/sam_utils/hf_sam_predictor.py +129 -0
- app_3rd/sam_utils/inference.py +123 -0
- app_3rd/spatrack_utils/infer_track.py +195 -0
- app_release.py +1278 -0
- config/__init__.py +0 -0
- config/magic_infer_moge.yaml +48 -0
- frontend_app_local.py +1036 -0
- models/SpaTrackV2/models/SpaTrack.py +758 -0
- models/SpaTrackV2/models/__init__.py +0 -0
- models/SpaTrackV2/models/blocks.py +519 -0
- models/SpaTrackV2/models/camera_transform.py +248 -0
- models/SpaTrackV2/models/depth_refiner/backbone.py +472 -0
- models/SpaTrackV2/models/depth_refiner/decode_head.py +619 -0
- models/SpaTrackV2/models/depth_refiner/depth_refiner.py +115 -0
- models/SpaTrackV2/models/depth_refiner/network.py +429 -0
- models/SpaTrackV2/models/depth_refiner/stablilization_attention.py +1187 -0
- models/SpaTrackV2/models/depth_refiner/stablizer.py +342 -0
- models/SpaTrackV2/models/predictor.py +153 -0
- models/SpaTrackV2/models/tracker3D/TrackRefiner.py +1478 -0
- models/SpaTrackV2/models/tracker3D/co_tracker/cotracker_base.py +418 -0
- models/SpaTrackV2/models/tracker3D/co_tracker/utils.py +929 -0
- models/SpaTrackV2/models/tracker3D/delta_utils/__init__.py +0 -0
- models/SpaTrackV2/models/tracker3D/delta_utils/blocks.py +842 -0
- models/SpaTrackV2/models/tracker3D/delta_utils/upsample_transformer.py +438 -0
- models/SpaTrackV2/models/tracker3D/spatrack_modules/alignment.py +471 -0
- models/SpaTrackV2/models/tracker3D/spatrack_modules/ba.py +538 -0
- models/SpaTrackV2/models/tracker3D/spatrack_modules/blocks.py +15 -0
- models/SpaTrackV2/models/tracker3D/spatrack_modules/dynamic_point_refine.py +0 -0
- models/SpaTrackV2/models/tracker3D/spatrack_modules/geometry_numpy.py +401 -0
- models/SpaTrackV2/models/tracker3D/spatrack_modules/geometry_torch.py +323 -0
- models/SpaTrackV2/models/tracker3D/spatrack_modules/pointmap_updator.py +104 -0
- models/SpaTrackV2/models/tracker3D/spatrack_modules/simple_vit_1d.py +125 -0
- models/SpaTrackV2/models/tracker3D/spatrack_modules/tools.py +289 -0
- models/SpaTrackV2/models/tracker3D/spatrack_modules/utils.py +1006 -0
- models/SpaTrackV2/models/utils.py +1221 -0
- models/SpaTrackV2/utils/embeddings.py +247 -0
- models/SpaTrackV2/utils/model_utils.py +444 -0
- models/SpaTrackV2/utils/visualizer.py +352 -0
- models/moge/__init__.py +0 -0
- models/moge/model/__init__.py +18 -0
- models/moge/model/dinov2/__init__.py +6 -0
- models/moge/model/dinov2/hub/__init__.py +4 -0
- models/moge/model/dinov2/hub/backbones.py +156 -0
- models/moge/model/dinov2/hub/utils.py +39 -0
- models/moge/model/dinov2/layers/__init__.py +11 -0
- models/moge/model/dinov2/layers/attention.py +89 -0
.gitignore
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ignore the multi media
|
| 2 |
+
checkpoints
|
| 3 |
+
**/checkpoints/
|
| 4 |
+
**/temp/
|
| 5 |
+
temp
|
| 6 |
+
assets_dev
|
| 7 |
+
assets/example0/results
|
| 8 |
+
assets/example0/snowboard.npz
|
| 9 |
+
assets/example1/results
|
| 10 |
+
assets/davis_eval
|
| 11 |
+
assets/*/results
|
| 12 |
+
*gradio*
|
| 13 |
+
#
|
| 14 |
+
models/monoD/zoeDepth/ckpts/*
|
| 15 |
+
models/monoD/depth_anything/ckpts/*
|
| 16 |
+
vis_results
|
| 17 |
+
dist_encrypted
|
| 18 |
+
# remove the dependencies
|
| 19 |
+
deps
|
| 20 |
+
|
| 21 |
+
# filter the __pycache__ files
|
| 22 |
+
__pycache__/
|
| 23 |
+
/**/**/__pycache__
|
| 24 |
+
/**/__pycache__
|
| 25 |
+
|
| 26 |
+
outputs
|
| 27 |
+
scripts/lauch_exp/config
|
| 28 |
+
scripts/lauch_exp/submit_job.log
|
| 29 |
+
scripts/lauch_exp/hydra_output
|
| 30 |
+
scripts/lauch_wulan
|
| 31 |
+
scripts/custom_video
|
| 32 |
+
# ignore the visualizer
|
| 33 |
+
viser
|
| 34 |
+
viser_result
|
| 35 |
+
benchmark/results
|
| 36 |
+
benchmark
|
| 37 |
+
|
| 38 |
+
ossutil_output
|
| 39 |
+
|
| 40 |
+
prev_version
|
| 41 |
+
spat_ceres
|
| 42 |
+
wandb
|
| 43 |
+
*.log
|
| 44 |
+
seg_target.py
|
| 45 |
+
|
| 46 |
+
eval_davis.py
|
| 47 |
+
eval_multiple_gpu.py
|
| 48 |
+
eval_pose_scan.py
|
| 49 |
+
eval_single_gpu.py
|
| 50 |
+
|
| 51 |
+
infer_cam.py
|
| 52 |
+
infer_stream.py
|
| 53 |
+
|
| 54 |
+
*.egg-info/
|
| 55 |
+
**/*.egg-info
|
| 56 |
+
|
| 57 |
+
eval_kinectics.py
|
| 58 |
+
models/SpaTrackV2/datasets
|
| 59 |
+
|
| 60 |
+
scripts
|
| 61 |
+
config/fix_2d.yaml
|
| 62 |
+
|
| 63 |
+
models/SpaTrackV2/datasets
|
| 64 |
+
scripts/
|
| 65 |
+
|
| 66 |
+
models/**/build
|
| 67 |
+
models/**/dist
|
| 68 |
+
|
| 69 |
+
temp_local
|
_viz/viz_template.html
ADDED
|
@@ -0,0 +1,1769 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>3D Point Cloud Visualizer</title>
|
| 7 |
+
<style>
|
| 8 |
+
:root {
|
| 9 |
+
--primary: #9b59b6; /* Brighter purple for dark mode */
|
| 10 |
+
--primary-light: #3a2e4a;
|
| 11 |
+
--secondary: #a86add;
|
| 12 |
+
--accent: #ff6e6e;
|
| 13 |
+
--bg: #1a1a1a;
|
| 14 |
+
--surface: #2c2c2c;
|
| 15 |
+
--text: #e0e0e0;
|
| 16 |
+
--text-secondary: #a0a0a0;
|
| 17 |
+
--border: #444444;
|
| 18 |
+
--shadow: rgba(0, 0, 0, 0.2);
|
| 19 |
+
--shadow-hover: rgba(0, 0, 0, 0.3);
|
| 20 |
+
|
| 21 |
+
--space-sm: 16px;
|
| 22 |
+
--space-md: 24px;
|
| 23 |
+
--space-lg: 32px;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
body {
|
| 27 |
+
margin: 0;
|
| 28 |
+
overflow: hidden;
|
| 29 |
+
background: var(--bg);
|
| 30 |
+
color: var(--text);
|
| 31 |
+
font-family: 'Inter', sans-serif;
|
| 32 |
+
-webkit-font-smoothing: antialiased;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
#canvas-container {
|
| 36 |
+
position: absolute;
|
| 37 |
+
width: 100%;
|
| 38 |
+
height: 100%;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
#ui-container {
|
| 42 |
+
position: absolute;
|
| 43 |
+
top: 0;
|
| 44 |
+
left: 0;
|
| 45 |
+
width: 100%;
|
| 46 |
+
height: 100%;
|
| 47 |
+
pointer-events: none;
|
| 48 |
+
z-index: 10;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
#status-bar {
|
| 52 |
+
position: absolute;
|
| 53 |
+
top: 16px;
|
| 54 |
+
left: 16px;
|
| 55 |
+
background: rgba(30, 30, 30, 0.9);
|
| 56 |
+
padding: 8px 16px;
|
| 57 |
+
border-radius: 8px;
|
| 58 |
+
pointer-events: auto;
|
| 59 |
+
box-shadow: 0 4px 6px var(--shadow);
|
| 60 |
+
backdrop-filter: blur(4px);
|
| 61 |
+
border: 1px solid var(--border);
|
| 62 |
+
color: var(--text);
|
| 63 |
+
transition: opacity 0.5s ease, transform 0.5s ease;
|
| 64 |
+
font-weight: 500;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
#status-bar.hidden {
|
| 68 |
+
opacity: 0;
|
| 69 |
+
transform: translateY(-20px);
|
| 70 |
+
pointer-events: none;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
#control-panel {
|
| 74 |
+
position: absolute;
|
| 75 |
+
bottom: 16px;
|
| 76 |
+
left: 50%;
|
| 77 |
+
transform: translateX(-50%);
|
| 78 |
+
background: rgba(44, 44, 44, 0.95);
|
| 79 |
+
padding: 12px 16px;
|
| 80 |
+
border-radius: 12px;
|
| 81 |
+
display: flex;
|
| 82 |
+
gap: 16px;
|
| 83 |
+
align-items: center;
|
| 84 |
+
pointer-events: auto;
|
| 85 |
+
box-shadow: 0 4px 10px var(--shadow);
|
| 86 |
+
backdrop-filter: blur(4px);
|
| 87 |
+
border: 1px solid var(--border);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
#timeline {
|
| 91 |
+
width: 400px;
|
| 92 |
+
height: 8px;
|
| 93 |
+
background: rgba(255, 255, 255, 0.1);
|
| 94 |
+
border-radius: 4px;
|
| 95 |
+
position: relative;
|
| 96 |
+
cursor: pointer;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
#progress {
|
| 100 |
+
position: absolute;
|
| 101 |
+
height: 100%;
|
| 102 |
+
background: var(--primary);
|
| 103 |
+
border-radius: 4px;
|
| 104 |
+
width: 0%;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
#playback-controls {
|
| 108 |
+
display: flex;
|
| 109 |
+
gap: 8px;
|
| 110 |
+
align-items: center;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
button {
|
| 114 |
+
background: rgba(255, 255, 255, 0.08);
|
| 115 |
+
border: 1px solid var(--border);
|
| 116 |
+
color: var(--text);
|
| 117 |
+
padding: 8px 12px;
|
| 118 |
+
border-radius: 6px;
|
| 119 |
+
cursor: pointer;
|
| 120 |
+
display: flex;
|
| 121 |
+
align-items: center;
|
| 122 |
+
justify-content: center;
|
| 123 |
+
transition: background 0.2s, transform 0.2s;
|
| 124 |
+
font-family: 'Inter', sans-serif;
|
| 125 |
+
font-weight: 500;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
button:hover {
|
| 129 |
+
background: rgba(255, 255, 255, 0.15);
|
| 130 |
+
transform: translateY(-1px);
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
button.active {
|
| 134 |
+
background: var(--primary);
|
| 135 |
+
color: white;
|
| 136 |
+
box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
select, input {
|
| 140 |
+
background: rgba(255, 255, 255, 0.08);
|
| 141 |
+
border: 1px solid var(--border);
|
| 142 |
+
color: var(--text);
|
| 143 |
+
padding: 8px 12px;
|
| 144 |
+
border-radius: 6px;
|
| 145 |
+
cursor: pointer;
|
| 146 |
+
font-family: 'Inter', sans-serif;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
.icon {
|
| 150 |
+
width: 20px;
|
| 151 |
+
height: 20px;
|
| 152 |
+
fill: currentColor;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
.tooltip {
|
| 156 |
+
position: absolute;
|
| 157 |
+
bottom: 100%;
|
| 158 |
+
left: 50%;
|
| 159 |
+
transform: translateX(-50%);
|
| 160 |
+
background: var(--surface);
|
| 161 |
+
color: var(--text);
|
| 162 |
+
padding: 6px 12px;
|
| 163 |
+
border-radius: 6px;
|
| 164 |
+
font-size: 14px;
|
| 165 |
+
white-space: nowrap;
|
| 166 |
+
margin-bottom: 8px;
|
| 167 |
+
opacity: 0;
|
| 168 |
+
transition: opacity 0.2s;
|
| 169 |
+
pointer-events: none;
|
| 170 |
+
box-shadow: 0 2px 4px var(--shadow);
|
| 171 |
+
border: 1px solid var(--border);
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
button:hover .tooltip {
|
| 175 |
+
opacity: 1;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
#settings-panel {
|
| 179 |
+
position: absolute;
|
| 180 |
+
top: 16px;
|
| 181 |
+
right: 16px;
|
| 182 |
+
background: rgba(44, 44, 44, 0.98);
|
| 183 |
+
padding: 20px;
|
| 184 |
+
border-radius: 12px;
|
| 185 |
+
width: 300px;
|
| 186 |
+
max-height: calc(100vh - 40px);
|
| 187 |
+
overflow-y: auto;
|
| 188 |
+
pointer-events: auto;
|
| 189 |
+
box-shadow: 0 4px 15px var(--shadow);
|
| 190 |
+
backdrop-filter: blur(4px);
|
| 191 |
+
border: 1px solid var(--border);
|
| 192 |
+
display: block;
|
| 193 |
+
opacity: 1;
|
| 194 |
+
scrollbar-width: thin;
|
| 195 |
+
scrollbar-color: var(--primary-light) transparent;
|
| 196 |
+
transition: transform 0.35s ease-in-out, opacity 0.3s ease-in-out;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
#settings-panel.is-hidden {
|
| 200 |
+
transform: translateX(calc(100% + 20px));
|
| 201 |
+
opacity: 0;
|
| 202 |
+
pointer-events: none;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
#settings-panel::-webkit-scrollbar {
|
| 206 |
+
width: 6px;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
#settings-panel::-webkit-scrollbar-track {
|
| 210 |
+
background: transparent;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
#settings-panel::-webkit-scrollbar-thumb {
|
| 214 |
+
background-color: var(--primary-light);
|
| 215 |
+
border-radius: 6px;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
@media (max-height: 700px) {
|
| 219 |
+
#settings-panel {
|
| 220 |
+
max-height: calc(100vh - 40px);
|
| 221 |
+
}
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
@media (max-width: 768px) {
|
| 225 |
+
#control-panel {
|
| 226 |
+
width: 90%;
|
| 227 |
+
flex-wrap: wrap;
|
| 228 |
+
justify-content: center;
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
#timeline {
|
| 232 |
+
width: 100%;
|
| 233 |
+
order: 3;
|
| 234 |
+
margin-top: 10px;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
#settings-panel {
|
| 238 |
+
width: 280px;
|
| 239 |
+
right: 10px;
|
| 240 |
+
top: 10px;
|
| 241 |
+
max-height: calc(100vh - 20px);
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
.settings-group {
|
| 246 |
+
margin-bottom: 16px;
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
.settings-group h3 {
|
| 250 |
+
margin: 0 0 8px 0;
|
| 251 |
+
font-size: 14px;
|
| 252 |
+
font-weight: 500;
|
| 253 |
+
color: var(--text-secondary);
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
.slider-container {
|
| 257 |
+
display: flex;
|
| 258 |
+
align-items: center;
|
| 259 |
+
gap: 12px;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
.slider-container label {
|
| 263 |
+
min-width: 80px;
|
| 264 |
+
font-size: 14px;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
input[type="range"] {
|
| 268 |
+
flex-grow: 1;
|
| 269 |
+
height: 4px;
|
| 270 |
+
-webkit-appearance: none;
|
| 271 |
+
background: rgba(255, 255, 255, 0.1);
|
| 272 |
+
border-radius: 2px;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
input[type="range"]::-webkit-slider-thumb {
|
| 276 |
+
-webkit-appearance: none;
|
| 277 |
+
width: 16px;
|
| 278 |
+
height: 16px;
|
| 279 |
+
border-radius: 50%;
|
| 280 |
+
background: var(--primary);
|
| 281 |
+
cursor: pointer;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
.toggle-switch {
|
| 285 |
+
position: relative;
|
| 286 |
+
display: inline-block;
|
| 287 |
+
width: 40px;
|
| 288 |
+
height: 20px;
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
.toggle-switch input {
|
| 292 |
+
opacity: 0;
|
| 293 |
+
width: 0;
|
| 294 |
+
height: 0;
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
.toggle-slider {
|
| 298 |
+
position: absolute;
|
| 299 |
+
cursor: pointer;
|
| 300 |
+
top: 0;
|
| 301 |
+
left: 0;
|
| 302 |
+
right: 0;
|
| 303 |
+
bottom: 0;
|
| 304 |
+
background: rgba(255, 255, 255, 0.1);
|
| 305 |
+
transition: .4s;
|
| 306 |
+
border-radius: 20px;
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
.toggle-slider:before {
|
| 310 |
+
position: absolute;
|
| 311 |
+
content: "";
|
| 312 |
+
height: 16px;
|
| 313 |
+
width: 16px;
|
| 314 |
+
left: 2px;
|
| 315 |
+
bottom: 2px;
|
| 316 |
+
background: var(--surface);
|
| 317 |
+
border: 1px solid var(--border);
|
| 318 |
+
transition: .4s;
|
| 319 |
+
border-radius: 50%;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
input:checked + .toggle-slider {
|
| 323 |
+
background: var(--primary);
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
input:checked + .toggle-slider:before {
|
| 327 |
+
transform: translateX(20px);
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
.checkbox-container {
|
| 331 |
+
display: flex;
|
| 332 |
+
align-items: center;
|
| 333 |
+
gap: 8px;
|
| 334 |
+
margin-bottom: 8px;
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
.checkbox-container label {
|
| 338 |
+
font-size: 14px;
|
| 339 |
+
cursor: pointer;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
#loading-overlay {
|
| 343 |
+
position: absolute;
|
| 344 |
+
top: 0;
|
| 345 |
+
left: 0;
|
| 346 |
+
width: 100%;
|
| 347 |
+
height: 100%;
|
| 348 |
+
background: var(--bg);
|
| 349 |
+
display: flex;
|
| 350 |
+
flex-direction: column;
|
| 351 |
+
align-items: center;
|
| 352 |
+
justify-content: center;
|
| 353 |
+
z-index: 100;
|
| 354 |
+
transition: opacity 0.5s;
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
#loading-overlay.fade-out {
|
| 358 |
+
opacity: 0;
|
| 359 |
+
pointer-events: none;
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
.spinner {
|
| 363 |
+
width: 50px;
|
| 364 |
+
height: 50px;
|
| 365 |
+
border: 5px solid rgba(155, 89, 182, 0.2);
|
| 366 |
+
border-radius: 50%;
|
| 367 |
+
border-top-color: var(--primary);
|
| 368 |
+
animation: spin 1s ease-in-out infinite;
|
| 369 |
+
margin-bottom: 16px;
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
@keyframes spin {
|
| 373 |
+
to { transform: rotate(360deg); }
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
#loading-text {
|
| 377 |
+
margin-top: 16px;
|
| 378 |
+
font-size: 18px;
|
| 379 |
+
color: var(--text);
|
| 380 |
+
font-weight: 500;
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
#frame-counter {
|
| 384 |
+
color: var(--text-secondary);
|
| 385 |
+
font-size: 14px;
|
| 386 |
+
font-weight: 500;
|
| 387 |
+
min-width: 120px;
|
| 388 |
+
text-align: center;
|
| 389 |
+
padding: 0 8px;
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
.control-btn {
|
| 393 |
+
background: rgba(255, 255, 255, 0.08);
|
| 394 |
+
border: 1px solid var(--border);
|
| 395 |
+
padding: 8px 12px;
|
| 396 |
+
border-radius: 6px;
|
| 397 |
+
cursor: pointer;
|
| 398 |
+
display: flex;
|
| 399 |
+
align-items: center;
|
| 400 |
+
justify-content: center;
|
| 401 |
+
transition: all 0.2s ease;
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
.control-btn:hover {
|
| 405 |
+
background: rgba(255, 255, 255, 0.15);
|
| 406 |
+
transform: translateY(-1px);
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
.control-btn.active {
|
| 410 |
+
background: var(--primary);
|
| 411 |
+
color: white;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
.control-btn.active:hover {
|
| 415 |
+
background: var(--primary);
|
| 416 |
+
box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
#settings-toggle-btn {
|
| 420 |
+
position: relative;
|
| 421 |
+
border-radius: 6px;
|
| 422 |
+
z-index: 20;
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
#settings-toggle-btn.active {
|
| 426 |
+
background: var(--primary);
|
| 427 |
+
color: white;
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
#status-bar,
|
| 431 |
+
#control-panel,
|
| 432 |
+
#settings-panel,
|
| 433 |
+
button,
|
| 434 |
+
input,
|
| 435 |
+
select,
|
| 436 |
+
.toggle-switch {
|
| 437 |
+
pointer-events: auto;
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
h2 {
|
| 441 |
+
font-size: 1.2rem;
|
| 442 |
+
font-weight: 600;
|
| 443 |
+
margin-top: 0;
|
| 444 |
+
margin-bottom: var(--space-md);
|
| 445 |
+
color: var(--primary);
|
| 446 |
+
cursor: move;
|
| 447 |
+
user-select: none;
|
| 448 |
+
display: flex;
|
| 449 |
+
align-items: center;
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
.drag-handle {
|
| 453 |
+
font-size: 14px;
|
| 454 |
+
margin-right: 8px;
|
| 455 |
+
opacity: 0.6;
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
h2:hover .drag-handle {
|
| 459 |
+
opacity: 1;
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
.loading-subtitle {
|
| 463 |
+
font-size: 14px;
|
| 464 |
+
color: var(--text-secondary);
|
| 465 |
+
margin-top: 8px;
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
#reset-view-btn {
|
| 469 |
+
background: var(--primary-light);
|
| 470 |
+
color: var(--primary);
|
| 471 |
+
border: 1px solid rgba(155, 89, 182, 0.2);
|
| 472 |
+
font-weight: 600;
|
| 473 |
+
transition: all 0.2s;
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
#reset-view-btn:hover {
|
| 477 |
+
background: var(--primary);
|
| 478 |
+
color: white;
|
| 479 |
+
transform: translateY(-2px);
|
| 480 |
+
box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
#settings-panel.visible {
|
| 484 |
+
display: block;
|
| 485 |
+
opacity: 1;
|
| 486 |
+
animation: slideIn 0.3s ease forwards;
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
@keyframes slideIn {
|
| 490 |
+
from {
|
| 491 |
+
transform: translateY(20px);
|
| 492 |
+
opacity: 0;
|
| 493 |
+
}
|
| 494 |
+
to {
|
| 495 |
+
transform: translateY(0);
|
| 496 |
+
opacity: 1;
|
| 497 |
+
}
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
.dragging {
|
| 501 |
+
opacity: 0.9;
|
| 502 |
+
box-shadow: 0 8px 20px rgba(0, 0, 0, 0.15) !important;
|
| 503 |
+
transition: none !important;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
/* Tooltip for draggable element */
|
| 507 |
+
.tooltip-drag {
|
| 508 |
+
position: absolute;
|
| 509 |
+
left: 50%;
|
| 510 |
+
transform: translateX(-50%);
|
| 511 |
+
background: var(--primary);
|
| 512 |
+
color: white;
|
| 513 |
+
font-size: 12px;
|
| 514 |
+
padding: 4px 8px;
|
| 515 |
+
border-radius: 4px;
|
| 516 |
+
opacity: 0;
|
| 517 |
+
pointer-events: none;
|
| 518 |
+
transition: opacity 0.3s;
|
| 519 |
+
white-space: nowrap;
|
| 520 |
+
bottom: 100%;
|
| 521 |
+
margin-bottom: 8px;
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
h2:hover .tooltip-drag {
|
| 525 |
+
opacity: 1;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
.btn-group {
|
| 529 |
+
display: flex;
|
| 530 |
+
margin-top: 16px;
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
#reset-view-btn, #reset-settings-btn {
|
| 534 |
+
background: var(--primary-light);
|
| 535 |
+
color: var(--primary);
|
| 536 |
+
border: 1px solid rgba(155, 89, 182, 0.2);
|
| 537 |
+
font-weight: 600;
|
| 538 |
+
transition: all 0.2s;
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
#reset-view-btn:hover, #reset-settings-btn:hover {
|
| 542 |
+
background: var(--primary);
|
| 543 |
+
color: white;
|
| 544 |
+
transform: translateY(-2px);
|
| 545 |
+
box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
#show-settings-btn {
|
| 549 |
+
position: absolute;
|
| 550 |
+
top: 16px;
|
| 551 |
+
right: 16px;
|
| 552 |
+
z-index: 15;
|
| 553 |
+
display: none;
|
| 554 |
+
}
|
| 555 |
+
</style>
|
| 556 |
+
</head>
|
| 557 |
+
<body>
|
| 558 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 559 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 560 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
|
| 561 |
+
|
| 562 |
+
<div id="canvas-container"></div>
|
| 563 |
+
|
| 564 |
+
<div id="ui-container">
|
| 565 |
+
<div id="status-bar">Initializing...</div>
|
| 566 |
+
|
| 567 |
+
<div id="control-panel">
|
| 568 |
+
<button id="play-pause-btn" class="control-btn">
|
| 569 |
+
<svg class="icon" viewBox="0 0 24 24">
|
| 570 |
+
<path id="play-icon" d="M8 5v14l11-7z"/>
|
| 571 |
+
<path id="pause-icon" d="M6 19h4V5H6v14zm8-14v14h4V5h-4z" style="display: none;"/>
|
| 572 |
+
</svg>
|
| 573 |
+
<span class="tooltip">Play/Pause</span>
|
| 574 |
+
</button>
|
| 575 |
+
|
| 576 |
+
<div id="timeline">
|
| 577 |
+
<div id="progress"></div>
|
| 578 |
+
</div>
|
| 579 |
+
|
| 580 |
+
<div id="frame-counter">Frame 0 / 0</div>
|
| 581 |
+
|
| 582 |
+
<div id="playback-controls">
|
| 583 |
+
<button id="speed-btn" class="control-btn">1x</button>
|
| 584 |
+
</div>
|
| 585 |
+
</div>
|
| 586 |
+
|
| 587 |
+
<div id="settings-panel">
|
| 588 |
+
<h2>
|
| 589 |
+
<span class="drag-handle">☰</span>
|
| 590 |
+
Visualization Settings
|
| 591 |
+
<button id="hide-settings-btn" class="control-btn" style="margin-left: auto; padding: 4px;" title="Hide Panel">
|
| 592 |
+
<svg class="icon" viewBox="0 0 24 24" style="width: 18px; height: 18px;">
|
| 593 |
+
<path d="M14.59 7.41L18.17 11H4v2h14.17l-3.58 3.59L16 18l6-6-6-6-1.41 1.41z"/>
|
| 594 |
+
</svg>
|
| 595 |
+
</button>
|
| 596 |
+
</h2>
|
| 597 |
+
|
| 598 |
+
<div class="settings-group">
|
| 599 |
+
<h3>Point Cloud</h3>
|
| 600 |
+
<div class="slider-container">
|
| 601 |
+
<label for="point-size">Size</label>
|
| 602 |
+
<input type="range" id="point-size" min="0.005" max="0.1" step="0.005" value="0.03">
|
| 603 |
+
</div>
|
| 604 |
+
<div class="slider-container">
|
| 605 |
+
<label for="point-opacity">Opacity</label>
|
| 606 |
+
<input type="range" id="point-opacity" min="0.1" max="1" step="0.05" value="1">
|
| 607 |
+
</div>
|
| 608 |
+
<div class="slider-container">
|
| 609 |
+
<label for="max-depth">Max Depth</label>
|
| 610 |
+
<input type="range" id="max-depth" min="0.1" max="10" step="0.2" value="100">
|
| 611 |
+
</div>
|
| 612 |
+
</div>
|
| 613 |
+
|
| 614 |
+
<div class="settings-group">
|
| 615 |
+
<h3>Trajectory</h3>
|
| 616 |
+
<div class="checkbox-container">
|
| 617 |
+
<label class="toggle-switch">
|
| 618 |
+
<input type="checkbox" id="show-trajectory" checked>
|
| 619 |
+
<span class="toggle-slider"></span>
|
| 620 |
+
</label>
|
| 621 |
+
<label for="show-trajectory">Show Trajectory</label>
|
| 622 |
+
</div>
|
| 623 |
+
<div class="checkbox-container">
|
| 624 |
+
<label class="toggle-switch">
|
| 625 |
+
<input type="checkbox" id="enable-rich-trail">
|
| 626 |
+
<span class="toggle-slider"></span>
|
| 627 |
+
</label>
|
| 628 |
+
<label for="enable-rich-trail">Visual-Rich Trail</label>
|
| 629 |
+
</div>
|
| 630 |
+
<div class="slider-container">
|
| 631 |
+
<label for="trajectory-line-width">Line Width</label>
|
| 632 |
+
<input type="range" id="trajectory-line-width" min="0.5" max="5" step="0.5" value="1.5">
|
| 633 |
+
</div>
|
| 634 |
+
<div class="slider-container">
|
| 635 |
+
<label for="trajectory-ball-size">Ball Size</label>
|
| 636 |
+
<input type="range" id="trajectory-ball-size" min="0.005" max="0.05" step="0.001" value="0.02">
|
| 637 |
+
</div>
|
| 638 |
+
<div class="slider-container">
|
| 639 |
+
<label for="trajectory-history">History Frames</label>
|
| 640 |
+
<input type="range" id="trajectory-history" min="1" max="500" step="1" value="30">
|
| 641 |
+
</div>
|
| 642 |
+
<div class="slider-container" id="tail-opacity-container" style="display: none;">
|
| 643 |
+
<label for="trajectory-fade">Tail Opacity</label>
|
| 644 |
+
<input type="range" id="trajectory-fade" min="0" max="1" step="0.05" value="0.0">
|
| 645 |
+
</div>
|
| 646 |
+
</div>
|
| 647 |
+
|
| 648 |
+
<div class="settings-group">
|
| 649 |
+
<h3>Camera</h3>
|
| 650 |
+
<div class="checkbox-container">
|
| 651 |
+
<label class="toggle-switch">
|
| 652 |
+
<input type="checkbox" id="show-camera-frustum" checked>
|
| 653 |
+
<span class="toggle-slider"></span>
|
| 654 |
+
</label>
|
| 655 |
+
<label for="show-camera-frustum">Show Camera Frustum</label>
|
| 656 |
+
</div>
|
| 657 |
+
<div class="slider-container">
|
| 658 |
+
<label for="frustum-size">Size</label>
|
| 659 |
+
<input type="range" id="frustum-size" min="0.02" max="0.5" step="0.01" value="0.2">
|
| 660 |
+
</div>
|
| 661 |
+
</div>
|
| 662 |
+
|
| 663 |
+
<div class="settings-group">
|
| 664 |
+
<div class="btn-group">
|
| 665 |
+
<button id="reset-view-btn" style="flex: 1; margin-right: 5px;">Reset View</button>
|
| 666 |
+
<button id="reset-settings-btn" style="flex: 1; margin-left: 5px;">Reset Settings</button>
|
| 667 |
+
</div>
|
| 668 |
+
</div>
|
| 669 |
+
</div>
|
| 670 |
+
|
| 671 |
+
<button id="show-settings-btn" class="control-btn" title="Show Settings">
|
| 672 |
+
<svg class="icon" viewBox="0 0 24 24">
|
| 673 |
+
<path d="M19.14,12.94c0.04-0.3,0.06-0.61,0.06-0.94c0-0.32-0.02-0.64-0.07-0.94l2.03-1.58c0.18-0.14,0.23-0.41,0.12-0.61 l-1.92-3.32c-0.12-0.22-0.37-0.29-0.59-0.22l-2.39,0.96c-0.5-0.38-1.03-0.7-1.62-0.94L14.4,2.81c-0.04-0.24-0.24-0.41-0.48-0.41 h-3.84c-0.24,0-0.43,0.17-0.47,0.41L9.25,5.35C8.66,5.59,8.12,5.92,7.63,6.29L5.24,5.33c-0.22-0.08-0.47,0-0.59,0.22L2.74,8.87 C2.62,9.08,2.66,9.34,2.86,9.48l2.03,1.58C4.84,11.36,4.8,11.69,4.8,12s0.02,0.64,0.07,0.94l-2.03,1.58 c-0.18,0.14-0.23,0.41-0.12,0.61l1.92,3.32c0.12,0.22,0.37,0.29,0.59,0.22l2.39-0.96c0.5,0.38,1.03,0.7,1.62,0.94l0.36,2.54 c0.04,0.24,0.24,0.41,0.48,0.41h3.84c0.24,0,0.44-0.17,0.47-0.41l0.36-2.54c0.59-0.24,1.13-0.56,1.62-0.94l2.39,0.96 c0.22,0.08,0.47,0,0.59-0.22l1.92-3.32c0.12-0.22,0.07-0.47-0.12-0.61L19.14,12.94z M12,15.6c-1.98,0-3.6-1.62-3.6-3.6 s1.62-3.6,3.6-3.6s3.6,1.62,3.6,3.6S13.98,15.6,12,15.6z"/>
|
| 674 |
+
</svg>
|
| 675 |
+
</button>
|
| 676 |
+
</div>
|
| 677 |
+
|
| 678 |
+
<div id="loading-overlay">
|
| 679 |
+
<!-- <div class="spinner"></div> -->
|
| 680 |
+
<div id="loading-text"></div>
|
| 681 |
+
<div class="loading-subtitle" style="font-size: xx-large;">Interactive Viewer of 3D Tracking</div>
|
| 682 |
+
</div>
|
| 683 |
+
|
| 684 |
+
<!-- Libraries -->
|
| 685 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/pako/2.1.0/pako.min.js"></script>
|
| 686 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/build/three.min.js"></script>
|
| 687 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/controls/OrbitControls.js"></script>
|
| 688 |
+
<script src="https://cdn.jsdelivr.net/npm/dat.gui@0.7.7/build/dat.gui.min.js"></script>
|
| 689 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/LineSegmentsGeometry.js"></script>
|
| 690 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/LineGeometry.js"></script>
|
| 691 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/LineMaterial.js"></script>
|
| 692 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/LineSegments2.js"></script>
|
| 693 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/Line2.js"></script>
|
| 694 |
+
|
| 695 |
+
<script>
|
| 696 |
+
class PointCloudVisualizer {
|
| 697 |
+
constructor() {
|
| 698 |
+
this.data = null;
|
| 699 |
+
this.config = {};
|
| 700 |
+
this.currentFrame = 0;
|
| 701 |
+
this.isPlaying = false;
|
| 702 |
+
this.playbackSpeed = 1;
|
| 703 |
+
this.lastFrameTime = 0;
|
| 704 |
+
this.defaultSettings = null;
|
| 705 |
+
|
| 706 |
+
this.ui = {
|
| 707 |
+
statusBar: document.getElementById('status-bar'),
|
| 708 |
+
playPauseBtn: document.getElementById('play-pause-btn'),
|
| 709 |
+
speedBtn: document.getElementById('speed-btn'),
|
| 710 |
+
timeline: document.getElementById('timeline'),
|
| 711 |
+
progress: document.getElementById('progress'),
|
| 712 |
+
settingsPanel: document.getElementById('settings-panel'),
|
| 713 |
+
loadingOverlay: document.getElementById('loading-overlay'),
|
| 714 |
+
loadingText: document.getElementById('loading-text'),
|
| 715 |
+
settingsToggleBtn: document.getElementById('settings-toggle-btn'),
|
| 716 |
+
frameCounter: document.getElementById('frame-counter'),
|
| 717 |
+
pointSize: document.getElementById('point-size'),
|
| 718 |
+
pointOpacity: document.getElementById('point-opacity'),
|
| 719 |
+
maxDepth: document.getElementById('max-depth'),
|
| 720 |
+
showTrajectory: document.getElementById('show-trajectory'),
|
| 721 |
+
enableRichTrail: document.getElementById('enable-rich-trail'),
|
| 722 |
+
trajectoryLineWidth: document.getElementById('trajectory-line-width'),
|
| 723 |
+
trajectoryBallSize: document.getElementById('trajectory-ball-size'),
|
| 724 |
+
trajectoryHistory: document.getElementById('trajectory-history'),
|
| 725 |
+
trajectoryFade: document.getElementById('trajectory-fade'),
|
| 726 |
+
tailOpacityContainer: document.getElementById('tail-opacity-container'),
|
| 727 |
+
resetViewBtn: document.getElementById('reset-view-btn'),
|
| 728 |
+
showCameraFrustum: document.getElementById('show-camera-frustum'),
|
| 729 |
+
frustumSize: document.getElementById('frustum-size'),
|
| 730 |
+
hideSettingsBtn: document.getElementById('hide-settings-btn'),
|
| 731 |
+
showSettingsBtn: document.getElementById('show-settings-btn')
|
| 732 |
+
};
|
| 733 |
+
|
| 734 |
+
this.scene = null;
|
| 735 |
+
this.camera = null;
|
| 736 |
+
this.renderer = null;
|
| 737 |
+
this.controls = null;
|
| 738 |
+
this.pointCloud = null;
|
| 739 |
+
this.trajectories = [];
|
| 740 |
+
this.cameraFrustum = null;
|
| 741 |
+
|
| 742 |
+
this.initThreeJS();
|
| 743 |
+
this.loadDefaultSettings().then(() => {
|
| 744 |
+
this.initEventListeners();
|
| 745 |
+
this.loadData();
|
| 746 |
+
});
|
| 747 |
+
}
|
| 748 |
+
|
| 749 |
+
async loadDefaultSettings() {
|
| 750 |
+
try {
|
| 751 |
+
const urlParams = new URLSearchParams(window.location.search);
|
| 752 |
+
const dataPath = urlParams.get('data') || '';
|
| 753 |
+
|
| 754 |
+
const defaultSettings = {
|
| 755 |
+
pointSize: 0.03,
|
| 756 |
+
pointOpacity: 1.0,
|
| 757 |
+
showTrajectory: true,
|
| 758 |
+
trajectoryLineWidth: 2.5,
|
| 759 |
+
trajectoryBallSize: 0.015,
|
| 760 |
+
trajectoryHistory: 0,
|
| 761 |
+
showCameraFrustum: true,
|
| 762 |
+
frustumSize: 0.2
|
| 763 |
+
};
|
| 764 |
+
|
| 765 |
+
if (!dataPath) {
|
| 766 |
+
this.defaultSettings = defaultSettings;
|
| 767 |
+
this.applyDefaultSettings();
|
| 768 |
+
return;
|
| 769 |
+
}
|
| 770 |
+
|
| 771 |
+
// Try to extract dataset and videoId from the data path
|
| 772 |
+
// Expected format: demos/datasetname/videoid.bin
|
| 773 |
+
const pathParts = dataPath.split('/');
|
| 774 |
+
if (pathParts.length < 3) {
|
| 775 |
+
this.defaultSettings = defaultSettings;
|
| 776 |
+
this.applyDefaultSettings();
|
| 777 |
+
return;
|
| 778 |
+
}
|
| 779 |
+
|
| 780 |
+
const datasetName = pathParts[pathParts.length - 2];
|
| 781 |
+
let videoId = pathParts[pathParts.length - 1].replace('.bin', '');
|
| 782 |
+
|
| 783 |
+
// Load settings from data.json
|
| 784 |
+
const response = await fetch('./data.json');
|
| 785 |
+
if (!response.ok) {
|
| 786 |
+
this.defaultSettings = defaultSettings;
|
| 787 |
+
this.applyDefaultSettings();
|
| 788 |
+
return;
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
+
const settingsData = await response.json();
|
| 792 |
+
|
| 793 |
+
// Check if this dataset and video exist
|
| 794 |
+
if (settingsData[datasetName] && settingsData[datasetName][videoId]) {
|
| 795 |
+
this.defaultSettings = settingsData[datasetName][videoId];
|
| 796 |
+
} else {
|
| 797 |
+
this.defaultSettings = defaultSettings;
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
this.applyDefaultSettings();
|
| 801 |
+
} catch (error) {
|
| 802 |
+
console.error("Error loading default settings:", error);
|
| 803 |
+
|
| 804 |
+
this.defaultSettings = {
|
| 805 |
+
pointSize: 0.03,
|
| 806 |
+
pointOpacity: 1.0,
|
| 807 |
+
showTrajectory: true,
|
| 808 |
+
trajectoryLineWidth: 2.5,
|
| 809 |
+
trajectoryBallSize: 0.015,
|
| 810 |
+
trajectoryHistory: 0,
|
| 811 |
+
showCameraFrustum: true,
|
| 812 |
+
frustumSize: 0.2
|
| 813 |
+
};
|
| 814 |
+
|
| 815 |
+
this.applyDefaultSettings();
|
| 816 |
+
}
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
applyDefaultSettings() {
|
| 820 |
+
if (!this.defaultSettings) return;
|
| 821 |
+
|
| 822 |
+
if (this.ui.pointSize) {
|
| 823 |
+
this.ui.pointSize.value = this.defaultSettings.pointSize;
|
| 824 |
+
}
|
| 825 |
+
|
| 826 |
+
if (this.ui.pointOpacity) {
|
| 827 |
+
this.ui.pointOpacity.value = this.defaultSettings.pointOpacity;
|
| 828 |
+
}
|
| 829 |
+
|
| 830 |
+
if (this.ui.maxDepth) {
|
| 831 |
+
this.ui.maxDepth.value = this.defaultSettings.maxDepth || 100.0;
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
if (this.ui.showTrajectory) {
|
| 835 |
+
this.ui.showTrajectory.checked = this.defaultSettings.showTrajectory;
|
| 836 |
+
}
|
| 837 |
+
|
| 838 |
+
if (this.ui.trajectoryLineWidth) {
|
| 839 |
+
this.ui.trajectoryLineWidth.value = this.defaultSettings.trajectoryLineWidth;
|
| 840 |
+
}
|
| 841 |
+
|
| 842 |
+
if (this.ui.trajectoryBallSize) {
|
| 843 |
+
this.ui.trajectoryBallSize.value = this.defaultSettings.trajectoryBallSize;
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
if (this.ui.trajectoryHistory) {
|
| 847 |
+
this.ui.trajectoryHistory.value = this.defaultSettings.trajectoryHistory;
|
| 848 |
+
}
|
| 849 |
+
|
| 850 |
+
if (this.ui.showCameraFrustum) {
|
| 851 |
+
this.ui.showCameraFrustum.checked = this.defaultSettings.showCameraFrustum;
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
if (this.ui.frustumSize) {
|
| 855 |
+
this.ui.frustumSize.value = this.defaultSettings.frustumSize;
|
| 856 |
+
}
|
| 857 |
+
}
|
| 858 |
+
|
| 859 |
+
initThreeJS() {
|
| 860 |
+
this.scene = new THREE.Scene();
|
| 861 |
+
this.scene.background = new THREE.Color(0x1a1a1a);
|
| 862 |
+
|
| 863 |
+
this.camera = new THREE.PerspectiveCamera(60, window.innerWidth / window.innerHeight, 0.1, 10000);
|
| 864 |
+
this.camera.position.set(0, 0, 0);
|
| 865 |
+
|
| 866 |
+
this.renderer = new THREE.WebGLRenderer({ antialias: true });
|
| 867 |
+
this.renderer.setPixelRatio(window.devicePixelRatio);
|
| 868 |
+
this.renderer.setSize(window.innerWidth, window.innerHeight);
|
| 869 |
+
document.getElementById('canvas-container').appendChild(this.renderer.domElement);
|
| 870 |
+
|
| 871 |
+
this.controls = new THREE.OrbitControls(this.camera, this.renderer.domElement);
|
| 872 |
+
this.controls.enableDamping = true;
|
| 873 |
+
this.controls.dampingFactor = 0.05;
|
| 874 |
+
this.controls.target.set(0, 0, 0);
|
| 875 |
+
this.controls.minDistance = 0.1;
|
| 876 |
+
this.controls.maxDistance = 1000;
|
| 877 |
+
this.controls.update();
|
| 878 |
+
|
| 879 |
+
const ambientLight = new THREE.AmbientLight(0xffffff, 0.5);
|
| 880 |
+
this.scene.add(ambientLight);
|
| 881 |
+
|
| 882 |
+
const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
|
| 883 |
+
directionalLight.position.set(1, 1, 1);
|
| 884 |
+
this.scene.add(directionalLight);
|
| 885 |
+
}
|
| 886 |
+
|
| 887 |
+
initEventListeners() {
|
| 888 |
+
window.addEventListener('resize', () => this.onWindowResize());
|
| 889 |
+
|
| 890 |
+
this.ui.playPauseBtn.addEventListener('click', () => this.togglePlayback());
|
| 891 |
+
|
| 892 |
+
this.ui.timeline.addEventListener('click', (e) => {
|
| 893 |
+
const rect = this.ui.timeline.getBoundingClientRect();
|
| 894 |
+
const pos = (e.clientX - rect.left) / rect.width;
|
| 895 |
+
this.seekTo(pos);
|
| 896 |
+
});
|
| 897 |
+
|
| 898 |
+
this.ui.speedBtn.addEventListener('click', () => this.cyclePlaybackSpeed());
|
| 899 |
+
|
| 900 |
+
this.ui.pointSize.addEventListener('input', () => this.updatePointCloudSettings());
|
| 901 |
+
this.ui.pointOpacity.addEventListener('input', () => this.updatePointCloudSettings());
|
| 902 |
+
this.ui.maxDepth.addEventListener('input', () => this.updatePointCloudSettings());
|
| 903 |
+
this.ui.showTrajectory.addEventListener('change', () => {
|
| 904 |
+
this.trajectories.forEach(trajectory => {
|
| 905 |
+
trajectory.visible = this.ui.showTrajectory.checked;
|
| 906 |
+
});
|
| 907 |
+
});
|
| 908 |
+
|
| 909 |
+
this.ui.enableRichTrail.addEventListener('change', () => {
|
| 910 |
+
this.ui.tailOpacityContainer.style.display = this.ui.enableRichTrail.checked ? 'flex' : 'none';
|
| 911 |
+
this.updateTrajectories(this.currentFrame);
|
| 912 |
+
});
|
| 913 |
+
|
| 914 |
+
this.ui.trajectoryLineWidth.addEventListener('input', () => this.updateTrajectorySettings());
|
| 915 |
+
this.ui.trajectoryBallSize.addEventListener('input', () => this.updateTrajectorySettings());
|
| 916 |
+
this.ui.trajectoryHistory.addEventListener('input', () => {
|
| 917 |
+
this.updateTrajectories(this.currentFrame);
|
| 918 |
+
});
|
| 919 |
+
this.ui.trajectoryFade.addEventListener('input', () => {
|
| 920 |
+
this.updateTrajectories(this.currentFrame);
|
| 921 |
+
});
|
| 922 |
+
|
| 923 |
+
this.ui.resetViewBtn.addEventListener('click', () => this.resetView());
|
| 924 |
+
|
| 925 |
+
const resetSettingsBtn = document.getElementById('reset-settings-btn');
|
| 926 |
+
if (resetSettingsBtn) {
|
| 927 |
+
resetSettingsBtn.addEventListener('click', () => this.resetSettings());
|
| 928 |
+
}
|
| 929 |
+
|
| 930 |
+
document.addEventListener('keydown', (e) => {
|
| 931 |
+
if (e.key === 'Escape' && this.ui.settingsPanel.classList.contains('visible')) {
|
| 932 |
+
this.ui.settingsPanel.classList.remove('visible');
|
| 933 |
+
this.ui.settingsToggleBtn.classList.remove('active');
|
| 934 |
+
}
|
| 935 |
+
});
|
| 936 |
+
|
| 937 |
+
if (this.ui.settingsToggleBtn) {
|
| 938 |
+
this.ui.settingsToggleBtn.addEventListener('click', () => {
|
| 939 |
+
const isVisible = this.ui.settingsPanel.classList.toggle('visible');
|
| 940 |
+
this.ui.settingsToggleBtn.classList.toggle('active', isVisible);
|
| 941 |
+
|
| 942 |
+
if (isVisible) {
|
| 943 |
+
const panelRect = this.ui.settingsPanel.getBoundingClientRect();
|
| 944 |
+
const viewportHeight = window.innerHeight;
|
| 945 |
+
|
| 946 |
+
if (panelRect.bottom > viewportHeight) {
|
| 947 |
+
this.ui.settingsPanel.style.bottom = 'auto';
|
| 948 |
+
this.ui.settingsPanel.style.top = '80px';
|
| 949 |
+
}
|
| 950 |
+
}
|
| 951 |
+
});
|
| 952 |
+
}
|
| 953 |
+
|
| 954 |
+
if (this.ui.frustumSize) {
|
| 955 |
+
this.ui.frustumSize.addEventListener('input', () => this.updateFrustumDimensions());
|
| 956 |
+
}
|
| 957 |
+
|
| 958 |
+
this.makeElementDraggable(this.ui.settingsPanel);
|
| 959 |
+
|
| 960 |
+
if (this.ui.hideSettingsBtn && this.ui.showSettingsBtn && this.ui.settingsPanel) {
|
| 961 |
+
this.ui.hideSettingsBtn.addEventListener('click', () => {
|
| 962 |
+
this.ui.settingsPanel.classList.add('is-hidden');
|
| 963 |
+
this.ui.showSettingsBtn.style.display = 'flex';
|
| 964 |
+
});
|
| 965 |
+
|
| 966 |
+
this.ui.showSettingsBtn.addEventListener('click', () => {
|
| 967 |
+
this.ui.settingsPanel.classList.remove('is-hidden');
|
| 968 |
+
this.ui.showSettingsBtn.style.display = 'none';
|
| 969 |
+
});
|
| 970 |
+
}
|
| 971 |
+
}
|
| 972 |
+
|
| 973 |
+
makeElementDraggable(element) {
|
| 974 |
+
let pos1 = 0, pos2 = 0, pos3 = 0, pos4 = 0;
|
| 975 |
+
|
| 976 |
+
const dragHandle = element.querySelector('h2');
|
| 977 |
+
|
| 978 |
+
if (dragHandle) {
|
| 979 |
+
dragHandle.onmousedown = dragMouseDown;
|
| 980 |
+
dragHandle.title = "Drag to move panel";
|
| 981 |
+
} else {
|
| 982 |
+
element.onmousedown = dragMouseDown;
|
| 983 |
+
}
|
| 984 |
+
|
| 985 |
+
function dragMouseDown(e) {
|
| 986 |
+
e = e || window.event;
|
| 987 |
+
e.preventDefault();
|
| 988 |
+
pos3 = e.clientX;
|
| 989 |
+
pos4 = e.clientY;
|
| 990 |
+
document.onmouseup = closeDragElement;
|
| 991 |
+
document.onmousemove = elementDrag;
|
| 992 |
+
|
| 993 |
+
element.classList.add('dragging');
|
| 994 |
+
}
|
| 995 |
+
|
| 996 |
+
function elementDrag(e) {
|
| 997 |
+
e = e || window.event;
|
| 998 |
+
e.preventDefault();
|
| 999 |
+
pos1 = pos3 - e.clientX;
|
| 1000 |
+
pos2 = pos4 - e.clientY;
|
| 1001 |
+
pos3 = e.clientX;
|
| 1002 |
+
pos4 = e.clientY;
|
| 1003 |
+
|
| 1004 |
+
const newTop = element.offsetTop - pos2;
|
| 1005 |
+
const newLeft = element.offsetLeft - pos1;
|
| 1006 |
+
|
| 1007 |
+
const viewportWidth = window.innerWidth;
|
| 1008 |
+
const viewportHeight = window.innerHeight;
|
| 1009 |
+
|
| 1010 |
+
const panelRect = element.getBoundingClientRect();
|
| 1011 |
+
|
| 1012 |
+
const maxTop = viewportHeight - 50;
|
| 1013 |
+
const maxLeft = viewportWidth - 50;
|
| 1014 |
+
|
| 1015 |
+
element.style.top = Math.min(Math.max(newTop, 0), maxTop) + "px";
|
| 1016 |
+
element.style.left = Math.min(Math.max(newLeft, 0), maxLeft) + "px";
|
| 1017 |
+
|
| 1018 |
+
// Remove bottom/right settings when dragging
|
| 1019 |
+
element.style.bottom = 'auto';
|
| 1020 |
+
element.style.right = 'auto';
|
| 1021 |
+
}
|
| 1022 |
+
|
| 1023 |
+
function closeDragElement() {
|
| 1024 |
+
document.onmouseup = null;
|
| 1025 |
+
document.onmousemove = null;
|
| 1026 |
+
|
| 1027 |
+
element.classList.remove('dragging');
|
| 1028 |
+
}
|
| 1029 |
+
}
|
| 1030 |
+
|
| 1031 |
+
async loadData() {
|
| 1032 |
+
try {
|
| 1033 |
+
// this.ui.loadingText.textContent = "Loading binary data...";
|
| 1034 |
+
|
| 1035 |
+
let arrayBuffer;
|
| 1036 |
+
|
| 1037 |
+
if (window.embeddedBase64) {
|
| 1038 |
+
// Base64 embedded path
|
| 1039 |
+
const binaryString = atob(window.embeddedBase64);
|
| 1040 |
+
const len = binaryString.length;
|
| 1041 |
+
const bytes = new Uint8Array(len);
|
| 1042 |
+
for (let i = 0; i < len; i++) {
|
| 1043 |
+
bytes[i] = binaryString.charCodeAt(i);
|
| 1044 |
+
}
|
| 1045 |
+
arrayBuffer = bytes.buffer;
|
| 1046 |
+
} else {
|
| 1047 |
+
// Default fetch path (fallback)
|
| 1048 |
+
const urlParams = new URLSearchParams(window.location.search);
|
| 1049 |
+
const dataPath = urlParams.get('data') || 'data.bin';
|
| 1050 |
+
|
| 1051 |
+
const response = await fetch(dataPath);
|
| 1052 |
+
if (!response.ok) throw new Error(`Failed to load ${dataPath}`);
|
| 1053 |
+
arrayBuffer = await response.arrayBuffer();
|
| 1054 |
+
}
|
| 1055 |
+
|
| 1056 |
+
const dataView = new DataView(arrayBuffer);
|
| 1057 |
+
const headerLen = dataView.getUint32(0, true);
|
| 1058 |
+
|
| 1059 |
+
const headerText = new TextDecoder("utf-8").decode(arrayBuffer.slice(4, 4 + headerLen));
|
| 1060 |
+
const header = JSON.parse(headerText);
|
| 1061 |
+
|
| 1062 |
+
const compressedBlob = new Uint8Array(arrayBuffer, 4 + headerLen);
|
| 1063 |
+
const decompressed = pako.inflate(compressedBlob).buffer;
|
| 1064 |
+
|
| 1065 |
+
const arrays = {};
|
| 1066 |
+
for (const key in header) {
|
| 1067 |
+
if (key === "meta") continue;
|
| 1068 |
+
|
| 1069 |
+
const meta = header[key];
|
| 1070 |
+
const { dtype, shape, offset, length } = meta;
|
| 1071 |
+
const slice = decompressed.slice(offset, offset + length);
|
| 1072 |
+
|
| 1073 |
+
let typedArray;
|
| 1074 |
+
switch (dtype) {
|
| 1075 |
+
case "uint8": typedArray = new Uint8Array(slice); break;
|
| 1076 |
+
case "uint16": typedArray = new Uint16Array(slice); break;
|
| 1077 |
+
case "float32": typedArray = new Float32Array(slice); break;
|
| 1078 |
+
case "float64": typedArray = new Float64Array(slice); break;
|
| 1079 |
+
default: throw new Error(`Unknown dtype: ${dtype}`);
|
| 1080 |
+
}
|
| 1081 |
+
|
| 1082 |
+
arrays[key] = { data: typedArray, shape: shape };
|
| 1083 |
+
}
|
| 1084 |
+
|
| 1085 |
+
this.data = arrays;
|
| 1086 |
+
this.config = header.meta;
|
| 1087 |
+
|
| 1088 |
+
this.initCameraWithCorrectFOV();
|
| 1089 |
+
this.ui.loadingText.textContent = "Creating point cloud...";
|
| 1090 |
+
|
| 1091 |
+
this.initPointCloud();
|
| 1092 |
+
this.initTrajectories();
|
| 1093 |
+
|
| 1094 |
+
setTimeout(() => {
|
| 1095 |
+
this.ui.loadingOverlay.classList.add('fade-out');
|
| 1096 |
+
this.ui.statusBar.classList.add('hidden');
|
| 1097 |
+
this.startAnimation();
|
| 1098 |
+
}, 500);
|
| 1099 |
+
} catch (error) {
|
| 1100 |
+
console.error("Error loading data:", error);
|
| 1101 |
+
this.ui.statusBar.textContent = `Error: ${error.message}`;
|
| 1102 |
+
// this.ui.loadingText.textContent = `Error loading data: ${error.message}`;
|
| 1103 |
+
}
|
| 1104 |
+
}
|
| 1105 |
+
|
| 1106 |
+
initPointCloud() {
|
| 1107 |
+
const numPoints = this.config.resolution[0] * this.config.resolution[1];
|
| 1108 |
+
const positions = new Float32Array(numPoints * 3);
|
| 1109 |
+
const colors = new Float32Array(numPoints * 3);
|
| 1110 |
+
|
| 1111 |
+
const geometry = new THREE.BufferGeometry();
|
| 1112 |
+
geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3).setUsage(THREE.DynamicDrawUsage));
|
| 1113 |
+
geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3).setUsage(THREE.DynamicDrawUsage));
|
| 1114 |
+
|
| 1115 |
+
const pointSize = parseFloat(this.ui.pointSize.value) || this.defaultSettings.pointSize;
|
| 1116 |
+
const pointOpacity = parseFloat(this.ui.pointOpacity.value) || this.defaultSettings.pointOpacity;
|
| 1117 |
+
|
| 1118 |
+
const material = new THREE.PointsMaterial({
|
| 1119 |
+
size: pointSize,
|
| 1120 |
+
vertexColors: true,
|
| 1121 |
+
transparent: true,
|
| 1122 |
+
opacity: pointOpacity,
|
| 1123 |
+
sizeAttenuation: true
|
| 1124 |
+
});
|
| 1125 |
+
|
| 1126 |
+
this.pointCloud = new THREE.Points(geometry, material);
|
| 1127 |
+
this.scene.add(this.pointCloud);
|
| 1128 |
+
}
|
| 1129 |
+
|
| 1130 |
+
initTrajectories() {
|
| 1131 |
+
if (!this.data.trajectories) return;
|
| 1132 |
+
|
| 1133 |
+
this.trajectories.forEach(trajectory => {
|
| 1134 |
+
if (trajectory.userData.lineSegments) {
|
| 1135 |
+
trajectory.userData.lineSegments.forEach(segment => {
|
| 1136 |
+
segment.geometry.dispose();
|
| 1137 |
+
segment.material.dispose();
|
| 1138 |
+
});
|
| 1139 |
+
}
|
| 1140 |
+
this.scene.remove(trajectory);
|
| 1141 |
+
});
|
| 1142 |
+
this.trajectories = [];
|
| 1143 |
+
|
| 1144 |
+
const shape = this.data.trajectories.shape;
|
| 1145 |
+
if (!shape || shape.length < 2) return;
|
| 1146 |
+
|
| 1147 |
+
const [totalFrames, numTrajectories] = shape;
|
| 1148 |
+
const palette = this.createColorPalette(numTrajectories);
|
| 1149 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
| 1150 |
+
const maxHistory = 500; // Max value of the history slider, for the object pool
|
| 1151 |
+
|
| 1152 |
+
for (let i = 0; i < numTrajectories; i++) {
|
| 1153 |
+
const trajectoryGroup = new THREE.Group();
|
| 1154 |
+
|
| 1155 |
+
const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
|
| 1156 |
+
const sphereGeometry = new THREE.SphereGeometry(ballSize, 16, 16);
|
| 1157 |
+
const sphereMaterial = new THREE.MeshBasicMaterial({ color: palette[i], transparent: true });
|
| 1158 |
+
const positionMarker = new THREE.Mesh(sphereGeometry, sphereMaterial);
|
| 1159 |
+
trajectoryGroup.add(positionMarker);
|
| 1160 |
+
|
| 1161 |
+
// High-Performance Line (default)
|
| 1162 |
+
const simpleLineGeometry = new THREE.BufferGeometry();
|
| 1163 |
+
const simpleLinePositions = new Float32Array(maxHistory * 3);
|
| 1164 |
+
simpleLineGeometry.setAttribute('position', new THREE.BufferAttribute(simpleLinePositions, 3).setUsage(THREE.DynamicDrawUsage));
|
| 1165 |
+
const simpleLine = new THREE.Line(simpleLineGeometry, new THREE.LineBasicMaterial({ color: palette[i] }));
|
| 1166 |
+
simpleLine.frustumCulled = false;
|
| 1167 |
+
trajectoryGroup.add(simpleLine);
|
| 1168 |
+
|
| 1169 |
+
// High-Quality Line Segments (for rich trail)
|
| 1170 |
+
const lineSegments = [];
|
| 1171 |
+
const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
|
| 1172 |
+
|
| 1173 |
+
// Create a pool of line segment objects
|
| 1174 |
+
for (let j = 0; j < maxHistory - 1; j++) {
|
| 1175 |
+
const lineGeometry = new THREE.LineGeometry();
|
| 1176 |
+
lineGeometry.setPositions([0, 0, 0, 0, 0, 0]);
|
| 1177 |
+
const lineMaterial = new THREE.LineMaterial({
|
| 1178 |
+
color: palette[i],
|
| 1179 |
+
linewidth: lineWidth,
|
| 1180 |
+
resolution: resolution,
|
| 1181 |
+
transparent: true,
|
| 1182 |
+
depthWrite: false, // Correctly handle transparency
|
| 1183 |
+
opacity: 0
|
| 1184 |
+
});
|
| 1185 |
+
const segment = new THREE.Line2(lineGeometry, lineMaterial);
|
| 1186 |
+
segment.frustumCulled = false;
|
| 1187 |
+
segment.visible = false; // Start with all segments hidden
|
| 1188 |
+
trajectoryGroup.add(segment);
|
| 1189 |
+
lineSegments.push(segment);
|
| 1190 |
+
}
|
| 1191 |
+
|
| 1192 |
+
trajectoryGroup.userData = {
|
| 1193 |
+
marker: positionMarker,
|
| 1194 |
+
simpleLine: simpleLine,
|
| 1195 |
+
lineSegments: lineSegments,
|
| 1196 |
+
color: palette[i]
|
| 1197 |
+
};
|
| 1198 |
+
|
| 1199 |
+
this.scene.add(trajectoryGroup);
|
| 1200 |
+
this.trajectories.push(trajectoryGroup);
|
| 1201 |
+
}
|
| 1202 |
+
|
| 1203 |
+
const showTrajectory = this.ui.showTrajectory.checked;
|
| 1204 |
+
this.trajectories.forEach(trajectory => trajectory.visible = showTrajectory);
|
| 1205 |
+
}
|
| 1206 |
+
|
| 1207 |
+
createColorPalette(count) {
|
| 1208 |
+
const colors = [];
|
| 1209 |
+
const hueStep = 360 / count;
|
| 1210 |
+
|
| 1211 |
+
for (let i = 0; i < count; i++) {
|
| 1212 |
+
const hue = (i * hueStep) % 360;
|
| 1213 |
+
const color = new THREE.Color().setHSL(hue / 360, 0.8, 0.6);
|
| 1214 |
+
colors.push(color);
|
| 1215 |
+
}
|
| 1216 |
+
|
| 1217 |
+
return colors;
|
| 1218 |
+
}
|
| 1219 |
+
|
| 1220 |
+
updatePointCloud(frameIndex) {
|
| 1221 |
+
if (!this.data || !this.pointCloud) return;
|
| 1222 |
+
|
| 1223 |
+
const positions = this.pointCloud.geometry.attributes.position.array;
|
| 1224 |
+
const colors = this.pointCloud.geometry.attributes.color.array;
|
| 1225 |
+
|
| 1226 |
+
const rgbVideo = this.data.rgb_video;
|
| 1227 |
+
const depthsRgb = this.data.depths_rgb;
|
| 1228 |
+
const intrinsics = this.data.intrinsics;
|
| 1229 |
+
const invExtrinsics = this.data.inv_extrinsics;
|
| 1230 |
+
|
| 1231 |
+
const width = this.config.resolution[0];
|
| 1232 |
+
const height = this.config.resolution[1];
|
| 1233 |
+
const numPoints = width * height;
|
| 1234 |
+
|
| 1235 |
+
const K = this.get3x3Matrix(intrinsics.data, intrinsics.shape, frameIndex);
|
| 1236 |
+
const fx = K[0][0], fy = K[1][1], cx = K[0][2], cy = K[1][2];
|
| 1237 |
+
|
| 1238 |
+
const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
|
| 1239 |
+
const transform = this.getTransformElements(invExtrMat);
|
| 1240 |
+
|
| 1241 |
+
const rgbFrame = this.getFrame(rgbVideo.data, rgbVideo.shape, frameIndex);
|
| 1242 |
+
const depthFrame = this.getFrame(depthsRgb.data, depthsRgb.shape, frameIndex);
|
| 1243 |
+
|
| 1244 |
+
const maxDepth = parseFloat(this.ui.maxDepth.value) || 10.0;
|
| 1245 |
+
|
| 1246 |
+
let validPointCount = 0;
|
| 1247 |
+
|
| 1248 |
+
for (let i = 0; i < numPoints; i++) {
|
| 1249 |
+
const xPix = i % width;
|
| 1250 |
+
const yPix = Math.floor(i / width);
|
| 1251 |
+
|
| 1252 |
+
const d0 = depthFrame[i * 3];
|
| 1253 |
+
const d1 = depthFrame[i * 3 + 1];
|
| 1254 |
+
const depthEncoded = d0 | (d1 << 8);
|
| 1255 |
+
const depthValue = (depthEncoded / ((1 << 16) - 1)) *
|
| 1256 |
+
(this.config.depthRange[1] - this.config.depthRange[0]) +
|
| 1257 |
+
this.config.depthRange[0];
|
| 1258 |
+
|
| 1259 |
+
if (depthValue === 0 || depthValue > maxDepth) {
|
| 1260 |
+
continue;
|
| 1261 |
+
}
|
| 1262 |
+
|
| 1263 |
+
const X = ((xPix - cx) * depthValue) / fx;
|
| 1264 |
+
const Y = ((yPix - cy) * depthValue) / fy;
|
| 1265 |
+
const Z = depthValue;
|
| 1266 |
+
|
| 1267 |
+
const tx = transform.m11 * X + transform.m12 * Y + transform.m13 * Z + transform.m14;
|
| 1268 |
+
const ty = transform.m21 * X + transform.m22 * Y + transform.m23 * Z + transform.m24;
|
| 1269 |
+
const tz = transform.m31 * X + transform.m32 * Y + transform.m33 * Z + transform.m34;
|
| 1270 |
+
|
| 1271 |
+
const index = validPointCount * 3;
|
| 1272 |
+
positions[index] = tx;
|
| 1273 |
+
positions[index + 1] = -ty;
|
| 1274 |
+
positions[index + 2] = -tz;
|
| 1275 |
+
|
| 1276 |
+
colors[index] = rgbFrame[i * 3] / 255;
|
| 1277 |
+
colors[index + 1] = rgbFrame[i * 3 + 1] / 255;
|
| 1278 |
+
colors[index + 2] = rgbFrame[i * 3 + 2] / 255;
|
| 1279 |
+
|
| 1280 |
+
validPointCount++;
|
| 1281 |
+
}
|
| 1282 |
+
|
| 1283 |
+
this.pointCloud.geometry.setDrawRange(0, validPointCount);
|
| 1284 |
+
this.pointCloud.geometry.attributes.position.needsUpdate = true;
|
| 1285 |
+
this.pointCloud.geometry.attributes.color.needsUpdate = true;
|
| 1286 |
+
this.pointCloud.geometry.computeBoundingSphere(); // Important for camera culling
|
| 1287 |
+
|
| 1288 |
+
this.updateTrajectories(frameIndex);
|
| 1289 |
+
|
| 1290 |
+
const progress = (frameIndex + 1) / this.config.totalFrames;
|
| 1291 |
+
this.ui.progress.style.width = `${progress * 100}%`;
|
| 1292 |
+
|
| 1293 |
+
if (this.ui.frameCounter && this.config.totalFrames) {
|
| 1294 |
+
this.ui.frameCounter.textContent = `Frame ${frameIndex} / ${this.config.totalFrames - 1}`;
|
| 1295 |
+
}
|
| 1296 |
+
|
| 1297 |
+
this.updateCameraFrustum(frameIndex);
|
| 1298 |
+
}
|
| 1299 |
+
|
| 1300 |
+
updateTrajectories(frameIndex) {
|
| 1301 |
+
if (!this.data.trajectories || this.trajectories.length === 0) return;
|
| 1302 |
+
|
| 1303 |
+
const trajectoryData = this.data.trajectories.data;
|
| 1304 |
+
const [totalFrames, numTrajectories] = this.data.trajectories.shape;
|
| 1305 |
+
const historyFrames = parseInt(this.ui.trajectoryHistory.value);
|
| 1306 |
+
const tailOpacity = parseFloat(this.ui.trajectoryFade.value);
|
| 1307 |
+
|
| 1308 |
+
const isRichMode = this.ui.enableRichTrail.checked;
|
| 1309 |
+
|
| 1310 |
+
for (let i = 0; i < numTrajectories; i++) {
|
| 1311 |
+
const trajectoryGroup = this.trajectories[i];
|
| 1312 |
+
const { marker, simpleLine, lineSegments } = trajectoryGroup.userData;
|
| 1313 |
+
|
| 1314 |
+
const currentPos = new THREE.Vector3();
|
| 1315 |
+
const currentOffset = (frameIndex * numTrajectories + i) * 3;
|
| 1316 |
+
|
| 1317 |
+
currentPos.x = trajectoryData[currentOffset];
|
| 1318 |
+
currentPos.y = -trajectoryData[currentOffset + 1];
|
| 1319 |
+
currentPos.z = -trajectoryData[currentOffset + 2];
|
| 1320 |
+
|
| 1321 |
+
marker.position.copy(currentPos);
|
| 1322 |
+
marker.material.opacity = 1.0;
|
| 1323 |
+
|
| 1324 |
+
const historyToShow = Math.min(historyFrames, frameIndex + 1);
|
| 1325 |
+
|
| 1326 |
+
if (isRichMode) {
|
| 1327 |
+
// --- High-Quality Mode ---
|
| 1328 |
+
simpleLine.visible = false;
|
| 1329 |
+
|
| 1330 |
+
for (let j = 0; j < lineSegments.length; j++) {
|
| 1331 |
+
const segment = lineSegments[j];
|
| 1332 |
+
if (j < historyToShow - 1) {
|
| 1333 |
+
const headFrame = frameIndex - j;
|
| 1334 |
+
const tailFrame = frameIndex - j - 1;
|
| 1335 |
+
const headOffset = (headFrame * numTrajectories + i) * 3;
|
| 1336 |
+
const tailOffset = (tailFrame * numTrajectories + i) * 3;
|
| 1337 |
+
const positions = [
|
| 1338 |
+
trajectoryData[headOffset], -trajectoryData[headOffset + 1], -trajectoryData[headOffset + 2],
|
| 1339 |
+
trajectoryData[tailOffset], -trajectoryData[tailOffset + 1], -trajectoryData[tailOffset + 2]
|
| 1340 |
+
];
|
| 1341 |
+
segment.geometry.setPositions(positions);
|
| 1342 |
+
const headOpacity = 1.0;
|
| 1343 |
+
const normalizedAge = j / Math.max(1, historyToShow - 2);
|
| 1344 |
+
const alpha = headOpacity - (headOpacity - tailOpacity) * normalizedAge;
|
| 1345 |
+
segment.material.opacity = Math.max(0, alpha);
|
| 1346 |
+
segment.visible = true;
|
| 1347 |
+
} else {
|
| 1348 |
+
segment.visible = false;
|
| 1349 |
+
}
|
| 1350 |
+
}
|
| 1351 |
+
} else {
|
| 1352 |
+
// --- Performance Mode ---
|
| 1353 |
+
lineSegments.forEach(s => s.visible = false);
|
| 1354 |
+
simpleLine.visible = true;
|
| 1355 |
+
|
| 1356 |
+
const positions = simpleLine.geometry.attributes.position.array;
|
| 1357 |
+
for (let j = 0; j < historyToShow; j++) {
|
| 1358 |
+
const historyFrame = Math.max(0, frameIndex - j);
|
| 1359 |
+
const offset = (historyFrame * numTrajectories + i) * 3;
|
| 1360 |
+
positions[j * 3] = trajectoryData[offset];
|
| 1361 |
+
positions[j * 3 + 1] = -trajectoryData[offset + 1];
|
| 1362 |
+
positions[j * 3 + 2] = -trajectoryData[offset + 2];
|
| 1363 |
+
}
|
| 1364 |
+
simpleLine.geometry.setDrawRange(0, historyToShow);
|
| 1365 |
+
simpleLine.geometry.attributes.position.needsUpdate = true;
|
| 1366 |
+
}
|
| 1367 |
+
}
|
| 1368 |
+
}
|
| 1369 |
+
|
| 1370 |
+
updateTrajectorySettings() {
|
| 1371 |
+
if (!this.trajectories || this.trajectories.length === 0) return;
|
| 1372 |
+
|
| 1373 |
+
const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
|
| 1374 |
+
const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
|
| 1375 |
+
|
| 1376 |
+
this.trajectories.forEach(trajectoryGroup => {
|
| 1377 |
+
const { marker, lineSegments } = trajectoryGroup.userData;
|
| 1378 |
+
|
| 1379 |
+
marker.geometry.dispose();
|
| 1380 |
+
marker.geometry = new THREE.SphereGeometry(ballSize, 16, 16);
|
| 1381 |
+
|
| 1382 |
+
// Line width only affects rich mode
|
| 1383 |
+
lineSegments.forEach(segment => {
|
| 1384 |
+
if (segment.material) {
|
| 1385 |
+
segment.material.linewidth = lineWidth;
|
| 1386 |
+
}
|
| 1387 |
+
});
|
| 1388 |
+
});
|
| 1389 |
+
|
| 1390 |
+
this.updateTrajectories(this.currentFrame);
|
| 1391 |
+
}
|
| 1392 |
+
|
| 1393 |
+
getDepthColor(normalizedDepth) {
|
| 1394 |
+
const hue = (1 - normalizedDepth) * 240 / 360;
|
| 1395 |
+
const color = new THREE.Color().setHSL(hue, 1.0, 0.5);
|
| 1396 |
+
return color;
|
| 1397 |
+
}
|
| 1398 |
+
|
| 1399 |
+
getFrame(typedArray, shape, frameIndex) {
|
| 1400 |
+
const [T, H, W, C] = shape;
|
| 1401 |
+
const frameSize = H * W * C;
|
| 1402 |
+
const offset = frameIndex * frameSize;
|
| 1403 |
+
return typedArray.subarray(offset, offset + frameSize);
|
| 1404 |
+
}
|
| 1405 |
+
|
| 1406 |
+
get3x3Matrix(typedArray, shape, frameIndex) {
|
| 1407 |
+
const frameSize = 9;
|
| 1408 |
+
const offset = frameIndex * frameSize;
|
| 1409 |
+
const K = [];
|
| 1410 |
+
for (let i = 0; i < 3; i++) {
|
| 1411 |
+
const row = [];
|
| 1412 |
+
for (let j = 0; j < 3; j++) {
|
| 1413 |
+
row.push(typedArray[offset + i * 3 + j]);
|
| 1414 |
+
}
|
| 1415 |
+
K.push(row);
|
| 1416 |
+
}
|
| 1417 |
+
return K;
|
| 1418 |
+
}
|
| 1419 |
+
|
| 1420 |
+
get4x4Matrix(typedArray, shape, frameIndex) {
|
| 1421 |
+
const frameSize = 16;
|
| 1422 |
+
const offset = frameIndex * frameSize;
|
| 1423 |
+
const M = [];
|
| 1424 |
+
for (let i = 0; i < 4; i++) {
|
| 1425 |
+
const row = [];
|
| 1426 |
+
for (let j = 0; j < 4; j++) {
|
| 1427 |
+
row.push(typedArray[offset + i * 4 + j]);
|
| 1428 |
+
}
|
| 1429 |
+
M.push(row);
|
| 1430 |
+
}
|
| 1431 |
+
return M;
|
| 1432 |
+
}
|
| 1433 |
+
|
| 1434 |
+
getTransformElements(matrix) {
|
| 1435 |
+
return {
|
| 1436 |
+
m11: matrix[0][0], m12: matrix[0][1], m13: matrix[0][2], m14: matrix[0][3],
|
| 1437 |
+
m21: matrix[1][0], m22: matrix[1][1], m23: matrix[1][2], m24: matrix[1][3],
|
| 1438 |
+
m31: matrix[2][0], m32: matrix[2][1], m33: matrix[2][2], m34: matrix[2][3]
|
| 1439 |
+
};
|
| 1440 |
+
}
|
| 1441 |
+
|
| 1442 |
+
togglePlayback() {
|
| 1443 |
+
this.isPlaying = !this.isPlaying;
|
| 1444 |
+
|
| 1445 |
+
const playIcon = document.getElementById('play-icon');
|
| 1446 |
+
const pauseIcon = document.getElementById('pause-icon');
|
| 1447 |
+
|
| 1448 |
+
if (this.isPlaying) {
|
| 1449 |
+
playIcon.style.display = 'none';
|
| 1450 |
+
pauseIcon.style.display = 'block';
|
| 1451 |
+
this.lastFrameTime = performance.now();
|
| 1452 |
+
} else {
|
| 1453 |
+
playIcon.style.display = 'block';
|
| 1454 |
+
pauseIcon.style.display = 'none';
|
| 1455 |
+
}
|
| 1456 |
+
}
|
| 1457 |
+
|
| 1458 |
+
cyclePlaybackSpeed() {
|
| 1459 |
+
const speeds = [0.5, 1, 2, 4, 8];
|
| 1460 |
+
const speedRates = speeds.map(s => s * this.config.baseFrameRate);
|
| 1461 |
+
|
| 1462 |
+
let currentIndex = 0;
|
| 1463 |
+
const normalizedSpeed = this.playbackSpeed / this.config.baseFrameRate;
|
| 1464 |
+
|
| 1465 |
+
for (let i = 0; i < speeds.length; i++) {
|
| 1466 |
+
if (Math.abs(normalizedSpeed - speeds[i]) < Math.abs(normalizedSpeed - speeds[currentIndex])) {
|
| 1467 |
+
currentIndex = i;
|
| 1468 |
+
}
|
| 1469 |
+
}
|
| 1470 |
+
|
| 1471 |
+
const nextIndex = (currentIndex + 1) % speeds.length;
|
| 1472 |
+
this.playbackSpeed = speedRates[nextIndex];
|
| 1473 |
+
this.ui.speedBtn.textContent = `${speeds[nextIndex]}x`;
|
| 1474 |
+
|
| 1475 |
+
if (speeds[nextIndex] === 1) {
|
| 1476 |
+
this.ui.speedBtn.classList.remove('active');
|
| 1477 |
+
} else {
|
| 1478 |
+
this.ui.speedBtn.classList.add('active');
|
| 1479 |
+
}
|
| 1480 |
+
}
|
| 1481 |
+
|
| 1482 |
+
seekTo(position) {
|
| 1483 |
+
const frameIndex = Math.floor(position * this.config.totalFrames);
|
| 1484 |
+
this.currentFrame = Math.max(0, Math.min(frameIndex, this.config.totalFrames - 1));
|
| 1485 |
+
this.updatePointCloud(this.currentFrame);
|
| 1486 |
+
}
|
| 1487 |
+
|
| 1488 |
+
updatePointCloudSettings() {
|
| 1489 |
+
if (!this.pointCloud) return;
|
| 1490 |
+
|
| 1491 |
+
const size = parseFloat(this.ui.pointSize.value);
|
| 1492 |
+
const opacity = parseFloat(this.ui.pointOpacity.value);
|
| 1493 |
+
|
| 1494 |
+
this.pointCloud.material.size = size;
|
| 1495 |
+
this.pointCloud.material.opacity = opacity;
|
| 1496 |
+
this.pointCloud.material.needsUpdate = true;
|
| 1497 |
+
|
| 1498 |
+
this.updatePointCloud(this.currentFrame);
|
| 1499 |
+
}
|
| 1500 |
+
|
| 1501 |
+
updateControls() {
|
| 1502 |
+
if (!this.controls) return;
|
| 1503 |
+
this.controls.update();
|
| 1504 |
+
}
|
| 1505 |
+
|
| 1506 |
+
resetView() {
|
| 1507 |
+
if (!this.camera || !this.controls) return;
|
| 1508 |
+
|
| 1509 |
+
// Reset camera position
|
| 1510 |
+
this.camera.position.set(0, 0, this.config.cameraZ || 0);
|
| 1511 |
+
|
| 1512 |
+
// Reset controls
|
| 1513 |
+
this.controls.reset();
|
| 1514 |
+
|
| 1515 |
+
// Set target slightly in front of camera
|
| 1516 |
+
this.controls.target.set(0, 0, -1);
|
| 1517 |
+
this.controls.update();
|
| 1518 |
+
|
| 1519 |
+
// Show status message
|
| 1520 |
+
this.ui.statusBar.textContent = "View reset";
|
| 1521 |
+
this.ui.statusBar.classList.remove('hidden');
|
| 1522 |
+
|
| 1523 |
+
// Hide status message after a few seconds
|
| 1524 |
+
setTimeout(() => {
|
| 1525 |
+
this.ui.statusBar.classList.add('hidden');
|
| 1526 |
+
}, 3000);
|
| 1527 |
+
}
|
| 1528 |
+
|
| 1529 |
+
onWindowResize() {
|
| 1530 |
+
if (!this.camera || !this.renderer) return;
|
| 1531 |
+
|
| 1532 |
+
const windowAspect = window.innerWidth / window.innerHeight;
|
| 1533 |
+
this.camera.aspect = windowAspect;
|
| 1534 |
+
this.camera.updateProjectionMatrix();
|
| 1535 |
+
this.renderer.setSize(window.innerWidth, window.innerHeight);
|
| 1536 |
+
|
| 1537 |
+
if (this.trajectories && this.trajectories.length > 0) {
|
| 1538 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
| 1539 |
+
this.trajectories.forEach(trajectory => {
|
| 1540 |
+
const { lineSegments } = trajectory.userData;
|
| 1541 |
+
if (lineSegments && lineSegments.length > 0) {
|
| 1542 |
+
lineSegments.forEach(segment => {
|
| 1543 |
+
if (segment.material && segment.material.resolution) {
|
| 1544 |
+
segment.material.resolution.copy(resolution);
|
| 1545 |
+
}
|
| 1546 |
+
});
|
| 1547 |
+
}
|
| 1548 |
+
});
|
| 1549 |
+
}
|
| 1550 |
+
|
| 1551 |
+
if (this.cameraFrustum) {
|
| 1552 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
| 1553 |
+
this.cameraFrustum.children.forEach(line => {
|
| 1554 |
+
if (line.material && line.material.resolution) {
|
| 1555 |
+
line.material.resolution.copy(resolution);
|
| 1556 |
+
}
|
| 1557 |
+
});
|
| 1558 |
+
}
|
| 1559 |
+
}
|
| 1560 |
+
|
| 1561 |
+
startAnimation() {
|
| 1562 |
+
this.isPlaying = true;
|
| 1563 |
+
this.lastFrameTime = performance.now();
|
| 1564 |
+
|
| 1565 |
+
this.camera.position.set(0, 0, this.config.cameraZ || 0);
|
| 1566 |
+
this.controls.target.set(0, 0, -1);
|
| 1567 |
+
this.controls.update();
|
| 1568 |
+
|
| 1569 |
+
this.playbackSpeed = this.config.baseFrameRate;
|
| 1570 |
+
|
| 1571 |
+
document.getElementById('play-icon').style.display = 'none';
|
| 1572 |
+
document.getElementById('pause-icon').style.display = 'block';
|
| 1573 |
+
|
| 1574 |
+
this.animate();
|
| 1575 |
+
}
|
| 1576 |
+
|
| 1577 |
+
animate() {
|
| 1578 |
+
requestAnimationFrame(() => this.animate());
|
| 1579 |
+
|
| 1580 |
+
if (this.controls) {
|
| 1581 |
+
this.controls.update();
|
| 1582 |
+
}
|
| 1583 |
+
|
| 1584 |
+
if (this.isPlaying && this.data) {
|
| 1585 |
+
const now = performance.now();
|
| 1586 |
+
const delta = (now - this.lastFrameTime) / 1000;
|
| 1587 |
+
|
| 1588 |
+
const framesToAdvance = Math.floor(delta * this.config.baseFrameRate * this.playbackSpeed);
|
| 1589 |
+
if (framesToAdvance > 0) {
|
| 1590 |
+
this.currentFrame = (this.currentFrame + framesToAdvance) % this.config.totalFrames;
|
| 1591 |
+
this.lastFrameTime = now;
|
| 1592 |
+
this.updatePointCloud(this.currentFrame);
|
| 1593 |
+
}
|
| 1594 |
+
}
|
| 1595 |
+
|
| 1596 |
+
if (this.renderer && this.scene && this.camera) {
|
| 1597 |
+
this.renderer.render(this.scene, this.camera);
|
| 1598 |
+
}
|
| 1599 |
+
}
|
| 1600 |
+
|
| 1601 |
+
initCameraWithCorrectFOV() {
|
| 1602 |
+
const fov = this.config.fov || 60;
|
| 1603 |
+
|
| 1604 |
+
const windowAspect = window.innerWidth / window.innerHeight;
|
| 1605 |
+
|
| 1606 |
+
this.camera = new THREE.PerspectiveCamera(
|
| 1607 |
+
fov,
|
| 1608 |
+
windowAspect,
|
| 1609 |
+
0.1,
|
| 1610 |
+
10000
|
| 1611 |
+
);
|
| 1612 |
+
|
| 1613 |
+
this.controls.object = this.camera;
|
| 1614 |
+
this.controls.update();
|
| 1615 |
+
|
| 1616 |
+
this.initCameraFrustum();
|
| 1617 |
+
}
|
| 1618 |
+
|
| 1619 |
+
initCameraFrustum() {
|
| 1620 |
+
this.cameraFrustum = new THREE.Group();
|
| 1621 |
+
|
| 1622 |
+
this.scene.add(this.cameraFrustum);
|
| 1623 |
+
|
| 1624 |
+
this.initCameraFrustumGeometry();
|
| 1625 |
+
|
| 1626 |
+
const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : (this.defaultSettings ? this.defaultSettings.showCameraFrustum : false);
|
| 1627 |
+
|
| 1628 |
+
this.cameraFrustum.visible = showCameraFrustum;
|
| 1629 |
+
}
|
| 1630 |
+
|
| 1631 |
+
initCameraFrustumGeometry() {
|
| 1632 |
+
const fov = this.config.fov || 60;
|
| 1633 |
+
const originalAspect = this.config.original_aspect_ratio || 1.33;
|
| 1634 |
+
|
| 1635 |
+
const size = parseFloat(this.ui.frustumSize.value) || this.defaultSettings.frustumSize;
|
| 1636 |
+
|
| 1637 |
+
const halfHeight = Math.tan(THREE.MathUtils.degToRad(fov / 2)) * size;
|
| 1638 |
+
const halfWidth = halfHeight * originalAspect;
|
| 1639 |
+
|
| 1640 |
+
const vertices = [
|
| 1641 |
+
new THREE.Vector3(0, 0, 0),
|
| 1642 |
+
new THREE.Vector3(-halfWidth, -halfHeight, size),
|
| 1643 |
+
new THREE.Vector3(halfWidth, -halfHeight, size),
|
| 1644 |
+
new THREE.Vector3(halfWidth, halfHeight, size),
|
| 1645 |
+
new THREE.Vector3(-halfWidth, halfHeight, size)
|
| 1646 |
+
];
|
| 1647 |
+
|
| 1648 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
| 1649 |
+
|
| 1650 |
+
const linePairs = [
|
| 1651 |
+
[1, 2], [2, 3], [3, 4], [4, 1],
|
| 1652 |
+
[0, 1], [0, 2], [0, 3], [0, 4]
|
| 1653 |
+
];
|
| 1654 |
+
|
| 1655 |
+
const colors = {
|
| 1656 |
+
edge: new THREE.Color(0x3366ff),
|
| 1657 |
+
ray: new THREE.Color(0x33cc66)
|
| 1658 |
+
};
|
| 1659 |
+
|
| 1660 |
+
linePairs.forEach((pair, index) => {
|
| 1661 |
+
const positions = [
|
| 1662 |
+
vertices[pair[0]].x, vertices[pair[0]].y, vertices[pair[0]].z,
|
| 1663 |
+
vertices[pair[1]].x, vertices[pair[1]].y, vertices[pair[1]].z
|
| 1664 |
+
];
|
| 1665 |
+
|
| 1666 |
+
const lineGeometry = new THREE.LineGeometry();
|
| 1667 |
+
lineGeometry.setPositions(positions);
|
| 1668 |
+
|
| 1669 |
+
let color = index < 4 ? colors.edge : colors.ray;
|
| 1670 |
+
|
| 1671 |
+
const lineMaterial = new THREE.LineMaterial({
|
| 1672 |
+
color: color,
|
| 1673 |
+
linewidth: 2,
|
| 1674 |
+
resolution: resolution,
|
| 1675 |
+
dashed: false
|
| 1676 |
+
});
|
| 1677 |
+
|
| 1678 |
+
const line = new THREE.Line2(lineGeometry, lineMaterial);
|
| 1679 |
+
this.cameraFrustum.add(line);
|
| 1680 |
+
});
|
| 1681 |
+
}
|
| 1682 |
+
|
| 1683 |
+
updateCameraFrustum(frameIndex) {
|
| 1684 |
+
if (!this.cameraFrustum || !this.data) return;
|
| 1685 |
+
|
| 1686 |
+
const invExtrinsics = this.data.inv_extrinsics;
|
| 1687 |
+
if (!invExtrinsics) return;
|
| 1688 |
+
|
| 1689 |
+
const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
|
| 1690 |
+
|
| 1691 |
+
const matrix = new THREE.Matrix4();
|
| 1692 |
+
matrix.set(
|
| 1693 |
+
invExtrMat[0][0], invExtrMat[0][1], invExtrMat[0][2], invExtrMat[0][3],
|
| 1694 |
+
invExtrMat[1][0], invExtrMat[1][1], invExtrMat[1][2], invExtrMat[1][3],
|
| 1695 |
+
invExtrMat[2][0], invExtrMat[2][1], invExtrMat[2][2], invExtrMat[2][3],
|
| 1696 |
+
invExtrMat[3][0], invExtrMat[3][1], invExtrMat[3][2], invExtrMat[3][3]
|
| 1697 |
+
);
|
| 1698 |
+
|
| 1699 |
+
const position = new THREE.Vector3();
|
| 1700 |
+
position.setFromMatrixPosition(matrix);
|
| 1701 |
+
|
| 1702 |
+
const rotMatrix = new THREE.Matrix4().extractRotation(matrix);
|
| 1703 |
+
|
| 1704 |
+
const coordinateCorrection = new THREE.Matrix4().makeRotationX(Math.PI);
|
| 1705 |
+
|
| 1706 |
+
const finalRotation = new THREE.Matrix4().multiplyMatrices(coordinateCorrection, rotMatrix);
|
| 1707 |
+
|
| 1708 |
+
const quaternion = new THREE.Quaternion();
|
| 1709 |
+
quaternion.setFromRotationMatrix(finalRotation);
|
| 1710 |
+
|
| 1711 |
+
position.y = -position.y;
|
| 1712 |
+
position.z = -position.z;
|
| 1713 |
+
|
| 1714 |
+
this.cameraFrustum.position.copy(position);
|
| 1715 |
+
this.cameraFrustum.quaternion.copy(quaternion);
|
| 1716 |
+
|
| 1717 |
+
const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : this.defaultSettings.showCameraFrustum;
|
| 1718 |
+
|
| 1719 |
+
if (this.cameraFrustum.visible !== showCameraFrustum) {
|
| 1720 |
+
this.cameraFrustum.visible = showCameraFrustum;
|
| 1721 |
+
}
|
| 1722 |
+
|
| 1723 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
| 1724 |
+
this.cameraFrustum.children.forEach(line => {
|
| 1725 |
+
if (line.material && line.material.resolution) {
|
| 1726 |
+
line.material.resolution.copy(resolution);
|
| 1727 |
+
}
|
| 1728 |
+
});
|
| 1729 |
+
}
|
| 1730 |
+
|
| 1731 |
+
updateFrustumDimensions() {
|
| 1732 |
+
if (!this.cameraFrustum) return;
|
| 1733 |
+
|
| 1734 |
+
while(this.cameraFrustum.children.length > 0) {
|
| 1735 |
+
const child = this.cameraFrustum.children[0];
|
| 1736 |
+
if (child.geometry) child.geometry.dispose();
|
| 1737 |
+
if (child.material) child.material.dispose();
|
| 1738 |
+
this.cameraFrustum.remove(child);
|
| 1739 |
+
}
|
| 1740 |
+
|
| 1741 |
+
this.initCameraFrustumGeometry();
|
| 1742 |
+
|
| 1743 |
+
this.updateCameraFrustum(this.currentFrame);
|
| 1744 |
+
}
|
| 1745 |
+
|
| 1746 |
+
resetSettings() {
|
| 1747 |
+
if (!this.defaultSettings) return;
|
| 1748 |
+
|
| 1749 |
+
this.applyDefaultSettings();
|
| 1750 |
+
|
| 1751 |
+
this.updatePointCloudSettings();
|
| 1752 |
+
this.updateTrajectorySettings();
|
| 1753 |
+
this.updateFrustumDimensions();
|
| 1754 |
+
|
| 1755 |
+
this.ui.statusBar.textContent = "Settings reset to defaults";
|
| 1756 |
+
this.ui.statusBar.classList.remove('hidden');
|
| 1757 |
+
|
| 1758 |
+
setTimeout(() => {
|
| 1759 |
+
this.ui.statusBar.classList.add('hidden');
|
| 1760 |
+
}, 3000);
|
| 1761 |
+
}
|
| 1762 |
+
}
|
| 1763 |
+
|
| 1764 |
+
window.addEventListener('DOMContentLoaded', () => {
|
| 1765 |
+
new PointCloudVisualizer();
|
| 1766 |
+
});
|
| 1767 |
+
</script>
|
| 1768 |
+
</body>
|
| 1769 |
+
</html>
|
app.py
CHANGED
|
@@ -4,160 +4,363 @@ import json
|
|
| 4 |
import numpy as np
|
| 5 |
import cv2
|
| 6 |
import base64
|
| 7 |
-
import requests
|
| 8 |
import time
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
|
| 15 |
-
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
#
|
| 40 |
try:
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
if
|
| 45 |
-
|
| 46 |
-
username = user_info.get('name', 'Unknown')
|
| 47 |
-
print(f"✅ Authenticated as: {username}")
|
| 48 |
-
|
| 49 |
-
# Check if user has access to the specific space
|
| 50 |
-
space_url = f"https://huggingface.co/api/spaces/{BACKEND_SPACE_URL}"
|
| 51 |
-
space_response = requests.get(space_url, headers=headers, timeout=5)
|
| 52 |
-
|
| 53 |
-
if space_response.status_code == 200:
|
| 54 |
-
print("✅ You have access to the backend Space")
|
| 55 |
-
return True
|
| 56 |
-
elif space_response.status_code == 401:
|
| 57 |
-
print("❌ You don't have access to the backend Space")
|
| 58 |
-
print("🔧 Solutions:")
|
| 59 |
-
print(" 1. Contact the Space owner to add you as collaborator")
|
| 60 |
-
print(" 2. Ask the owner to make the Space public")
|
| 61 |
-
return False
|
| 62 |
-
elif space_response.status_code == 404:
|
| 63 |
-
print("❌ Backend Space not found")
|
| 64 |
-
print("🔧 Please check if the Space URL is correct")
|
| 65 |
-
return False
|
| 66 |
-
else:
|
| 67 |
-
print(f"⚠️ Unexpected response checking Space access: {space_response.status_code}")
|
| 68 |
-
return False
|
| 69 |
-
|
| 70 |
-
else:
|
| 71 |
-
print(f"❌ Token validation failed: {response.status_code}")
|
| 72 |
-
print("🔧 Your token might be invalid or expired")
|
| 73 |
-
return False
|
| 74 |
|
| 75 |
except Exception as e:
|
| 76 |
-
print(f"
|
| 77 |
-
|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
#
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
print("✅ Backend space appears to be running")
|
| 110 |
-
return True
|
| 111 |
-
|
| 112 |
-
elif response.status_code == 401:
|
| 113 |
-
print("❌ Authentication failed (HTTP 401)")
|
| 114 |
-
print("🔧 This means:")
|
| 115 |
-
print(" - The backend Space is private")
|
| 116 |
-
print(" - Your HF Token doesn't have access to this Space")
|
| 117 |
-
print(" - You need to be added as a collaborator to the Space")
|
| 118 |
-
print(" - Or the Space owner needs to make it public")
|
| 119 |
-
return False
|
| 120 |
-
|
| 121 |
-
elif response.status_code == 404:
|
| 122 |
-
print("❌ Backend space not found (HTTP 404)")
|
| 123 |
-
print("🔧 Please check if the Space URL is correct:")
|
| 124 |
-
print(f" Current URL: {BACKEND_SPACE_URL}")
|
| 125 |
-
return False
|
| 126 |
-
|
| 127 |
-
else:
|
| 128 |
-
print(f"❌ Backend space not accessible (HTTP {response.status_code})")
|
| 129 |
-
print(f"🔧 Response: {response.text[:200]}...")
|
| 130 |
-
return False
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
-
def
|
| 140 |
-
|
| 141 |
-
global backend_client, BACKEND_AVAILABLE
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
def numpy_to_base64(arr):
|
| 163 |
"""Convert numpy array to base64 string"""
|
|
@@ -167,24 +370,6 @@ def base64_to_numpy(b64_str, shape, dtype):
|
|
| 167 |
"""Convert base64 string back to numpy array"""
|
| 168 |
return np.frombuffer(base64.b64decode(b64_str), dtype=dtype).reshape(shape)
|
| 169 |
|
| 170 |
-
def base64_to_image(b64_str):
|
| 171 |
-
"""Convert base64 string to numpy image array"""
|
| 172 |
-
if not b64_str:
|
| 173 |
-
return None
|
| 174 |
-
try:
|
| 175 |
-
# Decode base64 to bytes
|
| 176 |
-
img_bytes = base64.b64decode(b64_str)
|
| 177 |
-
# Convert bytes to numpy array
|
| 178 |
-
nparr = np.frombuffer(img_bytes, np.uint8)
|
| 179 |
-
# Decode image
|
| 180 |
-
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 181 |
-
# Convert BGR to RGB
|
| 182 |
-
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 183 |
-
return img
|
| 184 |
-
except Exception as e:
|
| 185 |
-
print(f"Error converting base64 to image: {e}")
|
| 186 |
-
return None
|
| 187 |
-
|
| 188 |
def get_video_name(video_path):
|
| 189 |
"""Extract video name without extension"""
|
| 190 |
return os.path.splitext(os.path.basename(video_path))[0]
|
|
@@ -197,7 +382,6 @@ def extract_first_frame(video_path):
|
|
| 197 |
cap.release()
|
| 198 |
|
| 199 |
if ret:
|
| 200 |
-
# Convert BGR to RGB
|
| 201 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 202 |
return frame_rgb
|
| 203 |
else:
|
|
@@ -214,116 +398,65 @@ def handle_video_upload(video):
|
|
| 214 |
gr.update(value=756),
|
| 215 |
gr.update(value=3))
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
# Call the unified API with upload_video function type - fix: use handle_file wrapper
|
| 224 |
-
result = backend_client.predict(
|
| 225 |
-
"upload_video", # function_type
|
| 226 |
-
handle_file(video), # video file - wrapped with handle_file
|
| 227 |
-
"", # original_image_state (not used for upload)
|
| 228 |
-
[], # selected_points (not used for upload)
|
| 229 |
-
"positive_point", # point_type (not used for upload)
|
| 230 |
-
0, # point_x (not used for upload)
|
| 231 |
-
0, # point_y (not used for upload)
|
| 232 |
-
50, # grid_size (not used for upload)
|
| 233 |
-
756, # vo_points (not used for upload)
|
| 234 |
-
3, # fps (not used for upload)
|
| 235 |
-
api_name="/unified_api"
|
| 236 |
-
)
|
| 237 |
-
|
| 238 |
-
print(f"✅ Backend video upload API call successful!")
|
| 239 |
-
print(f"🔧 Result type: {type(result)}")
|
| 240 |
-
print(f"🔧 Result: {result}")
|
| 241 |
-
|
| 242 |
-
# Parse the result - expect a dict with success status
|
| 243 |
-
if isinstance(result, dict) and result.get("success"):
|
| 244 |
-
# Extract data from backend response
|
| 245 |
-
original_image_state = result.get("original_image_state", "")
|
| 246 |
-
display_image = result.get("display_image", None)
|
| 247 |
-
selected_points = result.get("selected_points", [])
|
| 248 |
-
|
| 249 |
-
# Fix: Convert display_image from list back to numpy array if needed
|
| 250 |
-
if isinstance(display_image, list):
|
| 251 |
-
display_image = np.array(display_image, dtype=np.uint8)
|
| 252 |
-
print(f"🔧 Converted display_image from list to numpy array: {display_image.shape}")
|
| 253 |
-
|
| 254 |
-
# Get video settings based on video name
|
| 255 |
-
video_name = get_video_name(video)
|
| 256 |
-
print(f"🎬 Video path: '{video}' -> Video name: '{video_name}'")
|
| 257 |
-
grid_size_val, vo_points_val, fps_val = get_video_settings(video_name)
|
| 258 |
-
print(f"🎬 Video settings for '{video_name}': grid_size={grid_size_val}, vo_points={vo_points_val}, fps={fps_val}")
|
| 259 |
-
|
| 260 |
-
return (original_image_state, display_image, selected_points,
|
| 261 |
-
gr.update(value=grid_size_val),
|
| 262 |
-
gr.update(value=vo_points_val),
|
| 263 |
-
gr.update(value=fps_val))
|
| 264 |
-
else:
|
| 265 |
-
print("Backend processing failed, using local fallback")
|
| 266 |
-
# Fallback to local processing
|
| 267 |
-
pass
|
| 268 |
-
except Exception as e:
|
| 269 |
-
print(f"Backend API call failed: {e}")
|
| 270 |
-
# Fallback to local processing
|
| 271 |
-
pass
|
| 272 |
-
|
| 273 |
-
# Fallback: local processing
|
| 274 |
-
print("Using local video processing...")
|
| 275 |
-
display_image = extract_first_frame(video)
|
| 276 |
-
|
| 277 |
-
if display_image is not None:
|
| 278 |
-
# Create a state format compatible with backend
|
| 279 |
-
import tempfile
|
| 280 |
-
import shutil
|
| 281 |
-
|
| 282 |
-
# Create a temporary directory for this session
|
| 283 |
-
session_id = str(int(time.time() * 1000)) # Use timestamp as session ID
|
| 284 |
-
temp_dir = os.path.join("temp_frontend", f"session_{session_id}")
|
| 285 |
-
os.makedirs(temp_dir, exist_ok=True)
|
| 286 |
-
|
| 287 |
-
# Copy video to temp directory with standardized name
|
| 288 |
-
video_name = get_video_name(video)
|
| 289 |
-
temp_video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
| 290 |
-
shutil.copy(video, temp_video_path)
|
| 291 |
-
|
| 292 |
-
# Create state format compatible with backend
|
| 293 |
-
frame_data = {
|
| 294 |
-
'data': numpy_to_base64(display_image),
|
| 295 |
-
'shape': display_image.shape,
|
| 296 |
-
'dtype': str(display_image.dtype),
|
| 297 |
-
'temp_dir': temp_dir,
|
| 298 |
-
'video_name': video_name,
|
| 299 |
-
'video_path': temp_video_path # Keep for backward compatibility
|
| 300 |
-
}
|
| 301 |
-
|
| 302 |
-
original_image_state = json.dumps(frame_data)
|
| 303 |
-
else:
|
| 304 |
-
# Fallback to simple state if frame extraction fails
|
| 305 |
-
original_image_state = json.dumps({
|
| 306 |
-
"video_path": video,
|
| 307 |
-
"frame": "local_processing_failed"
|
| 308 |
-
})
|
| 309 |
-
|
| 310 |
-
# Get video settings
|
| 311 |
video_name = get_video_name(video)
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
|
|
|
|
|
|
| 323 |
return (None, None, [],
|
| 324 |
gr.update(value=50),
|
| 325 |
gr.update(value=756),
|
| 326 |
gr.update(value=3))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.SelectData):
|
| 329 |
"""Handle point selection for SAM"""
|
|
@@ -331,357 +464,142 @@ def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.Sele
|
|
| 331 |
return None, []
|
| 332 |
|
| 333 |
try:
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
# Call the unified API with select_point function type
|
| 340 |
-
result = backend_client.predict(
|
| 341 |
-
"select_point", # function_type
|
| 342 |
-
None, # video file (not used for select_point)
|
| 343 |
-
original_img, # original_image_state
|
| 344 |
-
sel_pix, # selected_points
|
| 345 |
-
point_type, # point_type
|
| 346 |
-
evt.index[0], # point_x
|
| 347 |
-
evt.index[1], # point_y
|
| 348 |
-
50, # grid_size (not used for select_point)
|
| 349 |
-
756, # vo_points (not used for select_point)
|
| 350 |
-
3, # fps (not used for select_point)
|
| 351 |
-
api_name="/unified_api"
|
| 352 |
-
)
|
| 353 |
-
|
| 354 |
-
print(f"✅ Backend select point API call successful!")
|
| 355 |
-
print(f"🔧 Result type: {type(result)}")
|
| 356 |
-
print(f"🔧 Result: {result}")
|
| 357 |
-
|
| 358 |
-
# Parse the result - expect a dict with success status
|
| 359 |
-
if isinstance(result, dict) and result.get("success"):
|
| 360 |
-
display_image = result.get("display_image", None)
|
| 361 |
-
new_sel_pix = result.get("selected_points", sel_pix)
|
| 362 |
-
|
| 363 |
-
# Fix: Convert display_image from list back to numpy array if needed
|
| 364 |
-
if isinstance(display_image, list):
|
| 365 |
-
display_image = np.array(display_image, dtype=np.uint8)
|
| 366 |
-
print(f"🔧 Converted display_image from list to numpy array: {display_image.shape}")
|
| 367 |
-
|
| 368 |
-
return display_image, new_sel_pix
|
| 369 |
-
else:
|
| 370 |
-
print("Backend processing failed, using local fallback")
|
| 371 |
-
# Fallback to local processing
|
| 372 |
-
pass
|
| 373 |
-
except Exception as e:
|
| 374 |
-
print(f"Backend API call failed: {e}")
|
| 375 |
-
|
| 376 |
-
# Check for specific gradio_client errors
|
| 377 |
-
if "AppError" in str(type(e)):
|
| 378 |
-
print("🔧 Backend Space has internal errors (AppError)")
|
| 379 |
-
print("🔧 The backend Space code has bugs or configuration issues")
|
| 380 |
-
print("🔧 Contact the Space owner to fix the backend implementation")
|
| 381 |
-
elif "Could not fetch config" in str(e):
|
| 382 |
-
print("🔧 Config fetch failed - possible Gradio version mismatch")
|
| 383 |
-
print("🔧 Frontend and backend may be using incompatible Gradio versions")
|
| 384 |
-
elif "timeout" in str(e).lower():
|
| 385 |
-
print("🔧 Backend request timed out - Space might be overloaded")
|
| 386 |
-
else:
|
| 387 |
-
print(f"🔧 Unexpected error type: {type(e).__name__}")
|
| 388 |
-
|
| 389 |
-
print("🔄 Showing error message instead of visualization...")
|
| 390 |
-
# Fallback to local processing
|
| 391 |
-
pass
|
| 392 |
|
| 393 |
-
#
|
| 394 |
-
|
|
|
|
|
|
|
| 395 |
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
video_path = state_data.get("video_path")
|
| 400 |
-
except:
|
| 401 |
-
video_path = None
|
| 402 |
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
if display_image is not None:
|
| 407 |
-
# Add point to the image with enhanced visualization
|
| 408 |
-
x, y = evt.index[0], evt.index[1]
|
| 409 |
-
color = (0, 255, 0) if point_type == 'positive_point' else (255, 0, 0)
|
| 410 |
-
|
| 411 |
-
# Draw a larger, more visible point
|
| 412 |
-
cv2.circle(display_image, (x, y), 8, color, -1)
|
| 413 |
-
cv2.circle(display_image, (x, y), 12, (255, 255, 255), 2)
|
| 414 |
-
|
| 415 |
-
# Add point to selected points list - fix logic to match local version
|
| 416 |
-
new_sel_pix = sel_pix.copy() if sel_pix else []
|
| 417 |
-
new_sel_pix.append([x, y, point_type])
|
| 418 |
-
|
| 419 |
-
return display_image, new_sel_pix
|
| 420 |
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
except Exception as e:
|
| 424 |
-
print(f"Error in select_point: {e}")
|
| 425 |
return None, []
|
| 426 |
|
| 427 |
def reset_points(original_img: str, sel_pix):
|
| 428 |
-
"""Reset points and
|
| 429 |
if original_img is None:
|
| 430 |
return None, []
|
| 431 |
|
| 432 |
try:
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
# Call the unified API with reset_points function type
|
| 439 |
-
result = backend_client.predict(
|
| 440 |
-
"reset_points", # function_type
|
| 441 |
-
None, # video file (not used for reset_points)
|
| 442 |
-
original_img, # original_image_state
|
| 443 |
-
sel_pix, # selected_points
|
| 444 |
-
"positive_point", # point_type (not used for reset_points)
|
| 445 |
-
0, # point_x (not used for reset_points)
|
| 446 |
-
0, # point_y (not used for reset_points)
|
| 447 |
-
50, # grid_size (not used for reset_points)
|
| 448 |
-
756, # vo_points (not used for reset_points)
|
| 449 |
-
3, # fps (not used for reset_points)
|
| 450 |
-
api_name="/unified_api"
|
| 451 |
-
)
|
| 452 |
-
|
| 453 |
-
print(f"✅ Backend reset points API call successful!")
|
| 454 |
-
print(f"🔧 Result: {result}")
|
| 455 |
-
|
| 456 |
-
# Parse the result
|
| 457 |
-
if isinstance(result, dict) and result.get("success"):
|
| 458 |
-
display_image = result.get("display_image", None)
|
| 459 |
-
new_sel_pix = result.get("selected_points", [])
|
| 460 |
-
|
| 461 |
-
# Fix: Convert display_image from list back to numpy array if needed
|
| 462 |
-
if isinstance(display_image, list):
|
| 463 |
-
display_image = np.array(display_image, dtype=np.uint8)
|
| 464 |
-
print(f"🔧 Converted display_image from list to numpy array: {display_image.shape}")
|
| 465 |
-
|
| 466 |
-
return display_image, new_sel_pix
|
| 467 |
-
else:
|
| 468 |
-
print("Backend processing failed, using local fallback")
|
| 469 |
-
# Fallback to local processing
|
| 470 |
-
pass
|
| 471 |
-
except Exception as e:
|
| 472 |
-
print(f"Backend API call failed: {e}")
|
| 473 |
-
# Fallback to local processing
|
| 474 |
-
pass
|
| 475 |
|
| 476 |
-
#
|
| 477 |
-
|
| 478 |
|
| 479 |
-
#
|
| 480 |
-
|
| 481 |
-
state_data = json.loads(original_img)
|
| 482 |
-
video_path = state_data.get("video_path")
|
| 483 |
-
except:
|
| 484 |
-
video_path = None
|
| 485 |
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
|
|
|
|
|
|
| 490 |
|
| 491 |
-
|
|
|
|
| 492 |
|
| 493 |
except Exception as e:
|
| 494 |
-
print(f"Error in reset_points: {e}")
|
| 495 |
return None, []
|
| 496 |
|
| 497 |
-
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
| 498 |
-
|
| 499 |
def launch_viz(grid_size, vo_points, fps, original_image_state):
|
| 500 |
"""Launch visualization with user-specific temp directory"""
|
| 501 |
if original_image_state is None:
|
| 502 |
return None, None, None
|
| 503 |
|
| 504 |
try:
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
print(f"🔧 Original image state type: {type(original_image_state)}")
|
| 510 |
-
print(f"🔧 Original image state preview: {str(original_image_state)[:100]}...")
|
| 511 |
-
|
| 512 |
-
# Validate and potentially fix the original_image_state format
|
| 513 |
-
state_to_send = original_image_state
|
| 514 |
-
|
| 515 |
-
# Check if this is a local processing state that needs to be converted
|
| 516 |
-
try:
|
| 517 |
-
if isinstance(original_image_state, str):
|
| 518 |
-
parsed_state = json.loads(original_image_state)
|
| 519 |
-
if "video_path" in parsed_state and "frame" in parsed_state:
|
| 520 |
-
# This is a local processing state, we need to handle differently
|
| 521 |
-
print("🔧 Detected local processing state, cannot use backend for tracking")
|
| 522 |
-
print("🔧 Backend requires proper video upload state from backend API")
|
| 523 |
-
# Fall through to local processing
|
| 524 |
-
raise ValueError("Local state cannot be processed by backend")
|
| 525 |
-
except json.JSONDecodeError:
|
| 526 |
-
print("🔧 Invalid JSON state, cannot send to backend")
|
| 527 |
-
raise ValueError("Invalid state format")
|
| 528 |
-
|
| 529 |
-
# Call the unified API with run_tracker function type
|
| 530 |
-
result = backend_client.predict(
|
| 531 |
-
"run_tracker", # function_type
|
| 532 |
-
None, # video file (not used for run_tracker)
|
| 533 |
-
state_to_send, # original_image_state
|
| 534 |
-
[], # selected_points (not used for run_tracker)
|
| 535 |
-
"positive_point", # point_type (not used for run_tracker)
|
| 536 |
-
0, # point_x (not used for run_tracker)
|
| 537 |
-
0, # point_y (not used for run_tracker)
|
| 538 |
-
grid_size, # grid_size
|
| 539 |
-
vo_points, # vo_points
|
| 540 |
-
fps, # fps
|
| 541 |
-
api_name="/unified_api"
|
| 542 |
-
)
|
| 543 |
-
|
| 544 |
-
print(f"✅ Backend API call successful!")
|
| 545 |
-
print(f"🔧 Result type: {type(result)}")
|
| 546 |
-
print(f"🔧 Result: {result}")
|
| 547 |
-
|
| 548 |
-
# Parse the result
|
| 549 |
-
if isinstance(result, dict) and result.get("success"):
|
| 550 |
-
viz_html = result.get("viz_html", "")
|
| 551 |
-
track_video_path = result.get("track_video_path", "")
|
| 552 |
-
track_video_content = result.get("track_video_content", None)
|
| 553 |
-
track_video_filename = result.get("track_video_filename", "tracked_video.mp4")
|
| 554 |
-
|
| 555 |
-
# Save HTML to _viz directory (like local version)
|
| 556 |
-
viz_dir = './_viz'
|
| 557 |
-
os.makedirs(viz_dir, exist_ok=True)
|
| 558 |
-
random_path = f'./_viz/_{time.time()}.html'
|
| 559 |
-
|
| 560 |
-
with open(random_path, 'w', encoding='utf-8') as f:
|
| 561 |
-
f.write(viz_html)
|
| 562 |
-
|
| 563 |
-
# Create iframe HTML
|
| 564 |
-
iframe_html = f"""
|
| 565 |
-
<div style='border: 3px solid #667eea; border-radius: 10px;
|
| 566 |
-
background: #f8f9ff; height: 650px; width: 100%;
|
| 567 |
-
box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3);
|
| 568 |
-
margin: 0; padding: 0; box-sizing: border-box; overflow: hidden;'>
|
| 569 |
-
<iframe id="viz_iframe" src="/gradio_api/file={random_path}"
|
| 570 |
-
width="100%" height="650" frameborder="0"
|
| 571 |
-
style="border: none; display: block; width: 100%; height: 650px;
|
| 572 |
-
margin: 0; padding: 0; border-radius: 7px;">
|
| 573 |
-
</iframe>
|
| 574 |
-
</div>
|
| 575 |
-
"""
|
| 576 |
-
|
| 577 |
-
print(f"💾 HTML saved to: {random_path}")
|
| 578 |
-
print(f"📊 HTML content preview: {viz_html[:200]}...")
|
| 579 |
-
|
| 580 |
-
# If we have base64 encoded video content, save it as a temporary file
|
| 581 |
-
local_video_path = None
|
| 582 |
-
if track_video_content:
|
| 583 |
-
try:
|
| 584 |
-
# Create a temporary file for the video
|
| 585 |
-
temp_video_dir = "temp_frontend_videos"
|
| 586 |
-
os.makedirs(temp_video_dir, exist_ok=True)
|
| 587 |
-
|
| 588 |
-
# Generate unique filename to avoid conflicts
|
| 589 |
-
timestamp = str(int(time.time() * 1000))
|
| 590 |
-
local_video_path = os.path.join(temp_video_dir, f"{timestamp}_{track_video_filename}")
|
| 591 |
-
|
| 592 |
-
# Decode base64 and save as video file
|
| 593 |
-
video_bytes = base64.b64decode(track_video_content)
|
| 594 |
-
with open(local_video_path, 'wb') as f:
|
| 595 |
-
f.write(video_bytes)
|
| 596 |
-
|
| 597 |
-
print(f"✅ Successfully saved tracking video to: {local_video_path}")
|
| 598 |
-
print(f"�� Video file size: {len(video_bytes)} bytes")
|
| 599 |
-
|
| 600 |
-
except Exception as e:
|
| 601 |
-
print(f"❌ Failed to process tracking video: {e}")
|
| 602 |
-
local_video_path = None
|
| 603 |
-
else:
|
| 604 |
-
print("⚠️ No tracking video content received from backend")
|
| 605 |
-
|
| 606 |
-
# 返回iframe HTML、视频路径和HTML文件路径(用于下载)
|
| 607 |
-
return iframe_html, local_video_path, random_path
|
| 608 |
-
else:
|
| 609 |
-
error_msg = result.get("error", "Unknown error") if isinstance(result, dict) else "Backend processing failed"
|
| 610 |
-
print(f"❌ Backend processing failed: {error_msg}")
|
| 611 |
-
# Fall through to error message
|
| 612 |
-
pass
|
| 613 |
-
except Exception as e:
|
| 614 |
-
print(f"❌ Backend API call failed: {e}")
|
| 615 |
-
print(f"🔧 Error type: {type(e)}")
|
| 616 |
-
print(f"🔧 Error details: {str(e)}")
|
| 617 |
-
|
| 618 |
-
# Check for specific gradio_client errors
|
| 619 |
-
if "AppError" in str(type(e)):
|
| 620 |
-
print("🔧 Backend Space has internal errors (AppError)")
|
| 621 |
-
print("🔧 The backend Space code has bugs or configuration issues")
|
| 622 |
-
print("🔧 Contact the Space owner to fix the backend implementation")
|
| 623 |
-
elif "Could not fetch config" in str(e):
|
| 624 |
-
print("🔧 Config fetch failed - possible Gradio version mismatch")
|
| 625 |
-
print("🔧 Frontend and backend may be using incompatible Gradio versions")
|
| 626 |
-
elif "timeout" in str(e).lower():
|
| 627 |
-
print("🔧 Backend request timed out - Space might be overloaded")
|
| 628 |
-
elif "Expecting value" in str(e):
|
| 629 |
-
print("🔧 JSON parsing error in backend - state format mismatch")
|
| 630 |
-
print("🔧 This happens when using local processing state with backend API")
|
| 631 |
-
print("🔧 Please upload video again to use backend processing")
|
| 632 |
-
else:
|
| 633 |
-
print(f"🔧 Unexpected error type: {type(e).__name__}")
|
| 634 |
-
|
| 635 |
-
print("🔄 Showing error message instead of visualization...")
|
| 636 |
-
# Fall through to error message
|
| 637 |
-
pass
|
| 638 |
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
try:
|
| 642 |
-
if isinstance(original_image_state, str):
|
| 643 |
-
parsed_state = json.loads(original_image_state)
|
| 644 |
-
if "video_path" in parsed_state:
|
| 645 |
-
video_name = os.path.basename(parsed_state["video_path"])
|
| 646 |
-
state_info = f"Video: {video_name}"
|
| 647 |
-
except:
|
| 648 |
-
state_info = "State format unknown"
|
| 649 |
|
| 650 |
-
#
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
<p style='color: #2d3436; font-weight: bold; margin: 0 0 5px 0;'>Debug Information:</p>
|
| 666 |
-
<p style='color: #666; font-size: 12px; margin: 0;'>Backend Available: {BACKEND_AVAILABLE}</p>
|
| 667 |
-
<p style='color: #666; font-size: 12px; margin: 0;'>Backend Client: {backend_client is not None}</p>
|
| 668 |
-
<p style='color: #666; font-size: 12px; margin: 0;'>Backend URL: {BACKEND_SPACE_URL}</p>
|
| 669 |
-
<p style='color: #666; font-size: 12px; margin: 0;'>State Info: {state_info}</p>
|
| 670 |
-
<p style='color: #666; font-size: 12px; margin: 0;'>Processing Mode: {"Backend" if BACKEND_AVAILABLE else "Local (Limited)"}</p>
|
| 671 |
-
</div>
|
| 672 |
-
<div style='background-color: #e3f2fd; border-radius: 5px; padding: 10px; margin-top: 10px; border-left: 4px solid #2196f3;'>
|
| 673 |
-
<p style='color: #1976d2; font-weight: bold; margin: 0 0 5px 0;'>💡 Quick Fix:</p>
|
| 674 |
-
<p style='color: #1976d2; font-size: 13px; margin: 0;'>
|
| 675 |
-
Try uploading your video again - this should properly initialize the backend state for tracking.
|
| 676 |
-
</p>
|
| 677 |
-
</div>
|
| 678 |
-
</div>
|
| 679 |
-
"""
|
| 680 |
-
return error_message, None, None
|
| 681 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
except Exception as e:
|
| 683 |
-
print(f"Error in launch_viz: {e}")
|
| 684 |
-
return
|
| 685 |
|
| 686 |
def clear_all():
|
| 687 |
"""Clear all buffers and temporary files"""
|
|
@@ -699,10 +617,6 @@ def clear_all_with_download():
|
|
| 699 |
None, # tracking_video_download
|
| 700 |
None) # HTML download component
|
| 701 |
|
| 702 |
-
def update_tracker_model(model_name):
|
| 703 |
-
"""Update tracker model (placeholder function)"""
|
| 704 |
-
return
|
| 705 |
-
|
| 706 |
def get_video_settings(video_name):
|
| 707 |
"""Get video-specific settings based on video name"""
|
| 708 |
video_settings = {
|
|
@@ -726,68 +640,14 @@ def get_video_settings(video_name):
|
|
| 726 |
"cinema_1": (45, 756, 3),
|
| 727 |
}
|
| 728 |
|
| 729 |
-
return video_settings.get(video_name, (50, 756, 3))
|
| 730 |
-
|
| 731 |
-
def test_backend_connection():
|
| 732 |
-
"""Test if backend is actually working"""
|
| 733 |
-
global BACKEND_AVAILABLE
|
| 734 |
-
if not backend_client:
|
| 735 |
-
return False
|
| 736 |
-
|
| 737 |
-
try:
|
| 738 |
-
print("Testing backend connection with a simple call...")
|
| 739 |
-
# Check if we have fns available
|
| 740 |
-
if hasattr(backend_client, 'fns') and backend_client.fns:
|
| 741 |
-
print("✅ Backend API functions are available")
|
| 742 |
-
print(f"🔧 Available function indices: {list(backend_client.fns.keys())}")
|
| 743 |
-
return True
|
| 744 |
-
else:
|
| 745 |
-
print("❌ Backend API functions not found")
|
| 746 |
-
return False
|
| 747 |
-
except Exception as e:
|
| 748 |
-
print(f"❌ Backend connection test failed: {e}")
|
| 749 |
-
return False
|
| 750 |
-
|
| 751 |
-
def test_backend_api():
|
| 752 |
-
"""Test specific backend API functions"""
|
| 753 |
-
if not BACKEND_AVAILABLE or not backend_client:
|
| 754 |
-
print("❌ Backend not available for testing")
|
| 755 |
-
return False
|
| 756 |
-
|
| 757 |
-
try:
|
| 758 |
-
print("🧪 Testing backend API functions...")
|
| 759 |
-
|
| 760 |
-
# Test if fns exist and show available indices
|
| 761 |
-
if hasattr(backend_client, 'fns') and backend_client.fns:
|
| 762 |
-
print(f"✅ Backend has {len(backend_client.fns)} functions available")
|
| 763 |
-
for idx in backend_client.fns.keys():
|
| 764 |
-
print(f"✅ Function {idx} is available")
|
| 765 |
-
else:
|
| 766 |
-
print("❌ No functions found in backend API")
|
| 767 |
-
return False
|
| 768 |
-
|
| 769 |
-
return True
|
| 770 |
-
|
| 771 |
-
except Exception as e:
|
| 772 |
-
print(f"❌ Backend API test failed: {e}")
|
| 773 |
-
return False
|
| 774 |
-
|
| 775 |
-
# Initialize the backend connection
|
| 776 |
-
print("🚀 Initializing frontend application...")
|
| 777 |
-
result = initialize_backend()
|
| 778 |
-
|
| 779 |
-
# Test backend connection if available
|
| 780 |
-
if result and BACKEND_AVAILABLE:
|
| 781 |
-
print("✅ Backend connection successful!")
|
| 782 |
-
else:
|
| 783 |
-
print("❌ Backend connection failed!")
|
| 784 |
|
| 785 |
# Create the Gradio interface
|
| 786 |
print("🎨 Creating Gradio interface...")
|
| 787 |
|
| 788 |
with gr.Blocks(
|
| 789 |
theme=gr.themes.Soft(),
|
| 790 |
-
title="🎯 [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)
|
| 791 |
css="""
|
| 792 |
.gradio-container {
|
| 793 |
max-width: 1200px !important;
|
|
@@ -997,7 +857,11 @@ with gr.Blocks(
|
|
| 997 |
"""
|
| 998 |
) as demo:
|
| 999 |
|
|
|
|
|
|
|
| 1000 |
gr.Markdown("""
|
|
|
|
|
|
|
| 1001 |
Welcome to [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)! This interface allows you to track any pixels in 3D using our model.
|
| 1002 |
|
| 1003 |
**⚡ Quick Start:** Upload video → Click "Start Tracking Now!"
|
|
@@ -1010,9 +874,8 @@ with gr.Blocks(
|
|
| 1010 |
|
| 1011 |
""")
|
| 1012 |
|
| 1013 |
-
# Status indicator
|
| 1014 |
-
|
| 1015 |
-
gr.Markdown(f"**Status:** {status_info} | Backend: {BACKEND_SPACE_URL}")
|
| 1016 |
|
| 1017 |
# Main content area - video upload left, 3D visualization right
|
| 1018 |
with gr.Row():
|
|
@@ -1151,7 +1014,7 @@ with gr.Blocks(
|
|
| 1151 |
with gr.Row():
|
| 1152 |
reset_points_btn = gr.Button("🔄 Reset Points", variant="secondary", size="sm")
|
| 1153 |
|
| 1154 |
-
# Downloads section - hidden but still functional for
|
| 1155 |
with gr.Row(visible=False):
|
| 1156 |
with gr.Column(scale=1):
|
| 1157 |
tracking_video_download = gr.File(
|
|
@@ -1266,8 +1129,8 @@ with gr.Blocks(
|
|
| 1266 |
|
| 1267 |
# Launch the interface
|
| 1268 |
if __name__ == "__main__":
|
| 1269 |
-
print("🌟 Launching SpatialTracker V2
|
| 1270 |
-
print(
|
| 1271 |
|
| 1272 |
demo.launch(
|
| 1273 |
server_name="0.0.0.0",
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
import cv2
|
| 6 |
import base64
|
|
|
|
| 7 |
import time
|
| 8 |
+
import tempfile
|
| 9 |
+
import shutil
|
| 10 |
+
import glob
|
| 11 |
+
import threading
|
| 12 |
+
import subprocess
|
| 13 |
+
import struct
|
| 14 |
+
import zlib
|
| 15 |
from pathlib import Path
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
from typing import List, Tuple, Union
|
| 18 |
+
import torch
|
| 19 |
+
import logging
|
| 20 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 21 |
+
import atexit
|
| 22 |
+
import uuid
|
| 23 |
|
| 24 |
+
# Configure logging
|
| 25 |
+
logging.basicConfig(level=logging.INFO)
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
+
# Import custom modules with error handling
|
| 29 |
+
try:
|
| 30 |
+
from app_3rd.sam_utils.inference import SamPredictor, get_sam_predictor, run_inference
|
| 31 |
+
from app_3rd.spatrack_utils.infer_track import get_tracker_predictor, run_tracker, get_points_on_a_grid
|
| 32 |
+
except ImportError as e:
|
| 33 |
+
logger.error(f"Failed to import custom modules: {e}")
|
| 34 |
+
raise
|
| 35 |
|
| 36 |
+
try:
|
| 37 |
+
import spaces
|
| 38 |
+
except ImportError:
|
| 39 |
+
# Fallback for local development
|
| 40 |
+
def spaces(func):
|
| 41 |
+
return func
|
| 42 |
|
| 43 |
+
# Constants
|
| 44 |
+
MAX_FRAMES = 80
|
| 45 |
+
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
|
| 46 |
+
MARKERS = [1, 5] # Cross for negative, Star for positive
|
| 47 |
+
MARKER_SIZE = 8
|
| 48 |
+
|
| 49 |
+
# Thread pool for delayed deletion
|
| 50 |
+
thread_pool_executor = ThreadPoolExecutor(max_workers=2)
|
| 51 |
+
|
| 52 |
+
def delete_later(path: Union[str, os.PathLike], delay: int = 600):
|
| 53 |
+
"""Delete file or directory after specified delay (default 10 minutes)"""
|
| 54 |
+
def _delete():
|
| 55 |
+
try:
|
| 56 |
+
if os.path.isfile(path):
|
| 57 |
+
os.remove(path)
|
| 58 |
+
elif os.path.isdir(path):
|
| 59 |
+
shutil.rmtree(path)
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.warning(f"Failed to delete {path}: {e}")
|
| 62 |
+
|
| 63 |
+
def _wait_and_delete():
|
| 64 |
+
time.sleep(delay)
|
| 65 |
+
_delete()
|
| 66 |
+
|
| 67 |
+
thread_pool_executor.submit(_wait_and_delete)
|
| 68 |
+
atexit.register(_delete)
|
| 69 |
+
|
| 70 |
+
def create_user_temp_dir():
|
| 71 |
+
"""Create a unique temporary directory for each user session"""
|
| 72 |
+
session_id = str(uuid.uuid4())[:8] # Short unique ID
|
| 73 |
+
temp_dir = os.path.join("temp_local", f"session_{session_id}")
|
| 74 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 75 |
|
| 76 |
+
# Schedule deletion after 10 minutes
|
| 77 |
+
delete_later(temp_dir, delay=600)
|
| 78 |
+
|
| 79 |
+
return temp_dir
|
| 80 |
+
|
| 81 |
+
from huggingface_hub import hf_hub_download
|
| 82 |
+
# init the model
|
| 83 |
+
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
|
| 84 |
+
|
| 85 |
+
if os.environ.get("VGGT_DIR", None) is not None:
|
| 86 |
+
from models.vggt.vggt.models.vggt_moe import VGGT_MoE
|
| 87 |
+
from models.vggt.vggt.utils.load_fn import preprocess_image
|
| 88 |
+
vggt_model = VGGT_MoE()
|
| 89 |
+
vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")), strict=False)
|
| 90 |
+
vggt_model.eval()
|
| 91 |
+
vggt_model = vggt_model.to("cuda")
|
| 92 |
+
|
| 93 |
+
# Global model initialization
|
| 94 |
+
print("🚀 Initializing local models...")
|
| 95 |
+
tracker_model, _ = get_tracker_predictor(".", vo_points=756)
|
| 96 |
+
predictor = get_sam_predictor()
|
| 97 |
+
print("✅ Models loaded successfully!")
|
| 98 |
+
|
| 99 |
+
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
| 100 |
+
|
| 101 |
+
@spaces.GPU
|
| 102 |
+
def gpu_run_inference(predictor_arg, image, points, boxes):
|
| 103 |
+
"""GPU-accelerated SAM inference"""
|
| 104 |
+
if predictor_arg is None:
|
| 105 |
+
print("Initializing SAM predictor inside GPU function...")
|
| 106 |
+
predictor_arg = get_sam_predictor(predictor=predictor)
|
| 107 |
|
| 108 |
+
# Ensure predictor is on GPU
|
| 109 |
try:
|
| 110 |
+
if hasattr(predictor_arg, 'model'):
|
| 111 |
+
predictor_arg.model = predictor_arg.model.cuda()
|
| 112 |
+
elif hasattr(predictor_arg, 'sam'):
|
| 113 |
+
predictor_arg.sam = predictor_arg.sam.cuda()
|
| 114 |
+
elif hasattr(predictor_arg, 'to'):
|
| 115 |
+
predictor_arg = predictor_arg.to('cuda')
|
| 116 |
|
| 117 |
+
if hasattr(image, 'cuda'):
|
| 118 |
+
image = image.cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
except Exception as e:
|
| 121 |
+
print(f"Warning: Could not move predictor to GPU: {e}")
|
| 122 |
+
|
| 123 |
+
return run_inference(predictor_arg, image, points, boxes)
|
| 124 |
|
| 125 |
+
@spaces.GPU
|
| 126 |
+
def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps):
|
| 127 |
+
"""GPU-accelerated tracking"""
|
| 128 |
+
import torchvision.transforms as T
|
| 129 |
+
import decord
|
| 130 |
+
|
| 131 |
+
if tracker_model_arg is None or tracker_viser_arg is None:
|
| 132 |
+
print("Initializing tracker models inside GPU function...")
|
| 133 |
+
out_dir = os.path.join(temp_dir, "results")
|
| 134 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 135 |
+
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, tracker_model=tracker_model)
|
| 136 |
+
|
| 137 |
+
# Setup paths
|
| 138 |
+
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
| 139 |
+
mask_path = os.path.join(temp_dir, f"{video_name}.png")
|
| 140 |
+
out_dir = os.path.join(temp_dir, "results")
|
| 141 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 142 |
+
|
| 143 |
+
# Load video using decord
|
| 144 |
+
video_reader = decord.VideoReader(video_path)
|
| 145 |
+
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2)
|
| 146 |
+
|
| 147 |
+
# Resize to ensure minimum side is 336
|
| 148 |
+
h, w = video_tensor.shape[2:]
|
| 149 |
+
scale = max(224 / h, 224 / w)
|
| 150 |
+
if scale < 1:
|
| 151 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 152 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 153 |
+
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES]
|
| 154 |
+
|
| 155 |
+
# Move to GPU
|
| 156 |
+
video_tensor = video_tensor.cuda()
|
| 157 |
+
print(f"Video tensor shape: {video_tensor.shape}, device: {video_tensor.device}")
|
| 158 |
+
|
| 159 |
+
depth_tensor = None
|
| 160 |
+
intrs = None
|
| 161 |
+
extrs = None
|
| 162 |
+
data_npz_load = {}
|
| 163 |
+
|
| 164 |
+
# run vggt
|
| 165 |
+
if os.environ.get("VGGT_DIR", None) is not None:
|
| 166 |
+
# process the image tensor
|
| 167 |
+
video_tensor = preprocess_image(video_tensor)[None]
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 170 |
+
# Predict attributes including cameras, depth maps, and point maps.
|
| 171 |
+
predictions = vggt_model(video_tensor.cuda()/255)
|
| 172 |
+
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
|
| 173 |
+
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
|
| 174 |
|
| 175 |
+
depth_tensor = depth_map.squeeze().cpu().numpy()
|
| 176 |
+
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
| 177 |
+
extrs = extrinsic.squeeze().cpu().numpy()
|
| 178 |
+
intrs = intrinsic.squeeze().cpu().numpy()
|
| 179 |
+
video_tensor = video_tensor.squeeze()
|
| 180 |
+
#NOTE: 20% of the depth is not reliable
|
| 181 |
+
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
|
| 182 |
+
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
| 183 |
+
|
| 184 |
+
# Load and process mask
|
| 185 |
+
if os.path.exists(mask_path):
|
| 186 |
+
mask = cv2.imread(mask_path)
|
| 187 |
+
mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
|
| 188 |
+
mask = mask.sum(axis=-1)>0
|
| 189 |
+
else:
|
| 190 |
+
mask = np.ones_like(video_tensor[0,0].cpu().numpy())>0
|
| 191 |
+
grid_size = 10
|
| 192 |
+
|
| 193 |
+
# Get frame dimensions and create grid points
|
| 194 |
+
frame_H, frame_W = video_tensor.shape[2:]
|
| 195 |
+
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cuda")
|
| 196 |
+
|
| 197 |
+
# Sample mask values at grid points and filter
|
| 198 |
+
if os.path.exists(mask_path):
|
| 199 |
+
grid_pts_int = grid_pts[0].long()
|
| 200 |
+
mask_values = mask[grid_pts_int.cpu()[...,1], grid_pts_int.cpu()[...,0]]
|
| 201 |
+
grid_pts = grid_pts[:, mask_values]
|
| 202 |
+
|
| 203 |
+
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
|
| 204 |
+
print(f"Query points shape: {query_xyt.shape}")
|
| 205 |
+
|
| 206 |
+
# Run model inference
|
| 207 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 208 |
+
(
|
| 209 |
+
c2w_traj, intrs, point_map, conf_depth,
|
| 210 |
+
track3d_pred, track2d_pred, vis_pred, conf_pred, video
|
| 211 |
+
) = tracker_model_arg.forward(video_tensor, depth=depth_tensor,
|
| 212 |
+
intrs=intrs, extrs=extrs,
|
| 213 |
+
queries=query_xyt,
|
| 214 |
+
fps=1, full_point=False, iters_track=4,
|
| 215 |
+
query_no_BA=True, fixed_cam=False, stage=1,
|
| 216 |
+
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
| 217 |
|
| 218 |
+
# Resize results to avoid large I/O
|
| 219 |
+
max_size = 224
|
| 220 |
+
h, w = video.shape[2:]
|
| 221 |
+
scale = min(max_size / h, max_size / w)
|
| 222 |
+
if scale < 1:
|
| 223 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 224 |
+
video = T.Resize((new_h, new_w))(video)
|
| 225 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 226 |
+
point_map = T.Resize((new_h, new_w))(point_map)
|
| 227 |
+
track2d_pred[...,:2] = track2d_pred[...,:2] * scale
|
| 228 |
+
intrs[:,:2,:] = intrs[:,:2,:] * scale
|
| 229 |
+
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
| 230 |
|
| 231 |
+
# Visualize tracks
|
| 232 |
+
tracker_viser_arg.visualize(video=video[None],
|
| 233 |
+
tracks=track2d_pred[None][...,:2],
|
| 234 |
+
visibility=vis_pred[None],filename="test")
|
| 235 |
+
|
| 236 |
+
# Save in tapip3d format
|
| 237 |
+
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
|
| 238 |
+
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
|
| 239 |
+
data_npz_load["intrinsics"] = intrs.cpu().numpy()
|
| 240 |
+
data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
|
| 241 |
+
data_npz_load["video"] = (video_tensor).cpu().numpy()/255
|
| 242 |
+
data_npz_load["visibs"] = vis_pred.cpu().numpy()
|
| 243 |
+
data_npz_load["confs"] = conf_pred.cpu().numpy()
|
| 244 |
+
data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
|
| 245 |
+
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
+
return None
|
| 248 |
+
|
| 249 |
+
def compress_and_write(filename, header, blob):
|
| 250 |
+
header_bytes = json.dumps(header).encode("utf-8")
|
| 251 |
+
header_len = struct.pack("<I", len(header_bytes))
|
| 252 |
+
with open(filename, "wb") as f:
|
| 253 |
+
f.write(header_len)
|
| 254 |
+
f.write(header_bytes)
|
| 255 |
+
f.write(blob)
|
| 256 |
|
| 257 |
+
def process_point_cloud_data(npz_file, width=256, height=192, fps=4):
|
| 258 |
+
fixed_size = (width, height)
|
|
|
|
| 259 |
|
| 260 |
+
data = np.load(npz_file)
|
| 261 |
+
extrinsics = data["extrinsics"]
|
| 262 |
+
intrinsics = data["intrinsics"]
|
| 263 |
+
trajs = data["coords"]
|
| 264 |
+
T, C, H, W = data["video"].shape
|
| 265 |
+
|
| 266 |
+
fx = intrinsics[0, 0, 0]
|
| 267 |
+
fy = intrinsics[0, 1, 1]
|
| 268 |
+
fov_y = 2 * np.arctan(H / (2 * fy)) * (180 / np.pi)
|
| 269 |
+
fov_x = 2 * np.arctan(W / (2 * fx)) * (180 / np.pi)
|
| 270 |
+
original_aspect_ratio = (W / fx) / (H / fy)
|
| 271 |
+
|
| 272 |
+
rgb_video = (rearrange(data["video"], "T C H W -> T H W C") * 255).astype(np.uint8)
|
| 273 |
+
rgb_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_AREA)
|
| 274 |
+
for frame in rgb_video])
|
| 275 |
+
|
| 276 |
+
depth_video = data["depths"].astype(np.float32)
|
| 277 |
+
if "confs_depth" in data.keys():
|
| 278 |
+
confs = (data["confs_depth"].astype(np.float32) > 0.5).astype(np.float32)
|
| 279 |
+
depth_video = depth_video * confs
|
| 280 |
+
depth_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_NEAREST)
|
| 281 |
+
for frame in depth_video])
|
| 282 |
+
|
| 283 |
+
scale_x = fixed_size[0] / W
|
| 284 |
+
scale_y = fixed_size[1] / H
|
| 285 |
+
intrinsics = intrinsics.copy()
|
| 286 |
+
intrinsics[:, 0, :] *= scale_x
|
| 287 |
+
intrinsics[:, 1, :] *= scale_y
|
| 288 |
+
|
| 289 |
+
min_depth = float(depth_video.min()) * 0.8
|
| 290 |
+
max_depth = float(depth_video.max()) * 1.5
|
| 291 |
+
|
| 292 |
+
depth_normalized = (depth_video - min_depth) / (max_depth - min_depth)
|
| 293 |
+
depth_int = (depth_normalized * ((1 << 16) - 1)).astype(np.uint16)
|
| 294 |
+
|
| 295 |
+
depths_rgb = np.zeros((T, fixed_size[1], fixed_size[0], 3), dtype=np.uint8)
|
| 296 |
+
depths_rgb[:, :, :, 0] = (depth_int & 0xFF).astype(np.uint8)
|
| 297 |
+
depths_rgb[:, :, :, 1] = ((depth_int >> 8) & 0xFF).astype(np.uint8)
|
| 298 |
+
|
| 299 |
+
first_frame_inv = np.linalg.inv(extrinsics[0])
|
| 300 |
+
normalized_extrinsics = np.array([first_frame_inv @ ext for ext in extrinsics])
|
| 301 |
+
|
| 302 |
+
normalized_trajs = np.zeros_like(trajs)
|
| 303 |
+
for t in range(T):
|
| 304 |
+
homogeneous_trajs = np.concatenate([trajs[t], np.ones((trajs.shape[1], 1))], axis=1)
|
| 305 |
+
transformed_trajs = (first_frame_inv @ homogeneous_trajs.T).T
|
| 306 |
+
normalized_trajs[t] = transformed_trajs[:, :3]
|
| 307 |
+
|
| 308 |
+
arrays = {
|
| 309 |
+
"rgb_video": rgb_video,
|
| 310 |
+
"depths_rgb": depths_rgb,
|
| 311 |
+
"intrinsics": intrinsics,
|
| 312 |
+
"extrinsics": normalized_extrinsics,
|
| 313 |
+
"inv_extrinsics": np.linalg.inv(normalized_extrinsics),
|
| 314 |
+
"trajectories": normalized_trajs.astype(np.float32),
|
| 315 |
+
"cameraZ": 0.0
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
header = {}
|
| 319 |
+
blob_parts = []
|
| 320 |
+
offset = 0
|
| 321 |
+
for key, arr in arrays.items():
|
| 322 |
+
arr = np.ascontiguousarray(arr)
|
| 323 |
+
arr_bytes = arr.tobytes()
|
| 324 |
+
header[key] = {
|
| 325 |
+
"dtype": str(arr.dtype),
|
| 326 |
+
"shape": arr.shape,
|
| 327 |
+
"offset": offset,
|
| 328 |
+
"length": len(arr_bytes)
|
| 329 |
+
}
|
| 330 |
+
blob_parts.append(arr_bytes)
|
| 331 |
+
offset += len(arr_bytes)
|
| 332 |
+
|
| 333 |
+
raw_blob = b"".join(blob_parts)
|
| 334 |
+
compressed_blob = zlib.compress(raw_blob, level=9)
|
| 335 |
+
|
| 336 |
+
header["meta"] = {
|
| 337 |
+
"depthRange": [min_depth, max_depth],
|
| 338 |
+
"totalFrames": int(T),
|
| 339 |
+
"resolution": fixed_size,
|
| 340 |
+
"baseFrameRate": fps,
|
| 341 |
+
"numTrajectoryPoints": normalized_trajs.shape[1],
|
| 342 |
+
"fov": float(fov_y),
|
| 343 |
+
"fov_x": float(fov_x),
|
| 344 |
+
"original_aspect_ratio": float(original_aspect_ratio),
|
| 345 |
+
"fixed_aspect_ratio": float(fixed_size[0]/fixed_size[1])
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
compress_and_write('./_viz/data.bin', header, compressed_blob)
|
| 349 |
+
with open('./_viz/data.bin', "rb") as f:
|
| 350 |
+
encoded_blob = base64.b64encode(f.read()).decode("ascii")
|
| 351 |
+
os.unlink('./_viz/data.bin')
|
| 352 |
+
|
| 353 |
+
random_path = f'./_viz/_{time.time()}.html'
|
| 354 |
+
with open('./_viz/viz_template.html') as f:
|
| 355 |
+
html_template = f.read()
|
| 356 |
+
html_out = html_template.replace(
|
| 357 |
+
"<head>",
|
| 358 |
+
f"<head>\n<script>window.embeddedBase64 = `{encoded_blob}`;</script>"
|
| 359 |
+
)
|
| 360 |
+
with open(random_path,'w') as f:
|
| 361 |
+
f.write(html_out)
|
| 362 |
+
|
| 363 |
+
return random_path
|
| 364 |
|
| 365 |
def numpy_to_base64(arr):
|
| 366 |
"""Convert numpy array to base64 string"""
|
|
|
|
| 370 |
"""Convert base64 string back to numpy array"""
|
| 371 |
return np.frombuffer(base64.b64decode(b64_str), dtype=dtype).reshape(shape)
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
def get_video_name(video_path):
|
| 374 |
"""Extract video name without extension"""
|
| 375 |
return os.path.splitext(os.path.basename(video_path))[0]
|
|
|
|
| 382 |
cap.release()
|
| 383 |
|
| 384 |
if ret:
|
|
|
|
| 385 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 386 |
return frame_rgb
|
| 387 |
else:
|
|
|
|
| 398 |
gr.update(value=756),
|
| 399 |
gr.update(value=3))
|
| 400 |
|
| 401 |
+
# Create user-specific temporary directory
|
| 402 |
+
user_temp_dir = create_user_temp_dir()
|
| 403 |
+
|
| 404 |
+
# Get original video name and copy to temp directory
|
| 405 |
+
if isinstance(video, str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
video_name = get_video_name(video)
|
| 407 |
+
video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
|
| 408 |
+
shutil.copy(video, video_path)
|
| 409 |
+
else:
|
| 410 |
+
video_name = get_video_name(video.name)
|
| 411 |
+
video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
|
| 412 |
+
with open(video_path, 'wb') as f:
|
| 413 |
+
f.write(video.read())
|
| 414 |
+
|
| 415 |
+
print(f"📁 Video saved to: {video_path}")
|
| 416 |
+
|
| 417 |
+
# Extract first frame
|
| 418 |
+
frame = extract_first_frame(video_path)
|
| 419 |
+
if frame is None:
|
| 420 |
return (None, None, [],
|
| 421 |
gr.update(value=50),
|
| 422 |
gr.update(value=756),
|
| 423 |
gr.update(value=3))
|
| 424 |
+
|
| 425 |
+
# Resize frame to have minimum side length of 336
|
| 426 |
+
h, w = frame.shape[:2]
|
| 427 |
+
scale = 336 / min(h, w)
|
| 428 |
+
new_h, new_w = int(h * scale)//2*2, int(w * scale)//2*2
|
| 429 |
+
frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
| 430 |
+
|
| 431 |
+
# Store frame data with temp directory info
|
| 432 |
+
frame_data = {
|
| 433 |
+
'data': numpy_to_base64(frame),
|
| 434 |
+
'shape': frame.shape,
|
| 435 |
+
'dtype': str(frame.dtype),
|
| 436 |
+
'temp_dir': user_temp_dir,
|
| 437 |
+
'video_name': video_name,
|
| 438 |
+
'video_path': video_path
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
# Get video-specific settings
|
| 442 |
+
print(f"🎬 Video path: '{video}' -> Video name: '{video_name}'")
|
| 443 |
+
grid_size_val, vo_points_val, fps_val = get_video_settings(video_name)
|
| 444 |
+
print(f"🎬 Video settings for '{video_name}': grid_size={grid_size_val}, vo_points={vo_points_val}, fps={fps_val}")
|
| 445 |
+
|
| 446 |
+
return (json.dumps(frame_data), frame, [],
|
| 447 |
+
gr.update(value=grid_size_val),
|
| 448 |
+
gr.update(value=vo_points_val),
|
| 449 |
+
gr.update(value=fps_val))
|
| 450 |
+
|
| 451 |
+
def save_masks(o_masks, video_name, temp_dir):
|
| 452 |
+
"""Save binary masks to files in user-specific temp directory"""
|
| 453 |
+
o_files = []
|
| 454 |
+
for mask, _ in o_masks:
|
| 455 |
+
o_mask = np.uint8(mask.squeeze() * 255)
|
| 456 |
+
o_file = os.path.join(temp_dir, f"{video_name}.png")
|
| 457 |
+
cv2.imwrite(o_file, o_mask)
|
| 458 |
+
o_files.append(o_file)
|
| 459 |
+
return o_files
|
| 460 |
|
| 461 |
def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.SelectData):
|
| 462 |
"""Handle point selection for SAM"""
|
|
|
|
| 464 |
return None, []
|
| 465 |
|
| 466 |
try:
|
| 467 |
+
# Convert stored image data back to numpy array
|
| 468 |
+
frame_data = json.loads(original_img)
|
| 469 |
+
original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
|
| 470 |
+
temp_dir = frame_data.get('temp_dir', 'temp_local')
|
| 471 |
+
video_name = frame_data.get('video_name', 'video')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
+
# Create a display image for visualization
|
| 474 |
+
display_img = original_img_array.copy()
|
| 475 |
+
new_sel_pix = sel_pix.copy() if sel_pix else []
|
| 476 |
+
new_sel_pix.append((evt.index, 1 if point_type == 'positive_point' else 0))
|
| 477 |
|
| 478 |
+
print(f"🎯 Running SAM inference for point: {evt.index}, type: {point_type}")
|
| 479 |
+
# Run SAM inference
|
| 480 |
+
o_masks = gpu_run_inference(None, original_img_array, new_sel_pix, [])
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
+
# Draw points on display image
|
| 483 |
+
for point, label in new_sel_pix:
|
| 484 |
+
cv2.drawMarker(display_img, point, COLORS[label], markerType=MARKERS[label], markerSize=MARKER_SIZE, thickness=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
+
# Draw mask overlay on display image
|
| 487 |
+
if o_masks:
|
| 488 |
+
mask = o_masks[0][0]
|
| 489 |
+
overlay = display_img.copy()
|
| 490 |
+
overlay[mask.squeeze()!=0] = [20, 60, 200] # Light blue
|
| 491 |
+
display_img = cv2.addWeighted(overlay, 0.6, display_img, 0.4, 0)
|
| 492 |
+
|
| 493 |
+
# Save mask for tracking
|
| 494 |
+
save_masks(o_masks, video_name, temp_dir)
|
| 495 |
+
print(f"✅ Mask saved for video: {video_name}")
|
| 496 |
+
|
| 497 |
+
return display_img, new_sel_pix
|
| 498 |
|
| 499 |
except Exception as e:
|
| 500 |
+
print(f"❌ Error in select_point: {e}")
|
| 501 |
return None, []
|
| 502 |
|
| 503 |
def reset_points(original_img: str, sel_pix):
|
| 504 |
+
"""Reset all points and clear the mask"""
|
| 505 |
if original_img is None:
|
| 506 |
return None, []
|
| 507 |
|
| 508 |
try:
|
| 509 |
+
# Convert stored image data back to numpy array
|
| 510 |
+
frame_data = json.loads(original_img)
|
| 511 |
+
original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
|
| 512 |
+
temp_dir = frame_data.get('temp_dir', 'temp_local')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
+
# Create a display image (just the original image)
|
| 515 |
+
display_img = original_img_array.copy()
|
| 516 |
|
| 517 |
+
# Clear all points
|
| 518 |
+
new_sel_pix = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
| 520 |
+
# Clear any existing masks
|
| 521 |
+
for mask_file in glob.glob(os.path.join(temp_dir, "*.png")):
|
| 522 |
+
try:
|
| 523 |
+
os.remove(mask_file)
|
| 524 |
+
except Exception as e:
|
| 525 |
+
logger.warning(f"Failed to remove mask file {mask_file}: {e}")
|
| 526 |
|
| 527 |
+
print("🔄 Points and masks reset")
|
| 528 |
+
return display_img, new_sel_pix
|
| 529 |
|
| 530 |
except Exception as e:
|
| 531 |
+
print(f"❌ Error in reset_points: {e}")
|
| 532 |
return None, []
|
| 533 |
|
|
|
|
|
|
|
| 534 |
def launch_viz(grid_size, vo_points, fps, original_image_state):
|
| 535 |
"""Launch visualization with user-specific temp directory"""
|
| 536 |
if original_image_state is None:
|
| 537 |
return None, None, None
|
| 538 |
|
| 539 |
try:
|
| 540 |
+
# Get user's temp directory from stored frame data
|
| 541 |
+
frame_data = json.loads(original_image_state)
|
| 542 |
+
temp_dir = frame_data.get('temp_dir', 'temp_local')
|
| 543 |
+
video_name = frame_data.get('video_name', 'video')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
|
| 545 |
+
print(f"🚀 Starting tracking for video: {video_name}")
|
| 546 |
+
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
|
| 548 |
+
# Check for mask files
|
| 549 |
+
mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
|
| 550 |
+
video_files = glob.glob(os.path.join(temp_dir, "*.mp4"))
|
| 551 |
+
|
| 552 |
+
if not video_files:
|
| 553 |
+
print("❌ No video file found")
|
| 554 |
+
return "❌ Error: No video file found", None, None
|
| 555 |
+
|
| 556 |
+
video_path = video_files[0]
|
| 557 |
+
mask_path = mask_files[0] if mask_files else None
|
| 558 |
+
|
| 559 |
+
# Run tracker
|
| 560 |
+
print("🎯 Running tracker...")
|
| 561 |
+
out_dir = os.path.join(temp_dir, "results")
|
| 562 |
+
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
|
| 564 |
+
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps)
|
| 565 |
+
|
| 566 |
+
# Process results
|
| 567 |
+
npz_path = os.path.join(out_dir, "result.npz")
|
| 568 |
+
track2d_video = os.path.join(out_dir, "test_pred_track.mp4")
|
| 569 |
+
|
| 570 |
+
if os.path.exists(npz_path):
|
| 571 |
+
print("📊 Processing 3D visualization...")
|
| 572 |
+
html_path = process_point_cloud_data(npz_path)
|
| 573 |
+
|
| 574 |
+
# Schedule deletion of generated files
|
| 575 |
+
delete_later(html_path, delay=600)
|
| 576 |
+
if os.path.exists(track2d_video):
|
| 577 |
+
delete_later(track2d_video, delay=600)
|
| 578 |
+
delete_later(npz_path, delay=600)
|
| 579 |
+
|
| 580 |
+
# Create iframe HTML
|
| 581 |
+
iframe_html = f"""
|
| 582 |
+
<div style='border: 3px solid #667eea; border-radius: 10px;
|
| 583 |
+
background: #f8f9ff; height: 650px; width: 100%;
|
| 584 |
+
box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3);
|
| 585 |
+
margin: 0; padding: 0; box-sizing: border-box; overflow: hidden;'>
|
| 586 |
+
<iframe id="viz_iframe" src="/gradio_api/file={html_path}"
|
| 587 |
+
width="100%" height="650" frameborder="0"
|
| 588 |
+
style="border: none; display: block; width: 100%; height: 650px;
|
| 589 |
+
margin: 0; padding: 0; border-radius: 7px;">
|
| 590 |
+
</iframe>
|
| 591 |
+
</div>
|
| 592 |
+
"""
|
| 593 |
+
|
| 594 |
+
print("✅ Tracking completed successfully!")
|
| 595 |
+
return iframe_html, track2d_video if os.path.exists(track2d_video) else None, html_path
|
| 596 |
+
else:
|
| 597 |
+
print("❌ Tracking failed - no results generated")
|
| 598 |
+
return "❌ Error: Tracking failed to generate results", None, None
|
| 599 |
+
|
| 600 |
except Exception as e:
|
| 601 |
+
print(f"❌ Error in launch_viz: {e}")
|
| 602 |
+
return f"❌ Error: {str(e)}", None, None
|
| 603 |
|
| 604 |
def clear_all():
|
| 605 |
"""Clear all buffers and temporary files"""
|
|
|
|
| 617 |
None, # tracking_video_download
|
| 618 |
None) # HTML download component
|
| 619 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
def get_video_settings(video_name):
|
| 621 |
"""Get video-specific settings based on video name"""
|
| 622 |
video_settings = {
|
|
|
|
| 640 |
"cinema_1": (45, 756, 3),
|
| 641 |
}
|
| 642 |
|
| 643 |
+
return video_settings.get(video_name, (50, 756, 3))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
|
| 645 |
# Create the Gradio interface
|
| 646 |
print("🎨 Creating Gradio interface...")
|
| 647 |
|
| 648 |
with gr.Blocks(
|
| 649 |
theme=gr.themes.Soft(),
|
| 650 |
+
title="🎯 [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)",
|
| 651 |
css="""
|
| 652 |
.gradio-container {
|
| 653 |
max-width: 1200px !important;
|
|
|
|
| 857 |
"""
|
| 858 |
) as demo:
|
| 859 |
|
| 860 |
+
# Add prominent main title
|
| 861 |
+
|
| 862 |
gr.Markdown("""
|
| 863 |
+
# ✨ SpatialTrackerV2
|
| 864 |
+
|
| 865 |
Welcome to [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)! This interface allows you to track any pixels in 3D using our model.
|
| 866 |
|
| 867 |
**⚡ Quick Start:** Upload video → Click "Start Tracking Now!"
|
|
|
|
| 874 |
|
| 875 |
""")
|
| 876 |
|
| 877 |
+
# Status indicator
|
| 878 |
+
gr.Markdown("**Status:** 🟢 Local Processing Mode")
|
|
|
|
| 879 |
|
| 880 |
# Main content area - video upload left, 3D visualization right
|
| 881 |
with gr.Row():
|
|
|
|
| 1014 |
with gr.Row():
|
| 1015 |
reset_points_btn = gr.Button("🔄 Reset Points", variant="secondary", size="sm")
|
| 1016 |
|
| 1017 |
+
# Downloads section - hidden but still functional for local processing
|
| 1018 |
with gr.Row(visible=False):
|
| 1019 |
with gr.Column(scale=1):
|
| 1020 |
tracking_video_download = gr.File(
|
|
|
|
| 1129 |
|
| 1130 |
# Launch the interface
|
| 1131 |
if __name__ == "__main__":
|
| 1132 |
+
print("🌟 Launching SpatialTracker V2 Local Version...")
|
| 1133 |
+
print("🔗 Running in Local Processing Mode")
|
| 1134 |
|
| 1135 |
demo.launch(
|
| 1136 |
server_name="0.0.0.0",
|
app_3rd/README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🌟 SpatialTrackerV2 Integrated with SAM 🌟
|
| 2 |
+
SAM receives a point prompt and generates a mask for the target object, facilitating easy interaction to obtain the object's 3D trajectories with SpaTrack2.
|
| 3 |
+
|
| 4 |
+
## Installation
|
| 5 |
+
```
|
| 6 |
+
|
| 7 |
+
python -m pip install git+https://github.com/facebookresearch/segment-anything.git
|
| 8 |
+
cd app_3rd/sam_utils
|
| 9 |
+
mkdir checkpoints
|
| 10 |
+
cd checkpoints
|
| 11 |
+
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
| 12 |
+
```
|
app_3rd/sam_utils/hf_sam_predictor.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Optional, Tuple, List, Union
|
| 5 |
+
import warnings
|
| 6 |
+
import cv2
|
| 7 |
+
try:
|
| 8 |
+
from transformers import SamModel, SamProcessor
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
HF_AVAILABLE = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
HF_AVAILABLE = False
|
| 13 |
+
warnings.warn("transformers or huggingface_hub not available. HF SAM models will not work.")
|
| 14 |
+
|
| 15 |
+
# Hugging Face model mapping
|
| 16 |
+
HF_MODELS = {
|
| 17 |
+
'vit_b': 'facebook/sam-vit-base',
|
| 18 |
+
'vit_l': 'facebook/sam-vit-large',
|
| 19 |
+
'vit_h': 'facebook/sam-vit-huge'
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
class HFSamPredictor:
|
| 23 |
+
"""
|
| 24 |
+
Hugging Face version of SamPredictor that wraps the transformers SAM models.
|
| 25 |
+
This class provides the same interface as the original SamPredictor for seamless integration.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, model: SamModel, processor: SamProcessor, device: Optional[str] = None):
|
| 29 |
+
"""
|
| 30 |
+
Initialize the HF SAM predictor.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model: The SAM model from transformers
|
| 34 |
+
processor: The SAM processor from transformers
|
| 35 |
+
device: Device to run the model on ('cuda', 'cpu', etc.)
|
| 36 |
+
"""
|
| 37 |
+
self.model = model
|
| 38 |
+
self.processor = processor
|
| 39 |
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
| 40 |
+
self.model.to(self.device)
|
| 41 |
+
self.model.eval()
|
| 42 |
+
|
| 43 |
+
# Store the current image and its features
|
| 44 |
+
self.original_size = None
|
| 45 |
+
self.input_size = None
|
| 46 |
+
self.features = None
|
| 47 |
+
self.image = None
|
| 48 |
+
|
| 49 |
+
@classmethod
|
| 50 |
+
def from_pretrained(cls, model_name: str, device: Optional[str] = None) -> 'HFSamPredictor':
|
| 51 |
+
"""
|
| 52 |
+
Load a SAM model from Hugging Face Hub.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
model_name: Model name from HF_MODELS or direct HF model path
|
| 56 |
+
device: Device to load the model on
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
HFSamPredictor instance
|
| 60 |
+
"""
|
| 61 |
+
if not HF_AVAILABLE:
|
| 62 |
+
raise ImportError("transformers and huggingface_hub are required for HF SAM models")
|
| 63 |
+
|
| 64 |
+
# Map model type to HF model name if needed
|
| 65 |
+
if model_name in HF_MODELS:
|
| 66 |
+
model_name = HF_MODELS[model_name]
|
| 67 |
+
|
| 68 |
+
print(f"Loading SAM model from Hugging Face: {model_name}")
|
| 69 |
+
|
| 70 |
+
# Load model and processor
|
| 71 |
+
model = SamModel.from_pretrained(model_name)
|
| 72 |
+
processor = SamProcessor.from_pretrained(model_name)
|
| 73 |
+
return cls(model, processor, device)
|
| 74 |
+
|
| 75 |
+
def preprocess(self, image: np.ndarray,
|
| 76 |
+
input_points: List[List[float]], input_labels: List[int]) -> None:
|
| 77 |
+
"""
|
| 78 |
+
Set the image for prediction. This preprocesses the image and extracts features.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
image: Input image as numpy array (H, W, C) in RGB format
|
| 82 |
+
"""
|
| 83 |
+
if image.dtype != np.uint8:
|
| 84 |
+
image = (image * 255).astype(np.uint8)
|
| 85 |
+
|
| 86 |
+
self.image = image
|
| 87 |
+
self.original_size = image.shape[:2]
|
| 88 |
+
|
| 89 |
+
# Use dummy point to ensure processor returns original_sizes & reshaped_input_sizes
|
| 90 |
+
inputs = self.processor(
|
| 91 |
+
images=image,
|
| 92 |
+
input_points=input_points,
|
| 93 |
+
input_labels=input_labels,
|
| 94 |
+
return_tensors="pt"
|
| 95 |
+
)
|
| 96 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 97 |
+
|
| 98 |
+
self.input_size = inputs['pixel_values'].shape[-2:]
|
| 99 |
+
self.features = inputs
|
| 100 |
+
return inputs
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_hf_sam_predictor(model_type: str = 'vit_h', device: Optional[str] = None,
|
| 104 |
+
image: Optional[np.ndarray] = None) -> HFSamPredictor:
|
| 105 |
+
"""
|
| 106 |
+
Get a Hugging Face SAM predictor with the same interface as the original get_sam_predictor.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
model_type: Model type ('vit_b', 'vit_l', 'vit_h')
|
| 110 |
+
device: Device to run the model on
|
| 111 |
+
image: Optional image to set immediately
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
HFSamPredictor instance
|
| 115 |
+
"""
|
| 116 |
+
if not HF_AVAILABLE:
|
| 117 |
+
raise ImportError("transformers and huggingface_hub are required for HF SAM models")
|
| 118 |
+
|
| 119 |
+
if device is None:
|
| 120 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 121 |
+
|
| 122 |
+
# Load the predictor
|
| 123 |
+
predictor = HFSamPredictor.from_pretrained(model_type, device)
|
| 124 |
+
|
| 125 |
+
# Set image if provided
|
| 126 |
+
if image is not None:
|
| 127 |
+
predictor.set_image(image)
|
| 128 |
+
|
| 129 |
+
return predictor
|
app_3rd/sam_utils/inference.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from segment_anything import SamPredictor, sam_model_registry
|
| 6 |
+
|
| 7 |
+
# Try to import HF SAM support
|
| 8 |
+
try:
|
| 9 |
+
from app_3rd.sam_utils.hf_sam_predictor import get_hf_sam_predictor, HFSamPredictor
|
| 10 |
+
HF_AVAILABLE = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
HF_AVAILABLE = False
|
| 13 |
+
|
| 14 |
+
models = {
|
| 15 |
+
'vit_b': 'app_3rd/sam_utils/checkpoints/sam_vit_b_01ec64.pth',
|
| 16 |
+
'vit_l': 'app_3rd/sam_utils/checkpoints/sam_vit_l_0b3195.pth',
|
| 17 |
+
'vit_h': 'app_3rd/sam_utils/checkpoints/sam_vit_h_4b8939.pth'
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_sam_predictor(model_type='vit_b', device=None, image=None, use_hf=True, predictor=None):
|
| 22 |
+
"""
|
| 23 |
+
Get SAM predictor with option to use HuggingFace version
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
model_type: Model type ('vit_b', 'vit_l', 'vit_h')
|
| 27 |
+
device: Device to run on
|
| 28 |
+
image: Optional image to set immediately
|
| 29 |
+
use_hf: Whether to use HuggingFace SAM instead of original SAM
|
| 30 |
+
"""
|
| 31 |
+
if predictor is not None:
|
| 32 |
+
return predictor
|
| 33 |
+
if use_hf:
|
| 34 |
+
if not HF_AVAILABLE:
|
| 35 |
+
raise ImportError("HuggingFace SAM not available. Install transformers and huggingface_hub.")
|
| 36 |
+
return get_hf_sam_predictor(model_type, device, image)
|
| 37 |
+
|
| 38 |
+
# Original SAM logic
|
| 39 |
+
if device is None and torch.cuda.is_available():
|
| 40 |
+
device = 'cuda'
|
| 41 |
+
elif device is None:
|
| 42 |
+
device = 'cpu'
|
| 43 |
+
# sam model
|
| 44 |
+
sam = sam_model_registry[model_type](checkpoint=models[model_type])
|
| 45 |
+
sam = sam.to(device)
|
| 46 |
+
|
| 47 |
+
predictor = SamPredictor(sam)
|
| 48 |
+
if image is not None:
|
| 49 |
+
predictor.set_image(image)
|
| 50 |
+
return predictor
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def run_inference(predictor, input_x, selected_points, multi_object: bool = False):
|
| 54 |
+
"""
|
| 55 |
+
Run inference with either original SAM or HF SAM predictor
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
predictor: SamPredictor or HFSamPredictor instance
|
| 59 |
+
input_x: Input image
|
| 60 |
+
selected_points: List of (point, label) tuples
|
| 61 |
+
multi_object: Whether to handle multiple objects
|
| 62 |
+
"""
|
| 63 |
+
if len(selected_points) == 0:
|
| 64 |
+
return []
|
| 65 |
+
|
| 66 |
+
# Check if using HF SAM
|
| 67 |
+
if isinstance(predictor, HFSamPredictor):
|
| 68 |
+
return _run_hf_inference(predictor, input_x, selected_points, multi_object)
|
| 69 |
+
else:
|
| 70 |
+
return _run_original_inference(predictor, input_x, selected_points, multi_object)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _run_original_inference(predictor: SamPredictor, input_x, selected_points, multi_object: bool = False):
|
| 74 |
+
"""Run inference with original SAM"""
|
| 75 |
+
points = torch.Tensor(
|
| 76 |
+
[p for p, _ in selected_points]
|
| 77 |
+
).to(predictor.device).unsqueeze(1)
|
| 78 |
+
|
| 79 |
+
labels = torch.Tensor(
|
| 80 |
+
[int(l) for _, l in selected_points]
|
| 81 |
+
).to(predictor.device).unsqueeze(1)
|
| 82 |
+
|
| 83 |
+
transformed_points = predictor.transform.apply_coords_torch(
|
| 84 |
+
points, input_x.shape[:2])
|
| 85 |
+
|
| 86 |
+
masks, scores, logits = predictor.predict_torch(
|
| 87 |
+
point_coords=transformed_points[:,0][None],
|
| 88 |
+
point_labels=labels[:,0][None],
|
| 89 |
+
multimask_output=False,
|
| 90 |
+
)
|
| 91 |
+
masks = masks[0].cpu().numpy() # N 1 H W N is the number of points
|
| 92 |
+
|
| 93 |
+
gc.collect()
|
| 94 |
+
torch.cuda.empty_cache()
|
| 95 |
+
|
| 96 |
+
return [(masks, 'final_mask')]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _run_hf_inference(predictor: HFSamPredictor, input_x, selected_points, multi_object: bool = False):
|
| 100 |
+
"""Run inference with HF SAM"""
|
| 101 |
+
# Prepare points and labels for HF SAM
|
| 102 |
+
select_pts = [[list(p) for p, _ in selected_points]]
|
| 103 |
+
select_lbls = [[int(l) for _, l in selected_points]]
|
| 104 |
+
|
| 105 |
+
# Preprocess inputs
|
| 106 |
+
inputs = predictor.preprocess(input_x, select_pts, select_lbls)
|
| 107 |
+
|
| 108 |
+
# Run inference
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
outputs = predictor.model(**inputs)
|
| 111 |
+
|
| 112 |
+
# Post-process masks
|
| 113 |
+
masks = predictor.processor.image_processor.post_process_masks(
|
| 114 |
+
outputs.pred_masks.cpu(),
|
| 115 |
+
inputs["original_sizes"].cpu(),
|
| 116 |
+
inputs["reshaped_input_sizes"].cpu(),
|
| 117 |
+
)
|
| 118 |
+
masks = masks[0][:,:1,...].cpu().numpy()
|
| 119 |
+
|
| 120 |
+
gc.collect()
|
| 121 |
+
torch.cuda.empty_cache()
|
| 122 |
+
|
| 123 |
+
return [(masks, 'final_mask')]
|
app_3rd/spatrack_utils/infer_track.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pycolmap
|
| 2 |
+
from models.SpaTrackV2.models.predictor import Predictor
|
| 3 |
+
import yaml
|
| 4 |
+
import easydict
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
import torch
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import io
|
| 12 |
+
import moviepy.editor as mp
|
| 13 |
+
from models.SpaTrackV2.utils.visualizer import Visualizer
|
| 14 |
+
import tqdm
|
| 15 |
+
from models.SpaTrackV2.models.utils import get_points_on_a_grid
|
| 16 |
+
import glob
|
| 17 |
+
from rich import print
|
| 18 |
+
import argparse
|
| 19 |
+
import decord
|
| 20 |
+
from huggingface_hub import hf_hub_download
|
| 21 |
+
|
| 22 |
+
config = {
|
| 23 |
+
"ckpt_dir": "Yuxihenry/SpatialTrackerCkpts", # HuggingFace repo ID
|
| 24 |
+
"cfg_dir": "config/magic_infer_moge.yaml",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=None):
|
| 28 |
+
"""
|
| 29 |
+
Initialize and return the tracker predictor and visualizer
|
| 30 |
+
Args:
|
| 31 |
+
output_dir: Directory to save visualization results
|
| 32 |
+
vo_points: Number of points for visual odometry
|
| 33 |
+
Returns:
|
| 34 |
+
Tuple of (tracker_predictor, visualizer)
|
| 35 |
+
"""
|
| 36 |
+
viz = True
|
| 37 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
with open(config["cfg_dir"], "r") as f:
|
| 40 |
+
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
| 41 |
+
cfg = easydict.EasyDict(cfg)
|
| 42 |
+
cfg.out_dir = output_dir
|
| 43 |
+
cfg.model.track_num = vo_points
|
| 44 |
+
|
| 45 |
+
# Check if it's a local path or HuggingFace repo
|
| 46 |
+
if tracker_model is not None:
|
| 47 |
+
model = tracker_model
|
| 48 |
+
model.spatrack.track_num = vo_points
|
| 49 |
+
else:
|
| 50 |
+
if os.path.exists(config["ckpt_dir"]):
|
| 51 |
+
# Local file
|
| 52 |
+
model = Predictor.from_pretrained(config["ckpt_dir"], model_cfg=cfg["model"])
|
| 53 |
+
else:
|
| 54 |
+
# HuggingFace repo - download the model
|
| 55 |
+
print(f"Downloading model from HuggingFace: {config['ckpt_dir']}")
|
| 56 |
+
checkpoint_path = hf_hub_download(
|
| 57 |
+
repo_id=config["ckpt_dir"],
|
| 58 |
+
repo_type="model",
|
| 59 |
+
filename="SpaTrack3_offline.pth"
|
| 60 |
+
)
|
| 61 |
+
model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
|
| 62 |
+
model.eval()
|
| 63 |
+
model.to("cuda")
|
| 64 |
+
|
| 65 |
+
viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
|
| 66 |
+
fps=10, pad_value=0, tracks_leave_trace=5)
|
| 67 |
+
|
| 68 |
+
return model, viser
|
| 69 |
+
|
| 70 |
+
def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3):
|
| 71 |
+
"""
|
| 72 |
+
Run tracking on a video sequence
|
| 73 |
+
Args:
|
| 74 |
+
model: Tracker predictor instance
|
| 75 |
+
viser: Visualizer instance
|
| 76 |
+
temp_dir: Directory containing temporary files
|
| 77 |
+
video_name: Name of the video file (without extension)
|
| 78 |
+
grid_size: Size of the tracking grid
|
| 79 |
+
vo_points: Number of points for visual odometry
|
| 80 |
+
fps: Frames per second for visualization
|
| 81 |
+
"""
|
| 82 |
+
# Setup paths
|
| 83 |
+
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
| 84 |
+
mask_path = os.path.join(temp_dir, f"{video_name}.png")
|
| 85 |
+
out_dir = os.path.join(temp_dir, "results")
|
| 86 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 87 |
+
|
| 88 |
+
# Load video using decord
|
| 89 |
+
video_reader = decord.VideoReader(video_path)
|
| 90 |
+
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
|
| 91 |
+
|
| 92 |
+
# resize make sure the shortest side is 336
|
| 93 |
+
h, w = video_tensor.shape[2:]
|
| 94 |
+
scale = max(336 / h, 336 / w)
|
| 95 |
+
if scale < 1:
|
| 96 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 97 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 98 |
+
video_tensor = video_tensor[::fps].float()
|
| 99 |
+
depth_tensor = None
|
| 100 |
+
intrs = None
|
| 101 |
+
extrs = None
|
| 102 |
+
data_npz_load = {}
|
| 103 |
+
|
| 104 |
+
# Load and process mask
|
| 105 |
+
if os.path.exists(mask_path):
|
| 106 |
+
mask = cv2.imread(mask_path)
|
| 107 |
+
mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
|
| 108 |
+
mask = mask.sum(axis=-1)>0
|
| 109 |
+
else:
|
| 110 |
+
mask = np.ones_like(video_tensor[0,0].numpy())>0
|
| 111 |
+
|
| 112 |
+
# Get frame dimensions and create grid points
|
| 113 |
+
frame_H, frame_W = video_tensor.shape[2:]
|
| 114 |
+
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
|
| 115 |
+
|
| 116 |
+
# Sample mask values at grid points and filter out points where mask=0
|
| 117 |
+
if os.path.exists(mask_path):
|
| 118 |
+
grid_pts_int = grid_pts[0].long()
|
| 119 |
+
mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
|
| 120 |
+
grid_pts = grid_pts[:, mask_values]
|
| 121 |
+
|
| 122 |
+
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
|
| 123 |
+
|
| 124 |
+
# run vggt
|
| 125 |
+
if os.environ.get("VGGT_DIR", None) is not None:
|
| 126 |
+
vggt_model = VGGT()
|
| 127 |
+
vggt_model.load_state_dict(torch.load(VGGT_DIR))
|
| 128 |
+
vggt_model.eval()
|
| 129 |
+
vggt_model = vggt_model.to("cuda")
|
| 130 |
+
# process the image tensor
|
| 131 |
+
video_tensor = preprocess_image(video_tensor)[None]
|
| 132 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 133 |
+
# Predict attributes including cameras, depth maps, and point maps.
|
| 134 |
+
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_tensor.cuda()/255)
|
| 135 |
+
pose_enc = vggt_model.camera_head(aggregated_tokens_list)[-1]
|
| 136 |
+
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
|
| 137 |
+
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, video_tensor.shape[-2:])
|
| 138 |
+
# Predict Depth Maps
|
| 139 |
+
depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_tensor.cuda()/255, ps_idx)
|
| 140 |
+
# clear the cache
|
| 141 |
+
del vggt_model, aggregated_tokens_list, ps_idx, pose_enc
|
| 142 |
+
torch.cuda.empty_cache()
|
| 143 |
+
depth_tensor = depth_map.squeeze().cpu().numpy()
|
| 144 |
+
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
| 145 |
+
extrs[:, :3, :4] = extrinsic.squeeze().cpu().numpy()
|
| 146 |
+
intrs = intrinsic.squeeze().cpu().numpy()
|
| 147 |
+
video_tensor = video_tensor.squeeze()
|
| 148 |
+
#NOTE: 20% of the depth is not reliable
|
| 149 |
+
# threshold = depth_conf.squeeze().view(-1).quantile(0.5)
|
| 150 |
+
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
| 151 |
+
|
| 152 |
+
# Run model inference
|
| 153 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 154 |
+
(
|
| 155 |
+
c2w_traj, intrs, point_map, conf_depth,
|
| 156 |
+
track3d_pred, track2d_pred, vis_pred, conf_pred, video
|
| 157 |
+
) = model.forward(video_tensor, depth=depth_tensor,
|
| 158 |
+
intrs=intrs, extrs=extrs,
|
| 159 |
+
queries=query_xyt,
|
| 160 |
+
fps=1, full_point=False, iters_track=4,
|
| 161 |
+
query_no_BA=True, fixed_cam=False, stage=1,
|
| 162 |
+
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
| 163 |
+
|
| 164 |
+
# Resize results to avoid too large I/O Burden
|
| 165 |
+
max_size = 336
|
| 166 |
+
h, w = video.shape[2:]
|
| 167 |
+
scale = min(max_size / h, max_size / w)
|
| 168 |
+
if scale < 1:
|
| 169 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 170 |
+
video = T.Resize((new_h, new_w))(video)
|
| 171 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 172 |
+
point_map = T.Resize((new_h, new_w))(point_map)
|
| 173 |
+
track2d_pred[...,:2] = track2d_pred[...,:2] * scale
|
| 174 |
+
intrs[:,:2,:] = intrs[:,:2,:] * scale
|
| 175 |
+
if depth_tensor is not None:
|
| 176 |
+
depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
|
| 177 |
+
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
| 178 |
+
|
| 179 |
+
# Visualize tracks
|
| 180 |
+
viser.visualize(video=video[None],
|
| 181 |
+
tracks=track2d_pred[None][...,:2],
|
| 182 |
+
visibility=vis_pred[None],filename="test")
|
| 183 |
+
|
| 184 |
+
# Save in tapip3d format
|
| 185 |
+
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
|
| 186 |
+
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
|
| 187 |
+
data_npz_load["intrinsics"] = intrs.cpu().numpy()
|
| 188 |
+
data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
|
| 189 |
+
data_npz_load["video"] = (video_tensor).cpu().numpy()/255
|
| 190 |
+
data_npz_load["visibs"] = vis_pred.cpu().numpy()
|
| 191 |
+
data_npz_load["confs"] = conf_pred.cpu().numpy()
|
| 192 |
+
data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
|
| 193 |
+
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
|
| 194 |
+
|
| 195 |
+
print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")
|
app_release.py
ADDED
|
@@ -0,0 +1,1278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
import base64
|
| 7 |
+
import requests
|
| 8 |
+
import time
|
| 9 |
+
from typing import List, Tuple
|
| 10 |
+
from gradio_client.utils import handle_file
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Backend Space URL - replace with your actual backend space URL
|
| 14 |
+
BACKEND_SPACE_URL = "Yuxihenry/SpatialTrackerV2_Backend" # Replace with actual backend space URL
|
| 15 |
+
hf_token = os.getenv("HF_TOKEN") # Replace with your actual Hugging Face token
|
| 16 |
+
|
| 17 |
+
# Debug information
|
| 18 |
+
print(f"🔧 Environment Debug Info:")
|
| 19 |
+
print(f" - Backend URL: {BACKEND_SPACE_URL}")
|
| 20 |
+
print(f" - HF Token available: {'Yes' if hf_token else 'No'}")
|
| 21 |
+
print(f" - HF Token length: {len(hf_token) if hf_token else 0}")
|
| 22 |
+
|
| 23 |
+
# Flag to track if backend is available
|
| 24 |
+
BACKEND_AVAILABLE = False
|
| 25 |
+
backend_client = None
|
| 26 |
+
|
| 27 |
+
def check_user_permissions():
|
| 28 |
+
"""Check if user has necessary permissions"""
|
| 29 |
+
print("🔐 Checking user permissions...")
|
| 30 |
+
|
| 31 |
+
if not hf_token:
|
| 32 |
+
print("❌ No HF Token found")
|
| 33 |
+
print("🔧 To get a token:")
|
| 34 |
+
print(" 1. Go to https://huggingface.co/settings/tokens")
|
| 35 |
+
print(" 2. Create a new token with 'read' permissions")
|
| 36 |
+
print(" 3. Set it as environment variable: export HF_TOKEN='your_token'")
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
# Try to access user info
|
| 40 |
+
try:
|
| 41 |
+
headers = {'Authorization': f'Bearer {hf_token}'}
|
| 42 |
+
response = requests.get('https://huggingface.co/api/whoami', headers=headers, timeout=5)
|
| 43 |
+
|
| 44 |
+
if response.status_code == 200:
|
| 45 |
+
user_info = response.json()
|
| 46 |
+
username = user_info.get('name', 'Unknown')
|
| 47 |
+
print(f"✅ Authenticated as: {username}")
|
| 48 |
+
|
| 49 |
+
# Check if user has access to the specific space
|
| 50 |
+
space_url = f"https://huggingface.co/api/spaces/{BACKEND_SPACE_URL}"
|
| 51 |
+
space_response = requests.get(space_url, headers=headers, timeout=5)
|
| 52 |
+
|
| 53 |
+
if space_response.status_code == 200:
|
| 54 |
+
print("✅ You have access to the backend Space")
|
| 55 |
+
return True
|
| 56 |
+
elif space_response.status_code == 401:
|
| 57 |
+
print("❌ You don't have access to the backend Space")
|
| 58 |
+
print("🔧 Solutions:")
|
| 59 |
+
print(" 1. Contact the Space owner to add you as collaborator")
|
| 60 |
+
print(" 2. Ask the owner to make the Space public")
|
| 61 |
+
return False
|
| 62 |
+
elif space_response.status_code == 404:
|
| 63 |
+
print("❌ Backend Space not found")
|
| 64 |
+
print("🔧 Please check if the Space URL is correct")
|
| 65 |
+
return False
|
| 66 |
+
else:
|
| 67 |
+
print(f"⚠️ Unexpected response checking Space access: {space_response.status_code}")
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
else:
|
| 71 |
+
print(f"❌ Token validation failed: {response.status_code}")
|
| 72 |
+
print("🔧 Your token might be invalid or expired")
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"❌ Error checking permissions: {e}")
|
| 77 |
+
return False
|
| 78 |
+
|
| 79 |
+
def check_backend_space_status():
|
| 80 |
+
"""Check if backend space is running via HTTP request"""
|
| 81 |
+
try:
|
| 82 |
+
backend_url = f"https://huggingface.co/spaces/{BACKEND_SPACE_URL}"
|
| 83 |
+
print(f"🔍 Checking backend space status: {backend_url}")
|
| 84 |
+
|
| 85 |
+
# Prepare headers with authentication if token is available
|
| 86 |
+
headers = {}
|
| 87 |
+
if hf_token:
|
| 88 |
+
headers['Authorization'] = f'Bearer {hf_token}'
|
| 89 |
+
print(f"🔐 Using HF Token for authentication")
|
| 90 |
+
|
| 91 |
+
# Try to access the space page
|
| 92 |
+
response = requests.get(backend_url, headers=headers, timeout=10)
|
| 93 |
+
|
| 94 |
+
if response.status_code == 200:
|
| 95 |
+
print("✅ Backend space page is accessible")
|
| 96 |
+
|
| 97 |
+
# Check if space is running (look for common indicators)
|
| 98 |
+
page_content = response.text.lower()
|
| 99 |
+
if "runtime error" in page_content:
|
| 100 |
+
print("❌ Backend space has runtime error")
|
| 101 |
+
return False
|
| 102 |
+
elif "building" in page_content:
|
| 103 |
+
print("🔄 Backend space is building...")
|
| 104 |
+
return False
|
| 105 |
+
elif "sleeping" in page_content:
|
| 106 |
+
print("😴 Backend space is sleeping")
|
| 107 |
+
return False
|
| 108 |
+
else:
|
| 109 |
+
print("✅ Backend space appears to be running")
|
| 110 |
+
return True
|
| 111 |
+
|
| 112 |
+
elif response.status_code == 401:
|
| 113 |
+
print("❌ Authentication failed (HTTP 401)")
|
| 114 |
+
print("🔧 This means:")
|
| 115 |
+
print(" - The backend Space is private")
|
| 116 |
+
print(" - Your HF Token doesn't have access to this Space")
|
| 117 |
+
print(" - You need to be added as a collaborator to the Space")
|
| 118 |
+
print(" - Or the Space owner needs to make it public")
|
| 119 |
+
return False
|
| 120 |
+
|
| 121 |
+
elif response.status_code == 404:
|
| 122 |
+
print("❌ Backend space not found (HTTP 404)")
|
| 123 |
+
print("🔧 Please check if the Space URL is correct:")
|
| 124 |
+
print(f" Current URL: {BACKEND_SPACE_URL}")
|
| 125 |
+
return False
|
| 126 |
+
|
| 127 |
+
else:
|
| 128 |
+
print(f"❌ Backend space not accessible (HTTP {response.status_code})")
|
| 129 |
+
print(f"🔧 Response: {response.text[:200]}...")
|
| 130 |
+
return False
|
| 131 |
+
|
| 132 |
+
except requests.RequestException as e:
|
| 133 |
+
print(f"❌ Failed to check backend space status: {e}")
|
| 134 |
+
return False
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"❌ Unexpected error checking backend: {e}")
|
| 137 |
+
return False
|
| 138 |
+
|
| 139 |
+
def initialize_backend():
|
| 140 |
+
"""Initialize backend connection using gradio_client"""
|
| 141 |
+
global backend_client, BACKEND_AVAILABLE
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
from gradio_client import Client
|
| 145 |
+
|
| 146 |
+
# Connect to HF Space
|
| 147 |
+
if hf_token:
|
| 148 |
+
backend_client = Client(BACKEND_SPACE_URL, hf_token=hf_token)
|
| 149 |
+
else:
|
| 150 |
+
backend_client = Client(BACKEND_SPACE_URL)
|
| 151 |
+
|
| 152 |
+
# Test the connection
|
| 153 |
+
backend_client.view_api()
|
| 154 |
+
BACKEND_AVAILABLE = True
|
| 155 |
+
return True
|
| 156 |
+
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"❌ Backend connection failed: {e}")
|
| 159 |
+
BACKEND_AVAILABLE = False
|
| 160 |
+
return False
|
| 161 |
+
|
| 162 |
+
def numpy_to_base64(arr):
|
| 163 |
+
"""Convert numpy array to base64 string"""
|
| 164 |
+
return base64.b64encode(arr.tobytes()).decode('utf-8')
|
| 165 |
+
|
| 166 |
+
def base64_to_numpy(b64_str, shape, dtype):
|
| 167 |
+
"""Convert base64 string back to numpy array"""
|
| 168 |
+
return np.frombuffer(base64.b64decode(b64_str), dtype=dtype).reshape(shape)
|
| 169 |
+
|
| 170 |
+
def base64_to_image(b64_str):
|
| 171 |
+
"""Convert base64 string to numpy image array"""
|
| 172 |
+
if not b64_str:
|
| 173 |
+
return None
|
| 174 |
+
try:
|
| 175 |
+
# Decode base64 to bytes
|
| 176 |
+
img_bytes = base64.b64decode(b64_str)
|
| 177 |
+
# Convert bytes to numpy array
|
| 178 |
+
nparr = np.frombuffer(img_bytes, np.uint8)
|
| 179 |
+
# Decode image
|
| 180 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 181 |
+
# Convert BGR to RGB
|
| 182 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 183 |
+
return img
|
| 184 |
+
except Exception as e:
|
| 185 |
+
print(f"Error converting base64 to image: {e}")
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
def get_video_name(video_path):
|
| 189 |
+
"""Extract video name without extension"""
|
| 190 |
+
return os.path.splitext(os.path.basename(video_path))[0]
|
| 191 |
+
|
| 192 |
+
def extract_first_frame(video_path):
|
| 193 |
+
"""Extract first frame from video file"""
|
| 194 |
+
try:
|
| 195 |
+
cap = cv2.VideoCapture(video_path)
|
| 196 |
+
ret, frame = cap.read()
|
| 197 |
+
cap.release()
|
| 198 |
+
|
| 199 |
+
if ret:
|
| 200 |
+
# Convert BGR to RGB
|
| 201 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 202 |
+
return frame_rgb
|
| 203 |
+
else:
|
| 204 |
+
return None
|
| 205 |
+
except Exception as e:
|
| 206 |
+
print(f"Error extracting first frame: {e}")
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
def handle_video_upload(video):
|
| 210 |
+
"""Handle video upload and extract first frame"""
|
| 211 |
+
if video is None:
|
| 212 |
+
return (None, None, [],
|
| 213 |
+
gr.update(value=50),
|
| 214 |
+
gr.update(value=756),
|
| 215 |
+
gr.update(value=3))
|
| 216 |
+
|
| 217 |
+
try:
|
| 218 |
+
if BACKEND_AVAILABLE and backend_client:
|
| 219 |
+
# Try to use backend API
|
| 220 |
+
try:
|
| 221 |
+
print("🔧 Calling backend API for video upload...")
|
| 222 |
+
|
| 223 |
+
# Call the unified API with upload_video function type - fix: use handle_file wrapper
|
| 224 |
+
result = backend_client.predict(
|
| 225 |
+
"upload_video", # function_type
|
| 226 |
+
handle_file(video), # video file - wrapped with handle_file
|
| 227 |
+
"", # original_image_state (not used for upload)
|
| 228 |
+
[], # selected_points (not used for upload)
|
| 229 |
+
"positive_point", # point_type (not used for upload)
|
| 230 |
+
0, # point_x (not used for upload)
|
| 231 |
+
0, # point_y (not used for upload)
|
| 232 |
+
50, # grid_size (not used for upload)
|
| 233 |
+
756, # vo_points (not used for upload)
|
| 234 |
+
3, # fps (not used for upload)
|
| 235 |
+
api_name="/unified_api"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
print(f"✅ Backend video upload API call successful!")
|
| 239 |
+
print(f"🔧 Result type: {type(result)}")
|
| 240 |
+
print(f"🔧 Result: {result}")
|
| 241 |
+
|
| 242 |
+
# Parse the result - expect a dict with success status
|
| 243 |
+
if isinstance(result, dict) and result.get("success"):
|
| 244 |
+
# Extract data from backend response
|
| 245 |
+
original_image_state = result.get("original_image_state", "")
|
| 246 |
+
display_image = result.get("display_image", None)
|
| 247 |
+
selected_points = result.get("selected_points", [])
|
| 248 |
+
|
| 249 |
+
# Fix: Convert display_image from list back to numpy array if needed
|
| 250 |
+
if isinstance(display_image, list):
|
| 251 |
+
display_image = np.array(display_image, dtype=np.uint8)
|
| 252 |
+
print(f"🔧 Converted display_image from list to numpy array: {display_image.shape}")
|
| 253 |
+
|
| 254 |
+
# Get video settings based on video name
|
| 255 |
+
video_name = get_video_name(video)
|
| 256 |
+
print(f"🎬 Video path: '{video}' -> Video name: '{video_name}'")
|
| 257 |
+
grid_size_val, vo_points_val, fps_val = get_video_settings(video_name)
|
| 258 |
+
print(f"🎬 Video settings for '{video_name}': grid_size={grid_size_val}, vo_points={vo_points_val}, fps={fps_val}")
|
| 259 |
+
|
| 260 |
+
return (original_image_state, display_image, selected_points,
|
| 261 |
+
gr.update(value=grid_size_val),
|
| 262 |
+
gr.update(value=vo_points_val),
|
| 263 |
+
gr.update(value=fps_val))
|
| 264 |
+
else:
|
| 265 |
+
print("Backend processing failed, using local fallback")
|
| 266 |
+
# Fallback to local processing
|
| 267 |
+
pass
|
| 268 |
+
except Exception as e:
|
| 269 |
+
print(f"Backend API call failed: {e}")
|
| 270 |
+
# Fallback to local processing
|
| 271 |
+
pass
|
| 272 |
+
|
| 273 |
+
# Fallback: local processing
|
| 274 |
+
print("Using local video processing...")
|
| 275 |
+
display_image = extract_first_frame(video)
|
| 276 |
+
|
| 277 |
+
if display_image is not None:
|
| 278 |
+
# Create a state format compatible with backend
|
| 279 |
+
import tempfile
|
| 280 |
+
import shutil
|
| 281 |
+
|
| 282 |
+
# Create a temporary directory for this session
|
| 283 |
+
session_id = str(int(time.time() * 1000)) # Use timestamp as session ID
|
| 284 |
+
temp_dir = os.path.join("temp_frontend", f"session_{session_id}")
|
| 285 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 286 |
+
|
| 287 |
+
# Copy video to temp directory with standardized name
|
| 288 |
+
video_name = get_video_name(video)
|
| 289 |
+
temp_video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
| 290 |
+
shutil.copy(video, temp_video_path)
|
| 291 |
+
|
| 292 |
+
# Create state format compatible with backend
|
| 293 |
+
frame_data = {
|
| 294 |
+
'data': numpy_to_base64(display_image),
|
| 295 |
+
'shape': display_image.shape,
|
| 296 |
+
'dtype': str(display_image.dtype),
|
| 297 |
+
'temp_dir': temp_dir,
|
| 298 |
+
'video_name': video_name,
|
| 299 |
+
'video_path': temp_video_path # Keep for backward compatibility
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
original_image_state = json.dumps(frame_data)
|
| 303 |
+
else:
|
| 304 |
+
# Fallback to simple state if frame extraction fails
|
| 305 |
+
original_image_state = json.dumps({
|
| 306 |
+
"video_path": video,
|
| 307 |
+
"frame": "local_processing_failed"
|
| 308 |
+
})
|
| 309 |
+
|
| 310 |
+
# Get video settings
|
| 311 |
+
video_name = get_video_name(video)
|
| 312 |
+
print(f"🎬 Local fallback - Video path: '{video}' -> Video name: '{video_name}'")
|
| 313 |
+
grid_size_val, vo_points_val, fps_val = get_video_settings(video_name)
|
| 314 |
+
print(f"🎬 Local fallback - Video settings for '{video_name}': grid_size={grid_size_val}, vo_points={vo_points_val}, fps={fps_val}")
|
| 315 |
+
|
| 316 |
+
return (original_image_state, display_image, [],
|
| 317 |
+
gr.update(value=grid_size_val),
|
| 318 |
+
gr.update(value=vo_points_val),
|
| 319 |
+
gr.update(value=fps_val))
|
| 320 |
+
|
| 321 |
+
except Exception as e:
|
| 322 |
+
print(f"Error in handle_video_upload: {e}")
|
| 323 |
+
return (None, None, [],
|
| 324 |
+
gr.update(value=50),
|
| 325 |
+
gr.update(value=756),
|
| 326 |
+
gr.update(value=3))
|
| 327 |
+
|
| 328 |
+
def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.SelectData):
|
| 329 |
+
"""Handle point selection for SAM"""
|
| 330 |
+
if original_img is None:
|
| 331 |
+
return None, []
|
| 332 |
+
|
| 333 |
+
try:
|
| 334 |
+
if BACKEND_AVAILABLE and backend_client:
|
| 335 |
+
# Try to use backend API
|
| 336 |
+
try:
|
| 337 |
+
print(f"🔧 Calling backend select point API: x={evt.index[0]}, y={evt.index[1]}, type={point_type}")
|
| 338 |
+
|
| 339 |
+
# Call the unified API with select_point function type
|
| 340 |
+
result = backend_client.predict(
|
| 341 |
+
"select_point", # function_type
|
| 342 |
+
None, # video file (not used for select_point)
|
| 343 |
+
original_img, # original_image_state
|
| 344 |
+
sel_pix, # selected_points
|
| 345 |
+
point_type, # point_type
|
| 346 |
+
evt.index[0], # point_x
|
| 347 |
+
evt.index[1], # point_y
|
| 348 |
+
50, # grid_size (not used for select_point)
|
| 349 |
+
756, # vo_points (not used for select_point)
|
| 350 |
+
3, # fps (not used for select_point)
|
| 351 |
+
api_name="/unified_api"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
print(f"✅ Backend select point API call successful!")
|
| 355 |
+
print(f"🔧 Result type: {type(result)}")
|
| 356 |
+
print(f"🔧 Result: {result}")
|
| 357 |
+
|
| 358 |
+
# Parse the result - expect a dict with success status
|
| 359 |
+
if isinstance(result, dict) and result.get("success"):
|
| 360 |
+
display_image = result.get("display_image", None)
|
| 361 |
+
new_sel_pix = result.get("selected_points", sel_pix)
|
| 362 |
+
|
| 363 |
+
# Fix: Convert display_image from list back to numpy array if needed
|
| 364 |
+
if isinstance(display_image, list):
|
| 365 |
+
display_image = np.array(display_image, dtype=np.uint8)
|
| 366 |
+
print(f"🔧 Converted display_image from list to numpy array: {display_image.shape}")
|
| 367 |
+
|
| 368 |
+
return display_image, new_sel_pix
|
| 369 |
+
else:
|
| 370 |
+
print("Backend processing failed, using local fallback")
|
| 371 |
+
# Fallback to local processing
|
| 372 |
+
pass
|
| 373 |
+
except Exception as e:
|
| 374 |
+
print(f"Backend API call failed: {e}")
|
| 375 |
+
|
| 376 |
+
# Check for specific gradio_client errors
|
| 377 |
+
if "AppError" in str(type(e)):
|
| 378 |
+
print("🔧 Backend Space has internal errors (AppError)")
|
| 379 |
+
print("🔧 The backend Space code has bugs or configuration issues")
|
| 380 |
+
print("🔧 Contact the Space owner to fix the backend implementation")
|
| 381 |
+
elif "Could not fetch config" in str(e):
|
| 382 |
+
print("🔧 Config fetch failed - possible Gradio version mismatch")
|
| 383 |
+
print("🔧 Frontend and backend may be using incompatible Gradio versions")
|
| 384 |
+
elif "timeout" in str(e).lower():
|
| 385 |
+
print("🔧 Backend request timed out - Space might be overloaded")
|
| 386 |
+
else:
|
| 387 |
+
print(f"🔧 Unexpected error type: {type(e).__name__}")
|
| 388 |
+
|
| 389 |
+
print("🔄 Showing error message instead of visualization...")
|
| 390 |
+
# Fallback to local processing
|
| 391 |
+
pass
|
| 392 |
+
|
| 393 |
+
# Fallback: local processing with improved visualization
|
| 394 |
+
print("Using local point selection with enhanced visualization...")
|
| 395 |
+
|
| 396 |
+
# Parse original image state
|
| 397 |
+
try:
|
| 398 |
+
state_data = json.loads(original_img)
|
| 399 |
+
video_path = state_data.get("video_path")
|
| 400 |
+
except:
|
| 401 |
+
video_path = None
|
| 402 |
+
|
| 403 |
+
if video_path:
|
| 404 |
+
# Re-extract frame and add point with mask visualization
|
| 405 |
+
display_image = extract_first_frame(video_path)
|
| 406 |
+
if display_image is not None:
|
| 407 |
+
# Add point to the image with enhanced visualization
|
| 408 |
+
x, y = evt.index[0], evt.index[1]
|
| 409 |
+
color = (0, 255, 0) if point_type == 'positive_point' else (255, 0, 0)
|
| 410 |
+
|
| 411 |
+
# Draw a larger, more visible point
|
| 412 |
+
cv2.circle(display_image, (x, y), 8, color, -1)
|
| 413 |
+
cv2.circle(display_image, (x, y), 12, (255, 255, 255), 2)
|
| 414 |
+
|
| 415 |
+
# Add point to selected points list - fix logic to match local version
|
| 416 |
+
new_sel_pix = sel_pix.copy() if sel_pix else []
|
| 417 |
+
new_sel_pix.append([x, y, point_type])
|
| 418 |
+
|
| 419 |
+
return display_image, new_sel_pix
|
| 420 |
+
|
| 421 |
+
return None, []
|
| 422 |
+
|
| 423 |
+
except Exception as e:
|
| 424 |
+
print(f"Error in select_point: {e}")
|
| 425 |
+
return None, []
|
| 426 |
+
|
| 427 |
+
def reset_points(original_img: str, sel_pix):
|
| 428 |
+
"""Reset points and restore original image"""
|
| 429 |
+
if original_img is None:
|
| 430 |
+
return None, []
|
| 431 |
+
|
| 432 |
+
try:
|
| 433 |
+
if BACKEND_AVAILABLE and backend_client:
|
| 434 |
+
# Try to use backend API
|
| 435 |
+
try:
|
| 436 |
+
print("🔧 Calling backend reset points API...")
|
| 437 |
+
|
| 438 |
+
# Call the unified API with reset_points function type
|
| 439 |
+
result = backend_client.predict(
|
| 440 |
+
"reset_points", # function_type
|
| 441 |
+
None, # video file (not used for reset_points)
|
| 442 |
+
original_img, # original_image_state
|
| 443 |
+
sel_pix, # selected_points
|
| 444 |
+
"positive_point", # point_type (not used for reset_points)
|
| 445 |
+
0, # point_x (not used for reset_points)
|
| 446 |
+
0, # point_y (not used for reset_points)
|
| 447 |
+
50, # grid_size (not used for reset_points)
|
| 448 |
+
756, # vo_points (not used for reset_points)
|
| 449 |
+
3, # fps (not used for reset_points)
|
| 450 |
+
api_name="/unified_api"
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
print(f"✅ Backend reset points API call successful!")
|
| 454 |
+
print(f"🔧 Result: {result}")
|
| 455 |
+
|
| 456 |
+
# Parse the result
|
| 457 |
+
if isinstance(result, dict) and result.get("success"):
|
| 458 |
+
display_image = result.get("display_image", None)
|
| 459 |
+
new_sel_pix = result.get("selected_points", [])
|
| 460 |
+
|
| 461 |
+
# Fix: Convert display_image from list back to numpy array if needed
|
| 462 |
+
if isinstance(display_image, list):
|
| 463 |
+
display_image = np.array(display_image, dtype=np.uint8)
|
| 464 |
+
print(f"🔧 Converted display_image from list to numpy array: {display_image.shape}")
|
| 465 |
+
|
| 466 |
+
return display_image, new_sel_pix
|
| 467 |
+
else:
|
| 468 |
+
print("Backend processing failed, using local fallback")
|
| 469 |
+
# Fallback to local processing
|
| 470 |
+
pass
|
| 471 |
+
except Exception as e:
|
| 472 |
+
print(f"Backend API call failed: {e}")
|
| 473 |
+
# Fallback to local processing
|
| 474 |
+
pass
|
| 475 |
+
|
| 476 |
+
# Fallback: local processing
|
| 477 |
+
print("Using local reset points...")
|
| 478 |
+
|
| 479 |
+
# Parse original image state
|
| 480 |
+
try:
|
| 481 |
+
state_data = json.loads(original_img)
|
| 482 |
+
video_path = state_data.get("video_path")
|
| 483 |
+
except:
|
| 484 |
+
video_path = None
|
| 485 |
+
|
| 486 |
+
if video_path:
|
| 487 |
+
# Re-extract original frame
|
| 488 |
+
display_image = extract_first_frame(video_path)
|
| 489 |
+
return display_image, []
|
| 490 |
+
|
| 491 |
+
return None, []
|
| 492 |
+
|
| 493 |
+
except Exception as e:
|
| 494 |
+
print(f"Error in reset_points: {e}")
|
| 495 |
+
return None, []
|
| 496 |
+
|
| 497 |
+
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
| 498 |
+
|
| 499 |
+
def launch_viz(grid_size, vo_points, fps, original_image_state):
|
| 500 |
+
"""Launch visualization with user-specific temp directory"""
|
| 501 |
+
if original_image_state is None:
|
| 502 |
+
return None, None, None
|
| 503 |
+
|
| 504 |
+
try:
|
| 505 |
+
if BACKEND_AVAILABLE and backend_client:
|
| 506 |
+
# Try to use backend API
|
| 507 |
+
try:
|
| 508 |
+
print(f"🔧 Calling backend API with parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
|
| 509 |
+
print(f"🔧 Original image state type: {type(original_image_state)}")
|
| 510 |
+
print(f"🔧 Original image state preview: {str(original_image_state)[:100]}...")
|
| 511 |
+
|
| 512 |
+
# Validate and potentially fix the original_image_state format
|
| 513 |
+
state_to_send = original_image_state
|
| 514 |
+
|
| 515 |
+
# Check if this is a local processing state that needs to be converted
|
| 516 |
+
try:
|
| 517 |
+
if isinstance(original_image_state, str):
|
| 518 |
+
parsed_state = json.loads(original_image_state)
|
| 519 |
+
if "video_path" in parsed_state and "frame" in parsed_state:
|
| 520 |
+
# This is a local processing state, we need to handle differently
|
| 521 |
+
print("🔧 Detected local processing state, cannot use backend for tracking")
|
| 522 |
+
print("🔧 Backend requires proper video upload state from backend API")
|
| 523 |
+
# Fall through to local processing
|
| 524 |
+
raise ValueError("Local state cannot be processed by backend")
|
| 525 |
+
except json.JSONDecodeError:
|
| 526 |
+
print("🔧 Invalid JSON state, cannot send to backend")
|
| 527 |
+
raise ValueError("Invalid state format")
|
| 528 |
+
|
| 529 |
+
# Call the unified API with run_tracker function type
|
| 530 |
+
result = backend_client.predict(
|
| 531 |
+
"run_tracker", # function_type
|
| 532 |
+
None, # video file (not used for run_tracker)
|
| 533 |
+
state_to_send, # original_image_state
|
| 534 |
+
[], # selected_points (not used for run_tracker)
|
| 535 |
+
"positive_point", # point_type (not used for run_tracker)
|
| 536 |
+
0, # point_x (not used for run_tracker)
|
| 537 |
+
0, # point_y (not used for run_tracker)
|
| 538 |
+
grid_size, # grid_size
|
| 539 |
+
vo_points, # vo_points
|
| 540 |
+
fps, # fps
|
| 541 |
+
api_name="/unified_api"
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
print(f"✅ Backend API call successful!")
|
| 545 |
+
print(f"🔧 Result type: {type(result)}")
|
| 546 |
+
print(f"🔧 Result: {result}")
|
| 547 |
+
|
| 548 |
+
# Parse the result
|
| 549 |
+
if isinstance(result, dict) and result.get("success"):
|
| 550 |
+
viz_html = result.get("viz_html", "")
|
| 551 |
+
track_video_path = result.get("track_video_path", "")
|
| 552 |
+
track_video_content = result.get("track_video_content", None)
|
| 553 |
+
track_video_filename = result.get("track_video_filename", "tracked_video.mp4")
|
| 554 |
+
|
| 555 |
+
# Save HTML to _viz directory (like local version)
|
| 556 |
+
viz_dir = './_viz'
|
| 557 |
+
os.makedirs(viz_dir, exist_ok=True)
|
| 558 |
+
random_path = f'./_viz/_{time.time()}.html'
|
| 559 |
+
|
| 560 |
+
with open(random_path, 'w', encoding='utf-8') as f:
|
| 561 |
+
f.write(viz_html)
|
| 562 |
+
|
| 563 |
+
# Create iframe HTML
|
| 564 |
+
iframe_html = f"""
|
| 565 |
+
<div style='border: 3px solid #667eea; border-radius: 10px;
|
| 566 |
+
background: #f8f9ff; height: 650px; width: 100%;
|
| 567 |
+
box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3);
|
| 568 |
+
margin: 0; padding: 0; box-sizing: border-box; overflow: hidden;'>
|
| 569 |
+
<iframe id="viz_iframe" src="/gradio_api/file={random_path}"
|
| 570 |
+
width="100%" height="650" frameborder="0"
|
| 571 |
+
style="border: none; display: block; width: 100%; height: 650px;
|
| 572 |
+
margin: 0; padding: 0; border-radius: 7px;">
|
| 573 |
+
</iframe>
|
| 574 |
+
</div>
|
| 575 |
+
"""
|
| 576 |
+
|
| 577 |
+
print(f"💾 HTML saved to: {random_path}")
|
| 578 |
+
print(f"📊 HTML content preview: {viz_html[:200]}...")
|
| 579 |
+
|
| 580 |
+
# If we have base64 encoded video content, save it as a temporary file
|
| 581 |
+
local_video_path = None
|
| 582 |
+
if track_video_content:
|
| 583 |
+
try:
|
| 584 |
+
# Create a temporary file for the video
|
| 585 |
+
temp_video_dir = "temp_frontend_videos"
|
| 586 |
+
os.makedirs(temp_video_dir, exist_ok=True)
|
| 587 |
+
|
| 588 |
+
# Generate unique filename to avoid conflicts
|
| 589 |
+
timestamp = str(int(time.time() * 1000))
|
| 590 |
+
local_video_path = os.path.join(temp_video_dir, f"{timestamp}_{track_video_filename}")
|
| 591 |
+
|
| 592 |
+
# Decode base64 and save as video file
|
| 593 |
+
video_bytes = base64.b64decode(track_video_content)
|
| 594 |
+
with open(local_video_path, 'wb') as f:
|
| 595 |
+
f.write(video_bytes)
|
| 596 |
+
|
| 597 |
+
print(f"✅ Successfully saved tracking video to: {local_video_path}")
|
| 598 |
+
print(f"🔧 Video file size: {len(video_bytes)} bytes")
|
| 599 |
+
|
| 600 |
+
except Exception as e:
|
| 601 |
+
print(f"❌ Failed to process tracking video: {e}")
|
| 602 |
+
local_video_path = None
|
| 603 |
+
else:
|
| 604 |
+
print("⚠️ No tracking video content received from backend")
|
| 605 |
+
|
| 606 |
+
# 返回iframe HTML、视频路径和HTML文件路径(用于下载)
|
| 607 |
+
return iframe_html, local_video_path, random_path
|
| 608 |
+
else:
|
| 609 |
+
error_msg = result.get("error", "Unknown error") if isinstance(result, dict) else "Backend processing failed"
|
| 610 |
+
print(f"❌ Backend processing failed: {error_msg}")
|
| 611 |
+
# Fall through to error message
|
| 612 |
+
pass
|
| 613 |
+
except Exception as e:
|
| 614 |
+
print(f"❌ Backend API call failed: {e}")
|
| 615 |
+
print(f"🔧 Error type: {type(e)}")
|
| 616 |
+
print(f"🔧 Error details: {str(e)}")
|
| 617 |
+
|
| 618 |
+
# Check for specific gradio_client errors
|
| 619 |
+
if "AppError" in str(type(e)):
|
| 620 |
+
print("🔧 Backend Space has internal errors (AppError)")
|
| 621 |
+
print("🔧 The backend Space code has bugs or configuration issues")
|
| 622 |
+
print("🔧 Contact the Space owner to fix the backend implementation")
|
| 623 |
+
elif "Could not fetch config" in str(e):
|
| 624 |
+
print("🔧 Config fetch failed - possible Gradio version mismatch")
|
| 625 |
+
print("🔧 Frontend and backend may be using incompatible Gradio versions")
|
| 626 |
+
elif "timeout" in str(e).lower():
|
| 627 |
+
print("🔧 Backend request timed out - Space might be overloaded")
|
| 628 |
+
elif "Expecting value" in str(e):
|
| 629 |
+
print("🔧 JSON parsing error in backend - state format mismatch")
|
| 630 |
+
print("🔧 This happens when using local processing state with backend API")
|
| 631 |
+
print("🔧 Please upload video again to use backend processing")
|
| 632 |
+
else:
|
| 633 |
+
print(f"🔧 Unexpected error type: {type(e).__name__}")
|
| 634 |
+
|
| 635 |
+
print("🔄 Showing error message instead of visualization...")
|
| 636 |
+
# Fall through to error message
|
| 637 |
+
pass
|
| 638 |
+
|
| 639 |
+
# Create an informative error message based on the state
|
| 640 |
+
state_info = ""
|
| 641 |
+
try:
|
| 642 |
+
if isinstance(original_image_state, str):
|
| 643 |
+
parsed_state = json.loads(original_image_state)
|
| 644 |
+
if "video_path" in parsed_state:
|
| 645 |
+
video_name = os.path.basename(parsed_state["video_path"])
|
| 646 |
+
state_info = f"Video: {video_name}"
|
| 647 |
+
except:
|
| 648 |
+
state_info = "State format unknown"
|
| 649 |
+
|
| 650 |
+
# Fallback: show message that backend is required
|
| 651 |
+
error_message = f"""
|
| 652 |
+
<div style='border: 3px solid #ff6b6b; border-radius: 10px; padding: 20px; background-color: #fff5f5;'>
|
| 653 |
+
<h3 style='color: #d63031; margin-bottom: 15px;'>⚠️ Backend Processing Required</h3>
|
| 654 |
+
<p style='color: #2d3436; line-height: 1.6;'>
|
| 655 |
+
The tracking and visualization features require backend processing. The current setup is using local processing which is incompatible with the backend API.
|
| 656 |
+
</p>
|
| 657 |
+
<h4 style='color: #d63031; margin: 15px 0 10px 0;'>Solutions:</h4>
|
| 658 |
+
<ul style='color: #2d3436; line-height: 1.6;'>
|
| 659 |
+
<li><strong>Upload video again:</strong> This will properly initialize the backend state</li>
|
| 660 |
+
<li><strong>Select points on the frame:</strong> Ensure you've clicked on the object to track</li>
|
| 661 |
+
<li><strong>Check backend connection:</strong> Ensure the backend Space is running</li>
|
| 662 |
+
<li><strong>Use compatible state:</strong> Avoid local processing mode</li>
|
| 663 |
+
</ul>
|
| 664 |
+
<div style='background-color: #f8f9fa; border-radius: 5px; padding: 10px; margin-top: 15px;'>
|
| 665 |
+
<p style='color: #2d3436; font-weight: bold; margin: 0 0 5px 0;'>Debug Information:</p>
|
| 666 |
+
<p style='color: #666; font-size: 12px; margin: 0;'>Backend Available: {BACKEND_AVAILABLE}</p>
|
| 667 |
+
<p style='color: #666; font-size: 12px; margin: 0;'>Backend Client: {backend_client is not None}</p>
|
| 668 |
+
<p style='color: #666; font-size: 12px; margin: 0;'>Backend URL: {BACKEND_SPACE_URL}</p>
|
| 669 |
+
<p style='color: #666; font-size: 12px; margin: 0;'>State Info: {state_info}</p>
|
| 670 |
+
<p style='color: #666; font-size: 12px; margin: 0;'>Processing Mode: {"Backend" if BACKEND_AVAILABLE else "Local (Limited)"}</p>
|
| 671 |
+
</div>
|
| 672 |
+
<div style='background-color: #e3f2fd; border-radius: 5px; padding: 10px; margin-top: 10px; border-left: 4px solid #2196f3;'>
|
| 673 |
+
<p style='color: #1976d2; font-weight: bold; margin: 0 0 5px 0;'>💡 Quick Fix:</p>
|
| 674 |
+
<p style='color: #1976d2; font-size: 13px; margin: 0;'>
|
| 675 |
+
Try uploading your video again - this should properly initialize the backend state for tracking.
|
| 676 |
+
</p>
|
| 677 |
+
</div>
|
| 678 |
+
</div>
|
| 679 |
+
"""
|
| 680 |
+
return error_message, None, None
|
| 681 |
+
|
| 682 |
+
except Exception as e:
|
| 683 |
+
print(f"Error in launch_viz: {e}")
|
| 684 |
+
return None, None, None
|
| 685 |
+
|
| 686 |
+
def clear_all():
|
| 687 |
+
"""Clear all buffers and temporary files"""
|
| 688 |
+
return (None, None, [],
|
| 689 |
+
gr.update(value=50),
|
| 690 |
+
gr.update(value=756),
|
| 691 |
+
gr.update(value=3))
|
| 692 |
+
|
| 693 |
+
def clear_all_with_download():
|
| 694 |
+
"""Clear all buffers including both download components"""
|
| 695 |
+
return (None, None, [],
|
| 696 |
+
gr.update(value=50),
|
| 697 |
+
gr.update(value=756),
|
| 698 |
+
gr.update(value=3),
|
| 699 |
+
None, # tracking_video_download
|
| 700 |
+
None) # HTML download component
|
| 701 |
+
|
| 702 |
+
def update_tracker_model(model_name):
|
| 703 |
+
"""Update tracker model (placeholder function)"""
|
| 704 |
+
return
|
| 705 |
+
|
| 706 |
+
def get_video_settings(video_name):
|
| 707 |
+
"""Get video-specific settings based on video name"""
|
| 708 |
+
video_settings = {
|
| 709 |
+
"kiss": (45, 700, 10),
|
| 710 |
+
"backpack": (40, 600, 2),
|
| 711 |
+
"kitchen": (60, 800, 3),
|
| 712 |
+
"pillow": (35, 500, 2),
|
| 713 |
+
"handwave": (35, 500, 8),
|
| 714 |
+
"hockey": (45, 700, 2),
|
| 715 |
+
"drifting": (35, 1000, 6),
|
| 716 |
+
"basketball": (45, 1500, 5),
|
| 717 |
+
"ken_block_0": (45, 700, 2),
|
| 718 |
+
"ego_kc1": (45, 500, 4),
|
| 719 |
+
"vertical_place": (45, 500, 3),
|
| 720 |
+
"ego_teaser": (45, 1200, 10),
|
| 721 |
+
"robot_unitree": (45, 500, 4),
|
| 722 |
+
"robot_3": (35, 400, 5),
|
| 723 |
+
"teleop2": (45, 256, 7),
|
| 724 |
+
"pusht": (45, 256, 10),
|
| 725 |
+
"cinema_0": (45, 356, 5),
|
| 726 |
+
"cinema_1": (45, 756, 3),
|
| 727 |
+
}
|
| 728 |
+
|
| 729 |
+
return video_settings.get(video_name, (50, 756, 3))
|
| 730 |
+
|
| 731 |
+
def test_backend_connection():
|
| 732 |
+
"""Test if backend is actually working"""
|
| 733 |
+
global BACKEND_AVAILABLE
|
| 734 |
+
if not backend_client:
|
| 735 |
+
return False
|
| 736 |
+
|
| 737 |
+
try:
|
| 738 |
+
print("Testing backend connection with a simple call...")
|
| 739 |
+
# Check if we have fns available
|
| 740 |
+
if hasattr(backend_client, 'fns') and backend_client.fns:
|
| 741 |
+
print("✅ Backend API functions are available")
|
| 742 |
+
print(f"🔧 Available function indices: {list(backend_client.fns.keys())}")
|
| 743 |
+
return True
|
| 744 |
+
else:
|
| 745 |
+
print("❌ Backend API functions not found")
|
| 746 |
+
return False
|
| 747 |
+
except Exception as e:
|
| 748 |
+
print(f"❌ Backend connection test failed: {e}")
|
| 749 |
+
return False
|
| 750 |
+
|
| 751 |
+
def test_backend_api():
|
| 752 |
+
"""Test specific backend API functions"""
|
| 753 |
+
if not BACKEND_AVAILABLE or not backend_client:
|
| 754 |
+
print("❌ Backend not available for testing")
|
| 755 |
+
return False
|
| 756 |
+
|
| 757 |
+
try:
|
| 758 |
+
print("🧪 Testing backend API functions...")
|
| 759 |
+
|
| 760 |
+
# Test if fns exist and show available indices
|
| 761 |
+
if hasattr(backend_client, 'fns') and backend_client.fns:
|
| 762 |
+
print(f"✅ Backend has {len(backend_client.fns)} functions available")
|
| 763 |
+
for idx in backend_client.fns.keys():
|
| 764 |
+
print(f"✅ Function {idx} is available")
|
| 765 |
+
else:
|
| 766 |
+
print("❌ No functions found in backend API")
|
| 767 |
+
return False
|
| 768 |
+
|
| 769 |
+
return True
|
| 770 |
+
|
| 771 |
+
except Exception as e:
|
| 772 |
+
print(f"❌ Backend API test failed: {e}")
|
| 773 |
+
return False
|
| 774 |
+
|
| 775 |
+
# Initialize the backend connection
|
| 776 |
+
print("🚀 Initializing frontend application...")
|
| 777 |
+
result = initialize_backend()
|
| 778 |
+
|
| 779 |
+
# Test backend connection if available
|
| 780 |
+
if result and BACKEND_AVAILABLE:
|
| 781 |
+
print("✅ Backend connection successful!")
|
| 782 |
+
else:
|
| 783 |
+
print("❌ Backend connection failed!")
|
| 784 |
+
|
| 785 |
+
# Create the Gradio interface
|
| 786 |
+
print("🎨 Creating Gradio interface...")
|
| 787 |
+
|
| 788 |
+
with gr.Blocks(
|
| 789 |
+
theme=gr.themes.Soft(),
|
| 790 |
+
title="🎯 [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2) - Frontend Interface",
|
| 791 |
+
css="""
|
| 792 |
+
.gradio-container {
|
| 793 |
+
max-width: 1200px !important;
|
| 794 |
+
margin: auto !important;
|
| 795 |
+
}
|
| 796 |
+
.gr-button {
|
| 797 |
+
margin: 5px;
|
| 798 |
+
}
|
| 799 |
+
.gr-form {
|
| 800 |
+
background: white;
|
| 801 |
+
border-radius: 10px;
|
| 802 |
+
padding: 20px;
|
| 803 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
| 804 |
+
}
|
| 805 |
+
/* 固定3D可视化器尺寸 */
|
| 806 |
+
#viz_container {
|
| 807 |
+
height: 650px !important;
|
| 808 |
+
min-height: 650px !important;
|
| 809 |
+
max-height: 650px !important;
|
| 810 |
+
width: 100% !important;
|
| 811 |
+
margin: 0 !important;
|
| 812 |
+
padding: 0 !important;
|
| 813 |
+
overflow: hidden !important;
|
| 814 |
+
}
|
| 815 |
+
#viz_container > div {
|
| 816 |
+
height: 650px !important;
|
| 817 |
+
min-height: 650px !important;
|
| 818 |
+
max-height: 650px !important;
|
| 819 |
+
width: 100% !important;
|
| 820 |
+
margin: 0 !important;
|
| 821 |
+
padding: 0 !important;
|
| 822 |
+
box-sizing: border-box !important;
|
| 823 |
+
}
|
| 824 |
+
#viz_container iframe {
|
| 825 |
+
height: 650px !important;
|
| 826 |
+
min-height: 650px !important;
|
| 827 |
+
max-height: 650px !important;
|
| 828 |
+
width: 100% !important;
|
| 829 |
+
border: none !important;
|
| 830 |
+
display: block !important;
|
| 831 |
+
margin: 0 !important;
|
| 832 |
+
padding: 0 !important;
|
| 833 |
+
box-sizing: border-box !important;
|
| 834 |
+
}
|
| 835 |
+
/* 固定视频上传组件高度 */
|
| 836 |
+
.gr-video {
|
| 837 |
+
height: 300px !important;
|
| 838 |
+
min-height: 300px !important;
|
| 839 |
+
max-height: 300px !important;
|
| 840 |
+
}
|
| 841 |
+
.gr-video video {
|
| 842 |
+
height: 260px !important;
|
| 843 |
+
max-height: 260px !important;
|
| 844 |
+
object-fit: contain !important;
|
| 845 |
+
background: #f8f9fa;
|
| 846 |
+
}
|
| 847 |
+
.gr-video .gr-video-player {
|
| 848 |
+
height: 260px !important;
|
| 849 |
+
max-height: 260px !important;
|
| 850 |
+
}
|
| 851 |
+
/* 水平滚动的示例视频样式 */
|
| 852 |
+
.example-videos .gr-examples {
|
| 853 |
+
overflow: visible !important;
|
| 854 |
+
}
|
| 855 |
+
.example-videos .gr-examples .gr-table-wrapper {
|
| 856 |
+
overflow-x: auto !important;
|
| 857 |
+
overflow-y: hidden !important;
|
| 858 |
+
scrollbar-width: thin;
|
| 859 |
+
scrollbar-color: #667eea #f1f1f1;
|
| 860 |
+
}
|
| 861 |
+
.example-videos .gr-examples .gr-table-wrapper::-webkit-scrollbar {
|
| 862 |
+
height: 8px;
|
| 863 |
+
}
|
| 864 |
+
.example-videos .gr-examples .gr-table-wrapper::-webkit-scrollbar-track {
|
| 865 |
+
background: #f1f1f1;
|
| 866 |
+
border-radius: 4px;
|
| 867 |
+
}
|
| 868 |
+
.example-videos .gr-examples .gr-table-wrapper::-webkit-scrollbar-thumb {
|
| 869 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 870 |
+
border-radius: 4px;
|
| 871 |
+
}
|
| 872 |
+
.example-videos .gr-examples .gr-table-wrapper::-webkit-scrollbar-thumb:hover {
|
| 873 |
+
background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%);
|
| 874 |
+
}
|
| 875 |
+
.example-videos .gr-examples .gr-table {
|
| 876 |
+
display: flex !important;
|
| 877 |
+
flex-wrap: nowrap !important;
|
| 878 |
+
min-width: max-content !important;
|
| 879 |
+
gap: 10px !important;
|
| 880 |
+
}
|
| 881 |
+
.example-videos .gr-examples .gr-table tbody {
|
| 882 |
+
display: flex !important;
|
| 883 |
+
flex-direction: row !important;
|
| 884 |
+
flex-wrap: nowrap !important;
|
| 885 |
+
gap: 10px !important;
|
| 886 |
+
}
|
| 887 |
+
.example-videos .gr-examples .gr-table tbody tr {
|
| 888 |
+
display: flex !important;
|
| 889 |
+
flex-direction: column !important;
|
| 890 |
+
min-width: 120px !important;
|
| 891 |
+
max-width: 120px !important;
|
| 892 |
+
margin: 0 !important;
|
| 893 |
+
background: white;
|
| 894 |
+
border-radius: 8px;
|
| 895 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
|
| 896 |
+
transition: all 0.3s ease;
|
| 897 |
+
cursor: pointer;
|
| 898 |
+
}
|
| 899 |
+
.example-videos .gr-examples .gr-table tbody tr:hover {
|
| 900 |
+
transform: translateY(-2px);
|
| 901 |
+
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.2);
|
| 902 |
+
}
|
| 903 |
+
.example-videos .gr-examples .gr-table tbody tr td {
|
| 904 |
+
text-align: center !important;
|
| 905 |
+
padding: 8px !important;
|
| 906 |
+
border: none !important;
|
| 907 |
+
}
|
| 908 |
+
.example-videos .gr-examples .gr-table tbody tr td video {
|
| 909 |
+
border-radius: 6px !important;
|
| 910 |
+
width: 100% !important;
|
| 911 |
+
height: auto !important;
|
| 912 |
+
}
|
| 913 |
+
.example-videos .gr-examples .gr-table tbody tr td:last-child {
|
| 914 |
+
font-size: 12px !important;
|
| 915 |
+
font-weight: 500 !important;
|
| 916 |
+
color: #333 !important;
|
| 917 |
+
padding-top: 4px !important;
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
+
/* 新的水平滚动示例视频样式 */
|
| 921 |
+
.horizontal-examples .gr-examples {
|
| 922 |
+
overflow: visible !important;
|
| 923 |
+
}
|
| 924 |
+
.horizontal-examples .gr-examples .gr-table-wrapper {
|
| 925 |
+
overflow-x: auto !important;
|
| 926 |
+
overflow-y: hidden !important;
|
| 927 |
+
scrollbar-width: thin;
|
| 928 |
+
scrollbar-color: #667eea #f1f1f1;
|
| 929 |
+
padding: 10px 0;
|
| 930 |
+
}
|
| 931 |
+
.horizontal-examples .gr-examples .gr-table-wrapper::-webkit-scrollbar {
|
| 932 |
+
height: 8px;
|
| 933 |
+
}
|
| 934 |
+
.horizontal-examples .gr-examples .gr-table-wrapper::-webkit-scrollbar-track {
|
| 935 |
+
background: #f1f1f1;
|
| 936 |
+
border-radius: 4px;
|
| 937 |
+
}
|
| 938 |
+
.horizontal-examples .gr-examples .gr-table-wrapper::-webkit-scrollbar-thumb {
|
| 939 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 940 |
+
border-radius: 4px;
|
| 941 |
+
}
|
| 942 |
+
.horizontal-examples .gr-examples .gr-table-wrapper::-webkit-scrollbar-thumb:hover {
|
| 943 |
+
background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%);
|
| 944 |
+
}
|
| 945 |
+
.horizontal-examples .gr-examples .gr-table {
|
| 946 |
+
display: flex !important;
|
| 947 |
+
flex-wrap: nowrap !important;
|
| 948 |
+
min-width: max-content !important;
|
| 949 |
+
gap: 15px !important;
|
| 950 |
+
padding-bottom: 10px;
|
| 951 |
+
}
|
| 952 |
+
.horizontal-examples .gr-examples .gr-table tbody {
|
| 953 |
+
display: flex !important;
|
| 954 |
+
flex-direction: row !important;
|
| 955 |
+
flex-wrap: nowrap !important;
|
| 956 |
+
gap: 15px !important;
|
| 957 |
+
}
|
| 958 |
+
.horizontal-examples .gr-examples .gr-table tbody tr {
|
| 959 |
+
display: flex !important;
|
| 960 |
+
flex-direction: column !important;
|
| 961 |
+
min-width: 160px !important;
|
| 962 |
+
max-width: 160px !important;
|
| 963 |
+
margin: 0 !important;
|
| 964 |
+
background: white;
|
| 965 |
+
border-radius: 12px;
|
| 966 |
+
box-shadow: 0 3px 12px rgba(0,0,0,0.12);
|
| 967 |
+
transition: all 0.3s ease;
|
| 968 |
+
cursor: pointer;
|
| 969 |
+
overflow: hidden;
|
| 970 |
+
}
|
| 971 |
+
.horizontal-examples .gr-examples .gr-table tbody tr:hover {
|
| 972 |
+
transform: translateY(-4px);
|
| 973 |
+
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.25);
|
| 974 |
+
}
|
| 975 |
+
.horizontal-examples .gr-examples .gr-table tbody tr td {
|
| 976 |
+
text-align: center !important;
|
| 977 |
+
padding: 0 !important;
|
| 978 |
+
border: none !important;
|
| 979 |
+
}
|
| 980 |
+
.horizontal-examples .gr-examples .gr-table tbody tr td:first-child {
|
| 981 |
+
padding: 0 !important;
|
| 982 |
+
}
|
| 983 |
+
.horizontal-examples .gr-examples .gr-table tbody tr td video {
|
| 984 |
+
border-radius: 8px 8px 0 0 !important;
|
| 985 |
+
width: 100% !important;
|
| 986 |
+
height: 90px !important;
|
| 987 |
+
object-fit: cover !important;
|
| 988 |
+
}
|
| 989 |
+
.horizontal-examples .gr-examples .gr-table tbody tr td:last-child {
|
| 990 |
+
font-size: 11px !important;
|
| 991 |
+
font-weight: 600 !important;
|
| 992 |
+
color: #333 !important;
|
| 993 |
+
padding: 8px 12px !important;
|
| 994 |
+
background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%);
|
| 995 |
+
border-radius: 0 0 8px 8px;
|
| 996 |
+
}
|
| 997 |
+
"""
|
| 998 |
+
) as demo:
|
| 999 |
+
|
| 1000 |
+
gr.Markdown("""
|
| 1001 |
+
Welcome to [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)! This interface allows you to track any pixels in 3D using our model.
|
| 1002 |
+
|
| 1003 |
+
**⚡ Quick Start:** Upload video → Click "Start Tracking Now!"
|
| 1004 |
+
|
| 1005 |
+
**🔬 Advanced Usage with SAM:**
|
| 1006 |
+
1. Upload a video file or select from examples below
|
| 1007 |
+
2. Expand "Manual Point Selection" to click on specific objects for SAM-guided tracking
|
| 1008 |
+
3. Adjust tracking parameters for optimal performance
|
| 1009 |
+
4. Click "Start Tracking Now!" to begin 3D tracking with SAM guidance
|
| 1010 |
+
|
| 1011 |
+
""")
|
| 1012 |
+
|
| 1013 |
+
# Status indicator - more compact
|
| 1014 |
+
status_info = "🟢 Backend Connected" if BACKEND_AVAILABLE else "🟡 Standalone Mode"
|
| 1015 |
+
gr.Markdown(f"**Status:** {status_info} | Backend: {BACKEND_SPACE_URL}")
|
| 1016 |
+
|
| 1017 |
+
# Main content area - video upload left, 3D visualization right
|
| 1018 |
+
with gr.Row():
|
| 1019 |
+
with gr.Column(scale=1):
|
| 1020 |
+
# Video upload section
|
| 1021 |
+
with gr.Group():
|
| 1022 |
+
gr.Markdown("### 📂 Select Video")
|
| 1023 |
+
|
| 1024 |
+
# Define video_input here so it can be referenced in examples
|
| 1025 |
+
video_input = gr.Video(
|
| 1026 |
+
label="Upload Video or Select Example",
|
| 1027 |
+
format="mp4",
|
| 1028 |
+
height=250 # Matched height with 3D viz
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
# Horizontal video examples with slider
|
| 1032 |
+
gr.Markdown("**Examples:** (scroll horizontally to see all videos)")
|
| 1033 |
+
|
| 1034 |
+
# Traditional examples but with horizontal scroll styling
|
| 1035 |
+
with gr.Row(elem_classes=["horizontal-examples"]):
|
| 1036 |
+
gr.Examples(
|
| 1037 |
+
examples=[
|
| 1038 |
+
["./examples/kiss.mp4"],
|
| 1039 |
+
["./examples/backpack.mp4"],
|
| 1040 |
+
["./examples/kitchen.mp4"],
|
| 1041 |
+
["./examples/pillow.mp4"],
|
| 1042 |
+
["./examples/handwave.mp4"],
|
| 1043 |
+
["./examples/hockey.mp4"],
|
| 1044 |
+
["./examples/drifting.mp4"],
|
| 1045 |
+
["./examples/basketball.mp4"],
|
| 1046 |
+
["./examples/ken_block_0.mp4"],
|
| 1047 |
+
["./examples/ego_kc1.mp4"],
|
| 1048 |
+
["./examples/vertical_place.mp4"],
|
| 1049 |
+
["./examples/ego_teaser.mp4"],
|
| 1050 |
+
["./examples/robot_unitree.mp4"],
|
| 1051 |
+
["./examples/robot_3.mp4"],
|
| 1052 |
+
["./examples/teleop2.mp4"],
|
| 1053 |
+
["./examples/pusht.mp4"],
|
| 1054 |
+
["./examples/cinema_0.mp4"],
|
| 1055 |
+
["./examples/cinema_1.mp4"],
|
| 1056 |
+
],
|
| 1057 |
+
inputs=[video_input],
|
| 1058 |
+
outputs=[video_input],
|
| 1059 |
+
fn=None,
|
| 1060 |
+
cache_examples=False,
|
| 1061 |
+
label="",
|
| 1062 |
+
examples_per_page=6 # Show 6 examples per page so they can wrap to multiple rows
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
with gr.Column(scale=2):
|
| 1066 |
+
# 3D Visualization - wider and taller to match left side
|
| 1067 |
+
with gr.Group():
|
| 1068 |
+
gr.Markdown("### 🌐 3D Trajectory Visualization")
|
| 1069 |
+
viz_html = gr.HTML(
|
| 1070 |
+
label="3D Trajectory Visualization",
|
| 1071 |
+
value="""
|
| 1072 |
+
<div style='border: 3px solid #667eea; border-radius: 10px;
|
| 1073 |
+
background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%);
|
| 1074 |
+
text-align: center; height: 650px; display: flex;
|
| 1075 |
+
flex-direction: column; justify-content: center; align-items: center;
|
| 1076 |
+
box-shadow: 0 4px 16px rgba(102, 126, 234, 0.15);
|
| 1077 |
+
margin: 0; padding: 20px; box-sizing: border-box;'>
|
| 1078 |
+
<div style='font-size: 56px; margin-bottom: 25px;'>🌐</div>
|
| 1079 |
+
<h3 style='color: #667eea; margin-bottom: 18px; font-size: 28px; font-weight: 600;'>
|
| 1080 |
+
3D Trajectory Visualization
|
| 1081 |
+
</h3>
|
| 1082 |
+
<p style='color: #666; font-size: 18px; line-height: 1.6; max-width: 550px; margin-bottom: 30px;'>
|
| 1083 |
+
Track any pixels in 3D space with camera motion
|
| 1084 |
+
</p>
|
| 1085 |
+
<div style='background: rgba(102, 126, 234, 0.1); border-radius: 30px;
|
| 1086 |
+
padding: 15px 30px; border: 1px solid rgba(102, 126, 234, 0.2);'>
|
| 1087 |
+
<span style='color: #667eea; font-weight: 600; font-size: 16px;'>
|
| 1088 |
+
⚡ Powered by SpatialTracker V2
|
| 1089 |
+
</span>
|
| 1090 |
+
</div>
|
| 1091 |
+
</div>
|
| 1092 |
+
""",
|
| 1093 |
+
elem_id="viz_container"
|
| 1094 |
+
)
|
| 1095 |
+
|
| 1096 |
+
# Start button section - below video area
|
| 1097 |
+
with gr.Row():
|
| 1098 |
+
with gr.Column(scale=3):
|
| 1099 |
+
launch_btn = gr.Button("🚀 Start Tracking Now!", variant="primary", size="lg")
|
| 1100 |
+
with gr.Column(scale=1):
|
| 1101 |
+
clear_all_btn = gr.Button("🗑️ Clear All", variant="secondary", size="sm")
|
| 1102 |
+
|
| 1103 |
+
# Tracking parameters section
|
| 1104 |
+
with gr.Row():
|
| 1105 |
+
gr.Markdown("### ⚙️ Tracking Parameters")
|
| 1106 |
+
with gr.Row():
|
| 1107 |
+
grid_size = gr.Slider(
|
| 1108 |
+
minimum=10, maximum=100, step=10, value=50,
|
| 1109 |
+
label="Grid Size", info="Tracking detail level"
|
| 1110 |
+
)
|
| 1111 |
+
vo_points = gr.Slider(
|
| 1112 |
+
minimum=100, maximum=2000, step=50, value=756,
|
| 1113 |
+
label="VO Points", info="Motion accuracy"
|
| 1114 |
+
)
|
| 1115 |
+
fps = gr.Slider(
|
| 1116 |
+
minimum=1, maximum=30, step=1, value=3,
|
| 1117 |
+
label="FPS", info="Processing speed"
|
| 1118 |
+
)
|
| 1119 |
+
|
| 1120 |
+
# Advanced Point Selection with SAM - Collapsed by default
|
| 1121 |
+
with gr.Row():
|
| 1122 |
+
gr.Markdown("### 🎯 Advanced: Manual Point Selection with SAM")
|
| 1123 |
+
with gr.Accordion("🔬 SAM Point Selection Controls", open=False):
|
| 1124 |
+
gr.HTML("""
|
| 1125 |
+
<div style='margin-bottom: 15px;'>
|
| 1126 |
+
<ul style='color: #4a5568; font-size: 14px; line-height: 1.6; margin: 0; padding-left: 20px;'>
|
| 1127 |
+
<li>Click on target objects in the image for SAM-guided segmentation</li>
|
| 1128 |
+
<li>Positive points: include these areas | Negative points: exclude these areas</li>
|
| 1129 |
+
<li>Get more accurate 3D tracking results with SAM's powerful segmentation</li>
|
| 1130 |
+
</ul>
|
| 1131 |
+
</div>
|
| 1132 |
+
""")
|
| 1133 |
+
|
| 1134 |
+
with gr.Row():
|
| 1135 |
+
with gr.Column():
|
| 1136 |
+
interactive_frame = gr.Image(
|
| 1137 |
+
label="Click to select tracking points with SAM guidance",
|
| 1138 |
+
type="numpy",
|
| 1139 |
+
interactive=True,
|
| 1140 |
+
height=300
|
| 1141 |
+
)
|
| 1142 |
+
|
| 1143 |
+
with gr.Row():
|
| 1144 |
+
point_type = gr.Radio(
|
| 1145 |
+
choices=["positive_point", "negative_point"],
|
| 1146 |
+
value="positive_point",
|
| 1147 |
+
label="Point Type",
|
| 1148 |
+
info="Positive: track these areas | Negative: avoid these areas"
|
| 1149 |
+
)
|
| 1150 |
+
|
| 1151 |
+
with gr.Row():
|
| 1152 |
+
reset_points_btn = gr.Button("🔄 Reset Points", variant="secondary", size="sm")
|
| 1153 |
+
|
| 1154 |
+
# Downloads section - hidden but still functional for backend processing
|
| 1155 |
+
with gr.Row(visible=False):
|
| 1156 |
+
with gr.Column(scale=1):
|
| 1157 |
+
tracking_video_download = gr.File(
|
| 1158 |
+
label="📹 Download 2D Tracking Video",
|
| 1159 |
+
interactive=False,
|
| 1160 |
+
visible=False
|
| 1161 |
+
)
|
| 1162 |
+
with gr.Column(scale=1):
|
| 1163 |
+
html_download = gr.File(
|
| 1164 |
+
label="📄 Download 3D Visualization HTML",
|
| 1165 |
+
interactive=False,
|
| 1166 |
+
visible=False
|
| 1167 |
+
)
|
| 1168 |
+
|
| 1169 |
+
# GitHub Star Section
|
| 1170 |
+
gr.HTML("""
|
| 1171 |
+
<div style='background: linear-gradient(135deg, #e8eaff 0%, #f0f2ff 100%);
|
| 1172 |
+
border-radius: 8px; padding: 20px; margin: 15px 0;
|
| 1173 |
+
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.1);
|
| 1174 |
+
border: 1px solid rgba(102, 126, 234, 0.15);'>
|
| 1175 |
+
<div style='text-align: center;'>
|
| 1176 |
+
<h3 style='color: #4a5568; margin: 0 0 10px 0; font-size: 18px; font-weight: 600;'>
|
| 1177 |
+
⭐ Love SpatialTracker? Give us a Star! ⭐
|
| 1178 |
+
</h3>
|
| 1179 |
+
<p style='color: #666; margin: 0 0 15px 0; font-size: 14px; line-height: 1.5;'>
|
| 1180 |
+
Help us grow by starring our repository on GitHub! Your support means a lot to the community. 🚀
|
| 1181 |
+
</p>
|
| 1182 |
+
<a href="https://github.com/henry123-boy/SpaTrackerV2" target="_blank"
|
| 1183 |
+
style='display: inline-flex; align-items: center; gap: 8px;
|
| 1184 |
+
background: rgba(102, 126, 234, 0.1); color: #4a5568;
|
| 1185 |
+
padding: 10px 20px; border-radius: 25px; text-decoration: none;
|
| 1186 |
+
font-weight: bold; font-size: 14px; border: 1px solid rgba(102, 126, 234, 0.2);
|
| 1187 |
+
transition: all 0.3s ease;'
|
| 1188 |
+
onmouseover="this.style.background='rgba(102, 126, 234, 0.15)'; this.style.transform='translateY(-2px)'"
|
| 1189 |
+
onmouseout="this.style.background='rgba(102, 126, 234, 0.1)'; this.style.transform='translateY(0)'">
|
| 1190 |
+
<span style='font-size: 16px;'>⭐</span>
|
| 1191 |
+
Star SpatialTracker V2 on GitHub
|
| 1192 |
+
</a>
|
| 1193 |
+
</div>
|
| 1194 |
+
</div>
|
| 1195 |
+
""")
|
| 1196 |
+
|
| 1197 |
+
# Acknowledgments Section
|
| 1198 |
+
gr.HTML("""
|
| 1199 |
+
<div style='background: linear-gradient(135deg, #fff8e1 0%, #fffbf0 100%);
|
| 1200 |
+
border-radius: 8px; padding: 20px; margin: 15px 0;
|
| 1201 |
+
box-shadow: 0 2px 8px rgba(255, 193, 7, 0.1);
|
| 1202 |
+
border: 1px solid rgba(255, 193, 7, 0.2);'>
|
| 1203 |
+
<div style='text-align: center;'>
|
| 1204 |
+
<h3 style='color: #5d4037; margin: 0 0 10px 0; font-size: 18px; font-weight: 600;'>
|
| 1205 |
+
📚 Acknowledgments
|
| 1206 |
+
</h3>
|
| 1207 |
+
<p style='color: #5d4037; margin: 0 0 15px 0; font-size: 14px; line-height: 1.5;'>
|
| 1208 |
+
Our 3D visualizer is adapted from <strong>TAPIP3D</strong>. We thank the authors for their excellent work and contribution to the computer vision community!
|
| 1209 |
+
</p>
|
| 1210 |
+
<a href="https://github.com/zbw001/TAPIP3D" target="_blank"
|
| 1211 |
+
style='display: inline-flex; align-items: center; gap: 8px;
|
| 1212 |
+
background: rgba(255, 193, 7, 0.15); color: #5d4037;
|
| 1213 |
+
padding: 10px 20px; border-radius: 25px; text-decoration: none;
|
| 1214 |
+
font-weight: bold; font-size: 14px; border: 1px solid rgba(255, 193, 7, 0.3);
|
| 1215 |
+
transition: all 0.3s ease;'
|
| 1216 |
+
onmouseover="this.style.background='rgba(255, 193, 7, 0.25)'; this.style.transform='translateY(-2px)'"
|
| 1217 |
+
onmouseout="this.style.background='rgba(255, 193, 7, 0.15)'; this.style.transform='translateY(0)'">
|
| 1218 |
+
📚 Visit TAPIP3D Repository
|
| 1219 |
+
</a>
|
| 1220 |
+
</div>
|
| 1221 |
+
</div>
|
| 1222 |
+
""")
|
| 1223 |
+
|
| 1224 |
+
# Footer
|
| 1225 |
+
gr.HTML("""
|
| 1226 |
+
<div style='text-align: center; margin: 20px 0 10px 0;'>
|
| 1227 |
+
<span style='font-size: 12px; color: #888; font-style: italic;'>
|
| 1228 |
+
Powered by SpatialTracker V2 | Built with ❤️ for the Computer Vision Community
|
| 1229 |
+
</span>
|
| 1230 |
+
</div>
|
| 1231 |
+
""")
|
| 1232 |
+
|
| 1233 |
+
# Hidden state variables
|
| 1234 |
+
original_image_state = gr.State(None)
|
| 1235 |
+
selected_points = gr.State([])
|
| 1236 |
+
|
| 1237 |
+
# Event handlers
|
| 1238 |
+
video_input.change(
|
| 1239 |
+
fn=handle_video_upload,
|
| 1240 |
+
inputs=[video_input],
|
| 1241 |
+
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
|
| 1242 |
+
)
|
| 1243 |
+
|
| 1244 |
+
interactive_frame.select(
|
| 1245 |
+
fn=select_point,
|
| 1246 |
+
inputs=[original_image_state, selected_points, point_type],
|
| 1247 |
+
outputs=[interactive_frame, selected_points]
|
| 1248 |
+
)
|
| 1249 |
+
|
| 1250 |
+
reset_points_btn.click(
|
| 1251 |
+
fn=reset_points,
|
| 1252 |
+
inputs=[original_image_state, selected_points],
|
| 1253 |
+
outputs=[interactive_frame, selected_points]
|
| 1254 |
+
)
|
| 1255 |
+
|
| 1256 |
+
clear_all_btn.click(
|
| 1257 |
+
fn=clear_all_with_download,
|
| 1258 |
+
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download]
|
| 1259 |
+
)
|
| 1260 |
+
|
| 1261 |
+
launch_btn.click(
|
| 1262 |
+
fn=launch_viz,
|
| 1263 |
+
inputs=[grid_size, vo_points, fps, original_image_state],
|
| 1264 |
+
outputs=[viz_html, tracking_video_download, html_download]
|
| 1265 |
+
)
|
| 1266 |
+
|
| 1267 |
+
# Launch the interface
|
| 1268 |
+
if __name__ == "__main__":
|
| 1269 |
+
print("🌟 Launching SpatialTracker V2 Frontend...")
|
| 1270 |
+
print(f"🔗 Backend Status: {'Connected' if BACKEND_AVAILABLE else 'Disconnected'}")
|
| 1271 |
+
|
| 1272 |
+
demo.launch(
|
| 1273 |
+
server_name="0.0.0.0",
|
| 1274 |
+
server_port=7860,
|
| 1275 |
+
share=True,
|
| 1276 |
+
debug=True,
|
| 1277 |
+
show_error=True
|
| 1278 |
+
)
|
config/__init__.py
ADDED
|
File without changes
|
config/magic_infer_moge.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 0
|
| 2 |
+
# config the hydra logger, only in hydra `$` can be decoded as cite
|
| 3 |
+
data: ./assets/room
|
| 4 |
+
vis_track: false
|
| 5 |
+
hydra:
|
| 6 |
+
run:
|
| 7 |
+
dir: .
|
| 8 |
+
output_subdir: null
|
| 9 |
+
job_logging: {}
|
| 10 |
+
hydra_logging: {}
|
| 11 |
+
mixed_precision: bf16
|
| 12 |
+
visdom:
|
| 13 |
+
viz_ip: "localhost"
|
| 14 |
+
port: 6666
|
| 15 |
+
relax_load: false
|
| 16 |
+
res_all: 336
|
| 17 |
+
# config the ckpt path
|
| 18 |
+
# ckpts: "/mnt/bn/xyxdata/home/codes/my_projs/SpaTrack2/checkpoints/new_base.pth"
|
| 19 |
+
ckpts: "Yuxihenry/SpatialTracker_Files"
|
| 20 |
+
batch_size: 1
|
| 21 |
+
input:
|
| 22 |
+
type: image
|
| 23 |
+
fps: 1
|
| 24 |
+
model_wind_size: 32
|
| 25 |
+
model:
|
| 26 |
+
backbone_cfg:
|
| 27 |
+
ckpt_dir: "checkpoints/model.pt"
|
| 28 |
+
chunk_size: 24 # downsample factor for patchified features
|
| 29 |
+
ckpt_fwd: true
|
| 30 |
+
ft_cfg:
|
| 31 |
+
mode: "fix"
|
| 32 |
+
paras_name: []
|
| 33 |
+
resolution: 336
|
| 34 |
+
max_len: 512
|
| 35 |
+
Track_cfg:
|
| 36 |
+
base_ckpt: "checkpoints/scaled_offline.pth"
|
| 37 |
+
base:
|
| 38 |
+
stride: 4
|
| 39 |
+
corr_radius: 3
|
| 40 |
+
window_len: 60
|
| 41 |
+
stablizer: True
|
| 42 |
+
mode: "online"
|
| 43 |
+
s_wind: 200
|
| 44 |
+
overlap: 4
|
| 45 |
+
track_num: 0
|
| 46 |
+
|
| 47 |
+
dist_train:
|
| 48 |
+
num_nodes: 1
|
frontend_app_local.py
ADDED
|
@@ -0,0 +1,1036 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
import base64
|
| 7 |
+
import time
|
| 8 |
+
import tempfile
|
| 9 |
+
import shutil
|
| 10 |
+
import glob
|
| 11 |
+
import threading
|
| 12 |
+
import subprocess
|
| 13 |
+
import struct
|
| 14 |
+
import zlib
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
from typing import List, Tuple, Union
|
| 18 |
+
import torch
|
| 19 |
+
import logging
|
| 20 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 21 |
+
import atexit
|
| 22 |
+
import uuid
|
| 23 |
+
|
| 24 |
+
# Configure logging
|
| 25 |
+
logging.basicConfig(level=logging.INFO)
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
# Import custom modules with error handling
|
| 29 |
+
try:
|
| 30 |
+
from app_3rd.sam_utils.inference import SamPredictor, get_sam_predictor, run_inference
|
| 31 |
+
from app_3rd.spatrack_utils.infer_track import get_tracker_predictor, run_tracker, get_points_on_a_grid
|
| 32 |
+
except ImportError as e:
|
| 33 |
+
logger.error(f"Failed to import custom modules: {e}")
|
| 34 |
+
raise
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
import spaces
|
| 38 |
+
except ImportError:
|
| 39 |
+
# Fallback for local development
|
| 40 |
+
def spaces(func):
|
| 41 |
+
return func
|
| 42 |
+
|
| 43 |
+
# Constants
|
| 44 |
+
MAX_FRAMES = 80
|
| 45 |
+
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
|
| 46 |
+
MARKERS = [1, 5] # Cross for negative, Star for positive
|
| 47 |
+
MARKER_SIZE = 8
|
| 48 |
+
|
| 49 |
+
# Thread pool for delayed deletion
|
| 50 |
+
thread_pool_executor = ThreadPoolExecutor(max_workers=2)
|
| 51 |
+
|
| 52 |
+
def delete_later(path: Union[str, os.PathLike], delay: int = 600):
|
| 53 |
+
"""Delete file or directory after specified delay (default 10 minutes)"""
|
| 54 |
+
def _delete():
|
| 55 |
+
try:
|
| 56 |
+
if os.path.isfile(path):
|
| 57 |
+
os.remove(path)
|
| 58 |
+
elif os.path.isdir(path):
|
| 59 |
+
shutil.rmtree(path)
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.warning(f"Failed to delete {path}: {e}")
|
| 62 |
+
|
| 63 |
+
def _wait_and_delete():
|
| 64 |
+
time.sleep(delay)
|
| 65 |
+
_delete()
|
| 66 |
+
|
| 67 |
+
thread_pool_executor.submit(_wait_and_delete)
|
| 68 |
+
atexit.register(_delete)
|
| 69 |
+
|
| 70 |
+
def create_user_temp_dir():
|
| 71 |
+
"""Create a unique temporary directory for each user session"""
|
| 72 |
+
session_id = str(uuid.uuid4())[:8] # Short unique ID
|
| 73 |
+
temp_dir = os.path.join("temp_local", f"session_{session_id}")
|
| 74 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
# Schedule deletion after 10 minutes
|
| 77 |
+
delete_later(temp_dir, delay=600)
|
| 78 |
+
|
| 79 |
+
return temp_dir
|
| 80 |
+
|
| 81 |
+
# Initialize VGGT model
|
| 82 |
+
try:
|
| 83 |
+
import vggt
|
| 84 |
+
except:
|
| 85 |
+
subprocess.run(["pip", "install", "-e", "./models/vggt"], check=True)
|
| 86 |
+
|
| 87 |
+
from huggingface_hub import hf_hub_download
|
| 88 |
+
os.environ["VGGT_DIR"] = hf_hub_download("facebook/VGGT-1B", "model.pt")
|
| 89 |
+
|
| 90 |
+
if os.environ.get("VGGT_DIR", None) is not None:
|
| 91 |
+
from vggt.models.vggt import VGGT
|
| 92 |
+
from vggt.utils.load_fn import preprocess_image
|
| 93 |
+
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
| 94 |
+
vggt_model = VGGT()
|
| 95 |
+
vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")))
|
| 96 |
+
vggt_model.eval()
|
| 97 |
+
vggt_model = vggt_model.to("cuda")
|
| 98 |
+
|
| 99 |
+
# Global model initialization
|
| 100 |
+
print("🚀 Initializing local models...")
|
| 101 |
+
tracker_model, _ = get_tracker_predictor(".", vo_points=756)
|
| 102 |
+
predictor = get_sam_predictor()
|
| 103 |
+
print("✅ Models loaded successfully!")
|
| 104 |
+
|
| 105 |
+
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
| 106 |
+
|
| 107 |
+
@spaces.GPU
|
| 108 |
+
def gpu_run_inference(predictor_arg, image, points, boxes):
|
| 109 |
+
"""GPU-accelerated SAM inference"""
|
| 110 |
+
if predictor_arg is None:
|
| 111 |
+
print("Initializing SAM predictor inside GPU function...")
|
| 112 |
+
predictor_arg = get_sam_predictor(predictor=predictor)
|
| 113 |
+
|
| 114 |
+
# Ensure predictor is on GPU
|
| 115 |
+
try:
|
| 116 |
+
if hasattr(predictor_arg, 'model'):
|
| 117 |
+
predictor_arg.model = predictor_arg.model.cuda()
|
| 118 |
+
elif hasattr(predictor_arg, 'sam'):
|
| 119 |
+
predictor_arg.sam = predictor_arg.sam.cuda()
|
| 120 |
+
elif hasattr(predictor_arg, 'to'):
|
| 121 |
+
predictor_arg = predictor_arg.to('cuda')
|
| 122 |
+
|
| 123 |
+
if hasattr(image, 'cuda'):
|
| 124 |
+
image = image.cuda()
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"Warning: Could not move predictor to GPU: {e}")
|
| 128 |
+
|
| 129 |
+
return run_inference(predictor_arg, image, points, boxes)
|
| 130 |
+
|
| 131 |
+
@spaces.GPU
|
| 132 |
+
def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps):
|
| 133 |
+
"""GPU-accelerated tracking"""
|
| 134 |
+
import torchvision.transforms as T
|
| 135 |
+
import decord
|
| 136 |
+
|
| 137 |
+
if tracker_model_arg is None or tracker_viser_arg is None:
|
| 138 |
+
print("Initializing tracker models inside GPU function...")
|
| 139 |
+
out_dir = os.path.join(temp_dir, "results")
|
| 140 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 141 |
+
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, tracker_model=tracker_model)
|
| 142 |
+
|
| 143 |
+
# Setup paths
|
| 144 |
+
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
| 145 |
+
mask_path = os.path.join(temp_dir, f"{video_name}.png")
|
| 146 |
+
out_dir = os.path.join(temp_dir, "results")
|
| 147 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 148 |
+
|
| 149 |
+
# Load video using decord
|
| 150 |
+
video_reader = decord.VideoReader(video_path)
|
| 151 |
+
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2)
|
| 152 |
+
|
| 153 |
+
# Resize to ensure minimum side is 336
|
| 154 |
+
h, w = video_tensor.shape[2:]
|
| 155 |
+
scale = max(224 / h, 224 / w)
|
| 156 |
+
if scale < 1:
|
| 157 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 158 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 159 |
+
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES]
|
| 160 |
+
|
| 161 |
+
# Move to GPU
|
| 162 |
+
video_tensor = video_tensor.cuda()
|
| 163 |
+
print(f"Video tensor shape: {video_tensor.shape}, device: {video_tensor.device}")
|
| 164 |
+
|
| 165 |
+
depth_tensor = None
|
| 166 |
+
intrs = None
|
| 167 |
+
extrs = None
|
| 168 |
+
data_npz_load = {}
|
| 169 |
+
|
| 170 |
+
# Run VGGT for depth and camera estimation
|
| 171 |
+
if os.environ.get("VGGT_DIR", None) is not None:
|
| 172 |
+
video_tensor = preprocess_image(video_tensor)[None]
|
| 173 |
+
with torch.no_grad():
|
| 174 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 175 |
+
|
| 176 |
+
#TODO: remove this
|
| 177 |
+
single_frame=False
|
| 178 |
+
if single_frame==True:
|
| 179 |
+
video_tensor = rearrange(video_tensor, "b s c h w -> (b s) 1 c h w")
|
| 180 |
+
|
| 181 |
+
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_tensor.cuda()/255)
|
| 182 |
+
pose_enc = vggt_model.camera_head(aggregated_tokens_list)[-1]
|
| 183 |
+
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, video_tensor.shape[-2:])
|
| 184 |
+
depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_tensor.cuda()/255, ps_idx)
|
| 185 |
+
|
| 186 |
+
#TODO: remove this
|
| 187 |
+
if single_frame==True:
|
| 188 |
+
video_tensor = rearrange(video_tensor, "(b s) 1 c h w -> b s c h w", b=1)
|
| 189 |
+
depth_map = rearrange(depth_map, "(b s) 1 h w c -> b s h w c", b=video_tensor.shape[0])
|
| 190 |
+
depth_conf = rearrange(depth_conf, "(b s) 1 h w -> b s h w", b=video_tensor.shape[0])
|
| 191 |
+
extrinsic = rearrange(extrinsic, "(b s) 1 e f -> b s e f", b=1)
|
| 192 |
+
intrinsic = rearrange(intrinsic, "(b s) 1 e f -> b s e f", b=1)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
depth_tensor = depth_map.squeeze().cpu().numpy()
|
| 196 |
+
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
| 197 |
+
extrs[:, :3, :4] = extrinsic.squeeze().cpu().numpy()
|
| 198 |
+
intrs = intrinsic.squeeze().cpu().numpy()
|
| 199 |
+
video_tensor = video_tensor.squeeze()
|
| 200 |
+
threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
|
| 201 |
+
unc_metric = depth_conf.squeeze().cpu().numpy() > threshold
|
| 202 |
+
|
| 203 |
+
# Load and process mask
|
| 204 |
+
if os.path.exists(mask_path):
|
| 205 |
+
mask = cv2.imread(mask_path)
|
| 206 |
+
mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
|
| 207 |
+
mask = mask.sum(axis=-1)>0
|
| 208 |
+
else:
|
| 209 |
+
mask = np.ones_like(video_tensor[0,0].cpu().numpy())>0
|
| 210 |
+
grid_size = 10
|
| 211 |
+
|
| 212 |
+
# Get frame dimensions and create grid points
|
| 213 |
+
frame_H, frame_W = video_tensor.shape[2:]
|
| 214 |
+
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cuda")
|
| 215 |
+
|
| 216 |
+
# Sample mask values at grid points and filter
|
| 217 |
+
if os.path.exists(mask_path):
|
| 218 |
+
grid_pts_int = grid_pts[0].long()
|
| 219 |
+
mask_values = mask[grid_pts_int.cpu()[...,1], grid_pts_int.cpu()[...,0]]
|
| 220 |
+
grid_pts = grid_pts[:, mask_values]
|
| 221 |
+
|
| 222 |
+
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
|
| 223 |
+
print(f"Query points shape: {query_xyt.shape}")
|
| 224 |
+
|
| 225 |
+
# Run model inference
|
| 226 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 227 |
+
(
|
| 228 |
+
c2w_traj, intrs, point_map, conf_depth,
|
| 229 |
+
track3d_pred, track2d_pred, vis_pred, conf_pred, video
|
| 230 |
+
) = tracker_model_arg.forward(video_tensor, depth=depth_tensor,
|
| 231 |
+
intrs=intrs, extrs=extrs,
|
| 232 |
+
queries=query_xyt,
|
| 233 |
+
fps=1, full_point=False, iters_track=4,
|
| 234 |
+
query_no_BA=True, fixed_cam=False, stage=1,
|
| 235 |
+
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
| 236 |
+
|
| 237 |
+
# Resize results to avoid large I/O
|
| 238 |
+
max_size = 224
|
| 239 |
+
h, w = video.shape[2:]
|
| 240 |
+
scale = min(max_size / h, max_size / w)
|
| 241 |
+
if scale < 1:
|
| 242 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 243 |
+
video = T.Resize((new_h, new_w))(video)
|
| 244 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 245 |
+
point_map = T.Resize((new_h, new_w))(point_map)
|
| 246 |
+
track2d_pred[...,:2] = track2d_pred[...,:2] * scale
|
| 247 |
+
intrs[:,:2,:] = intrs[:,:2,:] * scale
|
| 248 |
+
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
| 249 |
+
|
| 250 |
+
# Visualize tracks
|
| 251 |
+
tracker_viser_arg.visualize(video=video[None],
|
| 252 |
+
tracks=track2d_pred[None][...,:2],
|
| 253 |
+
visibility=vis_pred[None],filename="test")
|
| 254 |
+
|
| 255 |
+
# Save in tapip3d format
|
| 256 |
+
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
|
| 257 |
+
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
|
| 258 |
+
data_npz_load["intrinsics"] = intrs.cpu().numpy()
|
| 259 |
+
data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
|
| 260 |
+
data_npz_load["video"] = (video_tensor).cpu().numpy()/255
|
| 261 |
+
data_npz_load["visibs"] = vis_pred.cpu().numpy()
|
| 262 |
+
data_npz_load["confs"] = conf_pred.cpu().numpy()
|
| 263 |
+
data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
|
| 264 |
+
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
|
| 265 |
+
|
| 266 |
+
return None
|
| 267 |
+
|
| 268 |
+
def compress_and_write(filename, header, blob):
|
| 269 |
+
header_bytes = json.dumps(header).encode("utf-8")
|
| 270 |
+
header_len = struct.pack("<I", len(header_bytes))
|
| 271 |
+
with open(filename, "wb") as f:
|
| 272 |
+
f.write(header_len)
|
| 273 |
+
f.write(header_bytes)
|
| 274 |
+
f.write(blob)
|
| 275 |
+
|
| 276 |
+
def process_point_cloud_data(npz_file, width=256, height=192, fps=4):
|
| 277 |
+
fixed_size = (width, height)
|
| 278 |
+
|
| 279 |
+
data = np.load(npz_file)
|
| 280 |
+
extrinsics = data["extrinsics"]
|
| 281 |
+
intrinsics = data["intrinsics"]
|
| 282 |
+
trajs = data["coords"]
|
| 283 |
+
T, C, H, W = data["video"].shape
|
| 284 |
+
|
| 285 |
+
fx = intrinsics[0, 0, 0]
|
| 286 |
+
fy = intrinsics[0, 1, 1]
|
| 287 |
+
fov_y = 2 * np.arctan(H / (2 * fy)) * (180 / np.pi)
|
| 288 |
+
fov_x = 2 * np.arctan(W / (2 * fx)) * (180 / np.pi)
|
| 289 |
+
original_aspect_ratio = (W / fx) / (H / fy)
|
| 290 |
+
|
| 291 |
+
rgb_video = (rearrange(data["video"], "T C H W -> T H W C") * 255).astype(np.uint8)
|
| 292 |
+
rgb_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_AREA)
|
| 293 |
+
for frame in rgb_video])
|
| 294 |
+
|
| 295 |
+
depth_video = data["depths"].astype(np.float32)
|
| 296 |
+
if "confs_depth" in data.keys():
|
| 297 |
+
confs = (data["confs_depth"].astype(np.float32) > 0.5).astype(np.float32)
|
| 298 |
+
depth_video = depth_video * confs
|
| 299 |
+
depth_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_NEAREST)
|
| 300 |
+
for frame in depth_video])
|
| 301 |
+
|
| 302 |
+
scale_x = fixed_size[0] / W
|
| 303 |
+
scale_y = fixed_size[1] / H
|
| 304 |
+
intrinsics = intrinsics.copy()
|
| 305 |
+
intrinsics[:, 0, :] *= scale_x
|
| 306 |
+
intrinsics[:, 1, :] *= scale_y
|
| 307 |
+
|
| 308 |
+
min_depth = float(depth_video.min()) * 0.8
|
| 309 |
+
max_depth = float(depth_video.max()) * 1.5
|
| 310 |
+
|
| 311 |
+
depth_normalized = (depth_video - min_depth) / (max_depth - min_depth)
|
| 312 |
+
depth_int = (depth_normalized * ((1 << 16) - 1)).astype(np.uint16)
|
| 313 |
+
|
| 314 |
+
depths_rgb = np.zeros((T, fixed_size[1], fixed_size[0], 3), dtype=np.uint8)
|
| 315 |
+
depths_rgb[:, :, :, 0] = (depth_int & 0xFF).astype(np.uint8)
|
| 316 |
+
depths_rgb[:, :, :, 1] = ((depth_int >> 8) & 0xFF).astype(np.uint8)
|
| 317 |
+
|
| 318 |
+
first_frame_inv = np.linalg.inv(extrinsics[0])
|
| 319 |
+
normalized_extrinsics = np.array([first_frame_inv @ ext for ext in extrinsics])
|
| 320 |
+
|
| 321 |
+
normalized_trajs = np.zeros_like(trajs)
|
| 322 |
+
for t in range(T):
|
| 323 |
+
homogeneous_trajs = np.concatenate([trajs[t], np.ones((trajs.shape[1], 1))], axis=1)
|
| 324 |
+
transformed_trajs = (first_frame_inv @ homogeneous_trajs.T).T
|
| 325 |
+
normalized_trajs[t] = transformed_trajs[:, :3]
|
| 326 |
+
|
| 327 |
+
arrays = {
|
| 328 |
+
"rgb_video": rgb_video,
|
| 329 |
+
"depths_rgb": depths_rgb,
|
| 330 |
+
"intrinsics": intrinsics,
|
| 331 |
+
"extrinsics": normalized_extrinsics,
|
| 332 |
+
"inv_extrinsics": np.linalg.inv(normalized_extrinsics),
|
| 333 |
+
"trajectories": normalized_trajs.astype(np.float32),
|
| 334 |
+
"cameraZ": 0.0
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
header = {}
|
| 338 |
+
blob_parts = []
|
| 339 |
+
offset = 0
|
| 340 |
+
for key, arr in arrays.items():
|
| 341 |
+
arr = np.ascontiguousarray(arr)
|
| 342 |
+
arr_bytes = arr.tobytes()
|
| 343 |
+
header[key] = {
|
| 344 |
+
"dtype": str(arr.dtype),
|
| 345 |
+
"shape": arr.shape,
|
| 346 |
+
"offset": offset,
|
| 347 |
+
"length": len(arr_bytes)
|
| 348 |
+
}
|
| 349 |
+
blob_parts.append(arr_bytes)
|
| 350 |
+
offset += len(arr_bytes)
|
| 351 |
+
|
| 352 |
+
raw_blob = b"".join(blob_parts)
|
| 353 |
+
compressed_blob = zlib.compress(raw_blob, level=9)
|
| 354 |
+
|
| 355 |
+
header["meta"] = {
|
| 356 |
+
"depthRange": [min_depth, max_depth],
|
| 357 |
+
"totalFrames": int(T),
|
| 358 |
+
"resolution": fixed_size,
|
| 359 |
+
"baseFrameRate": fps,
|
| 360 |
+
"numTrajectoryPoints": normalized_trajs.shape[1],
|
| 361 |
+
"fov": float(fov_y),
|
| 362 |
+
"fov_x": float(fov_x),
|
| 363 |
+
"original_aspect_ratio": float(original_aspect_ratio),
|
| 364 |
+
"fixed_aspect_ratio": float(fixed_size[0]/fixed_size[1])
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
compress_and_write('./_viz/data.bin', header, compressed_blob)
|
| 368 |
+
with open('./_viz/data.bin', "rb") as f:
|
| 369 |
+
encoded_blob = base64.b64encode(f.read()).decode("ascii")
|
| 370 |
+
os.unlink('./_viz/data.bin')
|
| 371 |
+
|
| 372 |
+
random_path = f'./_viz/_{time.time()}.html'
|
| 373 |
+
with open('./_viz/viz_template.html') as f:
|
| 374 |
+
html_template = f.read()
|
| 375 |
+
html_out = html_template.replace(
|
| 376 |
+
"<head>",
|
| 377 |
+
f"<head>\n<script>window.embeddedBase64 = `{encoded_blob}`;</script>"
|
| 378 |
+
)
|
| 379 |
+
with open(random_path,'w') as f:
|
| 380 |
+
f.write(html_out)
|
| 381 |
+
|
| 382 |
+
return random_path
|
| 383 |
+
|
| 384 |
+
def numpy_to_base64(arr):
|
| 385 |
+
"""Convert numpy array to base64 string"""
|
| 386 |
+
return base64.b64encode(arr.tobytes()).decode('utf-8')
|
| 387 |
+
|
| 388 |
+
def base64_to_numpy(b64_str, shape, dtype):
|
| 389 |
+
"""Convert base64 string back to numpy array"""
|
| 390 |
+
return np.frombuffer(base64.b64decode(b64_str), dtype=dtype).reshape(shape)
|
| 391 |
+
|
| 392 |
+
def get_video_name(video_path):
|
| 393 |
+
"""Extract video name without extension"""
|
| 394 |
+
return os.path.splitext(os.path.basename(video_path))[0]
|
| 395 |
+
|
| 396 |
+
def extract_first_frame(video_path):
|
| 397 |
+
"""Extract first frame from video file"""
|
| 398 |
+
try:
|
| 399 |
+
cap = cv2.VideoCapture(video_path)
|
| 400 |
+
ret, frame = cap.read()
|
| 401 |
+
cap.release()
|
| 402 |
+
|
| 403 |
+
if ret:
|
| 404 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 405 |
+
return frame_rgb
|
| 406 |
+
else:
|
| 407 |
+
return None
|
| 408 |
+
except Exception as e:
|
| 409 |
+
print(f"Error extracting first frame: {e}")
|
| 410 |
+
return None
|
| 411 |
+
|
| 412 |
+
def handle_video_upload(video):
|
| 413 |
+
"""Handle video upload and extract first frame"""
|
| 414 |
+
if video is None:
|
| 415 |
+
return None, None, [], 50, 756, 3
|
| 416 |
+
|
| 417 |
+
# Create user-specific temporary directory
|
| 418 |
+
user_temp_dir = create_user_temp_dir()
|
| 419 |
+
|
| 420 |
+
# Get original video name and copy to temp directory
|
| 421 |
+
if isinstance(video, str):
|
| 422 |
+
video_name = get_video_name(video)
|
| 423 |
+
video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
|
| 424 |
+
shutil.copy(video, video_path)
|
| 425 |
+
else:
|
| 426 |
+
video_name = get_video_name(video.name)
|
| 427 |
+
video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
|
| 428 |
+
with open(video_path, 'wb') as f:
|
| 429 |
+
f.write(video.read())
|
| 430 |
+
|
| 431 |
+
print(f"📁 Video saved to: {video_path}")
|
| 432 |
+
|
| 433 |
+
# Extract first frame
|
| 434 |
+
frame = extract_first_frame(video_path)
|
| 435 |
+
if frame is None:
|
| 436 |
+
return None, None, [], 50, 756, 3
|
| 437 |
+
|
| 438 |
+
# Resize frame to have minimum side length of 336
|
| 439 |
+
h, w = frame.shape[:2]
|
| 440 |
+
scale = 336 / min(h, w)
|
| 441 |
+
new_h, new_w = int(h * scale)//2*2, int(w * scale)//2*2
|
| 442 |
+
frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
| 443 |
+
|
| 444 |
+
# Store frame data with temp directory info
|
| 445 |
+
frame_data = {
|
| 446 |
+
'data': numpy_to_base64(frame),
|
| 447 |
+
'shape': frame.shape,
|
| 448 |
+
'dtype': str(frame.dtype),
|
| 449 |
+
'temp_dir': user_temp_dir,
|
| 450 |
+
'video_name': video_name,
|
| 451 |
+
'video_path': video_path
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
# Get video-specific settings
|
| 455 |
+
print(f"🎬 Video path: '{video}' -> Video name: '{video_name}'")
|
| 456 |
+
grid_size_val, vo_points_val, fps_val = get_video_settings(video_name)
|
| 457 |
+
print(f"🎬 Video settings for '{video_name}': grid_size={grid_size_val}, vo_points={vo_points_val}, fps={fps_val}")
|
| 458 |
+
|
| 459 |
+
return (json.dumps(frame_data), frame, [],
|
| 460 |
+
gr.update(value=grid_size_val),
|
| 461 |
+
gr.update(value=vo_points_val),
|
| 462 |
+
gr.update(value=fps_val))
|
| 463 |
+
|
| 464 |
+
def save_masks(o_masks, video_name, temp_dir):
|
| 465 |
+
"""Save binary masks to files in user-specific temp directory"""
|
| 466 |
+
o_files = []
|
| 467 |
+
for mask, _ in o_masks:
|
| 468 |
+
o_mask = np.uint8(mask.squeeze() * 255)
|
| 469 |
+
o_file = os.path.join(temp_dir, f"{video_name}.png")
|
| 470 |
+
cv2.imwrite(o_file, o_mask)
|
| 471 |
+
o_files.append(o_file)
|
| 472 |
+
return o_files
|
| 473 |
+
|
| 474 |
+
def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.SelectData):
|
| 475 |
+
"""Handle point selection for SAM"""
|
| 476 |
+
if original_img is None:
|
| 477 |
+
return None, []
|
| 478 |
+
|
| 479 |
+
try:
|
| 480 |
+
# Convert stored image data back to numpy array
|
| 481 |
+
frame_data = json.loads(original_img)
|
| 482 |
+
original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
|
| 483 |
+
temp_dir = frame_data.get('temp_dir', 'temp_local')
|
| 484 |
+
video_name = frame_data.get('video_name', 'video')
|
| 485 |
+
|
| 486 |
+
# Create a display image for visualization
|
| 487 |
+
display_img = original_img_array.copy()
|
| 488 |
+
new_sel_pix = sel_pix.copy() if sel_pix else []
|
| 489 |
+
new_sel_pix.append((evt.index, 1 if point_type == 'positive_point' else 0))
|
| 490 |
+
|
| 491 |
+
print(f"🎯 Running SAM inference for point: {evt.index}, type: {point_type}")
|
| 492 |
+
# Run SAM inference
|
| 493 |
+
o_masks = gpu_run_inference(None, original_img_array, new_sel_pix, [])
|
| 494 |
+
|
| 495 |
+
# Draw points on display image
|
| 496 |
+
for point, label in new_sel_pix:
|
| 497 |
+
cv2.drawMarker(display_img, point, COLORS[label], markerType=MARKERS[label], markerSize=MARKER_SIZE, thickness=2)
|
| 498 |
+
|
| 499 |
+
# Draw mask overlay on display image
|
| 500 |
+
if o_masks:
|
| 501 |
+
mask = o_masks[0][0]
|
| 502 |
+
overlay = display_img.copy()
|
| 503 |
+
overlay[mask.squeeze()!=0] = [20, 60, 200] # Light blue
|
| 504 |
+
display_img = cv2.addWeighted(overlay, 0.6, display_img, 0.4, 0)
|
| 505 |
+
|
| 506 |
+
# Save mask for tracking
|
| 507 |
+
save_masks(o_masks, video_name, temp_dir)
|
| 508 |
+
print(f"✅ Mask saved for video: {video_name}")
|
| 509 |
+
|
| 510 |
+
return display_img, new_sel_pix
|
| 511 |
+
|
| 512 |
+
except Exception as e:
|
| 513 |
+
print(f"❌ Error in select_point: {e}")
|
| 514 |
+
return None, []
|
| 515 |
+
|
| 516 |
+
def reset_points(original_img: str, sel_pix):
|
| 517 |
+
"""Reset all points and clear the mask"""
|
| 518 |
+
if original_img is None:
|
| 519 |
+
return None, []
|
| 520 |
+
|
| 521 |
+
try:
|
| 522 |
+
# Convert stored image data back to numpy array
|
| 523 |
+
frame_data = json.loads(original_img)
|
| 524 |
+
original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
|
| 525 |
+
temp_dir = frame_data.get('temp_dir', 'temp_local')
|
| 526 |
+
|
| 527 |
+
# Create a display image (just the original image)
|
| 528 |
+
display_img = original_img_array.copy()
|
| 529 |
+
|
| 530 |
+
# Clear all points
|
| 531 |
+
new_sel_pix = []
|
| 532 |
+
|
| 533 |
+
# Clear any existing masks
|
| 534 |
+
for mask_file in glob.glob(os.path.join(temp_dir, "*.png")):
|
| 535 |
+
try:
|
| 536 |
+
os.remove(mask_file)
|
| 537 |
+
except Exception as e:
|
| 538 |
+
logger.warning(f"Failed to remove mask file {mask_file}: {e}")
|
| 539 |
+
|
| 540 |
+
print("🔄 Points and masks reset")
|
| 541 |
+
return display_img, new_sel_pix
|
| 542 |
+
|
| 543 |
+
except Exception as e:
|
| 544 |
+
print(f"❌ Error in reset_points: {e}")
|
| 545 |
+
return None, []
|
| 546 |
+
|
| 547 |
+
def launch_viz(grid_size, vo_points, fps, original_image_state):
|
| 548 |
+
"""Launch visualization with user-specific temp directory"""
|
| 549 |
+
if original_image_state is None:
|
| 550 |
+
return None, None
|
| 551 |
+
|
| 552 |
+
try:
|
| 553 |
+
# Get user's temp directory from stored frame data
|
| 554 |
+
frame_data = json.loads(original_image_state)
|
| 555 |
+
temp_dir = frame_data.get('temp_dir', 'temp_local')
|
| 556 |
+
video_name = frame_data.get('video_name', 'video')
|
| 557 |
+
|
| 558 |
+
print(f"🚀 Starting tracking for video: {video_name}")
|
| 559 |
+
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
|
| 560 |
+
|
| 561 |
+
# Check for mask files
|
| 562 |
+
mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
|
| 563 |
+
video_files = glob.glob(os.path.join(temp_dir, "*.mp4"))
|
| 564 |
+
|
| 565 |
+
if not video_files:
|
| 566 |
+
print("❌ No video file found")
|
| 567 |
+
return "❌ Error: No video file found", None
|
| 568 |
+
|
| 569 |
+
video_path = video_files[0]
|
| 570 |
+
mask_path = mask_files[0] if mask_files else None
|
| 571 |
+
|
| 572 |
+
# Run tracker
|
| 573 |
+
print("🎯 Running tracker...")
|
| 574 |
+
out_dir = os.path.join(temp_dir, "results")
|
| 575 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 576 |
+
|
| 577 |
+
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps)
|
| 578 |
+
|
| 579 |
+
# Process results
|
| 580 |
+
npz_path = os.path.join(out_dir, "result.npz")
|
| 581 |
+
track2d_video = os.path.join(out_dir, "test_pred_track.mp4")
|
| 582 |
+
|
| 583 |
+
if os.path.exists(npz_path):
|
| 584 |
+
print("📊 Processing 3D visualization...")
|
| 585 |
+
html_path = process_point_cloud_data(npz_path)
|
| 586 |
+
|
| 587 |
+
# Schedule deletion of generated files
|
| 588 |
+
delete_later(html_path, delay=600)
|
| 589 |
+
if os.path.exists(track2d_video):
|
| 590 |
+
delete_later(track2d_video, delay=600)
|
| 591 |
+
delete_later(npz_path, delay=600)
|
| 592 |
+
|
| 593 |
+
# Create iframe HTML
|
| 594 |
+
iframe_html = f"""
|
| 595 |
+
<div style='border: 3px solid #667eea; border-radius: 10px; overflow: hidden; box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3);'>
|
| 596 |
+
<iframe id="viz_iframe" src="/gradio_api/file={html_path}" width="100%" height="950px" style="border:none;"></iframe>
|
| 597 |
+
</div>
|
| 598 |
+
"""
|
| 599 |
+
|
| 600 |
+
print("✅ Tracking completed successfully!")
|
| 601 |
+
return iframe_html, track2d_video if os.path.exists(track2d_video) else None
|
| 602 |
+
else:
|
| 603 |
+
print("❌ Tracking failed - no results generated")
|
| 604 |
+
return "❌ Error: Tracking failed to generate results", None
|
| 605 |
+
|
| 606 |
+
except Exception as e:
|
| 607 |
+
print(f"❌ Error in launch_viz: {e}")
|
| 608 |
+
return f"❌ Error: {str(e)}", None
|
| 609 |
+
|
| 610 |
+
def clear_all():
|
| 611 |
+
"""Clear all buffers and temporary files"""
|
| 612 |
+
return (None, None, [],
|
| 613 |
+
gr.update(value=50),
|
| 614 |
+
gr.update(value=756),
|
| 615 |
+
gr.update(value=3))
|
| 616 |
+
|
| 617 |
+
def get_video_settings(video_name):
|
| 618 |
+
"""Get video-specific settings based on video name"""
|
| 619 |
+
video_settings = {
|
| 620 |
+
"kiss": (45, 700, 10),
|
| 621 |
+
"backpack": (40, 600, 2),
|
| 622 |
+
"kitchen": (60, 800, 3),
|
| 623 |
+
"pillow": (35, 500, 2),
|
| 624 |
+
"handwave": (35, 500, 8),
|
| 625 |
+
"hockey": (45, 700, 2),
|
| 626 |
+
"drifting": (35, 1000, 6),
|
| 627 |
+
"basketball": (45, 1500, 5),
|
| 628 |
+
"ken_block_0": (45, 700, 2),
|
| 629 |
+
"ego_kc1": (45, 500, 4),
|
| 630 |
+
"vertical_place": (45, 500, 3),
|
| 631 |
+
"ego_teaser": (45, 1200, 10),
|
| 632 |
+
"robot_unitree": (45, 500, 4),
|
| 633 |
+
"droid_robot": (35, 400, 5),
|
| 634 |
+
"robot_2": (45, 256, 5),
|
| 635 |
+
"cinema_0": (45, 356, 5),
|
| 636 |
+
"cinema_1": (45, 756, 3),
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
return video_settings.get(video_name, (50, 756, 3))
|
| 640 |
+
|
| 641 |
+
# Create the Gradio interface
|
| 642 |
+
print("🎨 Creating Gradio interface...")
|
| 643 |
+
|
| 644 |
+
with gr.Blocks(
|
| 645 |
+
theme=gr.themes.Soft(),
|
| 646 |
+
title="SpatialTracker V2 - Local",
|
| 647 |
+
css="""
|
| 648 |
+
.gradio-container {
|
| 649 |
+
max-width: 1200px !important;
|
| 650 |
+
margin: auto !important;
|
| 651 |
+
}
|
| 652 |
+
.gr-button {
|
| 653 |
+
margin: 5px;
|
| 654 |
+
}
|
| 655 |
+
.gr-form {
|
| 656 |
+
background: white;
|
| 657 |
+
border-radius: 10px;
|
| 658 |
+
padding: 20px;
|
| 659 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
| 660 |
+
}
|
| 661 |
+
.gr-video {
|
| 662 |
+
height: 300px !important;
|
| 663 |
+
min-height: 300px !important;
|
| 664 |
+
max-height: 300px !important;
|
| 665 |
+
}
|
| 666 |
+
.gr-video video {
|
| 667 |
+
height: 260px !important;
|
| 668 |
+
max-height: 260px !important;
|
| 669 |
+
object-fit: contain !important;
|
| 670 |
+
background: #f8f9fa;
|
| 671 |
+
}
|
| 672 |
+
.horizontal-examples .gr-examples {
|
| 673 |
+
overflow: visible !important;
|
| 674 |
+
}
|
| 675 |
+
.horizontal-examples .gr-examples .gr-table-wrapper {
|
| 676 |
+
overflow-x: auto !important;
|
| 677 |
+
overflow-y: hidden !important;
|
| 678 |
+
scrollbar-width: thin;
|
| 679 |
+
scrollbar-color: #667eea #f1f1f1;
|
| 680 |
+
padding: 10px 0;
|
| 681 |
+
}
|
| 682 |
+
.horizontal-examples .gr-examples .gr-table-wrapper::-webkit-scrollbar {
|
| 683 |
+
height: 8px;
|
| 684 |
+
}
|
| 685 |
+
.horizontal-examples .gr-examples .gr-table-wrapper::-webkit-scrollbar-track {
|
| 686 |
+
background: #f1f1f1;
|
| 687 |
+
border-radius: 4px;
|
| 688 |
+
}
|
| 689 |
+
.horizontal-examples .gr-examples .gr-table-wrapper::-webkit-scrollbar-thumb {
|
| 690 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 691 |
+
border-radius: 4px;
|
| 692 |
+
}
|
| 693 |
+
.horizontal-examples .gr-examples .gr-table {
|
| 694 |
+
display: flex !important;
|
| 695 |
+
flex-wrap: nowrap !important;
|
| 696 |
+
min-width: max-content !important;
|
| 697 |
+
gap: 15px !important;
|
| 698 |
+
padding-bottom: 10px;
|
| 699 |
+
}
|
| 700 |
+
.horizontal-examples .gr-examples .gr-table tbody {
|
| 701 |
+
display: flex !important;
|
| 702 |
+
flex-direction: row !important;
|
| 703 |
+
flex-wrap: nowrap !important;
|
| 704 |
+
gap: 15px !important;
|
| 705 |
+
}
|
| 706 |
+
.horizontal-examples .gr-examples .gr-table tbody tr {
|
| 707 |
+
display: flex !important;
|
| 708 |
+
flex-direction: column !important;
|
| 709 |
+
min-width: 160px !important;
|
| 710 |
+
max-width: 160px !important;
|
| 711 |
+
margin: 0 !important;
|
| 712 |
+
background: white;
|
| 713 |
+
border-radius: 12px;
|
| 714 |
+
box-shadow: 0 3px 12px rgba(0,0,0,0.12);
|
| 715 |
+
transition: all 0.3s ease;
|
| 716 |
+
cursor: pointer;
|
| 717 |
+
overflow: hidden;
|
| 718 |
+
}
|
| 719 |
+
.horizontal-examples .gr-examples .gr-table tbody tr:hover {
|
| 720 |
+
transform: translateY(-4px);
|
| 721 |
+
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.25);
|
| 722 |
+
}
|
| 723 |
+
.horizontal-examples .gr-examples .gr-table tbody tr td {
|
| 724 |
+
text-align: center !important;
|
| 725 |
+
padding: 0 !important;
|
| 726 |
+
border: none !important;
|
| 727 |
+
}
|
| 728 |
+
.horizontal-examples .gr-examples .gr-table tbody tr td:first-child {
|
| 729 |
+
padding: 0 !important;
|
| 730 |
+
}
|
| 731 |
+
.horizontal-examples .gr-examples .gr-table tbody tr td video {
|
| 732 |
+
border-radius: 8px 8px 0 0 !important;
|
| 733 |
+
width: 100% !important;
|
| 734 |
+
height: 90px !important;
|
| 735 |
+
object-fit: cover !important;
|
| 736 |
+
}
|
| 737 |
+
.horizontal-examples .gr-examples .gr-table tbody tr td:last-child {
|
| 738 |
+
font-size: 11px !important;
|
| 739 |
+
font-weight: 600 !important;
|
| 740 |
+
color: #333 !important;
|
| 741 |
+
padding: 8px 12px !important;
|
| 742 |
+
background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%);
|
| 743 |
+
border-radius: 0 0 8px 8px;
|
| 744 |
+
}
|
| 745 |
+
"""
|
| 746 |
+
) as demo:
|
| 747 |
+
|
| 748 |
+
gr.Markdown("""
|
| 749 |
+
# 🎯 SpatialTracker V2 - Local Version
|
| 750 |
+
|
| 751 |
+
Welcome to SpatialTracker V2! This interface allows you to track any pixels in 3D using our model.
|
| 752 |
+
|
| 753 |
+
**Instructions:**
|
| 754 |
+
1. Upload a video file or select from examples below
|
| 755 |
+
2. Click on the object you want to track in the first frame
|
| 756 |
+
3. Adjust tracking parameters if needed
|
| 757 |
+
4. Click "Launch Visualization" to start tracking
|
| 758 |
+
|
| 759 |
+
""")
|
| 760 |
+
|
| 761 |
+
# Status indicator
|
| 762 |
+
gr.Markdown("**Status:** 🟢 Local Processing Mode")
|
| 763 |
+
gr.Markdown("<small style='color: #666;'>All processing runs locally with GPU acceleration</small>")
|
| 764 |
+
|
| 765 |
+
# GitHub Star Reminder
|
| 766 |
+
gr.HTML("""
|
| 767 |
+
<div style='background: linear-gradient(135deg, #e8eaff 0%, #f0f2ff 100%);
|
| 768 |
+
border-radius: 10px;
|
| 769 |
+
padding: 15px;
|
| 770 |
+
margin: 15px 0;
|
| 771 |
+
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.1);
|
| 772 |
+
border: 1px solid rgba(102, 126, 234, 0.15);'>
|
| 773 |
+
<div style='text-align: center; color: #4a5568;'>
|
| 774 |
+
<h3 style='margin: 0 0 10px 0; font-size: 18px; text-shadow: none; color: #2d3748;'>
|
| 775 |
+
⭐ Love SpatialTracker? Give us a Star! ⭐
|
| 776 |
+
</h3>
|
| 777 |
+
<p style='margin: 0 0 12px 0; font-size: 14px; opacity: 0.8; color: #4a5568;'>
|
| 778 |
+
Help us grow by starring our repository on GitHub! 🚀
|
| 779 |
+
</p>
|
| 780 |
+
<div style='display: flex; justify-content: center;'>
|
| 781 |
+
<a href="https://github.com/henry123-boy/SpaTrackerV2"
|
| 782 |
+
target="_blank"
|
| 783 |
+
style='display: inline-flex;
|
| 784 |
+
align-items: center;
|
| 785 |
+
gap: 6px;
|
| 786 |
+
background: rgba(102, 126, 234, 0.1);
|
| 787 |
+
color: #4a5568;
|
| 788 |
+
padding: 8px 16px;
|
| 789 |
+
border-radius: 20px;
|
| 790 |
+
text-decoration: none;
|
| 791 |
+
font-weight: bold;
|
| 792 |
+
font-size: 14px;
|
| 793 |
+
backdrop-filter: blur(5px);
|
| 794 |
+
border: 1px solid rgba(102, 126, 234, 0.2);
|
| 795 |
+
transition: all 0.3s ease;'
|
| 796 |
+
onmouseover="this.style.background='rgba(102, 126, 234, 0.15)'; this.style.transform='translateY(-1px)'"
|
| 797 |
+
onmouseout="this.style.background='rgba(102, 126, 234, 0.1)'; this.style.transform='translateY(0)'">
|
| 798 |
+
<span style='font-size: 16px;'>⭐</span>
|
| 799 |
+
Star on GitHub
|
| 800 |
+
</a>
|
| 801 |
+
</div>
|
| 802 |
+
</div>
|
| 803 |
+
</div>
|
| 804 |
+
""")
|
| 805 |
+
|
| 806 |
+
# Example videos section
|
| 807 |
+
with gr.Group():
|
| 808 |
+
gr.Markdown("### 📂 Example Videos")
|
| 809 |
+
gr.Markdown("Try these example videos to get started quickly:")
|
| 810 |
+
|
| 811 |
+
gr.HTML("""
|
| 812 |
+
<div style='background-color: #f8f9ff; border-radius: 8px; padding: 10px; margin: 10px 0; border-left: 4px solid #667eea;'>
|
| 813 |
+
<p style='margin: 0; font-size: 13px; color: #666; display: flex; align-items: center; gap: 8px;'>
|
| 814 |
+
<span style='font-size: 16px;'>💡</span>
|
| 815 |
+
<strong>Tip:</strong> Scroll horizontally below to see all example videos
|
| 816 |
+
</p>
|
| 817 |
+
</div>
|
| 818 |
+
""")
|
| 819 |
+
|
| 820 |
+
video_input = gr.Video(
|
| 821 |
+
label="Upload Video or Select Example",
|
| 822 |
+
format="mp4",
|
| 823 |
+
height=300
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
with gr.Group(elem_classes=["horizontal-examples"]):
|
| 827 |
+
gr.Examples(
|
| 828 |
+
examples=[
|
| 829 |
+
["examples/kiss.mp4"],
|
| 830 |
+
["examples/backpack.mp4"],
|
| 831 |
+
["examples/pillow.mp4"],
|
| 832 |
+
["examples/handwave.mp4"],
|
| 833 |
+
["examples/hockey.mp4"],
|
| 834 |
+
["examples/drifting.mp4"],
|
| 835 |
+
["examples/ken_block_0.mp4"],
|
| 836 |
+
["examples/kitchen.mp4"],
|
| 837 |
+
["examples/basketball.mp4"],
|
| 838 |
+
["examples/ego_kc1.mp4"],
|
| 839 |
+
["examples/vertical_place.mp4"],
|
| 840 |
+
["examples/ego_teaser.mp4"],
|
| 841 |
+
["examples/robot_unitree.mp4"],
|
| 842 |
+
["examples/droid_robot.mp4"],
|
| 843 |
+
["examples/robot_2.mp4"],
|
| 844 |
+
["examples/cinema_0.mp4"],
|
| 845 |
+
["examples/cinema_1.mp4"],
|
| 846 |
+
],
|
| 847 |
+
inputs=video_input,
|
| 848 |
+
label="🎬 Click on any example to load it",
|
| 849 |
+
examples_per_page=16
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
with gr.Row():
|
| 853 |
+
with gr.Column(scale=1):
|
| 854 |
+
# Interactive frame display
|
| 855 |
+
with gr.Group():
|
| 856 |
+
gr.Markdown("### 🎯 Point Selection")
|
| 857 |
+
gr.Markdown("Click on the object you want to track in the frame below:")
|
| 858 |
+
|
| 859 |
+
interactive_frame = gr.Image(
|
| 860 |
+
label="Click to select tracking points",
|
| 861 |
+
type="numpy",
|
| 862 |
+
interactive=True
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
with gr.Row():
|
| 866 |
+
point_type = gr.Radio(
|
| 867 |
+
choices=["positive_point", "negative_point"],
|
| 868 |
+
value="positive_point",
|
| 869 |
+
label="Point Type",
|
| 870 |
+
info="Positive points indicate the object to track, negative points indicate areas to avoid"
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
with gr.Row():
|
| 874 |
+
reset_points_btn = gr.Button("🔄 Reset Points", variant="secondary")
|
| 875 |
+
clear_all_btn = gr.Button("🗑️ Clear All", variant="stop")
|
| 876 |
+
|
| 877 |
+
with gr.Column(scale=1):
|
| 878 |
+
# Tracking results
|
| 879 |
+
with gr.Group():
|
| 880 |
+
gr.Markdown("### 🎬 Tracking Results")
|
| 881 |
+
tracking_result_video = gr.Video(
|
| 882 |
+
label="Tracking Result Video",
|
| 883 |
+
interactive=False,
|
| 884 |
+
height=300
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
# 3D Visualization - Make it larger and more prominent
|
| 888 |
+
with gr.Row():
|
| 889 |
+
with gr.Column():
|
| 890 |
+
with gr.Group():
|
| 891 |
+
gr.Markdown("### 🌐 3D Trajectory Visualization")
|
| 892 |
+
gr.Markdown("Interactive 3D visualization of 3D point tracking and camera motion:")
|
| 893 |
+
viz_html = gr.HTML(
|
| 894 |
+
label="3D Trajectory Visualization",
|
| 895 |
+
value="""
|
| 896 |
+
<div style='border: 3px solid #667eea; border-radius: 15px; padding: 40px;
|
| 897 |
+
background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%);
|
| 898 |
+
text-align: center; min-height: 600px; display: flex;
|
| 899 |
+
flex-direction: column; justify-content: center; align-items: center;
|
| 900 |
+
box-shadow: 0 8px 32px rgba(102, 126, 234, 0.2);'>
|
| 901 |
+
<div style='font-size: 48px; margin-bottom: 20px;'>🌐</div>
|
| 902 |
+
<h2 style='color: #667eea; margin-bottom: 15px; font-size: 28px; font-weight: 600;'>
|
| 903 |
+
3D Trajectory Visualization
|
| 904 |
+
</h2>
|
| 905 |
+
<p style='color: #666; font-size: 16px; line-height: 1.6; max-width: 500px; margin-bottom: 25px;'>
|
| 906 |
+
Perceive the world with Pixel-wise 3D Motions!
|
| 907 |
+
</p>
|
| 908 |
+
<div style='background: rgba(102, 126, 234, 0.1); border-radius: 25px;
|
| 909 |
+
padding: 12px 24px; border: 2px solid rgba(102, 126, 234, 0.2);'>
|
| 910 |
+
<span style='color: #667eea; font-weight: 600; font-size: 14px;'>
|
| 911 |
+
⚡ Powered by SpatialTracker V2
|
| 912 |
+
</span>
|
| 913 |
+
</div>
|
| 914 |
+
</div>
|
| 915 |
+
""",
|
| 916 |
+
elem_id="viz_container"
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
# Advanced settings section
|
| 920 |
+
with gr.Accordion("⚙️ Advanced Settings", open=True):
|
| 921 |
+
gr.Markdown("Adjust these parameters to optimize tracking performance:")
|
| 922 |
+
with gr.Row():
|
| 923 |
+
grid_size = gr.Slider(
|
| 924 |
+
minimum=10,
|
| 925 |
+
maximum=100,
|
| 926 |
+
step=10,
|
| 927 |
+
value=50,
|
| 928 |
+
label="Grid Size",
|
| 929 |
+
info="Size of the tracking grid (larger = more detailed)"
|
| 930 |
+
)
|
| 931 |
+
vo_points = gr.Slider(
|
| 932 |
+
minimum=100,
|
| 933 |
+
maximum=2000,
|
| 934 |
+
step=50,
|
| 935 |
+
value=756,
|
| 936 |
+
label="VO Points",
|
| 937 |
+
info="Number of visual odometry points (more = better accuracy)"
|
| 938 |
+
)
|
| 939 |
+
fps = gr.Slider(
|
| 940 |
+
minimum=1,
|
| 941 |
+
maximum=30,
|
| 942 |
+
step=1,
|
| 943 |
+
value=3,
|
| 944 |
+
label="FPS",
|
| 945 |
+
info="Frames per second for processing (higher = smoother but slower)"
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
# Launch button
|
| 949 |
+
with gr.Row():
|
| 950 |
+
launch_btn = gr.Button("🚀 Launch Visualization", variant="primary", size="lg")
|
| 951 |
+
|
| 952 |
+
# Hidden state variables
|
| 953 |
+
original_image_state = gr.State(None)
|
| 954 |
+
selected_points = gr.State([])
|
| 955 |
+
|
| 956 |
+
# Event handlers
|
| 957 |
+
video_input.change(
|
| 958 |
+
fn=handle_video_upload,
|
| 959 |
+
inputs=[video_input],
|
| 960 |
+
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
interactive_frame.select(
|
| 964 |
+
fn=select_point,
|
| 965 |
+
inputs=[original_image_state, selected_points, point_type],
|
| 966 |
+
outputs=[interactive_frame, selected_points]
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
reset_points_btn.click(
|
| 970 |
+
fn=reset_points,
|
| 971 |
+
inputs=[original_image_state, selected_points],
|
| 972 |
+
outputs=[interactive_frame, selected_points]
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
clear_all_btn.click(
|
| 976 |
+
fn=clear_all,
|
| 977 |
+
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps]
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
launch_btn.click(
|
| 981 |
+
fn=launch_viz,
|
| 982 |
+
inputs=[grid_size, vo_points, fps, original_image_state],
|
| 983 |
+
outputs=[viz_html, tracking_result_video]
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
# Acknowledgment section for TAPIP3D - moved to the end
|
| 987 |
+
gr.HTML("""
|
| 988 |
+
<div style='background: linear-gradient(135deg, #fff8e1 0%, #fffbf0 100%);
|
| 989 |
+
border-radius: 8px;
|
| 990 |
+
padding: 12px;
|
| 991 |
+
margin: 15px 0;
|
| 992 |
+
box-shadow: 0 1px 4px rgba(255, 193, 7, 0.1);
|
| 993 |
+
border: 1px solid rgba(255, 193, 7, 0.2);'>
|
| 994 |
+
<div style='text-align: center; color: #5d4037;'>
|
| 995 |
+
<h5 style='margin: 0 0 6px 0; font-size: 14px; color: #5d4037;'>
|
| 996 |
+
Acknowledgments
|
| 997 |
+
</h5>
|
| 998 |
+
<p style='margin: 0; font-size: 12px; opacity: 0.9; color: #5d4037; line-height: 1.3;'>
|
| 999 |
+
Our 3D visualizer is adapted from <strong>TAPIP3D</strong>. We thank the authors for their excellent work!
|
| 1000 |
+
</p>
|
| 1001 |
+
<div style='margin-top: 6px;'>
|
| 1002 |
+
<a href="https://github.com/zbw001/TAPIP3D"
|
| 1003 |
+
target="_blank"
|
| 1004 |
+
style='display: inline-flex;
|
| 1005 |
+
align-items: center;
|
| 1006 |
+
gap: 3px;
|
| 1007 |
+
background: rgba(255, 193, 7, 0.15);
|
| 1008 |
+
color: #5d4037;
|
| 1009 |
+
padding: 3px 10px;
|
| 1010 |
+
border-radius: 12px;
|
| 1011 |
+
text-decoration: none;
|
| 1012 |
+
font-weight: 500;
|
| 1013 |
+
font-size: 11px;
|
| 1014 |
+
border: 1px solid rgba(255, 193, 7, 0.3);
|
| 1015 |
+
transition: all 0.3s ease;'
|
| 1016 |
+
onmouseover="this.style.background='rgba(255, 193, 7, 0.2)'"
|
| 1017 |
+
onmouseout="this.style.background='rgba(255, 193, 7, 0.15)'">
|
| 1018 |
+
📚 TAPIP3D Repository
|
| 1019 |
+
</a>
|
| 1020 |
+
</div>
|
| 1021 |
+
</div>
|
| 1022 |
+
</div>
|
| 1023 |
+
""")
|
| 1024 |
+
|
| 1025 |
+
# Launch the interface
|
| 1026 |
+
if __name__ == "__main__":
|
| 1027 |
+
print("🌟 Launching SpatialTracker V2 Local Version...")
|
| 1028 |
+
print("🔗 Running in Local Processing Mode")
|
| 1029 |
+
|
| 1030 |
+
demo.launch(
|
| 1031 |
+
server_name="0.0.0.0",
|
| 1032 |
+
server_port=7860,
|
| 1033 |
+
share=True,
|
| 1034 |
+
debug=True,
|
| 1035 |
+
show_error=True
|
| 1036 |
+
)
|
models/SpaTrackV2/models/SpaTrack.py
ADDED
|
@@ -0,0 +1,758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#python
|
| 2 |
+
"""
|
| 3 |
+
SpaTrackerV2, which is an unified model to estimate 'intrinsic',
|
| 4 |
+
'video depth', 'extrinsic' and '3D Tracking' from casual video frames.
|
| 5 |
+
|
| 6 |
+
Contact: DM yuxixiao@zju.edu.cn
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import Literal, Union, List, Tuple, Dict
|
| 12 |
+
import cv2
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
# from depth anything v2
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
from models.monoD.depth_anything_v2.dpt import DepthAnythingV2
|
| 19 |
+
from models.moge.model.v1 import MoGeModel
|
| 20 |
+
import copy
|
| 21 |
+
from functools import partial
|
| 22 |
+
from models.SpaTrackV2.models.tracker3D.TrackRefiner import TrackRefiner3D
|
| 23 |
+
import kornia
|
| 24 |
+
from models.SpaTrackV2.utils.model_utils import sample_features5d
|
| 25 |
+
import utils3d
|
| 26 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
|
| 27 |
+
from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb
|
| 28 |
+
import random
|
| 29 |
+
|
| 30 |
+
class SpaTrack2(nn.Module):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
loggers: list, # include [ viz, logger_tf, logger]
|
| 34 |
+
backbone_cfg,
|
| 35 |
+
Track_cfg=None,
|
| 36 |
+
chunk_size=24,
|
| 37 |
+
ckpt_fwd: bool = False,
|
| 38 |
+
ft_cfg=None,
|
| 39 |
+
resolution=518,
|
| 40 |
+
max_len=600, # the maximum video length we can preprocess,
|
| 41 |
+
track_num=768,
|
| 42 |
+
):
|
| 43 |
+
|
| 44 |
+
self.chunk_size = chunk_size
|
| 45 |
+
self.max_len = max_len
|
| 46 |
+
self.resolution = resolution
|
| 47 |
+
# config the T-Lora Dinov2
|
| 48 |
+
#NOTE: initial the base model
|
| 49 |
+
base_cfg = copy.deepcopy(backbone_cfg)
|
| 50 |
+
backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
|
| 51 |
+
|
| 52 |
+
super(SpaTrack2, self).__init__()
|
| 53 |
+
if os.path.exists(backbone_ckpt_dir)==False:
|
| 54 |
+
base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
|
| 55 |
+
else:
|
| 56 |
+
checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
|
| 57 |
+
base_model = MoGeModel(**checkpoint["model_config"])
|
| 58 |
+
base_model.load_state_dict(checkpoint['model'])
|
| 59 |
+
# avoid the base_model is a member of SpaTrack2
|
| 60 |
+
object.__setattr__(self, 'base_model', base_model)
|
| 61 |
+
|
| 62 |
+
# Tracker model
|
| 63 |
+
self.Track3D = TrackRefiner3D(Track_cfg)
|
| 64 |
+
track_base_ckpt_dir = Track_cfg.base_ckpt
|
| 65 |
+
if os.path.exists(track_base_ckpt_dir):
|
| 66 |
+
track_pretrain = torch.load(track_base_ckpt_dir)
|
| 67 |
+
self.Track3D.load_state_dict(track_pretrain, strict=False)
|
| 68 |
+
|
| 69 |
+
# wrap the function of make lora trainable
|
| 70 |
+
self.make_paras_trainable = partial(self.make_paras_trainable,
|
| 71 |
+
mode=ft_cfg.mode,
|
| 72 |
+
paras_name=ft_cfg.paras_name)
|
| 73 |
+
self.track_num = track_num
|
| 74 |
+
|
| 75 |
+
def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
|
| 76 |
+
# gradient required for the lora_experts and gate
|
| 77 |
+
for name, param in self.named_parameters():
|
| 78 |
+
if any(x in name for x in paras_name):
|
| 79 |
+
if mode == 'fix':
|
| 80 |
+
param.requires_grad = False
|
| 81 |
+
else:
|
| 82 |
+
param.requires_grad = True
|
| 83 |
+
else:
|
| 84 |
+
if mode == 'fix':
|
| 85 |
+
param.requires_grad = True
|
| 86 |
+
else:
|
| 87 |
+
param.requires_grad = False
|
| 88 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 89 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 90 |
+
print(f"Total parameters: {total_params}")
|
| 91 |
+
print(f"Trainable parameters: {trainable_params/total_params*100:.2f}%")
|
| 92 |
+
|
| 93 |
+
def ProcVid(self,
|
| 94 |
+
x: torch.Tensor):
|
| 95 |
+
"""
|
| 96 |
+
split the video into several overlapped windows.
|
| 97 |
+
|
| 98 |
+
args:
|
| 99 |
+
x: the input video frames. [B, T, C, H, W]
|
| 100 |
+
outputs:
|
| 101 |
+
patch_size: the patch size of the video features
|
| 102 |
+
raises:
|
| 103 |
+
ValueError: if the input video is longer than `max_len`.
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
# normalize the input images
|
| 107 |
+
num_types = x.dtype
|
| 108 |
+
x = normalize_rgb(x, input_size=self.resolution)
|
| 109 |
+
x = x.to(num_types)
|
| 110 |
+
# get the video features
|
| 111 |
+
B, T, C, H, W = x.size()
|
| 112 |
+
if T > self.max_len:
|
| 113 |
+
raise ValueError(f"the video length should no more than {self.max_len}.")
|
| 114 |
+
# get the video features
|
| 115 |
+
patch_h, patch_w = H // 14, W // 14
|
| 116 |
+
patch_size = (patch_h, patch_w)
|
| 117 |
+
# resize and get the video features
|
| 118 |
+
x = x.view(B * T, C, H, W)
|
| 119 |
+
# operate the temporal encoding
|
| 120 |
+
return patch_size, x
|
| 121 |
+
|
| 122 |
+
def forward_stream(
|
| 123 |
+
self,
|
| 124 |
+
video: torch.Tensor,
|
| 125 |
+
queries: torch.Tensor = None,
|
| 126 |
+
T_org: int = None,
|
| 127 |
+
depth: torch.Tensor|np.ndarray|str=None,
|
| 128 |
+
unc_metric_in: torch.Tensor|np.ndarray|str=None,
|
| 129 |
+
intrs: torch.Tensor|np.ndarray|str=None,
|
| 130 |
+
extrs: torch.Tensor|np.ndarray|str=None,
|
| 131 |
+
queries_3d: torch.Tensor = None,
|
| 132 |
+
window_len: int = 16,
|
| 133 |
+
overlap_len: int = 4,
|
| 134 |
+
full_point: bool = False,
|
| 135 |
+
track2d_gt: torch.Tensor = None,
|
| 136 |
+
fixed_cam: bool = False,
|
| 137 |
+
query_no_BA: bool = False,
|
| 138 |
+
stage: int = 0,
|
| 139 |
+
support_frame: int = 0,
|
| 140 |
+
replace_ratio: float = 0.6,
|
| 141 |
+
annots_train: Dict = None,
|
| 142 |
+
iters_track=4,
|
| 143 |
+
**kwargs,
|
| 144 |
+
):
|
| 145 |
+
# step 1 allocate the query points on the grid
|
| 146 |
+
T, C, H, W = video.shape
|
| 147 |
+
|
| 148 |
+
if annots_train is not None:
|
| 149 |
+
vis_gt = annots_train["vis"]
|
| 150 |
+
_, _, N = vis_gt.shape
|
| 151 |
+
number_visible = vis_gt.sum(dim=1)
|
| 152 |
+
ratio_rand = torch.rand(1, N, device=vis_gt.device)
|
| 153 |
+
first_positive_inds = get_nth_visible_time_index(vis_gt, (number_visible*ratio_rand).long().clamp(min=1, max=T))
|
| 154 |
+
assert (torch.gather(vis_gt, 1, first_positive_inds[:, None, :].repeat(1, T, 1)) < 0).sum() == 0
|
| 155 |
+
|
| 156 |
+
first_positive_inds = first_positive_inds.long()
|
| 157 |
+
gather = torch.gather(
|
| 158 |
+
annots_train["traj_3d"][...,:2], 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
|
| 159 |
+
)
|
| 160 |
+
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
|
| 161 |
+
queries = torch.cat([first_positive_inds[:, :, None], xys], dim=-1)[0].cpu().numpy()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# Unfold video into segments of window_len with overlap_len
|
| 165 |
+
step_slide = window_len - overlap_len
|
| 166 |
+
if T < window_len:
|
| 167 |
+
video_unf = video.unsqueeze(0)
|
| 168 |
+
if depth is not None:
|
| 169 |
+
depth_unf = depth.unsqueeze(0)
|
| 170 |
+
else:
|
| 171 |
+
depth_unf = None
|
| 172 |
+
if unc_metric_in is not None:
|
| 173 |
+
unc_metric_unf = unc_metric_in.unsqueeze(0)
|
| 174 |
+
else:
|
| 175 |
+
unc_metric_unf = None
|
| 176 |
+
if intrs is not None:
|
| 177 |
+
intrs_unf = intrs.unsqueeze(0)
|
| 178 |
+
else:
|
| 179 |
+
intrs_unf = None
|
| 180 |
+
if extrs is not None:
|
| 181 |
+
extrs_unf = extrs.unsqueeze(0)
|
| 182 |
+
else:
|
| 183 |
+
extrs_unf = None
|
| 184 |
+
else:
|
| 185 |
+
video_unf = video.unfold(0, window_len, step_slide).permute(0, 4, 1, 2, 3) # [B, S, C, H, W]
|
| 186 |
+
if depth is not None:
|
| 187 |
+
depth_unf = depth.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
|
| 188 |
+
intrs_unf = intrs.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
|
| 189 |
+
else:
|
| 190 |
+
depth_unf = None
|
| 191 |
+
intrs_unf = None
|
| 192 |
+
if extrs is not None:
|
| 193 |
+
extrs_unf = extrs.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
|
| 194 |
+
else:
|
| 195 |
+
extrs_unf = None
|
| 196 |
+
if unc_metric_in is not None:
|
| 197 |
+
unc_metric_unf = unc_metric_in.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
|
| 198 |
+
else:
|
| 199 |
+
unc_metric_unf = None
|
| 200 |
+
|
| 201 |
+
# parallel
|
| 202 |
+
# Get number of segments
|
| 203 |
+
B = video_unf.shape[0]
|
| 204 |
+
#TODO: Process each segment in parallel using torch.nn.DataParallel
|
| 205 |
+
c2w_traj = torch.eye(4, 4)[None].repeat(T, 1, 1)
|
| 206 |
+
intrs_out = torch.eye(3, 3)[None].repeat(T, 1, 1)
|
| 207 |
+
point_map = torch.zeros(T, 3, H, W).cuda()
|
| 208 |
+
unc_metric = torch.zeros(T, H, W).cuda()
|
| 209 |
+
# set the queries
|
| 210 |
+
N, _ = queries.shape
|
| 211 |
+
track3d_pred = torch.zeros(T, N, 6).cuda()
|
| 212 |
+
track2d_pred = torch.zeros(T, N, 3).cuda()
|
| 213 |
+
vis_pred = torch.zeros(T, N, 1).cuda()
|
| 214 |
+
conf_pred = torch.zeros(T, N, 1).cuda()
|
| 215 |
+
dyn_preds = torch.zeros(T, N, 1).cuda()
|
| 216 |
+
# sort the queries by time
|
| 217 |
+
sorted_indices = np.argsort(queries[...,0])
|
| 218 |
+
sorted_inv_indices = np.argsort(sorted_indices)
|
| 219 |
+
sort_query = queries[sorted_indices]
|
| 220 |
+
sort_query = torch.from_numpy(sort_query).cuda()
|
| 221 |
+
if queries_3d is not None:
|
| 222 |
+
sort_query_3d = queries_3d[sorted_indices]
|
| 223 |
+
sort_query_3d = torch.from_numpy(sort_query_3d).cuda()
|
| 224 |
+
|
| 225 |
+
queries_len = 0
|
| 226 |
+
overlap_d = None
|
| 227 |
+
cache = None
|
| 228 |
+
loss = 0.0
|
| 229 |
+
|
| 230 |
+
for i in range(B):
|
| 231 |
+
segment = video_unf[i:i+1].cuda()
|
| 232 |
+
# Forward pass through model
|
| 233 |
+
# detect the key points for each frames
|
| 234 |
+
|
| 235 |
+
queries_new_mask = (sort_query[...,0] < i * step_slide + window_len) * (sort_query[...,0] >= (i * step_slide + overlap_len if i > 0 else 0))
|
| 236 |
+
if queries_3d is not None:
|
| 237 |
+
queries_new_3d = sort_query_3d[queries_new_mask]
|
| 238 |
+
queries_new_3d = queries_new_3d.float()
|
| 239 |
+
else:
|
| 240 |
+
queries_new_3d = None
|
| 241 |
+
queries_new = sort_query[queries_new_mask.bool()]
|
| 242 |
+
queries_new = queries_new.float()
|
| 243 |
+
if i > 0:
|
| 244 |
+
overlap2d = track2d_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
|
| 245 |
+
overlapvis = vis_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
|
| 246 |
+
overlapconf = conf_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
|
| 247 |
+
overlap_query = (overlapvis * overlapconf).max(dim=0)[1][None, ...]
|
| 248 |
+
overlap_xy = torch.gather(overlap2d, 0, overlap_query.repeat(1,1,2))
|
| 249 |
+
overlap_d = torch.gather(overlap2d, 0, overlap_query.repeat(1,1,3))[...,2].detach()
|
| 250 |
+
overlap_query = torch.cat([overlap_query[...,:1], overlap_xy], dim=-1)[0]
|
| 251 |
+
queries_new[...,0] -= i*step_slide
|
| 252 |
+
queries_new = torch.cat([overlap_query, queries_new], dim=0).detach()
|
| 253 |
+
|
| 254 |
+
if annots_train is None:
|
| 255 |
+
annots = {}
|
| 256 |
+
else:
|
| 257 |
+
annots = copy.deepcopy(annots_train)
|
| 258 |
+
annots["traj_3d"] = annots["traj_3d"][:, i*step_slide:i*step_slide+window_len, sorted_indices,:][...,:len(queries_new),:]
|
| 259 |
+
annots["vis"] = annots["vis"][:, i*step_slide:i*step_slide+window_len, sorted_indices][...,:len(queries_new)]
|
| 260 |
+
annots["poses_gt"] = annots["poses_gt"][:, i*step_slide:i*step_slide+window_len]
|
| 261 |
+
annots["depth_gt"] = annots["depth_gt"][:, i*step_slide:i*step_slide+window_len]
|
| 262 |
+
annots["intrs"] = annots["intrs"][:, i*step_slide:i*step_slide+window_len]
|
| 263 |
+
annots["traj_mat"] = annots["traj_mat"][:,i*step_slide:i*step_slide+window_len]
|
| 264 |
+
|
| 265 |
+
if depth is not None:
|
| 266 |
+
annots["depth_gt"] = depth_unf[i:i+1].to(segment.device).to(segment.dtype)
|
| 267 |
+
if unc_metric_in is not None:
|
| 268 |
+
annots["unc_metric"] = unc_metric_unf[i:i+1].to(segment.device).to(segment.dtype)
|
| 269 |
+
if intrs is not None:
|
| 270 |
+
intr_seg = intrs_unf[i:i+1].to(segment.device).to(segment.dtype)[0].clone()
|
| 271 |
+
focal = (intr_seg[:,0,0] / segment.shape[-1] + intr_seg[:,1,1]/segment.shape[-2]) / 2
|
| 272 |
+
pose_fake = torch.zeros(1, 8).to(depth.device).to(depth.dtype).repeat(segment.shape[1], 1)
|
| 273 |
+
pose_fake[:, -1] = focal
|
| 274 |
+
pose_fake[:,3]=1
|
| 275 |
+
annots["intrs_gt"] = intr_seg
|
| 276 |
+
if extrs is not None:
|
| 277 |
+
extrs_unf_norm = extrs_unf[i:i+1][0].clone()
|
| 278 |
+
extrs_unf_norm = torch.inverse(extrs_unf_norm[:1,...]) @ extrs_unf[i:i+1][0]
|
| 279 |
+
rot_vec = matrix_to_quaternion(extrs_unf_norm[:,:3,:3])
|
| 280 |
+
annots["poses_gt"] = torch.zeros(1, rot_vec.shape[0], 7).to(segment.device).to(segment.dtype)
|
| 281 |
+
annots["poses_gt"][:, :, 3:7] = rot_vec.to(segment.device).to(segment.dtype)[None]
|
| 282 |
+
annots["poses_gt"][:, :, :3] = extrs_unf_norm[:,:3,3].to(segment.device).to(segment.dtype)[None]
|
| 283 |
+
annots["use_extr"] = True
|
| 284 |
+
|
| 285 |
+
kwargs.update({"stage": stage})
|
| 286 |
+
|
| 287 |
+
#TODO: DEBUG
|
| 288 |
+
out = self.forward(segment, pts_q=queries_new,
|
| 289 |
+
pts_q_3d=queries_new_3d, overlap_d=overlap_d,
|
| 290 |
+
full_point=full_point,
|
| 291 |
+
fixed_cam=fixed_cam, query_no_BA=query_no_BA,
|
| 292 |
+
support_frame=segment.shape[1]-1,
|
| 293 |
+
cache=cache, replace_ratio=replace_ratio,
|
| 294 |
+
iters_track=iters_track,
|
| 295 |
+
**kwargs, annots=annots)
|
| 296 |
+
if self.training:
|
| 297 |
+
loss += out["loss"].squeeze()
|
| 298 |
+
# from models.SpaTrackV2.utils.visualizer import Visualizer
|
| 299 |
+
# vis_track = Visualizer(grayscale=False,
|
| 300 |
+
# fps=10, pad_value=50, tracks_leave_trace=0)
|
| 301 |
+
# vis_track.visualize(video=segment,
|
| 302 |
+
# tracks=out["traj_est"][...,:2],
|
| 303 |
+
# visibility=out["vis_est"],
|
| 304 |
+
# save_video=True)
|
| 305 |
+
# # visualize 4d
|
| 306 |
+
# import os, json
|
| 307 |
+
# import os.path as osp
|
| 308 |
+
# viser4d_dir = os.path.join("viser_4d_results")
|
| 309 |
+
# os.makedirs(viser4d_dir, exist_ok=True)
|
| 310 |
+
# depth_est = annots["depth_gt"][0]
|
| 311 |
+
# unc_metric = out["unc_metric"]
|
| 312 |
+
# mask = (unc_metric > 0.5).squeeze(1)
|
| 313 |
+
# # pose_est = out["poses_pred"].squeeze(0)
|
| 314 |
+
# pose_est = annots["traj_mat"][0]
|
| 315 |
+
# rgb_tracks = out["rgb_tracks"].squeeze(0)
|
| 316 |
+
# intrinsics = out["intrs"].squeeze(0)
|
| 317 |
+
# for i_k in range(out["depth"].shape[0]):
|
| 318 |
+
# img_i = out["imgs_raw"][0][i_k].permute(1, 2, 0).cpu().numpy()
|
| 319 |
+
# img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB)
|
| 320 |
+
# cv2.imwrite(osp.join(viser4d_dir, f'frame_{i_k:04d}.png'), img_i)
|
| 321 |
+
# if stage == 1:
|
| 322 |
+
# depth = depth_est[i_k].squeeze().cpu().numpy()
|
| 323 |
+
# np.save(osp.join(viser4d_dir, f'frame_{i_k:04d}.npy'), depth)
|
| 324 |
+
# else:
|
| 325 |
+
# point_map_vis = out["points_map"][i_k].cpu().numpy()
|
| 326 |
+
# np.save(osp.join(viser4d_dir, f'point_{i_k:04d}.npy'), point_map_vis)
|
| 327 |
+
# np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy())
|
| 328 |
+
# np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy())
|
| 329 |
+
# np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy())
|
| 330 |
+
# np.save(os.path.join(viser4d_dir, f'colored_track3d.npy'), rgb_tracks.cpu().numpy())
|
| 331 |
+
|
| 332 |
+
queries_len = len(queries_new)
|
| 333 |
+
# update the track3d and track2d
|
| 334 |
+
left_len = len(track3d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :])
|
| 335 |
+
track3d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["rgb_tracks"][0,:left_len,:queries_len,:]
|
| 336 |
+
track2d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["traj_est"][0,:left_len,:queries_len,:3]
|
| 337 |
+
vis_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["vis_est"][0,:left_len,:queries_len,None]
|
| 338 |
+
conf_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["conf_pred"][0,:left_len,:queries_len,None]
|
| 339 |
+
dyn_preds[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["dyn_preds"][0,:left_len,:queries_len,None]
|
| 340 |
+
|
| 341 |
+
# process the output for each segment
|
| 342 |
+
seg_c2w = out["poses_pred"][0]
|
| 343 |
+
seg_intrs = out["intrs"][0]
|
| 344 |
+
seg_point_map = out["points_map"]
|
| 345 |
+
seg_conf_depth = out["unc_metric"]
|
| 346 |
+
|
| 347 |
+
# cache management
|
| 348 |
+
cache = out["cache"]
|
| 349 |
+
for k in cache.keys():
|
| 350 |
+
if "_pyramid" in k:
|
| 351 |
+
for j in range(len(cache[k])):
|
| 352 |
+
if len(cache[k][j].shape) == 5:
|
| 353 |
+
cache[k][j] = cache[k][j][:,:,:,:queries_len,:]
|
| 354 |
+
elif len(cache[k][j].shape) == 4:
|
| 355 |
+
cache[k][j] = cache[k][j][:,:1,:queries_len,:]
|
| 356 |
+
elif "_pred_cache" in k:
|
| 357 |
+
cache[k] = cache[k][-overlap_len:,:queries_len,:]
|
| 358 |
+
else:
|
| 359 |
+
cache[k] = cache[k][-overlap_len:]
|
| 360 |
+
|
| 361 |
+
# update the results
|
| 362 |
+
idx_glob = i * step_slide
|
| 363 |
+
# refine part
|
| 364 |
+
# mask_update = sort_query[..., 0] < i * step_slide + window_len
|
| 365 |
+
# sort_query_pick = sort_query[mask_update]
|
| 366 |
+
intrs_out[idx_glob:idx_glob+window_len] = seg_intrs
|
| 367 |
+
point_map[idx_glob:idx_glob+window_len] = seg_point_map
|
| 368 |
+
unc_metric[idx_glob:idx_glob+window_len] = seg_conf_depth
|
| 369 |
+
# update the camera poses
|
| 370 |
+
|
| 371 |
+
# if using the ground truth pose
|
| 372 |
+
# if extrs_unf is not None:
|
| 373 |
+
# c2w_traj[idx_glob:idx_glob+window_len] = extrs_unf[i:i+1][0].to(c2w_traj.device).to(c2w_traj.dtype)
|
| 374 |
+
# else:
|
| 375 |
+
prev_c2w = c2w_traj[idx_glob:idx_glob+window_len][:1]
|
| 376 |
+
c2w_traj[idx_glob:idx_glob+window_len] = prev_c2w@seg_c2w.to(c2w_traj.device).to(c2w_traj.dtype)
|
| 377 |
+
|
| 378 |
+
track2d_pred = track2d_pred[:T_org,sorted_inv_indices,:]
|
| 379 |
+
track3d_pred = track3d_pred[:T_org,sorted_inv_indices,:]
|
| 380 |
+
vis_pred = vis_pred[:T_org,sorted_inv_indices,:]
|
| 381 |
+
conf_pred = conf_pred[:T_org,sorted_inv_indices,:]
|
| 382 |
+
dyn_preds = dyn_preds[:T_org,sorted_inv_indices,:]
|
| 383 |
+
unc_metric = unc_metric[:T_org,:]
|
| 384 |
+
point_map = point_map[:T_org,:]
|
| 385 |
+
intrs_out = intrs_out[:T_org,:]
|
| 386 |
+
c2w_traj = c2w_traj[:T_org,:]
|
| 387 |
+
if self.training:
|
| 388 |
+
ret = {
|
| 389 |
+
"loss": loss,
|
| 390 |
+
"depth_loss": 0.0,
|
| 391 |
+
"ab_loss": 0.0,
|
| 392 |
+
"vis_loss": out["vis_loss"],
|
| 393 |
+
"track_loss": out["track_loss"],
|
| 394 |
+
"conf_loss": out["conf_loss"],
|
| 395 |
+
"dyn_loss": out["dyn_loss"],
|
| 396 |
+
"sync_loss": out["sync_loss"],
|
| 397 |
+
"poses_pred": c2w_traj[None],
|
| 398 |
+
"intrs": intrs_out[None],
|
| 399 |
+
"points_map": point_map,
|
| 400 |
+
"track3d_pred": track3d_pred[None],
|
| 401 |
+
"rgb_tracks": track3d_pred[None],
|
| 402 |
+
"track2d_pred": track2d_pred[None],
|
| 403 |
+
"traj_est": track2d_pred[None],
|
| 404 |
+
"vis_est": vis_pred[None], "conf_pred": conf_pred[None],
|
| 405 |
+
"dyn_preds": dyn_preds[None],
|
| 406 |
+
"imgs_raw": video[None],
|
| 407 |
+
"unc_metric": unc_metric,
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
return ret
|
| 411 |
+
else:
|
| 412 |
+
return c2w_traj, intrs_out, point_map, unc_metric, track3d_pred, track2d_pred, vis_pred, conf_pred
|
| 413 |
+
def forward(self,
|
| 414 |
+
x: torch.Tensor,
|
| 415 |
+
annots: Dict = {},
|
| 416 |
+
pts_q: torch.Tensor = None,
|
| 417 |
+
pts_q_3d: torch.Tensor = None,
|
| 418 |
+
overlap_d: torch.Tensor = None,
|
| 419 |
+
full_point = False,
|
| 420 |
+
fixed_cam = False,
|
| 421 |
+
support_frame = 0,
|
| 422 |
+
query_no_BA = False,
|
| 423 |
+
cache = None,
|
| 424 |
+
replace_ratio = 0.6,
|
| 425 |
+
iters_track=4,
|
| 426 |
+
**kwargs):
|
| 427 |
+
"""
|
| 428 |
+
forward the video camera model, which predict (
|
| 429 |
+
`intr` `camera poses` `video depth`
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
args:
|
| 433 |
+
x: the input video frames. [B, T, C, H, W]
|
| 434 |
+
annots: the annotations for video frames.
|
| 435 |
+
{
|
| 436 |
+
"poses_gt": the pose encoding for the video frames. [B, T, 7]
|
| 437 |
+
"depth_gt": the ground truth depth for the video frames. [B, T, 1, H, W],
|
| 438 |
+
"metric": bool, whether to calculate the metric for the video frames.
|
| 439 |
+
}
|
| 440 |
+
"""
|
| 441 |
+
self.support_frame = support_frame
|
| 442 |
+
|
| 443 |
+
#TODO: to adjust a little bit
|
| 444 |
+
track_loss=ab_loss=vis_loss=track_loss=conf_loss=dyn_loss=0.0
|
| 445 |
+
B, T, _, H, W = x.shape
|
| 446 |
+
imgs_raw = x.clone()
|
| 447 |
+
# get the video split and features for each segment
|
| 448 |
+
patch_size, x_resize = self.ProcVid(x)
|
| 449 |
+
x_resize = rearrange(x_resize, "(b t) c h w -> b t c h w", b=B)
|
| 450 |
+
H_resize, W_resize = x_resize.shape[-2:]
|
| 451 |
+
|
| 452 |
+
prec_fx = W / W_resize
|
| 453 |
+
prec_fy = H / H_resize
|
| 454 |
+
# get patch size
|
| 455 |
+
P_H, P_W = patch_size
|
| 456 |
+
|
| 457 |
+
# get the depth, pointmap and mask
|
| 458 |
+
#TODO: Release DepthAnything Version
|
| 459 |
+
points_map_gt = None
|
| 460 |
+
with torch.no_grad():
|
| 461 |
+
if_gt_depth = (("depth_gt" in annots.keys())) and (kwargs.get('stage', 0)==1 or kwargs.get('stage', 0)==3)
|
| 462 |
+
if if_gt_depth==False:
|
| 463 |
+
if cache is not None:
|
| 464 |
+
T_cache = cache["points_map"].shape[0]
|
| 465 |
+
T_new = T - T_cache
|
| 466 |
+
x_resize_new = x_resize[:, T_cache:]
|
| 467 |
+
else:
|
| 468 |
+
T_new = T
|
| 469 |
+
x_resize_new = x_resize
|
| 470 |
+
# infer with chunk
|
| 471 |
+
chunk_size = self.chunk_size
|
| 472 |
+
metric_depth = []
|
| 473 |
+
intrs = []
|
| 474 |
+
unc_metric = []
|
| 475 |
+
mask = []
|
| 476 |
+
points_map = []
|
| 477 |
+
normals = []
|
| 478 |
+
normals_mask = []
|
| 479 |
+
for i in range(0, B*T_new, chunk_size):
|
| 480 |
+
output = self.base_model.infer(x_resize_new.view(B*T_new, -1, H_resize, W_resize)[i:i+chunk_size])
|
| 481 |
+
metric_depth.append(output['depth'])
|
| 482 |
+
intrs.append(output['intrinsics'])
|
| 483 |
+
unc_metric.append(output['mask_prob'])
|
| 484 |
+
mask.append(output['mask'])
|
| 485 |
+
points_map.append(output['points'])
|
| 486 |
+
normals_i, normals_mask_i = utils3d.torch.points_to_normals(output['points'], mask=output['mask'])
|
| 487 |
+
normals.append(normals_i)
|
| 488 |
+
normals_mask.append(normals_mask_i)
|
| 489 |
+
|
| 490 |
+
metric_depth = torch.cat(metric_depth, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
|
| 491 |
+
intrs = torch.cat(intrs, dim=0).view(B, T_new, 3, 3).to(x.dtype)
|
| 492 |
+
intrs[:,:,0,:] *= W_resize
|
| 493 |
+
intrs[:,:,1,:] *= H_resize
|
| 494 |
+
# points_map = torch.cat(points_map, dim=0)
|
| 495 |
+
mask = torch.cat(mask, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
|
| 496 |
+
# cat the normals
|
| 497 |
+
normals = torch.cat(normals, dim=0)
|
| 498 |
+
normals_mask = torch.cat(normals_mask, dim=0)
|
| 499 |
+
|
| 500 |
+
metric_depth = metric_depth.clone()
|
| 501 |
+
metric_depth[metric_depth == torch.inf] = 0
|
| 502 |
+
_depths = metric_depth[metric_depth > 0].reshape(-1)
|
| 503 |
+
q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
|
| 504 |
+
q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
|
| 505 |
+
iqr = q75 - q25
|
| 506 |
+
upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
|
| 507 |
+
_depth_roi = torch.tensor(
|
| 508 |
+
[1e-1, upper_bound.item()],
|
| 509 |
+
dtype=metric_depth.dtype,
|
| 510 |
+
device=metric_depth.device
|
| 511 |
+
)
|
| 512 |
+
mask_roi = (metric_depth > _depth_roi[0]) & (metric_depth < _depth_roi[1])
|
| 513 |
+
mask = mask * mask_roi
|
| 514 |
+
mask = mask * (~(utils3d.torch.depth_edge(metric_depth, rtol=0.03, mask=mask.bool()))) * normals_mask[:,None,...]
|
| 515 |
+
points_map = depth_to_points_colmap(metric_depth.squeeze(1), intrs.view(B*T_new, 3, 3))
|
| 516 |
+
unc_metric = torch.cat(unc_metric, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
|
| 517 |
+
unc_metric *= mask
|
| 518 |
+
if full_point:
|
| 519 |
+
unc_metric = (~(utils3d.torch.depth_edge(metric_depth, rtol=0.1, mask=torch.ones_like(metric_depth).bool()))).float() * (metric_depth != 0)
|
| 520 |
+
if cache is not None:
|
| 521 |
+
assert B==1, "only support batch size 1 right now."
|
| 522 |
+
unc_metric = torch.cat([cache["unc_metric"], unc_metric], dim=0)
|
| 523 |
+
intrs = torch.cat([cache["intrs"][None], intrs], dim=1)
|
| 524 |
+
points_map = torch.cat([cache["points_map"].permute(0,2,3,1), points_map], dim=0)
|
| 525 |
+
metric_depth = torch.cat([cache["metric_depth"], metric_depth], dim=0)
|
| 526 |
+
|
| 527 |
+
if "poses_gt" in annots.keys():
|
| 528 |
+
intrs, c2w_traj_gt = pose_enc2mat(annots["poses_gt"],
|
| 529 |
+
H_resize, W_resize, self.resolution)
|
| 530 |
+
else:
|
| 531 |
+
c2w_traj_gt = None
|
| 532 |
+
|
| 533 |
+
if "intrs_gt" in annots.keys():
|
| 534 |
+
intrs = annots["intrs_gt"].view(B, T, 3, 3)
|
| 535 |
+
fx_factor = W_resize / W
|
| 536 |
+
fy_factor = H_resize / H
|
| 537 |
+
intrs[:,:,0,:] *= fx_factor
|
| 538 |
+
intrs[:,:,1,:] *= fy_factor
|
| 539 |
+
|
| 540 |
+
if "depth_gt" in annots.keys():
|
| 541 |
+
|
| 542 |
+
metric_depth_gt = annots['depth_gt'].view(B*T, 1, H, W)
|
| 543 |
+
metric_depth_gt = F.interpolate(metric_depth_gt,
|
| 544 |
+
size=(H_resize, W_resize), mode='nearest')
|
| 545 |
+
|
| 546 |
+
_depths = metric_depth_gt[metric_depth_gt > 0].reshape(-1)
|
| 547 |
+
q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
|
| 548 |
+
q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
|
| 549 |
+
iqr = q75 - q25
|
| 550 |
+
upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
|
| 551 |
+
_depth_roi = torch.tensor(
|
| 552 |
+
[1e-1, upper_bound.item()],
|
| 553 |
+
dtype=metric_depth_gt.dtype,
|
| 554 |
+
device=metric_depth_gt.device
|
| 555 |
+
)
|
| 556 |
+
mask_roi = (metric_depth_gt > _depth_roi[0]) & (metric_depth_gt < _depth_roi[1])
|
| 557 |
+
# if (upper_bound > 200).any():
|
| 558 |
+
# import pdb; pdb.set_trace()
|
| 559 |
+
if (kwargs.get('stage', 0) == 2):
|
| 560 |
+
unc_metric = ((metric_depth_gt > 0)*(mask_roi) * (unc_metric > 0.5)).float()
|
| 561 |
+
metric_depth_gt[metric_depth_gt > 10*q25] = 0
|
| 562 |
+
else:
|
| 563 |
+
unc_metric = ((metric_depth_gt > 0)*(mask_roi)).float()
|
| 564 |
+
unc_metric *= (~(utils3d.torch.depth_edge(metric_depth_gt, rtol=0.03, mask=mask_roi.bool()))).float()
|
| 565 |
+
# filter the sky
|
| 566 |
+
metric_depth_gt[metric_depth_gt > 10*q25] = 0
|
| 567 |
+
if "unc_metric" in annots.keys():
|
| 568 |
+
unc_metric_ = F.interpolate(annots["unc_metric"].permute(1,0,2,3),
|
| 569 |
+
size=(H_resize, W_resize), mode='nearest')
|
| 570 |
+
unc_metric = unc_metric * unc_metric_
|
| 571 |
+
if if_gt_depth:
|
| 572 |
+
points_map = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
|
| 573 |
+
metric_depth = metric_depth_gt
|
| 574 |
+
points_map_gt = points_map
|
| 575 |
+
else:
|
| 576 |
+
points_map_gt = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
|
| 577 |
+
|
| 578 |
+
# track the 3d points
|
| 579 |
+
ret_track = None
|
| 580 |
+
regular_track = True
|
| 581 |
+
dyn_preds, final_tracks = None, None
|
| 582 |
+
|
| 583 |
+
if "use_extr" in annots.keys():
|
| 584 |
+
init_pose = True
|
| 585 |
+
else:
|
| 586 |
+
init_pose = False
|
| 587 |
+
# set the custom vid and valid only
|
| 588 |
+
custom_vid = annots.get("custom_vid", False)
|
| 589 |
+
valid_only = annots.get("data_dir", [""])[0] == "stereo4d"
|
| 590 |
+
if self.training:
|
| 591 |
+
if (annots["vis"].sum() > 0) and (kwargs.get('stage', 0)==1 or kwargs.get('stage', 0)==3):
|
| 592 |
+
traj3d = annots['traj_3d']
|
| 593 |
+
if (kwargs.get('stage', 0)==1) and (annots.get("custom_vid", False)==False):
|
| 594 |
+
support_pts_q = get_track_points(H_resize, W_resize,
|
| 595 |
+
T, x.device, query_size=self.track_num // 2,
|
| 596 |
+
support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental")[None]
|
| 597 |
+
else:
|
| 598 |
+
support_pts_q = get_track_points(H_resize, W_resize,
|
| 599 |
+
T, x.device, query_size=random.randint(32, 256),
|
| 600 |
+
support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental")[None]
|
| 601 |
+
if pts_q is not None:
|
| 602 |
+
pts_q = pts_q[None,None]
|
| 603 |
+
ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
|
| 604 |
+
metric_depth,
|
| 605 |
+
unc_metric.detach(), points_map, pts_q,
|
| 606 |
+
intrs=intrs.clone(), cache=cache,
|
| 607 |
+
prec_fx=prec_fx, prec_fy=prec_fy, overlap_d=overlap_d,
|
| 608 |
+
vis_gt=annots['vis'], traj3d_gt=traj3d, iters=iters_track,
|
| 609 |
+
cam_gt=c2w_traj_gt, support_pts_q=support_pts_q, custom_vid=custom_vid,
|
| 610 |
+
init_pose=init_pose, fixed_cam=fixed_cam, stage=kwargs.get('stage', 0),
|
| 611 |
+
points_map_gt=points_map_gt, valid_only=valid_only, replace_ratio=replace_ratio)
|
| 612 |
+
else:
|
| 613 |
+
ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
|
| 614 |
+
metric_depth,
|
| 615 |
+
unc_metric.detach(), points_map, traj3d[..., :2],
|
| 616 |
+
intrs=intrs.clone(), cache=cache,
|
| 617 |
+
prec_fx=prec_fx, prec_fy=prec_fy, overlap_d=overlap_d,
|
| 618 |
+
vis_gt=annots['vis'], traj3d_gt=traj3d, iters=iters_track,
|
| 619 |
+
cam_gt=c2w_traj_gt, support_pts_q=support_pts_q, custom_vid=custom_vid,
|
| 620 |
+
init_pose=init_pose, fixed_cam=fixed_cam, stage=kwargs.get('stage', 0),
|
| 621 |
+
points_map_gt=points_map_gt, valid_only=valid_only, replace_ratio=replace_ratio)
|
| 622 |
+
regular_track = False
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
if regular_track:
|
| 626 |
+
if pts_q is None:
|
| 627 |
+
pts_q = get_track_points(H_resize, W_resize,
|
| 628 |
+
T, x.device, query_size=self.track_num,
|
| 629 |
+
support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental" if self.training else "incremental")[None]
|
| 630 |
+
support_pts_q = None
|
| 631 |
+
else:
|
| 632 |
+
pts_q = pts_q[None,None]
|
| 633 |
+
# resize the query points
|
| 634 |
+
pts_q[...,1] *= W_resize / W
|
| 635 |
+
pts_q[...,2] *= H_resize / H
|
| 636 |
+
|
| 637 |
+
if pts_q_3d is not None:
|
| 638 |
+
pts_q_3d = pts_q_3d[None,None]
|
| 639 |
+
# resize the query points
|
| 640 |
+
pts_q_3d[...,1] *= W_resize / W
|
| 641 |
+
pts_q_3d[...,2] *= H_resize / H
|
| 642 |
+
else:
|
| 643 |
+
# adjust the query with uncertainty
|
| 644 |
+
if (full_point==False) and (overlap_d is None):
|
| 645 |
+
pts_q_unc = sample_features5d(unc_metric[None], pts_q).squeeze()
|
| 646 |
+
pts_q = pts_q[:,:,pts_q_unc>0.5,:]
|
| 647 |
+
if (pts_q_unc<0.5).sum() > 0:
|
| 648 |
+
# pad the query points
|
| 649 |
+
pad_num = pts_q_unc.shape[0] - pts_q.shape[2]
|
| 650 |
+
# pick the random indices
|
| 651 |
+
indices = torch.randint(0, pts_q.shape[2], (pad_num,), device=pts_q.device)
|
| 652 |
+
pad_pts = indices
|
| 653 |
+
pts_q = torch.cat([pts_q, pts_q[:,:,pad_pts,:]], dim=-2)
|
| 654 |
+
|
| 655 |
+
support_pts_q = get_track_points(H_resize, W_resize,
|
| 656 |
+
T, x.device, query_size=self.track_num,
|
| 657 |
+
support_frame=self.support_frame,
|
| 658 |
+
unc_metric=unc_metric, mode="mixed")[None]
|
| 659 |
+
|
| 660 |
+
points_map[points_map>1e3] = 0
|
| 661 |
+
points_map = depth_to_points_colmap(metric_depth.squeeze(1), intrs.view(B*T, 3, 3))
|
| 662 |
+
ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
|
| 663 |
+
metric_depth,
|
| 664 |
+
unc_metric.detach(), points_map, pts_q,
|
| 665 |
+
pts_q_3d=pts_q_3d, intrs=intrs.clone(),cache=cache,
|
| 666 |
+
overlap_d=overlap_d, cam_gt=c2w_traj_gt if kwargs.get('stage', 0)==1 else None,
|
| 667 |
+
prec_fx=prec_fx, prec_fy=prec_fy, support_pts_q=support_pts_q, custom_vid=custom_vid, valid_only=valid_only,
|
| 668 |
+
fixed_cam=fixed_cam, query_no_BA=query_no_BA, init_pose=init_pose, iters=iters_track,
|
| 669 |
+
stage=kwargs.get('stage', 0), points_map_gt=points_map_gt, replace_ratio=replace_ratio)
|
| 670 |
+
intrs = intrs_org
|
| 671 |
+
points_map = point_map_org_refined
|
| 672 |
+
c2w_traj = ret_track["cam_pred"]
|
| 673 |
+
|
| 674 |
+
if ret_track is not None:
|
| 675 |
+
if ret_track["loss"] is not None:
|
| 676 |
+
track_loss, conf_loss, dyn_loss, vis_loss, point_map_loss, scale_loss, shift_loss, sync_loss= ret_track["loss"]
|
| 677 |
+
|
| 678 |
+
# update the cache
|
| 679 |
+
cache.update({"metric_depth": metric_depth, "unc_metric": unc_metric, "points_map": points_map, "intrs": intrs[0]})
|
| 680 |
+
# output
|
| 681 |
+
depth = F.interpolate(metric_depth,
|
| 682 |
+
size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
|
| 683 |
+
points_map = F.interpolate(points_map,
|
| 684 |
+
size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
|
| 685 |
+
unc_metric = F.interpolate(unc_metric,
|
| 686 |
+
size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
|
| 687 |
+
|
| 688 |
+
if self.training:
|
| 689 |
+
|
| 690 |
+
loss = track_loss + conf_loss + dyn_loss + sync_loss + vis_loss + point_map_loss + (scale_loss + shift_loss)*50
|
| 691 |
+
ret = {"loss": loss,
|
| 692 |
+
"depth_loss": point_map_loss,
|
| 693 |
+
"ab_loss": (scale_loss + shift_loss)*50,
|
| 694 |
+
"vis_loss": vis_loss, "track_loss": track_loss,
|
| 695 |
+
"poses_pred": c2w_traj, "dyn_preds": dyn_preds, "traj_est": final_tracks, "conf_loss": conf_loss,
|
| 696 |
+
"imgs_raw": imgs_raw, "rgb_tracks": rgb_tracks, "vis_est": ret_track['vis_pred'],
|
| 697 |
+
"depth": depth, "points_map": points_map, "unc_metric": unc_metric, "intrs": intrs, "dyn_loss": dyn_loss,
|
| 698 |
+
"sync_loss": sync_loss, "conf_pred": ret_track['conf_pred'], "cache": cache,
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
else:
|
| 702 |
+
|
| 703 |
+
if ret_track is not None:
|
| 704 |
+
traj_est = ret_track['preds']
|
| 705 |
+
traj_est[..., 0] *= W / W_resize
|
| 706 |
+
traj_est[..., 1] *= H / H_resize
|
| 707 |
+
vis_est = ret_track['vis_pred']
|
| 708 |
+
else:
|
| 709 |
+
traj_est = torch.zeros(B, self.track_num // 2, 3).to(x.device)
|
| 710 |
+
vis_est = torch.zeros(B, self.track_num // 2).to(x.device)
|
| 711 |
+
|
| 712 |
+
if intrs is not None:
|
| 713 |
+
intrs[..., 0, :] *= W / W_resize
|
| 714 |
+
intrs[..., 1, :] *= H / H_resize
|
| 715 |
+
ret = {"poses_pred": c2w_traj, "dyn_preds": dyn_preds,
|
| 716 |
+
"depth": depth, "traj_est": traj_est, "vis_est": vis_est, "imgs_raw": imgs_raw,
|
| 717 |
+
"rgb_tracks": rgb_tracks, "intrs": intrs, "unc_metric": unc_metric, "points_map": points_map,
|
| 718 |
+
"conf_pred": ret_track['conf_pred'], "cache": cache,
|
| 719 |
+
}
|
| 720 |
+
|
| 721 |
+
return ret
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
# three stages of training
|
| 727 |
+
|
| 728 |
+
# stage 1:
|
| 729 |
+
# gt depth and intrinsics synthetic (includes Dynamic Replica, Kubric, Pointodyssey, Vkitti, TartanAir and Indoor() ) Motion Patern (tapvid3d)
|
| 730 |
+
# Tracking and Pose as well -> based on gt depth and intrinsics
|
| 731 |
+
# (Finished) -> (megasam + base model) vs. tapip3d. (use depth from megasam or pose, which keep the same setting as tapip3d.)
|
| 732 |
+
|
| 733 |
+
# stage 2: fixed 3D tracking
|
| 734 |
+
# Joint depth refiner
|
| 735 |
+
# input depth from whatever + rgb -> temporal module + scale and shift token -> coarse alignment -> scale and shift
|
| 736 |
+
# estimate the 3D tracks -> 3D tracks combine with pointmap -> update for pointmap (iteratively) -> residual map B T 3 H W
|
| 737 |
+
# ongoing two days
|
| 738 |
+
|
| 739 |
+
# stage 3: train multi windows by propagation
|
| 740 |
+
# 4 frames overlapped -> train on 64 -> fozen image encoder and finetuning the transformer (learnable parameters pretty small)
|
| 741 |
+
|
| 742 |
+
# types of scenarioes:
|
| 743 |
+
# 1. auto driving (waymo open dataset)
|
| 744 |
+
# 2. robot
|
| 745 |
+
# 3. internet ego video
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
# Iterative Transformer -- Solver -- General Neural MegaSAM + Tracks
|
| 750 |
+
# Update Variables:
|
| 751 |
+
# 1. 3D tracks B T N 3 xyz.
|
| 752 |
+
# 2. 2D tracks B T N 2 x y.
|
| 753 |
+
# 3. Dynamic Mask B T H W.
|
| 754 |
+
# 4. Camera Pose B T 4 4.
|
| 755 |
+
# 5. Video Depth.
|
| 756 |
+
|
| 757 |
+
# (RGB, RGBD, RGBD+Pose) x (Static, Dynamic)
|
| 758 |
+
# Campatiablity by product.
|
models/SpaTrackV2/models/__init__.py
ADDED
|
File without changes
|
models/SpaTrackV2/models/blocks.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.cuda.amp import autocast
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import collections
|
| 13 |
+
from functools import partial
|
| 14 |
+
from itertools import repeat
|
| 15 |
+
import torchvision.models as tvm
|
| 16 |
+
from torch.utils.checkpoint import checkpoint
|
| 17 |
+
from models.monoD.depth_anything.dpt import DPTHeadEnc, DPTHead
|
| 18 |
+
from typing import Union, Tuple
|
| 19 |
+
from torch import Tensor
|
| 20 |
+
|
| 21 |
+
# From PyTorch internals
|
| 22 |
+
def _ntuple(n):
|
| 23 |
+
def parse(x):
|
| 24 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 25 |
+
return tuple(x)
|
| 26 |
+
return tuple(repeat(x, n))
|
| 27 |
+
|
| 28 |
+
return parse
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def exists(val):
|
| 32 |
+
return val is not None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def default(val, d):
|
| 36 |
+
return val if exists(val) else d
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
to_2tuple = _ntuple(2)
|
| 40 |
+
|
| 41 |
+
class LayerScale(nn.Module):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
dim: int,
|
| 45 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 46 |
+
inplace: bool = False,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.inplace = inplace
|
| 50 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 51 |
+
|
| 52 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 53 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 54 |
+
|
| 55 |
+
class Mlp(nn.Module):
|
| 56 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
in_features,
|
| 61 |
+
hidden_features=None,
|
| 62 |
+
out_features=None,
|
| 63 |
+
act_layer=nn.GELU,
|
| 64 |
+
norm_layer=None,
|
| 65 |
+
bias=True,
|
| 66 |
+
drop=0.0,
|
| 67 |
+
use_conv=False,
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
out_features = out_features or in_features
|
| 71 |
+
hidden_features = hidden_features or in_features
|
| 72 |
+
bias = to_2tuple(bias)
|
| 73 |
+
drop_probs = to_2tuple(drop)
|
| 74 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 75 |
+
|
| 76 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
| 77 |
+
self.act = act_layer()
|
| 78 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 79 |
+
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
| 80 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
| 81 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
x = self.fc1(x)
|
| 85 |
+
x = self.act(x)
|
| 86 |
+
x = self.drop1(x)
|
| 87 |
+
x = self.fc2(x)
|
| 88 |
+
x = self.drop2(x)
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
class Attention(nn.Module):
|
| 92 |
+
def __init__(self, query_dim, context_dim=None,
|
| 93 |
+
num_heads=8, dim_head=48, qkv_bias=False, flash=False):
|
| 94 |
+
super().__init__()
|
| 95 |
+
inner_dim = self.inner_dim = dim_head * num_heads
|
| 96 |
+
context_dim = default(context_dim, query_dim)
|
| 97 |
+
self.scale = dim_head**-0.5
|
| 98 |
+
self.heads = num_heads
|
| 99 |
+
self.flash = flash
|
| 100 |
+
|
| 101 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
| 102 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
| 103 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
| 104 |
+
|
| 105 |
+
def forward(self, x, context=None, attn_bias=None):
|
| 106 |
+
B, N1, _ = x.shape
|
| 107 |
+
C = self.inner_dim
|
| 108 |
+
h = self.heads
|
| 109 |
+
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
| 110 |
+
context = default(context, x)
|
| 111 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
| 112 |
+
|
| 113 |
+
N2 = context.shape[1]
|
| 114 |
+
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
| 115 |
+
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
| 116 |
+
|
| 117 |
+
with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 118 |
+
if self.flash==False:
|
| 119 |
+
sim = (q @ k.transpose(-2, -1)) * self.scale
|
| 120 |
+
if attn_bias is not None:
|
| 121 |
+
sim = sim + attn_bias
|
| 122 |
+
if sim.abs().max()>1e2:
|
| 123 |
+
import pdb; pdb.set_trace()
|
| 124 |
+
attn = sim.softmax(dim=-1)
|
| 125 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
|
| 126 |
+
else:
|
| 127 |
+
input_args = [x.contiguous() for x in [q, k, v]]
|
| 128 |
+
x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
|
| 129 |
+
|
| 130 |
+
if self.to_out.bias.dtype != x.dtype:
|
| 131 |
+
x = x.to(self.to_out.bias.dtype)
|
| 132 |
+
|
| 133 |
+
return self.to_out(x)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class VGG19(nn.Module):
|
| 137 |
+
def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
| 140 |
+
self.amp = amp
|
| 141 |
+
self.amp_dtype = amp_dtype
|
| 142 |
+
|
| 143 |
+
def forward(self, x, **kwargs):
|
| 144 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
| 145 |
+
feats = {}
|
| 146 |
+
scale = 1
|
| 147 |
+
for layer in self.layers:
|
| 148 |
+
if isinstance(layer, nn.MaxPool2d):
|
| 149 |
+
feats[scale] = x
|
| 150 |
+
scale = scale*2
|
| 151 |
+
x = layer(x)
|
| 152 |
+
return feats
|
| 153 |
+
|
| 154 |
+
class CNNandDinov2(nn.Module):
|
| 155 |
+
def __init__(self, cnn_kwargs = None, amp = True, amp_dtype = torch.float16):
|
| 156 |
+
super().__init__()
|
| 157 |
+
# in case the Internet connection is not stable, please load the DINOv2 locally
|
| 158 |
+
self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
|
| 159 |
+
'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
|
| 160 |
+
|
| 161 |
+
state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
|
| 162 |
+
self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
|
| 166 |
+
self.cnn = VGG19(**cnn_kwargs)
|
| 167 |
+
self.amp = amp
|
| 168 |
+
self.amp_dtype = amp_dtype
|
| 169 |
+
if self.amp:
|
| 170 |
+
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
| 171 |
+
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def train(self, mode: bool = True):
|
| 175 |
+
return self.cnn.train(mode)
|
| 176 |
+
|
| 177 |
+
def forward(self, x, upsample = False):
|
| 178 |
+
B,C,H,W = x.shape
|
| 179 |
+
feature_pyramid = self.cnn(x)
|
| 180 |
+
|
| 181 |
+
if not upsample:
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
if self.dinov2_vitl14[0].device != x.device:
|
| 184 |
+
self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
|
| 185 |
+
dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
|
| 186 |
+
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
|
| 187 |
+
del dinov2_features_16
|
| 188 |
+
feature_pyramid[16] = features_16
|
| 189 |
+
return feature_pyramid
|
| 190 |
+
|
| 191 |
+
class Dinov2(nn.Module):
|
| 192 |
+
def __init__(self, amp = True, amp_dtype = torch.float16):
|
| 193 |
+
super().__init__()
|
| 194 |
+
# in case the Internet connection is not stable, please load the DINOv2 locally
|
| 195 |
+
self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
|
| 196 |
+
'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
|
| 197 |
+
|
| 198 |
+
state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
|
| 199 |
+
self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
|
| 200 |
+
|
| 201 |
+
self.amp = amp
|
| 202 |
+
self.amp_dtype = amp_dtype
|
| 203 |
+
if self.amp:
|
| 204 |
+
self.dinov2_vitl14 = self.dinov2_vitl14.to(self.amp_dtype)
|
| 205 |
+
|
| 206 |
+
def forward(self, x, upsample = False):
|
| 207 |
+
B,C,H,W = x.shape
|
| 208 |
+
mean_ = torch.tensor([0.485, 0.456, 0.406],
|
| 209 |
+
device=x.device).view(1, 3, 1, 1)
|
| 210 |
+
std_ = torch.tensor([0.229, 0.224, 0.225],
|
| 211 |
+
device=x.device).view(1, 3, 1, 1)
|
| 212 |
+
x = (x+1)/2
|
| 213 |
+
x = (x - mean_)/std_
|
| 214 |
+
h_re, w_re = 560, 560
|
| 215 |
+
x_resize = F.interpolate(x, size=(h_re, w_re),
|
| 216 |
+
mode='bilinear', align_corners=True)
|
| 217 |
+
if not upsample:
|
| 218 |
+
with torch.no_grad():
|
| 219 |
+
dinov2_features_16 = self.dinov2_vitl14.forward_features(x_resize.to(self.amp_dtype))
|
| 220 |
+
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,h_re//14, w_re//14)
|
| 221 |
+
del dinov2_features_16
|
| 222 |
+
features_16 = F.interpolate(features_16, size=(H//8, W//8), mode="bilinear", align_corners=True)
|
| 223 |
+
return features_16
|
| 224 |
+
|
| 225 |
+
class AttnBlock(nn.Module):
|
| 226 |
+
"""
|
| 227 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0,
|
| 231 |
+
flash=False, ckpt_fwd=False, debug=False, **block_kwargs):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.debug=debug
|
| 234 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 235 |
+
self.flash=flash
|
| 236 |
+
|
| 237 |
+
self.attn = Attention(
|
| 238 |
+
hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash,
|
| 239 |
+
**block_kwargs
|
| 240 |
+
)
|
| 241 |
+
self.ls = LayerScale(hidden_size, init_values=0.005)
|
| 242 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 243 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 244 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 245 |
+
self.mlp = Mlp(
|
| 246 |
+
in_features=hidden_size,
|
| 247 |
+
hidden_features=mlp_hidden_dim,
|
| 248 |
+
act_layer=approx_gelu,
|
| 249 |
+
)
|
| 250 |
+
self.ckpt_fwd = ckpt_fwd
|
| 251 |
+
def forward(self, x):
|
| 252 |
+
if self.debug:
|
| 253 |
+
print(x.max(), x.min(), x.mean())
|
| 254 |
+
if self.ckpt_fwd:
|
| 255 |
+
x = x + checkpoint(self.attn, self.norm1(x), use_reentrant=False)
|
| 256 |
+
else:
|
| 257 |
+
x = x + self.attn(self.norm1(x))
|
| 258 |
+
|
| 259 |
+
x = x + self.ls(self.mlp(self.norm2(x)))
|
| 260 |
+
return x
|
| 261 |
+
|
| 262 |
+
class CrossAttnBlock(nn.Module):
|
| 263 |
+
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, head_dim=48,
|
| 264 |
+
flash=False, ckpt_fwd=False, **block_kwargs):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 267 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
| 268 |
+
|
| 269 |
+
self.cross_attn = Attention(
|
| 270 |
+
hidden_size, context_dim=context_dim, dim_head=head_dim,
|
| 271 |
+
num_heads=num_heads, qkv_bias=True, **block_kwargs, flash=flash,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 275 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 276 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 277 |
+
self.mlp = Mlp(
|
| 278 |
+
in_features=hidden_size,
|
| 279 |
+
hidden_features=mlp_hidden_dim,
|
| 280 |
+
act_layer=approx_gelu,
|
| 281 |
+
drop=0,
|
| 282 |
+
)
|
| 283 |
+
self.ckpt_fwd = ckpt_fwd
|
| 284 |
+
def forward(self, x, context):
|
| 285 |
+
if self.ckpt_fwd:
|
| 286 |
+
with autocast():
|
| 287 |
+
x = x + checkpoint(self.cross_attn,
|
| 288 |
+
self.norm1(x), self.norm_context(context), use_reentrant=False)
|
| 289 |
+
else:
|
| 290 |
+
with autocast():
|
| 291 |
+
x = x + self.cross_attn(
|
| 292 |
+
self.norm1(x), self.norm_context(context)
|
| 293 |
+
)
|
| 294 |
+
x = x + self.mlp(self.norm2(x))
|
| 295 |
+
return x
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
| 299 |
+
"""Wrapper for grid_sample, uses pixel coordinates"""
|
| 300 |
+
H, W = img.shape[-2:]
|
| 301 |
+
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
| 302 |
+
# go to 0,1 then 0,2 then -1,1
|
| 303 |
+
xgrid = 2 * xgrid / (W - 1) - 1
|
| 304 |
+
ygrid = 2 * ygrid / (H - 1) - 1
|
| 305 |
+
|
| 306 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
| 307 |
+
img = F.grid_sample(img, grid, align_corners=True, mode=mode)
|
| 308 |
+
|
| 309 |
+
if mask:
|
| 310 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
| 311 |
+
return img, mask.float()
|
| 312 |
+
|
| 313 |
+
return img
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class CorrBlock:
|
| 317 |
+
def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None):
|
| 318 |
+
B, S, C, H_prev, W_prev = fmaps.shape
|
| 319 |
+
self.S, self.C, self.H, self.W = S, C, H_prev, W_prev
|
| 320 |
+
|
| 321 |
+
self.num_levels = num_levels
|
| 322 |
+
self.radius = radius
|
| 323 |
+
self.fmaps_pyramid = []
|
| 324 |
+
self.depth_pyramid = []
|
| 325 |
+
self.fmaps_pyramid.append(fmaps)
|
| 326 |
+
if depths_dnG is not None:
|
| 327 |
+
self.depth_pyramid.append(depths_dnG)
|
| 328 |
+
for i in range(self.num_levels - 1):
|
| 329 |
+
if depths_dnG is not None:
|
| 330 |
+
depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev)
|
| 331 |
+
depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2)
|
| 332 |
+
_, _, H, W = depths_dnG_.shape
|
| 333 |
+
depths_dnG = depths_dnG_.reshape(B, S, 1, H, W)
|
| 334 |
+
self.depth_pyramid.append(depths_dnG)
|
| 335 |
+
fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev)
|
| 336 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
| 337 |
+
_, _, H, W = fmaps_.shape
|
| 338 |
+
fmaps = fmaps_.reshape(B, S, C, H, W)
|
| 339 |
+
H_prev = H
|
| 340 |
+
W_prev = W
|
| 341 |
+
self.fmaps_pyramid.append(fmaps)
|
| 342 |
+
|
| 343 |
+
def sample(self, coords):
|
| 344 |
+
r = self.radius
|
| 345 |
+
B, S, N, D = coords.shape
|
| 346 |
+
assert D == 2
|
| 347 |
+
|
| 348 |
+
H, W = self.H, self.W
|
| 349 |
+
out_pyramid = []
|
| 350 |
+
for i in range(self.num_levels):
|
| 351 |
+
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
| 352 |
+
_, _, _, H, W = corrs.shape
|
| 353 |
+
|
| 354 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
| 355 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
| 356 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
| 357 |
+
coords.device
|
| 358 |
+
)
|
| 359 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
|
| 360 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
| 361 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 362 |
+
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
|
| 363 |
+
corrs = corrs.view(B, S, N, -1)
|
| 364 |
+
out_pyramid.append(corrs)
|
| 365 |
+
|
| 366 |
+
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
| 367 |
+
return out.contiguous().float()
|
| 368 |
+
|
| 369 |
+
def corr(self, targets):
|
| 370 |
+
B, S, N, C = targets.shape
|
| 371 |
+
assert C == self.C
|
| 372 |
+
assert S == self.S
|
| 373 |
+
|
| 374 |
+
fmap1 = targets
|
| 375 |
+
|
| 376 |
+
self.corrs_pyramid = []
|
| 377 |
+
for fmaps in self.fmaps_pyramid:
|
| 378 |
+
_, _, _, H, W = fmaps.shape
|
| 379 |
+
fmap2s = fmaps.view(B, S, C, H * W)
|
| 380 |
+
corrs = torch.matmul(fmap1, fmap2s)
|
| 381 |
+
corrs = corrs.view(B, S, N, H, W)
|
| 382 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
| 383 |
+
self.corrs_pyramid.append(corrs)
|
| 384 |
+
|
| 385 |
+
def corr_sample(self, targets, coords, coords_dp=None):
|
| 386 |
+
B, S, N, C = targets.shape
|
| 387 |
+
r = self.radius
|
| 388 |
+
Dim_c = (2*r+1)**2
|
| 389 |
+
assert C == self.C
|
| 390 |
+
assert S == self.S
|
| 391 |
+
|
| 392 |
+
out_pyramid = []
|
| 393 |
+
out_pyramid_dp = []
|
| 394 |
+
for i in range(self.num_levels):
|
| 395 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
| 396 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
| 397 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
| 398 |
+
coords.device
|
| 399 |
+
)
|
| 400 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
|
| 401 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
| 402 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 403 |
+
fmaps = self.fmaps_pyramid[i]
|
| 404 |
+
_, _, _, H, W = fmaps.shape
|
| 405 |
+
fmap2s = fmaps.view(B*S, C, H, W)
|
| 406 |
+
if len(self.depth_pyramid)>0:
|
| 407 |
+
depths_dnG_i = self.depth_pyramid[i]
|
| 408 |
+
depths_dnG_i = depths_dnG_i.view(B*S, 1, H, W)
|
| 409 |
+
dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B*S,1,N*Dim_c,2))
|
| 410 |
+
dp_corrs = (dnG_sample.view(B*S,N,-1) - coords_dp[0]).abs()/coords_dp[0]
|
| 411 |
+
out_pyramid_dp.append(dp_corrs)
|
| 412 |
+
fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B*S,1,N*Dim_c,2))
|
| 413 |
+
fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2) # B*S, N*Dim_c, C, -1
|
| 414 |
+
corrs = torch.matmul(targets.reshape(B*S*N, 1, -1), fmap2s_sample.reshape(B*S*N, Dim_c, -1).permute(0, 2, 1))
|
| 415 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
| 416 |
+
corrs = corrs.view(B, S, N, -1)
|
| 417 |
+
out_pyramid.append(corrs)
|
| 418 |
+
|
| 419 |
+
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
| 420 |
+
if len(self.depth_pyramid)>0:
|
| 421 |
+
out_dp = torch.cat(out_pyramid_dp, dim=-1)
|
| 422 |
+
self.fcorrD = out_dp.contiguous().float()
|
| 423 |
+
else:
|
| 424 |
+
self.fcorrD = torch.zeros_like(out).contiguous().float()
|
| 425 |
+
return out.contiguous().float()
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class EUpdateFormer(nn.Module):
|
| 429 |
+
"""
|
| 430 |
+
Transformer model that updates track estimates.
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
def __init__(
|
| 434 |
+
self,
|
| 435 |
+
space_depth=12,
|
| 436 |
+
time_depth=12,
|
| 437 |
+
input_dim=320,
|
| 438 |
+
hidden_size=384,
|
| 439 |
+
num_heads=8,
|
| 440 |
+
output_dim=130,
|
| 441 |
+
mlp_ratio=4.0,
|
| 442 |
+
vq_depth=3,
|
| 443 |
+
add_space_attn=True,
|
| 444 |
+
add_time_attn=True,
|
| 445 |
+
flash=True
|
| 446 |
+
):
|
| 447 |
+
super().__init__()
|
| 448 |
+
self.out_channels = 2
|
| 449 |
+
self.num_heads = num_heads
|
| 450 |
+
self.hidden_size = hidden_size
|
| 451 |
+
self.add_space_attn = add_space_attn
|
| 452 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
| 453 |
+
self.flash = flash
|
| 454 |
+
self.flow_head = nn.Sequential(
|
| 455 |
+
nn.Linear(hidden_size, output_dim, bias=True),
|
| 456 |
+
nn.ReLU(inplace=True),
|
| 457 |
+
nn.Linear(output_dim, output_dim, bias=True),
|
| 458 |
+
nn.ReLU(inplace=True),
|
| 459 |
+
nn.Linear(output_dim, output_dim, bias=True)
|
| 460 |
+
)
|
| 461 |
+
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 462 |
+
cfg = xLSTMBlockStackConfig(
|
| 463 |
+
mlstm_block=mLSTMBlockConfig(
|
| 464 |
+
mlstm=mLSTMLayerConfig(
|
| 465 |
+
conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
|
| 466 |
+
)
|
| 467 |
+
),
|
| 468 |
+
slstm_block=sLSTMBlockConfig(
|
| 469 |
+
slstm=sLSTMLayerConfig(
|
| 470 |
+
backend="cuda",
|
| 471 |
+
num_heads=4,
|
| 472 |
+
conv1d_kernel_size=4,
|
| 473 |
+
bias_init="powerlaw_blockdependent",
|
| 474 |
+
),
|
| 475 |
+
feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
|
| 476 |
+
),
|
| 477 |
+
context_length=50,
|
| 478 |
+
num_blocks=7,
|
| 479 |
+
embedding_dim=384,
|
| 480 |
+
slstm_at=[1],
|
| 481 |
+
|
| 482 |
+
)
|
| 483 |
+
self.xlstm_fwd = xLSTMBlockStack(cfg)
|
| 484 |
+
self.xlstm_bwd = xLSTMBlockStack(cfg)
|
| 485 |
+
|
| 486 |
+
self.initialize_weights()
|
| 487 |
+
|
| 488 |
+
def initialize_weights(self):
|
| 489 |
+
def _basic_init(module):
|
| 490 |
+
if isinstance(module, nn.Linear):
|
| 491 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 492 |
+
if module.bias is not None:
|
| 493 |
+
nn.init.constant_(module.bias, 0)
|
| 494 |
+
|
| 495 |
+
self.apply(_basic_init)
|
| 496 |
+
|
| 497 |
+
def forward(self,
|
| 498 |
+
input_tensor,
|
| 499 |
+
track_mask=None):
|
| 500 |
+
""" Updating with Transformer
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
input_tensor: B, N, T, C
|
| 504 |
+
arap_embed: B, N, T, C
|
| 505 |
+
"""
|
| 506 |
+
B, N, T, C = input_tensor.shape
|
| 507 |
+
x = self.input_transform(input_tensor)
|
| 508 |
+
|
| 509 |
+
track_mask = track_mask.permute(0,2,1,3).float()
|
| 510 |
+
fwd_x = x*track_mask
|
| 511 |
+
bwd_x = x.flip(2)*track_mask.flip(2)
|
| 512 |
+
feat_fwd = self.xlstm_fwd(self.norm(fwd_x.view(B*N, T, -1)))
|
| 513 |
+
feat_bwd = self.xlstm_bwd(self.norm(bwd_x.view(B*N, T, -1)))
|
| 514 |
+
feat = (feat_bwd.flip(1) + feat_fwd).view(B, N, T, -1)
|
| 515 |
+
|
| 516 |
+
flow = self.flow_head(feat)
|
| 517 |
+
|
| 518 |
+
return flow[..., :2], flow[..., 2:]
|
| 519 |
+
|
models/SpaTrackV2/models/camera_transform.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Adapted from https://github.com/amyxlase/relpose-plus-plus
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def bbox_xyxy_to_xywh(xyxy):
|
| 18 |
+
wh = xyxy[2:] - xyxy[:2]
|
| 19 |
+
xywh = np.concatenate([xyxy[:2], wh])
|
| 20 |
+
return xywh
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def adjust_camera_to_bbox_crop_(fl, pp, image_size_wh: torch.Tensor, clamp_bbox_xywh: torch.Tensor):
|
| 24 |
+
focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, image_size_wh)
|
| 25 |
+
|
| 26 |
+
principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2]
|
| 27 |
+
|
| 28 |
+
focal_length, principal_point_cropped = _convert_pixels_to_ndc(
|
| 29 |
+
focal_length_px, principal_point_px_cropped, clamp_bbox_xywh[2:]
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return focal_length, principal_point_cropped
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def adjust_camera_to_image_scale_(fl, pp, original_size_wh: torch.Tensor, new_size_wh: torch.LongTensor):
|
| 36 |
+
focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, original_size_wh)
|
| 37 |
+
|
| 38 |
+
# now scale and convert from pixels to NDC
|
| 39 |
+
image_size_wh_output = new_size_wh.float()
|
| 40 |
+
scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values
|
| 41 |
+
focal_length_px_scaled = focal_length_px * scale
|
| 42 |
+
principal_point_px_scaled = principal_point_px * scale
|
| 43 |
+
|
| 44 |
+
focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc(
|
| 45 |
+
focal_length_px_scaled, principal_point_px_scaled, image_size_wh_output
|
| 46 |
+
)
|
| 47 |
+
return focal_length_scaled, principal_point_scaled
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _convert_ndc_to_pixels(focal_length: torch.Tensor, principal_point: torch.Tensor, image_size_wh: torch.Tensor):
|
| 51 |
+
half_image_size = image_size_wh / 2
|
| 52 |
+
rescale = half_image_size.min()
|
| 53 |
+
principal_point_px = half_image_size - principal_point * rescale
|
| 54 |
+
focal_length_px = focal_length * rescale
|
| 55 |
+
return focal_length_px, principal_point_px
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _convert_pixels_to_ndc(
|
| 59 |
+
focal_length_px: torch.Tensor, principal_point_px: torch.Tensor, image_size_wh: torch.Tensor
|
| 60 |
+
):
|
| 61 |
+
half_image_size = image_size_wh / 2
|
| 62 |
+
rescale = half_image_size.min()
|
| 63 |
+
principal_point = (half_image_size - principal_point_px) / rescale
|
| 64 |
+
focal_length = focal_length_px / rescale
|
| 65 |
+
return focal_length, principal_point
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def normalize_cameras(
|
| 69 |
+
cameras, compute_optical=True, first_camera=True, normalize_trans=True, scale=1.0, points=None, max_norm=False,
|
| 70 |
+
pose_mode="C2W"
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
Normalizes cameras such that
|
| 74 |
+
(1) the optical axes point to the origin and the average distance to the origin is 1
|
| 75 |
+
(2) the first camera is the origin
|
| 76 |
+
(3) the translation vector is normalized
|
| 77 |
+
|
| 78 |
+
TODO: some transforms overlap with others. no need to do so many transforms
|
| 79 |
+
Args:
|
| 80 |
+
cameras (List[camera]).
|
| 81 |
+
"""
|
| 82 |
+
# Let distance from first camera to origin be unit
|
| 83 |
+
new_cameras = cameras.clone()
|
| 84 |
+
scale = 1.0
|
| 85 |
+
|
| 86 |
+
if compute_optical:
|
| 87 |
+
new_cameras, points = compute_optical_transform(new_cameras, points=points)
|
| 88 |
+
if first_camera:
|
| 89 |
+
new_cameras, points = first_camera_transform(new_cameras, points=points, pose_mode=pose_mode)
|
| 90 |
+
if normalize_trans:
|
| 91 |
+
new_cameras, points, scale = normalize_translation(new_cameras,
|
| 92 |
+
points=points, max_norm=max_norm)
|
| 93 |
+
return new_cameras, points, scale
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def compute_optical_transform(new_cameras, points=None):
|
| 97 |
+
"""
|
| 98 |
+
adapted from https://github.com/amyxlase/relpose-plus-plus
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
new_transform = new_cameras.get_world_to_view_transform()
|
| 102 |
+
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(new_cameras)
|
| 103 |
+
t = Translate(p_intersect)
|
| 104 |
+
scale = dist.squeeze()[0]
|
| 105 |
+
|
| 106 |
+
if points is not None:
|
| 107 |
+
points = t.inverse().transform_points(points)
|
| 108 |
+
points = points / scale
|
| 109 |
+
|
| 110 |
+
# Degenerate case
|
| 111 |
+
if scale == 0:
|
| 112 |
+
scale = torch.norm(new_cameras.T, dim=(0, 1))
|
| 113 |
+
scale = torch.sqrt(scale)
|
| 114 |
+
new_cameras.T = new_cameras.T / scale
|
| 115 |
+
else:
|
| 116 |
+
new_matrix = t.compose(new_transform).get_matrix()
|
| 117 |
+
new_cameras.R = new_matrix[:, :3, :3]
|
| 118 |
+
new_cameras.T = new_matrix[:, 3, :3] / scale
|
| 119 |
+
|
| 120 |
+
return new_cameras, points
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def compute_optical_axis_intersection(cameras):
|
| 124 |
+
centers = cameras.get_camera_center()
|
| 125 |
+
principal_points = cameras.principal_point
|
| 126 |
+
|
| 127 |
+
one_vec = torch.ones((len(cameras), 1))
|
| 128 |
+
optical_axis = torch.cat((principal_points, one_vec), -1)
|
| 129 |
+
|
| 130 |
+
pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
|
| 131 |
+
|
| 132 |
+
pp2 = pp[torch.arange(pp.shape[0]), torch.arange(pp.shape[0])]
|
| 133 |
+
|
| 134 |
+
directions = pp2 - centers
|
| 135 |
+
centers = centers.unsqueeze(0).unsqueeze(0)
|
| 136 |
+
directions = directions.unsqueeze(0).unsqueeze(0)
|
| 137 |
+
|
| 138 |
+
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(p=centers, r=directions, mask=None)
|
| 139 |
+
|
| 140 |
+
p_intersect = p_intersect.squeeze().unsqueeze(0)
|
| 141 |
+
dist = (p_intersect - centers).norm(dim=-1)
|
| 142 |
+
|
| 143 |
+
return p_intersect, dist, p_line_intersect, pp2, r
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def intersect_skew_line_groups(p, r, mask):
|
| 147 |
+
# p, r both of shape (B, N, n_intersected_lines, 3)
|
| 148 |
+
# mask of shape (B, N, n_intersected_lines)
|
| 149 |
+
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
|
| 150 |
+
_, p_line_intersect = _point_line_distance(p, r, p_intersect[..., None, :].expand_as(p))
|
| 151 |
+
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(dim=-1)
|
| 152 |
+
return p_intersect, p_line_intersect, intersect_dist_squared, r
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def intersect_skew_lines_high_dim(p, r, mask=None):
|
| 156 |
+
# Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
|
| 157 |
+
dim = p.shape[-1]
|
| 158 |
+
# make sure the heading vectors are l2-normed
|
| 159 |
+
if mask is None:
|
| 160 |
+
mask = torch.ones_like(p[..., 0])
|
| 161 |
+
r = torch.nn.functional.normalize(r, dim=-1)
|
| 162 |
+
|
| 163 |
+
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
|
| 164 |
+
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
|
| 165 |
+
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
|
| 166 |
+
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
|
| 167 |
+
|
| 168 |
+
if torch.any(torch.isnan(p_intersect)):
|
| 169 |
+
print(p_intersect)
|
| 170 |
+
raise ValueError(f"p_intersect is NaN")
|
| 171 |
+
|
| 172 |
+
return p_intersect, r
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _point_line_distance(p1, r1, p2):
|
| 176 |
+
df = p2 - p1
|
| 177 |
+
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
|
| 178 |
+
line_pt_nearest = p2 - proj_vector
|
| 179 |
+
d = (proj_vector).norm(dim=-1)
|
| 180 |
+
return d, line_pt_nearest
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def first_camera_transform(cameras, rotation_only=False,
|
| 184 |
+
points=None, pose_mode="C2W"):
|
| 185 |
+
"""
|
| 186 |
+
Transform so that the first camera is the origin
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
new_cameras = cameras.clone()
|
| 190 |
+
# new_transform = new_cameras.get_world_to_view_transform()
|
| 191 |
+
|
| 192 |
+
R = cameras.R
|
| 193 |
+
T = cameras.T
|
| 194 |
+
Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B, 3, 4]
|
| 195 |
+
Tran_M = torch.cat([Tran_M,
|
| 196 |
+
torch.tensor([[[0, 0, 0, 1]]], device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)], dim=1)
|
| 197 |
+
if pose_mode == "C2W":
|
| 198 |
+
Tran_M_new = (Tran_M[:1,...].inverse())@Tran_M
|
| 199 |
+
elif pose_mode == "W2C":
|
| 200 |
+
Tran_M_new = Tran_M@(Tran_M[:1,...].inverse())
|
| 201 |
+
|
| 202 |
+
if False:
|
| 203 |
+
tR = Rotate(new_cameras.R[0].unsqueeze(0))
|
| 204 |
+
if rotation_only:
|
| 205 |
+
t = tR.inverse()
|
| 206 |
+
else:
|
| 207 |
+
tT = Translate(new_cameras.T[0].unsqueeze(0))
|
| 208 |
+
t = tR.compose(tT).inverse()
|
| 209 |
+
|
| 210 |
+
if points is not None:
|
| 211 |
+
points = t.inverse().transform_points(points)
|
| 212 |
+
|
| 213 |
+
if pose_mode == "C2W":
|
| 214 |
+
new_matrix = new_transform.compose(t).get_matrix()
|
| 215 |
+
else:
|
| 216 |
+
import ipdb; ipdb.set_trace()
|
| 217 |
+
new_matrix = t.compose(new_transform).get_matrix()
|
| 218 |
+
|
| 219 |
+
new_cameras.R = Tran_M_new[:, :3, :3]
|
| 220 |
+
new_cameras.T = Tran_M_new[:, :3, 3]
|
| 221 |
+
|
| 222 |
+
return new_cameras, points
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def normalize_translation(new_cameras, points=None, max_norm=False):
|
| 226 |
+
t_gt = new_cameras.T.clone()
|
| 227 |
+
t_gt = t_gt[1:, :]
|
| 228 |
+
|
| 229 |
+
if max_norm:
|
| 230 |
+
t_gt_norm = torch.norm(t_gt, dim=(-1))
|
| 231 |
+
t_gt_scale = t_gt_norm.max()
|
| 232 |
+
if t_gt_norm.max() < 0.001:
|
| 233 |
+
t_gt_scale = torch.ones_like(t_gt_scale)
|
| 234 |
+
t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
|
| 235 |
+
else:
|
| 236 |
+
t_gt_norm = torch.norm(t_gt, dim=(0, 1))
|
| 237 |
+
t_gt_scale = t_gt_norm / math.sqrt(len(t_gt))
|
| 238 |
+
t_gt_scale = t_gt_scale / 2
|
| 239 |
+
if t_gt_norm.max() < 0.001:
|
| 240 |
+
t_gt_scale = torch.ones_like(t_gt_scale)
|
| 241 |
+
t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
|
| 242 |
+
|
| 243 |
+
new_cameras.T = new_cameras.T / t_gt_scale
|
| 244 |
+
|
| 245 |
+
if points is not None:
|
| 246 |
+
points = points / t_gt_scale
|
| 247 |
+
|
| 248 |
+
return new_cameras, points, t_gt_scale
|
models/SpaTrackV2/models/depth_refiner/backbone.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------
|
| 2 |
+
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This work is licensed under the NVIDIA Source Code License
|
| 5 |
+
# ---------------------------------------------------------------
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from functools import partial
|
| 10 |
+
|
| 11 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
| 12 |
+
from timm.models import register_model
|
| 13 |
+
from timm.models.vision_transformer import _cfg
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Mlp(nn.Module):
|
| 18 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 19 |
+
super().__init__()
|
| 20 |
+
out_features = out_features or in_features
|
| 21 |
+
hidden_features = hidden_features or in_features
|
| 22 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 23 |
+
self.dwconv = DWConv(hidden_features)
|
| 24 |
+
self.act = act_layer()
|
| 25 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 26 |
+
self.drop = nn.Dropout(drop)
|
| 27 |
+
|
| 28 |
+
self.apply(self._init_weights)
|
| 29 |
+
|
| 30 |
+
def _init_weights(self, m):
|
| 31 |
+
if isinstance(m, nn.Linear):
|
| 32 |
+
trunc_normal_(m.weight, std=.02)
|
| 33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 34 |
+
nn.init.constant_(m.bias, 0)
|
| 35 |
+
elif isinstance(m, nn.LayerNorm):
|
| 36 |
+
nn.init.constant_(m.bias, 0)
|
| 37 |
+
nn.init.constant_(m.weight, 1.0)
|
| 38 |
+
elif isinstance(m, nn.Conv2d):
|
| 39 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 40 |
+
fan_out //= m.groups
|
| 41 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 42 |
+
if m.bias is not None:
|
| 43 |
+
m.bias.data.zero_()
|
| 44 |
+
|
| 45 |
+
def forward(self, x, H, W):
|
| 46 |
+
x = self.fc1(x)
|
| 47 |
+
x = self.dwconv(x, H, W)
|
| 48 |
+
x = self.act(x)
|
| 49 |
+
x = self.drop(x)
|
| 50 |
+
x = self.fc2(x)
|
| 51 |
+
x = self.drop(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Attention(nn.Module):
|
| 56 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
|
| 57 |
+
super().__init__()
|
| 58 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
| 59 |
+
|
| 60 |
+
self.dim = dim
|
| 61 |
+
self.num_heads = num_heads
|
| 62 |
+
head_dim = dim // num_heads
|
| 63 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 64 |
+
|
| 65 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
| 66 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
| 67 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 68 |
+
self.proj = nn.Linear(dim, dim)
|
| 69 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 70 |
+
|
| 71 |
+
self.sr_ratio = sr_ratio
|
| 72 |
+
if sr_ratio > 1:
|
| 73 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
| 74 |
+
self.norm = nn.LayerNorm(dim)
|
| 75 |
+
|
| 76 |
+
self.apply(self._init_weights)
|
| 77 |
+
|
| 78 |
+
def _init_weights(self, m):
|
| 79 |
+
if isinstance(m, nn.Linear):
|
| 80 |
+
trunc_normal_(m.weight, std=.02)
|
| 81 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 82 |
+
nn.init.constant_(m.bias, 0)
|
| 83 |
+
elif isinstance(m, nn.LayerNorm):
|
| 84 |
+
nn.init.constant_(m.bias, 0)
|
| 85 |
+
nn.init.constant_(m.weight, 1.0)
|
| 86 |
+
elif isinstance(m, nn.Conv2d):
|
| 87 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 88 |
+
fan_out //= m.groups
|
| 89 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 90 |
+
if m.bias is not None:
|
| 91 |
+
m.bias.data.zero_()
|
| 92 |
+
|
| 93 |
+
def forward(self, x, H, W):
|
| 94 |
+
B, N, C = x.shape
|
| 95 |
+
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 96 |
+
|
| 97 |
+
if self.sr_ratio > 1:
|
| 98 |
+
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
| 99 |
+
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
| 100 |
+
x_ = self.norm(x_)
|
| 101 |
+
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 102 |
+
else:
|
| 103 |
+
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 104 |
+
k, v = kv[0], kv[1]
|
| 105 |
+
|
| 106 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 107 |
+
attn = attn.softmax(dim=-1)
|
| 108 |
+
attn = self.attn_drop(attn)
|
| 109 |
+
|
| 110 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 111 |
+
x = self.proj(x)
|
| 112 |
+
x = self.proj_drop(x)
|
| 113 |
+
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Block(nn.Module):
|
| 118 |
+
|
| 119 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 120 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.norm1 = norm_layer(dim)
|
| 123 |
+
self.attn = Attention(
|
| 124 |
+
dim,
|
| 125 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 126 |
+
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
|
| 127 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 128 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 129 |
+
self.norm2 = norm_layer(dim)
|
| 130 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 131 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 132 |
+
|
| 133 |
+
self.apply(self._init_weights)
|
| 134 |
+
|
| 135 |
+
def _init_weights(self, m):
|
| 136 |
+
if isinstance(m, nn.Linear):
|
| 137 |
+
trunc_normal_(m.weight, std=.02)
|
| 138 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 139 |
+
nn.init.constant_(m.bias, 0)
|
| 140 |
+
elif isinstance(m, nn.LayerNorm):
|
| 141 |
+
nn.init.constant_(m.bias, 0)
|
| 142 |
+
nn.init.constant_(m.weight, 1.0)
|
| 143 |
+
elif isinstance(m, nn.Conv2d):
|
| 144 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 145 |
+
fan_out //= m.groups
|
| 146 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 147 |
+
if m.bias is not None:
|
| 148 |
+
m.bias.data.zero_()
|
| 149 |
+
|
| 150 |
+
def forward(self, x, H, W):
|
| 151 |
+
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
| 152 |
+
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
| 153 |
+
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class OverlapPatchEmbed(nn.Module):
|
| 158 |
+
""" Image to Patch Embedding
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
| 162 |
+
super().__init__()
|
| 163 |
+
img_size = to_2tuple(img_size)
|
| 164 |
+
patch_size = to_2tuple(patch_size)
|
| 165 |
+
|
| 166 |
+
self.img_size = img_size
|
| 167 |
+
self.patch_size = patch_size
|
| 168 |
+
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
| 169 |
+
self.num_patches = self.H * self.W
|
| 170 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
| 171 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
| 172 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 173 |
+
|
| 174 |
+
self.apply(self._init_weights)
|
| 175 |
+
|
| 176 |
+
def _init_weights(self, m):
|
| 177 |
+
if isinstance(m, nn.Linear):
|
| 178 |
+
trunc_normal_(m.weight, std=.02)
|
| 179 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 180 |
+
nn.init.constant_(m.bias, 0)
|
| 181 |
+
elif isinstance(m, nn.LayerNorm):
|
| 182 |
+
nn.init.constant_(m.bias, 0)
|
| 183 |
+
nn.init.constant_(m.weight, 1.0)
|
| 184 |
+
elif isinstance(m, nn.Conv2d):
|
| 185 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 186 |
+
fan_out //= m.groups
|
| 187 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 188 |
+
if m.bias is not None:
|
| 189 |
+
m.bias.data.zero_()
|
| 190 |
+
|
| 191 |
+
def forward(self, x):
|
| 192 |
+
x = self.proj(x)
|
| 193 |
+
_, _, H, W = x.shape
|
| 194 |
+
x = x.flatten(2).transpose(1, 2)
|
| 195 |
+
x = self.norm(x)
|
| 196 |
+
|
| 197 |
+
return x, H, W
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class OverlapPatchEmbed43(nn.Module):
|
| 203 |
+
""" Image to Patch Embedding
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
| 207 |
+
super().__init__()
|
| 208 |
+
img_size = to_2tuple(img_size)
|
| 209 |
+
patch_size = to_2tuple(patch_size)
|
| 210 |
+
|
| 211 |
+
self.img_size = img_size
|
| 212 |
+
self.patch_size = patch_size
|
| 213 |
+
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
| 214 |
+
self.num_patches = self.H * self.W
|
| 215 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
| 216 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
| 217 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 218 |
+
|
| 219 |
+
self.apply(self._init_weights)
|
| 220 |
+
|
| 221 |
+
def _init_weights(self, m):
|
| 222 |
+
if isinstance(m, nn.Linear):
|
| 223 |
+
trunc_normal_(m.weight, std=.02)
|
| 224 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 225 |
+
nn.init.constant_(m.bias, 0)
|
| 226 |
+
elif isinstance(m, nn.LayerNorm):
|
| 227 |
+
nn.init.constant_(m.bias, 0)
|
| 228 |
+
nn.init.constant_(m.weight, 1.0)
|
| 229 |
+
elif isinstance(m, nn.Conv2d):
|
| 230 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 231 |
+
fan_out //= m.groups
|
| 232 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 233 |
+
if m.bias is not None:
|
| 234 |
+
m.bias.data.zero_()
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
if x.shape[1]==4:
|
| 238 |
+
x = self.proj_4c(x)
|
| 239 |
+
else:
|
| 240 |
+
x = self.proj(x)
|
| 241 |
+
_, _, H, W = x.shape
|
| 242 |
+
x = x.flatten(2).transpose(1, 2)
|
| 243 |
+
x = self.norm(x)
|
| 244 |
+
|
| 245 |
+
return x, H, W
|
| 246 |
+
|
| 247 |
+
class MixVisionTransformer(nn.Module):
|
| 248 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
|
| 249 |
+
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
|
| 250 |
+
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
| 251 |
+
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
|
| 252 |
+
super().__init__()
|
| 253 |
+
self.num_classes = num_classes
|
| 254 |
+
self.depths = depths
|
| 255 |
+
|
| 256 |
+
# patch_embed 43
|
| 257 |
+
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
|
| 258 |
+
embed_dim=embed_dims[0])
|
| 259 |
+
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
|
| 260 |
+
embed_dim=embed_dims[1])
|
| 261 |
+
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
|
| 262 |
+
embed_dim=embed_dims[2])
|
| 263 |
+
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
|
| 264 |
+
embed_dim=embed_dims[3])
|
| 265 |
+
|
| 266 |
+
# transformer encoder
|
| 267 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
| 268 |
+
cur = 0
|
| 269 |
+
self.block1 = nn.ModuleList([Block(
|
| 270 |
+
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 271 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
| 272 |
+
sr_ratio=sr_ratios[0])
|
| 273 |
+
for i in range(depths[0])])
|
| 274 |
+
self.norm1 = norm_layer(embed_dims[0])
|
| 275 |
+
|
| 276 |
+
cur += depths[0]
|
| 277 |
+
self.block2 = nn.ModuleList([Block(
|
| 278 |
+
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 279 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
| 280 |
+
sr_ratio=sr_ratios[1])
|
| 281 |
+
for i in range(depths[1])])
|
| 282 |
+
self.norm2 = norm_layer(embed_dims[1])
|
| 283 |
+
|
| 284 |
+
cur += depths[1]
|
| 285 |
+
self.block3 = nn.ModuleList([Block(
|
| 286 |
+
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 287 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
| 288 |
+
sr_ratio=sr_ratios[2])
|
| 289 |
+
for i in range(depths[2])])
|
| 290 |
+
self.norm3 = norm_layer(embed_dims[2])
|
| 291 |
+
|
| 292 |
+
cur += depths[2]
|
| 293 |
+
self.block4 = nn.ModuleList([Block(
|
| 294 |
+
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 295 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
| 296 |
+
sr_ratio=sr_ratios[3])
|
| 297 |
+
for i in range(depths[3])])
|
| 298 |
+
self.norm4 = norm_layer(embed_dims[3])
|
| 299 |
+
|
| 300 |
+
# classification head
|
| 301 |
+
# self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
|
| 302 |
+
|
| 303 |
+
self.apply(self._init_weights)
|
| 304 |
+
|
| 305 |
+
def _init_weights(self, m):
|
| 306 |
+
if isinstance(m, nn.Linear):
|
| 307 |
+
trunc_normal_(m.weight, std=.02)
|
| 308 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 309 |
+
nn.init.constant_(m.bias, 0)
|
| 310 |
+
elif isinstance(m, nn.LayerNorm):
|
| 311 |
+
nn.init.constant_(m.bias, 0)
|
| 312 |
+
nn.init.constant_(m.weight, 1.0)
|
| 313 |
+
elif isinstance(m, nn.Conv2d):
|
| 314 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 315 |
+
fan_out //= m.groups
|
| 316 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 317 |
+
if m.bias is not None:
|
| 318 |
+
m.bias.data.zero_()
|
| 319 |
+
|
| 320 |
+
def init_weights(self, pretrained=None):
|
| 321 |
+
if isinstance(pretrained, str):
|
| 322 |
+
logger = get_root_logger()
|
| 323 |
+
load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
|
| 324 |
+
|
| 325 |
+
def reset_drop_path(self, drop_path_rate):
|
| 326 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
| 327 |
+
cur = 0
|
| 328 |
+
for i in range(self.depths[0]):
|
| 329 |
+
self.block1[i].drop_path.drop_prob = dpr[cur + i]
|
| 330 |
+
|
| 331 |
+
cur += self.depths[0]
|
| 332 |
+
for i in range(self.depths[1]):
|
| 333 |
+
self.block2[i].drop_path.drop_prob = dpr[cur + i]
|
| 334 |
+
|
| 335 |
+
cur += self.depths[1]
|
| 336 |
+
for i in range(self.depths[2]):
|
| 337 |
+
self.block3[i].drop_path.drop_prob = dpr[cur + i]
|
| 338 |
+
|
| 339 |
+
cur += self.depths[2]
|
| 340 |
+
for i in range(self.depths[3]):
|
| 341 |
+
self.block4[i].drop_path.drop_prob = dpr[cur + i]
|
| 342 |
+
|
| 343 |
+
def freeze_patch_emb(self):
|
| 344 |
+
self.patch_embed1.requires_grad = False
|
| 345 |
+
|
| 346 |
+
@torch.jit.ignore
|
| 347 |
+
def no_weight_decay(self):
|
| 348 |
+
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
|
| 349 |
+
|
| 350 |
+
def get_classifier(self):
|
| 351 |
+
return self.head
|
| 352 |
+
|
| 353 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
| 354 |
+
self.num_classes = num_classes
|
| 355 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 356 |
+
|
| 357 |
+
def forward_features(self, x):
|
| 358 |
+
B = x.shape[0]
|
| 359 |
+
outs = []
|
| 360 |
+
|
| 361 |
+
# stage 1
|
| 362 |
+
x, H, W = self.patch_embed1(x)
|
| 363 |
+
for i, blk in enumerate(self.block1):
|
| 364 |
+
x = blk(x, H, W)
|
| 365 |
+
x = self.norm1(x)
|
| 366 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 367 |
+
outs.append(x)
|
| 368 |
+
|
| 369 |
+
# stage 2
|
| 370 |
+
x, H, W = self.patch_embed2(x)
|
| 371 |
+
for i, blk in enumerate(self.block2):
|
| 372 |
+
x = blk(x, H, W)
|
| 373 |
+
x = self.norm2(x)
|
| 374 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 375 |
+
outs.append(x)
|
| 376 |
+
|
| 377 |
+
# stage 3
|
| 378 |
+
x, H, W = self.patch_embed3(x)
|
| 379 |
+
for i, blk in enumerate(self.block3):
|
| 380 |
+
x = blk(x, H, W)
|
| 381 |
+
x = self.norm3(x)
|
| 382 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 383 |
+
outs.append(x)
|
| 384 |
+
|
| 385 |
+
# stage 4
|
| 386 |
+
x, H, W = self.patch_embed4(x)
|
| 387 |
+
for i, blk in enumerate(self.block4):
|
| 388 |
+
x = blk(x, H, W)
|
| 389 |
+
x = self.norm4(x)
|
| 390 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 391 |
+
outs.append(x)
|
| 392 |
+
|
| 393 |
+
return outs
|
| 394 |
+
|
| 395 |
+
def forward(self, x):
|
| 396 |
+
if x.dim() == 5:
|
| 397 |
+
x = x.reshape(x.shape[0]*x.shape[1],x.shape[2],x.shape[3],x.shape[4])
|
| 398 |
+
x = self.forward_features(x)
|
| 399 |
+
# x = self.head(x)
|
| 400 |
+
|
| 401 |
+
return x
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class DWConv(nn.Module):
|
| 405 |
+
def __init__(self, dim=768):
|
| 406 |
+
super(DWConv, self).__init__()
|
| 407 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
| 408 |
+
|
| 409 |
+
def forward(self, x, H, W):
|
| 410 |
+
B, N, C = x.shape
|
| 411 |
+
x = x.transpose(1, 2).view(B, C, H, W)
|
| 412 |
+
x = self.dwconv(x)
|
| 413 |
+
x = x.flatten(2).transpose(1, 2)
|
| 414 |
+
|
| 415 |
+
return x
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
#@BACKBONES.register_module()
|
| 420 |
+
class mit_b0(MixVisionTransformer):
|
| 421 |
+
def __init__(self, **kwargs):
|
| 422 |
+
super(mit_b0, self).__init__(
|
| 423 |
+
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
| 424 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
| 425 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
#@BACKBONES.register_module()
|
| 429 |
+
class mit_b1(MixVisionTransformer):
|
| 430 |
+
def __init__(self, **kwargs):
|
| 431 |
+
super(mit_b1, self).__init__(
|
| 432 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
| 433 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
| 434 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
#@BACKBONES.register_module()
|
| 438 |
+
class mit_b2(MixVisionTransformer):
|
| 439 |
+
def __init__(self, **kwargs):
|
| 440 |
+
super(mit_b2, self).__init__(
|
| 441 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
| 442 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
|
| 443 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
#@BACKBONES.register_module()
|
| 447 |
+
class mit_b3(MixVisionTransformer):
|
| 448 |
+
def __init__(self, **kwargs):
|
| 449 |
+
super(mit_b3, self).__init__(
|
| 450 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
| 451 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
| 452 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
#@BACKBONES.register_module()
|
| 456 |
+
class mit_b4(MixVisionTransformer):
|
| 457 |
+
def __init__(self, **kwargs):
|
| 458 |
+
super(mit_b4, self).__init__(
|
| 459 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
| 460 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
|
| 461 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
#@BACKBONES.register_module()
|
| 465 |
+
class mit_b5(MixVisionTransformer):
|
| 466 |
+
def __init__(self, **kwargs):
|
| 467 |
+
super(mit_b5, self).__init__(
|
| 468 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
| 469 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
|
| 470 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
| 471 |
+
|
| 472 |
+
|
models/SpaTrackV2/models/depth_refiner/decode_head.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABCMeta, abstractmethod
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
# from mmcv.cnn import normal_init
|
| 7 |
+
# from mmcv.runner import auto_fp16, force_fp32
|
| 8 |
+
|
| 9 |
+
# from mmseg.core import build_pixel_sampler
|
| 10 |
+
# from mmseg.ops import resize
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
| 14 |
+
"""Base class for BaseDecodeHead.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 18 |
+
channels (int): Channels after modules, before conv_seg.
|
| 19 |
+
num_classes (int): Number of classes.
|
| 20 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
| 21 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
| 22 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
| 23 |
+
act_cfg (dict): Config of activation layers.
|
| 24 |
+
Default: dict(type='ReLU')
|
| 25 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
| 26 |
+
input_transform (str|None): Transformation type of input features.
|
| 27 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 28 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 29 |
+
same size as first one and than concat together.
|
| 30 |
+
Usually used in FCN head of HRNet.
|
| 31 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 32 |
+
a list and passed into decode head.
|
| 33 |
+
None: Only one select feature map is allowed.
|
| 34 |
+
Default: None.
|
| 35 |
+
loss_decode (dict): Config of decode loss.
|
| 36 |
+
Default: dict(type='CrossEntropyLoss').
|
| 37 |
+
ignore_index (int | None): The label index to be ignored. When using
|
| 38 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
| 39 |
+
sampler (dict|None): The config of segmentation map sampler.
|
| 40 |
+
Default: None.
|
| 41 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 42 |
+
Default: False.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self,
|
| 46 |
+
in_channels,
|
| 47 |
+
channels,
|
| 48 |
+
*,
|
| 49 |
+
num_classes,
|
| 50 |
+
dropout_ratio=0.1,
|
| 51 |
+
conv_cfg=None,
|
| 52 |
+
norm_cfg=None,
|
| 53 |
+
act_cfg=dict(type='ReLU'),
|
| 54 |
+
in_index=-1,
|
| 55 |
+
input_transform=None,
|
| 56 |
+
loss_decode=dict(
|
| 57 |
+
type='CrossEntropyLoss',
|
| 58 |
+
use_sigmoid=False,
|
| 59 |
+
loss_weight=1.0),
|
| 60 |
+
decoder_params=None,
|
| 61 |
+
ignore_index=255,
|
| 62 |
+
sampler=None,
|
| 63 |
+
align_corners=False):
|
| 64 |
+
super(BaseDecodeHead, self).__init__()
|
| 65 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
| 66 |
+
self.channels = channels
|
| 67 |
+
self.num_classes = num_classes
|
| 68 |
+
self.dropout_ratio = dropout_ratio
|
| 69 |
+
self.conv_cfg = conv_cfg
|
| 70 |
+
self.norm_cfg = norm_cfg
|
| 71 |
+
self.act_cfg = act_cfg
|
| 72 |
+
self.in_index = in_index
|
| 73 |
+
self.ignore_index = ignore_index
|
| 74 |
+
self.align_corners = align_corners
|
| 75 |
+
|
| 76 |
+
if sampler is not None:
|
| 77 |
+
self.sampler = build_pixel_sampler(sampler, context=self)
|
| 78 |
+
else:
|
| 79 |
+
self.sampler = None
|
| 80 |
+
|
| 81 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
| 82 |
+
if dropout_ratio > 0:
|
| 83 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
| 84 |
+
else:
|
| 85 |
+
self.dropout = None
|
| 86 |
+
self.fp16_enabled = False
|
| 87 |
+
|
| 88 |
+
def extra_repr(self):
|
| 89 |
+
"""Extra repr."""
|
| 90 |
+
s = f'input_transform={self.input_transform}, ' \
|
| 91 |
+
f'ignore_index={self.ignore_index}, ' \
|
| 92 |
+
f'align_corners={self.align_corners}'
|
| 93 |
+
return s
|
| 94 |
+
|
| 95 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
| 96 |
+
"""Check and initialize input transforms.
|
| 97 |
+
|
| 98 |
+
The in_channels, in_index and input_transform must match.
|
| 99 |
+
Specifically, when input_transform is None, only single feature map
|
| 100 |
+
will be selected. So in_channels and in_index must be of type int.
|
| 101 |
+
When input_transform
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 105 |
+
in_index (int|Sequence[int]): Input feature index.
|
| 106 |
+
input_transform (str|None): Transformation type of input features.
|
| 107 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 108 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 109 |
+
same size as first one and than concat together.
|
| 110 |
+
Usually used in FCN head of HRNet.
|
| 111 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 112 |
+
a list and passed into decode head.
|
| 113 |
+
None: Only one select feature map is allowed.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
if input_transform is not None:
|
| 117 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
| 118 |
+
self.input_transform = input_transform
|
| 119 |
+
self.in_index = in_index
|
| 120 |
+
if input_transform is not None:
|
| 121 |
+
assert isinstance(in_channels, (list, tuple))
|
| 122 |
+
assert isinstance(in_index, (list, tuple))
|
| 123 |
+
assert len(in_channels) == len(in_index)
|
| 124 |
+
if input_transform == 'resize_concat':
|
| 125 |
+
self.in_channels = sum(in_channels)
|
| 126 |
+
else:
|
| 127 |
+
self.in_channels = in_channels
|
| 128 |
+
else:
|
| 129 |
+
assert isinstance(in_channels, int)
|
| 130 |
+
assert isinstance(in_index, int)
|
| 131 |
+
self.in_channels = in_channels
|
| 132 |
+
|
| 133 |
+
def init_weights(self):
|
| 134 |
+
"""Initialize weights of classification layer."""
|
| 135 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
| 136 |
+
|
| 137 |
+
def _transform_inputs(self, inputs):
|
| 138 |
+
"""Transform inputs for decoder.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
Tensor: The transformed inputs
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
if self.input_transform == 'resize_concat':
|
| 148 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 149 |
+
upsampled_inputs = [
|
| 150 |
+
resize(
|
| 151 |
+
input=x,
|
| 152 |
+
size=inputs[0].shape[2:],
|
| 153 |
+
mode='bilinear',
|
| 154 |
+
align_corners=self.align_corners) for x in inputs
|
| 155 |
+
]
|
| 156 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
| 157 |
+
elif self.input_transform == 'multiple_select':
|
| 158 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 159 |
+
else:
|
| 160 |
+
inputs = inputs[self.in_index]
|
| 161 |
+
|
| 162 |
+
return inputs
|
| 163 |
+
|
| 164 |
+
# @auto_fp16()
|
| 165 |
+
@abstractmethod
|
| 166 |
+
def forward(self, inputs):
|
| 167 |
+
"""Placeholder of forward function."""
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
|
| 171 |
+
"""Forward function for training.
|
| 172 |
+
Args:
|
| 173 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 174 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 175 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 176 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 177 |
+
For details on the values of these keys see
|
| 178 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 179 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
| 180 |
+
used if the architecture supports semantic segmentation task.
|
| 181 |
+
train_cfg (dict): The training config.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
dict[str, Tensor]: a dictionary of loss components
|
| 185 |
+
"""
|
| 186 |
+
seg_logits = self.forward(inputs)
|
| 187 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
| 188 |
+
return losses
|
| 189 |
+
|
| 190 |
+
def forward_test(self, inputs, img_metas, test_cfg):
|
| 191 |
+
"""Forward function for testing.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 195 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 196 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 197 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 198 |
+
For details on the values of these keys see
|
| 199 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 200 |
+
test_cfg (dict): The testing config.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Tensor: Output segmentation map.
|
| 204 |
+
"""
|
| 205 |
+
return self.forward(inputs)
|
| 206 |
+
|
| 207 |
+
def cls_seg(self, feat):
|
| 208 |
+
"""Classify each pixel."""
|
| 209 |
+
if self.dropout is not None:
|
| 210 |
+
feat = self.dropout(feat)
|
| 211 |
+
output = self.conv_seg(feat)
|
| 212 |
+
return output
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class BaseDecodeHead_clips(nn.Module, metaclass=ABCMeta):
|
| 216 |
+
"""Base class for BaseDecodeHead_clips.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 220 |
+
channels (int): Channels after modules, before conv_seg.
|
| 221 |
+
num_classes (int): Number of classes.
|
| 222 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
| 223 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
| 224 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
| 225 |
+
act_cfg (dict): Config of activation layers.
|
| 226 |
+
Default: dict(type='ReLU')
|
| 227 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
| 228 |
+
input_transform (str|None): Transformation type of input features.
|
| 229 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 230 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 231 |
+
same size as first one and than concat together.
|
| 232 |
+
Usually used in FCN head of HRNet.
|
| 233 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 234 |
+
a list and passed into decode head.
|
| 235 |
+
None: Only one select feature map is allowed.
|
| 236 |
+
Default: None.
|
| 237 |
+
loss_decode (dict): Config of decode loss.
|
| 238 |
+
Default: dict(type='CrossEntropyLoss').
|
| 239 |
+
ignore_index (int | None): The label index to be ignored. When using
|
| 240 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
| 241 |
+
sampler (dict|None): The config of segmentation map sampler.
|
| 242 |
+
Default: None.
|
| 243 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 244 |
+
Default: False.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(self,
|
| 248 |
+
in_channels,
|
| 249 |
+
channels,
|
| 250 |
+
*,
|
| 251 |
+
num_classes,
|
| 252 |
+
dropout_ratio=0.1,
|
| 253 |
+
conv_cfg=None,
|
| 254 |
+
norm_cfg=None,
|
| 255 |
+
act_cfg=dict(type='ReLU'),
|
| 256 |
+
in_index=-1,
|
| 257 |
+
input_transform=None,
|
| 258 |
+
loss_decode=dict(
|
| 259 |
+
type='CrossEntropyLoss',
|
| 260 |
+
use_sigmoid=False,
|
| 261 |
+
loss_weight=1.0),
|
| 262 |
+
decoder_params=None,
|
| 263 |
+
ignore_index=255,
|
| 264 |
+
sampler=None,
|
| 265 |
+
align_corners=False,
|
| 266 |
+
num_clips=5):
|
| 267 |
+
super(BaseDecodeHead_clips, self).__init__()
|
| 268 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
| 269 |
+
self.channels = channels
|
| 270 |
+
self.num_classes = num_classes
|
| 271 |
+
self.dropout_ratio = dropout_ratio
|
| 272 |
+
self.conv_cfg = conv_cfg
|
| 273 |
+
self.norm_cfg = norm_cfg
|
| 274 |
+
self.act_cfg = act_cfg
|
| 275 |
+
self.in_index = in_index
|
| 276 |
+
self.ignore_index = ignore_index
|
| 277 |
+
self.align_corners = align_corners
|
| 278 |
+
self.num_clips=num_clips
|
| 279 |
+
|
| 280 |
+
if sampler is not None:
|
| 281 |
+
self.sampler = build_pixel_sampler(sampler, context=self)
|
| 282 |
+
else:
|
| 283 |
+
self.sampler = None
|
| 284 |
+
|
| 285 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
| 286 |
+
if dropout_ratio > 0:
|
| 287 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
| 288 |
+
else:
|
| 289 |
+
self.dropout = None
|
| 290 |
+
self.fp16_enabled = False
|
| 291 |
+
|
| 292 |
+
def extra_repr(self):
|
| 293 |
+
"""Extra repr."""
|
| 294 |
+
s = f'input_transform={self.input_transform}, ' \
|
| 295 |
+
f'ignore_index={self.ignore_index}, ' \
|
| 296 |
+
f'align_corners={self.align_corners}'
|
| 297 |
+
return s
|
| 298 |
+
|
| 299 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
| 300 |
+
"""Check and initialize input transforms.
|
| 301 |
+
|
| 302 |
+
The in_channels, in_index and input_transform must match.
|
| 303 |
+
Specifically, when input_transform is None, only single feature map
|
| 304 |
+
will be selected. So in_channels and in_index must be of type int.
|
| 305 |
+
When input_transform
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 309 |
+
in_index (int|Sequence[int]): Input feature index.
|
| 310 |
+
input_transform (str|None): Transformation type of input features.
|
| 311 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 312 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 313 |
+
same size as first one and than concat together.
|
| 314 |
+
Usually used in FCN head of HRNet.
|
| 315 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 316 |
+
a list and passed into decode head.
|
| 317 |
+
None: Only one select feature map is allowed.
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
if input_transform is not None:
|
| 321 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
| 322 |
+
self.input_transform = input_transform
|
| 323 |
+
self.in_index = in_index
|
| 324 |
+
if input_transform is not None:
|
| 325 |
+
assert isinstance(in_channels, (list, tuple))
|
| 326 |
+
assert isinstance(in_index, (list, tuple))
|
| 327 |
+
assert len(in_channels) == len(in_index)
|
| 328 |
+
if input_transform == 'resize_concat':
|
| 329 |
+
self.in_channels = sum(in_channels)
|
| 330 |
+
else:
|
| 331 |
+
self.in_channels = in_channels
|
| 332 |
+
else:
|
| 333 |
+
assert isinstance(in_channels, int)
|
| 334 |
+
assert isinstance(in_index, int)
|
| 335 |
+
self.in_channels = in_channels
|
| 336 |
+
|
| 337 |
+
def init_weights(self):
|
| 338 |
+
"""Initialize weights of classification layer."""
|
| 339 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
| 340 |
+
|
| 341 |
+
def _transform_inputs(self, inputs):
|
| 342 |
+
"""Transform inputs for decoder.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
Tensor: The transformed inputs
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
if self.input_transform == 'resize_concat':
|
| 352 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 353 |
+
upsampled_inputs = [
|
| 354 |
+
resize(
|
| 355 |
+
input=x,
|
| 356 |
+
size=inputs[0].shape[2:],
|
| 357 |
+
mode='bilinear',
|
| 358 |
+
align_corners=self.align_corners) for x in inputs
|
| 359 |
+
]
|
| 360 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
| 361 |
+
elif self.input_transform == 'multiple_select':
|
| 362 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 363 |
+
else:
|
| 364 |
+
inputs = inputs[self.in_index]
|
| 365 |
+
|
| 366 |
+
return inputs
|
| 367 |
+
|
| 368 |
+
# @auto_fp16()
|
| 369 |
+
@abstractmethod
|
| 370 |
+
def forward(self, inputs):
|
| 371 |
+
"""Placeholder of forward function."""
|
| 372 |
+
pass
|
| 373 |
+
|
| 374 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips):
|
| 375 |
+
"""Forward function for training.
|
| 376 |
+
Args:
|
| 377 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 378 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 379 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 380 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 381 |
+
For details on the values of these keys see
|
| 382 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 383 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
| 384 |
+
used if the architecture supports semantic segmentation task.
|
| 385 |
+
train_cfg (dict): The training config.
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
dict[str, Tensor]: a dictionary of loss components
|
| 389 |
+
"""
|
| 390 |
+
seg_logits = self.forward(inputs,batch_size, num_clips)
|
| 391 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
| 392 |
+
return losses
|
| 393 |
+
|
| 394 |
+
def forward_test(self, inputs, img_metas, test_cfg, batch_size, num_clips):
|
| 395 |
+
"""Forward function for testing.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 399 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 400 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 401 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 402 |
+
For details on the values of these keys see
|
| 403 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 404 |
+
test_cfg (dict): The testing config.
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
Tensor: Output segmentation map.
|
| 408 |
+
"""
|
| 409 |
+
return self.forward(inputs, batch_size, num_clips)
|
| 410 |
+
|
| 411 |
+
def cls_seg(self, feat):
|
| 412 |
+
"""Classify each pixel."""
|
| 413 |
+
if self.dropout is not None:
|
| 414 |
+
feat = self.dropout(feat)
|
| 415 |
+
output = self.conv_seg(feat)
|
| 416 |
+
return output
|
| 417 |
+
|
| 418 |
+
class BaseDecodeHead_clips_flow(nn.Module, metaclass=ABCMeta):
|
| 419 |
+
"""Base class for BaseDecodeHead_clips_flow.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 423 |
+
channels (int): Channels after modules, before conv_seg.
|
| 424 |
+
num_classes (int): Number of classes.
|
| 425 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
| 426 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
| 427 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
| 428 |
+
act_cfg (dict): Config of activation layers.
|
| 429 |
+
Default: dict(type='ReLU')
|
| 430 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
| 431 |
+
input_transform (str|None): Transformation type of input features.
|
| 432 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 433 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 434 |
+
same size as first one and than concat together.
|
| 435 |
+
Usually used in FCN head of HRNet.
|
| 436 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 437 |
+
a list and passed into decode head.
|
| 438 |
+
None: Only one select feature map is allowed.
|
| 439 |
+
Default: None.
|
| 440 |
+
loss_decode (dict): Config of decode loss.
|
| 441 |
+
Default: dict(type='CrossEntropyLoss').
|
| 442 |
+
ignore_index (int | None): The label index to be ignored. When using
|
| 443 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
| 444 |
+
sampler (dict|None): The config of segmentation map sampler.
|
| 445 |
+
Default: None.
|
| 446 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 447 |
+
Default: False.
|
| 448 |
+
"""
|
| 449 |
+
|
| 450 |
+
def __init__(self,
|
| 451 |
+
in_channels,
|
| 452 |
+
channels,
|
| 453 |
+
*,
|
| 454 |
+
num_classes,
|
| 455 |
+
dropout_ratio=0.1,
|
| 456 |
+
conv_cfg=None,
|
| 457 |
+
norm_cfg=None,
|
| 458 |
+
act_cfg=dict(type='ReLU'),
|
| 459 |
+
in_index=-1,
|
| 460 |
+
input_transform=None,
|
| 461 |
+
loss_decode=dict(
|
| 462 |
+
type='CrossEntropyLoss',
|
| 463 |
+
use_sigmoid=False,
|
| 464 |
+
loss_weight=1.0),
|
| 465 |
+
decoder_params=None,
|
| 466 |
+
ignore_index=255,
|
| 467 |
+
sampler=None,
|
| 468 |
+
align_corners=False,
|
| 469 |
+
num_clips=5):
|
| 470 |
+
super(BaseDecodeHead_clips_flow, self).__init__()
|
| 471 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
| 472 |
+
self.channels = channels
|
| 473 |
+
self.num_classes = num_classes
|
| 474 |
+
self.dropout_ratio = dropout_ratio
|
| 475 |
+
self.conv_cfg = conv_cfg
|
| 476 |
+
self.norm_cfg = norm_cfg
|
| 477 |
+
self.act_cfg = act_cfg
|
| 478 |
+
self.in_index = in_index
|
| 479 |
+
self.ignore_index = ignore_index
|
| 480 |
+
self.align_corners = align_corners
|
| 481 |
+
self.num_clips=num_clips
|
| 482 |
+
|
| 483 |
+
if sampler is not None:
|
| 484 |
+
self.sampler = build_pixel_sampler(sampler, context=self)
|
| 485 |
+
else:
|
| 486 |
+
self.sampler = None
|
| 487 |
+
|
| 488 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
| 489 |
+
if dropout_ratio > 0:
|
| 490 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
| 491 |
+
else:
|
| 492 |
+
self.dropout = None
|
| 493 |
+
self.fp16_enabled = False
|
| 494 |
+
|
| 495 |
+
def extra_repr(self):
|
| 496 |
+
"""Extra repr."""
|
| 497 |
+
s = f'input_transform={self.input_transform}, ' \
|
| 498 |
+
f'ignore_index={self.ignore_index}, ' \
|
| 499 |
+
f'align_corners={self.align_corners}'
|
| 500 |
+
return s
|
| 501 |
+
|
| 502 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
| 503 |
+
"""Check and initialize input transforms.
|
| 504 |
+
|
| 505 |
+
The in_channels, in_index and input_transform must match.
|
| 506 |
+
Specifically, when input_transform is None, only single feature map
|
| 507 |
+
will be selected. So in_channels and in_index must be of type int.
|
| 508 |
+
When input_transform
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 512 |
+
in_index (int|Sequence[int]): Input feature index.
|
| 513 |
+
input_transform (str|None): Transformation type of input features.
|
| 514 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 515 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 516 |
+
same size as first one and than concat together.
|
| 517 |
+
Usually used in FCN head of HRNet.
|
| 518 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 519 |
+
a list and passed into decode head.
|
| 520 |
+
None: Only one select feature map is allowed.
|
| 521 |
+
"""
|
| 522 |
+
|
| 523 |
+
if input_transform is not None:
|
| 524 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
| 525 |
+
self.input_transform = input_transform
|
| 526 |
+
self.in_index = in_index
|
| 527 |
+
if input_transform is not None:
|
| 528 |
+
assert isinstance(in_channels, (list, tuple))
|
| 529 |
+
assert isinstance(in_index, (list, tuple))
|
| 530 |
+
assert len(in_channels) == len(in_index)
|
| 531 |
+
if input_transform == 'resize_concat':
|
| 532 |
+
self.in_channels = sum(in_channels)
|
| 533 |
+
else:
|
| 534 |
+
self.in_channels = in_channels
|
| 535 |
+
else:
|
| 536 |
+
assert isinstance(in_channels, int)
|
| 537 |
+
assert isinstance(in_index, int)
|
| 538 |
+
self.in_channels = in_channels
|
| 539 |
+
|
| 540 |
+
def init_weights(self):
|
| 541 |
+
"""Initialize weights of classification layer."""
|
| 542 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
| 543 |
+
|
| 544 |
+
def _transform_inputs(self, inputs):
|
| 545 |
+
"""Transform inputs for decoder.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 549 |
+
|
| 550 |
+
Returns:
|
| 551 |
+
Tensor: The transformed inputs
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
if self.input_transform == 'resize_concat':
|
| 555 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 556 |
+
upsampled_inputs = [
|
| 557 |
+
resize(
|
| 558 |
+
input=x,
|
| 559 |
+
size=inputs[0].shape[2:],
|
| 560 |
+
mode='bilinear',
|
| 561 |
+
align_corners=self.align_corners) for x in inputs
|
| 562 |
+
]
|
| 563 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
| 564 |
+
elif self.input_transform == 'multiple_select':
|
| 565 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 566 |
+
else:
|
| 567 |
+
inputs = inputs[self.in_index]
|
| 568 |
+
|
| 569 |
+
return inputs
|
| 570 |
+
|
| 571 |
+
# @auto_fp16()
|
| 572 |
+
@abstractmethod
|
| 573 |
+
def forward(self, inputs):
|
| 574 |
+
"""Placeholder of forward function."""
|
| 575 |
+
pass
|
| 576 |
+
|
| 577 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips,img=None):
|
| 578 |
+
"""Forward function for training.
|
| 579 |
+
Args:
|
| 580 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 581 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 582 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 583 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 584 |
+
For details on the values of these keys see
|
| 585 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 586 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
| 587 |
+
used if the architecture supports semantic segmentation task.
|
| 588 |
+
train_cfg (dict): The training config.
|
| 589 |
+
|
| 590 |
+
Returns:
|
| 591 |
+
dict[str, Tensor]: a dictionary of loss components
|
| 592 |
+
"""
|
| 593 |
+
seg_logits = self.forward(inputs,batch_size, num_clips,img)
|
| 594 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
| 595 |
+
return losses
|
| 596 |
+
|
| 597 |
+
def forward_test(self, inputs, img_metas, test_cfg, batch_size=None, num_clips=None, img=None):
|
| 598 |
+
"""Forward function for testing.
|
| 599 |
+
|
| 600 |
+
Args:
|
| 601 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 602 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 603 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 604 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 605 |
+
For details on the values of these keys see
|
| 606 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 607 |
+
test_cfg (dict): The testing config.
|
| 608 |
+
|
| 609 |
+
Returns:
|
| 610 |
+
Tensor: Output segmentation map.
|
| 611 |
+
"""
|
| 612 |
+
return self.forward(inputs, batch_size, num_clips,img)
|
| 613 |
+
|
| 614 |
+
def cls_seg(self, feat):
|
| 615 |
+
"""Classify each pixel."""
|
| 616 |
+
if self.dropout is not None:
|
| 617 |
+
feat = self.dropout(feat)
|
| 618 |
+
output = self.conv_seg(feat)
|
| 619 |
+
return output
|
models/SpaTrackV2/models/depth_refiner/depth_refiner.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from models.monoD.depth_anything_v2.dinov2_layers.patch_embed import PatchEmbed
|
| 5 |
+
from models.SpaTrackV2.models.depth_refiner.backbone import mit_b3
|
| 6 |
+
from models.SpaTrackV2.models.depth_refiner.stablizer import Stabilization_Network_Cross_Attention
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
class TrackStablizer(nn.Module):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
self.backbone = mit_b3()
|
| 13 |
+
|
| 14 |
+
old_conv = self.backbone.patch_embed1.proj
|
| 15 |
+
new_conv = nn.Conv2d(old_conv.in_channels + 4, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding)
|
| 16 |
+
|
| 17 |
+
new_conv.weight[:, :3, :, :].data.copy_(old_conv.weight.clone())
|
| 18 |
+
self.backbone.patch_embed1.proj = new_conv
|
| 19 |
+
|
| 20 |
+
self.Track_Stabilizer = Stabilization_Network_Cross_Attention(in_channels=[64, 128, 320, 512],
|
| 21 |
+
in_index=[0, 1, 2, 3],
|
| 22 |
+
feature_strides=[4, 8, 16, 32],
|
| 23 |
+
channels=128,
|
| 24 |
+
dropout_ratio=0.1,
|
| 25 |
+
num_classes=1,
|
| 26 |
+
align_corners=False,
|
| 27 |
+
decoder_params=dict(embed_dim=256, depths=4),
|
| 28 |
+
num_clips=16,
|
| 29 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True))
|
| 30 |
+
|
| 31 |
+
self.edge_conv = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=64, kernel_size=3, padding=1, stride=1, bias=True),\
|
| 32 |
+
nn.ReLU(inplace=True))
|
| 33 |
+
self.edge_conv1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, bias=True),\
|
| 34 |
+
nn.ReLU(inplace=True))
|
| 35 |
+
self.success = False
|
| 36 |
+
self.x = None
|
| 37 |
+
|
| 38 |
+
def buffer_forward(self, inputs, num_clips=16):
|
| 39 |
+
"""
|
| 40 |
+
buffer forward for getting the pointmap and image features
|
| 41 |
+
"""
|
| 42 |
+
B, T, C, H, W = inputs.shape
|
| 43 |
+
self.x = self.backbone(inputs)
|
| 44 |
+
scale, shift = self.Track_Stabilizer.buffer_forward(self.x, num_clips=num_clips)
|
| 45 |
+
self.success = True
|
| 46 |
+
return scale, shift
|
| 47 |
+
|
| 48 |
+
def forward(self, inputs, tracks, tracks_uvd, num_clips=16, imgs=None, vis_track=None):
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
Args:
|
| 52 |
+
inputs: [B, T, C, H, W], RGB + PointMap + Mask
|
| 53 |
+
tracks: [B, T, N, 4], 3D tracks in camera coordinate + visibility
|
| 54 |
+
num_clips: int, number of clips to use
|
| 55 |
+
"""
|
| 56 |
+
B, T, C, H, W = inputs.shape
|
| 57 |
+
edge_feat = self.edge_conv(inputs.view(B*T,4,H,W))
|
| 58 |
+
edge_feat1 = self.edge_conv1(edge_feat)
|
| 59 |
+
|
| 60 |
+
if not self.success:
|
| 61 |
+
scale, shift = self.Track_Stabilizer.buffer_forward(self.x,num_clips=num_clips)
|
| 62 |
+
self.success = True
|
| 63 |
+
update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
|
| 64 |
+
else:
|
| 65 |
+
update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
|
| 66 |
+
|
| 67 |
+
return update
|
| 68 |
+
|
| 69 |
+
def reset_success(self):
|
| 70 |
+
self.success = False
|
| 71 |
+
self.x = None
|
| 72 |
+
self.Track_Stabilizer.reset_success()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
# Create test input tensors
|
| 77 |
+
batch_size = 1
|
| 78 |
+
seq_len = 16
|
| 79 |
+
channels = 7 # 3 for RGB + 3 for PointMap + 1 for Mask
|
| 80 |
+
height = 384
|
| 81 |
+
width = 512
|
| 82 |
+
|
| 83 |
+
# Create random input tensor with shape [B, T, C, H, W]
|
| 84 |
+
inputs = torch.randn(batch_size, seq_len, channels, height, width)
|
| 85 |
+
|
| 86 |
+
# Create random tracks
|
| 87 |
+
tracks = torch.randn(batch_size, seq_len, 1024, 4)
|
| 88 |
+
|
| 89 |
+
# Create random test images
|
| 90 |
+
test_imgs = torch.randn(batch_size, seq_len, 3, height, width)
|
| 91 |
+
|
| 92 |
+
# Initialize model and move to GPU
|
| 93 |
+
model = TrackStablizer().cuda()
|
| 94 |
+
|
| 95 |
+
# Move inputs to GPU and run forward pass
|
| 96 |
+
inputs = inputs.cuda()
|
| 97 |
+
tracks = tracks.cuda()
|
| 98 |
+
outputs = model.buffer_forward(inputs, num_clips=seq_len)
|
| 99 |
+
import time
|
| 100 |
+
start_time = time.time()
|
| 101 |
+
outputs = model(inputs, tracks, num_clips=seq_len)
|
| 102 |
+
end_time = time.time()
|
| 103 |
+
print(f"Time taken: {end_time - start_time} seconds")
|
| 104 |
+
import pdb; pdb.set_trace()
|
| 105 |
+
# # Print shapes for verification
|
| 106 |
+
# print(f"Input shape: {inputs.shape}")
|
| 107 |
+
# print(f"Output shape: {outputs.shape}")
|
| 108 |
+
|
| 109 |
+
# # Basic tests
|
| 110 |
+
# assert outputs.shape[0] == batch_size, "Batch size mismatch"
|
| 111 |
+
# assert len(outputs.shape) == 4, "Output should be 4D: [B,C,H,W]"
|
| 112 |
+
# assert torch.all(outputs >= 0), "Output should be non-negative after ReLU"
|
| 113 |
+
|
| 114 |
+
# print("All tests passed!")
|
| 115 |
+
|
models/SpaTrackV2/models/depth_refiner/network.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
'''
|
| 4 |
+
Author: Ke Xian
|
| 5 |
+
Email: kexian@hust.edu.cn
|
| 6 |
+
Date: 2020/07/20
|
| 7 |
+
'''
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.init as init
|
| 12 |
+
|
| 13 |
+
# ==============================================================================================================
|
| 14 |
+
|
| 15 |
+
class FTB(nn.Module):
|
| 16 |
+
def __init__(self, inchannels, midchannels=512):
|
| 17 |
+
super(FTB, self).__init__()
|
| 18 |
+
self.in1 = inchannels
|
| 19 |
+
self.mid = midchannels
|
| 20 |
+
|
| 21 |
+
self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True)
|
| 22 |
+
self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\
|
| 23 |
+
nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\
|
| 24 |
+
#nn.BatchNorm2d(num_features=self.mid),\
|
| 25 |
+
nn.ReLU(inplace=True),\
|
| 26 |
+
nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True))
|
| 27 |
+
self.relu = nn.ReLU(inplace=True)
|
| 28 |
+
|
| 29 |
+
self.init_params()
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
x = self.conv1(x)
|
| 33 |
+
x = x + self.conv_branch(x)
|
| 34 |
+
x = self.relu(x)
|
| 35 |
+
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
def init_params(self):
|
| 39 |
+
for m in self.modules():
|
| 40 |
+
if isinstance(m, nn.Conv2d):
|
| 41 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 42 |
+
init.normal_(m.weight, std=0.01)
|
| 43 |
+
# init.xavier_normal_(m.weight)
|
| 44 |
+
if m.bias is not None:
|
| 45 |
+
init.constant_(m.bias, 0)
|
| 46 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 47 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 48 |
+
init.normal_(m.weight, std=0.01)
|
| 49 |
+
# init.xavier_normal_(m.weight)
|
| 50 |
+
if m.bias is not None:
|
| 51 |
+
init.constant_(m.bias, 0)
|
| 52 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
| 53 |
+
init.constant_(m.weight, 1)
|
| 54 |
+
init.constant_(m.bias, 0)
|
| 55 |
+
elif isinstance(m, nn.Linear):
|
| 56 |
+
init.normal_(m.weight, std=0.01)
|
| 57 |
+
if m.bias is not None:
|
| 58 |
+
init.constant_(m.bias, 0)
|
| 59 |
+
|
| 60 |
+
class ATA(nn.Module):
|
| 61 |
+
def __init__(self, inchannels, reduction = 8):
|
| 62 |
+
super(ATA, self).__init__()
|
| 63 |
+
self.inchannels = inchannels
|
| 64 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 65 |
+
self.fc = nn.Sequential(nn.Linear(self.inchannels*2, self.inchannels // reduction),
|
| 66 |
+
nn.ReLU(inplace=True),
|
| 67 |
+
nn.Linear(self.inchannels // reduction, self.inchannels),
|
| 68 |
+
nn.Sigmoid())
|
| 69 |
+
self.init_params()
|
| 70 |
+
|
| 71 |
+
def forward(self, low_x, high_x):
|
| 72 |
+
n, c, _, _ = low_x.size()
|
| 73 |
+
x = torch.cat([low_x, high_x], 1)
|
| 74 |
+
x = self.avg_pool(x)
|
| 75 |
+
x = x.view(n, -1)
|
| 76 |
+
x = self.fc(x).view(n,c,1,1)
|
| 77 |
+
x = low_x * x + high_x
|
| 78 |
+
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
def init_params(self):
|
| 82 |
+
for m in self.modules():
|
| 83 |
+
if isinstance(m, nn.Conv2d):
|
| 84 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 85 |
+
#init.normal(m.weight, std=0.01)
|
| 86 |
+
init.xavier_normal_(m.weight)
|
| 87 |
+
if m.bias is not None:
|
| 88 |
+
init.constant_(m.bias, 0)
|
| 89 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 90 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 91 |
+
#init.normal_(m.weight, std=0.01)
|
| 92 |
+
init.xavier_normal_(m.weight)
|
| 93 |
+
if m.bias is not None:
|
| 94 |
+
init.constant_(m.bias, 0)
|
| 95 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
| 96 |
+
init.constant_(m.weight, 1)
|
| 97 |
+
init.constant_(m.bias, 0)
|
| 98 |
+
elif isinstance(m, nn.Linear):
|
| 99 |
+
init.normal_(m.weight, std=0.01)
|
| 100 |
+
if m.bias is not None:
|
| 101 |
+
init.constant_(m.bias, 0)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class FFM(nn.Module):
|
| 105 |
+
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
|
| 106 |
+
super(FFM, self).__init__()
|
| 107 |
+
self.inchannels = inchannels
|
| 108 |
+
self.midchannels = midchannels
|
| 109 |
+
self.outchannels = outchannels
|
| 110 |
+
self.upfactor = upfactor
|
| 111 |
+
|
| 112 |
+
self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
|
| 113 |
+
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
|
| 114 |
+
|
| 115 |
+
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
|
| 116 |
+
|
| 117 |
+
self.init_params()
|
| 118 |
+
#self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
| 119 |
+
#self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
| 120 |
+
#self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
| 121 |
+
|
| 122 |
+
def forward(self, low_x, high_x):
|
| 123 |
+
x = self.ftb1(low_x)
|
| 124 |
+
|
| 125 |
+
'''
|
| 126 |
+
x = torch.cat((x,high_x),1)
|
| 127 |
+
if x.shape[2] == 12:
|
| 128 |
+
x = self.p1(x)
|
| 129 |
+
elif x.shape[2] == 24:
|
| 130 |
+
x = self.p2(x)
|
| 131 |
+
elif x.shape[2] == 48:
|
| 132 |
+
x = self.p3(x)
|
| 133 |
+
'''
|
| 134 |
+
x = x + high_x ###high_x
|
| 135 |
+
x = self.ftb2(x)
|
| 136 |
+
x = self.upsample(x)
|
| 137 |
+
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
def init_params(self):
|
| 141 |
+
for m in self.modules():
|
| 142 |
+
if isinstance(m, nn.Conv2d):
|
| 143 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 144 |
+
init.normal_(m.weight, std=0.01)
|
| 145 |
+
#init.xavier_normal_(m.weight)
|
| 146 |
+
if m.bias is not None:
|
| 147 |
+
init.constant_(m.bias, 0)
|
| 148 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 149 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 150 |
+
init.normal_(m.weight, std=0.01)
|
| 151 |
+
#init.xavier_normal_(m.weight)
|
| 152 |
+
if m.bias is not None:
|
| 153 |
+
init.constant_(m.bias, 0)
|
| 154 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
| 155 |
+
init.constant_(m.weight, 1)
|
| 156 |
+
init.constant_(m.bias, 0)
|
| 157 |
+
elif isinstance(m, nn.Linear):
|
| 158 |
+
init.normal_(m.weight, std=0.01)
|
| 159 |
+
if m.bias is not None:
|
| 160 |
+
init.constant_(m.bias, 0)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class noFFM(nn.Module):
|
| 165 |
+
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
|
| 166 |
+
super(noFFM, self).__init__()
|
| 167 |
+
self.inchannels = inchannels
|
| 168 |
+
self.midchannels = midchannels
|
| 169 |
+
self.outchannels = outchannels
|
| 170 |
+
self.upfactor = upfactor
|
| 171 |
+
|
| 172 |
+
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
|
| 173 |
+
|
| 174 |
+
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
|
| 175 |
+
|
| 176 |
+
self.init_params()
|
| 177 |
+
#self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
| 178 |
+
#self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
| 179 |
+
#self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
| 180 |
+
|
| 181 |
+
def forward(self, low_x, high_x):
|
| 182 |
+
|
| 183 |
+
#x = self.ftb1(low_x)
|
| 184 |
+
x = high_x ###high_x
|
| 185 |
+
x = self.ftb2(x)
|
| 186 |
+
x = self.upsample(x)
|
| 187 |
+
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
def init_params(self):
|
| 191 |
+
for m in self.modules():
|
| 192 |
+
if isinstance(m, nn.Conv2d):
|
| 193 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 194 |
+
init.normal_(m.weight, std=0.01)
|
| 195 |
+
#init.xavier_normal_(m.weight)
|
| 196 |
+
if m.bias is not None:
|
| 197 |
+
init.constant_(m.bias, 0)
|
| 198 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 199 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 200 |
+
init.normal_(m.weight, std=0.01)
|
| 201 |
+
#init.xavier_normal_(m.weight)
|
| 202 |
+
if m.bias is not None:
|
| 203 |
+
init.constant_(m.bias, 0)
|
| 204 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
| 205 |
+
init.constant_(m.weight, 1)
|
| 206 |
+
init.constant_(m.bias, 0)
|
| 207 |
+
elif isinstance(m, nn.Linear):
|
| 208 |
+
init.normal_(m.weight, std=0.01)
|
| 209 |
+
if m.bias is not None:
|
| 210 |
+
init.constant_(m.bias, 0)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class AO(nn.Module):
|
| 216 |
+
# Adaptive output module
|
| 217 |
+
def __init__(self, inchannels, outchannels, upfactor=2):
|
| 218 |
+
super(AO, self).__init__()
|
| 219 |
+
self.inchannels = inchannels
|
| 220 |
+
self.outchannels = outchannels
|
| 221 |
+
self.upfactor = upfactor
|
| 222 |
+
|
| 223 |
+
"""
|
| 224 |
+
self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
|
| 225 |
+
nn.BatchNorm2d(num_features=self.inchannels//2),\
|
| 226 |
+
nn.ReLU(inplace=True),\
|
| 227 |
+
nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\
|
| 228 |
+
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) )#,\
|
| 229 |
+
#nn.ReLU(inplace=True)) ## get positive values
|
| 230 |
+
"""
|
| 231 |
+
self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
|
| 232 |
+
#nn.BatchNorm2d(num_features=self.inchannels//2),\
|
| 233 |
+
nn.ReLU(inplace=True),\
|
| 234 |
+
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True), \
|
| 235 |
+
nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=1, padding=0, stride=1))
|
| 236 |
+
|
| 237 |
+
#nn.ReLU(inplace=True)) ## get positive values
|
| 238 |
+
|
| 239 |
+
self.init_params()
|
| 240 |
+
|
| 241 |
+
def forward(self, x):
|
| 242 |
+
x = self.adapt_conv(x)
|
| 243 |
+
return x
|
| 244 |
+
|
| 245 |
+
def init_params(self):
|
| 246 |
+
for m in self.modules():
|
| 247 |
+
if isinstance(m, nn.Conv2d):
|
| 248 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 249 |
+
init.normal_(m.weight, std=0.01)
|
| 250 |
+
#init.xavier_normal_(m.weight)
|
| 251 |
+
if m.bias is not None:
|
| 252 |
+
init.constant_(m.bias, 0)
|
| 253 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 254 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 255 |
+
init.normal_(m.weight, std=0.01)
|
| 256 |
+
#init.xavier_normal_(m.weight)
|
| 257 |
+
if m.bias is not None:
|
| 258 |
+
init.constant_(m.bias, 0)
|
| 259 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
| 260 |
+
init.constant_(m.weight, 1)
|
| 261 |
+
init.constant_(m.bias, 0)
|
| 262 |
+
elif isinstance(m, nn.Linear):
|
| 263 |
+
init.normal_(m.weight, std=0.01)
|
| 264 |
+
if m.bias is not None:
|
| 265 |
+
init.constant_(m.bias, 0)
|
| 266 |
+
|
| 267 |
+
class ASPP(nn.Module):
|
| 268 |
+
def __init__(self, inchannels=256, planes=128, rates = [1, 6, 12, 18]):
|
| 269 |
+
super(ASPP, self).__init__()
|
| 270 |
+
self.inchannels = inchannels
|
| 271 |
+
self.planes = planes
|
| 272 |
+
self.rates = rates
|
| 273 |
+
self.kernel_sizes = []
|
| 274 |
+
self.paddings = []
|
| 275 |
+
for rate in self.rates:
|
| 276 |
+
if rate == 1:
|
| 277 |
+
self.kernel_sizes.append(1)
|
| 278 |
+
self.paddings.append(0)
|
| 279 |
+
else:
|
| 280 |
+
self.kernel_sizes.append(3)
|
| 281 |
+
self.paddings.append(rate)
|
| 282 |
+
self.atrous_0 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[0],
|
| 283 |
+
stride=1, padding=self.paddings[0], dilation=self.rates[0], bias=True),
|
| 284 |
+
nn.ReLU(inplace=True),
|
| 285 |
+
nn.BatchNorm2d(num_features=self.planes)
|
| 286 |
+
)
|
| 287 |
+
self.atrous_1 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[1],
|
| 288 |
+
stride=1, padding=self.paddings[1], dilation=self.rates[1], bias=True),
|
| 289 |
+
nn.ReLU(inplace=True),
|
| 290 |
+
nn.BatchNorm2d(num_features=self.planes),
|
| 291 |
+
)
|
| 292 |
+
self.atrous_2 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[2],
|
| 293 |
+
stride=1, padding=self.paddings[2], dilation=self.rates[2], bias=True),
|
| 294 |
+
nn.ReLU(inplace=True),
|
| 295 |
+
nn.BatchNorm2d(num_features=self.planes),
|
| 296 |
+
)
|
| 297 |
+
self.atrous_3 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[3],
|
| 298 |
+
stride=1, padding=self.paddings[3], dilation=self.rates[3], bias=True),
|
| 299 |
+
nn.ReLU(inplace=True),
|
| 300 |
+
nn.BatchNorm2d(num_features=self.planes),
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
#self.conv = nn.Conv2d(in_channels=self.planes * 4, out_channels=self.inchannels, kernel_size=3, padding=1, stride=1, bias=True)
|
| 304 |
+
def forward(self, x):
|
| 305 |
+
x = torch.cat([self.atrous_0(x), self.atrous_1(x), self.atrous_2(x), self.atrous_3(x)],1)
|
| 306 |
+
#x = self.conv(x)
|
| 307 |
+
|
| 308 |
+
return x
|
| 309 |
+
|
| 310 |
+
# ==============================================================================================================
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class ResidualConv(nn.Module):
|
| 314 |
+
def __init__(self, inchannels):
|
| 315 |
+
super(ResidualConv, self).__init__()
|
| 316 |
+
#nn.BatchNorm2d
|
| 317 |
+
self.conv = nn.Sequential(
|
| 318 |
+
#nn.BatchNorm2d(num_features=inchannels),
|
| 319 |
+
nn.ReLU(inplace=False),
|
| 320 |
+
#nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
|
| 321 |
+
#nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
|
| 322 |
+
nn.Conv2d(in_channels=inchannels, out_channels=inchannels//2, kernel_size=3, padding=1, stride=1, bias=False),
|
| 323 |
+
nn.BatchNorm2d(num_features=inchannels//2),
|
| 324 |
+
nn.ReLU(inplace=False),
|
| 325 |
+
nn.Conv2d(in_channels=inchannels//2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, bias=False)
|
| 326 |
+
)
|
| 327 |
+
self.init_params()
|
| 328 |
+
|
| 329 |
+
def forward(self, x):
|
| 330 |
+
x = self.conv(x)+x
|
| 331 |
+
return x
|
| 332 |
+
|
| 333 |
+
def init_params(self):
|
| 334 |
+
for m in self.modules():
|
| 335 |
+
if isinstance(m, nn.Conv2d):
|
| 336 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 337 |
+
init.normal_(m.weight, std=0.01)
|
| 338 |
+
#init.xavier_normal_(m.weight)
|
| 339 |
+
if m.bias is not None:
|
| 340 |
+
init.constant_(m.bias, 0)
|
| 341 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 342 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 343 |
+
init.normal_(m.weight, std=0.01)
|
| 344 |
+
#init.xavier_normal_(m.weight)
|
| 345 |
+
if m.bias is not None:
|
| 346 |
+
init.constant_(m.bias, 0)
|
| 347 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
| 348 |
+
init.constant_(m.weight, 1)
|
| 349 |
+
init.constant_(m.bias, 0)
|
| 350 |
+
elif isinstance(m, nn.Linear):
|
| 351 |
+
init.normal_(m.weight, std=0.01)
|
| 352 |
+
if m.bias is not None:
|
| 353 |
+
init.constant_(m.bias, 0)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class FeatureFusion(nn.Module):
|
| 357 |
+
def __init__(self, inchannels, outchannels):
|
| 358 |
+
super(FeatureFusion, self).__init__()
|
| 359 |
+
self.conv = ResidualConv(inchannels=inchannels)
|
| 360 |
+
#nn.BatchNorm2d
|
| 361 |
+
self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
|
| 362 |
+
nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,stride=2, padding=1, output_padding=1),
|
| 363 |
+
nn.BatchNorm2d(num_features=outchannels),
|
| 364 |
+
nn.ReLU(inplace=True))
|
| 365 |
+
|
| 366 |
+
def forward(self, lowfeat, highfeat):
|
| 367 |
+
return self.up(highfeat + self.conv(lowfeat))
|
| 368 |
+
|
| 369 |
+
def init_params(self):
|
| 370 |
+
for m in self.modules():
|
| 371 |
+
if isinstance(m, nn.Conv2d):
|
| 372 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 373 |
+
init.normal_(m.weight, std=0.01)
|
| 374 |
+
#init.xavier_normal_(m.weight)
|
| 375 |
+
if m.bias is not None:
|
| 376 |
+
init.constant_(m.bias, 0)
|
| 377 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 378 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 379 |
+
init.normal_(m.weight, std=0.01)
|
| 380 |
+
#init.xavier_normal_(m.weight)
|
| 381 |
+
if m.bias is not None:
|
| 382 |
+
init.constant_(m.bias, 0)
|
| 383 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
| 384 |
+
init.constant_(m.weight, 1)
|
| 385 |
+
init.constant_(m.bias, 0)
|
| 386 |
+
elif isinstance(m, nn.Linear):
|
| 387 |
+
init.normal_(m.weight, std=0.01)
|
| 388 |
+
if m.bias is not None:
|
| 389 |
+
init.constant_(m.bias, 0)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class SenceUnderstand(nn.Module):
|
| 393 |
+
def __init__(self, channels):
|
| 394 |
+
super(SenceUnderstand, self).__init__()
|
| 395 |
+
self.channels = channels
|
| 396 |
+
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
|
| 397 |
+
nn.ReLU(inplace = True))
|
| 398 |
+
self.pool = nn.AdaptiveAvgPool2d(8)
|
| 399 |
+
self.fc = nn.Sequential(nn.Linear(512*8*8, self.channels),
|
| 400 |
+
nn.ReLU(inplace = True))
|
| 401 |
+
self.conv2 = nn.Sequential(nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
|
| 402 |
+
nn.ReLU(inplace=True))
|
| 403 |
+
self.initial_params()
|
| 404 |
+
|
| 405 |
+
def forward(self, x):
|
| 406 |
+
n,c,h,w = x.size()
|
| 407 |
+
x = self.conv1(x)
|
| 408 |
+
x = self.pool(x)
|
| 409 |
+
x = x.view(n,-1)
|
| 410 |
+
x = self.fc(x)
|
| 411 |
+
x = x.view(n, self.channels, 1, 1)
|
| 412 |
+
x = self.conv2(x)
|
| 413 |
+
x = x.repeat(1,1,h,w)
|
| 414 |
+
return x
|
| 415 |
+
|
| 416 |
+
def initial_params(self, dev=0.01):
|
| 417 |
+
for m in self.modules():
|
| 418 |
+
if isinstance(m, nn.Conv2d):
|
| 419 |
+
#print torch.sum(m.weight)
|
| 420 |
+
m.weight.data.normal_(0, dev)
|
| 421 |
+
if m.bias is not None:
|
| 422 |
+
m.bias.data.fill_(0)
|
| 423 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 424 |
+
#print torch.sum(m.weight)
|
| 425 |
+
m.weight.data.normal_(0, dev)
|
| 426 |
+
if m.bias is not None:
|
| 427 |
+
m.bias.data.fill_(0)
|
| 428 |
+
elif isinstance(m, nn.Linear):
|
| 429 |
+
m.weight.data.normal_(0, dev)
|
models/SpaTrackV2/models/depth_refiner/stablilization_attention.py
ADDED
|
@@ -0,0 +1,1187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import time
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch.utils.checkpoint as checkpoint
|
| 7 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
class Mlp(nn.Module):
|
| 11 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 12 |
+
super().__init__()
|
| 13 |
+
out_features = out_features or in_features
|
| 14 |
+
hidden_features = hidden_features or in_features
|
| 15 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 16 |
+
self.act = act_layer()
|
| 17 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 18 |
+
self.drop = nn.Dropout(drop)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
x = self.fc1(x)
|
| 22 |
+
x = self.act(x)
|
| 23 |
+
x = self.drop(x)
|
| 24 |
+
x = self.fc2(x)
|
| 25 |
+
x = self.drop(x)
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def window_partition(x, window_size):
|
| 30 |
+
"""
|
| 31 |
+
Args:
|
| 32 |
+
x: (B, H, W, C)
|
| 33 |
+
window_size (int): window size
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 37 |
+
"""
|
| 38 |
+
B, H, W, C = x.shape
|
| 39 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 40 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 41 |
+
return windows
|
| 42 |
+
|
| 43 |
+
def window_partition_noreshape(x, window_size):
|
| 44 |
+
"""
|
| 45 |
+
Args:
|
| 46 |
+
x: (B, H, W, C)
|
| 47 |
+
window_size (int): window size
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
windows: (B, num_windows_h, num_windows_w, window_size, window_size, C)
|
| 51 |
+
"""
|
| 52 |
+
B, H, W, C = x.shape
|
| 53 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 54 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
| 55 |
+
return windows
|
| 56 |
+
|
| 57 |
+
def window_reverse(windows, window_size, H, W):
|
| 58 |
+
"""
|
| 59 |
+
Args:
|
| 60 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 61 |
+
window_size (int): Window size
|
| 62 |
+
H (int): Height of image
|
| 63 |
+
W (int): Width of image
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
x: (B, H, W, C)
|
| 67 |
+
"""
|
| 68 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 69 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 70 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
def get_roll_masks(H, W, window_size, shift_size):
|
| 74 |
+
#####################################
|
| 75 |
+
# move to top-left
|
| 76 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
| 77 |
+
h_slices = (slice(0, H-window_size),
|
| 78 |
+
slice(H-window_size, H-shift_size),
|
| 79 |
+
slice(H-shift_size, H))
|
| 80 |
+
w_slices = (slice(0, W-window_size),
|
| 81 |
+
slice(W-window_size, W-shift_size),
|
| 82 |
+
slice(W-shift_size, W))
|
| 83 |
+
cnt = 0
|
| 84 |
+
for h in h_slices:
|
| 85 |
+
for w in w_slices:
|
| 86 |
+
img_mask[:, h, w, :] = cnt
|
| 87 |
+
cnt += 1
|
| 88 |
+
|
| 89 |
+
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
|
| 90 |
+
mask_windows = mask_windows.view(-1, window_size * window_size)
|
| 91 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 92 |
+
attn_mask_tl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 93 |
+
|
| 94 |
+
####################################
|
| 95 |
+
# move to top right
|
| 96 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
| 97 |
+
h_slices = (slice(0, H-window_size),
|
| 98 |
+
slice(H-window_size, H-shift_size),
|
| 99 |
+
slice(H-shift_size, H))
|
| 100 |
+
w_slices = (slice(0, shift_size),
|
| 101 |
+
slice(shift_size, window_size),
|
| 102 |
+
slice(window_size, W))
|
| 103 |
+
cnt = 0
|
| 104 |
+
for h in h_slices:
|
| 105 |
+
for w in w_slices:
|
| 106 |
+
img_mask[:, h, w, :] = cnt
|
| 107 |
+
cnt += 1
|
| 108 |
+
|
| 109 |
+
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
|
| 110 |
+
mask_windows = mask_windows.view(-1, window_size * window_size)
|
| 111 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 112 |
+
attn_mask_tr = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 113 |
+
|
| 114 |
+
####################################
|
| 115 |
+
# move to bottom left
|
| 116 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
| 117 |
+
h_slices = (slice(0, shift_size),
|
| 118 |
+
slice(shift_size, window_size),
|
| 119 |
+
slice(window_size, H))
|
| 120 |
+
w_slices = (slice(0, W-window_size),
|
| 121 |
+
slice(W-window_size, W-shift_size),
|
| 122 |
+
slice(W-shift_size, W))
|
| 123 |
+
cnt = 0
|
| 124 |
+
for h in h_slices:
|
| 125 |
+
for w in w_slices:
|
| 126 |
+
img_mask[:, h, w, :] = cnt
|
| 127 |
+
cnt += 1
|
| 128 |
+
|
| 129 |
+
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
|
| 130 |
+
mask_windows = mask_windows.view(-1, window_size * window_size)
|
| 131 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 132 |
+
attn_mask_bl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 133 |
+
|
| 134 |
+
####################################
|
| 135 |
+
# move to bottom right
|
| 136 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
| 137 |
+
h_slices = (slice(0, shift_size),
|
| 138 |
+
slice(shift_size, window_size),
|
| 139 |
+
slice(window_size, H))
|
| 140 |
+
w_slices = (slice(0, shift_size),
|
| 141 |
+
slice(shift_size, window_size),
|
| 142 |
+
slice(window_size, W))
|
| 143 |
+
cnt = 0
|
| 144 |
+
for h in h_slices:
|
| 145 |
+
for w in w_slices:
|
| 146 |
+
img_mask[:, h, w, :] = cnt
|
| 147 |
+
cnt += 1
|
| 148 |
+
|
| 149 |
+
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
|
| 150 |
+
mask_windows = mask_windows.view(-1, window_size * window_size)
|
| 151 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 152 |
+
attn_mask_br = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 153 |
+
|
| 154 |
+
# append all
|
| 155 |
+
attn_mask_all = torch.cat((attn_mask_tl, attn_mask_tr, attn_mask_bl, attn_mask_br), -1)
|
| 156 |
+
return attn_mask_all
|
| 157 |
+
|
| 158 |
+
def get_relative_position_index(q_windows, k_windows):
|
| 159 |
+
"""
|
| 160 |
+
Args:
|
| 161 |
+
q_windows: tuple (query_window_height, query_window_width)
|
| 162 |
+
k_windows: tuple (key_window_height, key_window_width)
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
relative_position_index: query_window_height*query_window_width, key_window_height*key_window_width
|
| 166 |
+
"""
|
| 167 |
+
# get pair-wise relative position index for each token inside the window
|
| 168 |
+
coords_h_q = torch.arange(q_windows[0])
|
| 169 |
+
coords_w_q = torch.arange(q_windows[1])
|
| 170 |
+
coords_q = torch.stack(torch.meshgrid([coords_h_q, coords_w_q])) # 2, Wh_q, Ww_q
|
| 171 |
+
|
| 172 |
+
coords_h_k = torch.arange(k_windows[0])
|
| 173 |
+
coords_w_k = torch.arange(k_windows[1])
|
| 174 |
+
coords_k = torch.stack(torch.meshgrid([coords_h_k, coords_w_k])) # 2, Wh, Ww
|
| 175 |
+
|
| 176 |
+
coords_flatten_q = torch.flatten(coords_q, 1) # 2, Wh_q*Ww_q
|
| 177 |
+
coords_flatten_k = torch.flatten(coords_k, 1) # 2, Wh_k*Ww_k
|
| 178 |
+
|
| 179 |
+
relative_coords = coords_flatten_q[:, :, None] - coords_flatten_k[:, None, :] # 2, Wh_q*Ww_q, Wh_k*Ww_k
|
| 180 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh_q*Ww_q, Wh_k*Ww_k, 2
|
| 181 |
+
relative_coords[:, :, 0] += k_windows[0] - 1 # shift to start from 0
|
| 182 |
+
relative_coords[:, :, 1] += k_windows[1] - 1
|
| 183 |
+
relative_coords[:, :, 0] *= (q_windows[1] + k_windows[1]) - 1
|
| 184 |
+
relative_position_index = relative_coords.sum(-1) # Wh_q*Ww_q, Wh_k*Ww_k
|
| 185 |
+
return relative_position_index
|
| 186 |
+
|
| 187 |
+
def get_relative_position_index3d(q_windows, k_windows, num_clips):
|
| 188 |
+
"""
|
| 189 |
+
Args:
|
| 190 |
+
q_windows: tuple (query_window_height, query_window_width)
|
| 191 |
+
k_windows: tuple (key_window_height, key_window_width)
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
relative_position_index: query_window_height*query_window_width, key_window_height*key_window_width
|
| 195 |
+
"""
|
| 196 |
+
# get pair-wise relative position index for each token inside the window
|
| 197 |
+
coords_d_q = torch.arange(num_clips)
|
| 198 |
+
coords_h_q = torch.arange(q_windows[0])
|
| 199 |
+
coords_w_q = torch.arange(q_windows[1])
|
| 200 |
+
coords_q = torch.stack(torch.meshgrid([coords_d_q, coords_h_q, coords_w_q])) # 2, Wh_q, Ww_q
|
| 201 |
+
|
| 202 |
+
coords_d_k = torch.arange(num_clips)
|
| 203 |
+
coords_h_k = torch.arange(k_windows[0])
|
| 204 |
+
coords_w_k = torch.arange(k_windows[1])
|
| 205 |
+
coords_k = torch.stack(torch.meshgrid([coords_d_k, coords_h_k, coords_w_k])) # 2, Wh, Ww
|
| 206 |
+
|
| 207 |
+
coords_flatten_q = torch.flatten(coords_q, 1) # 2, Wh_q*Ww_q
|
| 208 |
+
coords_flatten_k = torch.flatten(coords_k, 1) # 2, Wh_k*Ww_k
|
| 209 |
+
|
| 210 |
+
relative_coords = coords_flatten_q[:, :, None] - coords_flatten_k[:, None, :] # 2, Wh_q*Ww_q, Wh_k*Ww_k
|
| 211 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh_q*Ww_q, Wh_k*Ww_k, 2
|
| 212 |
+
relative_coords[:, :, 0] += num_clips - 1 # shift to start from 0
|
| 213 |
+
relative_coords[:, :, 1] += k_windows[0] - 1
|
| 214 |
+
relative_coords[:, :, 2] += k_windows[1] - 1
|
| 215 |
+
relative_coords[:, :, 0] *= (q_windows[0] + k_windows[0] - 1)*(q_windows[1] + k_windows[1] - 1)
|
| 216 |
+
relative_coords[:, :, 1] *= (q_windows[1] + k_windows[1] - 1)
|
| 217 |
+
relative_position_index = relative_coords.sum(-1) # Wh_q*Ww_q, Wh_k*Ww_k
|
| 218 |
+
return relative_position_index
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class WindowAttention3d3(nn.Module):
|
| 222 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
dim (int): Number of input channels.
|
| 226 |
+
expand_size (int): The expand size at focal level 1.
|
| 227 |
+
window_size (tuple[int]): The height and width of the window.
|
| 228 |
+
focal_window (int): Focal region size.
|
| 229 |
+
focal_level (int): Focal attention level.
|
| 230 |
+
num_heads (int): Number of attention heads.
|
| 231 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 232 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 233 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 234 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 235 |
+
pool_method (str): window pooling method. Default: none
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
def __init__(self, dim, expand_size, window_size, focal_window, focal_level, num_heads,
|
| 239 |
+
qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pool_method="none", focal_l_clips=[7,1,2], focal_kernel_clips=[7,5,3]):
|
| 240 |
+
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.dim = dim
|
| 243 |
+
self.expand_size = expand_size
|
| 244 |
+
self.window_size = window_size # Wh, Ww
|
| 245 |
+
self.pool_method = pool_method
|
| 246 |
+
self.num_heads = num_heads
|
| 247 |
+
head_dim = dim // num_heads
|
| 248 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 249 |
+
self.focal_level = focal_level
|
| 250 |
+
self.focal_window = focal_window
|
| 251 |
+
|
| 252 |
+
# define a parameter table of relative position bias for each window
|
| 253 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 254 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 255 |
+
|
| 256 |
+
# get pair-wise relative position index for each token inside the window
|
| 257 |
+
coords_h = torch.arange(self.window_size[0])
|
| 258 |
+
coords_w = torch.arange(self.window_size[1])
|
| 259 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 260 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 261 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 262 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 263 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
| 264 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 265 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 266 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 267 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 268 |
+
|
| 269 |
+
num_clips=4
|
| 270 |
+
# # define a parameter table of relative position bias
|
| 271 |
+
# self.relative_position_bias_table = nn.Parameter(
|
| 272 |
+
# torch.zeros((2 * num_clips - 1) * (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
|
| 273 |
+
|
| 274 |
+
# # get pair-wise relative position index for each token inside the window
|
| 275 |
+
# coords_d = torch.arange(num_clips)
|
| 276 |
+
# coords_h = torch.arange(self.window_size[0])
|
| 277 |
+
# coords_w = torch.arange(self.window_size[1])
|
| 278 |
+
# coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
|
| 279 |
+
# coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
|
| 280 |
+
# relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
|
| 281 |
+
# relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
|
| 282 |
+
# relative_coords[:, :, 0] += num_clips - 1 # shift to start from 0
|
| 283 |
+
# relative_coords[:, :, 1] += self.window_size[0] - 1
|
| 284 |
+
# relative_coords[:, :, 2] += self.window_size[1] - 1
|
| 285 |
+
|
| 286 |
+
# relative_coords[:, :, 0] *= (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1)
|
| 287 |
+
# relative_coords[:, :, 1] *= (2 * self.window_size[1] - 1)
|
| 288 |
+
# relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
|
| 289 |
+
# self.register_buffer("relative_position_index", relative_position_index)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
if self.expand_size > 0 and focal_level > 0:
|
| 293 |
+
# define a parameter table of position bias between window and its fine-grained surroundings
|
| 294 |
+
self.window_size_of_key = self.window_size[0] * self.window_size[1] if self.expand_size == 0 else \
|
| 295 |
+
(4 * self.window_size[0] * self.window_size[1] - 4 * (self.window_size[0] - self.expand_size) * (self.window_size[0] - self.expand_size))
|
| 296 |
+
self.relative_position_bias_table_to_neighbors = nn.Parameter(
|
| 297 |
+
torch.zeros(1, num_heads, self.window_size[0] * self.window_size[1], self.window_size_of_key)) # Wh*Ww, nH, nSurrounding
|
| 298 |
+
trunc_normal_(self.relative_position_bias_table_to_neighbors, std=.02)
|
| 299 |
+
|
| 300 |
+
# get mask for rolled k and rolled v
|
| 301 |
+
mask_tl = torch.ones(self.window_size[0], self.window_size[1]); mask_tl[:-self.expand_size, :-self.expand_size] = 0
|
| 302 |
+
mask_tr = torch.ones(self.window_size[0], self.window_size[1]); mask_tr[:-self.expand_size, self.expand_size:] = 0
|
| 303 |
+
mask_bl = torch.ones(self.window_size[0], self.window_size[1]); mask_bl[self.expand_size:, :-self.expand_size] = 0
|
| 304 |
+
mask_br = torch.ones(self.window_size[0], self.window_size[1]); mask_br[self.expand_size:, self.expand_size:] = 0
|
| 305 |
+
mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0)
|
| 306 |
+
self.register_buffer("valid_ind_rolled", mask_rolled.nonzero().view(-1))
|
| 307 |
+
|
| 308 |
+
if pool_method != "none" and focal_level > 1:
|
| 309 |
+
#self.relative_position_bias_table_to_windows = nn.ParameterList()
|
| 310 |
+
#self.relative_position_bias_table_to_windows_clips = nn.ParameterList()
|
| 311 |
+
#self.register_parameter('relative_position_bias_table_to_windows',[])
|
| 312 |
+
#self.register_parameter('relative_position_bias_table_to_windows_clips',[])
|
| 313 |
+
self.unfolds = nn.ModuleList()
|
| 314 |
+
self.unfolds_clips=nn.ModuleList()
|
| 315 |
+
|
| 316 |
+
# build relative position bias between local patch and pooled windows
|
| 317 |
+
for k in range(focal_level-1):
|
| 318 |
+
stride = 2**k
|
| 319 |
+
kernel_size = 2*(self.focal_window // 2) + 2**k + (2**k-1)
|
| 320 |
+
# define unfolding operations
|
| 321 |
+
self.unfolds += [nn.Unfold(
|
| 322 |
+
kernel_size=(kernel_size, kernel_size),
|
| 323 |
+
stride=stride, padding=kernel_size // 2)
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
# define relative position bias table
|
| 327 |
+
relative_position_bias_table_to_windows = nn.Parameter(
|
| 328 |
+
torch.zeros(
|
| 329 |
+
self.num_heads,
|
| 330 |
+
(self.window_size[0] + self.focal_window + 2**k - 2) * (self.window_size[1] + self.focal_window + 2**k - 2),
|
| 331 |
+
)
|
| 332 |
+
)
|
| 333 |
+
trunc_normal_(relative_position_bias_table_to_windows, std=.02)
|
| 334 |
+
#self.relative_position_bias_table_to_windows.append(relative_position_bias_table_to_windows)
|
| 335 |
+
self.register_parameter('relative_position_bias_table_to_windows_{}'.format(k),relative_position_bias_table_to_windows)
|
| 336 |
+
|
| 337 |
+
# define relative position bias index
|
| 338 |
+
relative_position_index_k = get_relative_position_index(self.window_size, to_2tuple(self.focal_window + 2**k - 1))
|
| 339 |
+
# relative_position_index_k = get_relative_position_index3d(self.window_size, to_2tuple(self.focal_window + 2**k - 1), num_clips)
|
| 340 |
+
self.register_buffer("relative_position_index_{}".format(k), relative_position_index_k)
|
| 341 |
+
|
| 342 |
+
# define unfolding index for focal_level > 0
|
| 343 |
+
if k > 0:
|
| 344 |
+
mask = torch.zeros(kernel_size, kernel_size); mask[(2**k)-1:, (2**k)-1:] = 1
|
| 345 |
+
self.register_buffer("valid_ind_unfold_{}".format(k), mask.flatten(0).nonzero().view(-1))
|
| 346 |
+
|
| 347 |
+
for k in range(len(focal_l_clips)):
|
| 348 |
+
# kernel_size=focal_kernel_clips[k]
|
| 349 |
+
focal_l_big_flag=False
|
| 350 |
+
if focal_l_clips[k]>self.window_size[0]:
|
| 351 |
+
stride=1
|
| 352 |
+
padding=0
|
| 353 |
+
kernel_size=focal_kernel_clips[k]
|
| 354 |
+
kernel_size_true=kernel_size
|
| 355 |
+
focal_l_big_flag=True
|
| 356 |
+
# stride=math.ceil(self.window_size/focal_l_clips[k])
|
| 357 |
+
# padding=(kernel_size-stride)/2
|
| 358 |
+
else:
|
| 359 |
+
stride = focal_l_clips[k]
|
| 360 |
+
# kernel_size
|
| 361 |
+
# kernel_size = 2*(focal_kernel_clips[k]// 2) + 2**focal_l_clips[k] + (2**focal_l_clips[k]-1)
|
| 362 |
+
kernel_size = focal_kernel_clips[k] ## kernel_size must be jishu
|
| 363 |
+
assert kernel_size%2==1
|
| 364 |
+
padding=kernel_size // 2
|
| 365 |
+
# kernel_size_true=focal_kernel_clips[k]+2**focal_l_clips[k]-1
|
| 366 |
+
kernel_size_true=kernel_size
|
| 367 |
+
# stride=math.ceil(self.window_size/focal_l_clips[k])
|
| 368 |
+
|
| 369 |
+
self.unfolds_clips += [nn.Unfold(
|
| 370 |
+
kernel_size=(kernel_size, kernel_size),
|
| 371 |
+
stride=stride,
|
| 372 |
+
padding=padding)
|
| 373 |
+
]
|
| 374 |
+
relative_position_bias_table_to_windows = nn.Parameter(
|
| 375 |
+
torch.zeros(
|
| 376 |
+
self.num_heads,
|
| 377 |
+
(self.window_size[0] + kernel_size_true - 1) * (self.window_size[0] + kernel_size_true - 1),
|
| 378 |
+
)
|
| 379 |
+
)
|
| 380 |
+
trunc_normal_(relative_position_bias_table_to_windows, std=.02)
|
| 381 |
+
#self.relative_position_bias_table_to_windows_clips.append(relative_position_bias_table_to_windows)
|
| 382 |
+
self.register_parameter('relative_position_bias_table_to_windows_clips_{}'.format(k),relative_position_bias_table_to_windows)
|
| 383 |
+
relative_position_index_k = get_relative_position_index(self.window_size, to_2tuple(kernel_size_true))
|
| 384 |
+
self.register_buffer("relative_position_index_clips_{}".format(k), relative_position_index_k)
|
| 385 |
+
# if (not focal_l_big_flag) and focal_l_clips[k]>0:
|
| 386 |
+
# mask = torch.zeros(kernel_size, kernel_size); mask[(2**focal_l_clips[k])-1:, (2**focal_l_clips[k])-1:] = 1
|
| 387 |
+
# self.register_buffer("valid_ind_unfold_clips_{}".format(k), mask.flatten(0).nonzero().view(-1))
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 392 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 393 |
+
self.proj = nn.Linear(dim, dim)
|
| 394 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 395 |
+
|
| 396 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 397 |
+
self.focal_l_clips=focal_l_clips
|
| 398 |
+
self.focal_kernel_clips=focal_kernel_clips
|
| 399 |
+
|
| 400 |
+
def forward(self, x_all, mask_all=None, batch_size=None, num_clips=None):
|
| 401 |
+
"""
|
| 402 |
+
Args:
|
| 403 |
+
x_all (list[Tensors]): input features at different granularity
|
| 404 |
+
mask_all (list[Tensors/None]): masks for input features at different granularity
|
| 405 |
+
"""
|
| 406 |
+
x = x_all[0][0] #
|
| 407 |
+
|
| 408 |
+
B0, nH, nW, C = x.shape
|
| 409 |
+
# assert B==batch_size*num_clips
|
| 410 |
+
assert B0==batch_size
|
| 411 |
+
qkv = self.qkv(x).reshape(B0, nH, nW, 3, C).permute(3, 0, 1, 2, 4).contiguous()
|
| 412 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B0, nH, nW, C
|
| 413 |
+
|
| 414 |
+
# partition q map
|
| 415 |
+
# print("x.shape: ", x.shape)
|
| 416 |
+
# print("q.shape: ", q.shape) # [4, 126, 126, 256]
|
| 417 |
+
(q_windows, k_windows, v_windows) = map(
|
| 418 |
+
lambda t: window_partition(t, self.window_size[0]).view(
|
| 419 |
+
-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads
|
| 420 |
+
).transpose(1, 2),
|
| 421 |
+
(q, k, v)
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# q_dim0, q_dim1, q_dim2, q_dim3=q_windows.shape
|
| 425 |
+
# q_windows=q_windows.view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]), q_dim1, q_dim2, q_dim3)
|
| 426 |
+
# q_windows=q_windows[:,-1].contiguous().view(-1, q_dim1, q_dim2, q_dim3) # query for the last frame (target frame)
|
| 427 |
+
|
| 428 |
+
# k_windows.shape [1296, 8, 49, 32]
|
| 429 |
+
|
| 430 |
+
if self.expand_size > 0 and self.focal_level > 0:
|
| 431 |
+
(k_tl, v_tl) = map(
|
| 432 |
+
lambda t: torch.roll(t, shifts=(-self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
|
| 433 |
+
)
|
| 434 |
+
(k_tr, v_tr) = map(
|
| 435 |
+
lambda t: torch.roll(t, shifts=(-self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
|
| 436 |
+
)
|
| 437 |
+
(k_bl, v_bl) = map(
|
| 438 |
+
lambda t: torch.roll(t, shifts=(self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
|
| 439 |
+
)
|
| 440 |
+
(k_br, v_br) = map(
|
| 441 |
+
lambda t: torch.roll(t, shifts=(self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
|
| 445 |
+
lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads),
|
| 446 |
+
(k_tl, k_tr, k_bl, k_br)
|
| 447 |
+
)
|
| 448 |
+
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
|
| 449 |
+
lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads),
|
| 450 |
+
(v_tl, v_tr, v_bl, v_br)
|
| 451 |
+
)
|
| 452 |
+
k_rolled = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 1).transpose(1, 2)
|
| 453 |
+
v_rolled = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 1).transpose(1, 2)
|
| 454 |
+
|
| 455 |
+
# mask out tokens in current window
|
| 456 |
+
# print("self.valid_ind_rolled.shape: ", self.valid_ind_rolled.shape) # [132]
|
| 457 |
+
# print("k_rolled.shape: ", k_rolled.shape) # [1296, 8, 196, 32]
|
| 458 |
+
k_rolled = k_rolled[:, :, self.valid_ind_rolled]
|
| 459 |
+
v_rolled = v_rolled[:, :, self.valid_ind_rolled]
|
| 460 |
+
k_rolled = torch.cat((k_windows, k_rolled), 2)
|
| 461 |
+
v_rolled = torch.cat((v_windows, v_rolled), 2)
|
| 462 |
+
else:
|
| 463 |
+
k_rolled = k_windows; v_rolled = v_windows;
|
| 464 |
+
|
| 465 |
+
# print("k_rolled.shape: ", k_rolled.shape) # [1296, 8, 181, 32]
|
| 466 |
+
|
| 467 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
| 468 |
+
k_pooled = []
|
| 469 |
+
v_pooled = []
|
| 470 |
+
for k in range(self.focal_level-1):
|
| 471 |
+
stride = 2**k
|
| 472 |
+
x_window_pooled = x_all[0][k+1] # B0, nWh, nWw, C
|
| 473 |
+
nWh, nWw = x_window_pooled.shape[1:3]
|
| 474 |
+
|
| 475 |
+
# generate mask for pooled windows
|
| 476 |
+
# print("x_window_pooled.shape: ", x_window_pooled.shape)
|
| 477 |
+
mask = x_window_pooled.new(nWh, nWw).fill_(1)
|
| 478 |
+
# print("here: ",x_window_pooled.shape, self.unfolds[k].kernel_size, self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).shape)
|
| 479 |
+
# print(mask.unique())
|
| 480 |
+
unfolded_mask = self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).view(
|
| 481 |
+
1, 1, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
|
| 482 |
+
view(nWh*nWw // stride // stride, -1, 1)
|
| 483 |
+
|
| 484 |
+
if k > 0:
|
| 485 |
+
valid_ind_unfold_k = getattr(self, "valid_ind_unfold_{}".format(k))
|
| 486 |
+
unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
|
| 487 |
+
|
| 488 |
+
# print("unfolded_mask.shape: ", unfolded_mask.shape, unfolded_mask.unique())
|
| 489 |
+
x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
|
| 490 |
+
# print((x_window_masks == 0).sum(), (x_window_masks > 0).sum(), x_window_masks.unique())
|
| 491 |
+
x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
|
| 492 |
+
# print(x_window_masks.shape)
|
| 493 |
+
mask_all[0][k+1] = x_window_masks
|
| 494 |
+
|
| 495 |
+
# generate k and v for pooled windows
|
| 496 |
+
qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
|
| 497 |
+
k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
(k_pooled_k, v_pooled_k) = map(
|
| 501 |
+
lambda t: self.unfolds[k](t).view(
|
| 502 |
+
B0, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
|
| 503 |
+
view(-1, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
|
| 504 |
+
(k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# print("k_pooled_k.shape: ", k_pooled_k.shape)
|
| 508 |
+
# print("valid_ind_unfold_k.shape: ", valid_ind_unfold_k.shape)
|
| 509 |
+
|
| 510 |
+
if k > 0:
|
| 511 |
+
(k_pooled_k, v_pooled_k) = map(
|
| 512 |
+
lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k)
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
# print("k_pooled_k.shape: ", k_pooled_k.shape)
|
| 516 |
+
|
| 517 |
+
k_pooled += [k_pooled_k]
|
| 518 |
+
v_pooled += [v_pooled_k]
|
| 519 |
+
|
| 520 |
+
for k in range(len(self.focal_l_clips)):
|
| 521 |
+
focal_l_big_flag=False
|
| 522 |
+
if self.focal_l_clips[k]>self.window_size[0]:
|
| 523 |
+
stride=1
|
| 524 |
+
focal_l_big_flag=True
|
| 525 |
+
else:
|
| 526 |
+
stride = self.focal_l_clips[k]
|
| 527 |
+
# if self.window_size>=focal_l_clips[k]:
|
| 528 |
+
# stride=math.ceil(self.window_size/focal_l_clips[k])
|
| 529 |
+
# # padding=(kernel_size-stride)/2
|
| 530 |
+
# else:
|
| 531 |
+
# stride=1
|
| 532 |
+
# padding=0
|
| 533 |
+
x_window_pooled = x_all[k+1]
|
| 534 |
+
nWh, nWw = x_window_pooled.shape[1:3]
|
| 535 |
+
mask = x_window_pooled.new(nWh, nWw).fill_(1)
|
| 536 |
+
|
| 537 |
+
# import pdb; pdb.set_trace()
|
| 538 |
+
# print(x_window_pooled.shape, self.unfolds_clips[k].kernel_size, self.unfolds_clips[k](mask.unsqueeze(0).unsqueeze(1)).shape)
|
| 539 |
+
|
| 540 |
+
unfolded_mask = self.unfolds_clips[k](mask.unsqueeze(0).unsqueeze(1)).view(
|
| 541 |
+
1, 1, self.unfolds_clips[k].kernel_size[0], self.unfolds_clips[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
|
| 542 |
+
view(nWh*nWw // stride // stride, -1, 1)
|
| 543 |
+
|
| 544 |
+
# if (not focal_l_big_flag) and self.focal_l_clips[k]>0:
|
| 545 |
+
# valid_ind_unfold_k = getattr(self, "valid_ind_unfold_clips_{}".format(k))
|
| 546 |
+
# unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
|
| 547 |
+
|
| 548 |
+
# print("unfolded_mask.shape: ", unfolded_mask.shape, unfolded_mask.unique())
|
| 549 |
+
x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
|
| 550 |
+
# print((x_window_masks == 0).sum(), (x_window_masks > 0).sum(), x_window_masks.unique())
|
| 551 |
+
x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
|
| 552 |
+
# print(x_window_masks.shape)
|
| 553 |
+
mask_all[k+1] = x_window_masks
|
| 554 |
+
|
| 555 |
+
# generate k and v for pooled windows
|
| 556 |
+
qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
|
| 557 |
+
k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw
|
| 558 |
+
|
| 559 |
+
if (not focal_l_big_flag):
|
| 560 |
+
(k_pooled_k, v_pooled_k) = map(
|
| 561 |
+
lambda t: self.unfolds_clips[k](t).view(
|
| 562 |
+
B0, C, self.unfolds_clips[k].kernel_size[0], self.unfolds_clips[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
|
| 563 |
+
view(-1, self.unfolds_clips[k].kernel_size[0]*self.unfolds_clips[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
|
| 564 |
+
(k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
|
| 565 |
+
)
|
| 566 |
+
else:
|
| 567 |
+
|
| 568 |
+
(k_pooled_k, v_pooled_k) = map(
|
| 569 |
+
lambda t: self.unfolds_clips[k](t),
|
| 570 |
+
(k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
|
| 571 |
+
)
|
| 572 |
+
LLL=k_pooled_k.size(2)
|
| 573 |
+
LLL_h=int(LLL**0.5)
|
| 574 |
+
assert LLL_h**2==LLL
|
| 575 |
+
k_pooled_k=k_pooled_k.reshape(B0, -1, LLL_h, LLL_h)
|
| 576 |
+
v_pooled_k=v_pooled_k.reshape(B0, -1, LLL_h, LLL_h)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
# print("k_pooled_k.shape: ", k_pooled_k.shape)
|
| 581 |
+
# print("valid_ind_unfold_k.shape: ", valid_ind_unfold_k.shape)
|
| 582 |
+
# if (not focal_l_big_flag) and self.focal_l_clips[k]:
|
| 583 |
+
# (k_pooled_k, v_pooled_k) = map(
|
| 584 |
+
# lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k)
|
| 585 |
+
# )
|
| 586 |
+
|
| 587 |
+
# print("k_pooled_k.shape: ", k_pooled_k.shape)
|
| 588 |
+
|
| 589 |
+
k_pooled += [k_pooled_k]
|
| 590 |
+
v_pooled += [v_pooled_k]
|
| 591 |
+
|
| 592 |
+
# qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
|
| 593 |
+
# k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw
|
| 594 |
+
# (k_pooled_k, v_pooled_k) = map(
|
| 595 |
+
# lambda t: self.unfolds[k](t).view(
|
| 596 |
+
# B0, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
|
| 597 |
+
# view(-1, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
|
| 598 |
+
# (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
|
| 599 |
+
# )
|
| 600 |
+
# k_pooled += [k_pooled_k]
|
| 601 |
+
# v_pooled += [v_pooled_k]
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
k_all = torch.cat([k_rolled] + k_pooled, 2)
|
| 605 |
+
v_all = torch.cat([v_rolled] + v_pooled, 2)
|
| 606 |
+
else:
|
| 607 |
+
k_all = k_rolled
|
| 608 |
+
v_all = v_rolled
|
| 609 |
+
|
| 610 |
+
N = k_all.shape[-2]
|
| 611 |
+
q_windows = q_windows * self.scale
|
| 612 |
+
# print(q_windows.shape, k_all.shape, v_all.shape)
|
| 613 |
+
# exit()
|
| 614 |
+
# k_all_dim0, k_all_dim1, k_all_dim2, k_all_dim3=k_all.shape
|
| 615 |
+
# k_all=k_all.contiguous().view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]),
|
| 616 |
+
# k_all_dim1, k_all_dim2, k_all_dim3).permute(0,2,3,4,1,5).contiguous().view(-1, k_all_dim1, k_all_dim2*num_clips, k_all_dim3)
|
| 617 |
+
# v_all=v_all.contiguous().view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]),
|
| 618 |
+
# k_all_dim1, k_all_dim2, k_all_dim3).permute(0,2,3,4,1,5).contiguous().view(-1, k_all_dim1, k_all_dim2*num_clips, k_all_dim3)
|
| 619 |
+
|
| 620 |
+
# print(q_windows.shape, k_all.shape, v_all.shape, k_rolled.shape)
|
| 621 |
+
# exit()
|
| 622 |
+
attn = (q_windows @ k_all.transpose(-2, -1)) # B0*nW, nHead, window_size*window_size, focal_window_size*focal_window_size
|
| 623 |
+
|
| 624 |
+
window_area = self.window_size[0] * self.window_size[1]
|
| 625 |
+
# window_area_clips= num_clips*self.window_size[0] * self.window_size[1]
|
| 626 |
+
window_area_rolled = k_rolled.shape[2]
|
| 627 |
+
|
| 628 |
+
# add relative position bias for tokens inside window
|
| 629 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 630 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
| 631 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 632 |
+
# print(relative_position_bias.shape, attn.shape)
|
| 633 |
+
attn[:, :, :window_area, :window_area] = attn[:, :, :window_area, :window_area] + relative_position_bias.unsqueeze(0)
|
| 634 |
+
|
| 635 |
+
# relative_position_bias = self.relative_position_bias_table[self.relative_position_index[-window_area:, :window_area_clips].reshape(-1)].view(
|
| 636 |
+
# window_area, window_area_clips, -1) # Wh*Ww,Wd*Wh*Ww,nH
|
| 637 |
+
# relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().view(self.num_heads,window_area,num_clips,window_area
|
| 638 |
+
# ).permute(0,1,3,2).contiguous().view(self.num_heads,window_area,window_area_clips).contiguous() # nH, Wh*Ww, Wh*Ww*Wd
|
| 639 |
+
# # attn_dim0, attn_dim1, attn_dim2, attn_dim3=attn.shape
|
| 640 |
+
# # attn=attn.view(attn_dim0,attn_dim1,attn_dim2,num_clips,-1)
|
| 641 |
+
# # print(attn.shape, relative_position_bias.shape)
|
| 642 |
+
# attn[:,:,:window_area, :window_area_clips]=attn[:,:,:window_area, :window_area_clips] + relative_position_bias.unsqueeze(0)
|
| 643 |
+
# attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
|
| 644 |
+
|
| 645 |
+
# add relative position bias for patches inside a window
|
| 646 |
+
if self.expand_size > 0 and self.focal_level > 0:
|
| 647 |
+
attn[:, :, :window_area, window_area:window_area_rolled] = attn[:, :, :window_area, window_area:window_area_rolled] + self.relative_position_bias_table_to_neighbors
|
| 648 |
+
|
| 649 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
| 650 |
+
# add relative position bias for different windows in an image
|
| 651 |
+
offset = window_area_rolled
|
| 652 |
+
# print(offset)
|
| 653 |
+
for k in range(self.focal_level-1):
|
| 654 |
+
# add relative position bias
|
| 655 |
+
relative_position_index_k = getattr(self, 'relative_position_index_{}'.format(k))
|
| 656 |
+
relative_position_bias_to_windows = getattr(self,'relative_position_bias_table_to_windows_{}'.format(k))[:, relative_position_index_k.view(-1)].view(
|
| 657 |
+
-1, self.window_size[0] * self.window_size[1], (self.focal_window+2**k-1)**2,
|
| 658 |
+
) # nH, NWh*NWw,focal_region*focal_region
|
| 659 |
+
attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
|
| 660 |
+
attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0)
|
| 661 |
+
# add attentional mask
|
| 662 |
+
if mask_all[0][k+1] is not None:
|
| 663 |
+
attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
|
| 664 |
+
attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + \
|
| 665 |
+
mask_all[0][k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[0][k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[0][k+1].shape[-1])
|
| 666 |
+
|
| 667 |
+
offset += (self.focal_window+2**k-1)**2
|
| 668 |
+
# print(offset)
|
| 669 |
+
for k in range(len(self.focal_l_clips)):
|
| 670 |
+
focal_l_big_flag=False
|
| 671 |
+
if self.focal_l_clips[k]>self.window_size[0]:
|
| 672 |
+
stride=1
|
| 673 |
+
padding=0
|
| 674 |
+
kernel_size=self.focal_kernel_clips[k]
|
| 675 |
+
kernel_size_true=kernel_size
|
| 676 |
+
focal_l_big_flag=True
|
| 677 |
+
# stride=math.ceil(self.window_size/focal_l_clips[k])
|
| 678 |
+
# padding=(kernel_size-stride)/2
|
| 679 |
+
else:
|
| 680 |
+
stride = self.focal_l_clips[k]
|
| 681 |
+
# kernel_size
|
| 682 |
+
# kernel_size = 2*(self.focal_kernel_clips[k]// 2) + 2**self.focal_l_clips[k] + (2**self.focal_l_clips[k]-1)
|
| 683 |
+
kernel_size = self.focal_kernel_clips[k]
|
| 684 |
+
padding=kernel_size // 2
|
| 685 |
+
# kernel_size_true=self.focal_kernel_clips[k]+2**self.focal_l_clips[k]-1
|
| 686 |
+
kernel_size_true=kernel_size
|
| 687 |
+
relative_position_index_k = getattr(self, 'relative_position_index_clips_{}'.format(k))
|
| 688 |
+
relative_position_bias_to_windows = getattr(self,'relative_position_bias_table_to_windows_clips_{}'.format(k))[:, relative_position_index_k.view(-1)].view(
|
| 689 |
+
-1, self.window_size[0] * self.window_size[1], (kernel_size_true)**2,
|
| 690 |
+
)
|
| 691 |
+
attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] = \
|
| 692 |
+
attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] + relative_position_bias_to_windows.unsqueeze(0)
|
| 693 |
+
if mask_all[k+1] is not None:
|
| 694 |
+
attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] = \
|
| 695 |
+
attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] + \
|
| 696 |
+
mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
|
| 697 |
+
offset += (kernel_size_true)**2
|
| 698 |
+
# print(offset)
|
| 699 |
+
# relative_position_index_k = getattr(self, 'relative_position_index_{}'.format(k))
|
| 700 |
+
# # relative_position_bias_to_windows = self.relative_position_bias_table_to_windows[k][:, relative_position_index_k.view(-1)].view(
|
| 701 |
+
# # -1, self.window_size[0] * self.window_size[1], (self.focal_window+2**k-1)**2,
|
| 702 |
+
# # ) # nH, NWh*NWw,focal_region*focal_region
|
| 703 |
+
# # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
|
| 704 |
+
# # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0)
|
| 705 |
+
# relative_position_bias_to_windows = self.relative_position_bias_table_to_windows[k][:, relative_position_index_k[-window_area:, :].view(-1)].view(
|
| 706 |
+
# -1, self.window_size[0] * self.window_size[1], num_clips*(self.focal_window+2**k-1)**2,
|
| 707 |
+
# ).contiguous() # nH, NWh*NWw, num_clips*focal_region*focal_region
|
| 708 |
+
# relative_position_bias_to_windows = relative_position_bias_to_windows.view(self.num_heads,
|
| 709 |
+
# window_area,num_clips,-1).permute(0,1,3,2).contiguous().view(self.num_heads,window_area,-1)
|
| 710 |
+
# attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] = \
|
| 711 |
+
# attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0)
|
| 712 |
+
# # add attentional mask
|
| 713 |
+
# if mask_all[k+1] is not None:
|
| 714 |
+
# # print("inside the mask, be careful 1")
|
| 715 |
+
# # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
|
| 716 |
+
# # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + \
|
| 717 |
+
# # mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
|
| 718 |
+
# # print("here: ", mask_all[k+1].shape, mask_all[k+1][:, :, None, None, :].shape)
|
| 719 |
+
|
| 720 |
+
# attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] = \
|
| 721 |
+
# attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] + \
|
| 722 |
+
# mask_all[k+1][:, :, None, None, :,None].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1, num_clips).view(-1, 1, 1, mask_all[k+1].shape[-1]*num_clips)
|
| 723 |
+
# # print()
|
| 724 |
+
|
| 725 |
+
# offset += (self.focal_window+2**k-1)**2
|
| 726 |
+
|
| 727 |
+
# print("mask_all[0]: ", mask_all[0])
|
| 728 |
+
# exit()
|
| 729 |
+
if mask_all[0][0] is not None:
|
| 730 |
+
print("inside the mask, be careful 0")
|
| 731 |
+
nW = mask_all[0].shape[0]
|
| 732 |
+
attn = attn.view(attn.shape[0] // nW, nW, self.num_heads, window_area, N)
|
| 733 |
+
attn[:, :, :, :, :window_area] = attn[:, :, :, :, :window_area] + mask_all[0][None, :, None, :, :]
|
| 734 |
+
attn = attn.view(-1, self.num_heads, window_area, N)
|
| 735 |
+
attn = self.softmax(attn)
|
| 736 |
+
else:
|
| 737 |
+
attn = self.softmax(attn)
|
| 738 |
+
|
| 739 |
+
attn = self.attn_drop(attn)
|
| 740 |
+
|
| 741 |
+
x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area, C)
|
| 742 |
+
x = self.proj(x)
|
| 743 |
+
x = self.proj_drop(x)
|
| 744 |
+
# print(x.shape)
|
| 745 |
+
# x = x.view(B/num_clips, nH, nW, C )
|
| 746 |
+
# exit()
|
| 747 |
+
return x
|
| 748 |
+
|
| 749 |
+
def extra_repr(self) -> str:
|
| 750 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
| 751 |
+
|
| 752 |
+
def flops(self, N, window_size, unfold_size):
|
| 753 |
+
# calculate flops for 1 window with token length of N
|
| 754 |
+
flops = 0
|
| 755 |
+
# qkv = self.qkv(x)
|
| 756 |
+
flops += N * self.dim * 3 * self.dim
|
| 757 |
+
# attn = (q @ k.transpose(-2, -1))
|
| 758 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
| 759 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
| 760 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size)
|
| 761 |
+
if self.expand_size > 0 and self.focal_level > 0:
|
| 762 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * ((window_size + 2*self.expand_size)**2-window_size**2)
|
| 763 |
+
|
| 764 |
+
# x = (attn @ v)
|
| 765 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
| 766 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
| 767 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size)
|
| 768 |
+
if self.expand_size > 0 and self.focal_level > 0:
|
| 769 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * ((window_size + 2*self.expand_size)**2-window_size**2)
|
| 770 |
+
|
| 771 |
+
# x = self.proj(x)
|
| 772 |
+
flops += N * self.dim * self.dim
|
| 773 |
+
return flops
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
class CffmTransformerBlock3d3(nn.Module):
|
| 777 |
+
r""" Focal Transformer Block.
|
| 778 |
+
|
| 779 |
+
Args:
|
| 780 |
+
dim (int): Number of input channels.
|
| 781 |
+
input_resolution (tuple[int]): Input resulotion.
|
| 782 |
+
num_heads (int): Number of attention heads.
|
| 783 |
+
window_size (int): Window size.
|
| 784 |
+
expand_size (int): expand size at first focal level (finest level).
|
| 785 |
+
shift_size (int): Shift size for SW-MSA.
|
| 786 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 787 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 788 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 789 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 790 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 791 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 792 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 793 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 794 |
+
pool_method (str): window pooling method. Default: none, options: [none|fc|conv]
|
| 795 |
+
focal_level (int): number of focal levels. Default: 1.
|
| 796 |
+
focal_window (int): region size of focal attention. Default: 1
|
| 797 |
+
use_layerscale (bool): whether use layer scale for training stability. Default: False
|
| 798 |
+
layerscale_value (float): scaling value for layer scale. Default: 1e-4
|
| 799 |
+
"""
|
| 800 |
+
|
| 801 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, expand_size=0, shift_size=0,
|
| 802 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
| 803 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pool_method="none",
|
| 804 |
+
focal_level=1, focal_window=1, use_layerscale=False, layerscale_value=1e-4, focal_l_clips=[7,2,4], focal_kernel_clips=[7,5,3]):
|
| 805 |
+
super().__init__()
|
| 806 |
+
self.dim = dim
|
| 807 |
+
self.input_resolution = input_resolution
|
| 808 |
+
self.num_heads = num_heads
|
| 809 |
+
self.window_size = window_size
|
| 810 |
+
self.shift_size = shift_size
|
| 811 |
+
self.expand_size = expand_size
|
| 812 |
+
self.mlp_ratio = mlp_ratio
|
| 813 |
+
self.pool_method = pool_method
|
| 814 |
+
self.focal_level = focal_level
|
| 815 |
+
self.focal_window = focal_window
|
| 816 |
+
self.use_layerscale = use_layerscale
|
| 817 |
+
self.focal_l_clips=focal_l_clips
|
| 818 |
+
self.focal_kernel_clips=focal_kernel_clips
|
| 819 |
+
|
| 820 |
+
if min(self.input_resolution) <= self.window_size:
|
| 821 |
+
# if window size is larger than input resolution, we don't partition windows
|
| 822 |
+
self.expand_size = 0
|
| 823 |
+
self.shift_size = 0
|
| 824 |
+
self.window_size = min(self.input_resolution)
|
| 825 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 826 |
+
|
| 827 |
+
self.window_size_glo = self.window_size
|
| 828 |
+
|
| 829 |
+
self.pool_layers = nn.ModuleList()
|
| 830 |
+
self.pool_layers_clips = nn.ModuleList()
|
| 831 |
+
if self.pool_method != "none":
|
| 832 |
+
for k in range(self.focal_level-1):
|
| 833 |
+
window_size_glo = math.floor(self.window_size_glo / (2 ** k))
|
| 834 |
+
if self.pool_method == "fc":
|
| 835 |
+
self.pool_layers.append(nn.Linear(window_size_glo * window_size_glo, 1))
|
| 836 |
+
self.pool_layers[-1].weight.data.fill_(1./(window_size_glo * window_size_glo))
|
| 837 |
+
self.pool_layers[-1].bias.data.fill_(0)
|
| 838 |
+
elif self.pool_method == "conv":
|
| 839 |
+
self.pool_layers.append(nn.Conv2d(dim, dim, kernel_size=window_size_glo, stride=window_size_glo, groups=dim))
|
| 840 |
+
for k in range(len(focal_l_clips)):
|
| 841 |
+
# window_size_glo = math.floor(self.window_size_glo / (2 ** k))
|
| 842 |
+
if focal_l_clips[k]>self.window_size:
|
| 843 |
+
window_size_glo = focal_l_clips[k]
|
| 844 |
+
else:
|
| 845 |
+
window_size_glo = math.floor(self.window_size_glo / (focal_l_clips[k]))
|
| 846 |
+
# window_size_glo = focal_l_clips[k]
|
| 847 |
+
if self.pool_method == "fc":
|
| 848 |
+
self.pool_layers_clips.append(nn.Linear(window_size_glo * window_size_glo, 1))
|
| 849 |
+
self.pool_layers_clips[-1].weight.data.fill_(1./(window_size_glo * window_size_glo))
|
| 850 |
+
self.pool_layers_clips[-1].bias.data.fill_(0)
|
| 851 |
+
elif self.pool_method == "conv":
|
| 852 |
+
self.pool_layers_clips.append(nn.Conv2d(dim, dim, kernel_size=window_size_glo, stride=window_size_glo, groups=dim))
|
| 853 |
+
|
| 854 |
+
self.norm1 = norm_layer(dim)
|
| 855 |
+
|
| 856 |
+
self.attn = WindowAttention3d3(
|
| 857 |
+
dim, expand_size=self.expand_size, window_size=to_2tuple(self.window_size),
|
| 858 |
+
focal_window=focal_window, focal_level=focal_level, num_heads=num_heads,
|
| 859 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pool_method=pool_method, focal_l_clips=focal_l_clips, focal_kernel_clips=focal_kernel_clips)
|
| 860 |
+
|
| 861 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 862 |
+
self.norm2 = norm_layer(dim)
|
| 863 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 864 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 865 |
+
|
| 866 |
+
# print("******self.shift_size: ", self.shift_size)
|
| 867 |
+
|
| 868 |
+
if self.shift_size > 0:
|
| 869 |
+
# calculate attention mask for SW-MSA
|
| 870 |
+
H, W = self.input_resolution
|
| 871 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
| 872 |
+
h_slices = (slice(0, -self.window_size),
|
| 873 |
+
slice(-self.window_size, -self.shift_size),
|
| 874 |
+
slice(-self.shift_size, None))
|
| 875 |
+
w_slices = (slice(0, -self.window_size),
|
| 876 |
+
slice(-self.window_size, -self.shift_size),
|
| 877 |
+
slice(-self.shift_size, None))
|
| 878 |
+
cnt = 0
|
| 879 |
+
for h in h_slices:
|
| 880 |
+
for w in w_slices:
|
| 881 |
+
img_mask[:, h, w, :] = cnt
|
| 882 |
+
cnt += 1
|
| 883 |
+
|
| 884 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 885 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
| 886 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 887 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 888 |
+
else:
|
| 889 |
+
# print("here mask none")
|
| 890 |
+
attn_mask = None
|
| 891 |
+
self.register_buffer("attn_mask", attn_mask)
|
| 892 |
+
|
| 893 |
+
if self.use_layerscale:
|
| 894 |
+
self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
|
| 895 |
+
self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
|
| 896 |
+
|
| 897 |
+
def forward(self, x):
|
| 898 |
+
H0, W0 = self.input_resolution
|
| 899 |
+
# B, L, C = x.shape
|
| 900 |
+
B0, D0, H0, W0, C = x.shape
|
| 901 |
+
shortcut = x
|
| 902 |
+
# assert L == H * W, "input feature has wrong size"
|
| 903 |
+
x=x.reshape(B0*D0,H0,W0,C).reshape(B0*D0,H0*W0,C)
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
x = self.norm1(x)
|
| 907 |
+
x = x.reshape(B0*D0, H0, W0, C)
|
| 908 |
+
# print("here")
|
| 909 |
+
# exit()
|
| 910 |
+
|
| 911 |
+
# pad feature maps to multiples of window size
|
| 912 |
+
pad_l = pad_t = 0
|
| 913 |
+
pad_r = (self.window_size - W0 % self.window_size) % self.window_size
|
| 914 |
+
pad_b = (self.window_size - H0 % self.window_size) % self.window_size
|
| 915 |
+
if pad_r > 0 or pad_b > 0:
|
| 916 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 917 |
+
|
| 918 |
+
B, H, W, C = x.shape ## B=B0*D0
|
| 919 |
+
|
| 920 |
+
if self.shift_size > 0:
|
| 921 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 922 |
+
else:
|
| 923 |
+
shifted_x = x
|
| 924 |
+
|
| 925 |
+
# print("shifted_x.shape: ", shifted_x.shape)
|
| 926 |
+
shifted_x=shifted_x.view(B0,D0,H,W,C)
|
| 927 |
+
x_windows_all = [shifted_x[:,-1]]
|
| 928 |
+
x_windows_all_clips=[]
|
| 929 |
+
x_window_masks_all = [self.attn_mask]
|
| 930 |
+
x_window_masks_all_clips=[]
|
| 931 |
+
|
| 932 |
+
if self.focal_level > 1 and self.pool_method != "none":
|
| 933 |
+
# if we add coarser granularity and the pool method is not none
|
| 934 |
+
# pooling_index=0
|
| 935 |
+
for k in range(self.focal_level-1):
|
| 936 |
+
window_size_glo = math.floor(self.window_size_glo / (2 ** k))
|
| 937 |
+
pooled_h = math.ceil(H / self.window_size) * (2 ** k)
|
| 938 |
+
pooled_w = math.ceil(W / self.window_size) * (2 ** k)
|
| 939 |
+
H_pool = pooled_h * window_size_glo
|
| 940 |
+
W_pool = pooled_w * window_size_glo
|
| 941 |
+
|
| 942 |
+
x_level_k = shifted_x[:,-1]
|
| 943 |
+
# trim or pad shifted_x depending on the required size
|
| 944 |
+
if H > H_pool:
|
| 945 |
+
trim_t = (H - H_pool) // 2
|
| 946 |
+
trim_b = H - H_pool - trim_t
|
| 947 |
+
x_level_k = x_level_k[:, trim_t:-trim_b]
|
| 948 |
+
elif H < H_pool:
|
| 949 |
+
pad_t = (H_pool - H) // 2
|
| 950 |
+
pad_b = H_pool - H - pad_t
|
| 951 |
+
x_level_k = F.pad(x_level_k, (0,0,0,0,pad_t,pad_b))
|
| 952 |
+
|
| 953 |
+
if W > W_pool:
|
| 954 |
+
trim_l = (W - W_pool) // 2
|
| 955 |
+
trim_r = W - W_pool - trim_l
|
| 956 |
+
x_level_k = x_level_k[:, :, trim_l:-trim_r]
|
| 957 |
+
elif W < W_pool:
|
| 958 |
+
pad_l = (W_pool - W) // 2
|
| 959 |
+
pad_r = W_pool - W - pad_l
|
| 960 |
+
x_level_k = F.pad(x_level_k, (0,0,pad_l,pad_r))
|
| 961 |
+
|
| 962 |
+
x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), window_size_glo) # B0, nw, nw, window_size, window_size, C
|
| 963 |
+
nWh, nWw = x_windows_noreshape.shape[1:3]
|
| 964 |
+
if self.pool_method == "mean":
|
| 965 |
+
x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B0, nWh, nWw, C
|
| 966 |
+
elif self.pool_method == "max":
|
| 967 |
+
x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B0, nWh, nWw, C) # B0, nWh, nWw, C
|
| 968 |
+
elif self.pool_method == "fc":
|
| 969 |
+
x_windows_noreshape = x_windows_noreshape.view(B0, nWh, nWw, window_size_glo*window_size_glo, C).transpose(3, 4) # B0, nWh, nWw, C, wsize**2
|
| 970 |
+
x_windows_pooled = self.pool_layers[k](x_windows_noreshape).flatten(-2) # B0, nWh, nWw, C
|
| 971 |
+
elif self.pool_method == "conv":
|
| 972 |
+
x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, 3, 1, 2).contiguous() # B0 * nw * nw, C, wsize, wsize
|
| 973 |
+
x_windows_pooled = self.pool_layers[k](x_windows_noreshape).view(B0, nWh, nWw, C) # B0, nWh, nWw, C
|
| 974 |
+
|
| 975 |
+
x_windows_all += [x_windows_pooled]
|
| 976 |
+
# print(x_windows_pooled.shape)
|
| 977 |
+
x_window_masks_all += [None]
|
| 978 |
+
# pooling_index=pooling_index+1
|
| 979 |
+
|
| 980 |
+
x_windows_all_clips += [x_windows_all]
|
| 981 |
+
x_window_masks_all_clips += [x_window_masks_all]
|
| 982 |
+
for k in range(len(self.focal_l_clips)):
|
| 983 |
+
if self.focal_l_clips[k]>self.window_size:
|
| 984 |
+
window_size_glo = self.focal_l_clips[k]
|
| 985 |
+
else:
|
| 986 |
+
window_size_glo = math.floor(self.window_size_glo / (self.focal_l_clips[k]))
|
| 987 |
+
|
| 988 |
+
pooled_h = math.ceil(H / self.window_size) * (self.focal_l_clips[k])
|
| 989 |
+
pooled_w = math.ceil(W / self.window_size) * (self.focal_l_clips[k])
|
| 990 |
+
|
| 991 |
+
H_pool = pooled_h * window_size_glo
|
| 992 |
+
W_pool = pooled_w * window_size_glo
|
| 993 |
+
|
| 994 |
+
x_level_k = shifted_x[:,k]
|
| 995 |
+
if H!=H_pool or W!=W_pool:
|
| 996 |
+
x_level_k=F.interpolate(x_level_k.permute(0,3,1,2), size=(H_pool, W_pool), mode='bilinear').permute(0,2,3,1)
|
| 997 |
+
|
| 998 |
+
# print(x_level_k.shape)
|
| 999 |
+
x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), window_size_glo) # B0, nw, nw, window_size, window_size, C
|
| 1000 |
+
nWh, nWw = x_windows_noreshape.shape[1:3]
|
| 1001 |
+
if self.pool_method == "mean":
|
| 1002 |
+
x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B0, nWh, nWw, C
|
| 1003 |
+
elif self.pool_method == "max":
|
| 1004 |
+
x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B0, nWh, nWw, C) # B0, nWh, nWw, C
|
| 1005 |
+
elif self.pool_method == "fc":
|
| 1006 |
+
x_windows_noreshape = x_windows_noreshape.view(B0, nWh, nWw, window_size_glo*window_size_glo, C).transpose(3, 4) # B0, nWh, nWw, C, wsize**2
|
| 1007 |
+
x_windows_pooled = self.pool_layers_clips[k](x_windows_noreshape).flatten(-2) # B0, nWh, nWw, C
|
| 1008 |
+
elif self.pool_method == "conv":
|
| 1009 |
+
x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, 3, 1, 2).contiguous() # B0 * nw * nw, C, wsize, wsize
|
| 1010 |
+
x_windows_pooled = self.pool_layers_clips[k](x_windows_noreshape).view(B0, nWh, nWw, C) # B0, nWh, nWw, C
|
| 1011 |
+
|
| 1012 |
+
x_windows_all_clips += [x_windows_pooled]
|
| 1013 |
+
# print(x_windows_pooled.shape)
|
| 1014 |
+
x_window_masks_all_clips += [None]
|
| 1015 |
+
# pooling_index=pooling_index+1
|
| 1016 |
+
# exit()
|
| 1017 |
+
|
| 1018 |
+
attn_windows = self.attn(x_windows_all_clips, mask_all=x_window_masks_all_clips, batch_size=B0, num_clips=D0) # nW*B0, window_size*window_size, C
|
| 1019 |
+
|
| 1020 |
+
attn_windows = attn_windows[:, :self.window_size ** 2]
|
| 1021 |
+
|
| 1022 |
+
# merge windows
|
| 1023 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 1024 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H(padded) W(padded) C
|
| 1025 |
+
|
| 1026 |
+
# reverse cyclic shift
|
| 1027 |
+
if self.shift_size > 0:
|
| 1028 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 1029 |
+
else:
|
| 1030 |
+
x = shifted_x
|
| 1031 |
+
# x = x[:, :self.input_resolution[0], :self.input_resolution[1]].contiguous().view(B, -1, C)
|
| 1032 |
+
x = x[:, :H0, :W0].contiguous().view(B0, -1, C)
|
| 1033 |
+
|
| 1034 |
+
# FFN
|
| 1035 |
+
# x = shortcut + self.drop_path(x if (not self.use_layerscale) else (self.gamma_1 * x))
|
| 1036 |
+
# x = x + self.drop_path(self.mlp(self.norm2(x)) if (not self.use_layerscale) else (self.gamma_2 * self.mlp(self.norm2(x))))
|
| 1037 |
+
|
| 1038 |
+
# print(x.shape, shortcut[:,-1].view(B0, -1, C).shape)
|
| 1039 |
+
x = shortcut[:,-1].view(B0, -1, C) + self.drop_path(x if (not self.use_layerscale) else (self.gamma_1 * x))
|
| 1040 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)) if (not self.use_layerscale) else (self.gamma_2 * self.mlp(self.norm2(x))))
|
| 1041 |
+
|
| 1042 |
+
# x=torch.cat([shortcut[:,:-1],x.view(B0,self.input_resolution[0],self.input_resolution[1],C).unsqueeze(1)],1)
|
| 1043 |
+
x=torch.cat([shortcut[:,:-1],x.view(B0,H0,W0,C).unsqueeze(1)],1)
|
| 1044 |
+
|
| 1045 |
+
assert x.shape==shortcut.shape
|
| 1046 |
+
|
| 1047 |
+
# exit()
|
| 1048 |
+
|
| 1049 |
+
return x
|
| 1050 |
+
|
| 1051 |
+
def extra_repr(self) -> str:
|
| 1052 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
| 1053 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
| 1054 |
+
|
| 1055 |
+
def flops(self):
|
| 1056 |
+
flops = 0
|
| 1057 |
+
H, W = self.input_resolution
|
| 1058 |
+
# norm1
|
| 1059 |
+
flops += self.dim * H * W
|
| 1060 |
+
|
| 1061 |
+
# W-MSA/SW-MSA
|
| 1062 |
+
nW = H * W / self.window_size / self.window_size
|
| 1063 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size, self.window_size, self.focal_window)
|
| 1064 |
+
|
| 1065 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
| 1066 |
+
for k in range(self.focal_level-1):
|
| 1067 |
+
window_size_glo = math.floor(self.window_size_glo / (2 ** k))
|
| 1068 |
+
nW_glo = nW * (2**k)
|
| 1069 |
+
# (sub)-window pooling
|
| 1070 |
+
flops += nW_glo * self.dim * window_size_glo * window_size_glo
|
| 1071 |
+
# qkv for global levels
|
| 1072 |
+
# NOTE: in our implementation, we pass the pooled window embedding to qkv embedding layer,
|
| 1073 |
+
# but theoritically, we only need to compute k and v.
|
| 1074 |
+
flops += nW_glo * self.dim * 3 * self.dim
|
| 1075 |
+
|
| 1076 |
+
# mlp
|
| 1077 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
| 1078 |
+
# norm2
|
| 1079 |
+
flops += self.dim * H * W
|
| 1080 |
+
return flops
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
class BasicLayer3d3(nn.Module):
|
| 1084 |
+
""" A basic Focal Transformer layer for one stage.
|
| 1085 |
+
|
| 1086 |
+
Args:
|
| 1087 |
+
dim (int): Number of input channels.
|
| 1088 |
+
input_resolution (tuple[int]): Input resolution.
|
| 1089 |
+
depth (int): Number of blocks.
|
| 1090 |
+
num_heads (int): Number of attention heads.
|
| 1091 |
+
window_size (int): Local window size.
|
| 1092 |
+
expand_size (int): expand size for focal level 1.
|
| 1093 |
+
expand_layer (str): expand layer. Default: all
|
| 1094 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
| 1095 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 1096 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 1097 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 1098 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 1099 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 1100 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 1101 |
+
pool_method (str): Window pooling method. Default: none.
|
| 1102 |
+
focal_level (int): Number of focal levels. Default: 1.
|
| 1103 |
+
focal_window (int): region size at each focal level. Default: 1.
|
| 1104 |
+
use_conv_embed (bool): whether use overlapped convolutional patch embedding layer. Default: False
|
| 1105 |
+
use_shift (bool): Whether use window shift as in Swin Transformer. Default: False
|
| 1106 |
+
use_pre_norm (bool): Whether use pre-norm before patch embedding projection for stability. Default: False
|
| 1107 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 1108 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 1109 |
+
use_layerscale (bool): Whether use layer scale for stability. Default: False.
|
| 1110 |
+
layerscale_value (float): Layerscale value. Default: 1e-4.
|
| 1111 |
+
"""
|
| 1112 |
+
|
| 1113 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size, expand_size, expand_layer="all",
|
| 1114 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
| 1115 |
+
drop_path=0., norm_layer=nn.LayerNorm, pool_method="none",
|
| 1116 |
+
focal_level=1, focal_window=1, use_conv_embed=False, use_shift=False, use_pre_norm=False,
|
| 1117 |
+
downsample=None, use_checkpoint=False, use_layerscale=False, layerscale_value=1e-4, focal_l_clips=[16,8,2], focal_kernel_clips=[7,5,3]):
|
| 1118 |
+
|
| 1119 |
+
super().__init__()
|
| 1120 |
+
self.dim = dim
|
| 1121 |
+
self.input_resolution = input_resolution
|
| 1122 |
+
self.depth = depth
|
| 1123 |
+
self.use_checkpoint = use_checkpoint
|
| 1124 |
+
|
| 1125 |
+
if expand_layer == "even":
|
| 1126 |
+
expand_factor = 0
|
| 1127 |
+
elif expand_layer == "odd":
|
| 1128 |
+
expand_factor = 1
|
| 1129 |
+
elif expand_layer == "all":
|
| 1130 |
+
expand_factor = -1
|
| 1131 |
+
|
| 1132 |
+
# build blocks
|
| 1133 |
+
self.blocks = nn.ModuleList([
|
| 1134 |
+
CffmTransformerBlock3d3(dim=dim, input_resolution=input_resolution,
|
| 1135 |
+
num_heads=num_heads, window_size=window_size,
|
| 1136 |
+
shift_size=(0 if (i % 2 == 0) else window_size // 2) if use_shift else 0,
|
| 1137 |
+
expand_size=0 if (i % 2 == expand_factor) else expand_size,
|
| 1138 |
+
mlp_ratio=mlp_ratio,
|
| 1139 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 1140 |
+
drop=drop,
|
| 1141 |
+
attn_drop=attn_drop,
|
| 1142 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 1143 |
+
norm_layer=norm_layer,
|
| 1144 |
+
pool_method=pool_method,
|
| 1145 |
+
focal_level=focal_level,
|
| 1146 |
+
focal_window=focal_window,
|
| 1147 |
+
use_layerscale=use_layerscale,
|
| 1148 |
+
layerscale_value=layerscale_value,
|
| 1149 |
+
focal_l_clips=focal_l_clips,
|
| 1150 |
+
focal_kernel_clips=focal_kernel_clips)
|
| 1151 |
+
for i in range(depth)])
|
| 1152 |
+
|
| 1153 |
+
# patch merging layer
|
| 1154 |
+
if downsample is not None:
|
| 1155 |
+
self.downsample = downsample(
|
| 1156 |
+
img_size=input_resolution, patch_size=2, in_chans=dim, embed_dim=2*dim,
|
| 1157 |
+
use_conv_embed=use_conv_embed, norm_layer=norm_layer, use_pre_norm=use_pre_norm,
|
| 1158 |
+
is_stem=False
|
| 1159 |
+
)
|
| 1160 |
+
else:
|
| 1161 |
+
self.downsample = None
|
| 1162 |
+
|
| 1163 |
+
def forward(self, x, batch_size=None, num_clips=None, reg_tokens=None):
|
| 1164 |
+
B, D, C, H, W = x.shape
|
| 1165 |
+
x = rearrange(x, 'b d c h w -> b d h w c')
|
| 1166 |
+
for blk in self.blocks:
|
| 1167 |
+
if self.use_checkpoint:
|
| 1168 |
+
x = checkpoint.checkpoint(blk, x)
|
| 1169 |
+
else:
|
| 1170 |
+
x = blk(x)
|
| 1171 |
+
|
| 1172 |
+
if self.downsample is not None:
|
| 1173 |
+
x = x.view(x.shape[0], self.input_resolution[0], self.input_resolution[1], -1).permute(0, 3, 1, 2).contiguous()
|
| 1174 |
+
x = self.downsample(x)
|
| 1175 |
+
x = rearrange(x, 'b d h w c -> b d c h w')
|
| 1176 |
+
return x
|
| 1177 |
+
|
| 1178 |
+
def extra_repr(self) -> str:
|
| 1179 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
| 1180 |
+
|
| 1181 |
+
def flops(self):
|
| 1182 |
+
flops = 0
|
| 1183 |
+
for blk in self.blocks:
|
| 1184 |
+
flops += blk.flops()
|
| 1185 |
+
if self.downsample is not None:
|
| 1186 |
+
flops += self.downsample.flops()
|
| 1187 |
+
return flops
|
models/SpaTrackV2/models/depth_refiner/stablizer.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
# from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
# from mmseg.ops import resize
|
| 7 |
+
from torch.nn.functional import interpolate as resize
|
| 8 |
+
# from builder import HEADS
|
| 9 |
+
from models.SpaTrackV2.models.depth_refiner.decode_head import BaseDecodeHead, BaseDecodeHead_clips, BaseDecodeHead_clips_flow
|
| 10 |
+
# from mmseg.models.utils import *
|
| 11 |
+
import attr
|
| 12 |
+
from IPython import embed
|
| 13 |
+
from models.SpaTrackV2.models.depth_refiner.stablilization_attention import BasicLayer3d3
|
| 14 |
+
import cv2
|
| 15 |
+
from models.SpaTrackV2.models.depth_refiner.network import *
|
| 16 |
+
import warnings
|
| 17 |
+
# from mmcv.utils import Registry, build_from_cfg
|
| 18 |
+
from torch import nn
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from models.SpaTrackV2.models.blocks import (
|
| 22 |
+
AttnBlock, CrossAttnBlock, Mlp
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
class MLP(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Linear Embedding
|
| 28 |
+
"""
|
| 29 |
+
def __init__(self, input_dim=2048, embed_dim=768):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.proj = nn.Linear(input_dim, embed_dim)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
x = x.flatten(2).transpose(1, 2)
|
| 35 |
+
x = self.proj(x)
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def scatter_multiscale_fast(
|
| 40 |
+
track2d: torch.Tensor,
|
| 41 |
+
trackfeature: torch.Tensor,
|
| 42 |
+
H: int,
|
| 43 |
+
W: int,
|
| 44 |
+
kernel_sizes = [1]
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
"""
|
| 47 |
+
Scatter sparse track features onto a dense image grid with weighted multi-scale pooling to handle zero-value gaps.
|
| 48 |
+
|
| 49 |
+
This function scatters sparse track features into a dense image grid and applies multi-scale average pooling
|
| 50 |
+
while excluding zero-value holes. The weight mask ensures that only valid feature regions contribute to the pooling,
|
| 51 |
+
avoiding dilution by empty pixels.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
track2d (torch.Tensor): Float tensor of shape (B, T, N, 2) containing (x, y) pixel coordinates
|
| 55 |
+
for each track point across batches, frames, and points.
|
| 56 |
+
trackfeature (torch.Tensor): Float tensor of shape (B, T, N, C) with C-dimensional features
|
| 57 |
+
for each track point.
|
| 58 |
+
H (int): Height of the target output image.
|
| 59 |
+
W (int): Width of the target output image.
|
| 60 |
+
kernel_sizes (List[int]): List of odd integers for average pooling kernel sizes. Default: [3, 5, 7].
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
torch.Tensor: Multi-scale fused feature map of shape (B, T, C, H, W) with hole-resistant pooling.
|
| 64 |
+
"""
|
| 65 |
+
B, T, N, C = trackfeature.shape
|
| 66 |
+
device = trackfeature.device
|
| 67 |
+
|
| 68 |
+
# 1. Flatten coordinates and filter valid points within image bounds
|
| 69 |
+
coords_flat = track2d.round().long().reshape(-1, 2) # (B*T*N, 2)
|
| 70 |
+
x = coords_flat[:, 0] # x coordinates
|
| 71 |
+
y = coords_flat[:, 1] # y coordinates
|
| 72 |
+
feat_flat = trackfeature.reshape(-1, C) # Flatten features
|
| 73 |
+
|
| 74 |
+
valid_mask = (x >= 0) & (x < W) & (y >= 0) & (y < H)
|
| 75 |
+
x = x[valid_mask]
|
| 76 |
+
y = y[valid_mask]
|
| 77 |
+
feat_flat = feat_flat[valid_mask]
|
| 78 |
+
valid_count = x.shape[0]
|
| 79 |
+
|
| 80 |
+
if valid_count == 0:
|
| 81 |
+
return torch.zeros(B, T, C, H, W, device=device) # Handle no-valid-point case
|
| 82 |
+
|
| 83 |
+
# 2. Calculate linear indices and batch-frame indices for scattering
|
| 84 |
+
lin_idx = y * W + x # Linear index within a single frame (H*W range)
|
| 85 |
+
|
| 86 |
+
# Generate batch-frame indices (e.g., 0~B*T-1 for each frame in batch)
|
| 87 |
+
bt_idx_raw = (
|
| 88 |
+
torch.arange(B * T, device=device)
|
| 89 |
+
.view(B, T, 1)
|
| 90 |
+
.expand(B, T, N)
|
| 91 |
+
.reshape(-1)
|
| 92 |
+
)
|
| 93 |
+
bt_idx = bt_idx_raw[valid_mask] # Indices for valid points across batch and frames
|
| 94 |
+
|
| 95 |
+
# 3. Create accumulation buffers for features and weights
|
| 96 |
+
total_space = B * T * H * W
|
| 97 |
+
img_accum_flat = torch.zeros(total_space, C, device=device) # Feature accumulator
|
| 98 |
+
weight_accum_flat = torch.zeros(total_space, 1, device=device) # Weight accumulator (counts)
|
| 99 |
+
|
| 100 |
+
# 4. Scatter features and weights into accumulation buffers
|
| 101 |
+
idx_in_accum = bt_idx * (H * W) + lin_idx # Global index: batch_frame * H*W + pixel_index
|
| 102 |
+
|
| 103 |
+
# Add features to corresponding indices (index_add_ is efficient for sparse updates)
|
| 104 |
+
img_accum_flat.index_add_(0, idx_in_accum, feat_flat)
|
| 105 |
+
weight_accum_flat.index_add_(0, idx_in_accum, torch.ones((valid_count, 1), device=device))
|
| 106 |
+
|
| 107 |
+
# 5. Normalize features by valid weights, keep zeros for invalid regions
|
| 108 |
+
valid_mask_flat = weight_accum_flat > 0 # Binary mask for valid pixels
|
| 109 |
+
img_accum_flat = img_accum_flat / (weight_accum_flat + 1e-6) # Avoid division by zero
|
| 110 |
+
img_accum_flat = img_accum_flat * valid_mask_flat.float() # Mask out invalid regions
|
| 111 |
+
|
| 112 |
+
# 6. Reshape to (B, T, C, H, W) for further processing
|
| 113 |
+
img = (
|
| 114 |
+
img_accum_flat.view(B, T, H, W, C)
|
| 115 |
+
.permute(0, 1, 4, 2, 3)
|
| 116 |
+
.contiguous()
|
| 117 |
+
) # Shape: (B, T, C, H, W)
|
| 118 |
+
|
| 119 |
+
# 7. Multi-scale pooling with weight masking to exclude zero holes
|
| 120 |
+
blurred_outputs = []
|
| 121 |
+
for k in kernel_sizes:
|
| 122 |
+
pad = k // 2
|
| 123 |
+
img_bt = img.view(B*T, C, H, W) # Flatten batch and time for pooling
|
| 124 |
+
|
| 125 |
+
# Create weight mask for valid regions (1 where features exist, 0 otherwise)
|
| 126 |
+
weight_mask = (
|
| 127 |
+
weight_accum_flat.view(B, T, 1, H, W) > 0
|
| 128 |
+
).float().view(B*T, 1, H, W) # Shape: (B*T, 1, H, W)
|
| 129 |
+
|
| 130 |
+
# Calculate number of valid neighbors in each pooling window
|
| 131 |
+
weight_sum = F.conv2d(
|
| 132 |
+
weight_mask,
|
| 133 |
+
torch.ones((1, 1, k, k), device=device),
|
| 134 |
+
stride=1,
|
| 135 |
+
padding=pad
|
| 136 |
+
) # Shape: (B*T, 1, H, W)
|
| 137 |
+
|
| 138 |
+
# Sum features only in valid regions
|
| 139 |
+
feat_sum = F.conv2d(
|
| 140 |
+
img_bt * weight_mask, # Mask out invalid regions before summing
|
| 141 |
+
torch.ones((1, 1, k, k), device=device).expand(C, 1, k, k),
|
| 142 |
+
stride=1,
|
| 143 |
+
padding=pad,
|
| 144 |
+
groups=C
|
| 145 |
+
) # Shape: (B*T, C, H, W)
|
| 146 |
+
|
| 147 |
+
# Compute average only over valid neighbors
|
| 148 |
+
feat_avg = feat_sum / (weight_sum + 1e-6)
|
| 149 |
+
blurred_outputs.append(feat_avg)
|
| 150 |
+
|
| 151 |
+
# 8. Fuse multi-scale results by averaging across kernel sizes
|
| 152 |
+
fused = torch.stack(blurred_outputs).mean(dim=0) # Average over kernel sizes
|
| 153 |
+
return fused.view(B, T, C, H, W) # Restore original shape
|
| 154 |
+
|
| 155 |
+
#@HEADS.register_module()
|
| 156 |
+
class Stabilization_Network_Cross_Attention(BaseDecodeHead_clips_flow):
|
| 157 |
+
|
| 158 |
+
def __init__(self, feature_strides, **kwargs):
|
| 159 |
+
super(Stabilization_Network_Cross_Attention, self).__init__(input_transform='multiple_select', **kwargs)
|
| 160 |
+
self.training = False
|
| 161 |
+
assert len(feature_strides) == len(self.in_channels)
|
| 162 |
+
assert min(feature_strides) == feature_strides[0]
|
| 163 |
+
self.feature_strides = feature_strides
|
| 164 |
+
|
| 165 |
+
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
|
| 166 |
+
|
| 167 |
+
decoder_params = kwargs['decoder_params']
|
| 168 |
+
embedding_dim = decoder_params['embed_dim']
|
| 169 |
+
|
| 170 |
+
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
|
| 171 |
+
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
|
| 172 |
+
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
|
| 173 |
+
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
|
| 174 |
+
|
| 175 |
+
self.linear_fuse = nn.Sequential(nn.Conv2d(embedding_dim*4, embedding_dim, kernel_size=(1, 1), stride=(1, 1), bias=False),\
|
| 176 |
+
nn.ReLU(inplace=True))
|
| 177 |
+
|
| 178 |
+
self.proj_track = nn.Conv2d(100, 128, kernel_size=(1, 1), stride=(1, 1), bias=True)
|
| 179 |
+
|
| 180 |
+
depths = decoder_params['depths']
|
| 181 |
+
|
| 182 |
+
self.reg_tokens = nn.Parameter(torch.zeros(1, 2, embedding_dim))
|
| 183 |
+
self.global_patch = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=(8, 8), stride=(8, 8), bias=True)
|
| 184 |
+
|
| 185 |
+
self.att_temporal = nn.ModuleList(
|
| 186 |
+
[
|
| 187 |
+
AttnBlock(embedding_dim, 8,
|
| 188 |
+
mlp_ratio=4, flash=True, ckpt_fwd=True)
|
| 189 |
+
for _ in range(8)
|
| 190 |
+
]
|
| 191 |
+
)
|
| 192 |
+
self.att_spatial = nn.ModuleList(
|
| 193 |
+
[
|
| 194 |
+
AttnBlock(embedding_dim, 8,
|
| 195 |
+
mlp_ratio=4, flash=True, ckpt_fwd=True)
|
| 196 |
+
for _ in range(8)
|
| 197 |
+
]
|
| 198 |
+
)
|
| 199 |
+
self.scale_shift_head = nn.Sequential(nn.Linear(embedding_dim, embedding_dim), nn.GELU(), nn.Linear(embedding_dim, 4))
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# Initialize reg tokens
|
| 203 |
+
nn.init.trunc_normal_(self.reg_tokens, std=0.02)
|
| 204 |
+
|
| 205 |
+
self.decoder_focal=BasicLayer3d3(dim=embedding_dim,
|
| 206 |
+
input_resolution=(96,
|
| 207 |
+
96),
|
| 208 |
+
depth=depths,
|
| 209 |
+
num_heads=8,
|
| 210 |
+
window_size=7,
|
| 211 |
+
mlp_ratio=4.,
|
| 212 |
+
qkv_bias=True,
|
| 213 |
+
qk_scale=None,
|
| 214 |
+
drop=0.,
|
| 215 |
+
attn_drop=0.,
|
| 216 |
+
drop_path=0.,
|
| 217 |
+
norm_layer=nn.LayerNorm,
|
| 218 |
+
pool_method='fc',
|
| 219 |
+
downsample=None,
|
| 220 |
+
focal_level=2,
|
| 221 |
+
focal_window=5,
|
| 222 |
+
expand_size=3,
|
| 223 |
+
expand_layer="all",
|
| 224 |
+
use_conv_embed=False,
|
| 225 |
+
use_shift=False,
|
| 226 |
+
use_pre_norm=False,
|
| 227 |
+
use_checkpoint=False,
|
| 228 |
+
use_layerscale=False,
|
| 229 |
+
layerscale_value=1e-4,
|
| 230 |
+
focal_l_clips=[7,4,2],
|
| 231 |
+
focal_kernel_clips=[7,5,3])
|
| 232 |
+
|
| 233 |
+
self.ffm2 = FFM(inchannels= 256, midchannels= 256, outchannels = 128)
|
| 234 |
+
self.ffm1 = FFM(inchannels= 128, midchannels= 128, outchannels = 64)
|
| 235 |
+
self.ffm0 = FFM(inchannels= 64, midchannels= 64, outchannels = 32,upfactor=1)
|
| 236 |
+
self.AO = AO(32, outchannels=3, upfactor=1)
|
| 237 |
+
self._c2 = None
|
| 238 |
+
self._c_further = None
|
| 239 |
+
|
| 240 |
+
def buffer_forward(self, inputs, num_clips=None, imgs=None):#,infermode=1):
|
| 241 |
+
|
| 242 |
+
# input: B T 7 H W (7 means 3 rgb + 3 pointmap + 1 uncertainty) normalized
|
| 243 |
+
if self.training:
|
| 244 |
+
assert self.num_clips==num_clips
|
| 245 |
+
|
| 246 |
+
x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
|
| 247 |
+
c1, c2, c3, c4 = x
|
| 248 |
+
|
| 249 |
+
############## MLP decoder on C1-C4 ###########
|
| 250 |
+
n, _, h, w = c4.shape
|
| 251 |
+
batch_size = n // num_clips
|
| 252 |
+
|
| 253 |
+
_c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
|
| 254 |
+
_c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
| 255 |
+
|
| 256 |
+
_c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
|
| 257 |
+
_c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
| 258 |
+
|
| 259 |
+
_c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
|
| 260 |
+
_c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
| 261 |
+
|
| 262 |
+
_c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
|
| 263 |
+
_c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
|
| 264 |
+
|
| 265 |
+
_, _, h, w=_c.shape
|
| 266 |
+
_c_further=_c.reshape(batch_size, num_clips, -1, h, w) #h2w2
|
| 267 |
+
|
| 268 |
+
# Expand reg_tokens to match batch size
|
| 269 |
+
reg_tokens = self.reg_tokens.expand(batch_size*num_clips, -1, -1) # [B, 2, C]
|
| 270 |
+
|
| 271 |
+
_c2=self.decoder_focal(_c_further, batch_size=batch_size, num_clips=num_clips, reg_tokens=reg_tokens)
|
| 272 |
+
|
| 273 |
+
assert _c_further.shape==_c2.shape
|
| 274 |
+
self._c2 = _c2
|
| 275 |
+
self._c_further = _c_further
|
| 276 |
+
|
| 277 |
+
# compute the scale and shift of the global patch
|
| 278 |
+
global_patch = self.global_patch(_c2.view(batch_size*num_clips, -1, h, w)).view(batch_size*num_clips, _c2.shape[2], -1).permute(0,2,1)
|
| 279 |
+
global_patch = torch.cat([global_patch, reg_tokens], dim=1)
|
| 280 |
+
for i in range(8):
|
| 281 |
+
global_patch = self.att_temporal[i](global_patch)
|
| 282 |
+
global_patch = rearrange(global_patch, '(b t) n c -> (b n) t c', b=batch_size, t=num_clips, c=_c2.shape[2])
|
| 283 |
+
global_patch = self.att_spatial[i](global_patch)
|
| 284 |
+
global_patch = rearrange(global_patch, '(b n) t c -> (b t) n c', b=batch_size, t=num_clips, c=_c2.shape[2])
|
| 285 |
+
|
| 286 |
+
reg_tokens = global_patch[:, -2:, :]
|
| 287 |
+
s_ = self.scale_shift_head(reg_tokens)
|
| 288 |
+
scale = 1 + s_[:, 0, :1].view(batch_size, num_clips, 1, 1, 1)
|
| 289 |
+
shift = s_[:, 1, 1:].view(batch_size, num_clips, 3, 1, 1)
|
| 290 |
+
shift[:,:,:2,...] = 0
|
| 291 |
+
return scale, shift
|
| 292 |
+
|
| 293 |
+
def forward(self, inputs, edge_feat, edge_feat1, tracks, tracks_uvd, num_clips=None, imgs=None, vis_track=None):#,infermode=1):
|
| 294 |
+
|
| 295 |
+
if self._c2 is None:
|
| 296 |
+
scale, shift = self.buffer_forward(inputs,num_clips,imgs)
|
| 297 |
+
|
| 298 |
+
B, T, N, _ = tracks.shape
|
| 299 |
+
|
| 300 |
+
_c2 = self._c2
|
| 301 |
+
_c_further = self._c_further
|
| 302 |
+
|
| 303 |
+
# skip and head
|
| 304 |
+
_c_further = rearrange(_c_further, 'b t c h w -> (b t) c h w', b=B, t=T)
|
| 305 |
+
_c2 = rearrange(_c2, 'b t c h w -> (b t) c h w', b=B, t=T)
|
| 306 |
+
|
| 307 |
+
outframe = self.ffm2(_c_further, _c2)
|
| 308 |
+
|
| 309 |
+
tracks_uv = tracks_uvd[...,:2].clone()
|
| 310 |
+
track_feature = scatter_multiscale_fast(tracks_uv/2, tracks, outframe.shape[-2], outframe.shape[-1], kernel_sizes=[1, 3, 5])
|
| 311 |
+
# visualize track_feature as video
|
| 312 |
+
# import cv2
|
| 313 |
+
# import imageio
|
| 314 |
+
# import os
|
| 315 |
+
# BT, C, H, W = outframe.shape
|
| 316 |
+
# track_feature_vis = track_feature.view(B, T, 3, H, W).float().detach().cpu().numpy()
|
| 317 |
+
# track_feature_vis = track_feature_vis.transpose(0,1,3,4,2)
|
| 318 |
+
# track_feature_vis = (track_feature_vis - track_feature_vis.min()) / (track_feature_vis.max() - track_feature_vis.min() + 1e-6)
|
| 319 |
+
# track_feature_vis = (track_feature_vis * 255).astype(np.uint8)
|
| 320 |
+
# imgs =(imgs.detach() + 1) * 127.5
|
| 321 |
+
# vis_track.visualize(video=imgs, tracks=tracks_uv, filename="test")
|
| 322 |
+
# for b in range(B):
|
| 323 |
+
# frames = []
|
| 324 |
+
# for t in range(T):
|
| 325 |
+
# frame = track_feature_vis[b,t]
|
| 326 |
+
# frame = cv2.applyColorMap(frame[...,0], cv2.COLORMAP_JET)
|
| 327 |
+
# frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 328 |
+
# frames.append(frame)
|
| 329 |
+
# # Save as gif
|
| 330 |
+
# imageio.mimsave(f'track_feature_b{b}.gif', frames, duration=0.1)
|
| 331 |
+
# import pdb; pdb.set_trace()
|
| 332 |
+
track_feature = rearrange(track_feature, 'b t c h w -> (b t) c h w')
|
| 333 |
+
track_feature = self.proj_track(track_feature)
|
| 334 |
+
outframe = self.ffm1(edge_feat1 + track_feature,outframe)
|
| 335 |
+
outframe = self.ffm0(edge_feat,outframe)
|
| 336 |
+
outframe = self.AO(outframe)
|
| 337 |
+
|
| 338 |
+
return outframe
|
| 339 |
+
|
| 340 |
+
def reset_success(self):
|
| 341 |
+
self._c2 = None
|
| 342 |
+
self._c_further = None
|
models/SpaTrackV2/models/predictor.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from models.SpaTrackV2.models.SpaTrack import SpaTrack2
|
| 12 |
+
from typing import Literal
|
| 13 |
+
import numpy as np
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Union, Optional
|
| 16 |
+
import cv2
|
| 17 |
+
import os
|
| 18 |
+
import decord
|
| 19 |
+
|
| 20 |
+
class Predictor(torch.nn.Module):
|
| 21 |
+
def __init__(self, args=None):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.args = args
|
| 24 |
+
self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
|
| 25 |
+
self.S_wind = 200
|
| 26 |
+
self.overlap = 8
|
| 27 |
+
|
| 28 |
+
def to(self, device: Union[str, torch.device]):
|
| 29 |
+
self.spatrack.to(device)
|
| 30 |
+
self.spatrack.base_model.to(device)
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def from_pretrained(
|
| 34 |
+
cls,
|
| 35 |
+
pretrained_model_name_or_path: Union[str, Path],
|
| 36 |
+
*,
|
| 37 |
+
force_download: bool = False,
|
| 38 |
+
cache_dir: Optional[str] = None,
|
| 39 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 40 |
+
model_cfg: Optional[dict] = None,
|
| 41 |
+
**kwargs,
|
| 42 |
+
) -> "SpaTrack2":
|
| 43 |
+
"""
|
| 44 |
+
Load a pretrained model from a local file or a remote repository.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
pretrained_model_name_or_path (str or Path):
|
| 48 |
+
- Path to a local model file (e.g., `./model.pth`).
|
| 49 |
+
- HuggingFace Hub model ID (e.g., `username/model-name`).
|
| 50 |
+
force_download (bool, optional):
|
| 51 |
+
Whether to force re-download even if cached. Default: False.
|
| 52 |
+
cache_dir (str, optional):
|
| 53 |
+
Custom cache directory. Default: None (use default cache).
|
| 54 |
+
device (str or torch.device, optional):
|
| 55 |
+
Target device (e.g., "cuda", "cpu"). Default: None (keep original).
|
| 56 |
+
**kwargs:
|
| 57 |
+
Additional config overrides.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
SpaTrack2: Loaded pretrained model.
|
| 61 |
+
"""
|
| 62 |
+
# (1) check the path is local or remote
|
| 63 |
+
if isinstance(pretrained_model_name_or_path, Path):
|
| 64 |
+
model_path = str(pretrained_model_name_or_path)
|
| 65 |
+
else:
|
| 66 |
+
model_path = pretrained_model_name_or_path
|
| 67 |
+
# (2) if the path is remote, download it
|
| 68 |
+
if not os.path.exists(model_path):
|
| 69 |
+
raise NotImplementedError("Remote download not implemented yet. Use a local path.")
|
| 70 |
+
# (3) load the model weights
|
| 71 |
+
|
| 72 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 73 |
+
# (4) initialize the model (can load config.json if exists)
|
| 74 |
+
config_path = os.path.join(os.path.dirname(model_path), "config.json")
|
| 75 |
+
config = {}
|
| 76 |
+
if os.path.exists(config_path):
|
| 77 |
+
import json
|
| 78 |
+
with open(config_path, "r") as f:
|
| 79 |
+
config.update(json.load(f))
|
| 80 |
+
config.update(kwargs) # allow override the config
|
| 81 |
+
if model_cfg is not None:
|
| 82 |
+
config = model_cfg
|
| 83 |
+
model = cls(config)
|
| 84 |
+
if "model" in state_dict:
|
| 85 |
+
model.spatrack.load_state_dict(state_dict["model"], strict=False)
|
| 86 |
+
else:
|
| 87 |
+
model.spatrack.load_state_dict(state_dict, strict=False)
|
| 88 |
+
# (5) device management
|
| 89 |
+
if device is not None:
|
| 90 |
+
model.to(device)
|
| 91 |
+
|
| 92 |
+
return model
|
| 93 |
+
|
| 94 |
+
def forward(self, video: str|torch.Tensor|np.ndarray,
|
| 95 |
+
depth: str|torch.Tensor|np.ndarray=None,
|
| 96 |
+
unc_metric: str|torch.Tensor|np.ndarray=None,
|
| 97 |
+
intrs: str|torch.Tensor|np.ndarray=None,
|
| 98 |
+
extrs: str|torch.Tensor|np.ndarray=None,
|
| 99 |
+
queries=None, queries_3d=None, iters_track=4,
|
| 100 |
+
full_point=False, fps=30, track2d_gt=None,
|
| 101 |
+
fixed_cam=False, query_no_BA=False, stage=0,
|
| 102 |
+
support_frame=0, replace_ratio=0.6):
|
| 103 |
+
"""
|
| 104 |
+
video: this could be a path to a video, a tensor of shape (T, C, H, W) or a numpy array of shape (T, C, H, W)
|
| 105 |
+
queries: (B, N, 2)
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
if isinstance(video, str):
|
| 109 |
+
video = decord.VideoReader(video)
|
| 110 |
+
video = video[::fps].asnumpy() # Convert to numpy array
|
| 111 |
+
video = np.array(video) # Ensure numpy array
|
| 112 |
+
video = torch.from_numpy(video).permute(0, 3, 1, 2).float()
|
| 113 |
+
elif isinstance(video, np.ndarray):
|
| 114 |
+
video = torch.from_numpy(video).float()
|
| 115 |
+
|
| 116 |
+
if isinstance(depth, np.ndarray):
|
| 117 |
+
depth = torch.from_numpy(depth).float()
|
| 118 |
+
if isinstance(intrs, np.ndarray):
|
| 119 |
+
intrs = torch.from_numpy(intrs).float()
|
| 120 |
+
if isinstance(extrs, np.ndarray):
|
| 121 |
+
extrs = torch.from_numpy(extrs).float()
|
| 122 |
+
if isinstance(unc_metric, np.ndarray):
|
| 123 |
+
unc_metric = torch.from_numpy(unc_metric).float()
|
| 124 |
+
|
| 125 |
+
T_, C, H, W = video.shape
|
| 126 |
+
step_slide = self.S_wind - self.overlap
|
| 127 |
+
if T_ > self.S_wind:
|
| 128 |
+
|
| 129 |
+
num_windows = (T_ - self.S_wind + step_slide) // step_slide
|
| 130 |
+
T = num_windows * step_slide + self.S_wind
|
| 131 |
+
pad_len = T - T_
|
| 132 |
+
|
| 133 |
+
video = torch.cat([video, video[-1:].repeat(T-video.shape[0], 1, 1, 1)], dim=0)
|
| 134 |
+
if depth is not None:
|
| 135 |
+
depth = torch.cat([depth, depth[-1:].repeat(T-depth.shape[0], 1, 1)], dim=0)
|
| 136 |
+
if intrs is not None:
|
| 137 |
+
intrs = torch.cat([intrs, intrs[-1:].repeat(T-intrs.shape[0], 1, 1)], dim=0)
|
| 138 |
+
if extrs is not None:
|
| 139 |
+
extrs = torch.cat([extrs, extrs[-1:].repeat(T-extrs.shape[0], 1, 1)], dim=0)
|
| 140 |
+
if unc_metric is not None:
|
| 141 |
+
unc_metric = torch.cat([unc_metric, unc_metric[-1:].repeat(T-unc_metric.shape[0], 1)], dim=0)
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
ret = self.spatrack.forward_stream(video, queries, T_org=T_,
|
| 144 |
+
depth=depth, intrs=intrs, unc_metric_in=unc_metric, extrs=extrs, queries_3d=queries_3d,
|
| 145 |
+
window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track,
|
| 146 |
+
fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
return ret
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
models/SpaTrackV2/models/tracker3D/TrackRefiner.py
ADDED
|
@@ -0,0 +1,1478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
import torch
|
| 3 |
+
import torch.amp
|
| 4 |
+
from models.SpaTrackV2.models.tracker3D.co_tracker.cotracker_base import CoTrackerThreeOffline, get_1d_sincos_pos_embed_from_grid
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from models.SpaTrackV2.utils.visualizer import Visualizer
|
| 7 |
+
from models.SpaTrackV2.utils.model_utils import sample_features5d
|
| 8 |
+
from models.SpaTrackV2.models.blocks import bilinear_sampler
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
|
| 11 |
+
EfficientUpdateFormer, AttnBlock, Attention, CrossAttnBlock,
|
| 12 |
+
sequence_BCE_loss, sequence_loss, sequence_prob_loss, sequence_dyn_prob_loss, sequence_loss_xyz, balanced_binary_cross_entropy
|
| 13 |
+
)
|
| 14 |
+
from torchvision.io import write_video
|
| 15 |
+
import math
|
| 16 |
+
from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
|
| 17 |
+
Mlp, BasicEncoder, EfficientUpdateFormer, GeometryEncoder, NeighborTransformer, CorrPointformer
|
| 18 |
+
)
|
| 19 |
+
from models.SpaTrackV2.utils.embeddings import get_3d_sincos_pos_embed_from_grid
|
| 20 |
+
from einops import rearrange, repeat
|
| 21 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import (
|
| 22 |
+
EfficientUpdateFormer3D, weighted_procrustes_torch, posenc, key_fr_wprocrustes, get_topo_mask,
|
| 23 |
+
TrackFusion, get_nth_visible_time_index
|
| 24 |
+
)
|
| 25 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.ba import extract_static_from_3DTracks, ba_pycolmap
|
| 26 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.pointmap_updator import PointMapUpdator
|
| 27 |
+
from models.SpaTrackV2.models.depth_refiner.depth_refiner import TrackStablizer
|
| 28 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.alignment import affine_invariant_global_loss
|
| 29 |
+
from models.SpaTrackV2.models.tracker3D.delta_utils.upsample_transformer import UpsampleTransformerAlibi
|
| 30 |
+
|
| 31 |
+
class TrackRefiner3D(CoTrackerThreeOffline):
|
| 32 |
+
|
| 33 |
+
def __init__(self, args=None):
|
| 34 |
+
super().__init__(**args.base)
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
This is 3D warpper from cotracker, which load the cotracker pretrain and
|
| 38 |
+
jointly refine the `camera pose`, `3D tracks`, `video depth`, `visibility` and `conf`
|
| 39 |
+
"""
|
| 40 |
+
self.updateformer3D = EfficientUpdateFormer3D(self.updateformer)
|
| 41 |
+
self.corr_depth_mlp = Mlp(in_features=256, hidden_features=256, out_features=256)
|
| 42 |
+
self.rel_pos_mlp = Mlp(in_features=75, hidden_features=128, out_features=128)
|
| 43 |
+
self.rel_pos_glob_mlp = Mlp(in_features=75, hidden_features=128, out_features=256)
|
| 44 |
+
self.corr_xyz_mlp = Mlp(in_features=256, hidden_features=128, out_features=128)
|
| 45 |
+
self.xyz_mlp = Mlp(in_features=126, hidden_features=128, out_features=84)
|
| 46 |
+
# self.track_feat_mlp = Mlp(in_features=1110, hidden_features=128, out_features=128)
|
| 47 |
+
self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110)
|
| 48 |
+
# get the anchor point's embedding, and init the pts refiner
|
| 49 |
+
update_pts = True
|
| 50 |
+
# self.corr_transformer = nn.ModuleList([
|
| 51 |
+
# CorrPointformer(
|
| 52 |
+
# dim=128,
|
| 53 |
+
# num_heads=8,
|
| 54 |
+
# head_dim=128 // 8,
|
| 55 |
+
# mlp_ratio=4.0,
|
| 56 |
+
# )
|
| 57 |
+
# for _ in range(self.corr_levels)
|
| 58 |
+
# ])
|
| 59 |
+
self.corr_transformer = nn.ModuleList([
|
| 60 |
+
CorrPointformer(
|
| 61 |
+
dim=128,
|
| 62 |
+
num_heads=8,
|
| 63 |
+
head_dim=128 // 8,
|
| 64 |
+
mlp_ratio=4.0,
|
| 65 |
+
)
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
self.fnet = BasicEncoder(input_dim=3,
|
| 69 |
+
output_dim=self.latent_dim, stride=self.stride)
|
| 70 |
+
self.corr3d_radius = 3
|
| 71 |
+
|
| 72 |
+
if args.stablizer:
|
| 73 |
+
self.scale_shift_tokens = nn.Parameter(torch.randn(1, 2, self.latent_dim, requires_grad=True))
|
| 74 |
+
self.upsample_kernel_size = 5
|
| 75 |
+
self.residual_embedding = nn.Parameter(torch.randn(
|
| 76 |
+
self.latent_dim, self.model_resolution[0]//16,
|
| 77 |
+
self.model_resolution[1]//16, requires_grad=True))
|
| 78 |
+
self.dense_mlp = nn.Conv2d(2*self.latent_dim+63, self.latent_dim, kernel_size=1, stride=1, padding=0)
|
| 79 |
+
self.upsample_factor = 4
|
| 80 |
+
self.upsample_transformer = UpsampleTransformerAlibi(
|
| 81 |
+
kernel_size=self.upsample_kernel_size, # kernel_size=3, #
|
| 82 |
+
stride=self.stride,
|
| 83 |
+
latent_dim=self.latent_dim,
|
| 84 |
+
num_attn_blocks=2,
|
| 85 |
+
upsample_factor=4,
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
self.update_pointmap = None
|
| 89 |
+
|
| 90 |
+
self.mode = args.mode
|
| 91 |
+
if self.mode == "online":
|
| 92 |
+
self.s_wind = args.s_wind
|
| 93 |
+
self.overlap = args.overlap
|
| 94 |
+
|
| 95 |
+
def upsample_with_mask(
|
| 96 |
+
self, inp: torch.Tensor, mask: torch.Tensor
|
| 97 |
+
) -> torch.Tensor:
|
| 98 |
+
"""Upsample flow field [H/P, W/P, 2] -> [H, W, 2] using convex combination"""
|
| 99 |
+
H, W = inp.shape[-2:]
|
| 100 |
+
up_inp = F.unfold(
|
| 101 |
+
inp, [self.upsample_kernel_size, self.upsample_kernel_size], padding=(self.upsample_kernel_size - 1) // 2
|
| 102 |
+
)
|
| 103 |
+
up_inp = rearrange(up_inp, "b c (h w) -> b c h w", h=H, w=W)
|
| 104 |
+
up_inp = F.interpolate(up_inp, scale_factor=self.upsample_factor, mode="nearest")
|
| 105 |
+
up_inp = rearrange(
|
| 106 |
+
up_inp, "b (c i j) h w -> b c (i j) h w", i=self.upsample_kernel_size, j=self.upsample_kernel_size
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
up_inp = torch.sum(mask * up_inp, dim=2)
|
| 110 |
+
return up_inp
|
| 111 |
+
|
| 112 |
+
def track_from_cam(self, queries, c2w_traj, intrs,
|
| 113 |
+
rgbs=None, visualize=False):
|
| 114 |
+
"""
|
| 115 |
+
This function will generate tracks by camera transform
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
queries: B T N 4
|
| 119 |
+
c2w_traj: B T 4 4
|
| 120 |
+
intrs: B T 3 3
|
| 121 |
+
"""
|
| 122 |
+
B, T, N, _ = queries.shape
|
| 123 |
+
query_t = queries[:,0,:,0].to(torch.int64) # B N
|
| 124 |
+
query_c2w = torch.gather(c2w_traj,
|
| 125 |
+
dim=1, index=query_t[..., None, None].expand(-1, -1, 4, 4)) # B N 4 4
|
| 126 |
+
query_intr = torch.gather(intrs,
|
| 127 |
+
dim=1, index=query_t[..., None, None].expand(-1, -1, 3, 3)) # B N 3 3
|
| 128 |
+
query_pts = queries[:,0,:,1:4].clone() # B N 3
|
| 129 |
+
query_d = queries[:,0,:,3:4] # B N 3
|
| 130 |
+
query_pts[...,2] = 1
|
| 131 |
+
|
| 132 |
+
cam_pts = torch.einsum("bnij,bnj->bni", torch.inverse(query_intr), query_pts)*query_d # B N 3
|
| 133 |
+
# convert to world
|
| 134 |
+
cam_pts_h = torch.zeros(B, N, 4, device=cam_pts.device)
|
| 135 |
+
cam_pts_h[..., :3] = cam_pts
|
| 136 |
+
cam_pts_h[..., 3] = 1
|
| 137 |
+
world_pts = torch.einsum("bnij,bnj->bni", query_c2w, cam_pts_h)
|
| 138 |
+
# convert to other frames
|
| 139 |
+
cam_other_pts_ = torch.einsum("btnij,btnj->btni",
|
| 140 |
+
torch.inverse(c2w_traj[:,:,None].float().repeat(1,1,N,1,1)),
|
| 141 |
+
world_pts[:,None].repeat(1,T,1,1))
|
| 142 |
+
cam_depth = cam_other_pts_[...,2:3]
|
| 143 |
+
cam_other_pts = cam_other_pts_[...,:3] / (cam_other_pts_[...,2:3].abs()+1e-6)
|
| 144 |
+
cam_other_pts = torch.einsum("btnij,btnj->btni", intrs[:,:,None].repeat(1,1,N,1,1), cam_other_pts[...,:3])
|
| 145 |
+
cam_other_pts[..., 2:] = cam_depth
|
| 146 |
+
|
| 147 |
+
if visualize:
|
| 148 |
+
viser = Visualizer(save_dir=".", grayscale=True,
|
| 149 |
+
fps=10, pad_value=50, tracks_leave_trace=0)
|
| 150 |
+
cam_other_pts[..., 0] /= self.factor_x
|
| 151 |
+
cam_other_pts[..., 1] /= self.factor_y
|
| 152 |
+
viser.visualize(video=rgbs, tracks=cam_other_pts[..., :2], filename="test")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
init_xyzs = cam_other_pts
|
| 156 |
+
|
| 157 |
+
return init_xyzs, world_pts[..., :3], cam_other_pts_[..., :3]
|
| 158 |
+
|
| 159 |
+
def cam_from_track(self, tracks, intrs,
|
| 160 |
+
dyn_prob=None, metric_unc=None,
|
| 161 |
+
vis_est=None, only_cam_pts=False,
|
| 162 |
+
track_feat_concat=None,
|
| 163 |
+
tracks_xyz=None,
|
| 164 |
+
query_pts=None,
|
| 165 |
+
fixed_cam=False,
|
| 166 |
+
depth_unproj=None,
|
| 167 |
+
cam_gt=None,
|
| 168 |
+
init_pose=False,
|
| 169 |
+
):
|
| 170 |
+
"""
|
| 171 |
+
This function will generate tracks by camera transform
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
queries: B T N 3
|
| 175 |
+
scale_est: 1 1
|
| 176 |
+
shift_est: 1 1
|
| 177 |
+
intrs: B T 3 3
|
| 178 |
+
dyn_prob: B T N
|
| 179 |
+
metric_unc: B N 1
|
| 180 |
+
query_pts: B T N 3
|
| 181 |
+
"""
|
| 182 |
+
if tracks_xyz is not None:
|
| 183 |
+
B, T, N, _ = tracks.shape
|
| 184 |
+
cam_pts = tracks_xyz
|
| 185 |
+
intr_repeat = intrs[:,:,None].repeat(1,1,N,1,1)
|
| 186 |
+
else:
|
| 187 |
+
B, T, N, _ = tracks.shape
|
| 188 |
+
# get the pts in cam coordinate
|
| 189 |
+
tracks_xy = tracks[...,:3].clone().detach() # B T N 3
|
| 190 |
+
# tracks_z = 1/(tracks[...,2:] * scale_est + shift_est) # B T N 1
|
| 191 |
+
tracks_z = tracks[...,2:].detach() # B T N 1
|
| 192 |
+
tracks_xy[...,2] = 1
|
| 193 |
+
intr_repeat = intrs[:,:,None].repeat(1,1,N,1,1)
|
| 194 |
+
cam_pts = torch.einsum("bnij,bnj->bni",
|
| 195 |
+
torch.inverse(intr_repeat.view(B*T,N,3,3)).float(),
|
| 196 |
+
tracks_xy.view(B*T, N, 3))*(tracks_z.view(B*T,N,1).abs()) # B*T N 3
|
| 197 |
+
cam_pts[...,2] *= torch.sign(tracks_z.view(B*T,N))
|
| 198 |
+
# get the normalized cam pts, and pts refiner
|
| 199 |
+
mask_z = (tracks_z.max(dim=1)[0]<200).squeeze()
|
| 200 |
+
cam_pts = cam_pts.view(B, T, N, 3)
|
| 201 |
+
|
| 202 |
+
if only_cam_pts:
|
| 203 |
+
return cam_pts
|
| 204 |
+
dyn_prob = dyn_prob.mean(dim=1)[..., None]
|
| 205 |
+
# B T N 3 -> local frames coordinates. transformer static points B T N 3 -> B T N 3 static (B T N 3) -> same -> dynamic points @ C2T.inverse()
|
| 206 |
+
# get the cam pose
|
| 207 |
+
vis_est_ = vis_est[:,:,None,:]
|
| 208 |
+
graph_matrix = (vis_est_*vis_est_.permute(0, 2,1,3)).detach()
|
| 209 |
+
# find the max connected component
|
| 210 |
+
key_fr_idx = [0]
|
| 211 |
+
weight_final = (metric_unc) # * vis_est
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
with torch.amp.autocast(enabled=False, device_type='cuda'):
|
| 215 |
+
if fixed_cam:
|
| 216 |
+
c2w_traj_init = self.c2w_est_curr
|
| 217 |
+
c2w_traj_glob = c2w_traj_init
|
| 218 |
+
cam_pts_refine = cam_pts
|
| 219 |
+
intrs_refine = intrs
|
| 220 |
+
xy_refine = query_pts[...,1:3]
|
| 221 |
+
world_tracks_init = torch.einsum("btij,btnj->btni", c2w_traj_init[:,:,:3,:3], cam_pts) + c2w_traj_init[:,:,None,:3,3]
|
| 222 |
+
world_tracks_refined = world_tracks_init
|
| 223 |
+
# extract the stable static points for refine the camera pose
|
| 224 |
+
intrs_dn = intrs.clone()
|
| 225 |
+
intrs_dn[...,0,:] *= self.factor_x
|
| 226 |
+
intrs_dn[...,1,:] *= self.factor_y
|
| 227 |
+
_, query_world_pts, _ = self.track_from_cam(query_pts, c2w_traj_init, intrs_dn)
|
| 228 |
+
world_tracks_static, mask_static, mask_topk, vis_mask_static, tracks2d_static = extract_static_from_3DTracks(world_tracks_init,
|
| 229 |
+
dyn_prob, query_world_pts,
|
| 230 |
+
vis_est, tracks, img_size=self.image_size,
|
| 231 |
+
K=0)
|
| 232 |
+
world_static_refine = world_tracks_static
|
| 233 |
+
|
| 234 |
+
else:
|
| 235 |
+
|
| 236 |
+
if (not self.training):
|
| 237 |
+
# if (self.c2w_est_curr==torch.eye(4, device=cam_pts.device).repeat(B, T, 1, 1)).all():
|
| 238 |
+
campts_update = torch.einsum("btij,btnj->btni", self.c2w_est_curr[...,:3,:3], cam_pts) + self.c2w_est_curr[...,None,:3,3]
|
| 239 |
+
# campts_update = cam_pts
|
| 240 |
+
c2w_traj_init_update = key_fr_wprocrustes(campts_update, graph_matrix,
|
| 241 |
+
(weight_final*(1-dyn_prob)).permute(0,2,1), vis_est_.permute(0,1,3,2))
|
| 242 |
+
c2w_traj_init = c2w_traj_init_update@self.c2w_est_curr
|
| 243 |
+
# else:
|
| 244 |
+
# c2w_traj_init = self.c2w_est_curr # extract the stable static points for refine the camera pose
|
| 245 |
+
else:
|
| 246 |
+
# if (self.c2w_est_curr==torch.eye(4, device=cam_pts.device).repeat(B, T, 1, 1)).all():
|
| 247 |
+
campts_update = torch.einsum("btij,btnj->btni", self.c2w_est_curr[...,:3,:3], cam_pts) + self.c2w_est_curr[...,None,:3,3]
|
| 248 |
+
# campts_update = cam_pts
|
| 249 |
+
c2w_traj_init_update = key_fr_wprocrustes(campts_update, graph_matrix,
|
| 250 |
+
(weight_final*(1-dyn_prob)).permute(0,2,1), vis_est_.permute(0,1,3,2))
|
| 251 |
+
c2w_traj_init = c2w_traj_init_update@self.c2w_est_curr
|
| 252 |
+
# else:
|
| 253 |
+
# c2w_traj_init = self.c2w_est_curr # extract the stable static points for refine the camera pose
|
| 254 |
+
|
| 255 |
+
intrs_dn = intrs.clone()
|
| 256 |
+
intrs_dn[...,0,:] *= self.factor_x
|
| 257 |
+
intrs_dn[...,1,:] *= self.factor_y
|
| 258 |
+
_, query_world_pts, _ = self.track_from_cam(query_pts, c2w_traj_init, intrs_dn)
|
| 259 |
+
# refine the world tracks
|
| 260 |
+
world_tracks_init = torch.einsum("btij,btnj->btni", c2w_traj_init[:,:,:3,:3], cam_pts) + c2w_traj_init[:,:,None,:3,3]
|
| 261 |
+
world_tracks_static, mask_static, mask_topk, vis_mask_static, tracks2d_static = extract_static_from_3DTracks(world_tracks_init,
|
| 262 |
+
dyn_prob, query_world_pts,
|
| 263 |
+
vis_est, tracks, img_size=self.image_size,
|
| 264 |
+
K=150 if self.training else 1500)
|
| 265 |
+
# calculate the efficient ba
|
| 266 |
+
cam_tracks_static = cam_pts[:,:,mask_static.squeeze(),:][:,:,mask_topk.squeeze(),:]
|
| 267 |
+
cam_tracks_static[...,2] = depth_unproj.view(B, T, N)[:,:,mask_static.squeeze()][:,:,mask_topk.squeeze()]
|
| 268 |
+
|
| 269 |
+
c2w_traj_glob, world_static_refine, intrs_refine = ba_pycolmap(world_tracks_static, intrs,
|
| 270 |
+
c2w_traj_init, vis_mask_static,
|
| 271 |
+
tracks2d_static, self.image_size,
|
| 272 |
+
cam_tracks_static=cam_tracks_static,
|
| 273 |
+
training=self.training, query_pts=query_pts)
|
| 274 |
+
c2w_traj_glob = c2w_traj_glob.view(B, T, 4, 4)
|
| 275 |
+
world_tracks_refined = world_tracks_init
|
| 276 |
+
|
| 277 |
+
#NOTE: merge the index of static points and topk points
|
| 278 |
+
# merge_idx = torch.where(mask_static.squeeze()>0)[0][mask_topk.squeeze()]
|
| 279 |
+
# world_tracks_refined[:,:,merge_idx] = world_static_refine
|
| 280 |
+
|
| 281 |
+
# test the procrustes
|
| 282 |
+
w2c_traj_glob = torch.inverse(c2w_traj_init.detach())
|
| 283 |
+
cam_pts_refine = torch.einsum("btij,btnj->btni", w2c_traj_glob[:,:,:3,:3], world_tracks_refined) + w2c_traj_glob[:,:,None,:3,3]
|
| 284 |
+
# get the xyz_refine
|
| 285 |
+
#TODO: refiner
|
| 286 |
+
cam_pts4_proj = cam_pts_refine.clone()
|
| 287 |
+
cam_pts4_proj[...,2] *= torch.sign(cam_pts4_proj[...,2:3].view(B*T,N))
|
| 288 |
+
xy_refine = torch.einsum("btnij,btnj->btni", intrs_refine.view(B,T,1,3,3).repeat(1,1,N,1,1), cam_pts4_proj/cam_pts4_proj[...,2:3].abs())
|
| 289 |
+
xy_refine[..., 2] = cam_pts4_proj[...,2:3].view(B*T,N)
|
| 290 |
+
# xy_refine = torch.zeros_like(cam_pts_refine)[...,:2]
|
| 291 |
+
return c2w_traj_glob, cam_pts_refine, intrs_refine, xy_refine, world_tracks_init, world_tracks_refined, c2w_traj_init
|
| 292 |
+
|
| 293 |
+
def extract_img_feat(self, video, fmaps_chunk_size=200):
|
| 294 |
+
B, T, C, H, W = video.shape
|
| 295 |
+
dtype = video.dtype
|
| 296 |
+
H4, W4 = H // self.stride, W // self.stride
|
| 297 |
+
# Compute convolutional features for the video or for the current chunk in case of online mode
|
| 298 |
+
if T > fmaps_chunk_size:
|
| 299 |
+
fmaps = []
|
| 300 |
+
for t in range(0, T, fmaps_chunk_size):
|
| 301 |
+
video_chunk = video[:, t : t + fmaps_chunk_size]
|
| 302 |
+
fmaps_chunk = self.fnet(video_chunk.reshape(-1, C, H, W))
|
| 303 |
+
T_chunk = video_chunk.shape[1]
|
| 304 |
+
C_chunk, H_chunk, W_chunk = fmaps_chunk.shape[1:]
|
| 305 |
+
fmaps.append(fmaps_chunk.reshape(B, T_chunk, C_chunk, H_chunk, W_chunk))
|
| 306 |
+
fmaps = torch.cat(fmaps, dim=1).reshape(-1, C_chunk, H_chunk, W_chunk)
|
| 307 |
+
else:
|
| 308 |
+
fmaps = self.fnet(video.reshape(-1, C, H, W))
|
| 309 |
+
fmaps = fmaps.permute(0, 2, 3, 1)
|
| 310 |
+
fmaps = fmaps / torch.sqrt(
|
| 311 |
+
torch.maximum(
|
| 312 |
+
torch.sum(torch.square(fmaps), axis=-1, keepdims=True),
|
| 313 |
+
torch.tensor(1e-12, device=fmaps.device),
|
| 314 |
+
)
|
| 315 |
+
)
|
| 316 |
+
fmaps = fmaps.permute(0, 3, 1, 2).reshape(
|
| 317 |
+
B, -1, self.latent_dim, H // self.stride, W // self.stride
|
| 318 |
+
)
|
| 319 |
+
fmaps = fmaps.to(dtype)
|
| 320 |
+
|
| 321 |
+
return fmaps
|
| 322 |
+
|
| 323 |
+
def norm_xyz(self, xyz):
|
| 324 |
+
"""
|
| 325 |
+
xyz can be (B T N 3) or (B T 3 H W) or (B N 3)
|
| 326 |
+
"""
|
| 327 |
+
if xyz.ndim == 3:
|
| 328 |
+
min_pts = self.min_pts
|
| 329 |
+
max_pts = self.max_pts
|
| 330 |
+
return (xyz - min_pts[None,None,:]) / (max_pts - min_pts)[None,None,:] * 2 - 1
|
| 331 |
+
elif xyz.ndim == 4:
|
| 332 |
+
min_pts = self.min_pts
|
| 333 |
+
max_pts = self.max_pts
|
| 334 |
+
return (xyz - min_pts[None,None,None,:]) / (max_pts - min_pts)[None,None,None,:] * 2 - 1
|
| 335 |
+
elif xyz.ndim == 5:
|
| 336 |
+
if xyz.shape[2] == 3:
|
| 337 |
+
min_pts = self.min_pts
|
| 338 |
+
max_pts = self.max_pts
|
| 339 |
+
return (xyz - min_pts[None,None,:,None,None]) / (max_pts - min_pts)[None,None,:,None,None] * 2 - 1
|
| 340 |
+
elif xyz.shape[-1] == 3:
|
| 341 |
+
min_pts = self.min_pts
|
| 342 |
+
max_pts = self.max_pts
|
| 343 |
+
return (xyz - min_pts[None,None,None,None,:]) / (max_pts - min_pts)[None,None,None,None,:] * 2 - 1
|
| 344 |
+
|
| 345 |
+
def denorm_xyz(self, xyz):
|
| 346 |
+
"""
|
| 347 |
+
xyz can be (B T N 3) or (B T 3 H W) or (B N 3)
|
| 348 |
+
"""
|
| 349 |
+
if xyz.ndim == 3:
|
| 350 |
+
min_pts = self.min_pts
|
| 351 |
+
max_pts = self.max_pts
|
| 352 |
+
return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,:] + min_pts[None,None,:]
|
| 353 |
+
elif xyz.ndim == 4:
|
| 354 |
+
min_pts = self.min_pts
|
| 355 |
+
max_pts = self.max_pts
|
| 356 |
+
return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,None,:] + min_pts[None,None,None,:]
|
| 357 |
+
elif xyz.ndim == 5:
|
| 358 |
+
if xyz.shape[2] == 3:
|
| 359 |
+
min_pts = self.min_pts
|
| 360 |
+
max_pts = self.max_pts
|
| 361 |
+
return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,:,None,None] + min_pts[None,None,:,None,None]
|
| 362 |
+
elif xyz.shape[-1] == 3:
|
| 363 |
+
min_pts = self.min_pts
|
| 364 |
+
max_pts = self.max_pts
|
| 365 |
+
return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,None,None,:] + min_pts[None,None,None,None,:]
|
| 366 |
+
|
| 367 |
+
def forward(
|
| 368 |
+
self,
|
| 369 |
+
video,
|
| 370 |
+
metric_depth,
|
| 371 |
+
metric_unc,
|
| 372 |
+
point_map,
|
| 373 |
+
queries,
|
| 374 |
+
pts_q_3d=None,
|
| 375 |
+
overlap_d=None,
|
| 376 |
+
iters=4,
|
| 377 |
+
add_space_attn=True,
|
| 378 |
+
fmaps_chunk_size=200,
|
| 379 |
+
intrs=None,
|
| 380 |
+
traj3d_gt=None,
|
| 381 |
+
custom_vid=False,
|
| 382 |
+
vis_gt=None,
|
| 383 |
+
prec_fx=None,
|
| 384 |
+
prec_fy=None,
|
| 385 |
+
cam_gt=None,
|
| 386 |
+
init_pose=False,
|
| 387 |
+
support_pts_q=None,
|
| 388 |
+
update_pointmap=True,
|
| 389 |
+
fixed_cam=False,
|
| 390 |
+
query_no_BA=False,
|
| 391 |
+
stage=0,
|
| 392 |
+
cache=None,
|
| 393 |
+
points_map_gt=None,
|
| 394 |
+
valid_only=False,
|
| 395 |
+
replace_ratio=0.6,
|
| 396 |
+
):
|
| 397 |
+
"""Predict tracks
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
video (FloatTensor[B, T, 3 H W]): input videos.
|
| 401 |
+
queries (FloatTensor[B, N, 3]): point queries.
|
| 402 |
+
iters (int, optional): number of updates. Defaults to 4.
|
| 403 |
+
vdp_feats_cache: last layer's feature of depth
|
| 404 |
+
tracks_init: B T N 3 the initialization of 3D tracks computed by cam pose
|
| 405 |
+
Returns:
|
| 406 |
+
- coords_predicted (FloatTensor[B, T, N, 2]):
|
| 407 |
+
- vis_predicted (FloatTensor[B, T, N]):
|
| 408 |
+
- train_data: `None` if `is_train` is false, otherwise:
|
| 409 |
+
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
|
| 410 |
+
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
|
| 411 |
+
- mask (BoolTensor[B, T, N]):
|
| 412 |
+
"""
|
| 413 |
+
self.stage = stage
|
| 414 |
+
|
| 415 |
+
if cam_gt is not None:
|
| 416 |
+
cam_gt = cam_gt.clone()
|
| 417 |
+
cam_gt = torch.inverse(cam_gt[:,:1,...])@cam_gt
|
| 418 |
+
B, T, C, _, _ = video.shape
|
| 419 |
+
_, _, H_, W_ = metric_depth.shape
|
| 420 |
+
_, _, N, __ = queries.shape
|
| 421 |
+
if (vis_gt is not None)&(queries.shape[1] == T):
|
| 422 |
+
aug_visb = True
|
| 423 |
+
if aug_visb:
|
| 424 |
+
number_visible = vis_gt.sum(dim=1)
|
| 425 |
+
ratio_rand = torch.rand(B, N, device=vis_gt.device)
|
| 426 |
+
# first_positive_inds = get_nth_visible_time_index(vis_gt, 1)
|
| 427 |
+
first_positive_inds = get_nth_visible_time_index(vis_gt, (number_visible*ratio_rand).long().clamp(min=1, max=T))
|
| 428 |
+
|
| 429 |
+
assert (torch.gather(vis_gt, 1, first_positive_inds[:, None, :].repeat(1, T, 1)) < 0).sum() == 0
|
| 430 |
+
else:
|
| 431 |
+
__, first_positive_inds = torch.max(vis_gt, dim=1)
|
| 432 |
+
first_positive_inds = first_positive_inds.long()
|
| 433 |
+
gather = torch.gather(
|
| 434 |
+
queries, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
|
| 435 |
+
)
|
| 436 |
+
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
|
| 437 |
+
gather_xyz = torch.gather(
|
| 438 |
+
traj3d_gt, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 3)
|
| 439 |
+
)
|
| 440 |
+
z_gt_query = torch.diagonal(gather_xyz, dim1=1, dim2=2).permute(0, 2, 1)[...,2]
|
| 441 |
+
queries = torch.cat([first_positive_inds[:, :, None], xys], dim=-1)
|
| 442 |
+
queries = torch.cat([queries, support_pts_q[:,0]], dim=1)
|
| 443 |
+
else:
|
| 444 |
+
# Generate the 768 points randomly in the whole video
|
| 445 |
+
queries = queries.squeeze(1)
|
| 446 |
+
ba_len = queries.shape[1]
|
| 447 |
+
z_gt_query = None
|
| 448 |
+
if support_pts_q is not None:
|
| 449 |
+
queries = torch.cat([queries, support_pts_q[:,0]], dim=1)
|
| 450 |
+
|
| 451 |
+
if (abs(prec_fx-1.0) > 1e-4) & (self.training) & (traj3d_gt is not None):
|
| 452 |
+
traj3d_gt[..., 0] /= prec_fx
|
| 453 |
+
traj3d_gt[..., 1] /= prec_fy
|
| 454 |
+
queries[...,1] /= prec_fx
|
| 455 |
+
queries[...,2] /= prec_fy
|
| 456 |
+
|
| 457 |
+
video_vis = F.interpolate(video.clone().view(B*T, 3, video.shape[-2], video.shape[-1]), (H_, W_), mode="bilinear", align_corners=False).view(B, T, 3, H_, W_)
|
| 458 |
+
|
| 459 |
+
self.image_size = torch.tensor([H_, W_])
|
| 460 |
+
# self.model_resolution = (H_, W_)
|
| 461 |
+
# resize the queries and intrs
|
| 462 |
+
self.factor_x = self.model_resolution[1]/W_
|
| 463 |
+
self.factor_y = self.model_resolution[0]/H_
|
| 464 |
+
queries[...,1] *= self.factor_x
|
| 465 |
+
queries[...,2] *= self.factor_y
|
| 466 |
+
intrs_org = intrs.clone()
|
| 467 |
+
intrs[...,0,:] *= self.factor_x
|
| 468 |
+
intrs[...,1,:] *= self.factor_y
|
| 469 |
+
|
| 470 |
+
# get the fmaps and color features
|
| 471 |
+
video = F.interpolate(video.view(B*T, 3, video.shape[-2], video.shape[-1]),
|
| 472 |
+
(self.model_resolution[0], self.model_resolution[1])).view(B, T, 3, self.model_resolution[0], self.model_resolution[1])
|
| 473 |
+
_, _, _, H, W = video.shape
|
| 474 |
+
if cache is not None:
|
| 475 |
+
T_cache = cache["fmaps"].shape[0]
|
| 476 |
+
fmaps = self.extract_img_feat(video[:,T_cache:], fmaps_chunk_size=fmaps_chunk_size)
|
| 477 |
+
fmaps = torch.cat([cache["fmaps"][None], fmaps], dim=1)
|
| 478 |
+
else:
|
| 479 |
+
fmaps = self.extract_img_feat(video, fmaps_chunk_size=fmaps_chunk_size)
|
| 480 |
+
fmaps_org = fmaps.clone()
|
| 481 |
+
|
| 482 |
+
metric_depth = F.interpolate(metric_depth.view(B*T, 1, H_, W_),
|
| 483 |
+
(self.model_resolution[0], self.model_resolution[1]),mode="nearest").view(B*T, 1, self.model_resolution[0], self.model_resolution[1]).clamp(0.01, 200)
|
| 484 |
+
self.metric_unc_org = metric_unc.clone()
|
| 485 |
+
metric_unc = F.interpolate(metric_unc.view(B*T, 1, H_, W_),
|
| 486 |
+
(self.model_resolution[0], self.model_resolution[1]),mode="nearest").view(B*T, 1, self.model_resolution[0], self.model_resolution[1])
|
| 487 |
+
if (self.stage == 2) & (self.training):
|
| 488 |
+
scale_rand = (torch.rand(B, T, device=video.device) - 0.5) + 1
|
| 489 |
+
point_map = scale_rand.view(B*T,1,1,1) * point_map
|
| 490 |
+
|
| 491 |
+
point_map_org = point_map.permute(0,3,1,2).view(B*T, 3, H_, W_).clone()
|
| 492 |
+
point_map = F.interpolate(point_map_org.clone(),
|
| 493 |
+
(self.model_resolution[0], self.model_resolution[1]),mode="nearest").view(B*T, 3, self.model_resolution[0], self.model_resolution[1])
|
| 494 |
+
# align the point map
|
| 495 |
+
point_map_org_train = point_map_org.view(B*T, 3, H_, W_).clone()
|
| 496 |
+
|
| 497 |
+
if (stage == 2):
|
| 498 |
+
# align the point map
|
| 499 |
+
try:
|
| 500 |
+
self.pred_points, scale_gt, shift_gt = affine_invariant_global_loss(
|
| 501 |
+
point_map_org_train.permute(0,2,3,1),
|
| 502 |
+
points_map_gt,
|
| 503 |
+
mask=self.metric_unc_org[:,0]>0.5,
|
| 504 |
+
align_resolution=32,
|
| 505 |
+
only_align=True
|
| 506 |
+
)
|
| 507 |
+
except:
|
| 508 |
+
scale_gt, shift_gt = torch.ones(B*T).to(video.device), torch.zeros(B*T,3).to(video.device)
|
| 509 |
+
self.scale_gt, self.shift_gt = scale_gt, shift_gt
|
| 510 |
+
else:
|
| 511 |
+
scale_est, shift_est = None, None
|
| 512 |
+
|
| 513 |
+
# extract the pts features
|
| 514 |
+
device = queries.device
|
| 515 |
+
assert H % self.stride == 0 and W % self.stride == 0
|
| 516 |
+
|
| 517 |
+
B, N, __ = queries.shape
|
| 518 |
+
queries_z = sample_features5d(metric_depth.view(B, T, 1, H, W),
|
| 519 |
+
queries[:,None], interp_mode="nearest").squeeze(1)
|
| 520 |
+
queries_z_unc = sample_features5d(metric_unc.view(B, T, 1, H, W),
|
| 521 |
+
queries[:,None], interp_mode="nearest").squeeze(1)
|
| 522 |
+
|
| 523 |
+
queries_rgb = sample_features5d(video.view(B, T, C, H, W),
|
| 524 |
+
queries[:,None], interp_mode="nearest").squeeze(1)
|
| 525 |
+
queries_point_map = sample_features5d(point_map.view(B, T, 3, H, W),
|
| 526 |
+
queries[:,None], interp_mode="nearest").squeeze(1)
|
| 527 |
+
if ((queries_z > 100)*(queries_z == 0)).sum() > 0:
|
| 528 |
+
import pdb; pdb.set_trace()
|
| 529 |
+
|
| 530 |
+
if overlap_d is not None:
|
| 531 |
+
queries_z[:,:overlap_d.shape[1],:] = overlap_d[...,None]
|
| 532 |
+
queries_point_map[:,:overlap_d.shape[1],2:] = overlap_d[...,None]
|
| 533 |
+
|
| 534 |
+
if pts_q_3d is not None:
|
| 535 |
+
scale_factor = (pts_q_3d[...,-1].permute(0,2,1) / queries_z[:,:pts_q_3d.shape[2],:]).squeeze().median()
|
| 536 |
+
queries_z[:,:pts_q_3d.shape[2],:] = pts_q_3d[...,-1].permute(0,2,1) / scale_factor
|
| 537 |
+
queries_point_map[:,:pts_q_3d.shape[2],2:] = pts_q_3d[...,-1].permute(0,2,1) / scale_factor
|
| 538 |
+
|
| 539 |
+
# normalize the points
|
| 540 |
+
self.min_pts, self.max_pts = queries_point_map.mean(dim=(0,1)) - 3*queries_point_map.std(dim=(0,1)), queries_point_map.mean(dim=(0,1)) + 3*queries_point_map.std(dim=(0,1))
|
| 541 |
+
queries_point_map = self.norm_xyz(queries_point_map)
|
| 542 |
+
queries_point_map_ = queries_point_map.reshape(B, 1, N, 3).expand(B, T, N, 3).clone()
|
| 543 |
+
point_map = self.norm_xyz(point_map.view(B, T, 3, H, W)).view(B*T, 3, H, W)
|
| 544 |
+
|
| 545 |
+
if z_gt_query is not None:
|
| 546 |
+
queries_z[:,:z_gt_query.shape[1],:] = z_gt_query[:,:,None]
|
| 547 |
+
mask_traj_gt = ((queries_z[:,:z_gt_query.shape[1],:] - z_gt_query[:,:,None])).abs() < 0.1
|
| 548 |
+
else:
|
| 549 |
+
if traj3d_gt is not None:
|
| 550 |
+
mask_traj_gt = torch.ones_like(queries_z[:, :traj3d_gt.shape[2]]).bool()
|
| 551 |
+
else:
|
| 552 |
+
mask_traj_gt = torch.ones_like(queries_z).bool()
|
| 553 |
+
|
| 554 |
+
queries_xyz = torch.cat([queries, queries_z], dim=-1)[:,None].repeat(1, T, 1, 1)
|
| 555 |
+
if cache is not None:
|
| 556 |
+
cache_T, cache_N = cache["track2d_pred_cache"].shape[0], cache["track2d_pred_cache"].shape[1]
|
| 557 |
+
cachexy = cache["track2d_pred_cache"].clone()
|
| 558 |
+
cachexy[...,0] = cachexy[...,0] * self.factor_x
|
| 559 |
+
cachexy[...,1] = cachexy[...,1] * self.factor_y
|
| 560 |
+
# initialize the 2d points with cache
|
| 561 |
+
queries_xyz[:,:cache_T,:cache_N,1:] = cachexy
|
| 562 |
+
queries_xyz[:,cache_T:,:cache_N,1:] = cachexy[-1:]
|
| 563 |
+
# initialize the 3d points with cache
|
| 564 |
+
queries_point_map_[:,:cache_T,:cache_N,:] = self.norm_xyz(cache["track3d_pred_cache"][None])
|
| 565 |
+
queries_point_map_[:,cache_T:,:cache_N,:] = self.norm_xyz(cache["track3d_pred_cache"][-1:][None])
|
| 566 |
+
|
| 567 |
+
if cam_gt is not None:
|
| 568 |
+
q_static_proj, q_xyz_world, q_xyz_cam = self.track_from_cam(queries_xyz, cam_gt,
|
| 569 |
+
intrs, rgbs=video_vis, visualize=False)
|
| 570 |
+
q_static_proj[..., 0] /= self.factor_x
|
| 571 |
+
q_static_proj[..., 1] /= self.factor_y
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
assert T >= 1 # A tracker needs at least two frames to track something
|
| 575 |
+
video = 2 * (video / 255.0) - 1.0
|
| 576 |
+
dtype = video.dtype
|
| 577 |
+
queried_frames = queries[:, :, 0].long()
|
| 578 |
+
|
| 579 |
+
queried_coords = queries[..., 1:3]
|
| 580 |
+
queried_coords = queried_coords / self.stride
|
| 581 |
+
|
| 582 |
+
# We store our predictions here
|
| 583 |
+
(all_coords_predictions, all_coords_xyz_predictions,all_vis_predictions,
|
| 584 |
+
all_confidence_predictions, all_cam_predictions, all_dynamic_prob_predictions,
|
| 585 |
+
all_cam_pts_predictions, all_world_tracks_predictions, all_world_tracks_refined_predictions,
|
| 586 |
+
all_scale_est, all_shift_est) = (
|
| 587 |
+
[],
|
| 588 |
+
[],
|
| 589 |
+
[],
|
| 590 |
+
[],
|
| 591 |
+
[],
|
| 592 |
+
[],
|
| 593 |
+
[],
|
| 594 |
+
[],
|
| 595 |
+
[],
|
| 596 |
+
[],
|
| 597 |
+
[]
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
# We compute track features
|
| 601 |
+
fmaps_pyramid = []
|
| 602 |
+
point_map_pyramid = []
|
| 603 |
+
track_feat_pyramid = []
|
| 604 |
+
track_feat_support_pyramid = []
|
| 605 |
+
track_feat3d_pyramid = []
|
| 606 |
+
track_feat_support3d_pyramid = []
|
| 607 |
+
track_depth_support_pyramid = []
|
| 608 |
+
track_point_map_pyramid = []
|
| 609 |
+
track_point_map_support_pyramid = []
|
| 610 |
+
fmaps_pyramid.append(fmaps)
|
| 611 |
+
metric_depth = metric_depth
|
| 612 |
+
point_map = point_map
|
| 613 |
+
metric_depth_align = F.interpolate(metric_depth, scale_factor=0.25, mode='nearest')
|
| 614 |
+
point_map_align = F.interpolate(point_map, scale_factor=0.25, mode='nearest')
|
| 615 |
+
point_map_pyramid.append(point_map_align.view(B, T, 3, point_map_align.shape[-2], point_map_align.shape[-1]))
|
| 616 |
+
for i in range(self.corr_levels - 1):
|
| 617 |
+
fmaps_ = fmaps.reshape(
|
| 618 |
+
B * T, self.latent_dim, fmaps.shape[-2], fmaps.shape[-1]
|
| 619 |
+
)
|
| 620 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
| 621 |
+
fmaps = fmaps_.reshape(
|
| 622 |
+
B, T, self.latent_dim, fmaps_.shape[-2], fmaps_.shape[-1]
|
| 623 |
+
)
|
| 624 |
+
fmaps_pyramid.append(fmaps)
|
| 625 |
+
# downsample the depth
|
| 626 |
+
metric_depth_ = metric_depth_align.reshape(B*T,1,metric_depth_align.shape[-2],metric_depth_align.shape[-1])
|
| 627 |
+
metric_depth_ = F.interpolate(metric_depth_, scale_factor=0.5, mode='nearest')
|
| 628 |
+
metric_depth_align = metric_depth_.reshape(B,T,1,metric_depth_.shape[-2], metric_depth_.shape[-1])
|
| 629 |
+
# downsample the point map
|
| 630 |
+
point_map_ = point_map_align.reshape(B*T,3,point_map_align.shape[-2],point_map_align.shape[-1])
|
| 631 |
+
point_map_ = F.interpolate(point_map_, scale_factor=0.5, mode='nearest')
|
| 632 |
+
point_map_align = point_map_.reshape(B,T,3,point_map_.shape[-2], point_map_.shape[-1])
|
| 633 |
+
point_map_pyramid.append(point_map_align)
|
| 634 |
+
|
| 635 |
+
for i in range(self.corr_levels):
|
| 636 |
+
if cache is not None:
|
| 637 |
+
cache_N = cache["track_feat_pyramid"][i].shape[2]
|
| 638 |
+
track_feat_cached, track_feat_support_cached = cache["track_feat_pyramid"][i], cache["track_feat_support_pyramid"][i]
|
| 639 |
+
track_feat3d_cached, track_feat_support3d_cached = cache["track_feat3d_pyramid"][i], cache["track_feat_support3d_pyramid"][i]
|
| 640 |
+
track_point_map_cached, track_point_map_support_cached = self.norm_xyz(cache["track_point_map_pyramid"][i]), self.norm_xyz(cache["track_point_map_support_pyramid"][i])
|
| 641 |
+
queried_coords_new = queried_coords[:,cache_N:,:] / 2**i
|
| 642 |
+
queried_frames_new = queried_frames[:,cache_N:]
|
| 643 |
+
else:
|
| 644 |
+
queried_coords_new = queried_coords / 2**i
|
| 645 |
+
queried_frames_new = queried_frames
|
| 646 |
+
track_feat, track_feat_support = self.get_track_feat(
|
| 647 |
+
fmaps_pyramid[i],
|
| 648 |
+
queried_frames_new,
|
| 649 |
+
queried_coords_new,
|
| 650 |
+
support_radius=self.corr_radius,
|
| 651 |
+
)
|
| 652 |
+
# get 3d track feat
|
| 653 |
+
track_point_map, track_point_map_support = self.get_track_feat(
|
| 654 |
+
point_map_pyramid[i],
|
| 655 |
+
queried_frames_new,
|
| 656 |
+
queried_coords_new,
|
| 657 |
+
support_radius=self.corr3d_radius,
|
| 658 |
+
)
|
| 659 |
+
track_feat3d, track_feat_support3d = self.get_track_feat(
|
| 660 |
+
fmaps_pyramid[i],
|
| 661 |
+
queried_frames_new,
|
| 662 |
+
queried_coords_new,
|
| 663 |
+
support_radius=self.corr3d_radius,
|
| 664 |
+
)
|
| 665 |
+
if cache is not None:
|
| 666 |
+
track_feat = torch.cat([track_feat_cached, track_feat], dim=2)
|
| 667 |
+
track_point_map = torch.cat([track_point_map_cached, track_point_map], dim=2)
|
| 668 |
+
track_feat_support = torch.cat([track_feat_support_cached[:,0], track_feat_support], dim=2)
|
| 669 |
+
track_point_map_support = torch.cat([track_point_map_support_cached[:,0], track_point_map_support], dim=2)
|
| 670 |
+
track_feat3d = torch.cat([track_feat3d_cached, track_feat3d], dim=2)
|
| 671 |
+
track_feat_support3d = torch.cat([track_feat_support3d_cached[:,0], track_feat_support3d], dim=2)
|
| 672 |
+
track_feat_pyramid.append(track_feat.repeat(1, T, 1, 1))
|
| 673 |
+
track_feat_support_pyramid.append(track_feat_support.unsqueeze(1))
|
| 674 |
+
track_feat3d_pyramid.append(track_feat3d.repeat(1, T, 1, 1))
|
| 675 |
+
track_feat_support3d_pyramid.append(track_feat_support3d.unsqueeze(1))
|
| 676 |
+
track_point_map_pyramid.append(track_point_map.repeat(1, T, 1, 1))
|
| 677 |
+
track_point_map_support_pyramid.append(track_point_map_support.unsqueeze(1))
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
D_coords = 2
|
| 681 |
+
(coord_preds, coords_xyz_preds, vis_preds, confidence_preds,
|
| 682 |
+
dynamic_prob_preds, cam_preds, pts3d_cam_pred, world_tracks_pred,
|
| 683 |
+
world_tracks_refined_pred, point_map_preds, scale_ests, shift_ests) = (
|
| 684 |
+
[], [], [], [], [], [], [], [], [], [], [], []
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
c2w_ests = []
|
| 688 |
+
vis = torch.zeros((B, T, N), device=device).float()
|
| 689 |
+
confidence = torch.zeros((B, T, N), device=device).float()
|
| 690 |
+
dynamic_prob = torch.zeros((B, T, N), device=device).float()
|
| 691 |
+
pro_analysis_w = torch.zeros((B, T, N), device=device).float()
|
| 692 |
+
|
| 693 |
+
coords = queries_xyz[...,1:].clone()
|
| 694 |
+
coords[...,:2] /= self.stride
|
| 695 |
+
# coords[...,:2] = queried_coords.reshape(B, 1, N, 2).expand(B, T, N, 2).float()[...,:2]
|
| 696 |
+
# initialize the 3d points
|
| 697 |
+
coords_xyz = queries_point_map_.clone()
|
| 698 |
+
|
| 699 |
+
# if cache is not None:
|
| 700 |
+
# viser = Visualizer(save_dir=".", grayscale=True,
|
| 701 |
+
# fps=10, pad_value=50, tracks_leave_trace=0)
|
| 702 |
+
# coords_clone = coords.clone()
|
| 703 |
+
# coords_clone[...,:2] *= self.stride
|
| 704 |
+
# coords_clone[..., 0] /= self.factor_x
|
| 705 |
+
# coords_clone[..., 1] /= self.factor_y
|
| 706 |
+
# viser.visualize(video=video_vis, tracks=coords_clone[..., :2], filename="test")
|
| 707 |
+
# import pdb; pdb.set_trace()
|
| 708 |
+
|
| 709 |
+
if init_pose:
|
| 710 |
+
q_init_proj, q_xyz_world, q_xyz_cam = self.track_from_cam(queries_xyz, cam_gt,
|
| 711 |
+
intrs, rgbs=video_vis, visualize=False)
|
| 712 |
+
q_init_proj[..., 0] /= self.stride
|
| 713 |
+
q_init_proj[..., 1] /= self.stride
|
| 714 |
+
|
| 715 |
+
r = 2 * self.corr_radius + 1
|
| 716 |
+
r_depth = 2 * self.corr3d_radius + 1
|
| 717 |
+
anchor_loss = 0
|
| 718 |
+
# two current states
|
| 719 |
+
self.c2w_est_curr = torch.eye(4, device=device).repeat(B, T , 1, 1)
|
| 720 |
+
coords_proj_curr = coords.view(B * T, N, 3)[...,:2]
|
| 721 |
+
if init_pose:
|
| 722 |
+
self.c2w_est_curr = cam_gt.to(coords_proj_curr.device).to(coords_proj_curr.dtype)
|
| 723 |
+
sync_loss = 0
|
| 724 |
+
if stage == 2:
|
| 725 |
+
extra_sparse_tokens = self.scale_shift_tokens[:,:,None,:].repeat(B, 1, T, 1)
|
| 726 |
+
extra_dense_tokens = self.residual_embedding[None,None].repeat(B, T, 1, 1, 1)
|
| 727 |
+
xyz_pos_enc = posenc(point_map_pyramid[-2].permute(0,1,3,4,2), min_deg=0, max_deg=10).permute(0,1,4,2,3)
|
| 728 |
+
extra_dense_tokens = torch.cat([xyz_pos_enc, extra_dense_tokens, fmaps_pyramid[-2]], dim=2)
|
| 729 |
+
extra_dense_tokens = rearrange(extra_dense_tokens, 'b t c h w -> (b t) c h w')
|
| 730 |
+
extra_dense_tokens = self.dense_mlp(extra_dense_tokens)
|
| 731 |
+
extra_dense_tokens = rearrange(extra_dense_tokens, '(b t) c h w -> b t c h w', b=B, t=T)
|
| 732 |
+
else:
|
| 733 |
+
extra_sparse_tokens = None
|
| 734 |
+
extra_dense_tokens = None
|
| 735 |
+
|
| 736 |
+
scale_est, shift_est = torch.ones(B, T, 1, 1, device=device), torch.zeros(B, T, 1, 3, device=device)
|
| 737 |
+
residual_point = torch.zeros(B, T, 3, self.model_resolution[0]//self.stride,
|
| 738 |
+
self.model_resolution[1]//self.stride, device=device)
|
| 739 |
+
|
| 740 |
+
for it in range(iters):
|
| 741 |
+
# query points scale and shift
|
| 742 |
+
scale_est_query = torch.gather(scale_est, dim=1, index=queries[:,:,None,:1].long())
|
| 743 |
+
shift_est_query = torch.gather(shift_est, dim=1, index=queries[:,:,None,:1].long().repeat(1, 1, 1, 3))
|
| 744 |
+
|
| 745 |
+
coords = coords.detach() # B T N 3
|
| 746 |
+
coords_xyz = coords_xyz.detach()
|
| 747 |
+
vis = vis.detach()
|
| 748 |
+
confidence = confidence.detach()
|
| 749 |
+
dynamic_prob = dynamic_prob.detach()
|
| 750 |
+
pro_analysis_w = pro_analysis_w.detach()
|
| 751 |
+
coords_init = coords.view(B * T, N, 3)
|
| 752 |
+
coords_xyz_init = coords_xyz.view(B * T, N, 3)
|
| 753 |
+
corr_embs = []
|
| 754 |
+
corr_depth_embs = []
|
| 755 |
+
corr_feats = []
|
| 756 |
+
for i in range(self.corr_levels):
|
| 757 |
+
# K_level = int(32*0.8**(i))
|
| 758 |
+
K_level = 16
|
| 759 |
+
corr_feat = self.get_correlation_feat(
|
| 760 |
+
fmaps_pyramid[i], coords_init[...,:2] / 2**i
|
| 761 |
+
)
|
| 762 |
+
#NOTE: update the point map
|
| 763 |
+
residual_point_i = F.interpolate(residual_point.view(B*T,3,residual_point.shape[-2],residual_point.shape[-1]),
|
| 764 |
+
size=(point_map_pyramid[i].shape[-2], point_map_pyramid[i].shape[-1]), mode='nearest')
|
| 765 |
+
point_map_pyramid_i = (self.denorm_xyz(point_map_pyramid[i]) * scale_est[...,None]
|
| 766 |
+
+ shift_est.permute(0,1,3,2)[...,None] + residual_point_i.view(B,T,3,point_map_pyramid[i].shape[-2], point_map_pyramid[i].shape[-1])).clone().detach()
|
| 767 |
+
|
| 768 |
+
corr_point_map = self.get_correlation_feat(
|
| 769 |
+
self.norm_xyz(point_map_pyramid_i), coords_proj_curr / 2**i, radius=self.corr3d_radius
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
corr_point_feat = self.get_correlation_feat(
|
| 773 |
+
fmaps_pyramid[i], coords_proj_curr / 2**i, radius=self.corr3d_radius
|
| 774 |
+
)
|
| 775 |
+
track_feat_support = (
|
| 776 |
+
track_feat_support_pyramid[i]
|
| 777 |
+
.view(B, 1, r, r, N, self.latent_dim)
|
| 778 |
+
.squeeze(1)
|
| 779 |
+
.permute(0, 3, 1, 2, 4)
|
| 780 |
+
)
|
| 781 |
+
track_feat_support3d = (
|
| 782 |
+
track_feat_support3d_pyramid[i]
|
| 783 |
+
.view(B, 1, r_depth, r_depth, N, self.latent_dim)
|
| 784 |
+
.squeeze(1)
|
| 785 |
+
.permute(0, 3, 1, 2, 4)
|
| 786 |
+
)
|
| 787 |
+
#NOTE: update the point map
|
| 788 |
+
track_point_map_support_pyramid_i = (self.denorm_xyz(track_point_map_support_pyramid[i]) * scale_est_query.view(B,1,1,N,1)
|
| 789 |
+
+ shift_est_query.view(B,1,1,N,3)).clone().detach()
|
| 790 |
+
|
| 791 |
+
track_point_map_support = (
|
| 792 |
+
self.norm_xyz(track_point_map_support_pyramid_i)
|
| 793 |
+
.view(B, 1, r_depth, r_depth, N, 3)
|
| 794 |
+
.squeeze(1)
|
| 795 |
+
.permute(0, 3, 1, 2, 4)
|
| 796 |
+
)
|
| 797 |
+
corr_volume = torch.einsum(
|
| 798 |
+
"btnhwc,bnijc->btnhwij", corr_feat, track_feat_support
|
| 799 |
+
)
|
| 800 |
+
corr_emb = self.corr_mlp(corr_volume.reshape(B, T, N, r * r * r * r))
|
| 801 |
+
|
| 802 |
+
with torch.no_grad():
|
| 803 |
+
rel_pos_query_ = track_point_map_support - track_point_map_support[:,:,self.corr3d_radius,self.corr3d_radius,:][...,None,None,:]
|
| 804 |
+
rel_pos_target_ = corr_point_map - coords_xyz_init.view(B, T, N, 1, 1, 3)
|
| 805 |
+
# select the top 9 points
|
| 806 |
+
rel_pos_query_idx = rel_pos_query_.norm(dim=-1).view(B, N, -1).topk(K_level+1, dim=-1, largest=False)[1][...,1:,None]
|
| 807 |
+
rel_pos_target_idx = rel_pos_target_.norm(dim=-1).view(B, T, N, -1).topk(K_level+1, dim=-1, largest=False)[1][...,1:,None]
|
| 808 |
+
rel_pos_query_ = torch.gather(rel_pos_query_.view(B, N, -1, 3), dim=-2, index=rel_pos_query_idx.expand(B, N, K_level, 3))
|
| 809 |
+
rel_pos_target_ = torch.gather(rel_pos_target_.view(B, T, N, -1, 3), dim=-2, index=rel_pos_target_idx.expand(B, T, N, K_level, 3))
|
| 810 |
+
rel_pos_query = rel_pos_query_
|
| 811 |
+
rel_pos_target = rel_pos_target_
|
| 812 |
+
rel_pos_query = posenc(rel_pos_query, min_deg=0, max_deg=12)
|
| 813 |
+
rel_pos_target = posenc(rel_pos_target, min_deg=0, max_deg=12)
|
| 814 |
+
rel_pos_target = self.rel_pos_mlp(rel_pos_target)
|
| 815 |
+
rel_pos_query = self.rel_pos_mlp(rel_pos_query)
|
| 816 |
+
with torch.no_grad():
|
| 817 |
+
# integrate with feature
|
| 818 |
+
track_feat_support_ = rearrange(track_feat_support3d, 'b n r k c -> b n (r k) c', r=r_depth, k=r_depth, n=N, b=B)
|
| 819 |
+
track_feat_support_ = torch.gather(track_feat_support_, dim=-2, index=rel_pos_query_idx.expand(B, N, K_level, 128))
|
| 820 |
+
queried_feat = torch.cat([rel_pos_query, track_feat_support_], dim=-1)
|
| 821 |
+
corr_feat_ = rearrange(corr_point_feat, 'b t n r k c -> b t n (r k) c', t=T, n=N, b=B)
|
| 822 |
+
corr_feat_ = torch.gather(corr_feat_, dim=-2, index=rel_pos_target_idx.expand(B, T, N, K_level, 128))
|
| 823 |
+
target_feat = torch.cat([rel_pos_target, corr_feat_], dim=-1)
|
| 824 |
+
|
| 825 |
+
# 3d attention
|
| 826 |
+
queried_feat = self.corr_xyz_mlp(queried_feat)
|
| 827 |
+
target_feat = self.corr_xyz_mlp(target_feat)
|
| 828 |
+
queried_feat = repeat(queried_feat, 'b n k c -> b t n k c', k=K_level, t=T, n=N, b=B)
|
| 829 |
+
corr_depth_emb = self.corr_transformer[0](queried_feat.reshape(B*T*N,-1,128),
|
| 830 |
+
target_feat.reshape(B*T*N,-1,128),
|
| 831 |
+
target_rel_pos=rel_pos_target.reshape(B*T*N,-1,128))
|
| 832 |
+
corr_depth_emb = rearrange(corr_depth_emb, '(b t n) 1 c -> b t n c', t=T, n=N, b=B)
|
| 833 |
+
corr_depth_emb = self.corr_depth_mlp(corr_depth_emb)
|
| 834 |
+
valid_mask = self.denorm_xyz(coords_xyz_init).view(B, T, N, -1)[...,2:3] > 0
|
| 835 |
+
corr_depth_embs.append(corr_depth_emb*valid_mask)
|
| 836 |
+
|
| 837 |
+
corr_embs.append(corr_emb)
|
| 838 |
+
corr_embs = torch.cat(corr_embs, dim=-1)
|
| 839 |
+
corr_embs = corr_embs.view(B, T, N, corr_embs.shape[-1])
|
| 840 |
+
corr_depth_embs = torch.cat(corr_depth_embs, dim=-1)
|
| 841 |
+
corr_depth_embs = corr_depth_embs.view(B, T, N, corr_depth_embs.shape[-1])
|
| 842 |
+
transformer_input = [vis[..., None], confidence[..., None], corr_embs]
|
| 843 |
+
transformer_input_depth = [vis[..., None], confidence[..., None], corr_depth_embs]
|
| 844 |
+
|
| 845 |
+
rel_coords_forward = coords[:,:-1,...,:2] - coords[:,1:,...,:2]
|
| 846 |
+
rel_coords_backward = coords[:, 1:,...,:2] - coords[:, :-1,...,:2]
|
| 847 |
+
|
| 848 |
+
rel_xyz_forward = coords_xyz[:,:-1,...,:3] - coords_xyz[:,1:,...,:3]
|
| 849 |
+
rel_xyz_backward = coords_xyz[:, 1:,...,:3] - coords_xyz[:, :-1,...,:3]
|
| 850 |
+
|
| 851 |
+
rel_coords_forward = torch.nn.functional.pad(
|
| 852 |
+
rel_coords_forward, (0, 0, 0, 0, 0, 1)
|
| 853 |
+
)
|
| 854 |
+
rel_coords_backward = torch.nn.functional.pad(
|
| 855 |
+
rel_coords_backward, (0, 0, 0, 0, 1, 0)
|
| 856 |
+
)
|
| 857 |
+
rel_xyz_forward = torch.nn.functional.pad(
|
| 858 |
+
rel_xyz_forward, (0, 0, 0, 0, 0, 1)
|
| 859 |
+
)
|
| 860 |
+
rel_xyz_backward = torch.nn.functional.pad(
|
| 861 |
+
rel_xyz_backward, (0, 0, 0, 0, 1, 0)
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
scale = (
|
| 865 |
+
torch.tensor(
|
| 866 |
+
[self.model_resolution[1], self.model_resolution[0]],
|
| 867 |
+
device=coords.device,
|
| 868 |
+
)
|
| 869 |
+
/ self.stride
|
| 870 |
+
)
|
| 871 |
+
rel_coords_forward = rel_coords_forward / scale
|
| 872 |
+
rel_coords_backward = rel_coords_backward / scale
|
| 873 |
+
|
| 874 |
+
rel_pos_emb_input = posenc(
|
| 875 |
+
torch.cat([rel_coords_forward, rel_coords_backward], dim=-1),
|
| 876 |
+
min_deg=0,
|
| 877 |
+
max_deg=10,
|
| 878 |
+
) # batch, num_points, num_frames, 84
|
| 879 |
+
rel_xyz_emb_input = posenc(
|
| 880 |
+
torch.cat([rel_xyz_forward, rel_xyz_backward], dim=-1),
|
| 881 |
+
min_deg=0,
|
| 882 |
+
max_deg=10,
|
| 883 |
+
) # batch, num_points, num_frames, 126
|
| 884 |
+
rel_xyz_emb_input = self.xyz_mlp(rel_xyz_emb_input)
|
| 885 |
+
transformer_input.append(rel_pos_emb_input)
|
| 886 |
+
transformer_input_depth.append(rel_xyz_emb_input)
|
| 887 |
+
# get the queries world
|
| 888 |
+
with torch.no_grad():
|
| 889 |
+
# update the query points with scale and shift
|
| 890 |
+
queries_xyz_i = queries_xyz.clone().detach()
|
| 891 |
+
queries_xyz_i[..., -1] = queries_xyz_i[..., -1] * scale_est_query.view(B,1,N) + shift_est_query.view(B,1,N,3)[...,2]
|
| 892 |
+
_, _, q_xyz_cam = self.track_from_cam(queries_xyz_i, self.c2w_est_curr,
|
| 893 |
+
intrs, rgbs=None, visualize=False)
|
| 894 |
+
q_xyz_cam = self.norm_xyz(q_xyz_cam)
|
| 895 |
+
|
| 896 |
+
query_t = queries[:,None,:,:1].repeat(B, T, 1, 1)
|
| 897 |
+
q_xyz_cam = torch.cat([query_t/T, q_xyz_cam], dim=-1)
|
| 898 |
+
T_all = torch.arange(T, device=device)[None,:,None,None].repeat(B, 1, N, 1)
|
| 899 |
+
current_xyzt = torch.cat([T_all/T, coords_xyz_init.view(B, T, N, -1)], dim=-1)
|
| 900 |
+
rel_pos_query_glob = q_xyz_cam - current_xyzt
|
| 901 |
+
# embed the confidence and dynamic probability
|
| 902 |
+
confidence_curr = torch.sigmoid(confidence[...,None])
|
| 903 |
+
dynamic_prob_curr = torch.sigmoid(dynamic_prob[...,None]).mean(dim=1, keepdim=True).repeat(1,T,1,1)
|
| 904 |
+
# embed the confidence and dynamic probability
|
| 905 |
+
rel_pos_query_glob = torch.cat([rel_pos_query_glob, confidence_curr, dynamic_prob_curr], dim=-1)
|
| 906 |
+
rel_pos_query_glob = posenc(rel_pos_query_glob, min_deg=0, max_deg=12)
|
| 907 |
+
transformer_input_depth.append(rel_pos_query_glob)
|
| 908 |
+
|
| 909 |
+
x = (
|
| 910 |
+
torch.cat(transformer_input, dim=-1)
|
| 911 |
+
.permute(0, 2, 1, 3)
|
| 912 |
+
.reshape(B * N, T, -1)
|
| 913 |
+
)
|
| 914 |
+
x_depth = (
|
| 915 |
+
torch.cat(transformer_input_depth, dim=-1)
|
| 916 |
+
.permute(0, 2, 1, 3)
|
| 917 |
+
.reshape(B * N, T, -1)
|
| 918 |
+
)
|
| 919 |
+
x_depth = self.proj_xyz_embed(x_depth)
|
| 920 |
+
|
| 921 |
+
x = x + self.interpolate_time_embed(x, T)
|
| 922 |
+
x = x.view(B, N, T, -1) # (B N) T D -> B N T D
|
| 923 |
+
x_depth = x_depth + self.interpolate_time_embed(x_depth, T)
|
| 924 |
+
x_depth = x_depth.view(B, N, T, -1) # (B N) T D -> B N T D
|
| 925 |
+
delta, delta_depth, delta_dynamic_prob, delta_pro_analysis_w, scale_shift_out, dense_res_out = self.updateformer3D(
|
| 926 |
+
x,
|
| 927 |
+
x_depth,
|
| 928 |
+
self.updateformer,
|
| 929 |
+
add_space_attn=add_space_attn,
|
| 930 |
+
extra_sparse_tokens=extra_sparse_tokens,
|
| 931 |
+
extra_dense_tokens=extra_dense_tokens,
|
| 932 |
+
)
|
| 933 |
+
# update the scale and shift
|
| 934 |
+
if scale_shift_out is not None:
|
| 935 |
+
extra_sparse_tokens = extra_sparse_tokens + scale_shift_out[...,:128]
|
| 936 |
+
scale_update = scale_shift_out[:,:1,:,-1].permute(0,2,1)[...,None]
|
| 937 |
+
shift_update = scale_shift_out[:,1:,:,-1].permute(0,2,1)[...,None]
|
| 938 |
+
scale_est = scale_est + scale_update
|
| 939 |
+
shift_est[...,2:] = shift_est[...,2:] + shift_update / 10
|
| 940 |
+
# dense tokens update
|
| 941 |
+
extra_dense_tokens = extra_dense_tokens + dense_res_out[:,:,-128:]
|
| 942 |
+
res_low = dense_res_out[:,:,:3]
|
| 943 |
+
up_mask = self.upsample_transformer(extra_dense_tokens.mean(dim=1), res_low)
|
| 944 |
+
up_mask = repeat(up_mask, "b k h w -> b s k h w", s=T)
|
| 945 |
+
up_mask = rearrange(up_mask, "b s c h w -> (b s) 1 c h w")
|
| 946 |
+
res_up = self.upsample_with_mask(
|
| 947 |
+
rearrange(res_low, 'b t c h w -> (b t) c h w'),
|
| 948 |
+
up_mask,
|
| 949 |
+
)
|
| 950 |
+
res_up = rearrange(res_up, "(b t) c h w -> b t c h w", b=B, t=T)
|
| 951 |
+
# residual_point = residual_point + res_up
|
| 952 |
+
|
| 953 |
+
delta_coords = delta[..., :D_coords].permute(0, 2, 1, 3)
|
| 954 |
+
delta_vis = delta[..., D_coords].permute(0, 2, 1)
|
| 955 |
+
delta_confidence = delta[..., D_coords + 1].permute(0, 2, 1)
|
| 956 |
+
|
| 957 |
+
vis = vis + delta_vis
|
| 958 |
+
confidence = confidence + delta_confidence
|
| 959 |
+
dynamic_prob = dynamic_prob + delta_dynamic_prob[...,0].permute(0, 2, 1)
|
| 960 |
+
pro_analysis_w = pro_analysis_w + delta_pro_analysis_w[...,0].permute(0, 2, 1)
|
| 961 |
+
# update the depth
|
| 962 |
+
vis_est = torch.sigmoid(vis.detach())
|
| 963 |
+
|
| 964 |
+
delta_xyz = delta_depth[...,:3].permute(0,2,1,3)
|
| 965 |
+
denorm_delta_depth = (self.denorm_xyz(coords_xyz+delta_xyz)-self.denorm_xyz(coords_xyz))[...,2:3]
|
| 966 |
+
|
| 967 |
+
|
| 968 |
+
delta_depth_ = denorm_delta_depth.detach()
|
| 969 |
+
delta_coords = torch.cat([delta_coords, delta_depth_],dim=-1)
|
| 970 |
+
coords = coords + delta_coords
|
| 971 |
+
coords_append = coords.clone()
|
| 972 |
+
coords_xyz_append = self.denorm_xyz(coords_xyz + delta_xyz).clone()
|
| 973 |
+
|
| 974 |
+
coords_append[..., :2] = coords_append[..., :2] * float(self.stride)
|
| 975 |
+
coords_append[..., 0] /= self.factor_x
|
| 976 |
+
coords_append[..., 1] /= self.factor_y
|
| 977 |
+
|
| 978 |
+
# get the camera pose from tracks
|
| 979 |
+
dynamic_prob_curr = torch.sigmoid(dynamic_prob.detach())*torch.sigmoid(pro_analysis_w)
|
| 980 |
+
mask_out = (coords_append[...,0]<W_)&(coords_append[...,0]>0)&(coords_append[...,1]<H_)&(coords_append[...,1]>0)
|
| 981 |
+
if query_no_BA:
|
| 982 |
+
dynamic_prob_curr[:,:,:ba_len] = torch.ones_like(dynamic_prob_curr[:,:,:ba_len])
|
| 983 |
+
point_map_org_i = scale_est.view(B*T,1,1,1)*point_map_org.clone().detach() + shift_est.view(B*T,3,1,1)
|
| 984 |
+
# depth_unproj = bilinear_sampler(point_map_org_i, coords_append[...,:2].view(B*T, N, 1, 2), mode="nearest")[:,2,:,0].detach()
|
| 985 |
+
|
| 986 |
+
depth_unproj_neg = self.get_correlation_feat(
|
| 987 |
+
point_map_org_i.view(B,T,3,point_map_org_i.shape[-2], point_map_org_i.shape[-1]),
|
| 988 |
+
coords_append[...,:2].view(B*T, N, 2), radius=self.corr3d_radius
|
| 989 |
+
)[..., 2]
|
| 990 |
+
depth_diff = (depth_unproj_neg.view(B,T,N,-1) - coords_append[...,2:]).abs()
|
| 991 |
+
idx_neg = torch.argmin(depth_diff, dim=-1)
|
| 992 |
+
depth_unproj = depth_unproj_neg.view(B,T,N,-1)[torch.arange(B)[:, None, None, None],
|
| 993 |
+
torch.arange(T)[None, :, None, None],
|
| 994 |
+
torch.arange(N)[None, None, :, None],
|
| 995 |
+
idx_neg.view(B,T,N,1)].view(B*T, N)
|
| 996 |
+
|
| 997 |
+
unc_unproj = bilinear_sampler(self.metric_unc_org, coords_append[...,:2].view(B*T, N, 1, 2), mode="nearest")[:,0,:,0].detach()
|
| 998 |
+
depth_unproj[unc_unproj<0.5] = 0.0
|
| 999 |
+
|
| 1000 |
+
# replace the depth for visible and solid points
|
| 1001 |
+
conf_est = torch.sigmoid(confidence.detach())
|
| 1002 |
+
replace_mask = (depth_unproj.view(B,T,N)>0.0) * (vis_est>0.5) # * (conf_est>0.5)
|
| 1003 |
+
#NOTE: way1: find the jitter points
|
| 1004 |
+
depth_rel = (depth_unproj.view(B, T, N) - queries_z.permute(0, 2, 1))
|
| 1005 |
+
depth_ddt1 = depth_rel[:, 1:, :] - depth_rel[:, :-1, :]
|
| 1006 |
+
depth_ddt2 = depth_rel[:, 2:, :] - 2 * depth_rel[:, 1:-1, :] + depth_rel[:, :-2, :]
|
| 1007 |
+
jitter_mask = torch.zeros_like(depth_rel, dtype=torch.bool)
|
| 1008 |
+
if depth_ddt2.abs().max()>0:
|
| 1009 |
+
thre2 = torch.quantile(depth_ddt2.abs()[depth_ddt2.abs()>0], replace_ratio)
|
| 1010 |
+
jitter_mask[:, 1:-1, :] = (depth_ddt2.abs() < thre2)
|
| 1011 |
+
thre1 = torch.quantile(depth_ddt1.abs()[depth_ddt1.abs()>0], replace_ratio)
|
| 1012 |
+
jitter_mask[:, :-1, :] *= (depth_ddt1.abs() < thre1)
|
| 1013 |
+
replace_mask = replace_mask * jitter_mask
|
| 1014 |
+
|
| 1015 |
+
#NOTE: way2: top k topological change detection
|
| 1016 |
+
# coords_2d_lift = coords_append.clone()
|
| 1017 |
+
# coords_2d_lift[...,2][replace_mask] = depth_unproj.view(B,T,N)[replace_mask]
|
| 1018 |
+
# coords_2d_lift = self.cam_from_track(coords_2d_lift.clone(), intrs_org, only_cam_pts=True)
|
| 1019 |
+
# coords_2d_lift[~replace_mask] = coords_xyz_append[~replace_mask]
|
| 1020 |
+
# import pdb; pdb.set_trace()
|
| 1021 |
+
# jitter_mask = get_topo_mask(coords_xyz_append, coords_2d_lift, replace_ratio)
|
| 1022 |
+
# replace_mask = replace_mask * jitter_mask
|
| 1023 |
+
|
| 1024 |
+
# replace the depth
|
| 1025 |
+
if self.training:
|
| 1026 |
+
replace_mask = torch.zeros_like(replace_mask)
|
| 1027 |
+
coords_append[...,2][replace_mask] = depth_unproj.view(B,T,N)[replace_mask]
|
| 1028 |
+
coords_xyz_unproj = self.cam_from_track(coords_append.clone(), intrs_org, only_cam_pts=True)
|
| 1029 |
+
coords[...,2][replace_mask] = depth_unproj.view(B,T,N)[replace_mask]
|
| 1030 |
+
# coords_xyz_append[replace_mask] = coords_xyz_unproj[replace_mask]
|
| 1031 |
+
coords_xyz_append_refine = coords_xyz_append.clone()
|
| 1032 |
+
coords_xyz_append_refine[replace_mask] = coords_xyz_unproj[replace_mask]
|
| 1033 |
+
|
| 1034 |
+
c2w_traj_est, cam_pts_est, intrs_refine, coords_refine, world_tracks, world_tracks_refined, c2w_traj_init = self.cam_from_track(coords_append.clone(),
|
| 1035 |
+
intrs_org, dynamic_prob_curr, queries_z_unc, conf_est*vis_est*mask_out.float(),
|
| 1036 |
+
track_feat_concat=x_depth, tracks_xyz=coords_xyz_append_refine, init_pose=init_pose,
|
| 1037 |
+
query_pts=queries_xyz_i, fixed_cam=fixed_cam, depth_unproj=depth_unproj, cam_gt=cam_gt)
|
| 1038 |
+
intrs_org = intrs_refine.view(B, T, 3, 3).to(intrs_org.dtype)
|
| 1039 |
+
|
| 1040 |
+
# get the queries world
|
| 1041 |
+
self.c2w_est_curr = c2w_traj_est.detach()
|
| 1042 |
+
|
| 1043 |
+
# update coords and coords_append
|
| 1044 |
+
coords[..., 2] = (cam_pts_est)[...,2]
|
| 1045 |
+
coords_append[..., 2] = (cam_pts_est)[...,2]
|
| 1046 |
+
|
| 1047 |
+
# update coords_xyz_append
|
| 1048 |
+
# coords_xyz_append = cam_pts_est
|
| 1049 |
+
coords_xyz = self.norm_xyz(cam_pts_est)
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
# proj
|
| 1053 |
+
coords_xyz_de = coords_xyz_append.clone()
|
| 1054 |
+
coords_xyz_de[coords_xyz_de[...,2].abs()<1e-6] = -1e-4
|
| 1055 |
+
mask_nan = coords_xyz_de[...,2].abs()<1e-2
|
| 1056 |
+
coords_proj = torch.einsum("btij,btnj->btni", intrs_org, coords_xyz_de/coords_xyz_de[...,2:3].abs())[...,:2]
|
| 1057 |
+
coords_proj[...,0] *= self.factor_x
|
| 1058 |
+
coords_proj[...,1] *= self.factor_y
|
| 1059 |
+
coords_proj[...,:2] /= float(self.stride)
|
| 1060 |
+
# make sure it is aligned with 2d tracking
|
| 1061 |
+
coords_proj_curr = coords[...,:2].view(B*T, N, 2).detach()
|
| 1062 |
+
vis_est = (vis_est>0.5).float()
|
| 1063 |
+
sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean()
|
| 1064 |
+
# coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype)
|
| 1065 |
+
# if torch.isnan(coords_proj_curr).sum()>0:
|
| 1066 |
+
# import pdb; pdb.set_trace()
|
| 1067 |
+
|
| 1068 |
+
if False:
|
| 1069 |
+
point_map_resize = point_map.clone().view(B, T, 3, H, W)
|
| 1070 |
+
update_input = torch.cat([point_map_resize, metric_unc.view(B,T,1,H,W)], dim=2)
|
| 1071 |
+
coords_append_resize = coords.clone().detach()
|
| 1072 |
+
coords_append_resize[..., :2] = coords_append_resize[..., :2] * float(self.stride)
|
| 1073 |
+
update_track_input = self.norm_xyz(cam_pts_est)*5
|
| 1074 |
+
update_track_input = torch.cat([update_track_input, vis_est[...,None]], dim=-1)
|
| 1075 |
+
update_track_input = posenc(update_track_input, min_deg=0, max_deg=12)
|
| 1076 |
+
update = self.update_pointmap.stablizer(update_input,
|
| 1077 |
+
update_track_input, coords_append_resize)#, imgs=video, vis_track=viser)
|
| 1078 |
+
#NOTE: update the point map
|
| 1079 |
+
point_map_resize += update
|
| 1080 |
+
point_map_refine_out = F.interpolate(point_map_resize.view(B*T, -1, H, W),
|
| 1081 |
+
size=(self.image_size[0].item(), self.image_size[1].item()), mode='nearest')
|
| 1082 |
+
point_map_refine_out = rearrange(point_map_refine_out, '(b t) c h w -> b t c h w', t=T, b=B)
|
| 1083 |
+
point_map_preds.append(self.denorm_xyz(point_map_refine_out))
|
| 1084 |
+
point_map_org = self.denorm_xyz(point_map_refine_out).view(B*T, 3, H_, W_)
|
| 1085 |
+
|
| 1086 |
+
# if torch.isnan(coords).sum()>0:
|
| 1087 |
+
# import pdb; pdb.set_trace()
|
| 1088 |
+
#NOTE: the 2d tracking + unproject depth
|
| 1089 |
+
fix_cam_est = coords_append.clone()
|
| 1090 |
+
fix_cam_est[...,2] = depth_unproj
|
| 1091 |
+
fix_cam_pts = self.cam_from_track(
|
| 1092 |
+
fix_cam_est, intrs_org, only_cam_pts=True
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
coord_preds.append(coords_append)
|
| 1096 |
+
coords_xyz_preds.append(coords_xyz_append)
|
| 1097 |
+
vis_preds.append(vis)
|
| 1098 |
+
cam_preds.append(c2w_traj_init)
|
| 1099 |
+
pts3d_cam_pred.append(cam_pts_est)
|
| 1100 |
+
world_tracks_pred.append(world_tracks)
|
| 1101 |
+
world_tracks_refined_pred.append(world_tracks_refined)
|
| 1102 |
+
confidence_preds.append(confidence)
|
| 1103 |
+
dynamic_prob_preds.append(dynamic_prob)
|
| 1104 |
+
scale_ests.append(scale_est)
|
| 1105 |
+
shift_ests.append(shift_est)
|
| 1106 |
+
|
| 1107 |
+
if stage!=0:
|
| 1108 |
+
all_coords_predictions.append([coord for coord in coord_preds])
|
| 1109 |
+
all_coords_xyz_predictions.append([coord_xyz for coord_xyz in coords_xyz_preds])
|
| 1110 |
+
all_vis_predictions.append(vis_preds)
|
| 1111 |
+
all_confidence_predictions.append(confidence_preds)
|
| 1112 |
+
all_dynamic_prob_predictions.append(dynamic_prob_preds)
|
| 1113 |
+
all_cam_predictions.append([cam for cam in cam_preds])
|
| 1114 |
+
all_cam_pts_predictions.append([pts for pts in pts3d_cam_pred])
|
| 1115 |
+
all_world_tracks_predictions.append([world_tracks for world_tracks in world_tracks_pred])
|
| 1116 |
+
all_world_tracks_refined_predictions.append([world_tracks_refined for world_tracks_refined in world_tracks_refined_pred])
|
| 1117 |
+
all_scale_est.append(scale_ests)
|
| 1118 |
+
all_shift_est.append(shift_ests)
|
| 1119 |
+
if stage!=0:
|
| 1120 |
+
train_data = (
|
| 1121 |
+
all_coords_predictions,
|
| 1122 |
+
all_coords_xyz_predictions,
|
| 1123 |
+
all_vis_predictions,
|
| 1124 |
+
all_confidence_predictions,
|
| 1125 |
+
all_dynamic_prob_predictions,
|
| 1126 |
+
all_cam_predictions,
|
| 1127 |
+
all_cam_pts_predictions,
|
| 1128 |
+
all_world_tracks_predictions,
|
| 1129 |
+
all_world_tracks_refined_predictions,
|
| 1130 |
+
all_scale_est,
|
| 1131 |
+
all_shift_est,
|
| 1132 |
+
torch.ones_like(vis_preds[-1], device=vis_preds[-1].device),
|
| 1133 |
+
)
|
| 1134 |
+
else:
|
| 1135 |
+
train_data = None
|
| 1136 |
+
# resize back
|
| 1137 |
+
# init the trajectories by camera motion
|
| 1138 |
+
|
| 1139 |
+
# if cache is not None:
|
| 1140 |
+
# viser = Visualizer(save_dir=".", grayscale=True,
|
| 1141 |
+
# fps=10, pad_value=50, tracks_leave_trace=0)
|
| 1142 |
+
# coords_clone = coords.clone()
|
| 1143 |
+
# coords_clone[...,:2] *= self.stride
|
| 1144 |
+
# coords_clone[..., 0] /= self.factor_x
|
| 1145 |
+
# coords_clone[..., 1] /= self.factor_y
|
| 1146 |
+
# viser.visualize(video=video_vis, tracks=coords_clone[..., :2], filename="test_refine")
|
| 1147 |
+
# import pdb; pdb.set_trace()
|
| 1148 |
+
|
| 1149 |
+
if train_data is not None:
|
| 1150 |
+
# get the gt pts in the world coordinate
|
| 1151 |
+
self_supervised = False
|
| 1152 |
+
if (traj3d_gt is not None):
|
| 1153 |
+
if traj3d_gt[...,2].abs().max()>0:
|
| 1154 |
+
gt_cam_pts = self.cam_from_track(
|
| 1155 |
+
traj3d_gt, intrs_org, only_cam_pts=True
|
| 1156 |
+
)
|
| 1157 |
+
else:
|
| 1158 |
+
self_supervised = True
|
| 1159 |
+
else:
|
| 1160 |
+
self_supervised = True
|
| 1161 |
+
|
| 1162 |
+
if self_supervised:
|
| 1163 |
+
gt_cam_pts = self.cam_from_track(
|
| 1164 |
+
coord_preds[-1].detach(), intrs_org, only_cam_pts=True
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
if cam_gt is not None:
|
| 1168 |
+
gt_world_pts = torch.einsum(
|
| 1169 |
+
"btij,btnj->btni",
|
| 1170 |
+
cam_gt[...,:3,:3],
|
| 1171 |
+
gt_cam_pts
|
| 1172 |
+
) + cam_gt[...,None, :3,3] # B T N 3
|
| 1173 |
+
else:
|
| 1174 |
+
gt_world_pts = torch.einsum(
|
| 1175 |
+
"btij,btnj->btni",
|
| 1176 |
+
self.c2w_est_curr[...,:3,:3],
|
| 1177 |
+
gt_cam_pts
|
| 1178 |
+
) + self.c2w_est_curr[...,None, :3,3] # B T N 3
|
| 1179 |
+
# update the query points with scale and shift
|
| 1180 |
+
queries_xyz_i = queries_xyz.clone().detach()
|
| 1181 |
+
queries_xyz_i[..., -1] = queries_xyz_i[..., -1] * scale_est_query.view(B,1,N) + shift_est_query.view(B,1,N,3)[...,2]
|
| 1182 |
+
q_static_proj, q_xyz_world, q_xyz_cam = self.track_from_cam(queries_xyz_i,
|
| 1183 |
+
self.c2w_est_curr,
|
| 1184 |
+
intrs, rgbs=video_vis, visualize=False)
|
| 1185 |
+
|
| 1186 |
+
q_static_proj[..., 0] /= self.factor_x
|
| 1187 |
+
q_static_proj[..., 1] /= self.factor_y
|
| 1188 |
+
cam_gt = self.c2w_est_curr[:,:,:3,:]
|
| 1189 |
+
|
| 1190 |
+
if traj3d_gt is not None:
|
| 1191 |
+
ret_loss = self.loss(train_data, traj3d_gt,
|
| 1192 |
+
vis_gt, None, cam_gt, queries_z_unc,
|
| 1193 |
+
q_xyz_world, q_static_proj, anchor_loss=anchor_loss, fix_cam_pts=fix_cam_pts, video_vis=video_vis, stage=stage,
|
| 1194 |
+
gt_world_pts=gt_world_pts, mask_traj_gt=mask_traj_gt, intrs=intrs_org, custom_vid=custom_vid, valid_only=valid_only,
|
| 1195 |
+
c2w_ests=c2w_ests, point_map_preds=point_map_preds, points_map_gt=points_map_gt, metric_unc=metric_unc, scale_est=scale_est,
|
| 1196 |
+
shift_est=shift_est, point_map_org_train=point_map_org_train)
|
| 1197 |
+
else:
|
| 1198 |
+
ret_loss = self.loss(train_data, traj3d_gt,
|
| 1199 |
+
vis_gt, None, cam_gt, queries_z_unc,
|
| 1200 |
+
q_xyz_world, q_static_proj, anchor_loss=anchor_loss, fix_cam_pts=fix_cam_pts, video_vis=video_vis, stage=stage,
|
| 1201 |
+
gt_world_pts=gt_world_pts, mask_traj_gt=mask_traj_gt, intrs=intrs_org, custom_vid=custom_vid, valid_only=valid_only,
|
| 1202 |
+
c2w_ests=c2w_ests, point_map_preds=point_map_preds, points_map_gt=points_map_gt, metric_unc=metric_unc, scale_est=scale_est,
|
| 1203 |
+
shift_est=shift_est, point_map_org_train=point_map_org_train)
|
| 1204 |
+
if custom_vid:
|
| 1205 |
+
sync_loss = 0*sync_loss
|
| 1206 |
+
if (sync_loss > 50) and (stage==1):
|
| 1207 |
+
ret_loss = (0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss) + (0*sync_loss,)
|
| 1208 |
+
else:
|
| 1209 |
+
ret_loss = ret_loss+(10*sync_loss,)
|
| 1210 |
+
|
| 1211 |
+
else:
|
| 1212 |
+
ret_loss = None
|
| 1213 |
+
|
| 1214 |
+
color_pts = torch.cat([pts3d_cam_pred[-1], queries_rgb[:,None].repeat(1, T, 1, 1)], dim=-1)
|
| 1215 |
+
|
| 1216 |
+
#TODO: For evaluation. We found our model have some bias on invisible points after training. (to be fixed)
|
| 1217 |
+
vis_pred_out = torch.sigmoid(vis_preds[-1]) + 0.2
|
| 1218 |
+
|
| 1219 |
+
ret = {"preds": coord_preds[-1], "vis_pred": vis_pred_out,
|
| 1220 |
+
"conf_pred": torch.sigmoid(confidence_preds[-1]),
|
| 1221 |
+
"cam_pred": self.c2w_est_curr,"loss": ret_loss}
|
| 1222 |
+
|
| 1223 |
+
cache = {
|
| 1224 |
+
"fmaps": fmaps_org[0].detach(),
|
| 1225 |
+
"track_feat_support3d_pyramid": [track_feat_support3d_pyramid[i].detach() for i in range(len(track_feat_support3d_pyramid))],
|
| 1226 |
+
"track_point_map_support_pyramid": [self.denorm_xyz(track_point_map_support_pyramid[i].detach()) for i in range(len(track_point_map_support_pyramid))],
|
| 1227 |
+
"track_feat3d_pyramid": [track_feat3d_pyramid[i].detach() for i in range(len(track_feat3d_pyramid))],
|
| 1228 |
+
"track_point_map_pyramid": [self.denorm_xyz(track_point_map_pyramid[i].detach()) for i in range(len(track_point_map_pyramid))],
|
| 1229 |
+
"track_feat_pyramid": [track_feat_pyramid[i].detach() for i in range(len(track_feat_pyramid))],
|
| 1230 |
+
"track_feat_support_pyramid": [track_feat_support_pyramid[i].detach() for i in range(len(track_feat_support_pyramid))],
|
| 1231 |
+
"track2d_pred_cache": coord_preds[-1][0].clone().detach(),
|
| 1232 |
+
"track3d_pred_cache": pts3d_cam_pred[-1][0].clone().detach(),
|
| 1233 |
+
}
|
| 1234 |
+
#NOTE: update the point map
|
| 1235 |
+
point_map_org = scale_est.view(B*T,1,1,1)*point_map_org + shift_est.view(B*T,3,1,1)
|
| 1236 |
+
point_map_org_refined = point_map_org
|
| 1237 |
+
return ret, torch.sigmoid(dynamic_prob_preds[-1])*queries_z_unc[:,None,:,0], coord_preds[-1], color_pts, intrs_org, point_map_org_refined, cache
|
| 1238 |
+
|
| 1239 |
+
def track_d2_loss(self, tracks3d, stride=[1,2,3], dyn_prob=None, mask=None):
|
| 1240 |
+
"""
|
| 1241 |
+
tracks3d: B T N 3
|
| 1242 |
+
dyn_prob: B T N 1
|
| 1243 |
+
"""
|
| 1244 |
+
r = 0.8
|
| 1245 |
+
t_diff_total = 0.0
|
| 1246 |
+
for i, s_ in enumerate(stride):
|
| 1247 |
+
w_ = r**i
|
| 1248 |
+
tracks3d_stride = tracks3d[:, ::s_, :, :] # B T//s_ N 3
|
| 1249 |
+
t_diff_tracks3d = (tracks3d_stride[:, 1:, :, :] - tracks3d_stride[:, :-1, :, :])
|
| 1250 |
+
t_diff2 = (t_diff_tracks3d[:, 1:, :, :] - t_diff_tracks3d[:, :-1, :, :])
|
| 1251 |
+
t_diff_total += w_*(t_diff2.norm(dim=-1).mean())
|
| 1252 |
+
|
| 1253 |
+
return 1e2*t_diff_total
|
| 1254 |
+
|
| 1255 |
+
def loss(self, train_data, traj3d_gt=None,
|
| 1256 |
+
vis_gt=None, static_tracks_gt=None, cam_gt=None,
|
| 1257 |
+
z_unc=None, q_xyz_world=None, q_static_proj=None, anchor_loss=0, valid_only=False,
|
| 1258 |
+
gt_world_pts=None, mask_traj_gt=None, intrs=None, c2w_ests=None, custom_vid=False, video_vis=None, stage=0,
|
| 1259 |
+
fix_cam_pts=None, point_map_preds=None, points_map_gt=None, metric_unc=None, scale_est=None, shift_est=None, point_map_org_train=None):
|
| 1260 |
+
"""
|
| 1261 |
+
Compute the loss of 3D tracking problem
|
| 1262 |
+
|
| 1263 |
+
"""
|
| 1264 |
+
|
| 1265 |
+
(
|
| 1266 |
+
coord_predictions, coords_xyz_predictions, vis_predictions, confidence_predicitons,
|
| 1267 |
+
dynamic_prob_predictions, camera_predictions, cam_pts_predictions, world_tracks_predictions,
|
| 1268 |
+
world_tracks_refined_predictions, scale_ests, shift_ests, valid_mask
|
| 1269 |
+
) = train_data
|
| 1270 |
+
B, T, _, _ = cam_gt.shape
|
| 1271 |
+
if (stage == 2) and self.training:
|
| 1272 |
+
# get the scale and shift gt
|
| 1273 |
+
self.metric_unc_org[:,0] = self.metric_unc_org[:,0] * (points_map_gt.norm(dim=-1)>0).float() * (self.metric_unc_org[:,0]>0.5).float()
|
| 1274 |
+
if not (self.scale_gt==torch.ones(B*T).to(self.scale_gt.device)).all():
|
| 1275 |
+
scale_gt, shift_gt = self.scale_gt, self.shift_gt
|
| 1276 |
+
scale_re = scale_gt[:4].mean()
|
| 1277 |
+
scale_loss = 0.0
|
| 1278 |
+
shift_loss = 0.0
|
| 1279 |
+
for i_scale in range(len(scale_ests[0])):
|
| 1280 |
+
scale_loss += 0.8**(len(scale_ests[0])-i_scale-1)*10*(scale_gt - scale_re*scale_ests[0][i_scale].view(-1)).abs().mean()
|
| 1281 |
+
shift_loss += 0.8**(len(shift_ests[0])-i_scale-1)*10*(shift_gt - scale_re*shift_ests[0][i_scale].view(-1,3)).abs().mean()
|
| 1282 |
+
else:
|
| 1283 |
+
scale_loss = 0.0 * scale_ests[0][0].mean()
|
| 1284 |
+
shift_loss = 0.0 * shift_ests[0][0].mean()
|
| 1285 |
+
scale_re = 1.0
|
| 1286 |
+
else:
|
| 1287 |
+
scale_loss = 0.0
|
| 1288 |
+
shift_loss = 0.0
|
| 1289 |
+
|
| 1290 |
+
if len(point_map_preds)>0:
|
| 1291 |
+
point_map_loss = 0.0
|
| 1292 |
+
for i in range(len(point_map_preds)):
|
| 1293 |
+
point_map_preds_i = point_map_preds[i]
|
| 1294 |
+
point_map_preds_i = rearrange(point_map_preds_i, 'b t c h w -> (b t) c h w', b=B, t=T)
|
| 1295 |
+
base_loss = ((self.pred_points - points_map_gt).norm(dim=-1) * self.metric_unc_org[:,0]).mean()
|
| 1296 |
+
point_map_loss_i = ((point_map_preds_i - points_map_gt.permute(0,3,1,2)).norm(dim=1) * self.metric_unc_org[:,0]).mean()
|
| 1297 |
+
point_map_loss += point_map_loss_i
|
| 1298 |
+
# point_map_loss += ((point_map_org_train - points_map_gt.permute(0,3,1,2)).norm(dim=1) * self.metric_unc_org[:,0]).mean()
|
| 1299 |
+
if scale_loss == 0.0:
|
| 1300 |
+
point_map_loss = 0*point_map_preds_i.sum()
|
| 1301 |
+
else:
|
| 1302 |
+
point_map_loss = 0.0
|
| 1303 |
+
|
| 1304 |
+
# camera loss
|
| 1305 |
+
cam_loss = 0.0
|
| 1306 |
+
dyn_loss = 0.0
|
| 1307 |
+
N_gt = gt_world_pts.shape[2]
|
| 1308 |
+
|
| 1309 |
+
# self supervised dynamic mask
|
| 1310 |
+
H_org, W_org = self.image_size[0], self.image_size[1]
|
| 1311 |
+
q_static_proj[torch.isnan(q_static_proj)] = -200
|
| 1312 |
+
in_view_mask = (q_static_proj[...,0]>0) & (q_static_proj[...,0]<W_org) & (q_static_proj[...,1]>0) & (q_static_proj[...,1]<H_org)
|
| 1313 |
+
dyn_mask_final = (((coord_predictions[0][-1] - q_static_proj))[...,:2].norm(dim=-1) * in_view_mask)
|
| 1314 |
+
dyn_mask_final = dyn_mask_final.sum(dim=1) / (in_view_mask.sum(dim=1) + 1e-2)
|
| 1315 |
+
dyn_mask_final = dyn_mask_final > 6
|
| 1316 |
+
|
| 1317 |
+
for iter_, cam_pred_i in enumerate(camera_predictions[0]):
|
| 1318 |
+
# points loss
|
| 1319 |
+
pts_i_world = world_tracks_predictions[0][iter_].view(B, T, -1, 3)
|
| 1320 |
+
|
| 1321 |
+
coords_xyz_i_world = coords_xyz_predictions[0][iter_].view(B, T, -1, 3)
|
| 1322 |
+
coords_i = coord_predictions[0][iter_].view(B, T, -1, 3)[..., :2]
|
| 1323 |
+
pts_i_world_refined = torch.einsum(
|
| 1324 |
+
"btij,btnj->btni",
|
| 1325 |
+
cam_gt[...,:3,:3],
|
| 1326 |
+
coords_xyz_i_world
|
| 1327 |
+
) + cam_gt[...,None, :3,3] # B T N 3
|
| 1328 |
+
|
| 1329 |
+
# pts_i_world_refined = world_tracks_refined_predictions[0][iter_].view(B, T, -1, 3)
|
| 1330 |
+
pts_world = pts_i_world
|
| 1331 |
+
dyn_prob_i_logits = dynamic_prob_predictions[0][iter_].mean(dim=1)
|
| 1332 |
+
dyn_prob_i = torch.sigmoid(dyn_prob_i_logits).detach()
|
| 1333 |
+
mask = pts_world.norm(dim=-1) < 200
|
| 1334 |
+
|
| 1335 |
+
# general
|
| 1336 |
+
vis_i_logits = vis_predictions[0][iter_]
|
| 1337 |
+
vis_i = torch.sigmoid(vis_i_logits).detach()
|
| 1338 |
+
if mask_traj_gt is not None:
|
| 1339 |
+
try:
|
| 1340 |
+
N_gt_mask = mask_traj_gt.shape[1]
|
| 1341 |
+
align_loss = (gt_world_pts - q_xyz_world[:,None,:N_gt,:,]).norm(dim=-1)[...,:N_gt_mask] * (mask_traj_gt.permute(0,2,1))
|
| 1342 |
+
visb_traj = (align_loss * vis_i[:,:,:N_gt_mask]).sum(dim=1)/vis_i[:,:,:N_gt_mask].sum(dim=1)
|
| 1343 |
+
except:
|
| 1344 |
+
import pdb; pdb.set_trace()
|
| 1345 |
+
else:
|
| 1346 |
+
visb_traj = ((gt_world_pts - q_xyz_world[:,None,:N_gt,:,]).norm(dim=-1) * vis_i[:,:,:N_gt]).sum(dim=1)/vis_i[:,:,:N_gt].sum(dim=1)
|
| 1347 |
+
|
| 1348 |
+
# pts_loss = ((q_xyz_world[:,None,...] - pts_world)[:,:,:N_gt,:].norm(dim=-1)*(1-dyn_prob_i[:,None,:N_gt])) # - 0.1*(1-dyn_prob_i[:,None,:N_gt]).log()
|
| 1349 |
+
pts_loss = 0
|
| 1350 |
+
static_mask = ~dyn_mask_final # more strict for static points
|
| 1351 |
+
dyn_mask = dyn_mask_final
|
| 1352 |
+
pts_loss_refined = ((q_xyz_world[:,None,...] - pts_i_world_refined).norm(dim=-1)*static_mask[:,None,:]).sum()/static_mask.sum() # - 0.1*(1-dyn_prob_i[:,None,:N_gt]).log()
|
| 1353 |
+
vis_logits_final = vis_predictions[0][-1].detach()
|
| 1354 |
+
vis_final = torch.sigmoid(vis_logits_final)+0.2 > 0.5 # more strict for visible points
|
| 1355 |
+
dyn_vis_mask = dyn_mask*vis_final * (fix_cam_pts[...,2] > 0.1)
|
| 1356 |
+
pts_loss_dynamic = ((fix_cam_pts - coords_xyz_i_world).norm(dim=-1)*dyn_vis_mask[:,None,:]).sum()/dyn_vis_mask.sum()
|
| 1357 |
+
|
| 1358 |
+
# pts_loss_refined = 0
|
| 1359 |
+
if traj3d_gt is not None:
|
| 1360 |
+
tap_traj = (gt_world_pts[:,:-1,...] - gt_world_pts[:,1:,...]).norm(dim=-1).sum(dim=1)[...,:N_gt_mask]
|
| 1361 |
+
mask_dyn = tap_traj>0.5
|
| 1362 |
+
if mask_traj_gt.sum() > 0:
|
| 1363 |
+
dyn_loss_i = 20*balanced_binary_cross_entropy(dyn_prob_i_logits[:,:N_gt_mask][mask_traj_gt.squeeze(-1)],
|
| 1364 |
+
mask_dyn.float()[mask_traj_gt.squeeze(-1)])
|
| 1365 |
+
else:
|
| 1366 |
+
dyn_loss_i = 0
|
| 1367 |
+
else:
|
| 1368 |
+
dyn_loss_i = 10*balanced_binary_cross_entropy(dyn_prob_i_logits, dyn_mask_final.float())
|
| 1369 |
+
|
| 1370 |
+
dyn_loss += dyn_loss_i
|
| 1371 |
+
|
| 1372 |
+
# visible loss for out of view points
|
| 1373 |
+
vis_i_train = torch.sigmoid(vis_i_logits)
|
| 1374 |
+
out_of_view_mask = (coords_i[...,0]<0)|(coords_i[...,0]>self.image_size[1])|(coords_i[...,1]<0)|(coords_i[...,1]>self.image_size[0])
|
| 1375 |
+
vis_loss_out_of_view = vis_i_train[out_of_view_mask].sum() / out_of_view_mask.sum()
|
| 1376 |
+
|
| 1377 |
+
|
| 1378 |
+
if traj3d_gt is not None:
|
| 1379 |
+
world_pts_loss = (((gt_world_pts - pts_i_world_refined[:,:,:gt_world_pts.shape[2],...]).norm(dim=-1))[...,:N_gt_mask] * mask_traj_gt.permute(0,2,1)).sum() / mask_traj_gt.sum()
|
| 1380 |
+
# world_pts_init_loss = (((gt_world_pts - pts_i_world[:,:,:gt_world_pts.shape[2],...]).norm(dim=-1))[...,:N_gt_mask] * mask_traj_gt.permute(0,2,1)).sum() / mask_traj_gt.sum()
|
| 1381 |
+
else:
|
| 1382 |
+
world_pts_loss = 0
|
| 1383 |
+
|
| 1384 |
+
# cam regress
|
| 1385 |
+
t_err = (cam_pred_i[...,:3,3] - cam_gt[...,:3,3]).norm(dim=-1).sum()
|
| 1386 |
+
|
| 1387 |
+
# xyz loss
|
| 1388 |
+
in_view_mask_large = (q_static_proj[...,0]>-50) & (q_static_proj[...,0]<W_org+50) & (q_static_proj[...,1]>-50) & (q_static_proj[...,1]<H_org+50)
|
| 1389 |
+
static_vis_mask = (q_static_proj[...,2]>0.05).float() * static_mask[:,None,:] * in_view_mask_large
|
| 1390 |
+
xyz_loss = ((coord_predictions[0][iter_] - q_static_proj)).abs()[...,:2].norm(dim=-1)*static_vis_mask
|
| 1391 |
+
xyz_loss = xyz_loss.sum()/static_vis_mask.sum()
|
| 1392 |
+
|
| 1393 |
+
# visualize the q_static_proj
|
| 1394 |
+
# viser = Visualizer(save_dir=".", grayscale=True,
|
| 1395 |
+
# fps=10, pad_value=50, tracks_leave_trace=0)
|
| 1396 |
+
# video_vis_ = F.interpolate(video_vis.view(B*T,3,video_vis.shape[-2],video_vis.shape[-1]), (H_org, W_org), mode='bilinear', align_corners=False)
|
| 1397 |
+
# viser.visualize(video=video_vis_, tracks=q_static_proj[:,:,dyn_mask_final.squeeze(), :2], filename="test")
|
| 1398 |
+
# viser.visualize(video=video_vis_, tracks=coord_predictions[0][-1][:,:,dyn_mask_final.squeeze(), :2], filename="test_pred")
|
| 1399 |
+
# import pdb; pdb.set_trace()
|
| 1400 |
+
|
| 1401 |
+
# temporal loss
|
| 1402 |
+
t_loss = self.track_d2_loss(pts_i_world_refined, [1,2,3], dyn_prob=dyn_prob_i, mask=mask)
|
| 1403 |
+
R_err = (cam_pred_i[...,:3,:3] - cam_gt[...,:3,:3]).abs().sum(dim=-1).mean()
|
| 1404 |
+
if self.stage == 1:
|
| 1405 |
+
cam_loss += 0.8**(len(camera_predictions[0])-iter_-1)*(10*t_err + 500*R_err + 20*pts_loss_refined + 10*xyz_loss + 20*pts_loss_dynamic + 10*vis_loss_out_of_view) #+ 5*(pts_loss + pts_loss_refined + world_pts_loss) + t_loss)
|
| 1406 |
+
elif self.stage == 3:
|
| 1407 |
+
cam_loss += 0.8**(len(camera_predictions[0])-iter_-1)*(10*t_err + 500*R_err + 10*vis_loss_out_of_view) #+ 5*(pts_loss + pts_loss_refined + world_pts_loss) + t_loss)
|
| 1408 |
+
else:
|
| 1409 |
+
cam_loss += 0*vis_loss_out_of_view
|
| 1410 |
+
|
| 1411 |
+
if (cam_loss > 20000)|(torch.isnan(cam_loss)):
|
| 1412 |
+
cam_loss = torch.zeros_like(cam_loss)
|
| 1413 |
+
|
| 1414 |
+
|
| 1415 |
+
if traj3d_gt is None:
|
| 1416 |
+
# ================ Condition 1: The self-supervised signals from the self-consistency ===================
|
| 1417 |
+
return cam_loss, train_data[0][0][0].mean()*0, dyn_loss, train_data[0][0][0].mean()*0, point_map_loss, scale_loss, shift_loss
|
| 1418 |
+
|
| 1419 |
+
|
| 1420 |
+
# ================ Condition 2: The supervision signal given by the ground truth trajectories ===================
|
| 1421 |
+
if (
|
| 1422 |
+
(torch.isnan(traj3d_gt).any()
|
| 1423 |
+
or traj3d_gt.abs().max() > 2000) and (custom_vid==False)
|
| 1424 |
+
):
|
| 1425 |
+
return cam_loss, train_data[0][0][0].mean()*0, dyn_loss, train_data[0][0][0].mean()*0, point_map_loss, scale_loss, shift_loss
|
| 1426 |
+
|
| 1427 |
+
|
| 1428 |
+
vis_gts = [vis_gt.float()]
|
| 1429 |
+
invis_gts = [1-vis_gt.float()]
|
| 1430 |
+
traj_gts = [traj3d_gt]
|
| 1431 |
+
valids_gts = [valid_mask]
|
| 1432 |
+
seq_loss_all = sequence_loss(
|
| 1433 |
+
coord_predictions,
|
| 1434 |
+
traj_gts,
|
| 1435 |
+
valids_gts,
|
| 1436 |
+
vis=vis_gts,
|
| 1437 |
+
gamma=0.8,
|
| 1438 |
+
add_huber_loss=False,
|
| 1439 |
+
loss_only_for_visible=False if custom_vid==False else True,
|
| 1440 |
+
z_unc=z_unc,
|
| 1441 |
+
mask_traj_gt=mask_traj_gt
|
| 1442 |
+
)
|
| 1443 |
+
|
| 1444 |
+
confidence_loss = sequence_prob_loss(
|
| 1445 |
+
coord_predictions, confidence_predicitons, traj_gts, vis_gts
|
| 1446 |
+
)
|
| 1447 |
+
|
| 1448 |
+
seq_loss_xyz = sequence_loss_xyz(
|
| 1449 |
+
coords_xyz_predictions,
|
| 1450 |
+
traj_gts,
|
| 1451 |
+
valids_gts,
|
| 1452 |
+
intrs=intrs,
|
| 1453 |
+
vis=vis_gts,
|
| 1454 |
+
gamma=0.8,
|
| 1455 |
+
add_huber_loss=False,
|
| 1456 |
+
loss_only_for_visible=False,
|
| 1457 |
+
mask_traj_gt=mask_traj_gt
|
| 1458 |
+
)
|
| 1459 |
+
|
| 1460 |
+
# filter the blinking points
|
| 1461 |
+
mask_vis = vis_gts[0].clone() # B T N
|
| 1462 |
+
mask_vis[mask_vis==0] = -1
|
| 1463 |
+
blink_mask = mask_vis[:,:-1,:] * mask_vis[:,1:,:] # first derivative B (T-1) N
|
| 1464 |
+
mask_vis[:,:-1,:], mask_vis[:,-1,:] = (blink_mask == 1), 0
|
| 1465 |
+
|
| 1466 |
+
vis_loss = sequence_BCE_loss(vis_predictions, vis_gts, mask=[mask_vis])
|
| 1467 |
+
|
| 1468 |
+
track_loss_out = (seq_loss_all+2*seq_loss_xyz + cam_loss)
|
| 1469 |
+
if valid_only:
|
| 1470 |
+
vis_loss = 0.0*vis_loss
|
| 1471 |
+
if custom_vid:
|
| 1472 |
+
return seq_loss_all, 0.0*seq_loss_all, 0.0*seq_loss_all, 10*vis_loss, 0.0*seq_loss_all, 0.0*seq_loss_all, 0.0*seq_loss_all
|
| 1473 |
+
|
| 1474 |
+
return track_loss_out, confidence_loss, dyn_loss, 10*vis_loss, point_map_loss, scale_loss, shift_loss
|
| 1475 |
+
|
| 1476 |
+
|
| 1477 |
+
|
| 1478 |
+
|
models/SpaTrackV2/models/tracker3D/co_tracker/cotracker_base.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from models.SpaTrackV2.utils.model_utils import sample_features5d, bilinear_sampler
|
| 11 |
+
|
| 12 |
+
from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
|
| 13 |
+
Mlp, BasicEncoder, EfficientUpdateFormer
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
torch.manual_seed(0)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_1d_sincos_pos_embed_from_grid(
|
| 20 |
+
embed_dim: int, pos: torch.Tensor
|
| 21 |
+
) -> torch.Tensor:
|
| 22 |
+
"""
|
| 23 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
- embed_dim: The embedding dimension.
|
| 27 |
+
- pos: The position to generate the embedding from.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
- emb: The generated 1D positional embedding.
|
| 31 |
+
"""
|
| 32 |
+
assert embed_dim % 2 == 0
|
| 33 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
| 34 |
+
omega /= embed_dim / 2.0
|
| 35 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 36 |
+
|
| 37 |
+
pos = pos.reshape(-1) # (M,)
|
| 38 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 39 |
+
|
| 40 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 41 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 42 |
+
|
| 43 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 44 |
+
return emb[None].float()
|
| 45 |
+
|
| 46 |
+
def posenc(x, min_deg, max_deg):
|
| 47 |
+
"""Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
|
| 48 |
+
Instead of computing [sin(x), cos(x)], we use the trig identity
|
| 49 |
+
cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
|
| 50 |
+
Args:
|
| 51 |
+
x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
|
| 52 |
+
min_deg: int, the minimum (inclusive) degree of the encoding.
|
| 53 |
+
max_deg: int, the maximum (exclusive) degree of the encoding.
|
| 54 |
+
legacy_posenc_order: bool, keep the same ordering as the original tf code.
|
| 55 |
+
Returns:
|
| 56 |
+
encoded: torch.Tensor, encoded variables.
|
| 57 |
+
"""
|
| 58 |
+
if min_deg == max_deg:
|
| 59 |
+
return x
|
| 60 |
+
scales = torch.tensor(
|
| 61 |
+
[2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
xb = (x[..., None, :] * scales[:, None]).reshape(list(x.shape[:-1]) + [-1])
|
| 65 |
+
four_feat = torch.sin(torch.cat([xb, xb + 0.5 * torch.pi], dim=-1))
|
| 66 |
+
return torch.cat([x] + [four_feat], dim=-1)
|
| 67 |
+
|
| 68 |
+
class CoTrackerThreeBase(nn.Module):
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
window_len=8,
|
| 72 |
+
stride=4,
|
| 73 |
+
corr_radius=3,
|
| 74 |
+
corr_levels=4,
|
| 75 |
+
num_virtual_tracks=64,
|
| 76 |
+
model_resolution=(384, 512),
|
| 77 |
+
add_space_attn=True,
|
| 78 |
+
linear_layer_for_vis_conf=True,
|
| 79 |
+
):
|
| 80 |
+
super(CoTrackerThreeBase, self).__init__()
|
| 81 |
+
self.window_len = window_len
|
| 82 |
+
self.stride = stride
|
| 83 |
+
self.corr_radius = corr_radius
|
| 84 |
+
self.corr_levels = corr_levels
|
| 85 |
+
self.hidden_dim = 256
|
| 86 |
+
self.latent_dim = 128
|
| 87 |
+
|
| 88 |
+
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
|
| 89 |
+
self.fnet = BasicEncoder(input_dim=3, output_dim=self.latent_dim, stride=stride)
|
| 90 |
+
|
| 91 |
+
highres_dim = 128
|
| 92 |
+
lowres_dim = 256
|
| 93 |
+
|
| 94 |
+
self.num_virtual_tracks = num_virtual_tracks
|
| 95 |
+
self.model_resolution = model_resolution
|
| 96 |
+
|
| 97 |
+
self.input_dim = 1110
|
| 98 |
+
|
| 99 |
+
self.updateformer = EfficientUpdateFormer(
|
| 100 |
+
space_depth=3,
|
| 101 |
+
time_depth=3,
|
| 102 |
+
input_dim=self.input_dim,
|
| 103 |
+
hidden_size=384,
|
| 104 |
+
output_dim=4,
|
| 105 |
+
mlp_ratio=4.0,
|
| 106 |
+
num_virtual_tracks=num_virtual_tracks,
|
| 107 |
+
add_space_attn=add_space_attn,
|
| 108 |
+
linear_layer_for_vis_conf=linear_layer_for_vis_conf,
|
| 109 |
+
)
|
| 110 |
+
self.corr_mlp = Mlp(in_features=49 * 49, hidden_features=384, out_features=256)
|
| 111 |
+
|
| 112 |
+
time_grid = torch.linspace(0, window_len - 1, window_len).reshape(
|
| 113 |
+
1, window_len, 1
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
self.register_buffer(
|
| 117 |
+
"time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0])
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def get_support_points(self, coords, r, reshape_back=True):
|
| 121 |
+
B, _, N, _ = coords.shape
|
| 122 |
+
device = coords.device
|
| 123 |
+
centroid_lvl = coords.reshape(B, N, 1, 1, 3)
|
| 124 |
+
|
| 125 |
+
dx = torch.linspace(-r, r, 2 * r + 1, device=device)
|
| 126 |
+
dy = torch.linspace(-r, r, 2 * r + 1, device=device)
|
| 127 |
+
|
| 128 |
+
xgrid, ygrid = torch.meshgrid(dy, dx, indexing="ij")
|
| 129 |
+
zgrid = torch.zeros_like(xgrid, device=device)
|
| 130 |
+
delta = torch.stack([zgrid, xgrid, ygrid], axis=-1)
|
| 131 |
+
delta_lvl = delta.view(1, 1, 2 * r + 1, 2 * r + 1, 3)
|
| 132 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 133 |
+
|
| 134 |
+
if reshape_back:
|
| 135 |
+
return coords_lvl.reshape(B, N, (2 * r + 1) ** 2, 3).permute(0, 2, 1, 3)
|
| 136 |
+
else:
|
| 137 |
+
return coords_lvl
|
| 138 |
+
|
| 139 |
+
def get_track_feat(self, fmaps, queried_frames, queried_coords, support_radius=0):
|
| 140 |
+
|
| 141 |
+
sample_frames = queried_frames[:, None, :, None]
|
| 142 |
+
sample_coords = torch.cat(
|
| 143 |
+
[
|
| 144 |
+
sample_frames,
|
| 145 |
+
queried_coords[:, None],
|
| 146 |
+
],
|
| 147 |
+
dim=-1,
|
| 148 |
+
)
|
| 149 |
+
support_points = self.get_support_points(sample_coords, support_radius)
|
| 150 |
+
support_track_feats = sample_features5d(fmaps, support_points)
|
| 151 |
+
return (
|
| 152 |
+
support_track_feats[:, None, support_track_feats.shape[1] // 2],
|
| 153 |
+
support_track_feats,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def get_correlation_feat(self, fmaps, queried_coords, radius=None, padding_mode="border"):
|
| 157 |
+
B, T, D, H_, W_ = fmaps.shape
|
| 158 |
+
N = queried_coords.shape[1]
|
| 159 |
+
if radius is None:
|
| 160 |
+
r = self.corr_radius
|
| 161 |
+
else:
|
| 162 |
+
r = radius
|
| 163 |
+
sample_coords = torch.cat(
|
| 164 |
+
[torch.zeros_like(queried_coords[..., :1]), queried_coords], dim=-1
|
| 165 |
+
)[:, None]
|
| 166 |
+
support_points = self.get_support_points(sample_coords, r, reshape_back=False)
|
| 167 |
+
correlation_feat = bilinear_sampler(
|
| 168 |
+
fmaps.reshape(B * T, D, 1, H_, W_), support_points, padding_mode=padding_mode
|
| 169 |
+
)
|
| 170 |
+
return correlation_feat.view(B, T, D, N, (2 * r + 1), (2 * r + 1)).permute(
|
| 171 |
+
0, 1, 3, 4, 5, 2
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def interpolate_time_embed(self, x, t):
|
| 175 |
+
previous_dtype = x.dtype
|
| 176 |
+
T = self.time_emb.shape[1]
|
| 177 |
+
|
| 178 |
+
if t == T:
|
| 179 |
+
return self.time_emb
|
| 180 |
+
|
| 181 |
+
time_emb = self.time_emb.float()
|
| 182 |
+
time_emb = F.interpolate(
|
| 183 |
+
time_emb.permute(0, 2, 1), size=t, mode="linear"
|
| 184 |
+
).permute(0, 2, 1)
|
| 185 |
+
return time_emb.to(previous_dtype)
|
| 186 |
+
|
| 187 |
+
class CoTrackerThreeOffline(CoTrackerThreeBase):
|
| 188 |
+
def __init__(self, **args):
|
| 189 |
+
super(CoTrackerThreeOffline, self).__init__(**args)
|
| 190 |
+
|
| 191 |
+
def forward(
|
| 192 |
+
self,
|
| 193 |
+
video,
|
| 194 |
+
queries,
|
| 195 |
+
iters=4,
|
| 196 |
+
is_train=False,
|
| 197 |
+
add_space_attn=True,
|
| 198 |
+
fmaps_chunk_size=200,
|
| 199 |
+
):
|
| 200 |
+
"""Predict tracks
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
video (FloatTensor[B, T, 3]): input videos.
|
| 204 |
+
queries (FloatTensor[B, N, 3]): point queries.
|
| 205 |
+
iters (int, optional): number of updates. Defaults to 4.
|
| 206 |
+
is_train (bool, optional): enables training mode. Defaults to False.
|
| 207 |
+
Returns:
|
| 208 |
+
- coords_predicted (FloatTensor[B, T, N, 2]):
|
| 209 |
+
- vis_predicted (FloatTensor[B, T, N]):
|
| 210 |
+
- train_data: `None` if `is_train` is false, otherwise:
|
| 211 |
+
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
|
| 212 |
+
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
|
| 213 |
+
- mask (BoolTensor[B, T, N]):
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
B, T, C, H, W = video.shape
|
| 217 |
+
device = queries.device
|
| 218 |
+
assert H % self.stride == 0 and W % self.stride == 0
|
| 219 |
+
|
| 220 |
+
B, N, __ = queries.shape
|
| 221 |
+
# B = batch size
|
| 222 |
+
# S_trimmed = actual number of frames in the window
|
| 223 |
+
# N = number of tracks
|
| 224 |
+
# C = color channels (3 for RGB)
|
| 225 |
+
# E = positional embedding size
|
| 226 |
+
# LRR = local receptive field radius
|
| 227 |
+
# D = dimension of the transformer input tokens
|
| 228 |
+
|
| 229 |
+
# video = B T C H W
|
| 230 |
+
# queries = B N 3
|
| 231 |
+
# coords_init = B T N 2
|
| 232 |
+
# vis_init = B T N 1
|
| 233 |
+
|
| 234 |
+
assert T >= 1 # A tracker needs at least two frames to track something
|
| 235 |
+
|
| 236 |
+
video = 2 * (video / 255.0) - 1.0
|
| 237 |
+
dtype = video.dtype
|
| 238 |
+
queried_frames = queries[:, :, 0].long()
|
| 239 |
+
|
| 240 |
+
queried_coords = queries[..., 1:3]
|
| 241 |
+
queried_coords = queried_coords / self.stride
|
| 242 |
+
|
| 243 |
+
# We store our predictions here
|
| 244 |
+
all_coords_predictions, all_vis_predictions, all_confidence_predictions = (
|
| 245 |
+
[],
|
| 246 |
+
[],
|
| 247 |
+
[],
|
| 248 |
+
)
|
| 249 |
+
C_ = C
|
| 250 |
+
H4, W4 = H // self.stride, W // self.stride
|
| 251 |
+
# Compute convolutional features for the video or for the current chunk in case of online mode
|
| 252 |
+
|
| 253 |
+
if T > fmaps_chunk_size:
|
| 254 |
+
fmaps = []
|
| 255 |
+
for t in range(0, T, fmaps_chunk_size):
|
| 256 |
+
video_chunk = video[:, t : t + fmaps_chunk_size]
|
| 257 |
+
fmaps_chunk = self.fnet(video_chunk.reshape(-1, C_, H, W))
|
| 258 |
+
T_chunk = video_chunk.shape[1]
|
| 259 |
+
C_chunk, H_chunk, W_chunk = fmaps_chunk.shape[1:]
|
| 260 |
+
fmaps.append(fmaps_chunk.reshape(B, T_chunk, C_chunk, H_chunk, W_chunk))
|
| 261 |
+
fmaps = torch.cat(fmaps, dim=1).reshape(-1, C_chunk, H_chunk, W_chunk)
|
| 262 |
+
else:
|
| 263 |
+
fmaps = self.fnet(video.reshape(-1, C_, H, W))
|
| 264 |
+
fmaps = fmaps.permute(0, 2, 3, 1)
|
| 265 |
+
fmaps = fmaps / torch.sqrt(
|
| 266 |
+
torch.maximum(
|
| 267 |
+
torch.sum(torch.square(fmaps), axis=-1, keepdims=True),
|
| 268 |
+
torch.tensor(1e-12, device=fmaps.device),
|
| 269 |
+
)
|
| 270 |
+
)
|
| 271 |
+
fmaps = fmaps.permute(0, 3, 1, 2).reshape(
|
| 272 |
+
B, -1, self.latent_dim, H // self.stride, W // self.stride
|
| 273 |
+
)
|
| 274 |
+
fmaps = fmaps.to(dtype)
|
| 275 |
+
|
| 276 |
+
# We compute track features
|
| 277 |
+
fmaps_pyramid = []
|
| 278 |
+
track_feat_pyramid = []
|
| 279 |
+
track_feat_support_pyramid = []
|
| 280 |
+
fmaps_pyramid.append(fmaps)
|
| 281 |
+
for i in range(self.corr_levels - 1):
|
| 282 |
+
fmaps_ = fmaps.reshape(
|
| 283 |
+
B * T, self.latent_dim, fmaps.shape[-2], fmaps.shape[-1]
|
| 284 |
+
)
|
| 285 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
| 286 |
+
fmaps = fmaps_.reshape(
|
| 287 |
+
B, T, self.latent_dim, fmaps_.shape[-2], fmaps_.shape[-1]
|
| 288 |
+
)
|
| 289 |
+
fmaps_pyramid.append(fmaps)
|
| 290 |
+
|
| 291 |
+
for i in range(self.corr_levels):
|
| 292 |
+
track_feat, track_feat_support = self.get_track_feat(
|
| 293 |
+
fmaps_pyramid[i],
|
| 294 |
+
queried_frames,
|
| 295 |
+
queried_coords / 2**i,
|
| 296 |
+
support_radius=self.corr_radius,
|
| 297 |
+
)
|
| 298 |
+
track_feat_pyramid.append(track_feat.repeat(1, T, 1, 1))
|
| 299 |
+
track_feat_support_pyramid.append(track_feat_support.unsqueeze(1))
|
| 300 |
+
|
| 301 |
+
D_coords = 2
|
| 302 |
+
|
| 303 |
+
coord_preds, vis_preds, confidence_preds = [], [], []
|
| 304 |
+
|
| 305 |
+
vis = torch.zeros((B, T, N), device=device).float()
|
| 306 |
+
confidence = torch.zeros((B, T, N), device=device).float()
|
| 307 |
+
coords = queried_coords.reshape(B, 1, N, 2).expand(B, T, N, 2).float()
|
| 308 |
+
|
| 309 |
+
r = 2 * self.corr_radius + 1
|
| 310 |
+
|
| 311 |
+
for it in range(iters):
|
| 312 |
+
coords = coords.detach() # B T N 2
|
| 313 |
+
coords_init = coords.view(B * T, N, 2)
|
| 314 |
+
corr_embs = []
|
| 315 |
+
corr_feats = []
|
| 316 |
+
for i in range(self.corr_levels):
|
| 317 |
+
corr_feat = self.get_correlation_feat(
|
| 318 |
+
fmaps_pyramid[i], coords_init / 2**i
|
| 319 |
+
)
|
| 320 |
+
track_feat_support = (
|
| 321 |
+
track_feat_support_pyramid[i]
|
| 322 |
+
.view(B, 1, r, r, N, self.latent_dim)
|
| 323 |
+
.squeeze(1)
|
| 324 |
+
.permute(0, 3, 1, 2, 4)
|
| 325 |
+
)
|
| 326 |
+
corr_volume = torch.einsum(
|
| 327 |
+
"btnhwc,bnijc->btnhwij", corr_feat, track_feat_support
|
| 328 |
+
)
|
| 329 |
+
corr_emb = self.corr_mlp(corr_volume.reshape(B * T * N, r * r * r * r))
|
| 330 |
+
corr_embs.append(corr_emb)
|
| 331 |
+
corr_embs = torch.cat(corr_embs, dim=-1)
|
| 332 |
+
corr_embs = corr_embs.view(B, T, N, corr_embs.shape[-1])
|
| 333 |
+
|
| 334 |
+
transformer_input = [vis[..., None], confidence[..., None], corr_embs]
|
| 335 |
+
|
| 336 |
+
rel_coords_forward = coords[:, :-1] - coords[:, 1:]
|
| 337 |
+
rel_coords_backward = coords[:, 1:] - coords[:, :-1]
|
| 338 |
+
|
| 339 |
+
rel_coords_forward = torch.nn.functional.pad(
|
| 340 |
+
rel_coords_forward, (0, 0, 0, 0, 0, 1)
|
| 341 |
+
)
|
| 342 |
+
rel_coords_backward = torch.nn.functional.pad(
|
| 343 |
+
rel_coords_backward, (0, 0, 0, 0, 1, 0)
|
| 344 |
+
)
|
| 345 |
+
scale = (
|
| 346 |
+
torch.tensor(
|
| 347 |
+
[self.model_resolution[1], self.model_resolution[0]],
|
| 348 |
+
device=coords.device,
|
| 349 |
+
)
|
| 350 |
+
/ self.stride
|
| 351 |
+
)
|
| 352 |
+
rel_coords_forward = rel_coords_forward / scale
|
| 353 |
+
rel_coords_backward = rel_coords_backward / scale
|
| 354 |
+
|
| 355 |
+
rel_pos_emb_input = posenc(
|
| 356 |
+
torch.cat([rel_coords_forward, rel_coords_backward], dim=-1),
|
| 357 |
+
min_deg=0,
|
| 358 |
+
max_deg=10,
|
| 359 |
+
) # batch, num_points, num_frames, 84
|
| 360 |
+
transformer_input.append(rel_pos_emb_input)
|
| 361 |
+
|
| 362 |
+
x = (
|
| 363 |
+
torch.cat(transformer_input, dim=-1)
|
| 364 |
+
.permute(0, 2, 1, 3)
|
| 365 |
+
.reshape(B * N, T, -1)
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
x = x + self.interpolate_time_embed(x, T)
|
| 369 |
+
x = x.view(B, N, T, -1) # (B N) T D -> B N T D
|
| 370 |
+
|
| 371 |
+
delta = self.updateformer(
|
| 372 |
+
x,
|
| 373 |
+
add_space_attn=add_space_attn,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
delta_coords = delta[..., :D_coords].permute(0, 2, 1, 3)
|
| 377 |
+
delta_vis = delta[..., D_coords].permute(0, 2, 1)
|
| 378 |
+
delta_confidence = delta[..., D_coords + 1].permute(0, 2, 1)
|
| 379 |
+
|
| 380 |
+
vis = vis + delta_vis
|
| 381 |
+
confidence = confidence + delta_confidence
|
| 382 |
+
|
| 383 |
+
coords = coords + delta_coords
|
| 384 |
+
coords_append = coords.clone()
|
| 385 |
+
coords_append[..., :2] = coords_append[..., :2] * float(self.stride)
|
| 386 |
+
coord_preds.append(coords_append)
|
| 387 |
+
vis_preds.append(torch.sigmoid(vis))
|
| 388 |
+
confidence_preds.append(torch.sigmoid(confidence))
|
| 389 |
+
|
| 390 |
+
if is_train:
|
| 391 |
+
all_coords_predictions.append([coord[..., :2] for coord in coord_preds])
|
| 392 |
+
all_vis_predictions.append(vis_preds)
|
| 393 |
+
all_confidence_predictions.append(confidence_preds)
|
| 394 |
+
|
| 395 |
+
if is_train:
|
| 396 |
+
train_data = (
|
| 397 |
+
all_coords_predictions,
|
| 398 |
+
all_vis_predictions,
|
| 399 |
+
all_confidence_predictions,
|
| 400 |
+
torch.ones_like(vis_preds[-1], device=vis_preds[-1].device),
|
| 401 |
+
)
|
| 402 |
+
else:
|
| 403 |
+
train_data = None
|
| 404 |
+
|
| 405 |
+
return coord_preds[-1][..., :2], vis_preds[-1], confidence_preds[-1], train_data
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
if __name__ == "__main__":
|
| 409 |
+
cotrack_cktp = "/data0/xyx/scaled_offline.pth"
|
| 410 |
+
cotracker = CoTrackerThreeOffline(
|
| 411 |
+
stride=4, corr_radius=3, window_len=60
|
| 412 |
+
)
|
| 413 |
+
with open(cotrack_cktp, "rb") as f:
|
| 414 |
+
state_dict = torch.load(f, map_location="cpu")
|
| 415 |
+
if "model" in state_dict:
|
| 416 |
+
state_dict = state_dict["model"]
|
| 417 |
+
cotracker.load_state_dict(state_dict)
|
| 418 |
+
import pdb; pdb.set_trace()
|
models/SpaTrackV2/models/tracker3D/co_tracker/utils.py
ADDED
|
@@ -0,0 +1,929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import Callable, List
|
| 6 |
+
import collections
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from itertools import repeat
|
| 9 |
+
from models.SpaTrackV2.utils.model_utils import bilinear_sampler
|
| 10 |
+
from models.SpaTrackV2.models.blocks import CrossAttnBlock as CrossAttnBlock_F
|
| 11 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 12 |
+
from torch.nn.attention import sdpa_kernel, SDPBackend
|
| 13 |
+
# import flash_attn
|
| 14 |
+
EPS = 1e-6
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ResidualBlock(nn.Module):
|
| 18 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
| 19 |
+
super(ResidualBlock, self).__init__()
|
| 20 |
+
|
| 21 |
+
self.conv1 = nn.Conv2d(
|
| 22 |
+
in_planes,
|
| 23 |
+
planes,
|
| 24 |
+
kernel_size=3,
|
| 25 |
+
padding=1,
|
| 26 |
+
stride=stride,
|
| 27 |
+
padding_mode="zeros",
|
| 28 |
+
)
|
| 29 |
+
self.conv2 = nn.Conv2d(
|
| 30 |
+
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
|
| 31 |
+
)
|
| 32 |
+
self.relu = nn.ReLU(inplace=True)
|
| 33 |
+
|
| 34 |
+
num_groups = planes // 8
|
| 35 |
+
|
| 36 |
+
if norm_fn == "group":
|
| 37 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 38 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 39 |
+
if not stride == 1:
|
| 40 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 41 |
+
|
| 42 |
+
elif norm_fn == "batch":
|
| 43 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 44 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 45 |
+
if not stride == 1:
|
| 46 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 47 |
+
|
| 48 |
+
elif norm_fn == "instance":
|
| 49 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
| 50 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
| 51 |
+
if not stride == 1:
|
| 52 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 53 |
+
|
| 54 |
+
elif norm_fn == "none":
|
| 55 |
+
self.norm1 = nn.Sequential()
|
| 56 |
+
self.norm2 = nn.Sequential()
|
| 57 |
+
if not stride == 1:
|
| 58 |
+
self.norm3 = nn.Sequential()
|
| 59 |
+
|
| 60 |
+
if stride == 1:
|
| 61 |
+
self.downsample = None
|
| 62 |
+
|
| 63 |
+
else:
|
| 64 |
+
self.downsample = nn.Sequential(
|
| 65 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
y = x
|
| 70 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 71 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 72 |
+
|
| 73 |
+
if self.downsample is not None:
|
| 74 |
+
x = self.downsample(x)
|
| 75 |
+
|
| 76 |
+
return self.relu(x + y)
|
| 77 |
+
|
| 78 |
+
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
|
| 79 |
+
r"""Masked mean
|
| 80 |
+
|
| 81 |
+
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
|
| 82 |
+
over a mask :attr:`mask`, returning
|
| 83 |
+
|
| 84 |
+
.. math::
|
| 85 |
+
\text{output} =
|
| 86 |
+
\frac
|
| 87 |
+
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
|
| 88 |
+
{\epsilon + \sum_{i=1}^N \text{mask}_i}
|
| 89 |
+
|
| 90 |
+
where :math:`N` is the number of elements in :attr:`input` and
|
| 91 |
+
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
|
| 92 |
+
division by zero.
|
| 93 |
+
|
| 94 |
+
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
|
| 95 |
+
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
|
| 96 |
+
Optionally, the dimension can be kept in the output by setting
|
| 97 |
+
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
|
| 98 |
+
the same dimension as :attr:`input`.
|
| 99 |
+
|
| 100 |
+
The interface is similar to `torch.mean()`.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
inout (Tensor): input tensor.
|
| 104 |
+
mask (Tensor): mask.
|
| 105 |
+
dim (int, optional): Dimension to sum over. Defaults to None.
|
| 106 |
+
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Tensor: mean tensor.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
mask = mask.expand_as(input)
|
| 113 |
+
|
| 114 |
+
prod = input * mask
|
| 115 |
+
|
| 116 |
+
if dim is None:
|
| 117 |
+
numer = torch.sum(prod)
|
| 118 |
+
denom = torch.sum(mask)
|
| 119 |
+
else:
|
| 120 |
+
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
| 121 |
+
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
|
| 122 |
+
|
| 123 |
+
mean = numer / (EPS + denom)
|
| 124 |
+
return mean
|
| 125 |
+
|
| 126 |
+
class GeometryEncoder(nn.Module):
|
| 127 |
+
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
| 128 |
+
super(GeometryEncoder, self).__init__()
|
| 129 |
+
self.stride = stride
|
| 130 |
+
self.norm_fn = "instance"
|
| 131 |
+
self.in_planes = output_dim // 2
|
| 132 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
| 133 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
| 134 |
+
self.conv1 = nn.Conv2d(
|
| 135 |
+
input_dim,
|
| 136 |
+
self.in_planes,
|
| 137 |
+
kernel_size=7,
|
| 138 |
+
stride=2,
|
| 139 |
+
padding=3,
|
| 140 |
+
padding_mode="zeros",
|
| 141 |
+
)
|
| 142 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 143 |
+
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
| 144 |
+
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
| 145 |
+
|
| 146 |
+
self.conv2 = nn.Conv2d(
|
| 147 |
+
output_dim * 5 // 4,
|
| 148 |
+
output_dim,
|
| 149 |
+
kernel_size=3,
|
| 150 |
+
padding=1,
|
| 151 |
+
padding_mode="zeros",
|
| 152 |
+
)
|
| 153 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 154 |
+
self.conv3 = nn.Conv2d(output_dim, output_dim, kernel_size=1)
|
| 155 |
+
for m in self.modules():
|
| 156 |
+
if isinstance(m, nn.Conv2d):
|
| 157 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 158 |
+
elif isinstance(m, (nn.InstanceNorm2d)):
|
| 159 |
+
if m.weight is not None:
|
| 160 |
+
nn.init.constant_(m.weight, 1)
|
| 161 |
+
if m.bias is not None:
|
| 162 |
+
nn.init.constant_(m.bias, 0)
|
| 163 |
+
|
| 164 |
+
def _make_layer(self, dim, stride=1):
|
| 165 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 166 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 167 |
+
layers = (layer1, layer2)
|
| 168 |
+
|
| 169 |
+
self.in_planes = dim
|
| 170 |
+
return nn.Sequential(*layers)
|
| 171 |
+
|
| 172 |
+
def forward(self, x):
|
| 173 |
+
_, _, H, W = x.shape
|
| 174 |
+
x = self.conv1(x)
|
| 175 |
+
x = self.norm1(x)
|
| 176 |
+
x = self.relu1(x)
|
| 177 |
+
a = self.layer1(x)
|
| 178 |
+
b = self.layer2(a)
|
| 179 |
+
def _bilinear_intepolate(x):
|
| 180 |
+
return F.interpolate(
|
| 181 |
+
x,
|
| 182 |
+
(H // self.stride, W // self.stride),
|
| 183 |
+
mode="bilinear",
|
| 184 |
+
align_corners=True,
|
| 185 |
+
)
|
| 186 |
+
a = _bilinear_intepolate(a)
|
| 187 |
+
b = _bilinear_intepolate(b)
|
| 188 |
+
x = self.conv2(torch.cat([a, b], dim=1))
|
| 189 |
+
x = self.norm2(x)
|
| 190 |
+
x = self.relu2(x)
|
| 191 |
+
x = self.conv3(x)
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
class BasicEncoder(nn.Module):
|
| 195 |
+
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
| 196 |
+
super(BasicEncoder, self).__init__()
|
| 197 |
+
self.stride = stride
|
| 198 |
+
self.norm_fn = "instance"
|
| 199 |
+
self.in_planes = output_dim // 2
|
| 200 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
| 201 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
| 202 |
+
|
| 203 |
+
self.conv1 = nn.Conv2d(
|
| 204 |
+
input_dim,
|
| 205 |
+
self.in_planes,
|
| 206 |
+
kernel_size=7,
|
| 207 |
+
stride=2,
|
| 208 |
+
padding=3,
|
| 209 |
+
padding_mode="zeros",
|
| 210 |
+
)
|
| 211 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 212 |
+
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
| 213 |
+
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
| 214 |
+
self.layer3 = self._make_layer(output_dim, stride=2)
|
| 215 |
+
self.layer4 = self._make_layer(output_dim, stride=2)
|
| 216 |
+
|
| 217 |
+
self.conv2 = nn.Conv2d(
|
| 218 |
+
output_dim * 3 + output_dim // 4,
|
| 219 |
+
output_dim * 2,
|
| 220 |
+
kernel_size=3,
|
| 221 |
+
padding=1,
|
| 222 |
+
padding_mode="zeros",
|
| 223 |
+
)
|
| 224 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 225 |
+
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
| 226 |
+
for m in self.modules():
|
| 227 |
+
if isinstance(m, nn.Conv2d):
|
| 228 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 229 |
+
elif isinstance(m, (nn.InstanceNorm2d)):
|
| 230 |
+
if m.weight is not None:
|
| 231 |
+
nn.init.constant_(m.weight, 1)
|
| 232 |
+
if m.bias is not None:
|
| 233 |
+
nn.init.constant_(m.bias, 0)
|
| 234 |
+
|
| 235 |
+
def _make_layer(self, dim, stride=1):
|
| 236 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 237 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 238 |
+
layers = (layer1, layer2)
|
| 239 |
+
|
| 240 |
+
self.in_planes = dim
|
| 241 |
+
return nn.Sequential(*layers)
|
| 242 |
+
|
| 243 |
+
def forward(self, x):
|
| 244 |
+
_, _, H, W = x.shape
|
| 245 |
+
|
| 246 |
+
x = self.conv1(x)
|
| 247 |
+
x = self.norm1(x)
|
| 248 |
+
x = self.relu1(x)
|
| 249 |
+
|
| 250 |
+
a = self.layer1(x)
|
| 251 |
+
b = self.layer2(a)
|
| 252 |
+
c = self.layer3(b)
|
| 253 |
+
d = self.layer4(c)
|
| 254 |
+
|
| 255 |
+
def _bilinear_intepolate(x):
|
| 256 |
+
return F.interpolate(
|
| 257 |
+
x,
|
| 258 |
+
(H // self.stride, W // self.stride),
|
| 259 |
+
mode="bilinear",
|
| 260 |
+
align_corners=True,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
a = _bilinear_intepolate(a)
|
| 264 |
+
b = _bilinear_intepolate(b)
|
| 265 |
+
c = _bilinear_intepolate(c)
|
| 266 |
+
d = _bilinear_intepolate(d)
|
| 267 |
+
|
| 268 |
+
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
| 269 |
+
x = self.norm2(x)
|
| 270 |
+
x = self.relu2(x)
|
| 271 |
+
x = self.conv3(x)
|
| 272 |
+
return x
|
| 273 |
+
|
| 274 |
+
# From PyTorch internals
|
| 275 |
+
def _ntuple(n):
|
| 276 |
+
def parse(x):
|
| 277 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 278 |
+
return tuple(x)
|
| 279 |
+
return tuple(repeat(x, n))
|
| 280 |
+
|
| 281 |
+
return parse
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def exists(val):
|
| 285 |
+
return val is not None
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def default(val, d):
|
| 289 |
+
return val if exists(val) else d
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
to_2tuple = _ntuple(2)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class Mlp(nn.Module):
|
| 296 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 297 |
+
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
in_features,
|
| 301 |
+
hidden_features=None,
|
| 302 |
+
out_features=None,
|
| 303 |
+
act_layer=nn.GELU,
|
| 304 |
+
norm_layer=None,
|
| 305 |
+
bias=True,
|
| 306 |
+
drop=0.0,
|
| 307 |
+
use_conv=False,
|
| 308 |
+
):
|
| 309 |
+
super().__init__()
|
| 310 |
+
out_features = out_features or in_features
|
| 311 |
+
hidden_features = hidden_features or in_features
|
| 312 |
+
bias = to_2tuple(bias)
|
| 313 |
+
drop_probs = to_2tuple(drop)
|
| 314 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 315 |
+
|
| 316 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
| 317 |
+
self.act = act_layer()
|
| 318 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 319 |
+
self.norm = (
|
| 320 |
+
norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
| 321 |
+
)
|
| 322 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
| 323 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 324 |
+
|
| 325 |
+
def forward(self, x):
|
| 326 |
+
x = self.fc1(x)
|
| 327 |
+
x = self.act(x)
|
| 328 |
+
x = self.drop1(x)
|
| 329 |
+
x = self.fc2(x)
|
| 330 |
+
x = self.drop2(x)
|
| 331 |
+
return x
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class Attention(nn.Module):
|
| 335 |
+
def __init__(
|
| 336 |
+
self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False
|
| 337 |
+
):
|
| 338 |
+
super().__init__()
|
| 339 |
+
inner_dim = dim_head * num_heads
|
| 340 |
+
self.inner_dim = inner_dim
|
| 341 |
+
context_dim = default(context_dim, query_dim)
|
| 342 |
+
self.scale = dim_head**-0.5
|
| 343 |
+
self.heads = num_heads
|
| 344 |
+
|
| 345 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
| 346 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
| 347 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
| 348 |
+
|
| 349 |
+
def forward(self, x, context=None, attn_bias=None, flash=True):
|
| 350 |
+
B, N1, C = x.shape
|
| 351 |
+
h = self.heads
|
| 352 |
+
|
| 353 |
+
q = self.to_q(x).reshape(B, N1, h, self.inner_dim // h).permute(0, 2, 1, 3)
|
| 354 |
+
context = default(context, x)
|
| 355 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
| 356 |
+
|
| 357 |
+
N2 = context.shape[1]
|
| 358 |
+
k = k.reshape(B, N2, h, self.inner_dim // h).permute(0, 2, 1, 3)
|
| 359 |
+
v = v.reshape(B, N2, h, self.inner_dim // h).permute(0, 2, 1, 3)
|
| 360 |
+
|
| 361 |
+
if (
|
| 362 |
+
(N1 < 64 and N2 < 64) or
|
| 363 |
+
(B > 1e4) or
|
| 364 |
+
(q.shape[1] != k.shape[1]) or
|
| 365 |
+
(q.shape[1] % k.shape[1] != 0)
|
| 366 |
+
):
|
| 367 |
+
flash = False
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
if flash == False:
|
| 371 |
+
sim = (q @ k.transpose(-2, -1)) * self.scale
|
| 372 |
+
if attn_bias is not None:
|
| 373 |
+
sim = sim + attn_bias
|
| 374 |
+
if sim.abs().max() > 1e2:
|
| 375 |
+
import pdb; pdb.set_trace()
|
| 376 |
+
attn = sim.softmax(dim=-1)
|
| 377 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N1, self.inner_dim)
|
| 378 |
+
else:
|
| 379 |
+
|
| 380 |
+
input_args = [x.contiguous() for x in [q, k, v]]
|
| 381 |
+
try:
|
| 382 |
+
# print(f"q.shape: {q.shape}, dtype: {q.dtype}, device: {q.device}")
|
| 383 |
+
# print(f"Flash SDP available: {torch.backends.cuda.flash_sdp_enabled()}")
|
| 384 |
+
# print(f"Flash SDP allowed: {torch.backends.cuda.enable_flash_sdp}")
|
| 385 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
| 386 |
+
x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
|
| 387 |
+
except Exception as e:
|
| 388 |
+
print(e)
|
| 389 |
+
|
| 390 |
+
if self.to_out.bias.dtype != x.dtype:
|
| 391 |
+
x = x.to(self.to_out.bias.dtype)
|
| 392 |
+
|
| 393 |
+
return self.to_out(x)
|
| 394 |
+
|
| 395 |
+
class CrossAttnBlock(nn.Module):
|
| 396 |
+
def __init__(
|
| 397 |
+
self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
|
| 398 |
+
):
|
| 399 |
+
super().__init__()
|
| 400 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 401 |
+
self.norm_context = nn.LayerNorm(context_dim)
|
| 402 |
+
self.cross_attn = Attention(
|
| 403 |
+
hidden_size,
|
| 404 |
+
context_dim=context_dim,
|
| 405 |
+
num_heads=num_heads,
|
| 406 |
+
qkv_bias=True,
|
| 407 |
+
**block_kwargs
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 411 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 412 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 413 |
+
self.mlp = Mlp(
|
| 414 |
+
in_features=hidden_size,
|
| 415 |
+
hidden_features=mlp_hidden_dim,
|
| 416 |
+
act_layer=approx_gelu,
|
| 417 |
+
drop=0,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
def forward(self, x, context, mask=None):
|
| 421 |
+
attn_bias = None
|
| 422 |
+
if mask is not None:
|
| 423 |
+
if mask.shape[1] == x.shape[1]:
|
| 424 |
+
mask = mask[:, None, :, None].expand(
|
| 425 |
+
-1, self.cross_attn.heads, -1, context.shape[1]
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
mask = mask[:, None, None].expand(
|
| 429 |
+
-1, self.cross_attn.heads, x.shape[1], -1
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
| 433 |
+
attn_bias = (~mask) * max_neg_value
|
| 434 |
+
x = x + self.cross_attn(
|
| 435 |
+
self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
|
| 436 |
+
)
|
| 437 |
+
x = x + self.mlp(self.norm2(x))
|
| 438 |
+
return x
|
| 439 |
+
|
| 440 |
+
class AttnBlock(nn.Module):
|
| 441 |
+
def __init__(
|
| 442 |
+
self,
|
| 443 |
+
hidden_size,
|
| 444 |
+
num_heads,
|
| 445 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 446 |
+
mlp_ratio=4.0,
|
| 447 |
+
**block_kwargs
|
| 448 |
+
):
|
| 449 |
+
super().__init__()
|
| 450 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 451 |
+
self.attn = attn_class(
|
| 452 |
+
hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 456 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 457 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 458 |
+
self.mlp = Mlp(
|
| 459 |
+
in_features=hidden_size,
|
| 460 |
+
hidden_features=mlp_hidden_dim,
|
| 461 |
+
act_layer=approx_gelu,
|
| 462 |
+
drop=0,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
def forward(self, x, mask=None):
|
| 466 |
+
attn_bias = mask
|
| 467 |
+
if mask is not None:
|
| 468 |
+
mask = (
|
| 469 |
+
(mask[:, None] * mask[:, :, None])
|
| 470 |
+
.unsqueeze(1)
|
| 471 |
+
.expand(-1, self.attn.num_heads, -1, -1)
|
| 472 |
+
)
|
| 473 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
| 474 |
+
attn_bias = (~mask) * max_neg_value
|
| 475 |
+
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 476 |
+
x = x + self.mlp(self.norm2(x))
|
| 477 |
+
return x
|
| 478 |
+
|
| 479 |
+
class EfficientUpdateFormer(nn.Module):
|
| 480 |
+
"""
|
| 481 |
+
Transformer model that updates track estimates.
|
| 482 |
+
"""
|
| 483 |
+
|
| 484 |
+
def __init__(
|
| 485 |
+
self,
|
| 486 |
+
space_depth=6,
|
| 487 |
+
time_depth=6,
|
| 488 |
+
input_dim=320,
|
| 489 |
+
hidden_size=384,
|
| 490 |
+
num_heads=8,
|
| 491 |
+
output_dim=130,
|
| 492 |
+
mlp_ratio=4.0,
|
| 493 |
+
num_virtual_tracks=64,
|
| 494 |
+
add_space_attn=True,
|
| 495 |
+
linear_layer_for_vis_conf=False,
|
| 496 |
+
patch_feat=False,
|
| 497 |
+
patch_dim=128,
|
| 498 |
+
):
|
| 499 |
+
super().__init__()
|
| 500 |
+
self.out_channels = 2
|
| 501 |
+
self.num_heads = num_heads
|
| 502 |
+
self.hidden_size = hidden_size
|
| 503 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
| 504 |
+
if linear_layer_for_vis_conf:
|
| 505 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
|
| 506 |
+
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
|
| 507 |
+
else:
|
| 508 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
| 509 |
+
|
| 510 |
+
if patch_feat==False:
|
| 511 |
+
self.virual_tracks = nn.Parameter(
|
| 512 |
+
torch.randn(1, num_virtual_tracks, 1, hidden_size)
|
| 513 |
+
)
|
| 514 |
+
self.num_virtual_tracks = num_virtual_tracks
|
| 515 |
+
else:
|
| 516 |
+
self.patch_proj = nn.Linear(patch_dim, hidden_size, bias=True)
|
| 517 |
+
|
| 518 |
+
self.add_space_attn = add_space_attn
|
| 519 |
+
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
|
| 520 |
+
self.time_blocks = nn.ModuleList(
|
| 521 |
+
[
|
| 522 |
+
AttnBlock(
|
| 523 |
+
hidden_size,
|
| 524 |
+
num_heads,
|
| 525 |
+
mlp_ratio=mlp_ratio,
|
| 526 |
+
attn_class=Attention,
|
| 527 |
+
)
|
| 528 |
+
for _ in range(time_depth)
|
| 529 |
+
]
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
if add_space_attn:
|
| 533 |
+
self.space_virtual_blocks = nn.ModuleList(
|
| 534 |
+
[
|
| 535 |
+
AttnBlock(
|
| 536 |
+
hidden_size,
|
| 537 |
+
num_heads,
|
| 538 |
+
mlp_ratio=mlp_ratio,
|
| 539 |
+
attn_class=Attention,
|
| 540 |
+
)
|
| 541 |
+
for _ in range(space_depth)
|
| 542 |
+
]
|
| 543 |
+
)
|
| 544 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
| 545 |
+
[
|
| 546 |
+
CrossAttnBlock(
|
| 547 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 548 |
+
)
|
| 549 |
+
for _ in range(space_depth)
|
| 550 |
+
]
|
| 551 |
+
)
|
| 552 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
| 553 |
+
[
|
| 554 |
+
CrossAttnBlock(
|
| 555 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 556 |
+
)
|
| 557 |
+
for _ in range(space_depth)
|
| 558 |
+
]
|
| 559 |
+
)
|
| 560 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
| 561 |
+
self.initialize_weights()
|
| 562 |
+
|
| 563 |
+
def initialize_weights(self):
|
| 564 |
+
def _basic_init(module):
|
| 565 |
+
if isinstance(module, nn.Linear):
|
| 566 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 567 |
+
if module.bias is not None:
|
| 568 |
+
nn.init.constant_(module.bias, 0)
|
| 569 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
| 570 |
+
if self.linear_layer_for_vis_conf:
|
| 571 |
+
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
|
| 572 |
+
|
| 573 |
+
def _trunc_init(module):
|
| 574 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 575 |
+
if isinstance(module, nn.Linear):
|
| 576 |
+
torch.nn.init.trunc_normal_(module.weight, std=0.02)
|
| 577 |
+
if module.bias is not None:
|
| 578 |
+
nn.init.zeros_(module.bias)
|
| 579 |
+
|
| 580 |
+
self.apply(_basic_init)
|
| 581 |
+
|
| 582 |
+
def forward(self, input_tensor, mask=None, add_space_attn=True, patch_feat=None):
|
| 583 |
+
tokens = self.input_transform(input_tensor)
|
| 584 |
+
|
| 585 |
+
B, _, T, _ = tokens.shape
|
| 586 |
+
if patch_feat is None:
|
| 587 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
| 588 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
| 589 |
+
else:
|
| 590 |
+
patch_feat = self.patch_proj(patch_feat.detach())
|
| 591 |
+
tokens = torch.cat([tokens, patch_feat], dim=1)
|
| 592 |
+
self.num_virtual_tracks = patch_feat.shape[1]
|
| 593 |
+
|
| 594 |
+
_, N, _, _ = tokens.shape
|
| 595 |
+
j = 0
|
| 596 |
+
layers = []
|
| 597 |
+
for i in range(len(self.time_blocks)):
|
| 598 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
| 599 |
+
time_tokens = torch.utils.checkpoint.checkpoint(
|
| 600 |
+
self.time_blocks[i],
|
| 601 |
+
time_tokens
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
| 605 |
+
if (
|
| 606 |
+
add_space_attn
|
| 607 |
+
and hasattr(self, "space_virtual_blocks")
|
| 608 |
+
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
|
| 609 |
+
):
|
| 610 |
+
space_tokens = (
|
| 611 |
+
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
| 612 |
+
) # B N T C -> (B T) N C
|
| 613 |
+
|
| 614 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
| 615 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
| 616 |
+
|
| 617 |
+
virtual_tokens = torch.utils.checkpoint.checkpoint(
|
| 618 |
+
self.space_virtual2point_blocks[j],
|
| 619 |
+
virtual_tokens, point_tokens, mask
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
virtual_tokens = torch.utils.checkpoint.checkpoint(
|
| 623 |
+
self.space_virtual_blocks[j],
|
| 624 |
+
virtual_tokens
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
point_tokens = torch.utils.checkpoint.checkpoint(
|
| 628 |
+
self.space_point2virtual_blocks[j],
|
| 629 |
+
point_tokens, virtual_tokens, mask
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
| 633 |
+
tokens = space_tokens.view(B, T, N, -1).permute(
|
| 634 |
+
0, 2, 1, 3
|
| 635 |
+
) # (B T) N C -> B N T C
|
| 636 |
+
j += 1
|
| 637 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
| 638 |
+
|
| 639 |
+
flow = self.flow_head(tokens)
|
| 640 |
+
if self.linear_layer_for_vis_conf:
|
| 641 |
+
vis_conf = self.vis_conf_head(tokens)
|
| 642 |
+
flow = torch.cat([flow, vis_conf], dim=-1)
|
| 643 |
+
|
| 644 |
+
return flow
|
| 645 |
+
|
| 646 |
+
def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
|
| 647 |
+
probs = torch.sigmoid(logits)
|
| 648 |
+
ce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
|
| 649 |
+
p_t = probs * targets + (1 - probs) * (1 - targets)
|
| 650 |
+
loss = alpha * (1 - p_t) ** gamma * ce_loss
|
| 651 |
+
return loss.mean()
|
| 652 |
+
|
| 653 |
+
def balanced_binary_cross_entropy(logits, targets, balance_weight=1.0, eps=1e-6, reduction="mean", pos_bias=0.0, mask=None):
|
| 654 |
+
"""
|
| 655 |
+
logits: Tensor of arbitrary shape
|
| 656 |
+
targets: same shape as logits
|
| 657 |
+
balance_weight: scaling the loss
|
| 658 |
+
reduction: 'mean', 'sum', or 'none'
|
| 659 |
+
"""
|
| 660 |
+
targets = targets.float()
|
| 661 |
+
positive = (targets == 1).float().sum()
|
| 662 |
+
total = targets.numel()
|
| 663 |
+
positive_ratio = positive / (total + eps)
|
| 664 |
+
|
| 665 |
+
pos_weight = (1 - positive_ratio) / (positive_ratio + eps)
|
| 666 |
+
pos_weight = pos_weight.clamp(min=0.1, max=10.0)
|
| 667 |
+
loss = F.binary_cross_entropy_with_logits(
|
| 668 |
+
logits,
|
| 669 |
+
targets,
|
| 670 |
+
pos_weight=pos_weight+pos_bias,
|
| 671 |
+
reduction=reduction
|
| 672 |
+
)
|
| 673 |
+
if mask is not None:
|
| 674 |
+
loss = (loss * mask).sum() / (mask.sum() + eps)
|
| 675 |
+
return balance_weight * loss
|
| 676 |
+
|
| 677 |
+
def sequence_loss(
|
| 678 |
+
flow_preds,
|
| 679 |
+
flow_gt,
|
| 680 |
+
valids,
|
| 681 |
+
vis=None,
|
| 682 |
+
gamma=0.8,
|
| 683 |
+
add_huber_loss=False,
|
| 684 |
+
loss_only_for_visible=False,
|
| 685 |
+
depth_sample=None,
|
| 686 |
+
z_unc=None,
|
| 687 |
+
mask_traj_gt=None
|
| 688 |
+
):
|
| 689 |
+
"""Loss function defined over sequence of flow predictions"""
|
| 690 |
+
total_flow_loss = 0.0
|
| 691 |
+
for j in range(len(flow_gt)):
|
| 692 |
+
B, S, N, D = flow_gt[j].shape
|
| 693 |
+
B, S2, N = valids[j].shape
|
| 694 |
+
assert S == S2
|
| 695 |
+
n_predictions = len(flow_preds[j])
|
| 696 |
+
flow_loss = 0.0
|
| 697 |
+
for i in range(n_predictions):
|
| 698 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
| 699 |
+
flow_pred = flow_preds[j][i][:,:,:flow_gt[j].shape[2]]
|
| 700 |
+
if flow_pred.shape[-1] == 3:
|
| 701 |
+
flow_pred[...,2] = flow_pred[...,2]
|
| 702 |
+
if add_huber_loss:
|
| 703 |
+
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0)
|
| 704 |
+
else:
|
| 705 |
+
if flow_gt[j][...,2].abs().max() != 0:
|
| 706 |
+
track_z_loss = (flow_pred- flow_gt[j])[...,2].abs().mean()
|
| 707 |
+
if mask_traj_gt is not None:
|
| 708 |
+
track_z_loss = ((flow_pred- flow_gt[j])[...,2].abs() * mask_traj_gt.permute(0,2,1)).sum() / (mask_traj_gt.sum(dim=1)+1e-6)
|
| 709 |
+
else:
|
| 710 |
+
track_z_loss = 0
|
| 711 |
+
i_loss = (flow_pred[...,:2] - flow_gt[j][...,:2]).abs() # B, S, N, 2
|
| 712 |
+
# print((flow_pred - flow_gt[j])[...,2].abs()[vis[j].bool()].mean())
|
| 713 |
+
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
| 714 |
+
valid_ = valids[j].clone()[:,:, :flow_gt[j].shape[2]] # Ensure valid_ has the same shape as i_loss
|
| 715 |
+
valid_ = valid_ * (flow_gt[j][...,:2].norm(dim=-1) > 0).float()
|
| 716 |
+
if loss_only_for_visible:
|
| 717 |
+
valid_ = valid_ * vis[j]
|
| 718 |
+
# print(reduce_masked_mean(i_loss, valid_).item(), track_z_loss.item()/16)
|
| 719 |
+
flow_loss += i_weight * (reduce_masked_mean(i_loss, valid_) + track_z_loss + 10*reduce_masked_mean(i_loss, valid_* vis[j]))
|
| 720 |
+
# if flow_loss > 5e2:
|
| 721 |
+
# import pdb; pdb.set_trace()
|
| 722 |
+
flow_loss = flow_loss / n_predictions
|
| 723 |
+
total_flow_loss += flow_loss
|
| 724 |
+
return total_flow_loss / len(flow_gt)
|
| 725 |
+
|
| 726 |
+
def sequence_loss_xyz(
|
| 727 |
+
flow_preds,
|
| 728 |
+
flow_gt,
|
| 729 |
+
valids,
|
| 730 |
+
intrs,
|
| 731 |
+
vis=None,
|
| 732 |
+
gamma=0.8,
|
| 733 |
+
add_huber_loss=False,
|
| 734 |
+
loss_only_for_visible=False,
|
| 735 |
+
mask_traj_gt=None
|
| 736 |
+
):
|
| 737 |
+
"""Loss function defined over sequence of flow predictions"""
|
| 738 |
+
total_flow_loss = 0.0
|
| 739 |
+
for j in range(len(flow_gt)):
|
| 740 |
+
B, S, N, D = flow_gt[j].shape
|
| 741 |
+
B, S2, N = valids[j].shape
|
| 742 |
+
assert S == S2
|
| 743 |
+
n_predictions = len(flow_preds[j])
|
| 744 |
+
flow_loss = 0.0
|
| 745 |
+
for i in range(n_predictions):
|
| 746 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
| 747 |
+
flow_pred = flow_preds[j][i][:,:,:flow_gt[j].shape[2]]
|
| 748 |
+
flow_gt_ = flow_gt[j]
|
| 749 |
+
flow_gt_one = torch.cat([flow_gt_[...,:2], torch.ones_like(flow_gt_[:,:,:,:1])], dim=-1)
|
| 750 |
+
flow_gt_cam = torch.einsum('btsc,btnc->btns', torch.inverse(intrs), flow_gt_one)
|
| 751 |
+
flow_gt_cam *= flow_gt_[...,2:3].abs()
|
| 752 |
+
flow_gt_cam[...,2] *= torch.sign(flow_gt_cam[...,2])
|
| 753 |
+
|
| 754 |
+
if add_huber_loss:
|
| 755 |
+
i_loss = huber_loss(flow_pred, flow_gt_cam, delta=6.0)
|
| 756 |
+
else:
|
| 757 |
+
i_loss = (flow_pred- flow_gt_cam).norm(dim=-1,keepdim=True) # B, S, N, 2
|
| 758 |
+
|
| 759 |
+
# print((flow_pred - flow_gt[j])[...,2].abs()[vis[j].bool()].mean())
|
| 760 |
+
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
| 761 |
+
valid_ = valids[j].clone()[:,:, :flow_gt[j].shape[2]] # Ensure valid_ has the same shape as i_loss
|
| 762 |
+
if loss_only_for_visible:
|
| 763 |
+
valid_ = valid_ * vis[j]
|
| 764 |
+
# print(reduce_masked_mean(i_loss, valid_).item(), track_z_loss.item()/16)
|
| 765 |
+
flow_loss += i_weight * (reduce_masked_mean(i_loss, valid_)) * 1000
|
| 766 |
+
# if flow_loss > 5e2:
|
| 767 |
+
# import pdb; pdb.set_trace()
|
| 768 |
+
flow_loss = flow_loss / n_predictions
|
| 769 |
+
total_flow_loss += flow_loss
|
| 770 |
+
return total_flow_loss / len(flow_gt)
|
| 771 |
+
|
| 772 |
+
def huber_loss(x, y, delta=1.0):
|
| 773 |
+
"""Calculate element-wise Huber loss between x and y"""
|
| 774 |
+
diff = x - y
|
| 775 |
+
abs_diff = diff.abs()
|
| 776 |
+
flag = (abs_diff <= delta).float()
|
| 777 |
+
return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
def sequence_BCE_loss(vis_preds, vis_gts, mask=None):
|
| 781 |
+
total_bce_loss = 0.0
|
| 782 |
+
for j in range(len(vis_preds)):
|
| 783 |
+
n_predictions = len(vis_preds[j])
|
| 784 |
+
bce_loss = 0.0
|
| 785 |
+
for i in range(n_predictions):
|
| 786 |
+
N_gt = vis_gts[j].shape[-1]
|
| 787 |
+
if mask is not None:
|
| 788 |
+
vis_loss = balanced_binary_cross_entropy(vis_preds[j][i][...,:N_gt], vis_gts[j], mask=mask[j], reduction="none")
|
| 789 |
+
else:
|
| 790 |
+
vis_loss = balanced_binary_cross_entropy(vis_preds[j][i][...,:N_gt], vis_gts[j]) + focal_loss(vis_preds[j][i][...,:N_gt], vis_gts[j])
|
| 791 |
+
# print(vis_loss, ((torch.sigmoid(vis_preds[j][i][...,:N_gt])>0.5).float() - vis_gts[j]).abs().sum())
|
| 792 |
+
bce_loss += vis_loss
|
| 793 |
+
bce_loss = bce_loss / n_predictions
|
| 794 |
+
total_bce_loss += bce_loss
|
| 795 |
+
return total_bce_loss / len(vis_preds)
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
def sequence_prob_loss(
|
| 799 |
+
tracks: torch.Tensor,
|
| 800 |
+
confidence: torch.Tensor,
|
| 801 |
+
target_points: torch.Tensor,
|
| 802 |
+
visibility: torch.Tensor,
|
| 803 |
+
expected_dist_thresh: float = 12.0,
|
| 804 |
+
):
|
| 805 |
+
"""Loss for classifying if a point is within pixel threshold of its target."""
|
| 806 |
+
# Points with an error larger than 12 pixels are likely to be useless; marking
|
| 807 |
+
# them as occluded will actually improve Jaccard metrics and give
|
| 808 |
+
# qualitatively better results.
|
| 809 |
+
total_logprob_loss = 0.0
|
| 810 |
+
for j in range(len(tracks)):
|
| 811 |
+
n_predictions = len(tracks[j])
|
| 812 |
+
logprob_loss = 0.0
|
| 813 |
+
for i in range(n_predictions):
|
| 814 |
+
N_gt = target_points[j].shape[2]
|
| 815 |
+
err = torch.sum((tracks[j][i].detach()[:,:,:N_gt,:2] - target_points[j][...,:2]) ** 2, dim=-1)
|
| 816 |
+
valid = (err <= expected_dist_thresh**2).float()
|
| 817 |
+
logprob = balanced_binary_cross_entropy(confidence[j][i][...,:N_gt], valid, reduction="none")
|
| 818 |
+
logprob *= visibility[j]
|
| 819 |
+
logprob = torch.mean(logprob, dim=[1, 2])
|
| 820 |
+
logprob_loss += logprob
|
| 821 |
+
logprob_loss = logprob_loss / n_predictions
|
| 822 |
+
total_logprob_loss += logprob_loss
|
| 823 |
+
return total_logprob_loss / len(tracks)
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
def sequence_dyn_prob_loss(
|
| 827 |
+
tracks: torch.Tensor,
|
| 828 |
+
confidence: torch.Tensor,
|
| 829 |
+
target_points: torch.Tensor,
|
| 830 |
+
visibility: torch.Tensor,
|
| 831 |
+
expected_dist_thresh: float = 6.0,
|
| 832 |
+
):
|
| 833 |
+
"""Loss for classifying if a point is within pixel threshold of its target."""
|
| 834 |
+
# Points with an error larger than 12 pixels are likely to be useless; marking
|
| 835 |
+
# them as occluded will actually improve Jaccard metrics and give
|
| 836 |
+
# qualitatively better results.
|
| 837 |
+
total_logprob_loss = 0.0
|
| 838 |
+
for j in range(len(tracks)):
|
| 839 |
+
n_predictions = len(tracks[j])
|
| 840 |
+
logprob_loss = 0.0
|
| 841 |
+
for i in range(n_predictions):
|
| 842 |
+
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1)
|
| 843 |
+
valid = (err <= expected_dist_thresh**2).float()
|
| 844 |
+
valid = (valid.sum(dim=1) > 0).float()
|
| 845 |
+
logprob = balanced_binary_cross_entropy(confidence[j][i].mean(dim=1), valid, reduction="none")
|
| 846 |
+
# logprob *= visibility[j]
|
| 847 |
+
logprob = torch.mean(logprob, dim=[0, 1])
|
| 848 |
+
logprob_loss += logprob
|
| 849 |
+
logprob_loss = logprob_loss / n_predictions
|
| 850 |
+
total_logprob_loss += logprob_loss
|
| 851 |
+
return total_logprob_loss / len(tracks)
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
def masked_mean(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
|
| 855 |
+
if mask is None:
|
| 856 |
+
return data.mean(dim=dim, keepdim=True)
|
| 857 |
+
mask = mask.float()
|
| 858 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
| 859 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
| 860 |
+
mask_sum, min=1.0
|
| 861 |
+
)
|
| 862 |
+
return mask_mean
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
|
| 866 |
+
if mask is None:
|
| 867 |
+
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
|
| 868 |
+
mask = mask.float()
|
| 869 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
| 870 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
| 871 |
+
mask_sum, min=1.0
|
| 872 |
+
)
|
| 873 |
+
mask_var = torch.sum(
|
| 874 |
+
mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
|
| 875 |
+
) / torch.clamp(mask_sum, min=1.0)
|
| 876 |
+
return mask_mean.squeeze(dim), mask_var.squeeze(dim)
|
| 877 |
+
|
| 878 |
+
class NeighborTransformer(nn.Module):
|
| 879 |
+
def __init__(self, dim: int, num_heads: int, head_dim: int, mlp_ratio: float):
|
| 880 |
+
super().__init__()
|
| 881 |
+
self.dim = dim
|
| 882 |
+
self.output_token_1 = nn.Parameter(torch.randn(1, dim))
|
| 883 |
+
self.output_token_2 = nn.Parameter(torch.randn(1, dim))
|
| 884 |
+
self.xblock1_2 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
|
| 885 |
+
self.xblock2_1 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
|
| 886 |
+
self.aggr1 = Attention(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim)
|
| 887 |
+
self.aggr2 = Attention(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim)
|
| 888 |
+
|
| 889 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 890 |
+
from einops import rearrange, repeat
|
| 891 |
+
import torch.utils.checkpoint as checkpoint
|
| 892 |
+
|
| 893 |
+
assert len (x.shape) == 3, "x should be of shape (B, N, D)"
|
| 894 |
+
assert len (y.shape) == 3, "y should be of shape (B, N, D)"
|
| 895 |
+
|
| 896 |
+
# not work so well ...
|
| 897 |
+
|
| 898 |
+
def forward_chunk(x, y):
|
| 899 |
+
new_x = self.xblock1_2(x, y)
|
| 900 |
+
new_y = self.xblock2_1(y, x)
|
| 901 |
+
out1 = self.aggr1(repeat(self.output_token_1, 'n d -> b n d', b=x.shape[0]), context=new_x)
|
| 902 |
+
out2 = self.aggr2(repeat(self.output_token_2, 'n d -> b n d', b=x.shape[0]), context=new_y)
|
| 903 |
+
return out1 + out2
|
| 904 |
+
|
| 905 |
+
return checkpoint.checkpoint(forward_chunk, x, y)
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
class CorrPointformer(nn.Module):
|
| 909 |
+
def __init__(self, dim: int, num_heads: int, head_dim: int, mlp_ratio: float):
|
| 910 |
+
super().__init__()
|
| 911 |
+
self.dim = dim
|
| 912 |
+
self.xblock1_2 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
|
| 913 |
+
# self.xblock2_1 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
|
| 914 |
+
self.aggr = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
|
| 915 |
+
self.out_proj = nn.Linear(dim, 2*dim)
|
| 916 |
+
|
| 917 |
+
def forward(self, query: torch.Tensor, target: torch.Tensor, target_rel_pos: torch.Tensor) -> torch.Tensor:
|
| 918 |
+
from einops import rearrange, repeat
|
| 919 |
+
import torch.utils.checkpoint as checkpoint
|
| 920 |
+
|
| 921 |
+
def forward_chunk(query, target, target_rel_pos):
|
| 922 |
+
new_query = self.xblock1_2(query, target).mean(dim=1, keepdim=True)
|
| 923 |
+
# new_target = self.xblock2_1(target, query).mean(dim=1, keepdim=True)
|
| 924 |
+
# new_aggr = new_query + new_target
|
| 925 |
+
out = self.aggr(new_query, target+target_rel_pos) # (potential delta xyz) (target - center)
|
| 926 |
+
out = self.out_proj(out)
|
| 927 |
+
return out
|
| 928 |
+
|
| 929 |
+
return checkpoint.checkpoint(forward_chunk, query, target, target_rel_pos)
|
models/SpaTrackV2/models/tracker3D/delta_utils/__init__.py
ADDED
|
File without changes
|
models/SpaTrackV2/models/tracker3D/delta_utils/blocks.py
ADDED
|
@@ -0,0 +1,842 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import collections
|
| 8 |
+
from functools import partial
|
| 9 |
+
from itertools import repeat
|
| 10 |
+
from typing import Callable
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from models.SpaTrackV2.models.blocks import bilinear_sampler
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
from torch import Tensor, einsum
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# From PyTorch internals
|
| 22 |
+
def _ntuple(n):
|
| 23 |
+
def parse(x):
|
| 24 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 25 |
+
return tuple(x)
|
| 26 |
+
return tuple(repeat(x, n))
|
| 27 |
+
|
| 28 |
+
return parse
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def exists(val):
|
| 32 |
+
return val is not None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def default(val, d):
|
| 36 |
+
return val if exists(val) else d
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
to_2tuple = _ntuple(2)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Mlp(nn.Module):
|
| 43 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
in_features,
|
| 48 |
+
hidden_features=None,
|
| 49 |
+
out_features=None,
|
| 50 |
+
act_layer=nn.GELU,
|
| 51 |
+
norm_layer=None,
|
| 52 |
+
bias=True,
|
| 53 |
+
drop=0.0,
|
| 54 |
+
use_conv=False,
|
| 55 |
+
zero_init=False,
|
| 56 |
+
):
|
| 57 |
+
super().__init__()
|
| 58 |
+
out_features = out_features or in_features
|
| 59 |
+
hidden_features = hidden_features or in_features
|
| 60 |
+
bias = to_2tuple(bias)
|
| 61 |
+
drop_probs = to_2tuple(drop)
|
| 62 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 63 |
+
|
| 64 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
| 65 |
+
self.act = act_layer()
|
| 66 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 67 |
+
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
| 68 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
| 69 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 70 |
+
|
| 71 |
+
if zero_init:
|
| 72 |
+
self.zero_init()
|
| 73 |
+
|
| 74 |
+
def zero_init(self):
|
| 75 |
+
nn.init.constant_(self.fc2.weight, 0)
|
| 76 |
+
if self.fc2.bias is not None:
|
| 77 |
+
nn.init.constant_(self.fc2.bias, 0)
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
x = self.fc1(x)
|
| 81 |
+
x = self.act(x)
|
| 82 |
+
x = self.drop1(x)
|
| 83 |
+
x = self.fc2(x)
|
| 84 |
+
x = self.drop2(x)
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class Upsample(nn.Module):
|
| 89 |
+
def __init__(self, in_channels, with_conv):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.with_conv = with_conv
|
| 92 |
+
if self.with_conv:
|
| 93 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 94 |
+
|
| 95 |
+
def forward(self, x, mode="nearest"):
|
| 96 |
+
x = F.interpolate(x, scale_factor=2.0, mode=mode)
|
| 97 |
+
if self.with_conv:
|
| 98 |
+
x = self.conv(x)
|
| 99 |
+
return x
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ResidualBlock(nn.Module):
|
| 103 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
| 104 |
+
super(ResidualBlock, self).__init__()
|
| 105 |
+
|
| 106 |
+
self.conv1 = nn.Conv2d(
|
| 107 |
+
in_planes,
|
| 108 |
+
planes,
|
| 109 |
+
kernel_size=3,
|
| 110 |
+
padding=1,
|
| 111 |
+
stride=stride,
|
| 112 |
+
padding_mode="zeros",
|
| 113 |
+
)
|
| 114 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
|
| 115 |
+
self.relu = nn.ReLU(inplace=True)
|
| 116 |
+
|
| 117 |
+
num_groups = planes // 8
|
| 118 |
+
|
| 119 |
+
if norm_fn == "group":
|
| 120 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 121 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 122 |
+
if not stride == 1:
|
| 123 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 124 |
+
|
| 125 |
+
elif norm_fn == "batch":
|
| 126 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 127 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 128 |
+
if not stride == 1:
|
| 129 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 130 |
+
|
| 131 |
+
elif norm_fn == "instance":
|
| 132 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
| 133 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
| 134 |
+
if not stride == 1:
|
| 135 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 136 |
+
|
| 137 |
+
elif norm_fn == "none":
|
| 138 |
+
self.norm1 = nn.Sequential()
|
| 139 |
+
self.norm2 = nn.Sequential()
|
| 140 |
+
if not stride == 1:
|
| 141 |
+
self.norm3 = nn.Sequential()
|
| 142 |
+
|
| 143 |
+
if stride == 1:
|
| 144 |
+
self.downsample = None
|
| 145 |
+
|
| 146 |
+
else:
|
| 147 |
+
self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
| 148 |
+
|
| 149 |
+
def forward(self, x):
|
| 150 |
+
y = x
|
| 151 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 152 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 153 |
+
|
| 154 |
+
if self.downsample is not None:
|
| 155 |
+
x = self.downsample(x)
|
| 156 |
+
|
| 157 |
+
return self.relu(x + y)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class BasicEncoder(nn.Module):
|
| 161 |
+
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
| 162 |
+
super(BasicEncoder, self).__init__()
|
| 163 |
+
self.stride = stride
|
| 164 |
+
self.norm_fn = "instance"
|
| 165 |
+
self.in_planes = output_dim // 2
|
| 166 |
+
|
| 167 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
| 168 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
| 169 |
+
|
| 170 |
+
self.conv1 = nn.Conv2d(
|
| 171 |
+
input_dim,
|
| 172 |
+
self.in_planes,
|
| 173 |
+
kernel_size=7,
|
| 174 |
+
stride=2,
|
| 175 |
+
padding=3,
|
| 176 |
+
padding_mode="zeros",
|
| 177 |
+
)
|
| 178 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 179 |
+
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
| 180 |
+
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
| 181 |
+
self.layer3 = self._make_layer(output_dim, stride=2)
|
| 182 |
+
self.layer4 = self._make_layer(output_dim, stride=2)
|
| 183 |
+
|
| 184 |
+
self.conv2 = nn.Conv2d(
|
| 185 |
+
output_dim * 3 + output_dim // 4,
|
| 186 |
+
output_dim * 2,
|
| 187 |
+
kernel_size=3,
|
| 188 |
+
padding=1,
|
| 189 |
+
padding_mode="zeros",
|
| 190 |
+
)
|
| 191 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 192 |
+
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
| 193 |
+
for m in self.modules():
|
| 194 |
+
if isinstance(m, nn.Conv2d):
|
| 195 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 196 |
+
elif isinstance(m, (nn.InstanceNorm2d)):
|
| 197 |
+
if m.weight is not None:
|
| 198 |
+
nn.init.constant_(m.weight, 1)
|
| 199 |
+
if m.bias is not None:
|
| 200 |
+
nn.init.constant_(m.bias, 0)
|
| 201 |
+
|
| 202 |
+
def _make_layer(self, dim, stride=1):
|
| 203 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 204 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 205 |
+
layers = (layer1, layer2)
|
| 206 |
+
|
| 207 |
+
self.in_planes = dim
|
| 208 |
+
return nn.Sequential(*layers)
|
| 209 |
+
|
| 210 |
+
def forward(self, x, return_intermediate=False):
|
| 211 |
+
_, _, H, W = x.shape
|
| 212 |
+
|
| 213 |
+
x = self.conv1(x)
|
| 214 |
+
x = self.norm1(x)
|
| 215 |
+
x = self.relu1(x)
|
| 216 |
+
|
| 217 |
+
a = self.layer1(x)
|
| 218 |
+
b = self.layer2(a)
|
| 219 |
+
c = self.layer3(b)
|
| 220 |
+
d = self.layer4(c)
|
| 221 |
+
|
| 222 |
+
def _bilinear_intepolate(x):
|
| 223 |
+
return F.interpolate(
|
| 224 |
+
x,
|
| 225 |
+
(H // self.stride, W // self.stride),
|
| 226 |
+
mode="bilinear",
|
| 227 |
+
align_corners=True,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# a = _bilinear_intepolate(a)
|
| 231 |
+
# b = _bilinear_intepolate(b)
|
| 232 |
+
# c = _bilinear_intepolate(c)
|
| 233 |
+
# d = _bilinear_intepolate(d)
|
| 234 |
+
|
| 235 |
+
cat_feat = torch.cat(
|
| 236 |
+
[_bilinear_intepolate(a), _bilinear_intepolate(b), _bilinear_intepolate(c), _bilinear_intepolate(d)], dim=1
|
| 237 |
+
)
|
| 238 |
+
x = self.conv2(cat_feat)
|
| 239 |
+
x = self.norm2(x)
|
| 240 |
+
x = self.relu2(x)
|
| 241 |
+
x = self.conv3(x)
|
| 242 |
+
|
| 243 |
+
# breakpoint()
|
| 244 |
+
if return_intermediate:
|
| 245 |
+
if self.stride == 4:
|
| 246 |
+
return x, a, c # 128, h/4, w/4, - 64, h/2, w/2 - 128, h/8, w/8
|
| 247 |
+
elif self.stride == 8:
|
| 248 |
+
return x, b, d
|
| 249 |
+
else:
|
| 250 |
+
raise NotImplementedError
|
| 251 |
+
return x
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class CorrBlockFP16:
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
fmaps,
|
| 258 |
+
num_levels=4,
|
| 259 |
+
radius=4,
|
| 260 |
+
multiple_track_feats=False,
|
| 261 |
+
padding_mode="zeros",
|
| 262 |
+
):
|
| 263 |
+
B, S, C, H, W = fmaps.shape
|
| 264 |
+
self.S, self.C, self.H, self.W = S, C, H, W
|
| 265 |
+
self.padding_mode = padding_mode
|
| 266 |
+
self.num_levels = num_levels
|
| 267 |
+
self.radius = radius
|
| 268 |
+
self.fmaps_pyramid = []
|
| 269 |
+
self.multiple_track_feats = multiple_track_feats
|
| 270 |
+
|
| 271 |
+
self.fmaps_pyramid.append(fmaps)
|
| 272 |
+
for i in range(self.num_levels - 1):
|
| 273 |
+
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
| 274 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
| 275 |
+
_, _, H, W = fmaps_.shape
|
| 276 |
+
fmaps = fmaps_.reshape(B, S, C, H, W)
|
| 277 |
+
self.fmaps_pyramid.append(fmaps)
|
| 278 |
+
|
| 279 |
+
def sample(self, coords):
|
| 280 |
+
r = self.radius
|
| 281 |
+
B, S, N, D = coords.shape
|
| 282 |
+
assert D == 2
|
| 283 |
+
|
| 284 |
+
H, W = self.H, self.W
|
| 285 |
+
out_pyramid = []
|
| 286 |
+
for i in range(self.num_levels):
|
| 287 |
+
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
| 288 |
+
*_, H, W = corrs.shape
|
| 289 |
+
|
| 290 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
| 291 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
| 292 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
|
| 293 |
+
|
| 294 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
|
| 295 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
| 296 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 297 |
+
|
| 298 |
+
# breakpoint()
|
| 299 |
+
corrs = bilinear_sampler(
|
| 300 |
+
corrs.reshape(B * S * N, 1, H, W),
|
| 301 |
+
coords_lvl,
|
| 302 |
+
padding_mode=self.padding_mode,
|
| 303 |
+
)
|
| 304 |
+
corrs = corrs.view(B, S, N, -1)
|
| 305 |
+
out_pyramid.append(corrs)
|
| 306 |
+
|
| 307 |
+
del self.corrs_pyramid
|
| 308 |
+
|
| 309 |
+
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
| 310 |
+
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
|
| 311 |
+
return out
|
| 312 |
+
|
| 313 |
+
def corr(self, targets):
|
| 314 |
+
B, S, N, C = targets.shape
|
| 315 |
+
if self.multiple_track_feats:
|
| 316 |
+
targets_split = targets.split(C // self.num_levels, dim=-1)
|
| 317 |
+
B, S, N, C = targets_split[0].shape
|
| 318 |
+
|
| 319 |
+
assert C == self.C
|
| 320 |
+
assert S == self.S
|
| 321 |
+
|
| 322 |
+
fmap1 = targets
|
| 323 |
+
|
| 324 |
+
self.corrs_pyramid = []
|
| 325 |
+
for i, fmaps in enumerate(self.fmaps_pyramid):
|
| 326 |
+
*_, H, W = fmaps.shape
|
| 327 |
+
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
|
| 328 |
+
if self.multiple_track_feats:
|
| 329 |
+
fmap1 = targets_split[i]
|
| 330 |
+
corrs = torch.matmul(fmap1, fmap2s)
|
| 331 |
+
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
|
| 332 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
| 333 |
+
# breakpoint()
|
| 334 |
+
self.corrs_pyramid.append(corrs)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class CorrBlock:
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
fmaps,
|
| 341 |
+
num_levels=4,
|
| 342 |
+
radius=4,
|
| 343 |
+
multiple_track_feats=False,
|
| 344 |
+
padding_mode="zeros",
|
| 345 |
+
):
|
| 346 |
+
B, S, C, H, W = fmaps.shape
|
| 347 |
+
self.S, self.C, self.H, self.W = S, C, H, W
|
| 348 |
+
self.padding_mode = padding_mode
|
| 349 |
+
self.num_levels = num_levels
|
| 350 |
+
self.radius = radius
|
| 351 |
+
self.fmaps_pyramid = []
|
| 352 |
+
self.multiple_track_feats = multiple_track_feats
|
| 353 |
+
|
| 354 |
+
self.fmaps_pyramid.append(fmaps)
|
| 355 |
+
for i in range(self.num_levels - 1):
|
| 356 |
+
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
| 357 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
| 358 |
+
_, _, H, W = fmaps_.shape
|
| 359 |
+
fmaps = fmaps_.reshape(B, S, C, H, W)
|
| 360 |
+
self.fmaps_pyramid.append(fmaps)
|
| 361 |
+
|
| 362 |
+
def sample(self, coords, delete=True):
|
| 363 |
+
r = self.radius
|
| 364 |
+
B, S, N, D = coords.shape
|
| 365 |
+
assert D == 2
|
| 366 |
+
|
| 367 |
+
H, W = self.H, self.W
|
| 368 |
+
out_pyramid = []
|
| 369 |
+
for i in range(self.num_levels):
|
| 370 |
+
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
| 371 |
+
*_, H, W = corrs.shape
|
| 372 |
+
|
| 373 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
| 374 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
| 375 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
|
| 376 |
+
|
| 377 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
|
| 378 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
| 379 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 380 |
+
|
| 381 |
+
# breakpoint()
|
| 382 |
+
|
| 383 |
+
# t1 = time.time()
|
| 384 |
+
corrs = bilinear_sampler(
|
| 385 |
+
corrs.reshape(B * S * N, 1, H, W),
|
| 386 |
+
coords_lvl,
|
| 387 |
+
padding_mode=self.padding_mode,
|
| 388 |
+
)
|
| 389 |
+
# t2 = time.time()
|
| 390 |
+
|
| 391 |
+
# print(coords_lvl.shape, t2 - t1)
|
| 392 |
+
corrs = corrs.view(B, S, N, -1)
|
| 393 |
+
out_pyramid.append(corrs)
|
| 394 |
+
|
| 395 |
+
if delete:
|
| 396 |
+
del self.corrs_pyramid
|
| 397 |
+
|
| 398 |
+
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
| 399 |
+
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
|
| 400 |
+
return out
|
| 401 |
+
|
| 402 |
+
def corr(self, targets):
|
| 403 |
+
B, S, N, C = targets.shape
|
| 404 |
+
if self.multiple_track_feats:
|
| 405 |
+
targets_split = targets.split(C // self.num_levels, dim=-1)
|
| 406 |
+
B, S, N, C = targets_split[0].shape
|
| 407 |
+
|
| 408 |
+
assert C == self.C
|
| 409 |
+
assert S == self.S
|
| 410 |
+
|
| 411 |
+
fmap1 = targets
|
| 412 |
+
|
| 413 |
+
self.corrs_pyramid = []
|
| 414 |
+
for i, fmaps in enumerate(self.fmaps_pyramid):
|
| 415 |
+
*_, H, W = fmaps.shape
|
| 416 |
+
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
|
| 417 |
+
if self.multiple_track_feats:
|
| 418 |
+
fmap1 = targets_split[i]
|
| 419 |
+
corrs = torch.matmul(fmap1, fmap2s)
|
| 420 |
+
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
|
| 421 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
| 422 |
+
# breakpoint()
|
| 423 |
+
self.corrs_pyramid.append(corrs)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class Attention(nn.Module):
|
| 427 |
+
def __init__(
|
| 428 |
+
self,
|
| 429 |
+
query_dim,
|
| 430 |
+
context_dim=None,
|
| 431 |
+
num_heads=8,
|
| 432 |
+
dim_head=48,
|
| 433 |
+
qkv_bias=False,
|
| 434 |
+
flash=False,
|
| 435 |
+
alibi=False,
|
| 436 |
+
zero_init=False,
|
| 437 |
+
):
|
| 438 |
+
super().__init__()
|
| 439 |
+
inner_dim = dim_head * num_heads
|
| 440 |
+
context_dim = default(context_dim, query_dim)
|
| 441 |
+
self.scale = dim_head**-0.5
|
| 442 |
+
self.heads = num_heads
|
| 443 |
+
|
| 444 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
| 445 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
| 446 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
| 447 |
+
|
| 448 |
+
self.flash = flash
|
| 449 |
+
self.alibi = alibi
|
| 450 |
+
|
| 451 |
+
if zero_init:
|
| 452 |
+
self.zero_init()
|
| 453 |
+
# if self.alibi:
|
| 454 |
+
# self.training_length = 24
|
| 455 |
+
|
| 456 |
+
# bias_forward = get_alibi_slope(self.heads // 2) * get_relative_positions(self.training_length)
|
| 457 |
+
# bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
|
| 458 |
+
# bias_backward = get_alibi_slope(self.heads // 2) * get_relative_positions(self.training_length, reverse=True)
|
| 459 |
+
# bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
|
| 460 |
+
|
| 461 |
+
# self.precomputed_attn_bias = self.register_buffer("precomputed_attn_bias", torch.cat([bias_forward, bias_backward], dim=0), persistent=False)
|
| 462 |
+
|
| 463 |
+
def zero_init(self):
|
| 464 |
+
nn.init.constant_(self.to_out.weight, 0)
|
| 465 |
+
nn.init.constant_(self.to_out.bias, 0)
|
| 466 |
+
|
| 467 |
+
# breakpoint()
|
| 468 |
+
|
| 469 |
+
def forward(self, x, context=None, attn_bias=None):
|
| 470 |
+
B, N1, C = x.shape
|
| 471 |
+
h = self.heads
|
| 472 |
+
|
| 473 |
+
q = self.to_q(x).reshape(B, N1, h, C // h)
|
| 474 |
+
context = default(context, x)
|
| 475 |
+
N2 = context.shape[1]
|
| 476 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
| 477 |
+
k = k.reshape(B, N2, h, C // h)
|
| 478 |
+
v = v.reshape(B, N2, h, C // h)
|
| 479 |
+
|
| 480 |
+
if self.flash:
|
| 481 |
+
with torch.autocast(device_type="cuda", enabled=True):
|
| 482 |
+
x = flash_attn_func(q.half(), k.half(), v.half())
|
| 483 |
+
x = x.reshape(B, N1, C)
|
| 484 |
+
x = x.float()
|
| 485 |
+
else:
|
| 486 |
+
q = q.permute(0, 2, 1, 3)
|
| 487 |
+
k = k.permute(0, 2, 1, 3)
|
| 488 |
+
v = v.permute(0, 2, 1, 3)
|
| 489 |
+
|
| 490 |
+
sim = (q @ k.transpose(-2, -1)) * self.scale
|
| 491 |
+
|
| 492 |
+
if attn_bias is not None:
|
| 493 |
+
sim = sim + attn_bias
|
| 494 |
+
attn = sim.softmax(dim=-1)
|
| 495 |
+
|
| 496 |
+
x = attn @ v
|
| 497 |
+
x = x.transpose(1, 2).reshape(B, N1, C)
|
| 498 |
+
x = self.to_out(x)
|
| 499 |
+
return x
|
| 500 |
+
|
| 501 |
+
def forward_noattn(self, x):
|
| 502 |
+
# B, N1, C = x.shape
|
| 503 |
+
# h = self.heads
|
| 504 |
+
_, x = self.to_kv(x).chunk(2, dim=-1)
|
| 505 |
+
# x = x.reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
| 506 |
+
# x = x.transpose(1, 2).reshape(B, N1, C)
|
| 507 |
+
x = self.to_out(x)
|
| 508 |
+
|
| 509 |
+
return x
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def get_relative_positions(seq_len, reverse=False, device="cpu"):
|
| 513 |
+
x = torch.arange(seq_len, device=device)[None, :]
|
| 514 |
+
y = torch.arange(seq_len, device=device)[:, None]
|
| 515 |
+
return torch.tril(x - y) if not reverse else torch.triu(y - x)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def get_alibi_slope(num_heads, device="cpu"):
|
| 519 |
+
x = (24) ** (1 / num_heads)
|
| 520 |
+
return torch.tensor([1 / x ** (i + 1) for i in range(num_heads)], device=device, dtype=torch.float32).view(
|
| 521 |
+
-1, 1, 1
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
class RelativeAttention(nn.Module):
|
| 526 |
+
"""Multi-headed attention (MHA) module."""
|
| 527 |
+
|
| 528 |
+
def __init__(self, query_dim, num_heads=8, qkv_bias=True, model_size=None, flash=False):
|
| 529 |
+
super(RelativeAttention, self).__init__()
|
| 530 |
+
|
| 531 |
+
query_dim = query_dim // num_heads
|
| 532 |
+
self.num_heads = num_heads
|
| 533 |
+
self.query_dim = query_dim
|
| 534 |
+
self.value_size = query_dim
|
| 535 |
+
self.model_size = query_dim * num_heads
|
| 536 |
+
|
| 537 |
+
self.qkv_bias = qkv_bias
|
| 538 |
+
|
| 539 |
+
self.query_proj = nn.Linear(num_heads * query_dim, num_heads * query_dim, bias=qkv_bias)
|
| 540 |
+
self.key_proj = nn.Linear(num_heads * query_dim, num_heads * query_dim, bias=qkv_bias)
|
| 541 |
+
self.value_proj = nn.Linear(num_heads * self.value_size, num_heads * self.value_size, bias=qkv_bias)
|
| 542 |
+
self.final_proj = nn.Linear(num_heads * self.value_size, self.model_size, bias=qkv_bias)
|
| 543 |
+
|
| 544 |
+
self.training_length = 24
|
| 545 |
+
|
| 546 |
+
bias_forward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(self.training_length)
|
| 547 |
+
bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
|
| 548 |
+
bias_backward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(
|
| 549 |
+
self.training_length, reverse=True
|
| 550 |
+
)
|
| 551 |
+
bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
|
| 552 |
+
|
| 553 |
+
self.register_buffer(
|
| 554 |
+
"precomputed_attn_bias", torch.cat([bias_forward, bias_backward], dim=0), persistent=False
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
def forward(self, x, attn_bias=None):
|
| 558 |
+
batch_size, sequence_length, _ = x.size()
|
| 559 |
+
|
| 560 |
+
query_heads = self._linear_projection(x, self.query_dim, self.query_proj) # [T', H, Q=K]
|
| 561 |
+
key_heads = self._linear_projection(x, self.query_dim, self.key_proj) # [T, H, K]
|
| 562 |
+
value_heads = self._linear_projection(x, self.value_size, self.value_proj) # [T, H, V]
|
| 563 |
+
|
| 564 |
+
if self.training_length == sequence_length:
|
| 565 |
+
new_attn_bias = self.precomputed_attn_bias
|
| 566 |
+
else:
|
| 567 |
+
device = x.device
|
| 568 |
+
bias_forward = get_alibi_slope(self.num_heads // 2, device=device) * get_relative_positions(
|
| 569 |
+
sequence_length, device=device
|
| 570 |
+
)
|
| 571 |
+
bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
|
| 572 |
+
bias_backward = get_alibi_slope(self.num_heads // 2, device=device) * get_relative_positions(
|
| 573 |
+
sequence_length, reverse=True, device=device
|
| 574 |
+
)
|
| 575 |
+
bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
|
| 576 |
+
new_attn_bias = torch.cat([bias_forward, bias_backward], dim=0)
|
| 577 |
+
|
| 578 |
+
if attn_bias is not None:
|
| 579 |
+
attn_bias = attn_bias + new_attn_bias
|
| 580 |
+
else:
|
| 581 |
+
attn_bias = new_attn_bias
|
| 582 |
+
|
| 583 |
+
attn = F.scaled_dot_product_attention(
|
| 584 |
+
query_heads, key_heads, value_heads, attn_mask=new_attn_bias, scale=1 / np.sqrt(self.query_dim)
|
| 585 |
+
)
|
| 586 |
+
attn = attn.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, -1)
|
| 587 |
+
|
| 588 |
+
return self.final_proj(attn) # [T', D']
|
| 589 |
+
|
| 590 |
+
# attn_logits = torch.einsum("...thd,...Thd->...htT", query_heads, key_heads)
|
| 591 |
+
# attn_logits = attn_logits / np.sqrt(self.query_dim) + new_attn_bias
|
| 592 |
+
|
| 593 |
+
# # breakpoint()
|
| 594 |
+
# if attn_bias is not None:
|
| 595 |
+
# if attn_bias.ndim != attn_logits.ndim:
|
| 596 |
+
# raise ValueError(f"Mask dimensionality {attn_bias.ndim} must match logits dimensionality {attn_logits.ndim}.")
|
| 597 |
+
# attn_logits = torch.where(attn_bias, attn_logits, torch.tensor(-1e30))
|
| 598 |
+
|
| 599 |
+
# attn_weights = F.softmax(attn_logits, dim=-1) # [H, T', T]
|
| 600 |
+
|
| 601 |
+
# attn = torch.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
|
| 602 |
+
# attn = attn.reshape(batch_size, sequence_length, -1) # [T', H*V]
|
| 603 |
+
|
| 604 |
+
# return self.final_proj(attn) # [T', D']
|
| 605 |
+
|
| 606 |
+
# def _linear_projection(self, x, head_size, proj_layer):
|
| 607 |
+
# y = proj_layer(x)
|
| 608 |
+
# *leading_dims, _ = x.shape
|
| 609 |
+
# return y.reshape((*leading_dims, self.num_heads, head_size))
|
| 610 |
+
|
| 611 |
+
def _linear_projection(self, x, head_size, proj_layer):
|
| 612 |
+
y = proj_layer(x)
|
| 613 |
+
batch_size, sequence_length, _ = x.shape
|
| 614 |
+
return y.reshape((batch_size, sequence_length, self.num_heads, head_size)).permute(0, 2, 1, 3)
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
class AttnBlock(nn.Module):
|
| 618 |
+
def __init__(
|
| 619 |
+
self, hidden_size, num_heads, attn_class: Callable[..., nn.Module] = Attention, mlp_ratio=4.0, **block_kwargs
|
| 620 |
+
):
|
| 621 |
+
super().__init__()
|
| 622 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 623 |
+
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 624 |
+
|
| 625 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 626 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 627 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 628 |
+
self.mlp = Mlp(
|
| 629 |
+
in_features=hidden_size,
|
| 630 |
+
hidden_features=mlp_hidden_dim,
|
| 631 |
+
act_layer=approx_gelu,
|
| 632 |
+
drop=0,
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
def forward(self, x, mask=None):
|
| 636 |
+
attn_bias = mask
|
| 637 |
+
if mask is not None:
|
| 638 |
+
mask = (mask[:, None] * mask[:, :, None]).unsqueeze(1).expand(-1, self.attn.heads, -1, -1)
|
| 639 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
| 640 |
+
attn_bias = (~mask) * max_neg_value
|
| 641 |
+
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 642 |
+
x = x + self.mlp(self.norm2(x))
|
| 643 |
+
return x
|
| 644 |
+
|
| 645 |
+
def forward_noattn(self, x):
|
| 646 |
+
x = x + self.attn.forward_noattn(self.norm1(x))
|
| 647 |
+
x = x + self.mlp(self.norm2(x))
|
| 648 |
+
return x
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
def pix2cam(coords, intr, detach=True):
|
| 652 |
+
"""
|
| 653 |
+
Args:
|
| 654 |
+
coords: [B, T, N, 3]
|
| 655 |
+
intr: [B, T, 3, 3]
|
| 656 |
+
"""
|
| 657 |
+
if detach:
|
| 658 |
+
coords = coords.detach()
|
| 659 |
+
|
| 660 |
+
(
|
| 661 |
+
B,
|
| 662 |
+
S,
|
| 663 |
+
N,
|
| 664 |
+
_,
|
| 665 |
+
) = coords.shape
|
| 666 |
+
xy_src = coords.reshape(B * S * N, 3)
|
| 667 |
+
intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B * S * N, 3, 3)
|
| 668 |
+
xy_src = torch.cat([xy_src[..., :2], torch.ones_like(xy_src[..., :1])], dim=-1)
|
| 669 |
+
xyz_src = (torch.inverse(intr) @ xy_src[..., None])[..., 0]
|
| 670 |
+
dp_pred = coords[..., 2]
|
| 671 |
+
xyz_src_ = xyz_src * (dp_pred.reshape(S * N, 1))
|
| 672 |
+
xyz_src_ = xyz_src_.reshape(B, S, N, 3)
|
| 673 |
+
return xyz_src_
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def cam2pix(coords, intr):
|
| 677 |
+
"""
|
| 678 |
+
Args:
|
| 679 |
+
coords: [B, T, N, 3]
|
| 680 |
+
intr: [B, T, 3, 3]
|
| 681 |
+
"""
|
| 682 |
+
coords = coords.detach()
|
| 683 |
+
(
|
| 684 |
+
B,
|
| 685 |
+
S,
|
| 686 |
+
N,
|
| 687 |
+
_,
|
| 688 |
+
) = coords.shape
|
| 689 |
+
xy_src = coords.reshape(B * S * N, 3).clone()
|
| 690 |
+
intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B * S * N, 3, 3)
|
| 691 |
+
xy_src = xy_src / (xy_src[..., 2:] + 1e-5)
|
| 692 |
+
xyz_src = (intr @ xy_src[..., None])[..., 0]
|
| 693 |
+
dp_pred = coords[..., 2]
|
| 694 |
+
xyz_src[..., 2] *= dp_pred.reshape(S * N)
|
| 695 |
+
xyz_src = xyz_src.reshape(B, S, N, 3)
|
| 696 |
+
return xyz_src
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
class BroadMultiHeadAttention(nn.Module):
|
| 700 |
+
def __init__(self, dim, heads):
|
| 701 |
+
super(BroadMultiHeadAttention, self).__init__()
|
| 702 |
+
self.dim = dim
|
| 703 |
+
self.heads = heads
|
| 704 |
+
self.scale = (dim / heads) ** -0.5
|
| 705 |
+
self.attend = nn.Softmax(dim=-1)
|
| 706 |
+
|
| 707 |
+
def attend_with_rpe(self, Q, K):
|
| 708 |
+
Q = rearrange(Q.squeeze(), "i (heads d) -> heads i d", heads=self.heads)
|
| 709 |
+
K = rearrange(K, "b j (heads d) -> b heads j d", heads=self.heads)
|
| 710 |
+
|
| 711 |
+
dots = einsum("hid, bhjd -> bhij", Q, K) * self.scale # (b hw) heads 1 pointnum
|
| 712 |
+
|
| 713 |
+
return self.attend(dots)
|
| 714 |
+
|
| 715 |
+
def forward(self, Q, K, V):
|
| 716 |
+
attn = self.attend_with_rpe(Q, K)
|
| 717 |
+
B, _, _ = K.shape
|
| 718 |
+
_, N, _ = Q.shape
|
| 719 |
+
|
| 720 |
+
V = rearrange(V, "b j (heads d) -> b heads j d", heads=self.heads)
|
| 721 |
+
|
| 722 |
+
out = einsum("bhij, bhjd -> bhid", attn, V)
|
| 723 |
+
out = rearrange(out, "b heads n d -> b n (heads d)", b=B, n=N)
|
| 724 |
+
|
| 725 |
+
return out
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
class CrossAttentionLayer(nn.Module):
|
| 729 |
+
def __init__(
|
| 730 |
+
self,
|
| 731 |
+
qk_dim,
|
| 732 |
+
v_dim,
|
| 733 |
+
query_token_dim,
|
| 734 |
+
tgt_token_dim,
|
| 735 |
+
num_heads=8,
|
| 736 |
+
attn_drop=0.0,
|
| 737 |
+
proj_drop=0.0,
|
| 738 |
+
drop_path=0.0,
|
| 739 |
+
dropout=0.0,
|
| 740 |
+
):
|
| 741 |
+
super(CrossAttentionLayer, self).__init__()
|
| 742 |
+
assert qk_dim % num_heads == 0, f"dim {qk_dim} should be divided by num_heads {num_heads}."
|
| 743 |
+
assert v_dim % num_heads == 0, f"dim {v_dim} should be divided by num_heads {num_heads}."
|
| 744 |
+
"""
|
| 745 |
+
Query Token: [N, C] -> [N, qk_dim] (Q)
|
| 746 |
+
Target Token: [M, D] -> [M, qk_dim] (K), [M, v_dim] (V)
|
| 747 |
+
"""
|
| 748 |
+
self.num_heads = num_heads
|
| 749 |
+
head_dim = qk_dim // num_heads
|
| 750 |
+
self.scale = head_dim**-0.5
|
| 751 |
+
|
| 752 |
+
self.norm1 = nn.LayerNorm(query_token_dim)
|
| 753 |
+
self.norm2 = nn.LayerNorm(query_token_dim)
|
| 754 |
+
self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads)
|
| 755 |
+
self.q, self.k, self.v = (
|
| 756 |
+
nn.Linear(query_token_dim, qk_dim, bias=True),
|
| 757 |
+
nn.Linear(tgt_token_dim, qk_dim, bias=True),
|
| 758 |
+
nn.Linear(tgt_token_dim, v_dim, bias=True),
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
self.proj = nn.Linear(v_dim, query_token_dim)
|
| 762 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 763 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 764 |
+
|
| 765 |
+
self.ffn = nn.Sequential(
|
| 766 |
+
nn.Linear(query_token_dim, query_token_dim),
|
| 767 |
+
nn.GELU(),
|
| 768 |
+
nn.Dropout(dropout),
|
| 769 |
+
nn.Linear(query_token_dim, query_token_dim),
|
| 770 |
+
nn.Dropout(dropout),
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
def forward(self, query, tgt_token):
|
| 774 |
+
"""
|
| 775 |
+
x: [BH1W1, H3W3, D]
|
| 776 |
+
"""
|
| 777 |
+
short_cut = query
|
| 778 |
+
query = self.norm1(query)
|
| 779 |
+
|
| 780 |
+
q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token)
|
| 781 |
+
|
| 782 |
+
x = self.multi_head_attn(q, k, v)
|
| 783 |
+
|
| 784 |
+
x = short_cut + self.proj_drop(self.proj(x))
|
| 785 |
+
|
| 786 |
+
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
| 787 |
+
|
| 788 |
+
return x
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
class LayerNormProxy(nn.Module):
|
| 792 |
+
def __init__(self, dim):
|
| 793 |
+
|
| 794 |
+
super().__init__()
|
| 795 |
+
self.norm = nn.LayerNorm(dim)
|
| 796 |
+
|
| 797 |
+
def forward(self, x):
|
| 798 |
+
|
| 799 |
+
x = rearrange(x, "b c h w -> b h w c")
|
| 800 |
+
x = self.norm(x)
|
| 801 |
+
return rearrange(x, "b h w c -> b c h w")
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
|
| 805 |
+
"""Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
|
| 806 |
+
|
| 807 |
+
Instead of computing [sin(x), cos(x)], we use the trig identity
|
| 808 |
+
cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
|
| 809 |
+
|
| 810 |
+
Args:
|
| 811 |
+
x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
|
| 812 |
+
min_deg: int, the minimum (inclusive) degree of the encoding.
|
| 813 |
+
max_deg: int, the maximum (exclusive) degree of the encoding.
|
| 814 |
+
legacy_posenc_order: bool, keep the same ordering as the original tf code.
|
| 815 |
+
|
| 816 |
+
Returns:
|
| 817 |
+
encoded: torch.Tensor, encoded variables.
|
| 818 |
+
"""
|
| 819 |
+
if min_deg == max_deg:
|
| 820 |
+
return x
|
| 821 |
+
scales = torch.tensor([2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device)
|
| 822 |
+
if legacy_posenc_order:
|
| 823 |
+
xb = x[..., None, :] * scales[:, None]
|
| 824 |
+
four_feat = torch.reshape(torch.sin(torch.stack([xb, xb + 0.5 * np.pi], dim=-2)), list(x.shape[:-1]) + [-1])
|
| 825 |
+
else:
|
| 826 |
+
xb = torch.reshape((x[..., None, :] * scales[:, None]), list(x.shape[:-1]) + [-1])
|
| 827 |
+
four_feat = torch.sin(torch.cat([xb, xb + 0.5 * np.pi], dim=-1))
|
| 828 |
+
return torch.cat([x] + [four_feat], dim=-1)
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
def gaussian2D2(shape, sigma=(1, 1), rho=0):
|
| 832 |
+
if not isinstance(sigma, tuple):
|
| 833 |
+
sigma = (sigma, sigma)
|
| 834 |
+
sigma_x, sigma_y = sigma
|
| 835 |
+
|
| 836 |
+
m, n = [(ss - 1.0) / 2.0 for ss in shape]
|
| 837 |
+
y, x = np.ogrid[-m : m + 1, -n : n + 1]
|
| 838 |
+
|
| 839 |
+
energy = (x * x) / (sigma_x * sigma_x) - 2 * rho * x * y / (sigma_x * sigma_y) + (y * y) / (sigma_y * sigma_y)
|
| 840 |
+
h = np.exp(-energy / (2 * (1 - rho * rho)))
|
| 841 |
+
h[h < np.finfo(h.dtype).eps * h.max()] = 0
|
| 842 |
+
return h / h.sum()
|
models/SpaTrackV2/models/tracker3D/delta_utils/upsample_transformer.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from einops import einsum, rearrange, repeat
|
| 8 |
+
from jaxtyping import Float, Int64
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
|
| 11 |
+
from models.SpaTrackV2.models.tracker3D.delta_utils.blocks import (
|
| 12 |
+
Attention,
|
| 13 |
+
AttnBlock,
|
| 14 |
+
BasicEncoder,
|
| 15 |
+
CorrBlock,
|
| 16 |
+
Mlp,
|
| 17 |
+
ResidualBlock,
|
| 18 |
+
Upsample,
|
| 19 |
+
cam2pix,
|
| 20 |
+
pix2cam,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from models.SpaTrackV2.models.blocks import bilinear_sampler
|
| 24 |
+
|
| 25 |
+
def get_grid(height, width, shape=None, dtype="torch", device="cpu", align_corners=True, normalize=True):
|
| 26 |
+
H, W = height, width
|
| 27 |
+
S = shape if shape else []
|
| 28 |
+
if align_corners:
|
| 29 |
+
x = torch.linspace(0, 1, W, device=device)
|
| 30 |
+
y = torch.linspace(0, 1, H, device=device)
|
| 31 |
+
if not normalize:
|
| 32 |
+
x = x * (W - 1)
|
| 33 |
+
y = y * (H - 1)
|
| 34 |
+
else:
|
| 35 |
+
x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device)
|
| 36 |
+
y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device)
|
| 37 |
+
if not normalize:
|
| 38 |
+
x = x * W
|
| 39 |
+
y = y * H
|
| 40 |
+
x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W]
|
| 41 |
+
x = x.view(*x_view).expand(*exp)
|
| 42 |
+
y = y.view(*y_view).expand(*exp)
|
| 43 |
+
grid = torch.stack([x, y], dim=-1)
|
| 44 |
+
if dtype == "numpy":
|
| 45 |
+
grid = grid.numpy()
|
| 46 |
+
return grid
|
| 47 |
+
|
| 48 |
+
class RelativeAttention(nn.Module):
|
| 49 |
+
"""Multi-headed attention (MHA) module."""
|
| 50 |
+
|
| 51 |
+
def __init__(self, query_dim, num_heads=8, qkv_bias=True, model_size=None, flash=False):
|
| 52 |
+
super(RelativeAttention, self).__init__()
|
| 53 |
+
|
| 54 |
+
query_dim = query_dim // num_heads
|
| 55 |
+
self.num_heads = num_heads
|
| 56 |
+
self.query_dim = query_dim
|
| 57 |
+
self.value_size = query_dim
|
| 58 |
+
self.model_size = query_dim * num_heads
|
| 59 |
+
|
| 60 |
+
self.qkv_bias = qkv_bias
|
| 61 |
+
|
| 62 |
+
self.flash = flash
|
| 63 |
+
|
| 64 |
+
self.query_proj = nn.Linear(num_heads * query_dim, num_heads * query_dim, bias=qkv_bias)
|
| 65 |
+
self.key_proj = nn.Linear(num_heads * query_dim, num_heads * query_dim, bias=qkv_bias)
|
| 66 |
+
self.value_proj = nn.Linear(num_heads * self.value_size, num_heads * self.value_size, bias=qkv_bias)
|
| 67 |
+
self.final_proj = nn.Linear(num_heads * self.value_size, self.model_size, bias=qkv_bias)
|
| 68 |
+
|
| 69 |
+
self.scale = 1.0 / math.sqrt(self.query_dim)
|
| 70 |
+
# self.training_length = 24
|
| 71 |
+
|
| 72 |
+
# bias_forward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(self.training_length)
|
| 73 |
+
# bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
|
| 74 |
+
# bias_backward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(self.training_length, reverse=True)
|
| 75 |
+
# bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
|
| 76 |
+
|
| 77 |
+
# self.register_buffer("precomputed_attn_bias", torch.cat([bias_forward, bias_backward], dim=0), persistent=False)
|
| 78 |
+
|
| 79 |
+
def forward(self, x, context, attn_bias=None):
|
| 80 |
+
B, N1, C = x.size()
|
| 81 |
+
|
| 82 |
+
q = self._linear_projection(x, self.query_dim, self.query_proj) # [T', H, Q=K]
|
| 83 |
+
k = self._linear_projection(context, self.query_dim, self.key_proj) # [T, H, K]
|
| 84 |
+
v = self._linear_projection(context, self.value_size, self.value_proj) # [T, H, V]
|
| 85 |
+
|
| 86 |
+
if self.flash:
|
| 87 |
+
with torch.autocast(device_type="cuda", enabled=True):
|
| 88 |
+
x = flash_attn_func(q.half(), k.half(), v.half())
|
| 89 |
+
x = x.reshape(B, N1, C)
|
| 90 |
+
x = x.float()
|
| 91 |
+
else:
|
| 92 |
+
q = q.permute(0, 2, 1, 3)
|
| 93 |
+
k = k.permute(0, 2, 1, 3)
|
| 94 |
+
v = v.permute(0, 2, 1, 3)
|
| 95 |
+
|
| 96 |
+
sim = (q @ k.transpose(-2, -1)) * self.scale
|
| 97 |
+
|
| 98 |
+
if attn_bias is not None:
|
| 99 |
+
sim = sim + attn_bias
|
| 100 |
+
attn = sim.softmax(dim=-1)
|
| 101 |
+
|
| 102 |
+
x = attn @ v
|
| 103 |
+
x = x.transpose(1, 2).reshape(B, N1, C)
|
| 104 |
+
|
| 105 |
+
# with torch.autocast(device_type="cuda", dtype=torch.float32):
|
| 106 |
+
# attn = F.scaled_dot_product_attention(query_heads, key_heads, value_heads, attn_mask=attn_bias, scale=1.0 / math.sqrt(self.query_dim))
|
| 107 |
+
# else:
|
| 108 |
+
|
| 109 |
+
# sim = (query_heads @ key_heads.transpose(-2, -1)) * self.scale
|
| 110 |
+
|
| 111 |
+
# if attn_bias is not None:
|
| 112 |
+
# sim = sim + attn_bias
|
| 113 |
+
# attn = sim.softmax(dim=-1)
|
| 114 |
+
|
| 115 |
+
# attn = (attn @ value_heads)
|
| 116 |
+
# attn = attn.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, -1)
|
| 117 |
+
|
| 118 |
+
return self.final_proj(x) # [T', D']
|
| 119 |
+
|
| 120 |
+
def _linear_projection(self, x, head_size, proj_layer):
|
| 121 |
+
batch_size, sequence_length, _ = x.shape
|
| 122 |
+
y = proj_layer(x)
|
| 123 |
+
y = y.reshape((batch_size, sequence_length, self.num_heads, head_size))
|
| 124 |
+
|
| 125 |
+
return y
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class UpsampleCrossAttnBlock(nn.Module):
|
| 129 |
+
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 132 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
| 133 |
+
self.cross_attn = RelativeAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 134 |
+
|
| 135 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 136 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 137 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 138 |
+
self.mlp = Mlp(
|
| 139 |
+
in_features=hidden_size,
|
| 140 |
+
hidden_features=mlp_hidden_dim,
|
| 141 |
+
act_layer=approx_gelu,
|
| 142 |
+
drop=0,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def forward(self, x, context, attn_bias=None):
|
| 146 |
+
x = x + self.cross_attn(x=self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias)
|
| 147 |
+
x = x + self.mlp(self.norm2(x))
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class DecoderUpsampler(nn.Module):
|
| 152 |
+
def __init__(self, in_channels: int, middle_channels: int, out_channels: int = None, stride: int = 4):
|
| 153 |
+
super().__init__()
|
| 154 |
+
|
| 155 |
+
self.stride = stride
|
| 156 |
+
|
| 157 |
+
if out_channels is None:
|
| 158 |
+
out_channels = middle_channels
|
| 159 |
+
|
| 160 |
+
self.conv_in = nn.Conv2d(in_channels, middle_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
|
| 161 |
+
self.norm1 = nn.GroupNorm(num_groups=middle_channels // 8, num_channels=middle_channels, eps=1e-6)
|
| 162 |
+
|
| 163 |
+
self.res_blocks = nn.ModuleList()
|
| 164 |
+
self.upsample_blocks = nn.ModuleList()
|
| 165 |
+
|
| 166 |
+
for i in range(int(math.log2(self.stride))):
|
| 167 |
+
self.res_blocks.append(ResidualBlock(middle_channels, middle_channels))
|
| 168 |
+
self.upsample_blocks.append(Upsample(middle_channels, with_conv=True))
|
| 169 |
+
|
| 170 |
+
# in_channels = middle_channels
|
| 171 |
+
|
| 172 |
+
self.norm2 = nn.GroupNorm(num_groups=middle_channels // 8, num_channels=middle_channels, eps=1e-6)
|
| 173 |
+
self.conv_out = nn.Conv2d(middle_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
|
| 174 |
+
|
| 175 |
+
self.initialize_weight()
|
| 176 |
+
|
| 177 |
+
def initialize_weight(self):
|
| 178 |
+
def _basic_init(module):
|
| 179 |
+
if isinstance(module, nn.Conv2d):
|
| 180 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 181 |
+
if module.bias is not None:
|
| 182 |
+
nn.init.constant_(module.bias, 0)
|
| 183 |
+
|
| 184 |
+
self.res_blocks.apply(_basic_init)
|
| 185 |
+
self.conv_in.apply(_basic_init)
|
| 186 |
+
self.conv_out.apply(_basic_init)
|
| 187 |
+
|
| 188 |
+
def forward(
|
| 189 |
+
self,
|
| 190 |
+
x: Float[Tensor, "b c1 h_down w_down"],
|
| 191 |
+
mode: str = "nearest",
|
| 192 |
+
) -> Float[Tensor, "b c1 h_up w_up"]:
|
| 193 |
+
|
| 194 |
+
x = F.relu(self.norm1(self.conv_in(x)))
|
| 195 |
+
|
| 196 |
+
for i in range(len(self.res_blocks)):
|
| 197 |
+
x = self.res_blocks[i](x)
|
| 198 |
+
x = self.upsample_blocks[i](x, mode=mode)
|
| 199 |
+
|
| 200 |
+
x = self.conv_out(F.relu(self.norm2(x)))
|
| 201 |
+
return x
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class UpsampleTransformer(nn.Module):
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
kernel_size: int = 3,
|
| 208 |
+
stride: int = 4,
|
| 209 |
+
latent_dim: int = 128,
|
| 210 |
+
n_heads: int = 4,
|
| 211 |
+
num_attn_blocks: int = 2,
|
| 212 |
+
use_rel_emb: bool = True,
|
| 213 |
+
flash: bool = False,
|
| 214 |
+
):
|
| 215 |
+
super().__init__()
|
| 216 |
+
|
| 217 |
+
self.kernel_size = kernel_size
|
| 218 |
+
self.stride = stride
|
| 219 |
+
self.latent_dim = latent_dim
|
| 220 |
+
|
| 221 |
+
self.n_heads = n_heads
|
| 222 |
+
|
| 223 |
+
self.attnup_feat_cnn = DecoderUpsampler(
|
| 224 |
+
in_channels=self.latent_dim, middle_channels=self.latent_dim, out_channels=self.latent_dim
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
self.cross_blocks = nn.ModuleList(
|
| 228 |
+
[
|
| 229 |
+
UpsampleCrossAttnBlock(latent_dim + 64, latent_dim + 64, num_heads=n_heads, mlp_ratio=4, flash=flash)
|
| 230 |
+
for _ in range(num_attn_blocks)
|
| 231 |
+
]
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
self.flow_mlp = nn.Sequential(
|
| 235 |
+
nn.Conv2d(2 * 16, 128, 7, padding=3),
|
| 236 |
+
nn.ReLU(),
|
| 237 |
+
nn.Conv2d(128, 64, 3, padding=1),
|
| 238 |
+
nn.ReLU(),
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
self.out = nn.Linear(latent_dim + 64, kernel_size * kernel_size, bias=True)
|
| 242 |
+
|
| 243 |
+
if use_rel_emb:
|
| 244 |
+
self.rpb_attnup = nn.Parameter(torch.zeros(kernel_size * kernel_size))
|
| 245 |
+
torch.nn.init.trunc_normal_(self.rpb_attnup, std=0.1, mean=0.0, a=-2.0, b=2.0)
|
| 246 |
+
else:
|
| 247 |
+
self.rpb_attnup = None
|
| 248 |
+
|
| 249 |
+
def forward(
|
| 250 |
+
self,
|
| 251 |
+
feat_map: Float[Tensor, "b c1 h w"],
|
| 252 |
+
flow_map: Float[Tensor, "b c2 h w"],
|
| 253 |
+
):
|
| 254 |
+
B = feat_map.shape[0]
|
| 255 |
+
H_down, W_down = feat_map.shape[-2:]
|
| 256 |
+
# x0, y0 = x0y0
|
| 257 |
+
|
| 258 |
+
feat_map_up = self.attnup_feat_cnn(feat_map) # learnable upsample by 4
|
| 259 |
+
# feat_map_down = F.interpolate(feat_map_up, scale_factor=1/self.stride, mode='nearest') # B C H*4 W*4
|
| 260 |
+
feat_map_down = feat_map
|
| 261 |
+
# depths_down = F.interpolate(depths, scale_factor=1/self.stride, mode='nearest')
|
| 262 |
+
|
| 263 |
+
# NOTE prepare attention bias
|
| 264 |
+
# depths_down_ = torch.stack([depths_down[b, :, y0_:y0_+H_down, x0_:x0_+W_down] for b, (x0_,y0_) in enumerate(zip(x0, y0))], dim=0)
|
| 265 |
+
# depths_ = torch.stack([depths[b, :, y0_*4:y0_*4+H_down*4, x0_*4:x0_*4+W_down*4] for b, (x0_,y0_) in enumerate(zip(x0, y0))], dim=0)
|
| 266 |
+
# guidance_downsample = F.interpolate(guidance, size=(H, W), mode='nearest')
|
| 267 |
+
pad_val = (self.kernel_size - 1) // 2
|
| 268 |
+
# depths_down_padded = F.pad(depths_down_, (pad_val, pad_val, pad_val, pad_val), "replicate")
|
| 269 |
+
|
| 270 |
+
if self.rpb_attnup is not None:
|
| 271 |
+
relative_pos_attn_map = self.rpb_attnup.view(1, 1, -1, 1, 1).repeat(
|
| 272 |
+
B, self.n_heads, 1, H_down * 4, W_down * 4
|
| 273 |
+
)
|
| 274 |
+
relative_pos_attn_map = rearrange(relative_pos_attn_map, "b k n h w -> (b h w) k 1 n")
|
| 275 |
+
attn_bias = relative_pos_attn_map
|
| 276 |
+
else:
|
| 277 |
+
attn_bias = None
|
| 278 |
+
|
| 279 |
+
# NOTE prepare context (low-reso feat)
|
| 280 |
+
context = feat_map_down
|
| 281 |
+
context = F.unfold(context, kernel_size=self.kernel_size, padding=pad_val) # B C*kernel**2 H W
|
| 282 |
+
context = rearrange(context, "b c (h w) -> b c h w", h=H_down, w=W_down)
|
| 283 |
+
context = F.interpolate(context, scale_factor=self.stride, mode="nearest") # B C*kernel**2 H*4 W*4
|
| 284 |
+
context = rearrange(context, "b (c i j) h w -> (b h w) (i j) c", i=self.kernel_size, j=self.kernel_size)
|
| 285 |
+
|
| 286 |
+
# NOTE prepare queries (high-reso feat)
|
| 287 |
+
x = feat_map_up
|
| 288 |
+
x = rearrange(x, "b c h w -> (b h w) 1 c")
|
| 289 |
+
|
| 290 |
+
assert flow_map.shape[-2:] == feat_map.shape[-2:]
|
| 291 |
+
|
| 292 |
+
flow_map = rearrange(flow_map, "b t c h w -> b (t c) h w")
|
| 293 |
+
flow_map = self.flow_mlp(flow_map)
|
| 294 |
+
|
| 295 |
+
nn_flow_map = F.unfold(flow_map, kernel_size=self.kernel_size, padding=pad_val) # B C*kernel**2 H W
|
| 296 |
+
nn_flow_map = rearrange(nn_flow_map, "b c (h w) -> b c h w", h=H_down, w=W_down)
|
| 297 |
+
nn_flow_map = F.interpolate(nn_flow_map, scale_factor=self.stride, mode="nearest") # B C*kernel**2 H*4 W*4
|
| 298 |
+
nn_flow_map = rearrange(
|
| 299 |
+
nn_flow_map, "b (c i j) h w -> (b h w) (i j) c", i=self.kernel_size, j=self.kernel_size
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
up_flow_map = F.interpolate(flow_map, scale_factor=4, mode="nearest") # NN up # b 2 h w
|
| 303 |
+
up_flow_map = rearrange(up_flow_map, "b c h w -> (b h w) 1 c")
|
| 304 |
+
|
| 305 |
+
context = torch.cat([context, nn_flow_map], dim=-1)
|
| 306 |
+
x = torch.cat([x, up_flow_map], dim=-1)
|
| 307 |
+
|
| 308 |
+
for lvl in range(len(self.cross_blocks)):
|
| 309 |
+
x = self.cross_blocks[lvl](x, context, attn_bias)
|
| 310 |
+
|
| 311 |
+
mask_out = self.out(x)
|
| 312 |
+
mask_out = F.softmax(mask_out, dim=-1)
|
| 313 |
+
mask_out = rearrange(mask_out, "(b h w) 1 c -> b c h w", h=H_down * self.stride, w=W_down * self.stride)
|
| 314 |
+
|
| 315 |
+
return mask_out
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def get_alibi_slope(num_heads):
|
| 319 |
+
x = (24) ** (1 / num_heads)
|
| 320 |
+
return torch.tensor([1 / x ** (i + 1) for i in range(num_heads)]).float()
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class UpsampleTransformerAlibi(nn.Module):
|
| 324 |
+
def __init__(
|
| 325 |
+
self,
|
| 326 |
+
kernel_size: int = 3,
|
| 327 |
+
stride: int = 4,
|
| 328 |
+
latent_dim: int = 128,
|
| 329 |
+
n_heads: int = 4,
|
| 330 |
+
num_attn_blocks: int = 2,
|
| 331 |
+
upsample_factor: int = 4,
|
| 332 |
+
):
|
| 333 |
+
super().__init__()
|
| 334 |
+
|
| 335 |
+
self.kernel_size = kernel_size
|
| 336 |
+
self.stride = stride
|
| 337 |
+
self.latent_dim = latent_dim
|
| 338 |
+
self.upsample_factor = upsample_factor
|
| 339 |
+
|
| 340 |
+
self.n_heads = n_heads
|
| 341 |
+
|
| 342 |
+
self.attnup_feat_cnn = DecoderUpsampler(
|
| 343 |
+
in_channels=self.latent_dim,
|
| 344 |
+
middle_channels=self.latent_dim,
|
| 345 |
+
out_channels=self.latent_dim,
|
| 346 |
+
# stride=self.upsample_factor
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
self.cross_blocks = nn.ModuleList(
|
| 350 |
+
[
|
| 351 |
+
UpsampleCrossAttnBlock(
|
| 352 |
+
latent_dim+64,
|
| 353 |
+
latent_dim+64,
|
| 354 |
+
num_heads=n_heads,
|
| 355 |
+
mlp_ratio=4,
|
| 356 |
+
flash=False
|
| 357 |
+
)
|
| 358 |
+
for _ in range(num_attn_blocks)
|
| 359 |
+
]
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
self.flow_mlp = nn.Sequential(
|
| 363 |
+
nn.Conv2d(3*32, 128, 7, padding=3),
|
| 364 |
+
nn.ReLU(),
|
| 365 |
+
nn.Conv2d(128, 64, 3, padding=1),
|
| 366 |
+
nn.ReLU(),
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
self.out = nn.Linear(latent_dim+64, kernel_size*kernel_size, bias=True)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
alibi_slope = get_alibi_slope(n_heads // 2)
|
| 373 |
+
grid_kernel = get_grid(kernel_size, kernel_size, normalize=False).reshape(kernel_size, kernel_size, 2)
|
| 374 |
+
grid_kernel = grid_kernel - (kernel_size - 1) / 2
|
| 375 |
+
grid_kernel = -torch.abs(grid_kernel)
|
| 376 |
+
alibi_bias = torch.cat([
|
| 377 |
+
alibi_slope.view(-1,1,1) * grid_kernel[..., 0].view(1,kernel_size,kernel_size),
|
| 378 |
+
alibi_slope.view(-1,1,1) * grid_kernel[..., 1].view(1,kernel_size,kernel_size)
|
| 379 |
+
]) # n_heads, kernel_size, kernel_size
|
| 380 |
+
|
| 381 |
+
self.register_buffer("alibi_bias", alibi_bias)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def forward(
|
| 385 |
+
self,
|
| 386 |
+
feat_map: Float[Tensor, "b c1 h w"],
|
| 387 |
+
flow_map: Float[Tensor, "b c2 h w"],
|
| 388 |
+
):
|
| 389 |
+
B = feat_map.shape[0]
|
| 390 |
+
H_down, W_down = feat_map.shape[-2:]
|
| 391 |
+
|
| 392 |
+
feat_map_up = self.attnup_feat_cnn(feat_map) # learnable upsample by 4
|
| 393 |
+
if self.upsample_factor != 4:
|
| 394 |
+
additional_scale = float(self.upsample_factor / 4)
|
| 395 |
+
if additional_scale > 1:
|
| 396 |
+
feat_map_up = F.interpolate(feat_map_up, scale_factor=additional_scale, mode='bilinear', align_corners=False)
|
| 397 |
+
else:
|
| 398 |
+
feat_map_up = F.interpolate(feat_map_up, scale_factor=additional_scale, mode='nearest')
|
| 399 |
+
|
| 400 |
+
feat_map_down = feat_map
|
| 401 |
+
|
| 402 |
+
pad_val = (self.kernel_size - 1) // 2
|
| 403 |
+
|
| 404 |
+
attn_bias = self.alibi_bias.view(1,self.n_heads,self.kernel_size**2,1,1).repeat(B,1,1,H_down*self.upsample_factor,W_down*self.upsample_factor)
|
| 405 |
+
attn_bias = rearrange(attn_bias, "b k n h w -> (b h w) k 1 n")
|
| 406 |
+
|
| 407 |
+
# NOTE prepare context (low-reso feat)
|
| 408 |
+
context = feat_map_down
|
| 409 |
+
context = F.unfold(context, kernel_size=self.kernel_size, padding=pad_val) # B C*kernel**2 H W
|
| 410 |
+
context = rearrange(context, 'b c (h w) -> b c h w', h=H_down, w=W_down)
|
| 411 |
+
context = F.interpolate(context, scale_factor=self.upsample_factor, mode='nearest') # B C*kernel**2 H*4 W*4
|
| 412 |
+
context = rearrange(context, 'b (c i j) h w -> (b h w) (i j) c', i=self.kernel_size, j=self.kernel_size)
|
| 413 |
+
|
| 414 |
+
# NOTE prepare queries (high-reso feat)
|
| 415 |
+
x = feat_map_up
|
| 416 |
+
x = rearrange(x, 'b c h w -> (b h w) 1 c')
|
| 417 |
+
|
| 418 |
+
assert flow_map.shape[-2:] == feat_map.shape[-2:]
|
| 419 |
+
|
| 420 |
+
flow_map = rearrange(flow_map, 'b t c h w -> b (t c) h w')
|
| 421 |
+
flow_map = self.flow_mlp(flow_map)
|
| 422 |
+
|
| 423 |
+
nn_flow_map = F.unfold(flow_map, kernel_size=self.kernel_size, padding=pad_val) # B C*kernel**2 H W
|
| 424 |
+
nn_flow_map = rearrange(nn_flow_map, 'b c (h w) -> b c h w', h=H_down, w=W_down)
|
| 425 |
+
nn_flow_map = F.interpolate(nn_flow_map, scale_factor=self.upsample_factor, mode='nearest') # B C*kernel**2 H*4 W*4
|
| 426 |
+
nn_flow_map = rearrange(nn_flow_map, 'b (c i j) h w -> (b h w) (i j) c', i=self.kernel_size, j=self.kernel_size)
|
| 427 |
+
up_flow_map = F.interpolate(flow_map, scale_factor=self.upsample_factor, mode="nearest") # NN up # b 2 h w
|
| 428 |
+
up_flow_map = rearrange(up_flow_map, 'b c h w -> (b h w) 1 c')
|
| 429 |
+
context = torch.cat([context, nn_flow_map], dim=-1)
|
| 430 |
+
x = torch.cat([x, up_flow_map], dim=-1)
|
| 431 |
+
for lvl in range(len(self.cross_blocks)):
|
| 432 |
+
x = self.cross_blocks[lvl](x, context, attn_bias)
|
| 433 |
+
|
| 434 |
+
mask_out = self.out(x)
|
| 435 |
+
mask_out = F.softmax(mask_out, dim=-1)
|
| 436 |
+
mask_out = rearrange(mask_out, '(b h w) 1 c -> b c h w', h=H_down*self.upsample_factor, w=W_down*self.upsample_factor)
|
| 437 |
+
|
| 438 |
+
return mask_out
|
models/SpaTrackV2/models/tracker3D/spatrack_modules/alignment.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import math
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.types
|
| 10 |
+
import utils3d
|
| 11 |
+
|
| 12 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.geometry_torch import (
|
| 13 |
+
weighted_mean,
|
| 14 |
+
harmonic_mean,
|
| 15 |
+
geometric_mean,
|
| 16 |
+
mask_aware_nearest_resize,
|
| 17 |
+
normalized_view_plane_uv,
|
| 18 |
+
angle_diff_vec3
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min:
|
| 22 |
+
"Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`."
|
| 23 |
+
shape = src.shape[:dim] + (size,) + src.shape[dim + 1:]
|
| 24 |
+
minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False)
|
| 25 |
+
minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index))
|
| 26 |
+
indices = torch.full(shape, -1, dtype=torch.long, device=src.device)
|
| 27 |
+
indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim]
|
| 28 |
+
return torch.return_types.min((minimum, indices))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
|
| 32 |
+
batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
|
| 33 |
+
n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
|
| 34 |
+
splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
|
| 35 |
+
splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
|
| 36 |
+
results = []
|
| 37 |
+
for i in range(n_chunks):
|
| 38 |
+
chunk_args = tuple(arg[i] for arg in splited_args)
|
| 39 |
+
chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
|
| 40 |
+
results.append(fn(*chunk_args, **chunk_kwargs))
|
| 41 |
+
|
| 42 |
+
if isinstance(results[0], tuple):
|
| 43 |
+
return tuple(torch.cat(r, dim=0) for r in zip(*results))
|
| 44 |
+
else:
|
| 45 |
+
return torch.cat(results, dim=0)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _pad_inf(x_: torch.Tensor):
|
| 49 |
+
return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _pad_cumsum(cumsum: torch.Tensor):
|
| 53 |
+
return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float):
|
| 57 |
+
return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
| 61 |
+
"""
|
| 62 |
+
If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`.
|
| 63 |
+
|
| 64 |
+
w_i must be >= 0.
|
| 65 |
+
|
| 66 |
+
### Parameters:
|
| 67 |
+
- `x`: tensor of shape (..., n)
|
| 68 |
+
- `y`: tensor of shape (..., n)
|
| 69 |
+
- `w`: tensor of shape (..., n)
|
| 70 |
+
- `trunc`: optional, float or tensor of shape (..., n) or None
|
| 71 |
+
|
| 72 |
+
### Returns:
|
| 73 |
+
- `a`: tensor of shape (...), differentiable
|
| 74 |
+
- `loss`: tensor of shape (...), value of loss function at `a`, detached
|
| 75 |
+
- `index`: tensor of shape (...), where a = y[idx] / x[idx]
|
| 76 |
+
"""
|
| 77 |
+
if trunc is None:
|
| 78 |
+
x, y, w = torch.broadcast_tensors(x, y, w)
|
| 79 |
+
sign = torch.sign(x)
|
| 80 |
+
x, y = x * sign, y * sign
|
| 81 |
+
y_div_x = y / x.clamp_min(eps)
|
| 82 |
+
y_div_x, argsort = y_div_x.sort(dim=-1)
|
| 83 |
+
|
| 84 |
+
wx = torch.gather(x * w, dim=-1, index=argsort)
|
| 85 |
+
derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True)
|
| 86 |
+
search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1)
|
| 87 |
+
|
| 88 |
+
a = y_div_x.gather(dim=-1, index=search).squeeze(-1)
|
| 89 |
+
index = argsort.gather(dim=-1, index=search).squeeze(-1)
|
| 90 |
+
loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1)
|
| 91 |
+
|
| 92 |
+
else:
|
| 93 |
+
# Reshape to (batch_size, n) for simplicity
|
| 94 |
+
x, y, w = torch.broadcast_tensors(x, y, w)
|
| 95 |
+
batch_shape = x.shape[:-1]
|
| 96 |
+
batch_size = math.prod(batch_shape)
|
| 97 |
+
x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1])
|
| 98 |
+
|
| 99 |
+
sign = torch.sign(x)
|
| 100 |
+
x, y = x * sign, y * sign
|
| 101 |
+
wx, wy = w * x, w * y
|
| 102 |
+
xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering
|
| 103 |
+
|
| 104 |
+
y_div_x = A = y / x.clamp_min(eps)
|
| 105 |
+
B = (wy - trunc) / wx.clamp_min(eps)
|
| 106 |
+
C = (wy + trunc) / wx.clamp_min(eps)
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
# Caculate prefix sum by orders of A, B, C
|
| 109 |
+
A, A_argsort = A.sort(dim=-1)
|
| 110 |
+
Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1)
|
| 111 |
+
A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases.
|
| 112 |
+
|
| 113 |
+
B, B_argsort = B.sort(dim=-1)
|
| 114 |
+
Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1)
|
| 115 |
+
B, Q_B = _pad_inf(B), _pad_cumsum(Q_B)
|
| 116 |
+
|
| 117 |
+
C, C_argsort = C.sort(dim=-1)
|
| 118 |
+
Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1)
|
| 119 |
+
C, Q_C = _pad_inf(C), _pad_cumsum(Q_C)
|
| 120 |
+
|
| 121 |
+
# Caculate left and right derivative of A
|
| 122 |
+
j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1)
|
| 123 |
+
j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1)
|
| 124 |
+
j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1)
|
| 125 |
+
left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
|
| 126 |
+
j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1)
|
| 127 |
+
j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1)
|
| 128 |
+
j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1)
|
| 129 |
+
right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
|
| 130 |
+
|
| 131 |
+
# Find extrema
|
| 132 |
+
is_extrema = (left_derivative < 0) & (right_derivative >= 0)
|
| 133 |
+
is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema.
|
| 134 |
+
where_extrema_batch, where_extrema_index = torch.where(is_extrema)
|
| 135 |
+
|
| 136 |
+
# Calculate objective value at extrema
|
| 137 |
+
extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,)
|
| 138 |
+
MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G)
|
| 139 |
+
SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1]
|
| 140 |
+
extrema_value = torch.cat([
|
| 141 |
+
_compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc)
|
| 142 |
+
for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE))
|
| 143 |
+
]) # (num_extrema,)
|
| 144 |
+
|
| 145 |
+
# Find minima among corresponding extrema
|
| 146 |
+
minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,)
|
| 147 |
+
index = where_extrema_index[indices]
|
| 148 |
+
|
| 149 |
+
a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps)
|
| 150 |
+
a = a.reshape(batch_shape)
|
| 151 |
+
loss = minima.reshape(batch_shape)
|
| 152 |
+
index = index.reshape(batch_shape)
|
| 153 |
+
|
| 154 |
+
return a, loss, index
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
|
| 158 |
+
"""
|
| 159 |
+
Align `depth_src` to `depth_tgt` with given constant weights.
|
| 160 |
+
|
| 161 |
+
### Parameters:
|
| 162 |
+
- `depth_src: torch.Tensor` of shape (..., N)
|
| 163 |
+
- `depth_tgt: torch.Tensor` of shape (..., N)
|
| 164 |
+
|
| 165 |
+
"""
|
| 166 |
+
scale, _, _ = align(depth_src, depth_tgt, weight, trunc)
|
| 167 |
+
|
| 168 |
+
return scale
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
|
| 172 |
+
"""
|
| 173 |
+
Align `depth_src` to `depth_tgt` with given constant weights.
|
| 174 |
+
|
| 175 |
+
### Parameters:
|
| 176 |
+
- `depth_src: torch.Tensor` of shape (..., N)
|
| 177 |
+
- `depth_tgt: torch.Tensor` of shape (..., N)
|
| 178 |
+
- `weight: torch.Tensor` of shape (..., N)
|
| 179 |
+
- `trunc: float` or tensor of shape (..., N) or None
|
| 180 |
+
|
| 181 |
+
### Returns:
|
| 182 |
+
- `scale: torch.Tensor` of shape (...).
|
| 183 |
+
- `shift: torch.Tensor` of shape (...).
|
| 184 |
+
"""
|
| 185 |
+
dtype, device = depth_src.dtype, depth_src.device
|
| 186 |
+
|
| 187 |
+
# Flatten batch dimensions for simplicity
|
| 188 |
+
batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1]
|
| 189 |
+
batch_size = math.prod(batch_shape)
|
| 190 |
+
depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n)
|
| 191 |
+
|
| 192 |
+
# Here, we take anchors only for non-zero weights.
|
| 193 |
+
# Although the results will be still correct even anchor points have zero weight,
|
| 194 |
+
# it is wasting computation and may cause instability in some cases, e.g. too many extrema.
|
| 195 |
+
anchors_where_batch, anchors_where_n = torch.where(weight > 0)
|
| 196 |
+
|
| 197 |
+
# Stop gradient when solving optimal anchors
|
| 198 |
+
with torch.no_grad():
|
| 199 |
+
depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors)
|
| 200 |
+
depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors)
|
| 201 |
+
|
| 202 |
+
depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n)
|
| 203 |
+
depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n)
|
| 204 |
+
weight_anchored = weight[anchors_where_batch, :] # (anchors, n)
|
| 205 |
+
|
| 206 |
+
scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors)
|
| 207 |
+
|
| 208 |
+
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,)
|
| 209 |
+
|
| 210 |
+
# Reproduce by indexing for shorter compute graph
|
| 211 |
+
index_1 = anchors_where_n[index_anchor] # (batch_size,)
|
| 212 |
+
index_2 = index[index_anchor] # (batch_size,)
|
| 213 |
+
|
| 214 |
+
tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1)
|
| 215 |
+
tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1)
|
| 216 |
+
|
| 217 |
+
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7)
|
| 218 |
+
shift = tgt_1 - scale * src_1
|
| 219 |
+
|
| 220 |
+
scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape)
|
| 221 |
+
|
| 222 |
+
return scale, shift
|
| 223 |
+
|
| 224 |
+
def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12):
|
| 225 |
+
"""
|
| 226 |
+
Align `depth_src` to `depth_tgt` with given constant weights using IRLS.
|
| 227 |
+
"""
|
| 228 |
+
dtype, device = depth_src.dtype, depth_src.device
|
| 229 |
+
|
| 230 |
+
w = weight
|
| 231 |
+
x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1)
|
| 232 |
+
y = depth_tgt
|
| 233 |
+
|
| 234 |
+
for i in range(max_iter):
|
| 235 |
+
beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1)
|
| 236 |
+
w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps)
|
| 237 |
+
|
| 238 |
+
return beta[..., 0], beta[..., 1]
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
|
| 242 |
+
"""
|
| 243 |
+
### Parameters:
|
| 244 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 245 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 246 |
+
- `weight: torch.Tensor` of shape (..., N)
|
| 247 |
+
|
| 248 |
+
### Returns:
|
| 249 |
+
- `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it.
|
| 250 |
+
- `b: torch.Tensor` of shape (...)
|
| 251 |
+
"""
|
| 252 |
+
dtype, device = points_src.dtype, points_src.device
|
| 253 |
+
|
| 254 |
+
scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc)
|
| 255 |
+
|
| 256 |
+
return scale
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
|
| 260 |
+
"""
|
| 261 |
+
Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
|
| 262 |
+
It is similar to `align_affine` but scale and shift are applied to different dimensions.
|
| 263 |
+
|
| 264 |
+
### Parameters:
|
| 265 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 266 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 267 |
+
- `weights: torch.Tensor` of shape (..., N)
|
| 268 |
+
|
| 269 |
+
### Returns:
|
| 270 |
+
- `scale: torch.Tensor` of shape (...).
|
| 271 |
+
- `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros.
|
| 272 |
+
"""
|
| 273 |
+
dtype, device = points_src.dtype, points_src.device
|
| 274 |
+
|
| 275 |
+
# Flatten batch dimensions for simplicity
|
| 276 |
+
batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
|
| 277 |
+
batch_size = math.prod(batch_shape)
|
| 278 |
+
points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
|
| 279 |
+
|
| 280 |
+
# Take anchors
|
| 281 |
+
anchor_where_batch, anchor_where_n = torch.where(weight > 0)
|
| 282 |
+
with torch.no_grad():
|
| 283 |
+
zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype)
|
| 284 |
+
points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
|
| 285 |
+
points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
|
| 286 |
+
|
| 287 |
+
points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
|
| 288 |
+
points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
|
| 289 |
+
weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
|
| 290 |
+
|
| 291 |
+
# Solve optimal scale and shift for each anchor
|
| 292 |
+
MAX_ELEMENTS = 2 ** 20
|
| 293 |
+
scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
|
| 294 |
+
|
| 295 |
+
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
|
| 296 |
+
|
| 297 |
+
# Reproduce by indexing for shorter compute graph
|
| 298 |
+
index_2 = index[index_anchor] # (batch_size,) [0, 3n)
|
| 299 |
+
index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
|
| 300 |
+
|
| 301 |
+
zeros = torch.zeros((batch_size, n), device=device, dtype=dtype)
|
| 302 |
+
points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1)
|
| 303 |
+
tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
|
| 304 |
+
tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
|
| 305 |
+
|
| 306 |
+
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
|
| 307 |
+
shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
|
| 308 |
+
scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
|
| 309 |
+
|
| 310 |
+
return scale, shift
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
|
| 314 |
+
"""
|
| 315 |
+
Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
|
| 316 |
+
It is similar to `align_affine` but scale and shift are applied to different dimensions.
|
| 317 |
+
|
| 318 |
+
### Parameters:
|
| 319 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 320 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 321 |
+
- `weights: torch.Tensor` of shape (..., N)
|
| 322 |
+
|
| 323 |
+
### Returns:
|
| 324 |
+
- `scale: torch.Tensor` of shape (...).
|
| 325 |
+
- `shift: torch.Tensor` of shape (..., 3)
|
| 326 |
+
"""
|
| 327 |
+
dtype, device = points_src.dtype, points_src.device
|
| 328 |
+
|
| 329 |
+
# Flatten batch dimensions for simplicity
|
| 330 |
+
batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
|
| 331 |
+
batch_size = math.prod(batch_shape)
|
| 332 |
+
points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
|
| 333 |
+
|
| 334 |
+
# Take anchors
|
| 335 |
+
anchor_where_batch, anchor_where_n = torch.where(weight > 0)
|
| 336 |
+
|
| 337 |
+
with torch.no_grad():
|
| 338 |
+
points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3)
|
| 339 |
+
points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3)
|
| 340 |
+
|
| 341 |
+
points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
|
| 342 |
+
points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
|
| 343 |
+
weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
|
| 344 |
+
|
| 345 |
+
# Solve optimal scale and shift for each anchor
|
| 346 |
+
MAX_ELEMENTS = 2 ** 20
|
| 347 |
+
scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
|
| 348 |
+
|
| 349 |
+
# Get optimal scale and shift for each batch element
|
| 350 |
+
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
|
| 351 |
+
|
| 352 |
+
index_2 = index[index_anchor] # (batch_size,) [0, 3n)
|
| 353 |
+
index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
|
| 354 |
+
|
| 355 |
+
src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
|
| 356 |
+
src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
|
| 357 |
+
|
| 358 |
+
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
|
| 359 |
+
shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
|
| 360 |
+
|
| 361 |
+
scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
|
| 362 |
+
|
| 363 |
+
return scale, shift
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
|
| 367 |
+
"""
|
| 368 |
+
Align `points_src` to `points_tgt` with respect to a Z-axis shift.
|
| 369 |
+
|
| 370 |
+
### Parameters:
|
| 371 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 372 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 373 |
+
- `weights: torch.Tensor` of shape (..., N)
|
| 374 |
+
|
| 375 |
+
### Returns:
|
| 376 |
+
- `scale: torch.Tensor` of shape (...).
|
| 377 |
+
- `shift: torch.Tensor` of shape (..., 3)
|
| 378 |
+
"""
|
| 379 |
+
dtype, device = points_src.dtype, points_src.device
|
| 380 |
+
|
| 381 |
+
shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc)
|
| 382 |
+
shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)
|
| 383 |
+
|
| 384 |
+
return shift
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
|
| 388 |
+
"""
|
| 389 |
+
Align `points_src` to `points_tgt` with respect to a Z-axis shift.
|
| 390 |
+
|
| 391 |
+
### Parameters:
|
| 392 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 393 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 394 |
+
- `weights: torch.Tensor` of shape (..., N)
|
| 395 |
+
|
| 396 |
+
### Returns:
|
| 397 |
+
- `scale: torch.Tensor` of shape (...).
|
| 398 |
+
- `shift: torch.Tensor` of shape (..., 3)
|
| 399 |
+
"""
|
| 400 |
+
dtype, device = points_src.dtype, points_src.device
|
| 401 |
+
|
| 402 |
+
shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc)
|
| 403 |
+
|
| 404 |
+
return shift
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 408 |
+
"""
|
| 409 |
+
Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares.
|
| 410 |
+
|
| 411 |
+
### Parameters:
|
| 412 |
+
- `x: torch.Tensor` of shape (..., N)
|
| 413 |
+
- `y: torch.Tensor` of shape (..., N)
|
| 414 |
+
- `w: torch.Tensor` of shape (..., N)
|
| 415 |
+
|
| 416 |
+
### Returns:
|
| 417 |
+
- `a: torch.Tensor` of shape (...,)
|
| 418 |
+
- `b: torch.Tensor` of shape (...,)
|
| 419 |
+
"""
|
| 420 |
+
w_sqrt = torch.ones_like(x) if w is None else w.sqrt()
|
| 421 |
+
A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1)
|
| 422 |
+
B = (w_sqrt * y)[..., None]
|
| 423 |
+
a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1)
|
| 424 |
+
return a, b
|
| 425 |
+
|
| 426 |
+
def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor:
|
| 427 |
+
if beta == 0:
|
| 428 |
+
return err
|
| 429 |
+
else:
|
| 430 |
+
return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta)
|
| 431 |
+
|
| 432 |
+
def affine_invariant_global_loss(
|
| 433 |
+
pred_points: torch.Tensor,
|
| 434 |
+
gt_points: torch.Tensor,
|
| 435 |
+
mask: torch.Tensor,
|
| 436 |
+
align_resolution: int = 64,
|
| 437 |
+
beta: float = 0.0,
|
| 438 |
+
trunc: float = 1.0,
|
| 439 |
+
sparsity_aware: bool = False,
|
| 440 |
+
only_align: bool = False
|
| 441 |
+
):
|
| 442 |
+
device = pred_points.device
|
| 443 |
+
|
| 444 |
+
# Align
|
| 445 |
+
(pred_points_lr, gt_points_lr), lr_mask = mask_aware_nearest_resize((pred_points, gt_points), mask=mask, size=(align_resolution, align_resolution))
|
| 446 |
+
scale, shift = align_points_scale_z_shift(pred_points_lr.flatten(-3, -2), gt_points_lr.flatten(-3, -2), lr_mask.flatten(-2, -1) / gt_points_lr[..., 2].flatten(-2, -1).clamp_min(1e-2), trunc=trunc)
|
| 447 |
+
valid = scale > 0
|
| 448 |
+
scale, shift = torch.where(valid, scale, 0), torch.where(valid[..., None], shift, 0)
|
| 449 |
+
|
| 450 |
+
pred_points = scale[..., None, None, None] * pred_points + shift[..., None, None, :]
|
| 451 |
+
if only_align:
|
| 452 |
+
return pred_points, scale, shift
|
| 453 |
+
# Compute loss
|
| 454 |
+
weight = (valid[..., None, None] & mask).float() / gt_points[..., 2].clamp_min(1e-5)
|
| 455 |
+
weight = weight.clamp_max(10.0 * weighted_mean(weight, mask, dim=(-2, -1), keepdim=True)) # In case your data contains extremely small depth values
|
| 456 |
+
loss = _smooth((pred_points - gt_points).abs() * weight[..., None], beta=beta).mean(dim=(-3, -2, -1))
|
| 457 |
+
|
| 458 |
+
if sparsity_aware:
|
| 459 |
+
# Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1.
|
| 460 |
+
sparsity = mask.float().mean(dim=(-2, -1)) / lr_mask.float().mean(dim=(-2, -1))
|
| 461 |
+
loss = loss / (sparsity + 1e-7)
|
| 462 |
+
|
| 463 |
+
err = (pred_points.detach() - gt_points).norm(dim=-1) / gt_points[..., 2]
|
| 464 |
+
|
| 465 |
+
# Record any scalar metric
|
| 466 |
+
misc = {
|
| 467 |
+
'truncated_error': weighted_mean(err.clamp_max(1.0), mask).item(),
|
| 468 |
+
'delta': weighted_mean((err < 1).float(), mask).item()
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
return loss, misc, scale.detach(), shift.detach()
|
models/SpaTrackV2/models/tracker3D/spatrack_modules/ba.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pycolmap
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pyceres
|
| 5 |
+
from pyceres import SolverOptions, LinearSolverType, PreconditionerType, TrustRegionStrategyType, LoggingType
|
| 6 |
+
import logging
|
| 7 |
+
from scipy.spatial.transform import Rotation as R
|
| 8 |
+
|
| 9 |
+
# config logging and make sure it print to the console
|
| 10 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 11 |
+
|
| 12 |
+
def extract_static_from_3DTracks(world_tracks, dyn_prob,
|
| 13 |
+
query_3d_pts, vis_est, tracks2d, img_size, K=100, maintain_invisb=False):
|
| 14 |
+
"""
|
| 15 |
+
world_tracks: B T N 3 this is the coarse 3d tracks in world coordinate (coarse 3d tracks)
|
| 16 |
+
dyn_prob: B T N this is the dynamic probability of the 3d tracks
|
| 17 |
+
query_3d_pts: B T N 3 this is the query 3d points in world coordinate (coarse by camera pose)
|
| 18 |
+
vis_est: B T N this is the visibility of the 3d tracks
|
| 19 |
+
tracks2d: B T N 2 this is the 2d tracks
|
| 20 |
+
K: int top K static points
|
| 21 |
+
"""
|
| 22 |
+
B, T, N, _ = world_tracks.shape
|
| 23 |
+
static_msk = (dyn_prob<0.5).bool()
|
| 24 |
+
world_tracks_static = world_tracks[:,:,static_msk.squeeze(),:]
|
| 25 |
+
query_3d_pts_static = query_3d_pts[:,static_msk.squeeze(),:]
|
| 26 |
+
if maintain_invisb:
|
| 27 |
+
vis = (tracks2d[...,0] > 0).bool() * (tracks2d[...,1] > 0).bool()
|
| 28 |
+
vis_mask = vis * (img_size[1] > tracks2d[...,0]) * (img_size[0] > tracks2d[...,1])
|
| 29 |
+
vis_mask = vis_mask[:,:,static_msk.squeeze()]
|
| 30 |
+
else:
|
| 31 |
+
vis_mask = (vis_est>0.5).bool()[:,:,static_msk.squeeze()]
|
| 32 |
+
tracks2d_static = tracks2d[:,:,static_msk.squeeze(),:]
|
| 33 |
+
world_tracks_static = (world_tracks_static*vis_mask[...,None]).sum(dim=1)/(vis_mask.sum(dim=1)[...,None]+1e-6)
|
| 34 |
+
# get the distance between the query_3d_pts_static and the world_tracks_static
|
| 35 |
+
dist = (query_3d_pts_static-world_tracks_static).norm(dim=-1)
|
| 36 |
+
# get the top K static points, which have the smallest distance
|
| 37 |
+
topk_idx = torch.argsort(dist,dim=-1)[:,:K]
|
| 38 |
+
world_tracks_static = world_tracks_static[torch.arange(B)[:,None,None],topk_idx]
|
| 39 |
+
query_3d_pts_static = query_3d_pts_static[torch.arange(B)[:,None,None],topk_idx]
|
| 40 |
+
# get the visible selected
|
| 41 |
+
vis_mask_static = vis_mask[:,:,topk_idx.squeeze()]
|
| 42 |
+
tracks2d_static = tracks2d_static[:, :, topk_idx.squeeze(), :]
|
| 43 |
+
|
| 44 |
+
return world_tracks_static, static_msk, topk_idx, vis_mask_static, tracks2d_static
|
| 45 |
+
|
| 46 |
+
def log_ba_summary(summary):
|
| 47 |
+
logging.info(f"Residuals : {summary.num_residuals_reduced}")
|
| 48 |
+
if summary.num_residuals_reduced > 0:
|
| 49 |
+
logging.info(f"Parameters : {summary.num_effective_parameters_reduced}")
|
| 50 |
+
logging.info(
|
| 51 |
+
f"Iterations : {summary.num_successful_steps + summary.num_unsuccessful_steps}"
|
| 52 |
+
)
|
| 53 |
+
logging.info(f"Time : {summary.total_time_in_seconds} [s]")
|
| 54 |
+
logging.info(
|
| 55 |
+
f"Initial cost : {np.sqrt(summary.initial_cost / summary.num_residuals_reduced)} [px]"
|
| 56 |
+
)
|
| 57 |
+
logging.info(
|
| 58 |
+
f"Final cost : {np.sqrt(summary.final_cost / summary.num_residuals_reduced)} [px]"
|
| 59 |
+
)
|
| 60 |
+
return True
|
| 61 |
+
else:
|
| 62 |
+
print("No residuals reduced")
|
| 63 |
+
return False
|
| 64 |
+
|
| 65 |
+
# def solve_bundle_adjustment(reconstruction, ba_options, ba_config):
|
| 66 |
+
# bundle_adjuster = pycolmap.BundleAdjuster(ba_options, ba_config)
|
| 67 |
+
# bundle_adjuster.set_up_problem(
|
| 68 |
+
# reconstruction, ba_options.create_loss_function()
|
| 69 |
+
# )
|
| 70 |
+
# solver_options = bundle_adjuster.set_up_solver_options(
|
| 71 |
+
# bundle_adjuster.problem, ba_options.solver_options
|
| 72 |
+
# )
|
| 73 |
+
# summary = pyceres.SolverSummary()
|
| 74 |
+
# pyceres.solve(solver_options, bundle_adjuster.problem, summary)
|
| 75 |
+
# return summary
|
| 76 |
+
|
| 77 |
+
def efficient_solver(solver_options, stability_mode=True):
|
| 78 |
+
# Set linear solver to ITERATIVE_SCHUR (using PCG to solve Schur complement)
|
| 79 |
+
solver_options.linear_solver_type = LinearSolverType.ITERATIVE_SCHUR
|
| 80 |
+
|
| 81 |
+
# Set preconditioner (critical for PCG)
|
| 82 |
+
solver_options.preconditioner_type = PreconditionerType.SCHUR_JACOBI
|
| 83 |
+
|
| 84 |
+
# Optimize trust region strategy
|
| 85 |
+
solver_options.trust_region_strategy_type = TrustRegionStrategyType.LEVENBERG_MARQUARDT
|
| 86 |
+
|
| 87 |
+
# Enable multi-threading acceleration
|
| 88 |
+
solver_options.num_threads = 32 # Adjust based on CPU cores
|
| 89 |
+
|
| 90 |
+
if stability_mode:
|
| 91 |
+
# Stability-first configuration
|
| 92 |
+
solver_options.initial_trust_region_radius = 1.0 # Reduce initial step size
|
| 93 |
+
solver_options.max_trust_region_radius = 10.0 # Limit max step size
|
| 94 |
+
solver_options.min_trust_region_radius = 1e-6 # Allow small step convergence
|
| 95 |
+
|
| 96 |
+
# Increase regularization parameters
|
| 97 |
+
solver_options.use_nonmonotonic_steps = True # Allow non-monotonic steps
|
| 98 |
+
solver_options.max_consecutive_nonmonotonic_steps = 10
|
| 99 |
+
|
| 100 |
+
# Adjust iteration termination conditions
|
| 101 |
+
solver_options.max_num_iterations = 100 # Increase max iterations
|
| 102 |
+
solver_options.function_tolerance = 1e-8 # Stricter function convergence
|
| 103 |
+
solver_options.gradient_tolerance = 1e-12 # Stricter gradient convergence
|
| 104 |
+
solver_options.parameter_tolerance = 1e-10 # Stricter parameter convergence
|
| 105 |
+
|
| 106 |
+
# Control PCG iterations and precision
|
| 107 |
+
solver_options.min_linear_solver_iterations = 10
|
| 108 |
+
solver_options.max_linear_solver_iterations = 100
|
| 109 |
+
solver_options.inner_iteration_tolerance = 0.01 # Higher inner iteration precision
|
| 110 |
+
|
| 111 |
+
# Increase damping factor
|
| 112 |
+
solver_options.min_lm_diagonal = 1e-3 # Increase min LM diagonal
|
| 113 |
+
solver_options.max_lm_diagonal = 1e+10 # Limit max LM diagonal
|
| 114 |
+
|
| 115 |
+
# Enable parameter change limits
|
| 116 |
+
solver_options.update_state_every_iteration = True # Update state each iteration
|
| 117 |
+
|
| 118 |
+
else:
|
| 119 |
+
# Efficiency-first configuration (original settings)
|
| 120 |
+
solver_options.initial_trust_region_radius = 10000.0
|
| 121 |
+
solver_options.max_trust_region_radius = 1e+16
|
| 122 |
+
solver_options.max_num_iterations = 50
|
| 123 |
+
solver_options.function_tolerance = 1e-6
|
| 124 |
+
solver_options.gradient_tolerance = 1e-10
|
| 125 |
+
solver_options.parameter_tolerance = 1e-8
|
| 126 |
+
solver_options.min_linear_solver_iterations = 5
|
| 127 |
+
solver_options.max_linear_solver_iterations = 50
|
| 128 |
+
solver_options.inner_iteration_tolerance = 0.1
|
| 129 |
+
|
| 130 |
+
# Enable Jacobi scaling for better numerical stability
|
| 131 |
+
solver_options.jacobi_scaling = True
|
| 132 |
+
|
| 133 |
+
# Disable verbose logging for better performance (enable for debugging)
|
| 134 |
+
solver_options.logging_type = LoggingType.SILENT
|
| 135 |
+
solver_options.minimizer_progress_to_stdout = False
|
| 136 |
+
|
| 137 |
+
return solver_options
|
| 138 |
+
|
| 139 |
+
class SpatTrackCost_static(pyceres.CostFunction):
|
| 140 |
+
def __init__(self, observed_depth):
|
| 141 |
+
"""
|
| 142 |
+
observed_depth: float
|
| 143 |
+
"""
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.observed_depth = float(observed_depth)
|
| 146 |
+
self.set_num_residuals(1)
|
| 147 |
+
self.set_parameter_block_sizes([4, 3, 3]) # [rotation_quat, translation, xyz]
|
| 148 |
+
|
| 149 |
+
def Evaluate(self, parameters, residuals, jacobians):
|
| 150 |
+
# Unpack parameters
|
| 151 |
+
quat = parameters[0] # shape: (4,) [w, x, y, z]
|
| 152 |
+
t = parameters[1] # shape: (3,)
|
| 153 |
+
point = parameters[2] # shape: (3,)
|
| 154 |
+
|
| 155 |
+
# Convert COLMAP-style quat [w, x, y, z] to scipy format [x, y, z, w]
|
| 156 |
+
r = R.from_quat([quat[1], quat[2], quat[3], quat[0]])
|
| 157 |
+
R_mat = r.as_matrix() # (3, 3)
|
| 158 |
+
|
| 159 |
+
# Transform point to camera frame
|
| 160 |
+
X_cam = R_mat @ point + t
|
| 161 |
+
z = X_cam[2]
|
| 162 |
+
|
| 163 |
+
# Compute residual (normalized depth error)
|
| 164 |
+
residuals[0] = 20.0 * (z - self.observed_depth) / self.observed_depth
|
| 165 |
+
|
| 166 |
+
if jacobians is not None:
|
| 167 |
+
if jacobians[2] is not None:
|
| 168 |
+
# dr/d(point3D): only z-axis matters, so only 3rd row of R
|
| 169 |
+
jacobians[2][0] = 20.0 * R_mat[2, 0] / self.observed_depth
|
| 170 |
+
jacobians[2][1] = 20.0 * R_mat[2, 1] / self.observed_depth
|
| 171 |
+
jacobians[2][2] = 20.0 * R_mat[2, 2] / self.observed_depth
|
| 172 |
+
|
| 173 |
+
if jacobians[1] is not None:
|
| 174 |
+
# dr/dt = ∂residual/∂translation = d(z)/dt = [0, 0, 1]
|
| 175 |
+
jacobians[1][0] = 0.0
|
| 176 |
+
jacobians[1][1] = 0.0
|
| 177 |
+
jacobians[1][2] = 20.0 / self.observed_depth
|
| 178 |
+
|
| 179 |
+
if jacobians[0] is not None:
|
| 180 |
+
# Optional: dr/d(quat) — not trivial to derive, can be left for autodiff if needed
|
| 181 |
+
# Set zero for now (not ideal but legal)
|
| 182 |
+
jacobians[0][:] = 0.0
|
| 183 |
+
|
| 184 |
+
return True
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class SpatTrackCost_dynamic(pyceres.CostFunction):
|
| 188 |
+
|
| 189 |
+
def __init__(self, observed_uv, image, point3D, camera):
|
| 190 |
+
"""
|
| 191 |
+
observed_uv: 1 1 K 2 this is the 2d tracks
|
| 192 |
+
image: pycolmap.Image object
|
| 193 |
+
point3D: pycolmap.Point3D object
|
| 194 |
+
camera: pycolmap.Camera object
|
| 195 |
+
"""
|
| 196 |
+
sizes = [image.cam_from_world.params.shape[0], point3D.xyz.shape[0], camera.params.shape[0]]
|
| 197 |
+
super().__init__(self, residual_size=2, parameter_block_sizes=sizes)
|
| 198 |
+
self.observed_uv = observed_uv
|
| 199 |
+
self.image = image
|
| 200 |
+
self.point3D = point3D
|
| 201 |
+
self.camera = camera
|
| 202 |
+
|
| 203 |
+
def solve_bundle_adjustment(reconstruction, ba_options,
|
| 204 |
+
ba_config=None, extra_residual=None):
|
| 205 |
+
"""
|
| 206 |
+
Perform bundle adjustment optimization (compatible with pycolmap 0.5+)
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
reconstruction: pycolmap.Reconstruction object
|
| 210 |
+
ba_options: pycolmap.BundleAdjustmentOptions object
|
| 211 |
+
ba_config: pycolmap.BundleAdjustmentConfig object (optional)
|
| 212 |
+
"""
|
| 213 |
+
# Alternatively, you can customize the existing problem or options as:
|
| 214 |
+
# import pyceres
|
| 215 |
+
bundle_adjuster = pycolmap.create_default_bundle_adjuster(
|
| 216 |
+
ba_options, ba_config, reconstruction
|
| 217 |
+
)
|
| 218 |
+
solver_options = ba_options.create_solver_options(
|
| 219 |
+
ba_config, bundle_adjuster.problem
|
| 220 |
+
)
|
| 221 |
+
summary = pyceres.SolverSummary()
|
| 222 |
+
solver_options = efficient_solver(solver_options)
|
| 223 |
+
problem = bundle_adjuster.problem
|
| 224 |
+
# problem = pyceres.Problem()
|
| 225 |
+
# if (extra_residual is not None):
|
| 226 |
+
# observed_depths = []
|
| 227 |
+
# quaternions = []
|
| 228 |
+
# translations = []
|
| 229 |
+
# points3d = []
|
| 230 |
+
# for res_ in extra_residual:
|
| 231 |
+
# point_id_i = res_["point3D_id"]
|
| 232 |
+
# for img_id_i, obs_depth_i in zip(res_["image_ids"], res_["observed_depth"]):
|
| 233 |
+
# if obs_depth_i > 0:
|
| 234 |
+
# observed_depths.append(obs_depth_i)
|
| 235 |
+
# quaternions.append(reconstruction.images[img_id_i].cam_from_world.rotation.quat)
|
| 236 |
+
# translations.append(reconstruction.images[img_id_i].cam_from_world.translation)
|
| 237 |
+
# points3d.append(reconstruction.points3D[point_id_i].xyz)
|
| 238 |
+
# pyceres.add_spatrack_static_problem(
|
| 239 |
+
# problem,
|
| 240 |
+
# observed_depths,
|
| 241 |
+
# quaternions,
|
| 242 |
+
# translations,
|
| 243 |
+
# points3d,
|
| 244 |
+
# huber_loss_delta=5.0
|
| 245 |
+
# )
|
| 246 |
+
|
| 247 |
+
pyceres.solve(solver_options, problem, summary)
|
| 248 |
+
|
| 249 |
+
return summary
|
| 250 |
+
|
| 251 |
+
def batch_matrix_to_pycolmap(
|
| 252 |
+
points3d,
|
| 253 |
+
extrinsics,
|
| 254 |
+
intrinsics,
|
| 255 |
+
tracks,
|
| 256 |
+
masks,
|
| 257 |
+
image_size,
|
| 258 |
+
max_points3D_val=3000,
|
| 259 |
+
shared_camera=False,
|
| 260 |
+
camera_type="SIMPLE_PINHOLE",
|
| 261 |
+
extra_params=None,
|
| 262 |
+
cam_tracks_static=None,
|
| 263 |
+
query_pts=None,
|
| 264 |
+
):
|
| 265 |
+
"""
|
| 266 |
+
Convert Batched Pytorch Tensors to PyCOLMAP
|
| 267 |
+
|
| 268 |
+
Check https://github.com/colmap/pycolmap for more details about its format
|
| 269 |
+
"""
|
| 270 |
+
# points3d: Px3
|
| 271 |
+
# extrinsics: Nx3x4
|
| 272 |
+
# intrinsics: Nx3x3
|
| 273 |
+
# tracks: NxPx2
|
| 274 |
+
# masks: NxP
|
| 275 |
+
# image_size: 2, assume all the frames have been padded to the same size
|
| 276 |
+
# where N is the number of frames and P is the number of tracks
|
| 277 |
+
|
| 278 |
+
N, P, _ = tracks.shape
|
| 279 |
+
assert len(extrinsics) == N
|
| 280 |
+
assert len(intrinsics) == N
|
| 281 |
+
assert len(points3d) == P
|
| 282 |
+
assert image_size.shape[0] == 2
|
| 283 |
+
|
| 284 |
+
extrinsics = extrinsics.cpu().numpy()
|
| 285 |
+
intrinsics = intrinsics.cpu().numpy()
|
| 286 |
+
|
| 287 |
+
if extra_params is not None:
|
| 288 |
+
extra_params = extra_params.cpu().numpy()
|
| 289 |
+
|
| 290 |
+
tracks = tracks.cpu().numpy()
|
| 291 |
+
masks = masks.cpu().numpy()
|
| 292 |
+
points3d = points3d.cpu().numpy()
|
| 293 |
+
image_size = image_size.cpu().numpy()
|
| 294 |
+
if cam_tracks_static is not None:
|
| 295 |
+
cam_tracks_static = cam_tracks_static.cpu().numpy()
|
| 296 |
+
|
| 297 |
+
# Reconstruction object, following the format of PyCOLMAP/COLMAP
|
| 298 |
+
reconstruction = pycolmap.Reconstruction()
|
| 299 |
+
|
| 300 |
+
inlier_num = masks.sum(0)
|
| 301 |
+
valid_mask = inlier_num >= 2 # a track is invalid if without two inliers
|
| 302 |
+
valid_idx = np.nonzero(valid_mask)[0]
|
| 303 |
+
|
| 304 |
+
# Only add 3D points that have sufficient 2D points
|
| 305 |
+
point3d_ids = []
|
| 306 |
+
for vidx in valid_idx:
|
| 307 |
+
point3d_id = reconstruction.add_point3D(
|
| 308 |
+
points3d[vidx], pycolmap.Track(), np.zeros(3)
|
| 309 |
+
)
|
| 310 |
+
point3d_ids.append(point3d_id)
|
| 311 |
+
|
| 312 |
+
# add the residual pair
|
| 313 |
+
if cam_tracks_static is not None:
|
| 314 |
+
extra_residual = []
|
| 315 |
+
for id_x, vidx in enumerate(valid_idx):
|
| 316 |
+
points_3d_id = point3d_ids[id_x]
|
| 317 |
+
point_residual = {
|
| 318 |
+
"point3D_id": points_3d_id,
|
| 319 |
+
"image_ids": [],
|
| 320 |
+
"observed_depth": [],
|
| 321 |
+
}
|
| 322 |
+
query_i = query_pts[:,:,vidx]
|
| 323 |
+
point_residual["image_ids"].append(int(query_i[0,0,0]))
|
| 324 |
+
point_residual["observed_depth"].append(query_i[0,0,-1])
|
| 325 |
+
extra_residual.append(point_residual)
|
| 326 |
+
else:
|
| 327 |
+
extra_residual = None
|
| 328 |
+
|
| 329 |
+
num_points3D = len(valid_idx)
|
| 330 |
+
|
| 331 |
+
camera = None
|
| 332 |
+
# frame idx
|
| 333 |
+
for fidx in range(N):
|
| 334 |
+
# set camera
|
| 335 |
+
if camera is None or (not shared_camera):
|
| 336 |
+
if camera_type == "SIMPLE_RADIAL":
|
| 337 |
+
pycolmap_intri = np.array(
|
| 338 |
+
[
|
| 339 |
+
intrinsics[fidx][0, 0],
|
| 340 |
+
intrinsics[fidx][0, 2],
|
| 341 |
+
intrinsics[fidx][1, 2],
|
| 342 |
+
extra_params[fidx][0],
|
| 343 |
+
]
|
| 344 |
+
)
|
| 345 |
+
elif camera_type == "SIMPLE_PINHOLE":
|
| 346 |
+
pycolmap_intri = np.array(
|
| 347 |
+
[
|
| 348 |
+
intrinsics[fidx][0, 0],
|
| 349 |
+
intrinsics[fidx][0, 2],
|
| 350 |
+
intrinsics[fidx][1, 2],
|
| 351 |
+
]
|
| 352 |
+
)
|
| 353 |
+
else:
|
| 354 |
+
raise ValueError(
|
| 355 |
+
f"Camera type {camera_type} is not supported yet"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
camera = pycolmap.Camera(
|
| 359 |
+
model=camera_type,
|
| 360 |
+
width=image_size[0],
|
| 361 |
+
height=image_size[1],
|
| 362 |
+
params=pycolmap_intri,
|
| 363 |
+
camera_id=fidx,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# add camera
|
| 367 |
+
reconstruction.add_camera(camera)
|
| 368 |
+
|
| 369 |
+
# set image
|
| 370 |
+
cam_from_world = pycolmap.Rigid3d(
|
| 371 |
+
pycolmap.Rotation3d(extrinsics[fidx][:3, :3]),
|
| 372 |
+
extrinsics[fidx][:3, 3],
|
| 373 |
+
) # Rot and Trans
|
| 374 |
+
image = pycolmap.Image(
|
| 375 |
+
id=fidx,
|
| 376 |
+
name=f"image_{fidx}",
|
| 377 |
+
camera_id=camera.camera_id,
|
| 378 |
+
cam_from_world=cam_from_world,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
points2D_list = []
|
| 382 |
+
|
| 383 |
+
point2D_idx = 0
|
| 384 |
+
# NOTE point3D_id start by 1
|
| 385 |
+
for point3D_id in range(1, num_points3D + 1):
|
| 386 |
+
original_track_idx = valid_idx[point3D_id - 1]
|
| 387 |
+
|
| 388 |
+
if (
|
| 389 |
+
reconstruction.points3D[point3D_id].xyz < max_points3D_val
|
| 390 |
+
).all():
|
| 391 |
+
if masks[fidx][original_track_idx]:
|
| 392 |
+
# It seems we don't need +0.5 for BA
|
| 393 |
+
point2D_xy = tracks[fidx][original_track_idx]
|
| 394 |
+
# Please note when adding the Point2D object
|
| 395 |
+
# It not only requires the 2D xy location, but also the id to 3D point
|
| 396 |
+
points2D_list.append(
|
| 397 |
+
pycolmap.Point2D(point2D_xy, point3D_id)
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# add element
|
| 401 |
+
track = reconstruction.points3D[point3D_id].track
|
| 402 |
+
track.add_element(fidx, point2D_idx)
|
| 403 |
+
point2D_idx += 1
|
| 404 |
+
|
| 405 |
+
assert point2D_idx == len(points2D_list)
|
| 406 |
+
try:
|
| 407 |
+
image.points2D = pycolmap.ListPoint2D(points2D_list)
|
| 408 |
+
except Exception as e:
|
| 409 |
+
print(f"frame {fidx} is out of BA: {e}")
|
| 410 |
+
|
| 411 |
+
# add image
|
| 412 |
+
reconstruction.add_image(image)
|
| 413 |
+
|
| 414 |
+
return reconstruction, valid_idx, extra_residual
|
| 415 |
+
|
| 416 |
+
def pycolmap_to_batch_matrix(
|
| 417 |
+
reconstruction, device="cuda", camera_type="SIMPLE_PINHOLE"
|
| 418 |
+
):
|
| 419 |
+
"""
|
| 420 |
+
Convert a PyCOLMAP Reconstruction Object to batched PyTorch tensors.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP.
|
| 424 |
+
device (str): The device to place the tensors on (default: "cuda").
|
| 425 |
+
camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE").
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params.
|
| 429 |
+
"""
|
| 430 |
+
|
| 431 |
+
num_images = len(reconstruction.images)
|
| 432 |
+
max_points3D_id = max(reconstruction.point3D_ids())
|
| 433 |
+
points3D = np.zeros((max_points3D_id, 3))
|
| 434 |
+
|
| 435 |
+
for point3D_id in reconstruction.points3D:
|
| 436 |
+
points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz
|
| 437 |
+
points3D = torch.from_numpy(points3D).to(device)
|
| 438 |
+
|
| 439 |
+
extrinsics = []
|
| 440 |
+
intrinsics = []
|
| 441 |
+
|
| 442 |
+
extra_params = [] if camera_type == "SIMPLE_RADIAL" else None
|
| 443 |
+
|
| 444 |
+
for i in range(num_images):
|
| 445 |
+
# Extract and append extrinsics
|
| 446 |
+
pyimg = reconstruction.images[i]
|
| 447 |
+
pycam = reconstruction.cameras[pyimg.camera_id]
|
| 448 |
+
matrix = pyimg.cam_from_world.matrix()
|
| 449 |
+
extrinsics.append(matrix)
|
| 450 |
+
|
| 451 |
+
# Extract and append intrinsics
|
| 452 |
+
calibration_matrix = pycam.calibration_matrix()
|
| 453 |
+
intrinsics.append(calibration_matrix)
|
| 454 |
+
|
| 455 |
+
if camera_type == "SIMPLE_RADIAL":
|
| 456 |
+
extra_params.append(pycam.params[-1])
|
| 457 |
+
|
| 458 |
+
# Convert lists to torch tensors
|
| 459 |
+
extrinsics = torch.from_numpy(np.stack(extrinsics)).to(device)
|
| 460 |
+
|
| 461 |
+
intrinsics = torch.from_numpy(np.stack(intrinsics)).to(device)
|
| 462 |
+
|
| 463 |
+
if camera_type == "SIMPLE_RADIAL":
|
| 464 |
+
extra_params = torch.from_numpy(np.stack(extra_params)).to(device)
|
| 465 |
+
extra_params = extra_params[:, None]
|
| 466 |
+
|
| 467 |
+
return points3D, extrinsics, intrinsics, extra_params
|
| 468 |
+
|
| 469 |
+
def ba_pycolmap(world_tracks, intrs, c2w_traj, visb, tracks2d, image_size, cam_tracks_static=None, training=True, query_pts=None):
|
| 470 |
+
"""
|
| 471 |
+
world_tracks: 1 1 K 3 this is the coarse 3d tracks in world coordinate (coarse 3d tracks)
|
| 472 |
+
intrs: B T 3 3 this is the intrinsic matrix
|
| 473 |
+
c2w_traj: B T 4 4 this is the camera trajectory
|
| 474 |
+
visb: B T K this is the visibility of the 3d tracks
|
| 475 |
+
tracks2d: B T K 2 this is the 2d tracks
|
| 476 |
+
"""
|
| 477 |
+
with torch.no_grad():
|
| 478 |
+
B, _, K, _ = world_tracks.shape
|
| 479 |
+
T = c2w_traj.shape[1]
|
| 480 |
+
world_tracks = world_tracks.view(K, 3).detach()
|
| 481 |
+
world_tracks_refine = world_tracks.view(K, 3).detach().clone()
|
| 482 |
+
c2w_traj_glob = c2w_traj.view(B*T, 4, 4).detach().clone()
|
| 483 |
+
c2w_traj = c2w_traj.view(B*T, 4, 4).detach()
|
| 484 |
+
intrs = intrs.view(B*T, 3, 3).detach()
|
| 485 |
+
visb = visb.view(B*T, K).detach()
|
| 486 |
+
tracks2d = tracks2d[...,:2].view(B*T, K, 2).detach()
|
| 487 |
+
|
| 488 |
+
rec, valid_idx_pts, extra_residual = batch_matrix_to_pycolmap(
|
| 489 |
+
world_tracks,
|
| 490 |
+
torch.inverse(c2w_traj)[:,:3,:],
|
| 491 |
+
intrs,
|
| 492 |
+
tracks2d,
|
| 493 |
+
visb,
|
| 494 |
+
image_size,
|
| 495 |
+
cam_tracks_static=cam_tracks_static,
|
| 496 |
+
query_pts=query_pts,
|
| 497 |
+
)
|
| 498 |
+
# NOTE It is window_size + 1 instead of window_size
|
| 499 |
+
ba_options = pycolmap.BundleAdjustmentOptions()
|
| 500 |
+
ba_options.refine_focal_length = False
|
| 501 |
+
ba_options.refine_principal_point = False
|
| 502 |
+
ba_options.refine_extra_params = False
|
| 503 |
+
ba_config = pycolmap.BundleAdjustmentConfig()
|
| 504 |
+
for image_id in rec.reg_image_ids():
|
| 505 |
+
ba_config.add_image(image_id)
|
| 506 |
+
# Fix frame 0, i.e, the end frame of the last window
|
| 507 |
+
ba_config.set_constant_cam_pose(0)
|
| 508 |
+
|
| 509 |
+
# fix the 3d points
|
| 510 |
+
for point3D_id in rec.points3D:
|
| 511 |
+
if training:
|
| 512 |
+
# ba_config.add_constant_point(point3D_id)
|
| 513 |
+
ba_config.add_variable_point(point3D_id)
|
| 514 |
+
else:
|
| 515 |
+
ba_config.add_variable_point(point3D_id)
|
| 516 |
+
# ba_config.add_constant_point(point3D_id)
|
| 517 |
+
if (len(ba_config.variable_point3D_ids) < 50) and (len(ba_config.constant_point3D_ids) < 50):
|
| 518 |
+
return c2w_traj_glob, world_tracks_refine, intrs
|
| 519 |
+
summary = solve_bundle_adjustment(rec, ba_options, ba_config, extra_residual=extra_residual)
|
| 520 |
+
# free the 3d points
|
| 521 |
+
# for point3D_id in rec.points3D:
|
| 522 |
+
# ba_config.remove_constant_point(point3D_id)
|
| 523 |
+
# ba_config.add_variable_point(point3D_id)
|
| 524 |
+
# summary = solve_bundle_adjustment(rec, ba_options, ba_config)
|
| 525 |
+
if not training:
|
| 526 |
+
ba_success = log_ba_summary(summary)
|
| 527 |
+
# get the refined results
|
| 528 |
+
points3D, extrinsics, intrinsics, extra_params = pycolmap_to_batch_matrix(rec, device="cuda", camera_type="SIMPLE_PINHOLE")
|
| 529 |
+
c2w_traj_glob[:, :3, :] = extrinsics
|
| 530 |
+
c2w_traj_glob = torch.inverse(c2w_traj_glob)
|
| 531 |
+
world_tracks_refine[valid_idx_pts] = points3D.to(world_tracks_refine.device).to(world_tracks_refine.dtype)
|
| 532 |
+
intrinsics = intrinsics.to(world_tracks_refine.device).to(world_tracks_refine.dtype)
|
| 533 |
+
# import pdb; pdb.set_trace()
|
| 534 |
+
return c2w_traj_glob, world_tracks_refine, intrinsics
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
|
models/SpaTrackV2/models/tracker3D/spatrack_modules/blocks.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PointDinoV2(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
PointDinoV2 is a 3D point tracking model that uses a backbone and head to extract features from points and track them.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, ):
|
| 12 |
+
super(PointDinoV2, self).__init__()
|
| 13 |
+
# self.backbone = PointDinoV2Backbone()
|
| 14 |
+
# self.head = PointDinoV2Head()
|
| 15 |
+
|
models/SpaTrackV2/models/tracker3D/spatrack_modules/dynamic_point_refine.py
ADDED
|
File without changes
|
models/SpaTrackV2/models/tracker3D/spatrack_modules/geometry_numpy.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from functools import partial
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from scipy.signal import fftconvolve
|
| 8 |
+
import numpy as np
|
| 9 |
+
import utils3d
|
| 10 |
+
|
| 11 |
+
from .tools import timeit
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
|
| 15 |
+
if w is None:
|
| 16 |
+
return np.mean(x, axis=axis)
|
| 17 |
+
else:
|
| 18 |
+
w = w.astype(x.dtype)
|
| 19 |
+
return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
|
| 23 |
+
if w is None:
|
| 24 |
+
return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
|
| 25 |
+
else:
|
| 26 |
+
w = w.astype(x.dtype)
|
| 27 |
+
return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
|
| 31 |
+
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
| 32 |
+
if aspect_ratio is None:
|
| 33 |
+
aspect_ratio = width / height
|
| 34 |
+
|
| 35 |
+
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
| 36 |
+
span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
|
| 37 |
+
|
| 38 |
+
u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
|
| 39 |
+
v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
|
| 40 |
+
u, v = np.meshgrid(u, v, indexing='xy')
|
| 41 |
+
uv = np.stack([u, v], axis=-1)
|
| 42 |
+
return uv
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def focal_to_fov_numpy(focal: np.ndarray):
|
| 46 |
+
return 2 * np.arctan(0.5 / focal)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def fov_to_focal_numpy(fov: np.ndarray):
|
| 50 |
+
return 0.5 / np.tan(fov / 2)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 54 |
+
fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
|
| 55 |
+
fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
|
| 56 |
+
return fov_x, fov_y
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def point_map_to_depth_legacy_numpy(points: np.ndarray):
|
| 60 |
+
height, width = points.shape[-3:-1]
|
| 61 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
| 62 |
+
uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2)
|
| 63 |
+
_, uv = np.broadcast_arrays(points[..., :2], uv)
|
| 64 |
+
|
| 65 |
+
# Solve least squares problem
|
| 66 |
+
b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2)
|
| 67 |
+
A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2)
|
| 68 |
+
|
| 69 |
+
M = A.swapaxes(-2, -1) @ A
|
| 70 |
+
solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1)
|
| 71 |
+
focal, shift = solution
|
| 72 |
+
|
| 73 |
+
depth = points[..., 2] + shift[..., None, None]
|
| 74 |
+
fov_x = np.arctan(width / diagonal / focal) * 2
|
| 75 |
+
fov_y = np.arctan(height / diagonal / focal) * 2
|
| 76 |
+
return depth, fov_x, fov_y, shift
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
|
| 80 |
+
"Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
|
| 81 |
+
from scipy.optimize import least_squares
|
| 82 |
+
uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
|
| 83 |
+
|
| 84 |
+
def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
|
| 85 |
+
xy_proj = xy / (z + shift)[: , None]
|
| 86 |
+
f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
| 87 |
+
err = (f * xy_proj - uv).ravel()
|
| 88 |
+
return err
|
| 89 |
+
|
| 90 |
+
solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
|
| 91 |
+
optim_shift = solution['x'].squeeze().astype(np.float32)
|
| 92 |
+
|
| 93 |
+
xy_proj = xy / (z + optim_shift)[: , None]
|
| 94 |
+
optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
| 95 |
+
|
| 96 |
+
return optim_shift, optim_focal
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
|
| 100 |
+
"Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
|
| 101 |
+
from scipy.optimize import least_squares
|
| 102 |
+
uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
|
| 103 |
+
|
| 104 |
+
def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
|
| 105 |
+
xy_proj = xy / (z + shift)[: , None]
|
| 106 |
+
err = (focal * xy_proj - uv).ravel()
|
| 107 |
+
return err
|
| 108 |
+
|
| 109 |
+
solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
|
| 110 |
+
optim_shift = solution['x'].squeeze().astype(np.float32)
|
| 111 |
+
|
| 112 |
+
return optim_shift
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)):
|
| 116 |
+
import cv2
|
| 117 |
+
assert points.shape[-1] == 3, "Points should (H, W, 3)"
|
| 118 |
+
|
| 119 |
+
height, width = points.shape[-3], points.shape[-2]
|
| 120 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
| 121 |
+
|
| 122 |
+
uv = normalized_view_plane_uv_numpy(width=width, height=height)
|
| 123 |
+
|
| 124 |
+
if mask is None:
|
| 125 |
+
points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
|
| 126 |
+
uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
|
| 127 |
+
else:
|
| 128 |
+
(points_lr, uv_lr), mask_lr = mask_aware_nearest_resize_numpy((points, uv), mask, downsample_size)
|
| 129 |
+
|
| 130 |
+
if points_lr.size < 2:
|
| 131 |
+
return 1., 0.
|
| 132 |
+
|
| 133 |
+
if focal is None:
|
| 134 |
+
focal, shift = solve_optimal_focal_shift(uv_lr, points_lr)
|
| 135 |
+
else:
|
| 136 |
+
shift = solve_optimal_shift(uv_lr, points_lr, focal)
|
| 137 |
+
|
| 138 |
+
return focal, shift
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def mask_aware_nearest_resize_numpy(
|
| 142 |
+
inputs: Union[np.ndarray, Tuple[np.ndarray, ...], None],
|
| 143 |
+
mask: np.ndarray,
|
| 144 |
+
size: Tuple[int, int],
|
| 145 |
+
return_index: bool = False
|
| 146 |
+
) -> Tuple[Union[np.ndarray, Tuple[np.ndarray, ...], None], np.ndarray, Tuple[np.ndarray, ...]]:
|
| 147 |
+
"""
|
| 148 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
| 149 |
+
|
| 150 |
+
### Parameters
|
| 151 |
+
- `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
|
| 152 |
+
- `mask`: input 2D mask of shape (..., H, W)
|
| 153 |
+
- `size`: target size (width, height)
|
| 154 |
+
|
| 155 |
+
### Returns
|
| 156 |
+
- `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
|
| 157 |
+
- `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
|
| 158 |
+
- `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension.
|
| 159 |
+
"""
|
| 160 |
+
height, width = mask.shape[-2:]
|
| 161 |
+
target_width, target_height = size
|
| 162 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
| 163 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
|
| 164 |
+
filter_size = filter_h_i * filter_w_i
|
| 165 |
+
padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
|
| 166 |
+
|
| 167 |
+
# Window the original mask and uv
|
| 168 |
+
uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
|
| 169 |
+
indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
|
| 170 |
+
padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
|
| 171 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
| 172 |
+
padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
|
| 173 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
| 174 |
+
padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
|
| 175 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
| 176 |
+
windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
| 177 |
+
windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
|
| 178 |
+
windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
| 179 |
+
|
| 180 |
+
# Gather the target pixels's local window
|
| 181 |
+
target_centers = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
|
| 182 |
+
target_lefttop = target_centers - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
|
| 183 |
+
target_window = np.round(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
|
| 184 |
+
|
| 185 |
+
target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
| 186 |
+
target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
| 187 |
+
target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(*([-1] * (mask.ndim - 2)), target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
| 188 |
+
|
| 189 |
+
# Compute nearest neighbor in the local window for each pixel
|
| 190 |
+
dist = np.square(target_window_centers - target_centers[..., None])
|
| 191 |
+
dist = dist[..., 0, :] + dist[..., 1, :]
|
| 192 |
+
dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size)
|
| 193 |
+
nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1)
|
| 194 |
+
nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width)
|
| 195 |
+
nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
|
| 196 |
+
target_mask = np.any(target_window_mask, axis=-1)
|
| 197 |
+
batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])]
|
| 198 |
+
|
| 199 |
+
index = (*batch_indices, nearest_i, nearest_j)
|
| 200 |
+
|
| 201 |
+
if inputs is None:
|
| 202 |
+
outputs = None
|
| 203 |
+
elif isinstance(inputs, np.ndarray):
|
| 204 |
+
outputs = inputs[index]
|
| 205 |
+
elif isinstance(inputs, Sequence):
|
| 206 |
+
outputs = tuple(x[index] for x in inputs)
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError(f'Invalid input type: {type(inputs)}')
|
| 209 |
+
|
| 210 |
+
if return_index:
|
| 211 |
+
return outputs, target_mask, index
|
| 212 |
+
else:
|
| 213 |
+
return outputs, target_mask
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def mask_aware_area_resize_numpy(image: np.ndarray, mask: np.ndarray, target_width: int, target_height: int) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]:
|
| 217 |
+
"""
|
| 218 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
| 219 |
+
|
| 220 |
+
### Parameters
|
| 221 |
+
- `image`: Input 2D image of shape (..., H, W, C)
|
| 222 |
+
- `mask`: Input 2D mask of shape (..., H, W)
|
| 223 |
+
- `target_width`: target width of the resized map
|
| 224 |
+
- `target_height`: target height of the resized map
|
| 225 |
+
|
| 226 |
+
### Returns
|
| 227 |
+
- `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width).
|
| 228 |
+
- `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
|
| 229 |
+
"""
|
| 230 |
+
height, width = mask.shape[-2:]
|
| 231 |
+
|
| 232 |
+
if image.shape[-2:] == (height, width):
|
| 233 |
+
omit_channel_dim = True
|
| 234 |
+
else:
|
| 235 |
+
omit_channel_dim = False
|
| 236 |
+
if omit_channel_dim:
|
| 237 |
+
image = image[..., None]
|
| 238 |
+
|
| 239 |
+
image = np.where(mask[..., None], image, 0)
|
| 240 |
+
|
| 241 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
| 242 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f) + 1, math.ceil(filter_w_f) + 1
|
| 243 |
+
filter_size = filter_h_i * filter_w_i
|
| 244 |
+
padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
|
| 245 |
+
|
| 246 |
+
# Window the original mask and uv (non-copy)
|
| 247 |
+
uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
|
| 248 |
+
indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
|
| 249 |
+
padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
|
| 250 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
| 251 |
+
padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
|
| 252 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
| 253 |
+
padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
|
| 254 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
| 255 |
+
windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
| 256 |
+
windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
|
| 257 |
+
windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
| 258 |
+
|
| 259 |
+
# Gather the target pixels's local window
|
| 260 |
+
target_center = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
|
| 261 |
+
target_lefttop = target_center - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
|
| 262 |
+
target_bottomright = target_center + np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
|
| 263 |
+
target_window = np.floor(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
|
| 264 |
+
|
| 265 |
+
target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
| 266 |
+
target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
| 267 |
+
target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
| 268 |
+
|
| 269 |
+
# Compute pixel area in the local windows
|
| 270 |
+
target_window_lefttop = np.maximum(target_window_centers - 0.5, target_lefttop[..., None])
|
| 271 |
+
target_window_bottomright = np.minimum(target_window_centers + 0.5, target_bottomright[..., None])
|
| 272 |
+
target_window_area = (target_window_bottomright - target_window_lefttop).clip(0, None)
|
| 273 |
+
target_window_area = np.where(target_window_mask, target_window_area[..., 0, :] * target_window_area[..., 1, :], 0)
|
| 274 |
+
|
| 275 |
+
# Weighted sum by area
|
| 276 |
+
target_window_image = image.reshape(*image.shape[:-3], height * width, -1)[..., target_window_indices, :].swapaxes(-2, -1)
|
| 277 |
+
target_mask = np.sum(target_window_area, axis=-1) >= 0.25
|
| 278 |
+
target_image = weighted_mean_numpy(target_window_image, target_window_area[..., None, :], axis=-1)
|
| 279 |
+
|
| 280 |
+
if omit_channel_dim:
|
| 281 |
+
target_image = target_image[..., 0]
|
| 282 |
+
|
| 283 |
+
return target_image, target_mask
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def norm3d(x: np.ndarray) -> np.ndarray:
|
| 287 |
+
"Faster `np.linalg.norm(x, axis=-1)` for 3D vectors"
|
| 288 |
+
return np.sqrt(np.square(x[..., 0]) + np.square(x[..., 1]) + np.square(x[..., 2]))
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def depth_occlusion_edge_numpy(depth: np.ndarray, mask: np.ndarray, kernel_size: int = 3, tol: float = 0.1):
|
| 292 |
+
disp = np.where(mask, 1 / depth, 0)
|
| 293 |
+
disp_pad = np.pad(disp, (kernel_size // 2, kernel_size // 2), constant_values=0)
|
| 294 |
+
mask_pad = np.pad(mask, (kernel_size // 2, kernel_size // 2), constant_values=False)
|
| 295 |
+
disp_window = utils3d.numpy.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
|
| 296 |
+
mask_window = utils3d.numpy.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
|
| 297 |
+
|
| 298 |
+
disp_mean = weighted_mean_numpy(disp_window, mask_window, axis=(-2, -1))
|
| 299 |
+
fg_edge_mask = mask & (disp > (1 + tol) * disp_mean)
|
| 300 |
+
bg_edge_mask = mask & (disp_mean > (1 + tol) * disp)
|
| 301 |
+
return fg_edge_mask, bg_edge_mask
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def disk_kernel(radius: int) -> np.ndarray:
|
| 305 |
+
"""
|
| 306 |
+
Generate disk kernel with given radius.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
radius (int): Radius of the disk (in pixels).
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
np.ndarray: (2*radius+1, 2*radius+1) normalized convolution kernel.
|
| 313 |
+
"""
|
| 314 |
+
# Create coordinate grid centered at (0,0)
|
| 315 |
+
L = np.arange(-radius, radius + 1)
|
| 316 |
+
X, Y = np.meshgrid(L, L)
|
| 317 |
+
# Generate disk: region inside circle with radius R is 1
|
| 318 |
+
kernel = ((X**2 + Y**2) <= radius**2).astype(np.float32)
|
| 319 |
+
# Normalize the kernel
|
| 320 |
+
kernel /= np.sum(kernel)
|
| 321 |
+
return kernel
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def disk_blur(image: np.ndarray, radius: int) -> np.ndarray:
|
| 325 |
+
"""
|
| 326 |
+
Apply disk blur to an image using FFT convolution.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
image (np.ndarray): Input image, can be grayscale or color.
|
| 330 |
+
radius (int): Blur radius (in pixels).
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
np.ndarray: Blurred image.
|
| 334 |
+
"""
|
| 335 |
+
if radius == 0:
|
| 336 |
+
return image
|
| 337 |
+
kernel = disk_kernel(radius)
|
| 338 |
+
if image.ndim == 2:
|
| 339 |
+
blurred = fftconvolve(image, kernel, mode='same')
|
| 340 |
+
elif image.ndim == 3:
|
| 341 |
+
channels = []
|
| 342 |
+
for i in range(image.shape[2]):
|
| 343 |
+
blurred_channel = fftconvolve(image[..., i], kernel, mode='same')
|
| 344 |
+
channels.append(blurred_channel)
|
| 345 |
+
blurred = np.stack(channels, axis=-1)
|
| 346 |
+
else:
|
| 347 |
+
raise ValueError("Image must be 2D or 3D.")
|
| 348 |
+
return blurred
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def depth_of_field(
|
| 352 |
+
img: np.ndarray,
|
| 353 |
+
disp: np.ndarray,
|
| 354 |
+
focus_disp : float,
|
| 355 |
+
max_blur_radius : int = 10,
|
| 356 |
+
) -> np.ndarray:
|
| 357 |
+
"""
|
| 358 |
+
Apply depth of field effect to an image.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
img (numpy.ndarray): (H, W, 3) input image.
|
| 362 |
+
depth (numpy.ndarray): (H, W) depth map of the scene.
|
| 363 |
+
focus_depth (float): Focus depth of the lens.
|
| 364 |
+
strength (float): Strength of the depth of field effect.
|
| 365 |
+
max_blur_radius (int): Maximum blur radius (in pixels).
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
numpy.ndarray: (H, W, 3) output image with depth of field effect applied.
|
| 369 |
+
"""
|
| 370 |
+
# Precalculate dialated depth map for each blur radius
|
| 371 |
+
max_disp = np.max(disp)
|
| 372 |
+
disp = disp / max_disp
|
| 373 |
+
focus_disp = focus_disp / max_disp
|
| 374 |
+
dilated_disp = []
|
| 375 |
+
for radius in range(max_blur_radius + 1):
|
| 376 |
+
dilated_disp.append(cv2.dilate(disp, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1)), iterations=1))
|
| 377 |
+
|
| 378 |
+
# Determine the blur radius for each pixel based on the depth map
|
| 379 |
+
blur_radii = np.clip(abs(disp - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
|
| 380 |
+
for radius in range(max_blur_radius + 1):
|
| 381 |
+
dialted_blur_radii = np.clip(abs(dilated_disp[radius] - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
|
| 382 |
+
mask = (dialted_blur_radii >= radius) & (dialted_blur_radii >= blur_radii) & (dilated_disp[radius] > disp)
|
| 383 |
+
blur_radii[mask] = dialted_blur_radii[mask]
|
| 384 |
+
blur_radii = np.clip(blur_radii, 0, max_blur_radius)
|
| 385 |
+
blur_radii = cv2.blur(blur_radii, (5, 5))
|
| 386 |
+
|
| 387 |
+
# Precalculate the blured image for each blur radius
|
| 388 |
+
unique_radii = np.unique(blur_radii)
|
| 389 |
+
precomputed = {}
|
| 390 |
+
for radius in range(max_blur_radius + 1):
|
| 391 |
+
if radius not in unique_radii:
|
| 392 |
+
continue
|
| 393 |
+
precomputed[radius] = disk_blur(img, radius)
|
| 394 |
+
|
| 395 |
+
# Composit the blured image for each pixel
|
| 396 |
+
output = np.zeros_like(img)
|
| 397 |
+
for r in unique_radii:
|
| 398 |
+
mask = blur_radii == r
|
| 399 |
+
output[mask] = precomputed[r][mask]
|
| 400 |
+
|
| 401 |
+
return output
|
models/SpaTrackV2/models/tracker3D/spatrack_modules/geometry_torch.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import math
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.types
|
| 10 |
+
import utils3d
|
| 11 |
+
|
| 12 |
+
from .tools import timeit
|
| 13 |
+
from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
| 17 |
+
if w is None:
|
| 18 |
+
return x.mean(dim=dim, keepdim=keepdim)
|
| 19 |
+
else:
|
| 20 |
+
w = w.to(x.dtype)
|
| 21 |
+
return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
| 25 |
+
if w is None:
|
| 26 |
+
return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
|
| 27 |
+
else:
|
| 28 |
+
w = w.to(x.dtype)
|
| 29 |
+
return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
| 33 |
+
if w is None:
|
| 34 |
+
return x.add(eps).log().mean(dim=dim).exp()
|
| 35 |
+
else:
|
| 36 |
+
w = w.to(x.dtype)
|
| 37 |
+
return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
|
| 41 |
+
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
| 42 |
+
if aspect_ratio is None:
|
| 43 |
+
aspect_ratio = width / height
|
| 44 |
+
|
| 45 |
+
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
| 46 |
+
span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
|
| 47 |
+
|
| 48 |
+
u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
|
| 49 |
+
v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
|
| 50 |
+
u, v = torch.meshgrid(u, v, indexing='xy')
|
| 51 |
+
uv = torch.stack([u, v], dim=-1)
|
| 52 |
+
return uv
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
|
| 56 |
+
kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
|
| 57 |
+
kernel = kernel / kernel.sum()
|
| 58 |
+
kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
|
| 59 |
+
input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
|
| 60 |
+
input = F.conv2d(input, kernel, groups=input.shape[1])
|
| 61 |
+
return input
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def focal_to_fov(focal: torch.Tensor):
|
| 65 |
+
return 2 * torch.atan(0.5 / focal)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def fov_to_focal(fov: torch.Tensor):
|
| 69 |
+
return 0.5 / torch.tan(fov / 2)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def angle_diff_vec3(v1: torch.Tensor, v2: torch.Tensor, eps: float = 1e-12):
|
| 73 |
+
return torch.atan2(torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps, (v1 * v2).sum(dim=-1))
|
| 74 |
+
|
| 75 |
+
def intrinsics_to_fov(intrinsics: torch.Tensor):
|
| 76 |
+
"""
|
| 77 |
+
Returns field of view in radians from normalized intrinsics matrix.
|
| 78 |
+
### Parameters:
|
| 79 |
+
- intrinsics: torch.Tensor of shape (..., 3, 3)
|
| 80 |
+
|
| 81 |
+
### Returns:
|
| 82 |
+
- fov_x: torch.Tensor of shape (...)
|
| 83 |
+
- fov_y: torch.Tensor of shape (...)
|
| 84 |
+
"""
|
| 85 |
+
focal_x = intrinsics[..., 0, 0]
|
| 86 |
+
focal_y = intrinsics[..., 1, 1]
|
| 87 |
+
return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def point_map_to_depth_legacy(points: torch.Tensor):
|
| 91 |
+
height, width = points.shape[-3:-1]
|
| 92 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
| 93 |
+
uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
|
| 94 |
+
|
| 95 |
+
# Solve least squares problem
|
| 96 |
+
b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
|
| 97 |
+
A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
|
| 98 |
+
|
| 99 |
+
M = A.transpose(-2, -1) @ A
|
| 100 |
+
solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
|
| 101 |
+
focal, shift = solution.unbind(-1)
|
| 102 |
+
|
| 103 |
+
depth = points[..., 2] + shift[..., None, None]
|
| 104 |
+
fov_x = torch.atan(width / diagonal / focal) * 2
|
| 105 |
+
fov_y = torch.atan(height / diagonal / focal) * 2
|
| 106 |
+
return depth, fov_x, fov_y, shift
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def view_plane_uv_to_focal(uv: torch.Tensor):
|
| 110 |
+
normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype)
|
| 111 |
+
focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12)
|
| 112 |
+
return focal
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
|
| 116 |
+
"""
|
| 117 |
+
Recover the depth map and FoV from a point map with unknown z shift and focal.
|
| 118 |
+
|
| 119 |
+
Note that it assumes:
|
| 120 |
+
- the optical center is at the center of the map
|
| 121 |
+
- the map is undistorted
|
| 122 |
+
- the map is isometric in the x and y directions
|
| 123 |
+
|
| 124 |
+
### Parameters:
|
| 125 |
+
- `points: torch.Tensor` of shape (..., H, W, 3)
|
| 126 |
+
- `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
|
| 127 |
+
|
| 128 |
+
### Returns:
|
| 129 |
+
- `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
|
| 130 |
+
- `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
|
| 131 |
+
"""
|
| 132 |
+
shape = points.shape
|
| 133 |
+
height, width = points.shape[-3], points.shape[-2]
|
| 134 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
| 135 |
+
|
| 136 |
+
points = points.reshape(-1, *shape[-3:])
|
| 137 |
+
mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
|
| 138 |
+
focal = focal.reshape(-1) if focal is not None else None
|
| 139 |
+
uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
|
| 140 |
+
|
| 141 |
+
points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
|
| 142 |
+
uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
|
| 143 |
+
mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
|
| 144 |
+
|
| 145 |
+
uv_lr_np = uv_lr.cpu().numpy()
|
| 146 |
+
points_lr_np = points_lr.detach().cpu().numpy()
|
| 147 |
+
focal_np = focal.cpu().numpy() if focal is not None else None
|
| 148 |
+
mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
|
| 149 |
+
optim_shift, optim_focal = [], []
|
| 150 |
+
for i in range(points.shape[0]):
|
| 151 |
+
points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
|
| 152 |
+
uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
|
| 153 |
+
if uv_lr_i_np.shape[0] < 2:
|
| 154 |
+
optim_focal.append(1)
|
| 155 |
+
optim_shift.append(0)
|
| 156 |
+
continue
|
| 157 |
+
if focal is None:
|
| 158 |
+
optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
|
| 159 |
+
optim_focal.append(float(optim_focal_i))
|
| 160 |
+
else:
|
| 161 |
+
optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
|
| 162 |
+
optim_shift.append(float(optim_shift_i))
|
| 163 |
+
optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
| 164 |
+
|
| 165 |
+
if focal is None:
|
| 166 |
+
optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
| 167 |
+
else:
|
| 168 |
+
optim_focal = focal.reshape(shape[:-3])
|
| 169 |
+
|
| 170 |
+
return optim_focal, optim_shift
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def mask_aware_nearest_resize(
|
| 174 |
+
inputs: Union[torch.Tensor, Sequence[torch.Tensor], None],
|
| 175 |
+
mask: torch.BoolTensor,
|
| 176 |
+
size: Tuple[int, int],
|
| 177 |
+
return_index: bool = False
|
| 178 |
+
) -> Tuple[Union[torch.Tensor, Sequence[torch.Tensor], None], torch.BoolTensor, Tuple[torch.LongTensor, ...]]:
|
| 179 |
+
"""
|
| 180 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
| 181 |
+
|
| 182 |
+
### Parameters
|
| 183 |
+
- `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
|
| 184 |
+
- `mask`: input 2D mask of shape (..., H, W)
|
| 185 |
+
- `size`: target size (target_width, target_height)
|
| 186 |
+
|
| 187 |
+
### Returns
|
| 188 |
+
- `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
|
| 189 |
+
- `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
|
| 190 |
+
- `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension, .
|
| 191 |
+
"""
|
| 192 |
+
height, width = mask.shape[-2:]
|
| 193 |
+
target_width, target_height = size
|
| 194 |
+
device = mask.device
|
| 195 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
| 196 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
|
| 197 |
+
filter_size = filter_h_i * filter_w_i
|
| 198 |
+
padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
|
| 199 |
+
|
| 200 |
+
# Window the original mask and uv
|
| 201 |
+
uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device)
|
| 202 |
+
indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width)
|
| 203 |
+
padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device)
|
| 204 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
| 205 |
+
padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device)
|
| 206 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
| 207 |
+
padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device)
|
| 208 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
| 209 |
+
windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1))
|
| 210 |
+
windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1))
|
| 211 |
+
windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1))
|
| 212 |
+
|
| 213 |
+
# Gather the target pixels's local window
|
| 214 |
+
target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device)
|
| 215 |
+
target_lefttop = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device)
|
| 216 |
+
target_window = torch.round(target_lefttop).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device)
|
| 217 |
+
|
| 218 |
+
target_window_uv = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
| 219 |
+
target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
| 220 |
+
target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
| 221 |
+
target_window_indices = target_window_indices.expand_as(target_window_mask)
|
| 222 |
+
|
| 223 |
+
# Compute nearest neighbor in the local window for each pixel
|
| 224 |
+
dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size)
|
| 225 |
+
nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1)
|
| 226 |
+
nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width)
|
| 227 |
+
target_mask = torch.any(target_window_mask, dim=-1)
|
| 228 |
+
nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
|
| 229 |
+
batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
|
| 230 |
+
|
| 231 |
+
index = (*batch_indices, nearest_i, nearest_j)
|
| 232 |
+
|
| 233 |
+
if inputs is None:
|
| 234 |
+
outputs = None
|
| 235 |
+
elif isinstance(inputs, torch.Tensor):
|
| 236 |
+
outputs = inputs[index]
|
| 237 |
+
elif isinstance(inputs, Sequence):
|
| 238 |
+
outputs = tuple(x[index] for x in inputs)
|
| 239 |
+
else:
|
| 240 |
+
raise ValueError(f'Invalid input type: {type(inputs)}')
|
| 241 |
+
|
| 242 |
+
if return_index:
|
| 243 |
+
return outputs, target_mask, index
|
| 244 |
+
else:
|
| 245 |
+
return outputs, target_mask
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def theshold_depth_change(depth: torch.Tensor, mask: torch.Tensor, pooler: Literal['min', 'max'], rtol: float = 0.2, kernel_size: int = 3):
|
| 249 |
+
*batch_shape, height, width = depth.shape
|
| 250 |
+
depth = depth.reshape(-1, 1, height, width)
|
| 251 |
+
mask = mask.reshape(-1, 1, height, width)
|
| 252 |
+
if pooler =='max':
|
| 253 |
+
pooled_depth = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
|
| 254 |
+
output_mask = pooled_depth > depth * (1 + rtol)
|
| 255 |
+
elif pooler =='min':
|
| 256 |
+
pooled_depth = -F.max_pool2d(-torch.where(mask, depth, torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
|
| 257 |
+
output_mask = pooled_depth < depth * (1 - rtol)
|
| 258 |
+
else:
|
| 259 |
+
raise ValueError(f'Unsupported pooler: {pooler}')
|
| 260 |
+
output_mask = output_mask.reshape(*batch_shape, height, width)
|
| 261 |
+
return output_mask
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
|
| 265 |
+
device, dtype = depth.device, depth.dtype
|
| 266 |
+
|
| 267 |
+
disp = torch.where(mask, 1 / depth, 0)
|
| 268 |
+
disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
|
| 269 |
+
mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
|
| 270 |
+
disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
|
| 271 |
+
mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
|
| 272 |
+
|
| 273 |
+
x = torch.linspace(-kernel_size // 2, kernel_size // 2, kernel_size, device=device, dtype=dtype)
|
| 274 |
+
A = torch.stack([*torch.meshgrid(x, x, indexing='xy'), torch.ones((kernel_size, kernel_size), device=device, dtype=dtype)], dim=-1).reshape(kernel_size ** 2, 3) # [kernel_size ** 2, 3]
|
| 275 |
+
A = mask_window[..., None] * A
|
| 276 |
+
I = torch.eye(3, device=device, dtype=dtype)
|
| 277 |
+
|
| 278 |
+
affine_disp_window = (disp_window[..., None, :] @ A @ torch.inverse(A.mT @ A + 1e-5 * I) @ A.mT).clamp_min(1e-12)[..., 0, :] # [..., H, W, kernel_size ** 2]
|
| 279 |
+
diff = torch.where(mask_window, torch.maximum(affine_disp_window, disp_window) / torch.minimum(affine_disp_window, disp_window) - 1, 0)
|
| 280 |
+
|
| 281 |
+
edge_mask = mask & (diff > tol).any(dim=-1)
|
| 282 |
+
|
| 283 |
+
disp_mean = weighted_mean(disp_window, mask_window, dim=-1)
|
| 284 |
+
fg_edge_mask = edge_mask & (disp > disp_mean)
|
| 285 |
+
# fg_edge_mask = edge_mask & theshold_depth_change(depth, mask, pooler='max', rtol=tol, kernel_size=kernel_size)
|
| 286 |
+
bg_edge_mask = edge_mask & ~fg_edge_mask
|
| 287 |
+
return fg_edge_mask, bg_edge_mask
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
|
| 291 |
+
device, dtype = depth.device, depth.dtype
|
| 292 |
+
|
| 293 |
+
disp = torch.where(mask, 1 / depth, 0)
|
| 294 |
+
disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
|
| 295 |
+
mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
|
| 296 |
+
disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
|
| 297 |
+
mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
|
| 298 |
+
|
| 299 |
+
disp_mean = weighted_mean(disp_window, mask_window, dim=(-2, -1))
|
| 300 |
+
fg_edge_mask = mask & (disp / disp_mean > 1 + tol)
|
| 301 |
+
bg_edge_mask = mask & (disp_mean / disp > 1 + tol)
|
| 302 |
+
|
| 303 |
+
fg_edge_mask = fg_edge_mask & F.max_pool2d(bg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
|
| 304 |
+
bg_edge_mask = bg_edge_mask & F.max_pool2d(fg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
|
| 305 |
+
|
| 306 |
+
return fg_edge_mask, bg_edge_mask
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def dilate_with_mask(input: torch.Tensor, mask: torch.BoolTensor, filter: Literal['min', 'max', 'mean', 'median'] = 'mean', iterations: int = 1) -> torch.Tensor:
|
| 310 |
+
kernel = torch.tensor([[False, True, False], [True, True, True], [False, True, False]], device=input.device, dtype=torch.bool)
|
| 311 |
+
for _ in range(iterations):
|
| 312 |
+
input_window = utils3d.torch.sliding_window_2d(F.pad(input, (1, 1, 1, 1), mode='constant', value=0), window_size=3, stride=1, dim=(-2, -1))
|
| 313 |
+
mask_window = kernel & utils3d.torch.sliding_window_2d(F.pad(mask, (1, 1, 1, 1), mode='constant', value=False), window_size=3, stride=1, dim=(-2, -1))
|
| 314 |
+
if filter =='min':
|
| 315 |
+
input = torch.where(mask, input, torch.where(mask_window, input_window, torch.inf).min(dim=(-2, -1)).values)
|
| 316 |
+
elif filter =='max':
|
| 317 |
+
input = torch.where(mask, input, torch.where(mask_window, input_window, -torch.inf).max(dim=(-2, -1)).values)
|
| 318 |
+
elif filter == 'mean':
|
| 319 |
+
input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).nanmean(dim=(-2, -1)))
|
| 320 |
+
elif filter =='median':
|
| 321 |
+
input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).flatten(-2).nanmedian(dim=-1).values)
|
| 322 |
+
mask = mask_window.any(dim=(-2, -1))
|
| 323 |
+
return input, mask
|
models/SpaTrackV2/models/tracker3D/spatrack_modules/pointmap_updator.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from models.SpaTrackV2.models.blocks import bilinear_sampler
|
| 4 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.alignment import align_points_scale, align_points_scale_xyz_shift
|
| 5 |
+
|
| 6 |
+
def compute_affine_scale_and_shift(points, pointmap, mask, weights=None, eps=1e-6):
|
| 7 |
+
"""
|
| 8 |
+
Compute global affine transform (scale * pointmap + shift = points)
|
| 9 |
+
using least-squares fitting with optional weights and mask.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
points (BT, N, 3): Target points
|
| 13 |
+
pointmap (BT, N, 3): Source points
|
| 14 |
+
mask (BT, N): Binary mask indicating valid points
|
| 15 |
+
weights (BT, N): Optional weights per point
|
| 16 |
+
eps (float): Numerical stability
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
scale (BT, 1): Scalar scale per batch
|
| 20 |
+
shift (BT, 3): Shift vector per batch
|
| 21 |
+
"""
|
| 22 |
+
if weights is None:
|
| 23 |
+
weights = mask.float()
|
| 24 |
+
else:
|
| 25 |
+
weights = weights * mask # combine mask
|
| 26 |
+
|
| 27 |
+
# Sum of weights
|
| 28 |
+
weight_sum = weights.sum(dim=1, keepdim=True) + eps # (BT, 1)
|
| 29 |
+
|
| 30 |
+
# Compute weighted centroids
|
| 31 |
+
centroid_p = (points * weights.unsqueeze(-1)).sum(dim=1) / weight_sum # (BT, 3)
|
| 32 |
+
centroid_m = (pointmap * weights.unsqueeze(-1)).sum(dim=1) / weight_sum # (BT, 3)
|
| 33 |
+
|
| 34 |
+
# Center the point sets
|
| 35 |
+
p_centered = points - centroid_p.unsqueeze(1) # (BT, N, 3)
|
| 36 |
+
m_centered = pointmap - centroid_m.unsqueeze(1) # (BT, N, 3)
|
| 37 |
+
|
| 38 |
+
# Compute scale: ratio of dot products
|
| 39 |
+
numerator = (weights.unsqueeze(-1) * (p_centered * m_centered)).sum(dim=1).sum(dim=-1) # (BT,)
|
| 40 |
+
denominator = (weights.unsqueeze(-1) * (m_centered ** 2)).sum(dim=1).sum(dim=-1) + eps # (BT,)
|
| 41 |
+
scale = (numerator / denominator).unsqueeze(-1) # (BT, 1)
|
| 42 |
+
|
| 43 |
+
# Compute shift: t = c_p - s * c_m
|
| 44 |
+
shift = centroid_p - scale * centroid_m # (BT, 3)
|
| 45 |
+
|
| 46 |
+
return scale, shift
|
| 47 |
+
|
| 48 |
+
def compute_weighted_std(track2d, vis_est, eps=1e-6):
|
| 49 |
+
"""
|
| 50 |
+
Compute the weighted standard deviation of 2D tracks across time.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
track2d (Tensor): shape (B, T, N, 2), 2D tracked points.
|
| 54 |
+
vis_est (Tensor): shape (B, T, N), visibility weights (0~1).
|
| 55 |
+
eps (float): small epsilon to avoid division by zero.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
std (Tensor): shape (B, N, 2), weighted standard deviation for each point.
|
| 59 |
+
"""
|
| 60 |
+
B, T, N, _ = track2d.shape
|
| 61 |
+
|
| 62 |
+
# Compute weighted mean
|
| 63 |
+
weighted_sum = (track2d * vis_est[..., None]).sum(dim=1) # (B, N, 2)
|
| 64 |
+
weight_sum = vis_est.sum(dim=1)[..., None] + eps # (B, N, 1)
|
| 65 |
+
track_mean = weighted_sum / weight_sum # (B, N, 2)
|
| 66 |
+
|
| 67 |
+
# Compute squared residuals
|
| 68 |
+
residuals = track2d - track_mean[:, None, :, :] # (B, T, N, 2)
|
| 69 |
+
weighted_sq_res = (residuals ** 2) * vis_est[..., None] # (B, T, N, 2)
|
| 70 |
+
|
| 71 |
+
# Compute weighted variance and std
|
| 72 |
+
var = weighted_sq_res.sum(dim=1) / (weight_sum + eps) # (B, N, 2)
|
| 73 |
+
std = var.sqrt() # (B, N, 2)
|
| 74 |
+
|
| 75 |
+
return std
|
| 76 |
+
|
| 77 |
+
class PointMapUpdator(nn.Module):
|
| 78 |
+
def __init__(self, stablizer):
|
| 79 |
+
super(PointMapUpdator, self).__init__()
|
| 80 |
+
self.stablizer = stablizer()
|
| 81 |
+
|
| 82 |
+
def init_pointmap(self, points_map):
|
| 83 |
+
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
def scale_update_from_tracks(self, cam_pts_est, coords_append, point_map_org, vis_est, reproj_loss):
|
| 87 |
+
B, T, N, _ = coords_append.shape
|
| 88 |
+
track2d = coords_append[...,:2].view(B*T, N, 2)
|
| 89 |
+
|
| 90 |
+
track_len_std = compute_weighted_std(track2d.view(B, T, N, 2), vis_est.view(B, T, N)).norm(dim=-1)
|
| 91 |
+
|
| 92 |
+
point_samp = bilinear_sampler(point_map_org, track2d[:,None], mode="nearest")
|
| 93 |
+
point_samp = point_samp.permute(0,3,1,2).view(B*T, N, 3)
|
| 94 |
+
cam_pts_est = cam_pts_est.view(B*T, N, 3)
|
| 95 |
+
# mask
|
| 96 |
+
mask = vis_est.view(B*T, N)
|
| 97 |
+
# using gaussian weights, mean is 2 pixels
|
| 98 |
+
nm_reproj_loss = (reproj_loss.view(B*T, N) / (track_len_std.view(B, N) + 1e-6)).clamp(0, 5)
|
| 99 |
+
std = nm_reproj_loss.std(dim=-1).view(B*T, 1) # B*T 1
|
| 100 |
+
weights = torch.exp(-(0.5-nm_reproj_loss.view(B*T, N))**2 / (2*std**2))
|
| 101 |
+
mask = mask*(point_samp[...,2]>0)*(cam_pts_est[...,2]>0)*weights
|
| 102 |
+
scales, shift = align_points_scale_xyz_shift(point_samp, cam_pts_est, mask)
|
| 103 |
+
|
| 104 |
+
return scales, shift
|
models/SpaTrackV2/models/tracker3D/spatrack_modules/simple_vit_1d.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from einops.layers.torch import Rearrange
|
| 6 |
+
|
| 7 |
+
# helpers
|
| 8 |
+
|
| 9 |
+
def posemb_sincos_1d(patches, temperature = 10000, dtype = torch.float32):
|
| 10 |
+
_, n, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
| 11 |
+
|
| 12 |
+
n = torch.arange(n, device = device)
|
| 13 |
+
assert (dim % 2) == 0, 'feature dimension must be multiple of 2 for sincos emb'
|
| 14 |
+
omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
|
| 15 |
+
omega = 1. / (temperature ** omega)
|
| 16 |
+
|
| 17 |
+
n = n.flatten()[:, None] * omega[None, :]
|
| 18 |
+
pe = torch.cat((n.sin(), n.cos()), dim = 1)
|
| 19 |
+
return pe.type(dtype)
|
| 20 |
+
|
| 21 |
+
# classes
|
| 22 |
+
|
| 23 |
+
class FeedForward(nn.Module):
|
| 24 |
+
def __init__(self, dim, hidden_dim):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.net = nn.Sequential(
|
| 27 |
+
nn.LayerNorm(dim),
|
| 28 |
+
nn.Linear(dim, hidden_dim),
|
| 29 |
+
nn.GELU(),
|
| 30 |
+
nn.Linear(hidden_dim, dim),
|
| 31 |
+
)
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
return self.net(x)
|
| 34 |
+
|
| 35 |
+
class Attention(nn.Module):
|
| 36 |
+
def __init__(self, dim, heads = 8, dim_head = 64):
|
| 37 |
+
super().__init__()
|
| 38 |
+
inner_dim = dim_head * heads
|
| 39 |
+
self.heads = heads
|
| 40 |
+
self.scale = dim_head ** -0.5
|
| 41 |
+
self.norm = nn.LayerNorm(dim)
|
| 42 |
+
|
| 43 |
+
self.attend = nn.Softmax(dim = -1)
|
| 44 |
+
|
| 45 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
| 46 |
+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
x = self.norm(x)
|
| 50 |
+
|
| 51 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
| 52 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
| 53 |
+
|
| 54 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 55 |
+
|
| 56 |
+
attn = self.attend(dots)
|
| 57 |
+
|
| 58 |
+
out = torch.matmul(attn, v)
|
| 59 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 60 |
+
return self.to_out(out)
|
| 61 |
+
|
| 62 |
+
class Transformer(nn.Module):
|
| 63 |
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.norm = nn.LayerNorm(dim)
|
| 66 |
+
self.layers = nn.ModuleList([])
|
| 67 |
+
for _ in range(depth):
|
| 68 |
+
self.layers.append(nn.ModuleList([
|
| 69 |
+
Attention(dim, heads = heads, dim_head = dim_head),
|
| 70 |
+
FeedForward(dim, mlp_dim)
|
| 71 |
+
]))
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
for attn, ff in self.layers:
|
| 74 |
+
x = attn(x) + x
|
| 75 |
+
x = ff(x) + x
|
| 76 |
+
return self.norm(x)
|
| 77 |
+
|
| 78 |
+
class SimpleViT(nn.Module):
|
| 79 |
+
def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
assert seq_len % patch_size == 0
|
| 83 |
+
|
| 84 |
+
num_patches = seq_len // patch_size
|
| 85 |
+
patch_dim = channels * patch_size
|
| 86 |
+
|
| 87 |
+
self.to_patch_embedding = nn.Sequential(
|
| 88 |
+
Rearrange('b c (n p) -> b n (p c)', p = patch_size),
|
| 89 |
+
nn.LayerNorm(patch_dim),
|
| 90 |
+
nn.Linear(patch_dim, dim),
|
| 91 |
+
nn.LayerNorm(dim),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
| 95 |
+
|
| 96 |
+
self.to_latent = nn.Identity()
|
| 97 |
+
self.linear_head = nn.Linear(dim, num_classes)
|
| 98 |
+
|
| 99 |
+
def forward(self, series):
|
| 100 |
+
*_, n, dtype = *series.shape, series.dtype
|
| 101 |
+
|
| 102 |
+
x = self.to_patch_embedding(series)
|
| 103 |
+
pe = posemb_sincos_1d(x)
|
| 104 |
+
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
| 105 |
+
|
| 106 |
+
x = self.transformer(x)
|
| 107 |
+
x = x.mean(dim = 1)
|
| 108 |
+
|
| 109 |
+
x = self.to_latent(x)
|
| 110 |
+
return self.linear_head(x)
|
| 111 |
+
|
| 112 |
+
if __name__ == '__main__':
|
| 113 |
+
|
| 114 |
+
v = SimpleViT(
|
| 115 |
+
seq_len = 256,
|
| 116 |
+
patch_size = 16,
|
| 117 |
+
num_classes = 1000,
|
| 118 |
+
dim = 1024,
|
| 119 |
+
depth = 6,
|
| 120 |
+
heads = 8,
|
| 121 |
+
mlp_dim = 2048
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
time_series = torch.randn(4, 3, 256)
|
| 125 |
+
logits = v(time_series) # (4, 1000)
|
models/SpaTrackV2/models/tracker3D/spatrack_modules/tools.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from numbers import Number
|
| 5 |
+
from functools import wraps
|
| 6 |
+
import warnings
|
| 7 |
+
import math
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import importlib
|
| 11 |
+
import importlib.util
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def catch_exception(fn):
|
| 15 |
+
@wraps(fn)
|
| 16 |
+
def wrapper(*args, **kwargs):
|
| 17 |
+
try:
|
| 18 |
+
return fn(*args, **kwargs)
|
| 19 |
+
except Exception as e:
|
| 20 |
+
import traceback
|
| 21 |
+
print(f"Exception in {fn.__name__}", end='r')
|
| 22 |
+
# print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})
|
| 23 |
+
traceback.print_exc(chain=False)
|
| 24 |
+
time.sleep(0.1)
|
| 25 |
+
return None
|
| 26 |
+
return wrapper
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CallbackOnException:
|
| 30 |
+
def __init__(self, callback: Callable, exception: type):
|
| 31 |
+
self.exception = exception
|
| 32 |
+
self.callback = callback
|
| 33 |
+
|
| 34 |
+
def __enter__(self):
|
| 35 |
+
return self
|
| 36 |
+
|
| 37 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 38 |
+
if isinstance(exc_val, self.exception):
|
| 39 |
+
self.callback()
|
| 40 |
+
return True
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
|
| 44 |
+
for k, v in d.items():
|
| 45 |
+
if isinstance(v, dict):
|
| 46 |
+
for sub_key in traverse_nested_dict_keys(v):
|
| 47 |
+
yield (k, ) + sub_key
|
| 48 |
+
else:
|
| 49 |
+
yield (k, )
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
|
| 53 |
+
for k in keys:
|
| 54 |
+
d = d.get(k, default)
|
| 55 |
+
if d is None:
|
| 56 |
+
break
|
| 57 |
+
return d
|
| 58 |
+
|
| 59 |
+
def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
|
| 60 |
+
for k in keys[:-1]:
|
| 61 |
+
d = d.setdefault(k, {})
|
| 62 |
+
d[keys[-1]] = value
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def key_average(list_of_dicts: list) -> Dict[str, Any]:
|
| 66 |
+
"""
|
| 67 |
+
Returns a dictionary with the average value of each key in the input list of dictionaries.
|
| 68 |
+
"""
|
| 69 |
+
_nested_dict_keys = set()
|
| 70 |
+
for d in list_of_dicts:
|
| 71 |
+
_nested_dict_keys.update(traverse_nested_dict_keys(d))
|
| 72 |
+
_nested_dict_keys = sorted(_nested_dict_keys)
|
| 73 |
+
result = {}
|
| 74 |
+
for k in _nested_dict_keys:
|
| 75 |
+
values = []
|
| 76 |
+
for d in list_of_dicts:
|
| 77 |
+
v = get_nested_dict(d, k)
|
| 78 |
+
if v is not None and not math.isnan(v):
|
| 79 |
+
values.append(v)
|
| 80 |
+
avg = sum(values) / len(values) if values else float('nan')
|
| 81 |
+
set_nested_dict(result, k, avg)
|
| 82 |
+
return result
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
|
| 86 |
+
"""
|
| 87 |
+
Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
|
| 88 |
+
"""
|
| 89 |
+
items = []
|
| 90 |
+
if parent_key is None:
|
| 91 |
+
parent_key = ()
|
| 92 |
+
for k, v in d.items():
|
| 93 |
+
new_key = parent_key + (k, )
|
| 94 |
+
if isinstance(v, MutableMapping):
|
| 95 |
+
items.extend(flatten_nested_dict(v, new_key).items())
|
| 96 |
+
else:
|
| 97 |
+
items.append((new_key, v))
|
| 98 |
+
return dict(items)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
| 102 |
+
"""
|
| 103 |
+
Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
|
| 104 |
+
"""
|
| 105 |
+
result = {}
|
| 106 |
+
for k, v in d.items():
|
| 107 |
+
sub_dict = result
|
| 108 |
+
for k_ in k[:-1]:
|
| 109 |
+
if k_ not in sub_dict:
|
| 110 |
+
sub_dict[k_] = {}
|
| 111 |
+
sub_dict = sub_dict[k_]
|
| 112 |
+
sub_dict[k[-1]] = v
|
| 113 |
+
return result
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def read_jsonl(file):
|
| 117 |
+
import json
|
| 118 |
+
with open(file, 'r') as f:
|
| 119 |
+
data = f.readlines()
|
| 120 |
+
return [json.loads(line) for line in data]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def write_jsonl(data: List[dict], file):
|
| 124 |
+
import json
|
| 125 |
+
with open(file, 'w') as f:
|
| 126 |
+
for item in data:
|
| 127 |
+
f.write(json.dumps(item) + '\n')
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
|
| 131 |
+
import pandas as pd
|
| 132 |
+
data = [flatten_nested_dict(d) for d in data]
|
| 133 |
+
df = pd.DataFrame(data)
|
| 134 |
+
df = df.sort_index(axis=1)
|
| 135 |
+
df.columns = pd.MultiIndex.from_tuples(df.columns)
|
| 136 |
+
return df
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
|
| 140 |
+
if isinstance(d, str):
|
| 141 |
+
for old, new in mapping.items():
|
| 142 |
+
d = d.replace(old, new)
|
| 143 |
+
elif isinstance(d, list):
|
| 144 |
+
for i, item in enumerate(d):
|
| 145 |
+
d[i] = recursive_replace(item, mapping)
|
| 146 |
+
elif isinstance(d, dict):
|
| 147 |
+
for k, v in d.items():
|
| 148 |
+
d[k] = recursive_replace(v, mapping)
|
| 149 |
+
return d
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class timeit:
|
| 153 |
+
_history: Dict[str, List['timeit']] = {}
|
| 154 |
+
|
| 155 |
+
def __init__(self, name: str = None, verbose: bool = True, average: bool = False):
|
| 156 |
+
self.name = name
|
| 157 |
+
self.verbose = verbose
|
| 158 |
+
self.start = None
|
| 159 |
+
self.end = None
|
| 160 |
+
self.average = average
|
| 161 |
+
if average and name not in timeit._history:
|
| 162 |
+
timeit._history[name] = []
|
| 163 |
+
|
| 164 |
+
def __call__(self, func: Callable):
|
| 165 |
+
import inspect
|
| 166 |
+
if inspect.iscoroutinefunction(func):
|
| 167 |
+
async def wrapper(*args, **kwargs):
|
| 168 |
+
with timeit(self.name or func.__qualname__):
|
| 169 |
+
ret = await func(*args, **kwargs)
|
| 170 |
+
return ret
|
| 171 |
+
return wrapper
|
| 172 |
+
else:
|
| 173 |
+
def wrapper(*args, **kwargs):
|
| 174 |
+
with timeit(self.name or func.__qualname__):
|
| 175 |
+
ret = func(*args, **kwargs)
|
| 176 |
+
return ret
|
| 177 |
+
return wrapper
|
| 178 |
+
|
| 179 |
+
def __enter__(self):
|
| 180 |
+
self.start = time.time()
|
| 181 |
+
return self
|
| 182 |
+
|
| 183 |
+
@property
|
| 184 |
+
def time(self) -> float:
|
| 185 |
+
assert self.start is not None, "Time not yet started."
|
| 186 |
+
assert self.end is not None, "Time not yet ended."
|
| 187 |
+
return self.end - self.start
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def average_time(self) -> float:
|
| 191 |
+
assert self.average, "Average time not available."
|
| 192 |
+
return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def history(self) -> List['timeit']:
|
| 196 |
+
return timeit._history.get(self.name, [])
|
| 197 |
+
|
| 198 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 199 |
+
self.end = time.time()
|
| 200 |
+
if self.average:
|
| 201 |
+
timeit._history[self.name].append(self)
|
| 202 |
+
if self.verbose:
|
| 203 |
+
if self.average:
|
| 204 |
+
avg = self.average_time
|
| 205 |
+
print(f"{self.name or 'It'} took {avg:.6f} seconds in average.")
|
| 206 |
+
else:
|
| 207 |
+
print(f"{self.name or 'It'} took {self.time:.6f} seconds.")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
|
| 211 |
+
first = strings[0]
|
| 212 |
+
|
| 213 |
+
for start in range(len(first)):
|
| 214 |
+
if any(s[start] != strings[0][start] for s in strings):
|
| 215 |
+
break
|
| 216 |
+
|
| 217 |
+
for end in range(1, min(len(s) for s in strings)):
|
| 218 |
+
if any(s[-end] != first[-end] for s in strings):
|
| 219 |
+
break
|
| 220 |
+
|
| 221 |
+
return [s[start:len(s) - end + 1] for s in strings]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
|
| 225 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 226 |
+
from contextlib import nullcontext
|
| 227 |
+
from tqdm import tqdm
|
| 228 |
+
|
| 229 |
+
if pbar is not None:
|
| 230 |
+
pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
|
| 231 |
+
else:
|
| 232 |
+
pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)
|
| 233 |
+
|
| 234 |
+
def decorator(fn: Callable):
|
| 235 |
+
with (
|
| 236 |
+
ThreadPoolExecutor(max_workers=num_workers) as executor,
|
| 237 |
+
pbar
|
| 238 |
+
):
|
| 239 |
+
pbar.refresh()
|
| 240 |
+
@catch_exception
|
| 241 |
+
@suppress_traceback
|
| 242 |
+
def _fn(input):
|
| 243 |
+
ret = fn(input)
|
| 244 |
+
pbar.update()
|
| 245 |
+
return ret
|
| 246 |
+
executor.map(_fn, inputs)
|
| 247 |
+
executor.shutdown(wait=True)
|
| 248 |
+
|
| 249 |
+
return decorator
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def suppress_traceback(fn):
|
| 253 |
+
@wraps(fn)
|
| 254 |
+
def wrapper(*args, **kwargs):
|
| 255 |
+
try:
|
| 256 |
+
return fn(*args, **kwargs)
|
| 257 |
+
except Exception as e:
|
| 258 |
+
e.__traceback__ = e.__traceback__.tb_next.tb_next
|
| 259 |
+
raise
|
| 260 |
+
return wrapper
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class no_warnings:
|
| 264 |
+
def __init__(self, action: str = 'ignore', **kwargs):
|
| 265 |
+
self.action = action
|
| 266 |
+
self.filter_kwargs = kwargs
|
| 267 |
+
|
| 268 |
+
def __call__(self, fn):
|
| 269 |
+
@wraps(fn)
|
| 270 |
+
def wrapper(*args, **kwargs):
|
| 271 |
+
with warnings.catch_warnings():
|
| 272 |
+
warnings.simplefilter(self.action, **self.filter_kwargs)
|
| 273 |
+
return fn(*args, **kwargs)
|
| 274 |
+
return wrapper
|
| 275 |
+
|
| 276 |
+
def __enter__(self):
|
| 277 |
+
self.warnings_manager = warnings.catch_warnings()
|
| 278 |
+
self.warnings_manager.__enter__()
|
| 279 |
+
warnings.simplefilter(self.action, **self.filter_kwargs)
|
| 280 |
+
|
| 281 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 282 |
+
self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str):
|
| 286 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 287 |
+
module = importlib.util.module_from_spec(spec)
|
| 288 |
+
spec.loader.exec_module(module)
|
| 289 |
+
return module
|
models/SpaTrackV2/models/tracker3D/spatrack_modules/utils.py
ADDED
|
@@ -0,0 +1,1006 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
import torch
|
| 3 |
+
import torch.amp
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
|
| 7 |
+
EfficientUpdateFormer, AttnBlock, Attention, CrossAttnBlock,
|
| 8 |
+
sequence_BCE_loss, sequence_loss, sequence_prob_loss, sequence_dyn_prob_loss
|
| 9 |
+
)
|
| 10 |
+
import math
|
| 11 |
+
from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
|
| 12 |
+
Mlp, BasicEncoder, EfficientUpdateFormer, GeometryEncoder, NeighborTransformer
|
| 13 |
+
)
|
| 14 |
+
import numpy as np
|
| 15 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.simple_vit_1d import Transformer,posemb_sincos_1d
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
|
| 18 |
+
def self_grid_pos_embedding(B, T, H, W, level=None):
|
| 19 |
+
import pdb; pdb.set_trace()
|
| 20 |
+
|
| 21 |
+
def random_se3_transformation(
|
| 22 |
+
batch_size: int = 1,
|
| 23 |
+
max_rotation_angle: float = math.pi,
|
| 24 |
+
max_translation: float = 1.0,
|
| 25 |
+
device: str = "cpu",
|
| 26 |
+
dtype: torch.dtype = torch.float32,
|
| 27 |
+
) -> torch.Tensor:
|
| 28 |
+
"""
|
| 29 |
+
随机生成刚体变换矩阵(SE(3) Transformation Matrix)。
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
batch_size (int): 批大小,默认为 1。
|
| 33 |
+
max_rotation_angle (float): 最大旋转角度(弧度),默认 π(180°)。
|
| 34 |
+
max_translation (float): 最大平移量,默认 1.0。
|
| 35 |
+
device (str): 设备('cpu' 或 'cuda')。
|
| 36 |
+
dtype (torch.dtype): 数据类型(推荐 float32)。
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
torch.Tensor: 形状为 (batch_size, 4, 4) 的齐次变换矩阵。
|
| 40 |
+
"""
|
| 41 |
+
# 随机生成旋转矩阵 R (batch_size, 3, 3)
|
| 42 |
+
# 方法 1:使用轴角表示(Axis-Angle)转换为旋转矩阵
|
| 43 |
+
axis = torch.randn(batch_size, 3, device=device, dtype=dtype) # 随机旋转轴
|
| 44 |
+
axis = axis / torch.norm(axis, dim=1, keepdim=True) # 归一化
|
| 45 |
+
angle = torch.rand(batch_size, 1, device=device, dtype=dtype) * max_rotation_angle # 随机角度 [0, max_angle]
|
| 46 |
+
|
| 47 |
+
# 计算旋转矩阵(Rodrigues' rotation formula)
|
| 48 |
+
K = torch.zeros(batch_size, 3, 3, device=device, dtype=dtype)
|
| 49 |
+
K[:, 0, 1] = -axis[:, 2]
|
| 50 |
+
K[:, 0, 2] = axis[:, 1]
|
| 51 |
+
K[:, 1, 0] = axis[:, 2]
|
| 52 |
+
K[:, 1, 2] = -axis[:, 0]
|
| 53 |
+
K[:, 2, 0] = -axis[:, 1]
|
| 54 |
+
K[:, 2, 1] = axis[:, 0]
|
| 55 |
+
|
| 56 |
+
I = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1)
|
| 57 |
+
R = I + torch.sin(angle).unsqueeze(-1) * K + (1 - torch.cos(angle).unsqueeze(-1)) * (K @ K)
|
| 58 |
+
|
| 59 |
+
# 随机生成平移向量 t (batch_size, 3)
|
| 60 |
+
t = (torch.rand(batch_size, 3, device=device, dtype=dtype) - 0.5) * 2 * max_translation
|
| 61 |
+
|
| 62 |
+
# 组合成齐次变换矩阵 T (batch_size, 4, 4)
|
| 63 |
+
T = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1)
|
| 64 |
+
T[:, :3, :3] = R
|
| 65 |
+
T[:, :3, 3] = t
|
| 66 |
+
|
| 67 |
+
return T
|
| 68 |
+
|
| 69 |
+
def weighted_procrustes_torch(X, Y, W=None, RT=None):
|
| 70 |
+
"""
|
| 71 |
+
Weighted Procrustes Analysis in PyTorch (batched).
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
X: (B, 1, N, 3), source point cloud.
|
| 75 |
+
Y: (B, T, N, 3), target point cloud.
|
| 76 |
+
W: (B, T, N) or (B, 1, N), optional weights for each point.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
t: (B, T, 3), optimal translation vectors.
|
| 80 |
+
R: (B, T, 3, 3), optimal rotation matrices.
|
| 81 |
+
"""
|
| 82 |
+
device = X.device
|
| 83 |
+
B, T, N, _ = Y.shape
|
| 84 |
+
|
| 85 |
+
# Default weights: uniform
|
| 86 |
+
if W is None:
|
| 87 |
+
W = torch.ones(B, 1, N, device=device)
|
| 88 |
+
elif W.dim() == 3: # (B, T, N) -> expand to match Y
|
| 89 |
+
W = W.unsqueeze(-1) # (B, T, N, 1)
|
| 90 |
+
else: # (B, 1, N)
|
| 91 |
+
W = W.unsqueeze(-1).expand(B, T, N, 1)
|
| 92 |
+
|
| 93 |
+
# Reshape X to (B, T, N, 3) by broadcasting
|
| 94 |
+
X = X.expand(B, T, N, 3)
|
| 95 |
+
|
| 96 |
+
# Compute weighted centroids
|
| 97 |
+
sum_W = torch.sum(W, dim=2, keepdim=True) # (B, T, 1, 1)
|
| 98 |
+
centroid_X = torch.sum(W * X, dim=2) / sum_W.squeeze(-1) # (B, T, 3)
|
| 99 |
+
centroid_Y = torch.sum(W * Y, dim=2) / sum_W.squeeze(-1) # (B, T, 3)
|
| 100 |
+
|
| 101 |
+
# Center the point clouds
|
| 102 |
+
X_centered = X - centroid_X.unsqueeze(2) # (B, T, N, 3)
|
| 103 |
+
Y_centered = Y - centroid_Y.unsqueeze(2) # (B, T, N, 3)
|
| 104 |
+
|
| 105 |
+
# Compute weighted covariance matrix H = X^T W Y
|
| 106 |
+
X_weighted = X_centered * W # (B, T, N, 3)
|
| 107 |
+
H = torch.matmul(X_weighted.transpose(2, 3), Y_centered) # (B, T, 3, 3)
|
| 108 |
+
|
| 109 |
+
# SVD decomposition
|
| 110 |
+
U, S, Vt = torch.linalg.svd(H) # U/Vt: (B, T, 3, 3)
|
| 111 |
+
|
| 112 |
+
# Ensure right-handed rotation (det(R) = +1)
|
| 113 |
+
det = torch.det(torch.matmul(U, Vt)) # (B, T)
|
| 114 |
+
Vt_corrected = Vt.clone()
|
| 115 |
+
mask = det < 0
|
| 116 |
+
B_idx, T_idx = torch.nonzero(mask, as_tuple=True)
|
| 117 |
+
Vt_corrected[B_idx, T_idx, -1, :] *= -1 # Flip last row for those needing correction
|
| 118 |
+
|
| 119 |
+
# Optimal rotation and translation
|
| 120 |
+
R = torch.matmul(U, Vt_corrected).inverse() # (B, T, 3, 3)
|
| 121 |
+
t = centroid_Y - torch.matmul(R, centroid_X.unsqueeze(-1)).squeeze(-1) # (B, T, 3)
|
| 122 |
+
w2c = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).repeat(B, T, 1, 1)
|
| 123 |
+
if (torch.det(R) - 1).abs().max() < 1e-3:
|
| 124 |
+
w2c[:, :, :3, :3] = R
|
| 125 |
+
else:
|
| 126 |
+
import pdb; pdb.set_trace()
|
| 127 |
+
w2c[:, :, :3, 3] = t
|
| 128 |
+
try:
|
| 129 |
+
c2w_traj = torch.inverse(w2c) # or torch.linalg.inv()
|
| 130 |
+
except:
|
| 131 |
+
c2w_traj = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).repeat(B, T, 1, 1)
|
| 132 |
+
|
| 133 |
+
return c2w_traj
|
| 134 |
+
|
| 135 |
+
def key_fr_wprocrustes(cam_pts, graph_matrix, dyn_weight, vis_mask,slide_len=16, overlap=8, K=3, mode="keyframe"):
|
| 136 |
+
"""
|
| 137 |
+
cam_pts: (B, T, N, 3)
|
| 138 |
+
graph_matrix: (B, 1, N)
|
| 139 |
+
dyn_weight: (B, T, N)
|
| 140 |
+
K: number of keyframes to select (including start and end)
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
c2w_traj: (B, T, 4, 4)
|
| 144 |
+
"""
|
| 145 |
+
B, T, N, _ = cam_pts.shape
|
| 146 |
+
device = cam_pts.device
|
| 147 |
+
|
| 148 |
+
if mode == "keyframe":
|
| 149 |
+
# Step 1: Keyframe selection
|
| 150 |
+
ky_fr_idx = [0, T - 1]
|
| 151 |
+
graph_sum = torch.sum(graph_matrix, dim=-1) # (B, T, T)
|
| 152 |
+
dist = torch.max(graph_sum[:, 0, :], graph_sum[:, T - 1, :]) # (B, T)
|
| 153 |
+
dist[:, [0, T - 1]] = float('inf')
|
| 154 |
+
for _ in range(K - 2): # already have 2
|
| 155 |
+
last_idx = ky_fr_idx[-1]
|
| 156 |
+
dist = torch.max(dist, graph_sum[:, last_idx, :])
|
| 157 |
+
dist[:, last_idx] = float('inf')
|
| 158 |
+
next_id = torch.argmin(dist, dim=1)[0].item() # Assuming batch=1 or shared
|
| 159 |
+
ky_fr_idx.append(next_id)
|
| 160 |
+
|
| 161 |
+
ky_fr_idx = sorted(ky_fr_idx)
|
| 162 |
+
elif mode == "slide":
|
| 163 |
+
id_slide = torch.arange(0, T)
|
| 164 |
+
id_slide = id_slide.unfold(0, slide_len, overlap)
|
| 165 |
+
vis_mask_slide = vis_mask.unfold(1, slide_len, overlap)
|
| 166 |
+
cam_pts_slide = cam_pts.unfold(1, slide_len, overlap)
|
| 167 |
+
ky_fr_idx = torch.arange(0, T - slide_len + 1, overlap)
|
| 168 |
+
if ky_fr_idx[-1] + slide_len < T:
|
| 169 |
+
# if the last keyframe does not cover the whole sequence, add one more keyframe
|
| 170 |
+
ky_fr_idx = torch.cat([ky_fr_idx, ky_fr_idx[-1:] + overlap])
|
| 171 |
+
id_add = torch.arange(ky_fr_idx[-1], ky_fr_idx[-1] + slide_len).clamp(max=T-1)
|
| 172 |
+
id_slide = torch.cat([id_slide, id_add[None, :]], dim=0)
|
| 173 |
+
cam_pts_add = cam_pts[:, id_add, :, :]
|
| 174 |
+
cam_pts_slide = torch.cat([cam_pts_slide, cam_pts_add.permute(0,2,3,1)[:, None, ...]], dim=1)
|
| 175 |
+
vis_mask_add = vis_mask[:, id_add, :]
|
| 176 |
+
vis_mask_slide = torch.cat([vis_mask_slide, vis_mask_add.permute(0,2,3,1)[:, None, ...]], dim=1)
|
| 177 |
+
|
| 178 |
+
if mode == "keyframe":
|
| 179 |
+
# Step 2: Weighted Procrustes in windows
|
| 180 |
+
base_pose = torch.eye(4, device=cam_pts.device).view(1, 1, 4, 4).repeat(B, 1, 1, 1) # (B, 1, 4, 4)
|
| 181 |
+
c2w_traj_out = []
|
| 182 |
+
for i in range(len(ky_fr_idx) - 1):
|
| 183 |
+
start_idx = ky_fr_idx[i]
|
| 184 |
+
end_idx = ky_fr_idx[i + 1]
|
| 185 |
+
|
| 186 |
+
# Visibility mask
|
| 187 |
+
vis_mask_i = graph_matrix[:, start_idx, end_idx, :] # (B, N) or (N,)
|
| 188 |
+
if vis_mask_i.dim() == 1:
|
| 189 |
+
vis_mask_i = vis_mask_i.unsqueeze(0) # (1, N)
|
| 190 |
+
|
| 191 |
+
# Broadcast cam_pts and dyn_weight
|
| 192 |
+
cam_ref = cam_pts[:, start_idx:start_idx+1, :, :] # (B, 1, M, 3)
|
| 193 |
+
cam_win = cam_pts[:, start_idx:end_idx+1, :, :] # (B, W, M, 3)
|
| 194 |
+
weight = dyn_weight[:, :, :] * vis_mask_i[:, None, :] # (B, W, M)
|
| 195 |
+
|
| 196 |
+
# Compute relative transformations
|
| 197 |
+
if weight.sum() < 50:
|
| 198 |
+
weight = weight.clamp(min=5e-2)
|
| 199 |
+
relative_tfms = weighted_procrustes_torch(cam_ref, cam_win, weight) # (B, W, 4, 4)
|
| 200 |
+
|
| 201 |
+
# Apply to original c2w_traj
|
| 202 |
+
updated_pose = base_pose.detach() @ relative_tfms # (B, W, 4, 4)
|
| 203 |
+
base_pose = relative_tfms[:, -1:, :, :].detach() # (B, 1, 4, 4)
|
| 204 |
+
|
| 205 |
+
# Assign to output trajectory (avoid in-place on autograd path)
|
| 206 |
+
c2w_traj_out.append(updated_pose[:, 1:, ...])
|
| 207 |
+
|
| 208 |
+
c2w_traj_out = torch.cat(c2w_traj_out, dim=1)
|
| 209 |
+
c2w_traj_out = torch.cat([torch.eye(4, device=device).repeat(B, 1, 1, 1), c2w_traj_out], dim=1)
|
| 210 |
+
elif mode == "slide":
|
| 211 |
+
c2w_traj_out = torch.eye(4, device=device).repeat(B, T, 1, 1)
|
| 212 |
+
for i in range(cam_pts_slide.shape[1]):
|
| 213 |
+
cam_pts_slide_i = cam_pts_slide[:, i, :, :].permute(0,3,1,2)
|
| 214 |
+
id_slide_i = id_slide[i, :]
|
| 215 |
+
vis_mask_i = vis_mask_slide[:, i, :, 0, :].permute(0,2,1) # (B, N) or (N,)
|
| 216 |
+
vis_mask_i = vis_mask_i[:,:1] * vis_mask_i
|
| 217 |
+
weight_i = dyn_weight * vis_mask_i
|
| 218 |
+
if weight_i.sum() < 50:
|
| 219 |
+
weight_i = weight_i.clamp(min=5e-2)
|
| 220 |
+
if i == 0:
|
| 221 |
+
c2w_traj_out[:, id_slide_i, :, :] = weighted_procrustes_torch(cam_pts_slide_i[:,:1], cam_pts_slide_i, weight_i)
|
| 222 |
+
else:
|
| 223 |
+
campts_update = torch.einsum("btij,btnj->btni", c2w_traj_out[:,id_slide_i][...,:3,:3], cam_pts_slide_i) + c2w_traj_out[:,id_slide_i][...,None,:3,3]
|
| 224 |
+
c2w_traj_update = weighted_procrustes_torch(campts_update[:,:1], campts_update, weight_i)
|
| 225 |
+
c2w_traj_out[:, id_slide_i, :, :] = c2w_traj_update@c2w_traj_out[:,id_slide_i]
|
| 226 |
+
|
| 227 |
+
return c2w_traj_out
|
| 228 |
+
|
| 229 |
+
def posenc(x, min_deg, max_deg):
|
| 230 |
+
"""Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
|
| 231 |
+
Instead of computing [sin(x), cos(x)], we use the trig identity
|
| 232 |
+
cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
|
| 233 |
+
Args:
|
| 234 |
+
x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
|
| 235 |
+
min_deg: int, the minimum (inclusive) degree of the encoding.
|
| 236 |
+
max_deg: int, the maximum (exclusive) degree of the encoding.
|
| 237 |
+
legacy_posenc_order: bool, keep the same ordering as the original tf code.
|
| 238 |
+
Returns:
|
| 239 |
+
encoded: torch.Tensor, encoded variables.
|
| 240 |
+
"""
|
| 241 |
+
if min_deg == max_deg:
|
| 242 |
+
return x
|
| 243 |
+
scales = torch.tensor(
|
| 244 |
+
[2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
xb = (x[..., None, :] * scales[:, None]).reshape(list(x.shape[:-1]) + [-1])
|
| 248 |
+
four_feat = torch.sin(torch.cat([xb, xb + 0.5 * torch.pi], dim=-1))
|
| 249 |
+
return torch.cat([x] + [four_feat], dim=-1)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class EfficientUpdateFormer3D(nn.Module):
|
| 253 |
+
"""
|
| 254 |
+
Transformer model that updates track in 3D
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
EFormer: EfficientUpdateFormer,
|
| 260 |
+
update_points=True
|
| 261 |
+
):
|
| 262 |
+
super().__init__()
|
| 263 |
+
|
| 264 |
+
hidden_size = EFormer.hidden_size
|
| 265 |
+
num_virtual_tracks = EFormer.num_virtual_tracks
|
| 266 |
+
num_heads = EFormer.num_heads
|
| 267 |
+
mlp_ratio = 4.0
|
| 268 |
+
|
| 269 |
+
#NOTE: we design a switcher to bridege the camera pose, 3d tracks and 2d tracks
|
| 270 |
+
|
| 271 |
+
# iteract with pretrained 2d tracking
|
| 272 |
+
self.switcher_tokens = nn.Parameter(
|
| 273 |
+
torch.randn(1, num_virtual_tracks, 1, hidden_size)
|
| 274 |
+
)
|
| 275 |
+
# cross attention
|
| 276 |
+
space_depth=len(EFormer.space_virtual_blocks)
|
| 277 |
+
self.space_switcher_blocks = nn.ModuleList(
|
| 278 |
+
[
|
| 279 |
+
AttnBlock(
|
| 280 |
+
hidden_size,
|
| 281 |
+
num_heads,
|
| 282 |
+
mlp_ratio=mlp_ratio,
|
| 283 |
+
attn_class=Attention,
|
| 284 |
+
)
|
| 285 |
+
for _ in range(space_depth)
|
| 286 |
+
]
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# config 3d tracks blocks
|
| 290 |
+
self.space_track3d2switcher_blocks = nn.ModuleList(
|
| 291 |
+
[
|
| 292 |
+
CrossAttnBlock(
|
| 293 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 294 |
+
)
|
| 295 |
+
for _ in range(space_depth)
|
| 296 |
+
]
|
| 297 |
+
)
|
| 298 |
+
self.space_switcher2track3d_blocks = nn.ModuleList(
|
| 299 |
+
[
|
| 300 |
+
CrossAttnBlock(
|
| 301 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 302 |
+
)
|
| 303 |
+
for _ in range(space_depth)
|
| 304 |
+
]
|
| 305 |
+
)
|
| 306 |
+
# config switcher blocks
|
| 307 |
+
self.space_virtual2switcher_blocks = nn.ModuleList(
|
| 308 |
+
[
|
| 309 |
+
CrossAttnBlock(
|
| 310 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 311 |
+
)
|
| 312 |
+
for _ in range(space_depth)
|
| 313 |
+
]
|
| 314 |
+
)
|
| 315 |
+
self.space_switcher2virtual_blocks = nn.ModuleList(
|
| 316 |
+
[
|
| 317 |
+
CrossAttnBlock(
|
| 318 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 319 |
+
)
|
| 320 |
+
for _ in range(space_depth)
|
| 321 |
+
]
|
| 322 |
+
)
|
| 323 |
+
# config the temporal blocks
|
| 324 |
+
self.time_blocks_new = nn.ModuleList(
|
| 325 |
+
[
|
| 326 |
+
AttnBlock(
|
| 327 |
+
hidden_size,
|
| 328 |
+
num_heads,
|
| 329 |
+
mlp_ratio=mlp_ratio,
|
| 330 |
+
attn_class=Attention,
|
| 331 |
+
)
|
| 332 |
+
for _ in range(len(EFormer.time_blocks))
|
| 333 |
+
]
|
| 334 |
+
)
|
| 335 |
+
# scale and shift cross attention
|
| 336 |
+
self.scale_shift_cross_attn = nn.ModuleList(
|
| 337 |
+
[
|
| 338 |
+
CrossAttnBlock(
|
| 339 |
+
128, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 340 |
+
)
|
| 341 |
+
for _ in range(len(EFormer.time_blocks))
|
| 342 |
+
]
|
| 343 |
+
)
|
| 344 |
+
self.scale_shift_self_attn = nn.ModuleList(
|
| 345 |
+
[
|
| 346 |
+
AttnBlock(
|
| 347 |
+
128, num_heads, mlp_ratio=mlp_ratio, attn_class=Attention
|
| 348 |
+
)
|
| 349 |
+
for _ in range(len(EFormer.time_blocks))
|
| 350 |
+
]
|
| 351 |
+
)
|
| 352 |
+
self.scale_shift_dec = torch.nn.Linear(128, 128+1, bias=True)
|
| 353 |
+
|
| 354 |
+
# dense cross attention
|
| 355 |
+
self.dense_res_cross_attn = nn.ModuleList(
|
| 356 |
+
[
|
| 357 |
+
CrossAttnBlock(
|
| 358 |
+
128, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 359 |
+
)
|
| 360 |
+
for _ in range(len(EFormer.time_blocks))
|
| 361 |
+
]
|
| 362 |
+
)
|
| 363 |
+
self.dense_res_self_attn = nn.ModuleList(
|
| 364 |
+
[
|
| 365 |
+
AttnBlock(
|
| 366 |
+
128, num_heads, mlp_ratio=mlp_ratio, attn_class=Attention
|
| 367 |
+
)
|
| 368 |
+
for _ in range(len(EFormer.time_blocks))
|
| 369 |
+
]
|
| 370 |
+
)
|
| 371 |
+
self.dense_res_dec = torch.nn.Conv2d(128, 3+128, kernel_size=1, stride=1, padding=0)
|
| 372 |
+
|
| 373 |
+
# set different heads
|
| 374 |
+
self.update_points = update_points
|
| 375 |
+
if update_points:
|
| 376 |
+
self.point_head = torch.nn.Linear(hidden_size, 4, bias=True)
|
| 377 |
+
else:
|
| 378 |
+
self.depth_head = torch.nn.Linear(hidden_size, 1, bias=True)
|
| 379 |
+
self.pro_analysis_w_head = torch.nn.Linear(hidden_size, 1, bias=True)
|
| 380 |
+
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
|
| 381 |
+
self.residual_head = torch.nn.Linear(hidden_size,
|
| 382 |
+
hidden_size, bias=True)
|
| 383 |
+
|
| 384 |
+
self.initialize_weights()
|
| 385 |
+
|
| 386 |
+
def initialize_weights(self):
|
| 387 |
+
def _basic_init(module):
|
| 388 |
+
if isinstance(module, nn.Linear):
|
| 389 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 390 |
+
if module.bias is not None:
|
| 391 |
+
nn.init.constant_(module.bias, 0)
|
| 392 |
+
if getattr(self, "point_head", None) is not None:
|
| 393 |
+
torch.nn.init.trunc_normal_(self.point_head.weight, std=1e-6)
|
| 394 |
+
torch.nn.init.constant_(self.point_head.bias, 0)
|
| 395 |
+
if getattr(self, "depth_head", None) is not None:
|
| 396 |
+
torch.nn.init.trunc_normal_(self.depth_head.weight, std=0.001)
|
| 397 |
+
if getattr(self, "vis_conf_head", None) is not None:
|
| 398 |
+
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=1e-6)
|
| 399 |
+
if getattr(self, "scale_shift_dec", None) is not None:
|
| 400 |
+
torch.nn.init.trunc_normal_(self.scale_shift_dec.weight, std=0.001)
|
| 401 |
+
if getattr(self, "residual_head", None) is not None:
|
| 402 |
+
torch.nn.init.trunc_normal_(self.residual_head.weight, std=0.001)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def _trunc_init(module):
|
| 406 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 407 |
+
if isinstance(module, nn.Linear):
|
| 408 |
+
torch.nn.init.trunc_normal_(module.weight, std=0.02)
|
| 409 |
+
if module.bias is not None:
|
| 410 |
+
nn.init.zeros_(module.bias)
|
| 411 |
+
|
| 412 |
+
self.apply(_basic_init)
|
| 413 |
+
|
| 414 |
+
def forward(self, input_tensor, input_tensor3d, EFormer: EfficientUpdateFormer,
|
| 415 |
+
mask=None, add_space_attn=True, extra_sparse_tokens=None, extra_dense_tokens=None):
|
| 416 |
+
|
| 417 |
+
#NOTE: prepare the pose and 3d tracks features
|
| 418 |
+
tokens3d = EFormer.input_transform(input_tensor3d)
|
| 419 |
+
|
| 420 |
+
tokens = EFormer.input_transform(input_tensor)
|
| 421 |
+
B, _, T, _ = tokens.shape
|
| 422 |
+
virtual_tokens = EFormer.virual_tracks.repeat(B, 1, T, 1)
|
| 423 |
+
switcher_tokens = self.switcher_tokens.repeat(B, 1, T, 1)
|
| 424 |
+
|
| 425 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
| 426 |
+
tokens3d = torch.cat([tokens3d, switcher_tokens], dim=1)
|
| 427 |
+
|
| 428 |
+
_, N, _, _ = tokens.shape
|
| 429 |
+
j = 0
|
| 430 |
+
layers = []
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
for i in range(len(EFormer.time_blocks)):
|
| 434 |
+
if extra_sparse_tokens is not None:
|
| 435 |
+
extra_sparse_tokens = rearrange(extra_sparse_tokens, 'b n t c -> (b t) n c')
|
| 436 |
+
extra_sparse_tokens = self.scale_shift_cross_attn[i](extra_sparse_tokens, rearrange(tokens3d, 'b n t c -> (b t) n c'))
|
| 437 |
+
extra_sparse_tokens = rearrange(extra_sparse_tokens, '(b t) n c -> (b n) t c', b=B, t=T)
|
| 438 |
+
extra_sparse_tokens = self.scale_shift_self_attn[i](extra_sparse_tokens)
|
| 439 |
+
extra_sparse_tokens = rearrange(extra_sparse_tokens, '(b n) t c -> b n t c', b=B, n=2, t=T)
|
| 440 |
+
|
| 441 |
+
if extra_dense_tokens is not None:
|
| 442 |
+
h_p, w_p = extra_dense_tokens.shape[-2:]
|
| 443 |
+
extra_dense_tokens = rearrange(extra_dense_tokens, 'b t c h w -> (b t) (h w) c')
|
| 444 |
+
extra_dense_tokens = self.dense_res_cross_attn[i](extra_dense_tokens, rearrange(tokens3d, 'b n t c -> (b t) n c'))
|
| 445 |
+
extra_dense_tokens = rearrange(extra_dense_tokens, '(b t) n c -> (b n) t c', b=B, t=T)
|
| 446 |
+
extra_dense_tokens = self.dense_res_self_attn[i](extra_dense_tokens)
|
| 447 |
+
extra_dense_tokens = rearrange(extra_dense_tokens, '(b h w) t c -> b t c h w', b=B, h=h_p, w=w_p)
|
| 448 |
+
|
| 449 |
+
# temporal
|
| 450 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
| 451 |
+
time_tokens = EFormer.time_blocks[i](time_tokens)
|
| 452 |
+
|
| 453 |
+
# temporal 3d
|
| 454 |
+
time_tokens3d = tokens3d.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
| 455 |
+
time_tokens3d = self.time_blocks_new[i](time_tokens3d)
|
| 456 |
+
|
| 457 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
| 458 |
+
tokens3d = time_tokens3d.view(B, N, T, -1)
|
| 459 |
+
|
| 460 |
+
if (
|
| 461 |
+
add_space_attn
|
| 462 |
+
and hasattr(EFormer, "space_virtual_blocks")
|
| 463 |
+
and (i % (len(EFormer.time_blocks) // len(EFormer.space_virtual_blocks)) == 0)
|
| 464 |
+
):
|
| 465 |
+
space_tokens = (
|
| 466 |
+
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
| 467 |
+
) # B N T C -> (B T) N C
|
| 468 |
+
space_tokens3d = (
|
| 469 |
+
tokens3d.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
| 470 |
+
) # B N T C -> (B T) N C
|
| 471 |
+
|
| 472 |
+
point_tokens = space_tokens[:, : N - EFormer.num_virtual_tracks]
|
| 473 |
+
virtual_tokens = space_tokens[:, N - EFormer.num_virtual_tracks :]
|
| 474 |
+
# get the 3d relevant tokens
|
| 475 |
+
track3d_tokens = space_tokens3d[:, : N - EFormer.num_virtual_tracks]
|
| 476 |
+
switcher_tokens = space_tokens[:, N - EFormer.num_virtual_tracks + 1:]
|
| 477 |
+
|
| 478 |
+
# iteract switcher with pose and tracks3d
|
| 479 |
+
switcher_tokens = self.space_track3d2switcher_blocks[j](
|
| 480 |
+
switcher_tokens, track3d_tokens, mask=mask
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
virtual_tokens = EFormer.space_virtual2point_blocks[j](
|
| 485 |
+
virtual_tokens, point_tokens, mask=mask
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# get the switcher_tokens
|
| 489 |
+
switcher_tokens = self.space_virtual2switcher_blocks[j](
|
| 490 |
+
switcher_tokens, virtual_tokens
|
| 491 |
+
)
|
| 492 |
+
virtual_tokens_res = self.residual_head(
|
| 493 |
+
self.space_switcher2virtual_blocks[j](
|
| 494 |
+
virtual_tokens, switcher_tokens
|
| 495 |
+
)
|
| 496 |
+
)
|
| 497 |
+
switcher_tokens_res = self.residual_head(
|
| 498 |
+
self.space_switcher2virtual_blocks[j](
|
| 499 |
+
switcher_tokens, virtual_tokens
|
| 500 |
+
)
|
| 501 |
+
)
|
| 502 |
+
# add residual
|
| 503 |
+
virtual_tokens = virtual_tokens + virtual_tokens_res
|
| 504 |
+
switcher_tokens = switcher_tokens + switcher_tokens_res
|
| 505 |
+
|
| 506 |
+
virtual_tokens = EFormer.space_virtual_blocks[j](virtual_tokens)
|
| 507 |
+
switcher_tokens = self.space_switcher_blocks[j](switcher_tokens)
|
| 508 |
+
# decode
|
| 509 |
+
point_tokens = EFormer.space_point2virtual_blocks[j](
|
| 510 |
+
point_tokens, virtual_tokens, mask=mask
|
| 511 |
+
)
|
| 512 |
+
track3d_tokens = self.space_switcher2track3d_blocks[j](
|
| 513 |
+
track3d_tokens, switcher_tokens, mask=mask
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
| 517 |
+
space_tokens3d = torch.cat([track3d_tokens, virtual_tokens], dim=1)
|
| 518 |
+
tokens = space_tokens.view(B, T, N, -1).permute(
|
| 519 |
+
0, 2, 1, 3
|
| 520 |
+
) # (B T) N C -> B N T C
|
| 521 |
+
tokens3d = space_tokens3d.view(B, T, N, -1).permute(
|
| 522 |
+
0, 2, 1, 3
|
| 523 |
+
) # (B T) N C -> B N T C
|
| 524 |
+
|
| 525 |
+
j += 1
|
| 526 |
+
|
| 527 |
+
tokens = tokens[:, : N - EFormer.num_virtual_tracks]
|
| 528 |
+
track3d_tokens = tokens3d[:, : N - EFormer.num_virtual_tracks]
|
| 529 |
+
|
| 530 |
+
if self.update_points:
|
| 531 |
+
depth_update, dynamic_prob_update = self.point_head(track3d_tokens)[..., :3], self.point_head(track3d_tokens)[..., 3:]
|
| 532 |
+
else:
|
| 533 |
+
depth_update, dynamic_prob_update = self.depth_head(track3d_tokens)[..., :1], self.depth_head(track3d_tokens)[..., 1:]
|
| 534 |
+
pro_analysis_w = self.pro_analysis_w_head(track3d_tokens)
|
| 535 |
+
flow = EFormer.flow_head(tokens)
|
| 536 |
+
if EFormer.linear_layer_for_vis_conf:
|
| 537 |
+
vis_conf = EFormer.vis_conf_head(tokens)
|
| 538 |
+
flow = torch.cat([flow, vis_conf], dim=-1)
|
| 539 |
+
if extra_sparse_tokens is not None:
|
| 540 |
+
scale_shift_out = self.scale_shift_dec(extra_sparse_tokens)
|
| 541 |
+
dense_res_out = self.dense_res_dec(extra_dense_tokens.view(B*T, -1, h_p, w_p)).view(B, T, -1, h_p, w_p)
|
| 542 |
+
return flow, depth_update, dynamic_prob_update, pro_analysis_w, scale_shift_out, dense_res_out
|
| 543 |
+
else:
|
| 544 |
+
return flow, depth_update, dynamic_prob_update, pro_analysis_w, None, None
|
| 545 |
+
|
| 546 |
+
def recover_global_translations_batch(global_rot, c2w_traj, graph_weight):
|
| 547 |
+
B, T = global_rot.shape[:2]
|
| 548 |
+
device = global_rot.device
|
| 549 |
+
|
| 550 |
+
# Compute R_i @ t_ij
|
| 551 |
+
t_rel = c2w_traj[:, :, :, :3, 3] # (B, T, T, 3)
|
| 552 |
+
R_i = global_rot[:, :, None, :, :] # (B, T, 1, 3, 3)
|
| 553 |
+
t_rhs = torch.matmul(R_i, t_rel.unsqueeze(-1)).squeeze(-1) # (B, T, T, 3)
|
| 554 |
+
|
| 555 |
+
# Mask: exclude self-loops and small weights
|
| 556 |
+
valid_mask = (graph_weight > 1e-5) & (~torch.eye(T, dtype=bool, device=device)[None, :, :]) # (B, T, T)
|
| 557 |
+
|
| 558 |
+
# Get all valid (i, j) edge indices
|
| 559 |
+
i_idx, j_idx = torch.meshgrid(
|
| 560 |
+
torch.arange(T, device=device),
|
| 561 |
+
torch.arange(T, device=device),
|
| 562 |
+
indexing="ij"
|
| 563 |
+
)
|
| 564 |
+
i_idx = i_idx.reshape(-1) # (T*T,)
|
| 565 |
+
j_idx = j_idx.reshape(-1)
|
| 566 |
+
|
| 567 |
+
# Expand to batch (B, T*T)
|
| 568 |
+
i_idx = i_idx[None, :].repeat(B, 1)
|
| 569 |
+
j_idx = j_idx[None, :].repeat(B, 1)
|
| 570 |
+
|
| 571 |
+
# Flatten everything
|
| 572 |
+
valid_mask_flat = valid_mask.view(B, -1) # (B, T*T)
|
| 573 |
+
w_flat = graph_weight.view(B, -1) # (B, T*T)
|
| 574 |
+
rhs_flat = t_rhs.view(B, -1, 3) # (B, T*T, 3)
|
| 575 |
+
|
| 576 |
+
# Initialize output translations
|
| 577 |
+
global_translations = torch.zeros(B, T, 3, device=device)
|
| 578 |
+
|
| 579 |
+
for b_id in range(B):
|
| 580 |
+
mask = valid_mask_flat[b_id]
|
| 581 |
+
i_valid = i_idx[b_id][mask]
|
| 582 |
+
j_valid = j_idx[b_id][mask]
|
| 583 |
+
w_valid = w_flat[b_id][mask]
|
| 584 |
+
rhs_valid = rhs_flat[b_id][mask]
|
| 585 |
+
|
| 586 |
+
n_edges = i_valid.shape[0]
|
| 587 |
+
|
| 588 |
+
# Build A matrix: (n_edges*3, T*3)
|
| 589 |
+
A = torch.zeros(n_edges*3, T*3, device=device)
|
| 590 |
+
|
| 591 |
+
# Build b vector: (n_edges*3,)
|
| 592 |
+
b = torch.zeros(n_edges*3, device=device)
|
| 593 |
+
|
| 594 |
+
for k in range(n_edges):
|
| 595 |
+
i, j = i_valid[k], j_valid[k]
|
| 596 |
+
weight = w_valid[k]
|
| 597 |
+
|
| 598 |
+
# Fill A matrix for x,y,z components
|
| 599 |
+
for dim in range(3):
|
| 600 |
+
row = k*3 + dim
|
| 601 |
+
A[row, i*3 + dim] = -weight
|
| 602 |
+
A[row, j*3 + dim] = weight
|
| 603 |
+
|
| 604 |
+
# Fill b vector
|
| 605 |
+
b[row] = rhs_valid[k, dim] * weight
|
| 606 |
+
|
| 607 |
+
# Solve least squares
|
| 608 |
+
try:
|
| 609 |
+
# Add small regularization for stability
|
| 610 |
+
AtA = A.transpose(-1, -2) @ A + 1e-4 * torch.eye(A.shape[-1], device=A.device)
|
| 611 |
+
Atb = A.transpose(-1, -2) @ b.unsqueeze(-1)
|
| 612 |
+
|
| 613 |
+
solution = torch.linalg.solve(AtA, Atb).squeeze(-1) # (3*T,)
|
| 614 |
+
t_batch = solution.view(T, 3)
|
| 615 |
+
|
| 616 |
+
# Fix scale by setting first frame to origin
|
| 617 |
+
t_batch = t_batch - t_batch[0:1]
|
| 618 |
+
global_translations[b_id] = t_batch
|
| 619 |
+
|
| 620 |
+
except RuntimeError as e:
|
| 621 |
+
print(f"Error in batch {b_id}: {e}")
|
| 622 |
+
global_translations[b_id] = torch.zeros(T, 3, device=device)
|
| 623 |
+
return global_translations
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def global_graph_motion_average(c2w_traj, graph_weight):
|
| 627 |
+
"""
|
| 628 |
+
This function will average the c2w_traj by the graph_weight
|
| 629 |
+
"""
|
| 630 |
+
B, T, T, _, _ = c2w_traj.shape
|
| 631 |
+
mask = graph_weight[..., 0, 0] < 1e-5 # (B, T, T)
|
| 632 |
+
mask = mask.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, 4, 4) # (B, T, T, 4, 4)
|
| 633 |
+
identity = torch.eye(4, device=c2w_traj.device).view(1, 1, 1, 4, 4).expand(B, T, T, 4, 4)
|
| 634 |
+
c2w_traj = torch.where(mask, identity, c2w_traj)
|
| 635 |
+
|
| 636 |
+
Rot_rel_weighted = c2w_traj[:,:,:,:3,:3].contiguous() * graph_weight # B T T 3 3
|
| 637 |
+
Rot_big = Rot_rel_weighted.permute(0, 1, 3, 2, 4).reshape(B, 3*T, 3*T) # B 3T 3T
|
| 638 |
+
epsilon = 1e-8
|
| 639 |
+
I_big = torch.eye(3*T, device=Rot_big.device).unsqueeze(0) # (1, 3T, 3T)
|
| 640 |
+
Rot_big_reg = Rot_big + epsilon * I_big # (B, 3T, 3T)
|
| 641 |
+
#NOTE: cal the global rotation
|
| 642 |
+
# Step 1: batch eigendecomposition
|
| 643 |
+
try:
|
| 644 |
+
eigvals, eigvecs = torch.linalg.eigh(Rot_big_reg) # eigvecs: (B, 3T, 3)
|
| 645 |
+
except:
|
| 646 |
+
import pdb; pdb.set_trace()
|
| 647 |
+
# Step 2: get the largest 3 eigenvectors
|
| 648 |
+
X = eigvecs[:, :, -3:] # (B, 3T, 3)
|
| 649 |
+
# Step 3: split into (B, T, 3, 3)
|
| 650 |
+
X = X.view(B, T, 3, 3) # each frame's rotation block (non-orthogonal)
|
| 651 |
+
# Step 4: project to SO(3), using SVD
|
| 652 |
+
U, _, Vh = torch.linalg.svd(X) # (B, T, 3, 3)
|
| 653 |
+
R = U @ Vh
|
| 654 |
+
# Step 5: ensure det(R)=1 (right-handed coordinate system)
|
| 655 |
+
det = torch.linalg.det(R) # (B, T)
|
| 656 |
+
neg_det_mask = det < 0
|
| 657 |
+
# if det<0, reverse the last column and multiply
|
| 658 |
+
U_flip = U.clone()
|
| 659 |
+
U_flip[neg_det_mask, :, -1] *= -1
|
| 660 |
+
R = U_flip @ Vh
|
| 661 |
+
# global rotation
|
| 662 |
+
Rot_glob = R[:,:1].inverse() @ R
|
| 663 |
+
# global translation
|
| 664 |
+
t_glob = recover_global_translations_batch(Rot_glob,
|
| 665 |
+
c2w_traj, graph_weight[...,0,0])
|
| 666 |
+
c2w_traj_final = torch.eye(4, device=c2w_traj.device)[None,None].repeat(B, T, 1, 1)
|
| 667 |
+
c2w_traj_final[:,:,:3,:3] = Rot_glob
|
| 668 |
+
c2w_traj_final[:,:,:3,3] = t_glob
|
| 669 |
+
|
| 670 |
+
return c2w_traj_final
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def depth_to_points_colmap(metric_depth: torch.Tensor,
|
| 674 |
+
intrinsics: torch.Tensor) -> torch.Tensor:
|
| 675 |
+
"""
|
| 676 |
+
Unproject a depth map to a point cloud in COLMAP convention.
|
| 677 |
+
|
| 678 |
+
Args:
|
| 679 |
+
metric_depth: (B, H, W) depth map, meters.
|
| 680 |
+
intrinsics: (B, 3, 3) COLMAP-style K matrix.
|
| 681 |
+
Returns:
|
| 682 |
+
points_map: (B, H, W, 3) point cloud in camera coordinates.
|
| 683 |
+
"""
|
| 684 |
+
# 因为输入的 metric_depth 维度是 (B, H, W)
|
| 685 |
+
B, H, W = metric_depth.shape
|
| 686 |
+
|
| 687 |
+
# 因为需要每个像素的 [u, v, 1] 齐次坐标
|
| 688 |
+
u = torch.arange(W, device=metric_depth.device, dtype=metric_depth.dtype)
|
| 689 |
+
v = torch.arange(H, device=metric_depth.device, dtype=metric_depth.dtype)
|
| 690 |
+
uu, vv = torch.meshgrid(u, v, indexing='xy')
|
| 691 |
+
pix = torch.stack([uu, vv, torch.ones_like(uu)], dim=-1)
|
| 692 |
+
pix = pix.reshape(-1, 3) # (H*W, 3)
|
| 693 |
+
# 因为要对 B 张图做相同操作
|
| 694 |
+
pix = pix.unsqueeze(0).expand(B, -1, -1) # (B, H*W, 3)
|
| 695 |
+
# import pdb; pdb.set_trace()
|
| 696 |
+
# 因为 K 是 (B, 3, 3)
|
| 697 |
+
K_inv = torch.inverse(intrinsics) # (B, 3, 3)
|
| 698 |
+
|
| 699 |
+
# 因为反投影方向是 X_cam = K^{-1} * pix
|
| 700 |
+
dirs = torch.einsum('bij,bkj->bki', K_inv, pix) # (B, H*W, 3)
|
| 701 |
+
|
| 702 |
+
# 因为要按深度伸缩
|
| 703 |
+
depths = metric_depth.reshape(B, -1) # (B, H*W)
|
| 704 |
+
pts = dirs * depths.unsqueeze(-1) # (B, H*W, 3)
|
| 705 |
+
|
| 706 |
+
# 因为希望输出 (B, H, W, 3)
|
| 707 |
+
points_map = pts.view(B, H, W, 3) # (B, H, W, 3)
|
| 708 |
+
|
| 709 |
+
return points_map
|
| 710 |
+
|
| 711 |
+
def vec6d_to_R(vector_6D):
|
| 712 |
+
v1=vector_6D[:,:3]/vector_6D[:,:3].norm(dim=-1,keepdim=True)
|
| 713 |
+
v2=vector_6D[:,3:]-(vector_6D[:,3:]*v1).sum(dim=-1,keepdim=True)*v1
|
| 714 |
+
v2=v2/v2.norm(dim=-1,keepdim=True)
|
| 715 |
+
v3=torch.cross(v1,v2,dim=-1)
|
| 716 |
+
return torch.concatenate((v1.unsqueeze(1),v2.unsqueeze(1),v3.unsqueeze(1)),dim=1)
|
| 717 |
+
|
| 718 |
+
class MyTransformerHead(nn.Module):
|
| 719 |
+
def __init__(self,input_dim,dim,use_positional_encoding_transformer):
|
| 720 |
+
super(MyTransformerHead,self).__init__()
|
| 721 |
+
|
| 722 |
+
patch_dim=input_dim+1
|
| 723 |
+
self.layers=3
|
| 724 |
+
# dim=128
|
| 725 |
+
self.use_positional_encoding_transformer=use_positional_encoding_transformer
|
| 726 |
+
self.to_patch_embedding = nn.Sequential(
|
| 727 |
+
nn.LayerNorm(patch_dim),
|
| 728 |
+
nn.Linear(patch_dim, dim),
|
| 729 |
+
nn.LayerNorm(dim),
|
| 730 |
+
)
|
| 731 |
+
self.transformer_frames=[]
|
| 732 |
+
self.transformer_points=[]
|
| 733 |
+
|
| 734 |
+
for i in range(self.layers):
|
| 735 |
+
self.transformer_frames.append(Transformer(dim, 1, 16, 64, 2048))
|
| 736 |
+
self.transformer_points.append(Transformer(dim, 1, 16, 64, 2048))
|
| 737 |
+
self.transformer_frames=nn.ModuleList(self.transformer_frames)
|
| 738 |
+
self.transformer_points=nn.ModuleList(self.transformer_points)
|
| 739 |
+
|
| 740 |
+
def forward(self, x):
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
x=torch.cat((x,torch.ones(x.shape[0],x.shape[1],1,x.shape[3]).cuda()),dim=2)
|
| 744 |
+
|
| 745 |
+
x=x.transpose(2,3)
|
| 746 |
+
|
| 747 |
+
b,n,f,c=x.shape
|
| 748 |
+
x=self.to_patch_embedding(x)
|
| 749 |
+
|
| 750 |
+
x=x.view(b*n,f,-1) # x.shape [390, 33, 256]
|
| 751 |
+
if self.use_positional_encoding_transformer:
|
| 752 |
+
pe = posemb_sincos_1d(x) #pe.shape= [33,256] (33 frame, 256 embedding dim)
|
| 753 |
+
x=pe.unsqueeze(0)+x
|
| 754 |
+
for i in range(self.layers):
|
| 755 |
+
#frames aggregation
|
| 756 |
+
x=self.transformer_frames[i](x)
|
| 757 |
+
|
| 758 |
+
#point sets aggregation
|
| 759 |
+
x=x.view(b,n,f,-1).transpose(1,2).reshape(b*f,n,-1)
|
| 760 |
+
|
| 761 |
+
x=self.transformer_points[i](x)
|
| 762 |
+
|
| 763 |
+
x=x.view(b,f,n,-1)
|
| 764 |
+
x=x.transpose(1,2).reshape(b*n,f,-1)
|
| 765 |
+
|
| 766 |
+
x=x.view(b,n,f,-1)
|
| 767 |
+
x=x.transpose(2,3)
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
return x
|
| 771 |
+
|
| 772 |
+
def positionalEncoding_vec(in_tensor, b):
|
| 773 |
+
proj = torch.einsum('ij, k -> ijk', in_tensor, b)
|
| 774 |
+
mapped_coords = torch.cat((torch.sin(proj), torch.cos(proj)), dim=1)
|
| 775 |
+
output = mapped_coords.transpose(2, 1).contiguous().view(mapped_coords.size(0), -1)
|
| 776 |
+
return output
|
| 777 |
+
|
| 778 |
+
class TrackFusion(nn.Module):
|
| 779 |
+
def __init__(self,width1=320,conv2_kernel_size=31,K=12,
|
| 780 |
+
conv_kernel_size=3,inputdim=2,use_positionl_encoding=True,
|
| 781 |
+
positional_dim=4,use_transformer=True,detach_cameras_dynamic=True,
|
| 782 |
+
use_positional_encoding_transformer=True,use_set_of_sets=False,predict_focal_length=False):
|
| 783 |
+
super(TrackFusion, self).__init__()
|
| 784 |
+
self.predict_focal_length=predict_focal_length
|
| 785 |
+
self.inputdim = inputdim
|
| 786 |
+
self.n1 = width1
|
| 787 |
+
|
| 788 |
+
self.K=K
|
| 789 |
+
self.n2 = 6+3+1+self.K+2
|
| 790 |
+
self.detach_cameras_dynamic=detach_cameras_dynamic
|
| 791 |
+
l=conv_kernel_size
|
| 792 |
+
# layers
|
| 793 |
+
self.use_set_of_sets=use_set_of_sets
|
| 794 |
+
self.use_positionl_encoding=use_positionl_encoding
|
| 795 |
+
self.positional_dim=positional_dim
|
| 796 |
+
actual_input_dim=inputdim
|
| 797 |
+
if self.use_positionl_encoding:
|
| 798 |
+
actual_input_dim=2 * inputdim * self.positional_dim+inputdim
|
| 799 |
+
|
| 800 |
+
self.use_transformer=use_transformer
|
| 801 |
+
|
| 802 |
+
if self.use_positionl_encoding:
|
| 803 |
+
self.b = torch.tensor([(2 ** j) * np.pi for j in range(self.positional_dim)],requires_grad = False)
|
| 804 |
+
|
| 805 |
+
if True:
|
| 806 |
+
if self.use_transformer:
|
| 807 |
+
self.transformer_my=MyTransformerHead(actual_input_dim,width1,use_positional_encoding_transformer)
|
| 808 |
+
|
| 809 |
+
self.conv_final = nn.Conv1d(self.n1, self.n2, kernel_size=conv2_kernel_size,stride=1, padding=conv2_kernel_size//2, padding_mode='circular')
|
| 810 |
+
|
| 811 |
+
self.fc1 = nn.Linear(self.n1,3*self.K+1)
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
torch.nn.init.xavier_uniform_(self.conv_final.weight)
|
| 816 |
+
|
| 817 |
+
torch.nn.init.xavier_uniform_(self.fc1.weight)
|
| 818 |
+
|
| 819 |
+
def forward(self, x, pts_miu=None, pts_radis=None, simple_return=True):
|
| 820 |
+
|
| 821 |
+
B, N, C, T = x.shape
|
| 822 |
+
if self.use_positionl_encoding:
|
| 823 |
+
x_original_shape=x.shape
|
| 824 |
+
x=x.transpose(2,3)
|
| 825 |
+
x=x.reshape(-1,x.shape[-1])
|
| 826 |
+
if self.b.device!=x.device:
|
| 827 |
+
self.b=self.b.to(x.device)
|
| 828 |
+
pos = positionalEncoding_vec(x,self.b)
|
| 829 |
+
x=torch.cat((x,pos),dim=1)
|
| 830 |
+
x=x.view(x_original_shape[0],x_original_shape[1],x_original_shape[3],x.shape[-1]).transpose(2,3)
|
| 831 |
+
|
| 832 |
+
b = len(x)
|
| 833 |
+
n= x.shape[1]
|
| 834 |
+
l= x.shape[-1]
|
| 835 |
+
if self.use_set_of_sets:
|
| 836 |
+
cameras,perpoint_features=self.set_of_sets_my(x)
|
| 837 |
+
else:
|
| 838 |
+
if self.use_transformer:
|
| 839 |
+
x=self.transformer_my(x)
|
| 840 |
+
else:
|
| 841 |
+
for i in range(len( self.conv1)):
|
| 842 |
+
if i==0:
|
| 843 |
+
x = x.reshape(n*b, x.shape[2],l)
|
| 844 |
+
else:
|
| 845 |
+
x = x.view(n * b, self.n1, l)
|
| 846 |
+
x1 = self.bn1[i](self.conv1[i](x)).view(b,n,self.n1,l)
|
| 847 |
+
x2 = self.bn1s[i](self.conv1s[i](x)).view(b,n,self.n1,l).mean(dim=1).view(b,1,self.n1,l).repeat(1,n,1,1)
|
| 848 |
+
x = F.relu(x1 + x2)
|
| 849 |
+
|
| 850 |
+
cameras=torch.mean(x,dim=1)
|
| 851 |
+
cameras=self.conv_final(cameras)
|
| 852 |
+
perpoint_features = torch.mean(x,dim=3)
|
| 853 |
+
perpoint_features = self.fc1(perpoint_features.view(n*b,self.n1))
|
| 854 |
+
|
| 855 |
+
B=perpoint_features[:,:self.K*3].view(b,n,3,self.K) # motion basis
|
| 856 |
+
NR=F.elu(perpoint_features[:,-1].view(b,n))+1+0.00001
|
| 857 |
+
|
| 858 |
+
position_params=cameras[:,:3,:]
|
| 859 |
+
if self.predict_focal_length:
|
| 860 |
+
focal_params=1+0.05*cameras[:,3:4,:].clone().transpose(1,2)
|
| 861 |
+
else:
|
| 862 |
+
focal_params=1.0
|
| 863 |
+
basis_params=cameras[:,4:4+self.K]
|
| 864 |
+
basis_params[:,0,:]=torch.clamp(basis_params[:,0,:].clone(),min=1.0,max=1.0)
|
| 865 |
+
basis_params.transpose(1,2).unsqueeze(1).unsqueeze(1)
|
| 866 |
+
rotation_params=cameras[:,4+self.K:4+self.K+6]
|
| 867 |
+
# Converting rotation parameters into a valid rotation matrix (probably better to move to 6d representation)
|
| 868 |
+
rotation_params=vec6d_to_R(rotation_params.transpose(1,2).reshape(b*l,6)).view(b,l,3,3)
|
| 869 |
+
|
| 870 |
+
# Transfering global 3D into each camera coordinates (using per camera roation and translation)
|
| 871 |
+
points3D_static=((basis_params.transpose(1,2).unsqueeze(1).unsqueeze(1))[:,:,:,:,:1]*B.unsqueeze(-2)[:,:,:,:,:1]).sum(-1)
|
| 872 |
+
|
| 873 |
+
if self.detach_cameras_dynamic==False:
|
| 874 |
+
points3D=((basis_params.transpose(1,2).unsqueeze(1).unsqueeze(1))[:,:,:,:,1:]*B.unsqueeze(-2)[:,:,:,:,1:]).sum(-1)+points3D_static
|
| 875 |
+
else:
|
| 876 |
+
points3D=((basis_params.transpose(1,2).unsqueeze(1).unsqueeze(1))[:,:,:,:,1:]*B.unsqueeze(-2)[:,:,:,:,1:]).sum(-1)+points3D_static.detach()
|
| 877 |
+
|
| 878 |
+
points3D=points3D.transpose(1,3)
|
| 879 |
+
points3D_static=points3D_static.transpose(1,3)
|
| 880 |
+
position_params=position_params.transpose(1,2)
|
| 881 |
+
if pts_miu is not None:
|
| 882 |
+
position_params=position_params*pts_radis.squeeze(-1)+pts_miu.squeeze(-2)
|
| 883 |
+
points3D_static = points3D_static*pts_radis.squeeze(-1)+pts_miu.permute(0,1,3,2)
|
| 884 |
+
points3D = points3D*pts_radis.squeeze(-1)+pts_miu.permute(0,1,3,2)
|
| 885 |
+
|
| 886 |
+
if self.detach_cameras_dynamic==False:
|
| 887 |
+
points3D_camera=(torch.bmm(rotation_params.view(b*l,3,3).transpose(1,2),points3D.reshape(b*l,3,n)-position_params.reshape(b*l,3).unsqueeze(-1)))
|
| 888 |
+
points3D_camera=points3D_camera.view(b,l,3,n)
|
| 889 |
+
else:
|
| 890 |
+
points3D_camera=(torch.bmm(rotation_params.view(b*l,3,3).transpose(1,2).detach(),points3D.reshape(b*l,3,n)-position_params.detach().reshape(b*l,3).unsqueeze(-1)))
|
| 891 |
+
points3D_camera=points3D_camera.view(b,l,3,n)
|
| 892 |
+
points3D_static_camera=(torch.bmm(rotation_params.view(b*l,3,3).transpose(1,2),points3D_static.reshape(b*l,3,n)-position_params.reshape(b*l,3).unsqueeze(-1)))
|
| 893 |
+
points3D_static_camera=points3D_static_camera.view(b,l,3,n)
|
| 894 |
+
|
| 895 |
+
# Projecting from 3D to 2D
|
| 896 |
+
projections=points3D_camera.clone()
|
| 897 |
+
projections_static=points3D_static_camera.clone()
|
| 898 |
+
|
| 899 |
+
depths=projections[:,:,2,:]
|
| 900 |
+
depths_static=projections_static[:,:,2,:]
|
| 901 |
+
|
| 902 |
+
projectionx=focal_params*projections[:,:,0,:]/torch.clamp(projections[:,:,2,:].clone(),min=0.01)
|
| 903 |
+
projectiony=focal_params*projections[:,:,1,:]/torch.clamp(projections[:,:,2,:].clone(),min=0.01)
|
| 904 |
+
|
| 905 |
+
projectionx_static=focal_params*projections_static[:,:,0,:]/torch.clamp(projections_static[:,:,2,:].clone(),min=0.01)
|
| 906 |
+
projectiony_static=focal_params*projections_static[:,:,1,:]/torch.clamp(projections_static[:,:,2,:].clone(),min=0.01)
|
| 907 |
+
|
| 908 |
+
projections2=torch.cat((projectionx.unsqueeze(2),projectiony.unsqueeze(2)),dim=2)
|
| 909 |
+
projections2_static=torch.cat((projectionx_static.unsqueeze(2),projectiony_static.unsqueeze(2)),dim=2)
|
| 910 |
+
|
| 911 |
+
if simple_return:
|
| 912 |
+
c2w_traj = torch.eye(4, device=x.device)[None,None].repeat(b,T,1,1)
|
| 913 |
+
c2w_traj[:,:,:3,:3] = rotation_params
|
| 914 |
+
c2w_traj[:,:,:3,3] = position_params
|
| 915 |
+
return c2w_traj, points3D, points3D_camera
|
| 916 |
+
else:
|
| 917 |
+
return focal_params,projections2,projections2_static,rotation_params,position_params,B,points3D,points3D_static,depths,depths_static,0,basis_params,0,0,points3D_camera,NR
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
def get_nth_visible_time_index(vis_gt: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
|
| 921 |
+
"""
|
| 922 |
+
vis_gt: [B, T, N] 0/1 binary tensor
|
| 923 |
+
n: [B, N] int tensor, the n-th visible time index to get (1-based)
|
| 924 |
+
Returns: [B, N] tensor of time indices into T, or -1 if not enough visible steps
|
| 925 |
+
"""
|
| 926 |
+
B, T, N = vis_gt.shape
|
| 927 |
+
|
| 928 |
+
# Create a tensor [0, 1, ..., T-1] for time indices
|
| 929 |
+
time_idx = torch.arange(T, device=vis_gt.device).view(1, T, 1).expand(B, T, N) # [B, T, N]
|
| 930 |
+
|
| 931 |
+
# Mask invisible steps with a large number (T)
|
| 932 |
+
masked_time = torch.where(vis_gt.bool(), time_idx, torch.full_like(time_idx, T))
|
| 933 |
+
|
| 934 |
+
# Sort along time dimension
|
| 935 |
+
sorted_time, _ = masked_time.sort(dim=1) # [B, T, N]
|
| 936 |
+
|
| 937 |
+
# Prepare index tensor for gather: [B, N] -> [B, 1, N]
|
| 938 |
+
gather_idx = (n - 1).clamp(min=0, max=T-1).unsqueeze(1) # shape: [B, 1, N]
|
| 939 |
+
assert gather_idx.shape == sorted_time.shape[:1] + (1, sorted_time.shape[2]) # [B, 1, N]
|
| 940 |
+
|
| 941 |
+
# Gather from sorted_time: result is [B, 1, N]
|
| 942 |
+
nth_time = sorted_time.gather(dim=1, index=gather_idx).squeeze(1) # [B, N]
|
| 943 |
+
|
| 944 |
+
# If value is T (i.e., masked), then not enough visible → set to -1
|
| 945 |
+
nth_time = torch.where(nth_time == T, torch.full_like(nth_time, -1), nth_time)
|
| 946 |
+
|
| 947 |
+
return nth_time # [B, N]
|
| 948 |
+
|
| 949 |
+
def knn_torch(x, k):
|
| 950 |
+
"""
|
| 951 |
+
x: (B, T, N, 2)
|
| 952 |
+
return: indices of k-NN, shape (B, T, N, k)
|
| 953 |
+
"""
|
| 954 |
+
B, T, N, C = x.shape
|
| 955 |
+
# Reshape to (B*T, N, 2)
|
| 956 |
+
x = x.view(B*T, N, C) # Merge the first two dimensions for easier processing
|
| 957 |
+
# Calculate pairwise distance: (B*T, N, N)
|
| 958 |
+
dist = torch.cdist(x, x, p=2) # Euclidean distance
|
| 959 |
+
|
| 960 |
+
# Exclude self: set diagonal to a large number (to prevent self from being a neighbor)
|
| 961 |
+
mask = torch.eye(N, device=x.device).bool()[None, :, :] # (1, N, N)
|
| 962 |
+
dist.masked_fill_(mask, float('inf'))
|
| 963 |
+
|
| 964 |
+
# Get indices of top k smallest distances
|
| 965 |
+
knn_idx = dist.topk(k, largest=False).indices # (B*T, N, k)
|
| 966 |
+
# Restore dimensions (B, T, N, k)
|
| 967 |
+
knn_idx = knn_idx.view(B, T, N, k)
|
| 968 |
+
return knn_idx
|
| 969 |
+
|
| 970 |
+
def get_topo_mask(coords_xyz_append: torch.Tensor,
|
| 971 |
+
coords_2d_lift: torch.Tensor, replace_ratio: float = 0.6) -> torch.Tensor:
|
| 972 |
+
"""
|
| 973 |
+
coords_xyz_append: [B, T, N, 3] 3d coordinates
|
| 974 |
+
coords_2d_lift: [B*T, N] depth map
|
| 975 |
+
replace_ratio: float, the ratio of the depth change to be considered as a topological change
|
| 976 |
+
"""
|
| 977 |
+
B, T, N, _ = coords_xyz_append.shape
|
| 978 |
+
# if N > 1024:
|
| 979 |
+
# pick_idx = torch.randperm(N)[:1024]
|
| 980 |
+
# else:
|
| 981 |
+
pick_idx = torch.arange(N, device=coords_xyz_append.device)
|
| 982 |
+
coords_xyz_append = coords_xyz_append[:,:,pick_idx,:]
|
| 983 |
+
knn_idx = knn_torch(coords_xyz_append, 49)
|
| 984 |
+
knn_idx = pick_idx[knn_idx]
|
| 985 |
+
# raw topology
|
| 986 |
+
raw_depth = coords_xyz_append[...,2:] # B T N 1 knn_idx B T N K
|
| 987 |
+
knn_depth = torch.gather(
|
| 988 |
+
raw_depth.expand(-1, -1, -1, knn_idx.shape[-1]), # (B, T, N, K)
|
| 989 |
+
dim=2,
|
| 990 |
+
index=knn_idx # (B, T, N, K)
|
| 991 |
+
).squeeze(-1) # → (B, T, N, K)
|
| 992 |
+
depth_rel_neg_raw = (knn_depth - raw_depth)
|
| 993 |
+
# unproj depth
|
| 994 |
+
knn_depth_unproj = torch.gather(
|
| 995 |
+
depth_unproj.view(B,T,N,1).expand(-1, -1, -1, knn_idx.shape[-1]), # (B, T, N, K)
|
| 996 |
+
dim=2,
|
| 997 |
+
index=knn_idx # (B, T, N, K)
|
| 998 |
+
).squeeze(-1) # → (B, T, N, K)
|
| 999 |
+
depth_rel_neg_unproj = (knn_depth_unproj - depth_unproj.view(B,T,N,1))
|
| 1000 |
+
# topological change threshold
|
| 1001 |
+
mask_topo = (depth_rel_neg_raw.abs() / (depth_rel_neg_unproj.abs()+1e-8) - 1).abs() < 0.4
|
| 1002 |
+
mask_topo = mask_topo.sum(dim=-1) > 9
|
| 1003 |
+
|
| 1004 |
+
return mask_topo
|
| 1005 |
+
|
| 1006 |
+
|
models/SpaTrackV2/models/utils.py
ADDED
|
@@ -0,0 +1,1221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from https://github.com/facebookresearch/PoseDiffusion
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from typing import Optional, Tuple, Union, List
|
| 14 |
+
from einops import rearrange, repeat
|
| 15 |
+
|
| 16 |
+
import cv2
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
# from torchmetrics.functional.regression import pearson_corrcoef
|
| 20 |
+
from easydict import EasyDict as edict
|
| 21 |
+
from enum import Enum
|
| 22 |
+
import torch.utils.data.distributed as dist
|
| 23 |
+
from typing import Literal, Union, List, Tuple, Dict
|
| 24 |
+
from models.monoD.depth_anything_v2.util.transform import Resize
|
| 25 |
+
from models.SpaTrackV2.utils.model_utils import sample_features5d
|
| 26 |
+
EPS = 1e-9
|
| 27 |
+
|
| 28 |
+
class Summary(Enum):
|
| 29 |
+
NONE = 0
|
| 30 |
+
AVERAGE = 1
|
| 31 |
+
SUM = 2
|
| 32 |
+
COUNT = 3
|
| 33 |
+
|
| 34 |
+
class AverageMeter(object):
|
| 35 |
+
"""Computes and stores the average and current value"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
|
| 38 |
+
self.name = name
|
| 39 |
+
self.fmt = fmt
|
| 40 |
+
self.summary_type = summary_type
|
| 41 |
+
self.reset()
|
| 42 |
+
|
| 43 |
+
def reset(self):
|
| 44 |
+
self.val = 0
|
| 45 |
+
self.avg = 0
|
| 46 |
+
self.sum = 0
|
| 47 |
+
self.count = 0
|
| 48 |
+
|
| 49 |
+
def update(self, val, n=1):
|
| 50 |
+
self.val = val
|
| 51 |
+
self.sum += val * n
|
| 52 |
+
self.count += n
|
| 53 |
+
self.avg = self.sum / self.count
|
| 54 |
+
|
| 55 |
+
def all_reduce(self):
|
| 56 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 57 |
+
if isinstance(self.sum, np.ndarray):
|
| 58 |
+
total = torch.tensor(
|
| 59 |
+
self.sum.tolist()
|
| 60 |
+
+ [
|
| 61 |
+
self.count,
|
| 62 |
+
],
|
| 63 |
+
dtype=torch.float32,
|
| 64 |
+
device=device,
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
total = torch.tensor(
|
| 68 |
+
[self.sum, self.count], dtype=torch.float32, device=device
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
|
| 72 |
+
if total.shape[0] > 2:
|
| 73 |
+
self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item()
|
| 74 |
+
else:
|
| 75 |
+
self.sum, self.count = total.tolist()
|
| 76 |
+
self.avg = self.sum / (self.count + 1e-5)
|
| 77 |
+
|
| 78 |
+
def __str__(self):
|
| 79 |
+
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
|
| 80 |
+
return fmtstr.format(**self.__dict__)
|
| 81 |
+
|
| 82 |
+
def summary(self):
|
| 83 |
+
fmtstr = ""
|
| 84 |
+
if self.summary_type is Summary.NONE:
|
| 85 |
+
fmtstr = ""
|
| 86 |
+
elif self.summary_type is Summary.AVERAGE:
|
| 87 |
+
fmtstr = "{name} {avg:.3f}"
|
| 88 |
+
elif self.summary_type is Summary.SUM:
|
| 89 |
+
fmtstr = "{name} {sum:.3f}"
|
| 90 |
+
elif self.summary_type is Summary.COUNT:
|
| 91 |
+
fmtstr = "{name} {count:.3f}"
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError("invalid summary type %r" % self.summary_type)
|
| 94 |
+
|
| 95 |
+
return fmtstr.format(**self.__dict__)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def procrustes_analysis(X0,X1): # [N,3]
|
| 99 |
+
# translation
|
| 100 |
+
t0 = X0.mean(dim=0,keepdim=True)
|
| 101 |
+
t1 = X1.mean(dim=0,keepdim=True)
|
| 102 |
+
X0c = X0-t0
|
| 103 |
+
X1c = X1-t1
|
| 104 |
+
# scale
|
| 105 |
+
s0 = (X0c**2).sum(dim=-1).mean().sqrt()
|
| 106 |
+
s1 = (X1c**2).sum(dim=-1).mean().sqrt()
|
| 107 |
+
X0cs = X0c/s0
|
| 108 |
+
X1cs = X1c/s1
|
| 109 |
+
# rotation (use double for SVD, float loses precision)
|
| 110 |
+
U,S,V = (X0cs.t()@X1cs).double().svd(some=True)
|
| 111 |
+
R = (U@V.t()).float()
|
| 112 |
+
if R.det()<0: R[2] *= -1
|
| 113 |
+
# align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0
|
| 114 |
+
sim3 = edict(t0=t0[0],t1=t1[0],s0=s0,s1=s1,R=R)
|
| 115 |
+
return sim3
|
| 116 |
+
|
| 117 |
+
def create_intri_matrix(focal_length, principal_point):
|
| 118 |
+
"""
|
| 119 |
+
Creates a intri matrix from focal length and principal point.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
focal_length (torch.Tensor): A Bx2 or BxSx2 tensor containing the focal lengths (fx, fy) for each image.
|
| 123 |
+
principal_point (torch.Tensor): A Bx2 or BxSx2 tensor containing the principal point coordinates (cx, cy) for each image.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
torch.Tensor: A Bx3x3 or BxSx3x3 tensor containing the camera matrix for each image.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
if len(focal_length.shape) == 2:
|
| 130 |
+
B = focal_length.shape[0]
|
| 131 |
+
intri_matrix = torch.zeros(B, 3, 3, dtype=focal_length.dtype, device=focal_length.device)
|
| 132 |
+
intri_matrix[:, 0, 0] = focal_length[:, 0]
|
| 133 |
+
intri_matrix[:, 1, 1] = focal_length[:, 1]
|
| 134 |
+
intri_matrix[:, 2, 2] = 1.0
|
| 135 |
+
intri_matrix[:, 0, 2] = principal_point[:, 0]
|
| 136 |
+
intri_matrix[:, 1, 2] = principal_point[:, 1]
|
| 137 |
+
else:
|
| 138 |
+
B, S = focal_length.shape[0], focal_length.shape[1]
|
| 139 |
+
intri_matrix = torch.zeros(B, S, 3, 3, dtype=focal_length.dtype, device=focal_length.device)
|
| 140 |
+
intri_matrix[:, :, 0, 0] = focal_length[:, :, 0]
|
| 141 |
+
intri_matrix[:, :, 1, 1] = focal_length[:, :, 1]
|
| 142 |
+
intri_matrix[:, :, 2, 2] = 1.0
|
| 143 |
+
intri_matrix[:, :, 0, 2] = principal_point[:, :, 0]
|
| 144 |
+
intri_matrix[:, :, 1, 2] = principal_point[:, :, 1]
|
| 145 |
+
|
| 146 |
+
return intri_matrix
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def closed_form_inverse_OpenCV(se3, R=None, T=None):
|
| 150 |
+
"""
|
| 151 |
+
Computes the inverse of each 4x4 SE3 matrix in the batch.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
- se3 (Tensor): Nx4x4 tensor of SE3 matrices.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
- Tensor: Nx4x4 tensor of inverted SE3 matrices.
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
| R t |
|
| 161 |
+
| 0 1 |
|
| 162 |
+
-->
|
| 163 |
+
| R^T -R^T t|
|
| 164 |
+
| 0 1 |
|
| 165 |
+
"""
|
| 166 |
+
if R is None:
|
| 167 |
+
R = se3[:, :3, :3]
|
| 168 |
+
|
| 169 |
+
if T is None:
|
| 170 |
+
T = se3[:, :3, 3:]
|
| 171 |
+
|
| 172 |
+
# Compute the transpose of the rotation
|
| 173 |
+
R_transposed = R.transpose(1, 2)
|
| 174 |
+
|
| 175 |
+
# -R^T t
|
| 176 |
+
top_right = -R_transposed.bmm(T)
|
| 177 |
+
|
| 178 |
+
inverted_matrix = torch.eye(4, 4)[None].repeat(len(se3), 1, 1)
|
| 179 |
+
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
|
| 180 |
+
|
| 181 |
+
inverted_matrix[:, :3, :3] = R_transposed
|
| 182 |
+
inverted_matrix[:, :3, 3:] = top_right
|
| 183 |
+
|
| 184 |
+
return inverted_matrix
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_EFP(pred_cameras, image_size, B, S, default_focal=False):
|
| 188 |
+
"""
|
| 189 |
+
Converting PyTorch3D cameras to extrinsics, intrinsics matrix
|
| 190 |
+
|
| 191 |
+
Return extrinsics, intrinsics, focal_length, principal_point
|
| 192 |
+
"""
|
| 193 |
+
scale = image_size.min()
|
| 194 |
+
|
| 195 |
+
focal_length = pred_cameras.focal_length
|
| 196 |
+
|
| 197 |
+
principal_point = torch.zeros_like(focal_length)
|
| 198 |
+
|
| 199 |
+
focal_length = focal_length * scale / 2
|
| 200 |
+
principal_point = (image_size[None] - principal_point * scale) / 2
|
| 201 |
+
|
| 202 |
+
Rots = pred_cameras.R.clone()
|
| 203 |
+
Trans = pred_cameras.T.clone()
|
| 204 |
+
|
| 205 |
+
extrinsics = torch.cat([Rots, Trans[..., None]], dim=-1)
|
| 206 |
+
|
| 207 |
+
# reshape
|
| 208 |
+
extrinsics = extrinsics.reshape(B, S, 3, 4)
|
| 209 |
+
focal_length = focal_length.reshape(B, S, 2)
|
| 210 |
+
principal_point = principal_point.reshape(B, S, 2)
|
| 211 |
+
|
| 212 |
+
# only one dof focal length
|
| 213 |
+
if default_focal:
|
| 214 |
+
focal_length[:] = scale
|
| 215 |
+
else:
|
| 216 |
+
focal_length = focal_length.mean(dim=-1, keepdim=True).expand(-1, -1, 2)
|
| 217 |
+
focal_length = focal_length.clamp(0.2 * scale, 5 * scale)
|
| 218 |
+
|
| 219 |
+
intrinsics = create_intri_matrix(focal_length, principal_point)
|
| 220 |
+
return extrinsics, intrinsics
|
| 221 |
+
|
| 222 |
+
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
| 223 |
+
"""
|
| 224 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
quaternions: quaternions with real part first,
|
| 228 |
+
as tensor of shape (..., 4).
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 232 |
+
"""
|
| 233 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
| 234 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
| 235 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 236 |
+
|
| 237 |
+
o = torch.stack(
|
| 238 |
+
(
|
| 239 |
+
1 - two_s * (j * j + k * k),
|
| 240 |
+
two_s * (i * j - k * r),
|
| 241 |
+
two_s * (i * k + j * r),
|
| 242 |
+
two_s * (i * j + k * r),
|
| 243 |
+
1 - two_s * (i * i + k * k),
|
| 244 |
+
two_s * (j * k - i * r),
|
| 245 |
+
two_s * (i * k - j * r),
|
| 246 |
+
two_s * (j * k + i * r),
|
| 247 |
+
1 - two_s * (i * i + j * j),
|
| 248 |
+
),
|
| 249 |
+
-1,
|
| 250 |
+
)
|
| 251 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 252 |
+
|
| 253 |
+
def pose_encoding_to_camera(
|
| 254 |
+
pose_encoding,
|
| 255 |
+
pose_encoding_type="absT_quaR_logFL",
|
| 256 |
+
log_focal_length_bias=1.8,
|
| 257 |
+
min_focal_length=0.1,
|
| 258 |
+
max_focal_length=30,
|
| 259 |
+
return_dict=False,
|
| 260 |
+
to_OpenCV=True,
|
| 261 |
+
):
|
| 262 |
+
"""
|
| 263 |
+
Args:
|
| 264 |
+
pose_encoding: A tensor of shape `BxNxC`, containing a batch of
|
| 265 |
+
`BxN` `C`-dimensional pose encodings.
|
| 266 |
+
pose_encoding_type: The type of pose encoding,
|
| 267 |
+
"""
|
| 268 |
+
pose_encoding_reshaped = pose_encoding.reshape(-1, pose_encoding.shape[-1]) # Reshape to BNxC
|
| 269 |
+
|
| 270 |
+
if pose_encoding_type == "absT_quaR_logFL":
|
| 271 |
+
# 3 for absT, 4 for quaR, 2 for absFL
|
| 272 |
+
abs_T = pose_encoding_reshaped[:, :3]
|
| 273 |
+
quaternion_R = pose_encoding_reshaped[:, 3:7]
|
| 274 |
+
R = quaternion_to_matrix(quaternion_R)
|
| 275 |
+
log_focal_length = pose_encoding_reshaped[:, 7:9]
|
| 276 |
+
# log_focal_length_bias was the hyperparameter
|
| 277 |
+
# to ensure the mean of logFL close to 0 during training
|
| 278 |
+
# Now converted back
|
| 279 |
+
focal_length = (log_focal_length + log_focal_length_bias).exp()
|
| 280 |
+
# clamp to avoid weird fl values
|
| 281 |
+
focal_length = torch.clamp(focal_length,
|
| 282 |
+
min=min_focal_length, max=max_focal_length)
|
| 283 |
+
elif pose_encoding_type == "absT_quaR_OneFL":
|
| 284 |
+
# 3 for absT, 4 for quaR, 1 for absFL
|
| 285 |
+
# [absolute translation, quaternion rotation, normalized focal length]
|
| 286 |
+
abs_T = pose_encoding_reshaped[:, :3]
|
| 287 |
+
quaternion_R = pose_encoding_reshaped[:, 3:7]
|
| 288 |
+
R = quaternion_to_matrix(quaternion_R)
|
| 289 |
+
focal_length = pose_encoding_reshaped[:, 7:8]
|
| 290 |
+
focal_length = torch.clamp(focal_length,
|
| 291 |
+
min=min_focal_length, max=max_focal_length)
|
| 292 |
+
else:
|
| 293 |
+
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
|
| 294 |
+
|
| 295 |
+
if to_OpenCV:
|
| 296 |
+
### From Pytorch3D coordinate to OpenCV coordinate:
|
| 297 |
+
# I hate coordinate conversion
|
| 298 |
+
R = R.clone()
|
| 299 |
+
abs_T = abs_T.clone()
|
| 300 |
+
R[:, :, :2] *= -1
|
| 301 |
+
abs_T[:, :2] *= -1
|
| 302 |
+
R = R.permute(0, 2, 1)
|
| 303 |
+
|
| 304 |
+
extrinsics_4x4 = torch.eye(4, 4).to(R.dtype).to(R.device)
|
| 305 |
+
extrinsics_4x4 = extrinsics_4x4[None].repeat(len(R), 1, 1)
|
| 306 |
+
|
| 307 |
+
extrinsics_4x4[:, :3, :3] = R.clone()
|
| 308 |
+
extrinsics_4x4[:, :3, 3] = abs_T.clone()
|
| 309 |
+
|
| 310 |
+
rel_transform = closed_form_inverse_OpenCV(extrinsics_4x4[0:1])
|
| 311 |
+
rel_transform = rel_transform.expand(len(extrinsics_4x4), -1, -1)
|
| 312 |
+
|
| 313 |
+
# relative to the first camera
|
| 314 |
+
# NOTE it is extrinsics_4x4 x rel_transform instead of rel_transform x extrinsics_4x4
|
| 315 |
+
extrinsics_4x4 = torch.bmm(extrinsics_4x4, rel_transform)
|
| 316 |
+
|
| 317 |
+
R = extrinsics_4x4[:, :3, :3].clone()
|
| 318 |
+
abs_T = extrinsics_4x4[:, :3, 3].clone()
|
| 319 |
+
|
| 320 |
+
if return_dict:
|
| 321 |
+
return {"focal_length": focal_length, "R": R, "T": abs_T}
|
| 322 |
+
|
| 323 |
+
pred_cameras = PerspectiveCameras(focal_length=focal_length,
|
| 324 |
+
R=R, T=abs_T, device=R.device, in_ndc=False)
|
| 325 |
+
return pred_cameras
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def camera_to_pose_encoding(
|
| 329 |
+
camera, pose_encoding_type="absT_quaR_logFL",
|
| 330 |
+
log_focal_length_bias=1.8, min_focal_length=0.1, max_focal_length=30
|
| 331 |
+
):
|
| 332 |
+
"""
|
| 333 |
+
Inverse to pose_encoding_to_camera
|
| 334 |
+
"""
|
| 335 |
+
if pose_encoding_type == "absT_quaR_logFL":
|
| 336 |
+
# Convert rotation matrix to quaternion
|
| 337 |
+
quaternion_R = matrix_to_quaternion(camera.R)
|
| 338 |
+
|
| 339 |
+
# Calculate log_focal_length
|
| 340 |
+
log_focal_length = (
|
| 341 |
+
torch.log(torch.clamp(camera.focal_length,
|
| 342 |
+
min=min_focal_length, max=max_focal_length))
|
| 343 |
+
- log_focal_length_bias
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# Concatenate to form pose_encoding
|
| 347 |
+
pose_encoding = torch.cat([camera.T, quaternion_R, log_focal_length], dim=-1)
|
| 348 |
+
|
| 349 |
+
elif pose_encoding_type == "absT_quaR_OneFL":
|
| 350 |
+
# [absolute translation, quaternion rotation, normalized focal length]
|
| 351 |
+
quaternion_R = matrix_to_quaternion(camera.R)
|
| 352 |
+
focal_length = (torch.clamp(camera.focal_length,
|
| 353 |
+
min=min_focal_length,
|
| 354 |
+
max=max_focal_length))[..., 0:1]
|
| 355 |
+
pose_encoding = torch.cat([camera.T, quaternion_R, focal_length], dim=-1)
|
| 356 |
+
else:
|
| 357 |
+
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
|
| 358 |
+
|
| 359 |
+
return pose_encoding
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def init_pose_enc(B: int,
|
| 363 |
+
S: int, pose_encoding_type: str="absT_quaR_logFL",
|
| 364 |
+
device: Optional[torch.device]=None):
|
| 365 |
+
"""
|
| 366 |
+
Initialize the pose encoding tensor
|
| 367 |
+
args:
|
| 368 |
+
B: batch size
|
| 369 |
+
S: number of frames
|
| 370 |
+
pose_encoding_type: the type of pose encoding
|
| 371 |
+
device: device to put the tensor
|
| 372 |
+
return:
|
| 373 |
+
pose_enc: [B S C]
|
| 374 |
+
"""
|
| 375 |
+
if pose_encoding_type == "absT_quaR_logFL":
|
| 376 |
+
C = 9
|
| 377 |
+
elif pose_encoding_type == "absT_quaR_OneFL":
|
| 378 |
+
C = 8
|
| 379 |
+
else:
|
| 380 |
+
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
|
| 381 |
+
|
| 382 |
+
pose_enc = torch.zeros(B, S, C, device=device)
|
| 383 |
+
pose_enc[..., :3] = 0 # absT
|
| 384 |
+
pose_enc[..., 3] = 1 # quaR
|
| 385 |
+
pose_enc[..., 7:] = 1 # logFL
|
| 386 |
+
return pose_enc
|
| 387 |
+
|
| 388 |
+
def first_pose_enc_norm(pose_enc: torch.Tensor,
|
| 389 |
+
pose_encoding_type: str="absT_quaR_OneFL",
|
| 390 |
+
pose_mode: str = "W2C"):
|
| 391 |
+
"""
|
| 392 |
+
make sure the poses in on window are normalized by the first frame, where the
|
| 393 |
+
first frame transformation is the Identity Matrix.
|
| 394 |
+
NOTE: Poses are all W2C
|
| 395 |
+
args:
|
| 396 |
+
pose_enc: [B S C]
|
| 397 |
+
return:
|
| 398 |
+
pose_enc_norm: [B S C]
|
| 399 |
+
"""
|
| 400 |
+
B, S, C = pose_enc.shape
|
| 401 |
+
# Pose encoding to Cameras (Pytorch3D coordinate)
|
| 402 |
+
pred_cameras = pose_encoding_to_camera(
|
| 403 |
+
pose_enc, pose_encoding_type=pose_encoding_type,
|
| 404 |
+
to_OpenCV=False
|
| 405 |
+
) #NOTE: the camera parameters are not in NDC
|
| 406 |
+
|
| 407 |
+
R = pred_cameras.R # [B*S, 3, 3]
|
| 408 |
+
T = pred_cameras.T # [B*S, 3]
|
| 409 |
+
|
| 410 |
+
Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*S, 3, 4]
|
| 411 |
+
extra_ = torch.tensor([[[0, 0, 0, 1]]],
|
| 412 |
+
device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
|
| 413 |
+
Tran_M = torch.cat([Tran_M, extra_
|
| 414 |
+
], dim=1)
|
| 415 |
+
Tran_M = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)
|
| 416 |
+
|
| 417 |
+
# Take the first frame as the base of world coordinate
|
| 418 |
+
if pose_mode == "C2W":
|
| 419 |
+
Tran_M_new = (Tran_M[:,:1,...].inverse())@Tran_M
|
| 420 |
+
elif pose_mode == "W2C":
|
| 421 |
+
Tran_M_new = Tran_M@(Tran_M[:,:1,...].inverse())
|
| 422 |
+
|
| 423 |
+
Tran_M_new = rearrange(Tran_M_new, 'b s c d -> (b s) c d')
|
| 424 |
+
|
| 425 |
+
R_ = Tran_M_new[:, :3, :3]
|
| 426 |
+
T_ = Tran_M_new[:, :3, 3]
|
| 427 |
+
|
| 428 |
+
# Cameras to Pose encoding
|
| 429 |
+
pred_cameras.R = R_
|
| 430 |
+
pred_cameras.T = T_
|
| 431 |
+
pose_enc_norm = camera_to_pose_encoding(pred_cameras,
|
| 432 |
+
pose_encoding_type=pose_encoding_type)
|
| 433 |
+
pose_enc_norm = rearrange(pose_enc_norm, '(b s) c -> b s c', b=B)
|
| 434 |
+
return pose_enc_norm
|
| 435 |
+
|
| 436 |
+
def first_pose_enc_denorm(
|
| 437 |
+
pose_enc: torch.Tensor,
|
| 438 |
+
pose_enc_1st: torch.Tensor,
|
| 439 |
+
pose_encoding_type: str="absT_quaR_OneFL",
|
| 440 |
+
pose_mode: str = "W2C"):
|
| 441 |
+
"""
|
| 442 |
+
make sure the poses in on window are de-normalized by the first frame, where the
|
| 443 |
+
first frame transformation is the Identity Matrix.
|
| 444 |
+
args:
|
| 445 |
+
pose_enc: [B S C]
|
| 446 |
+
pose_enc_1st: [B 1 C]
|
| 447 |
+
return:
|
| 448 |
+
pose_enc_denorm: [B S C]
|
| 449 |
+
"""
|
| 450 |
+
B, S, C = pose_enc.shape
|
| 451 |
+
pose_enc_all = torch.cat([pose_enc_1st, pose_enc], dim=1)
|
| 452 |
+
|
| 453 |
+
# Pose encoding to Cameras (Pytorch3D coordinate)
|
| 454 |
+
pred_cameras = pose_encoding_to_camera(
|
| 455 |
+
pose_enc_all, pose_encoding_type=pose_encoding_type,
|
| 456 |
+
to_OpenCV=False
|
| 457 |
+
) #NOTE: the camera parameters are not in NDC
|
| 458 |
+
R = pred_cameras.R # [B*(1+S), 3, 3]
|
| 459 |
+
T = pred_cameras.T # [B*(1+S), 3]
|
| 460 |
+
|
| 461 |
+
Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*(1+S), 3, 4]
|
| 462 |
+
extra_ = torch.tensor([[[0, 0, 0, 1]]],
|
| 463 |
+
device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
|
| 464 |
+
Tran_M = torch.cat([Tran_M, extra_
|
| 465 |
+
], dim=1)
|
| 466 |
+
Tran_M_new = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)[:, 1:]
|
| 467 |
+
Tran_M_1st = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)[:,:1]
|
| 468 |
+
|
| 469 |
+
if pose_mode == "C2W":
|
| 470 |
+
Tran_M_new = Tran_M_1st@Tran_M_new
|
| 471 |
+
elif pose_mode == "W2C":
|
| 472 |
+
Tran_M_new = Tran_M_new@Tran_M_1st
|
| 473 |
+
|
| 474 |
+
Tran_M_new_ = torch.cat([Tran_M_1st, Tran_M_new], dim=1)
|
| 475 |
+
R_ = Tran_M_new_[..., :3, :3].view(-1, 3, 3)
|
| 476 |
+
T_ = Tran_M_new_[..., :3, 3].view(-1, 3)
|
| 477 |
+
|
| 478 |
+
# Cameras to Pose encoding
|
| 479 |
+
pred_cameras.R = R_
|
| 480 |
+
pred_cameras.T = T_
|
| 481 |
+
|
| 482 |
+
# Cameras to Pose encoding
|
| 483 |
+
pose_enc_denorm = camera_to_pose_encoding(pred_cameras,
|
| 484 |
+
pose_encoding_type=pose_encoding_type)
|
| 485 |
+
pose_enc_denorm = rearrange(pose_enc_denorm, '(b s) c -> b s c', b=B)
|
| 486 |
+
return pose_enc_denorm[:, 1:]
|
| 487 |
+
|
| 488 |
+
def compute_scale_and_shift(prediction, target, mask):
|
| 489 |
+
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
|
| 490 |
+
a_00 = torch.sum(mask * prediction * prediction, (1, 2))
|
| 491 |
+
a_01 = torch.sum(mask * prediction, (1, 2))
|
| 492 |
+
a_11 = torch.sum(mask, (1, 2))
|
| 493 |
+
|
| 494 |
+
# right hand side: b = [b_0, b_1]
|
| 495 |
+
b_0 = torch.sum(mask * prediction * target, (1, 2))
|
| 496 |
+
b_1 = torch.sum(mask * target, (1, 2))
|
| 497 |
+
|
| 498 |
+
# solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
|
| 499 |
+
x_0 = torch.zeros_like(b_0)
|
| 500 |
+
x_1 = torch.zeros_like(b_1)
|
| 501 |
+
|
| 502 |
+
det = a_00 * a_11 - a_01 * a_01
|
| 503 |
+
# A needs to be a positive definite matrix.
|
| 504 |
+
valid = det > 0
|
| 505 |
+
|
| 506 |
+
x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
|
| 507 |
+
x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
|
| 508 |
+
|
| 509 |
+
return x_0, x_1
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def normalize_prediction_robust(target, mask, Bs):
|
| 513 |
+
ssum = torch.sum(mask, (1, 2))
|
| 514 |
+
valid = ssum > 0
|
| 515 |
+
|
| 516 |
+
m = torch.zeros_like(ssum).to(target.dtype)
|
| 517 |
+
s = torch.ones_like(ssum).to(target.dtype)
|
| 518 |
+
m[valid] = torch.median(
|
| 519 |
+
(mask[valid] * target[valid]).view(valid.sum(), -1), dim=1
|
| 520 |
+
).values
|
| 521 |
+
target = rearrange(target, '(b c) h w -> b c h w', b=Bs)
|
| 522 |
+
m_vid = rearrange(m, '(b c) -> b c 1 1', b=Bs) #.mean(dim=1, keepdim=True)
|
| 523 |
+
mask = rearrange(mask, '(b c) h w -> b c h w', b=Bs)
|
| 524 |
+
|
| 525 |
+
target = target - m_vid
|
| 526 |
+
|
| 527 |
+
sq = torch.sum(mask * target.abs(), (2, 3))
|
| 528 |
+
sq = rearrange(sq, 'b c -> (b c)')
|
| 529 |
+
s[valid] = torch.clamp((sq[valid] / ssum[valid]), min=1e-6)
|
| 530 |
+
s_vid = rearrange(s, '(b c) -> b c 1 1', b=Bs) #.mean(dim=1, keepdim=True)
|
| 531 |
+
target = target / s_vid
|
| 532 |
+
target = rearrange(target, 'b c h w -> (b c) h w', b=Bs)
|
| 533 |
+
|
| 534 |
+
return target, m_vid, s_vid
|
| 535 |
+
|
| 536 |
+
def normalize_video_robust(target, mask, Bs):
|
| 537 |
+
|
| 538 |
+
vid_valid = target[mask]
|
| 539 |
+
# downsample to 1/20
|
| 540 |
+
with torch.no_grad():
|
| 541 |
+
vid_valid = vid_valid[torch.randperm(vid_valid.shape[0], device='cuda')[:vid_valid.shape[0]//5]]
|
| 542 |
+
t_2, t_98 = torch.quantile(vid_valid, 0.02), torch.quantile(vid_valid, 0.98)
|
| 543 |
+
# normalize
|
| 544 |
+
target = (target - t_2) / (t_98 - t_2)*2 - 1
|
| 545 |
+
return target, t_2, t_98
|
| 546 |
+
|
| 547 |
+
def video_loss(prediction, target, mask, Bs):
|
| 548 |
+
# median norm
|
| 549 |
+
prediction_nm, a_norm, b_norm = normalize_video_robust(prediction, mask, Bs)
|
| 550 |
+
target_nm, a_norm_gt, b_norm_gt = normalize_video_robust(target.float(), mask, Bs)
|
| 551 |
+
depth_loss = nn.functional.l1_loss(prediction_nm[mask], target_nm[mask])
|
| 552 |
+
# rel depth 2 metric --> (pred - a')/(b'-a')*(b-a) + a
|
| 553 |
+
scale = (b_norm_gt - a_norm_gt) / (b_norm - a_norm)
|
| 554 |
+
shift = a_norm_gt - a_norm*scale
|
| 555 |
+
return depth_loss, scale, shift, prediction_nm, target_nm
|
| 556 |
+
|
| 557 |
+
def median_loss(prediction, target, mask, Bs):
|
| 558 |
+
# median norm
|
| 559 |
+
prediction_nm, a_norm, b_norm = normalize_prediction_robust(prediction, mask, Bs)
|
| 560 |
+
target_nm, a_norm_gt, b_norm_gt = normalize_prediction_robust(target.float(), mask, Bs)
|
| 561 |
+
depth_loss = nn.functional.l1_loss(prediction_nm[mask], target_nm[mask])
|
| 562 |
+
scale = b_norm_gt/b_norm
|
| 563 |
+
shift = a_norm_gt - a_norm*scale
|
| 564 |
+
return depth_loss, scale, shift, prediction_nm, target_nm
|
| 565 |
+
|
| 566 |
+
def reduction_batch_based(image_loss, M):
|
| 567 |
+
# average of all valid pixels of the batch
|
| 568 |
+
|
| 569 |
+
# avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
|
| 570 |
+
divisor = torch.sum(M)
|
| 571 |
+
|
| 572 |
+
if divisor == 0:
|
| 573 |
+
return 0
|
| 574 |
+
else:
|
| 575 |
+
return torch.sum(image_loss) / divisor
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def reduction_image_based(image_loss, M):
|
| 579 |
+
# mean of average of valid pixels of an image
|
| 580 |
+
|
| 581 |
+
# avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
|
| 582 |
+
valid = M.nonzero()
|
| 583 |
+
|
| 584 |
+
image_loss[valid] = image_loss[valid] / M[valid]
|
| 585 |
+
|
| 586 |
+
return torch.mean(image_loss)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
class ScaleAndShiftInvariantLoss(nn.Module):
|
| 590 |
+
def __init__(self):
|
| 591 |
+
super().__init__()
|
| 592 |
+
self.name = "SSILoss"
|
| 593 |
+
|
| 594 |
+
def forward(self, prediction, target, mask, Bs,
|
| 595 |
+
interpolate=True, return_interpolated=False):
|
| 596 |
+
|
| 597 |
+
if prediction.shape[-1] != target.shape[-1] and interpolate:
|
| 598 |
+
prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True)
|
| 599 |
+
intr_input = prediction
|
| 600 |
+
else:
|
| 601 |
+
intr_input = prediction
|
| 602 |
+
|
| 603 |
+
prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze()
|
| 604 |
+
assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}."
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
scale, shift = compute_scale_and_shift(prediction, target, mask)
|
| 608 |
+
a_norm = scale.view(Bs, -1, 1, 1).mean(dim=1, keepdim=True)
|
| 609 |
+
b_norm = shift.view(Bs, -1, 1, 1).mean(dim=1, keepdim=True)
|
| 610 |
+
prediction = rearrange(prediction, '(b c) h w -> b c h w', b=Bs)
|
| 611 |
+
target = rearrange(target, '(b c) h w -> b c h w', b=Bs)
|
| 612 |
+
mask = rearrange(mask, '(b c) h w -> b c h w', b=Bs)
|
| 613 |
+
scaled_prediction = a_norm * prediction + b_norm
|
| 614 |
+
loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask])
|
| 615 |
+
if not return_interpolated:
|
| 616 |
+
return loss, a_norm, b_norm
|
| 617 |
+
return loss, a_norm, b_norm
|
| 618 |
+
|
| 619 |
+
ScaleAndShiftInvariantLoss_fn = ScaleAndShiftInvariantLoss()
|
| 620 |
+
|
| 621 |
+
class GradientLoss(nn.Module):
|
| 622 |
+
def __init__(self, scales=4, reduction='batch-based'):
|
| 623 |
+
super().__init__()
|
| 624 |
+
|
| 625 |
+
if reduction == 'batch-based':
|
| 626 |
+
self.__reduction = reduction_batch_based
|
| 627 |
+
else:
|
| 628 |
+
self.__reduction = reduction_image_based
|
| 629 |
+
|
| 630 |
+
self.__scales = scales
|
| 631 |
+
|
| 632 |
+
def forward(self, prediction, target, mask):
|
| 633 |
+
total = 0
|
| 634 |
+
|
| 635 |
+
for scale in range(self.__scales):
|
| 636 |
+
step = pow(2, scale)
|
| 637 |
+
l1_ln, a_nm, b_nm = ScaleAndShiftInvariantLoss_fn(prediction[:, ::step, ::step],
|
| 638 |
+
target[:, ::step, ::step], mask[:, ::step, ::step], 1)
|
| 639 |
+
total += l1_ln
|
| 640 |
+
a_nm = a_nm.squeeze().detach() # [B, 1, 1]
|
| 641 |
+
b_nm = b_nm.squeeze().detach() # [B, 1, 1]
|
| 642 |
+
total += 2*gradient_loss(a_nm*prediction[:, ::step, ::step]+b_nm, target[:, ::step, ::step],
|
| 643 |
+
mask[:, ::step, ::step], reduction=self.__reduction)
|
| 644 |
+
|
| 645 |
+
return total
|
| 646 |
+
|
| 647 |
+
Grad_fn = GradientLoss()
|
| 648 |
+
|
| 649 |
+
def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
|
| 650 |
+
|
| 651 |
+
M = torch.sum(mask, (1, 2))
|
| 652 |
+
|
| 653 |
+
diff = prediction - target
|
| 654 |
+
diff = torch.mul(mask, diff)
|
| 655 |
+
grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
|
| 656 |
+
mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
|
| 657 |
+
grad_x = torch.mul(mask_x, grad_x)
|
| 658 |
+
|
| 659 |
+
grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
|
| 660 |
+
mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
|
| 661 |
+
grad_y = torch.mul(mask_y, grad_y)
|
| 662 |
+
|
| 663 |
+
image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
|
| 664 |
+
|
| 665 |
+
return reduction(image_loss, M)
|
| 666 |
+
|
| 667 |
+
def loss_fn(
|
| 668 |
+
poses_preds: List[torch.Tensor],
|
| 669 |
+
poses_pred_all: List[torch.Tensor],
|
| 670 |
+
poses_gt: torch.Tensor,
|
| 671 |
+
inv_depth_preds: List[torch.Tensor],
|
| 672 |
+
inv_depth_raw: List[torch.Tensor],
|
| 673 |
+
depths_gt: torch.Tensor,
|
| 674 |
+
S: int = 16,
|
| 675 |
+
gamma: float = 0.8,
|
| 676 |
+
logger=None,
|
| 677 |
+
logger_tf=None,
|
| 678 |
+
global_step=0,
|
| 679 |
+
):
|
| 680 |
+
"""
|
| 681 |
+
Args:
|
| 682 |
+
poses_preds: list of predicted poses
|
| 683 |
+
poses_gt: ground truth poses
|
| 684 |
+
inv_depth_preds: list of predicted inverse depth maps
|
| 685 |
+
depths_gt: ground truth depth maps
|
| 686 |
+
S: length of sliding window
|
| 687 |
+
"""
|
| 688 |
+
B, T, _, H, W = depths_gt.shape
|
| 689 |
+
|
| 690 |
+
loss_total = 0
|
| 691 |
+
for i in range(len(poses_preds)):
|
| 692 |
+
poses_preds_i = poses_preds[i][0]
|
| 693 |
+
poses_unc_i = poses_preds[i][1]
|
| 694 |
+
poses_gt_i = poses_gt[:, i*S//2:i*S//2+S,:]
|
| 695 |
+
poses_gt_i_norm = first_pose_enc_norm(poses_gt_i,
|
| 696 |
+
pose_encoding_type="absT_quaR_OneFL")
|
| 697 |
+
pose_loss = 0.0
|
| 698 |
+
for idx, poses_preds_ij in enumerate(poses_preds_i):
|
| 699 |
+
i_weight = gamma ** (len(poses_preds_i) - idx - 1)
|
| 700 |
+
if logger is not None:
|
| 701 |
+
if poses_preds_ij.max()>5e1:
|
| 702 |
+
logger.info(f"pose_pred_max_and_mean: {poses_preds_ij.max(), poses_preds_ij.mean()}")
|
| 703 |
+
|
| 704 |
+
trans_loss = (poses_preds_ij[...,:3] - poses_gt_i_norm[...,:3]).abs().sum(dim=-1).mean()
|
| 705 |
+
rot_loss = (poses_preds_ij[...,3:7] - poses_gt_i_norm[...,3:7]).abs().sum(dim=-1).mean()
|
| 706 |
+
focal_loss = (poses_preds_ij[...,7:] - poses_gt_i_norm[...,7:]).abs().sum(dim=-1).mean()
|
| 707 |
+
if torch.isnan((trans_loss + rot_loss + focal_loss)).any():
|
| 708 |
+
pose_loss += 0
|
| 709 |
+
else:
|
| 710 |
+
pose_loss += i_weight*(trans_loss + rot_loss + focal_loss)
|
| 711 |
+
if (logger_tf is not None)&(i==len(poses_preds)-1):
|
| 712 |
+
logger_tf.add_scalar(f"loss@pose/trans_iter{idx}",
|
| 713 |
+
trans_loss, global_step=global_step)
|
| 714 |
+
logger_tf.add_scalar(f"loss@pose/rot_iter{idx}",
|
| 715 |
+
rot_loss, global_step=global_step)
|
| 716 |
+
logger_tf.add_scalar(f"loss@pose/focal_iter{idx}",
|
| 717 |
+
focal_loss, global_step=global_step)
|
| 718 |
+
# compute the uncertainty loss
|
| 719 |
+
with torch.no_grad():
|
| 720 |
+
pose_loss_dist = (poses_preds_ij-poses_gt_i_norm).detach().abs()
|
| 721 |
+
pose_loss_std = 3*pose_loss_dist.view(-1,8).std(dim=0)[None,None,:]
|
| 722 |
+
gt_dist = F.relu(pose_loss_std - pose_loss_dist) / (pose_loss_std + 1e-3)
|
| 723 |
+
unc_loss = (gt_dist - poses_unc_i).abs().mean()
|
| 724 |
+
if (logger_tf is not None)&(i==len(poses_preds)-1):
|
| 725 |
+
logger_tf.add_scalar(f"loss@uncertainty/unc",
|
| 726 |
+
unc_loss,
|
| 727 |
+
global_step=global_step)
|
| 728 |
+
# if logger is not None:
|
| 729 |
+
# logger.info(f"pose_loss: {pose_loss}, unc_loss: {unc_loss}")
|
| 730 |
+
# total loss
|
| 731 |
+
loss_total += 0.1*unc_loss + 2*pose_loss
|
| 732 |
+
|
| 733 |
+
poses_gt_norm = poses_gt
|
| 734 |
+
pose_all_loss = 0.0
|
| 735 |
+
prev_loss = None
|
| 736 |
+
for idx, poses_preds_all_j in enumerate(poses_pred_all):
|
| 737 |
+
i_weight = gamma ** (len(poses_pred_all) - idx - 1)
|
| 738 |
+
trans_loss = (poses_preds_all_j[...,:3] - poses_gt_norm[...,:3]).abs().sum(dim=-1).mean()
|
| 739 |
+
rot_loss = (poses_preds_all_j[...,3:7] - poses_gt_norm[...,3:7]).abs().sum(dim=-1).mean()
|
| 740 |
+
focal_loss = (poses_preds_all_j[...,7:] - poses_gt_norm[...,7:]).abs().sum(dim=-1).mean()
|
| 741 |
+
if (logger_tf is not None):
|
| 742 |
+
if prev_loss is None:
|
| 743 |
+
prev_loss = (trans_loss + rot_loss + focal_loss)
|
| 744 |
+
else:
|
| 745 |
+
des_loss = (trans_loss + rot_loss + focal_loss) - prev_loss
|
| 746 |
+
prev_loss = trans_loss + rot_loss + focal_loss
|
| 747 |
+
logger_tf.add_scalar(f"loss@global_pose/des_iter{idx}",
|
| 748 |
+
des_loss, global_step=global_step)
|
| 749 |
+
logger_tf.add_scalar(f"loss@global_pose/trans_iter{idx}",
|
| 750 |
+
trans_loss, global_step=global_step)
|
| 751 |
+
logger_tf.add_scalar(f"loss@global_pose/rot_iter{idx}",
|
| 752 |
+
rot_loss, global_step=global_step)
|
| 753 |
+
logger_tf.add_scalar(f"loss@global_pose/focal_iter{idx}",
|
| 754 |
+
focal_loss, global_step=global_step)
|
| 755 |
+
if torch.isnan((trans_loss + rot_loss + focal_loss)).any():
|
| 756 |
+
pose_all_loss += 0
|
| 757 |
+
else:
|
| 758 |
+
pose_all_loss += i_weight*(trans_loss + rot_loss + focal_loss)
|
| 759 |
+
|
| 760 |
+
# if logger is not None:
|
| 761 |
+
# logger.info(f"global_pose_loss: {pose_all_loss}")
|
| 762 |
+
|
| 763 |
+
# compute the depth loss
|
| 764 |
+
if inv_depth_preds[0] is not None:
|
| 765 |
+
depths_gt = depths_gt[:,:,0]
|
| 766 |
+
msk = depths_gt > 5e-2
|
| 767 |
+
inv_gt = 1.0 / (depths_gt.clamp(1e-3, 1e16))
|
| 768 |
+
inv_gt_reshp = rearrange(inv_gt, 'b t h w -> (b t) h w')
|
| 769 |
+
inv_depth_preds_reshp = rearrange(inv_depth_preds[0], 'b t h w -> (b t) h w')
|
| 770 |
+
inv_raw_reshp = rearrange(inv_depth_raw[0], 'b t h w -> (b t) h w')
|
| 771 |
+
msk_reshp = rearrange(msk, 'b t h w -> (b t) h w')
|
| 772 |
+
huber_loss = ScaleAndShiftInvariantLoss_fn(inv_depth_preds_reshp, inv_gt_reshp, msk_reshp)
|
| 773 |
+
huber_loss_raw = ScaleAndShiftInvariantLoss_fn(inv_raw_reshp, inv_gt_reshp, msk_reshp)
|
| 774 |
+
# huber_loss = (inv_depth_preds[0][msk]-inv_gt[msk]).abs().mean()
|
| 775 |
+
# cal perason loss
|
| 776 |
+
perason_loss = 0
|
| 777 |
+
# for i in range(B):
|
| 778 |
+
# perason_loss += (1 - pearson_corrcoef(inv_depth_preds[0].view(B*T,-1), inv_gt.view(B*T,-1))).mean()
|
| 779 |
+
# perason_loss = perason_loss/B
|
| 780 |
+
if torch.isnan(huber_loss).any():
|
| 781 |
+
huber_loss = 0
|
| 782 |
+
depth_loss = huber_loss + perason_loss
|
| 783 |
+
if (logger_tf is not None)&(i==len(poses_preds)-1):
|
| 784 |
+
logger_tf.add_scalar(f"loss@depth/huber_iter{idx}",
|
| 785 |
+
depth_loss,
|
| 786 |
+
global_step=global_step)
|
| 787 |
+
# if logger is not None:
|
| 788 |
+
# logger.info(f"opt_depth: {huber_loss_raw - huber_loss}")
|
| 789 |
+
else:
|
| 790 |
+
depth_loss = 0.0
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
loss_total = loss_total/(len(poses_preds)) + 20*depth_loss + pose_all_loss
|
| 794 |
+
|
| 795 |
+
return loss_total, (huber_loss_raw - huber_loss)
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
def vis_depth(x: torch.tensor,
|
| 799 |
+
logger_tf = None, title: str = "depth", step: int = 0):
|
| 800 |
+
"""
|
| 801 |
+
args:
|
| 802 |
+
x: H W
|
| 803 |
+
"""
|
| 804 |
+
assert len(x.shape) == 2
|
| 805 |
+
|
| 806 |
+
depth_map_normalized = cv2.normalize(x.cpu().numpy(),
|
| 807 |
+
None, 0, 255, cv2.NORM_MINMAX)
|
| 808 |
+
depth_map_colored = cv2.applyColorMap(depth_map_normalized.astype(np.uint8),
|
| 809 |
+
cv2.COLORMAP_JET)
|
| 810 |
+
depth_map_tensor = torch.from_numpy(depth_map_colored).permute(2, 0, 1).unsqueeze(0)
|
| 811 |
+
if logger_tf is not None:
|
| 812 |
+
logger_tf.add_image(title, depth_map_tensor[0], step)
|
| 813 |
+
else:
|
| 814 |
+
return depth_map_tensor
|
| 815 |
+
|
| 816 |
+
def vis_pcd(
|
| 817 |
+
rgbs: torch.Tensor,
|
| 818 |
+
R: torch.Tensor,
|
| 819 |
+
T: torch.Tensor,
|
| 820 |
+
xy_depth: torch.Tensor,
|
| 821 |
+
focal_length: torch.Tensor,
|
| 822 |
+
pick_idx: List = [0]
|
| 823 |
+
):
|
| 824 |
+
"""
|
| 825 |
+
args:
|
| 826 |
+
rgbs: [S C H W]
|
| 827 |
+
R: [S 3 3]
|
| 828 |
+
T: [S 3]
|
| 829 |
+
xy_depth: [S H W 3]
|
| 830 |
+
focal_length: [S]
|
| 831 |
+
pick_idx: list of the index to pick
|
| 832 |
+
"""
|
| 833 |
+
S, C, H, W = rgbs.shape
|
| 834 |
+
|
| 835 |
+
rgbs_pick = rgbs[pick_idx]
|
| 836 |
+
R_pick = R[pick_idx]
|
| 837 |
+
T_pick = T[pick_idx]
|
| 838 |
+
xy_depth_pick = xy_depth[pick_idx]
|
| 839 |
+
focal_length_pick = focal_length[pick_idx]
|
| 840 |
+
pcd_world = depth2pcd(xy_depth_pick.clone(),
|
| 841 |
+
focal_length_pick, R_pick.clone(), T_pick.clone(),
|
| 842 |
+
device=xy_depth.device, H=H, W=W)
|
| 843 |
+
pcd_world = pcd_world.permute(0, 2, 1) #[...,[1,0,2]]
|
| 844 |
+
mask = pcd_world.reshape(-1,3)[:,2] < 20
|
| 845 |
+
rgb_world = rgbs_pick.view(len(pick_idx), 3, -1).permute(0, 2, 1)
|
| 846 |
+
pcl = Pointclouds(points=[pcd_world.reshape(-1,3)[mask]],
|
| 847 |
+
features=[rgb_world.reshape(-1,3)[mask]/255])
|
| 848 |
+
return pcl
|
| 849 |
+
|
| 850 |
+
def vis_result(rgbs, poses_pred, poses_gt,
|
| 851 |
+
depth_gt, depth_pred, iter_num=0,
|
| 852 |
+
vis=None, logger_tf=None, cfg=None):
|
| 853 |
+
"""
|
| 854 |
+
Args:
|
| 855 |
+
rgbs: [S C H W]
|
| 856 |
+
depths_gt: [S C H W]
|
| 857 |
+
poses_gt: [S C]
|
| 858 |
+
poses_pred: [S C]
|
| 859 |
+
depth_pred: [S H W]
|
| 860 |
+
"""
|
| 861 |
+
assert len(rgbs.shape) == 4, "only support one sequence, T 3 H W of rbg"
|
| 862 |
+
|
| 863 |
+
if vis is None:
|
| 864 |
+
return
|
| 865 |
+
S, _, H, W = depth_gt.shape
|
| 866 |
+
# get the xy
|
| 867 |
+
yx = torch.meshgrid(torch.arange(H).to(depth_pred.device),
|
| 868 |
+
torch.arange(W).to(depth_pred.device),indexing='ij')
|
| 869 |
+
xy = torch.stack(yx[::-1], dim=0).float().to(depth_pred.device)
|
| 870 |
+
xy_norm = (xy / torch.tensor([W, H],
|
| 871 |
+
device=depth_pred.device).view(2, 1, 1) - 0.5)*2
|
| 872 |
+
xy = xy[None].repeat(S, 1, 1, 1)
|
| 873 |
+
xy_depth = torch.cat([xy, depth_pred[:,None]], dim=1).permute(0, 2, 3, 1)
|
| 874 |
+
xy_depth_gt = torch.cat([xy, depth_gt], dim=1).permute(0, 2, 3, 1)
|
| 875 |
+
# get the focal length
|
| 876 |
+
focal_length = poses_gt[:,-1]*max(H, W)
|
| 877 |
+
|
| 878 |
+
# vis the camera poses
|
| 879 |
+
poses_gt_vis = pose_encoding_to_camera(poses_gt,
|
| 880 |
+
pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
|
| 881 |
+
poses_pred_vis = pose_encoding_to_camera(poses_pred,
|
| 882 |
+
pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
|
| 883 |
+
|
| 884 |
+
R_gt = poses_gt_vis.R.float()
|
| 885 |
+
R_pred = poses_pred_vis.R.float()
|
| 886 |
+
T_gt = poses_gt_vis.T.float()
|
| 887 |
+
T_pred = poses_pred_vis.T.float()
|
| 888 |
+
# C2W poses
|
| 889 |
+
R_gt_c2w = R_gt.permute(0,2,1)
|
| 890 |
+
T_gt_c2w = (-R_gt_c2w @ T_gt[:, :, None]).squeeze(-1)
|
| 891 |
+
R_pred_c2w = R_pred.permute(0,2,1)
|
| 892 |
+
T_pred_c2w = (-R_pred_c2w @ T_pred[:, :, None]).squeeze(-1)
|
| 893 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 894 |
+
pick_idx = torch.randperm(S)[:min(24, S)]
|
| 895 |
+
# pick_idx = [1]
|
| 896 |
+
#NOTE: very strange that the camera need C2W Rotation and W2C translation as input
|
| 897 |
+
poses_gt_vis = PerspectiveCamerasVisual(
|
| 898 |
+
R=R_gt_c2w[pick_idx], T=T_gt[pick_idx],
|
| 899 |
+
device=poses_gt_vis.device, image_size=((H, W),)
|
| 900 |
+
)
|
| 901 |
+
poses_pred_vis = PerspectiveCamerasVisual(
|
| 902 |
+
R=R_pred_c2w[pick_idx], T=T_pred[pick_idx],
|
| 903 |
+
device=poses_pred_vis.device
|
| 904 |
+
)
|
| 905 |
+
visual_dict = {"scenes": {"cameras": poses_pred_vis, "cameras_gt": poses_gt_vis}}
|
| 906 |
+
env_name = f"train_visualize_iter_{iter_num:05d}"
|
| 907 |
+
print(f"Visualizing the scene by visdom at env: {env_name}")
|
| 908 |
+
# visualize the depth map
|
| 909 |
+
vis_depth(depth_pred[0].detach(), logger_tf, title="vis/depth_pred",step=iter_num)
|
| 910 |
+
msk = depth_pred[0] > 1e-3
|
| 911 |
+
vis_depth(depth_gt[0,0].detach(), logger_tf, title="vis/depth_gt",step=iter_num)
|
| 912 |
+
depth_res = (depth_gt[0,0] - depth_pred[0]).abs()
|
| 913 |
+
vis_depth(depth_res.detach(), logger_tf, title="vis/depth_res",step=iter_num)
|
| 914 |
+
# visualize the point cloud
|
| 915 |
+
if cfg.debug.vis_pcd:
|
| 916 |
+
visual_dict["scenes"]["points_gt"] = vis_pcd(rgbs, R_gt, T_gt,
|
| 917 |
+
xy_depth_gt, focal_length, pick_idx)
|
| 918 |
+
else:
|
| 919 |
+
visual_dict["scenes"]["points_pred"] = vis_pcd(rgbs, R_pred, T_pred,
|
| 920 |
+
xy_depth, focal_length, pick_idx)
|
| 921 |
+
# visualize in visdom
|
| 922 |
+
fig = plot_scene(visual_dict, camera_scale=0.05)
|
| 923 |
+
vis.plotlyplot(fig, env=env_name, win="3D")
|
| 924 |
+
vis.save([env_name])
|
| 925 |
+
|
| 926 |
+
return
|
| 927 |
+
|
| 928 |
+
def depth2pcd(
|
| 929 |
+
xy_depth: torch.Tensor,
|
| 930 |
+
focal_length: torch.Tensor,
|
| 931 |
+
R: torch.Tensor,
|
| 932 |
+
T: torch.Tensor,
|
| 933 |
+
device: torch.device = None,
|
| 934 |
+
H: int = 518,
|
| 935 |
+
W: int = 518
|
| 936 |
+
):
|
| 937 |
+
"""
|
| 938 |
+
args:
|
| 939 |
+
xy_depth: [S H W 3]
|
| 940 |
+
focal_length: [S]
|
| 941 |
+
R: [S 3 3] W2C
|
| 942 |
+
T: [S 3] W2C
|
| 943 |
+
return:
|
| 944 |
+
xyz: [S 3 (H W)]
|
| 945 |
+
"""
|
| 946 |
+
S, H, W, _ = xy_depth.shape
|
| 947 |
+
# get the intrinsic
|
| 948 |
+
K = torch.eye(3, device=device)[None].repeat(len(focal_length), 1, 1).to(device)
|
| 949 |
+
K[:, 0, 0] = focal_length
|
| 950 |
+
K[:, 1, 1] = focal_length
|
| 951 |
+
K[:, 0, 2] = 0.5 * W
|
| 952 |
+
K[:, 1, 2] = 0.5 * H
|
| 953 |
+
K_inv = K.inverse()
|
| 954 |
+
# xyz
|
| 955 |
+
xyz = xy_depth.view(S, -1, 3).permute(0, 2, 1) # S 3 (H W)
|
| 956 |
+
depth = xyz[:, 2:].clone() # S (H W) 1
|
| 957 |
+
xyz[:, 2] = 1
|
| 958 |
+
xyz = K_inv @ xyz # S 3 (H W)
|
| 959 |
+
xyz = xyz * depth
|
| 960 |
+
# to world coordinate
|
| 961 |
+
xyz = R.permute(0,2,1) @ (xyz - T[:, :, None])
|
| 962 |
+
|
| 963 |
+
return xyz
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
def pose_enc2mat(poses_pred,
|
| 967 |
+
H_resize, W_resize, resolution=336):
|
| 968 |
+
"""
|
| 969 |
+
This function convert the pose encoding into `intrinsic` and `extrinsic`
|
| 970 |
+
|
| 971 |
+
Args:
|
| 972 |
+
poses_pred: B T 8
|
| 973 |
+
Return:
|
| 974 |
+
Intrinsic B T 3 3
|
| 975 |
+
Extrinsic B T 4 4
|
| 976 |
+
"""
|
| 977 |
+
B, T, _ = poses_pred.shape
|
| 978 |
+
focal_pred = poses_pred[:, :, -1].clone()
|
| 979 |
+
pos_quat_preds = poses_pred[:, :, :7].clone()
|
| 980 |
+
pos_quat_preds = pos_quat_preds.view(B*T, -1)
|
| 981 |
+
# get extrinsic
|
| 982 |
+
c2w_rot = quaternion_to_matrix(pos_quat_preds[:, 3:])
|
| 983 |
+
c2w_tran = pos_quat_preds[:, :3]
|
| 984 |
+
c2w_traj = torch.eye(4)[None].repeat(B*T, 1, 1).to(poses_pred.device)
|
| 985 |
+
c2w_traj[:, :3, :3], c2w_traj[:, :3, 3] = c2w_rot, c2w_tran
|
| 986 |
+
c2w_traj = c2w_traj.view(B, T, 4, 4)
|
| 987 |
+
# get intrinsic
|
| 988 |
+
fxs, fys = focal_pred*resolution, focal_pred*resolution
|
| 989 |
+
intrs = torch.eye(3).to(c2w_traj.device).to(c2w_traj.dtype)[None, None].repeat(B, T, 1, 1)
|
| 990 |
+
intrs[:,:,0,0], intrs[:,:,1,1] = fxs, fys
|
| 991 |
+
intrs[:,:,0,2], intrs[:,:,1,2] = W_resize/2, H_resize/2
|
| 992 |
+
|
| 993 |
+
return intrs, c2w_traj
|
| 994 |
+
|
| 995 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 996 |
+
"""
|
| 997 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 998 |
+
but with a zero subgradient where x is 0.
|
| 999 |
+
"""
|
| 1000 |
+
ret = torch.zeros_like(x)
|
| 1001 |
+
positive_mask = x > 0
|
| 1002 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 1003 |
+
return ret
|
| 1004 |
+
|
| 1005 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 1006 |
+
"""
|
| 1007 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 1008 |
+
part is non negative.
|
| 1009 |
+
|
| 1010 |
+
Args:
|
| 1011 |
+
quaternions: Quaternions with real part first,
|
| 1012 |
+
as tensor of shape (..., 4).
|
| 1013 |
+
|
| 1014 |
+
Returns:
|
| 1015 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 1016 |
+
"""
|
| 1017 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
| 1018 |
+
|
| 1019 |
+
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
| 1020 |
+
"""
|
| 1021 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 1022 |
+
|
| 1023 |
+
Args:
|
| 1024 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 1025 |
+
|
| 1026 |
+
Returns:
|
| 1027 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 1028 |
+
"""
|
| 1029 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 1030 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 1031 |
+
|
| 1032 |
+
batch_dim = matrix.shape[:-2]
|
| 1033 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
|
| 1034 |
+
|
| 1035 |
+
q_abs = _sqrt_positive_part(
|
| 1036 |
+
torch.stack(
|
| 1037 |
+
[1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
|
| 1038 |
+
)
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
| 1042 |
+
quat_by_rijk = torch.stack(
|
| 1043 |
+
[
|
| 1044 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 1045 |
+
# `int`.
|
| 1046 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 1047 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 1048 |
+
# `int`.
|
| 1049 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 1050 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 1051 |
+
# `int`.
|
| 1052 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 1053 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 1054 |
+
# `int`.
|
| 1055 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 1056 |
+
],
|
| 1057 |
+
dim=-2,
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
| 1061 |
+
# the candidate won't be picked.
|
| 1062 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 1063 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 1064 |
+
|
| 1065 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
| 1066 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
| 1067 |
+
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
|
| 1068 |
+
return standardize_quaternion(out)
|
| 1069 |
+
|
| 1070 |
+
|
| 1071 |
+
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
|
| 1072 |
+
# returns a meshgrid sized B x Y x X
|
| 1073 |
+
|
| 1074 |
+
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
|
| 1075 |
+
grid_y = torch.reshape(grid_y, [1, Y, 1])
|
| 1076 |
+
grid_y = grid_y.repeat(B, 1, X)
|
| 1077 |
+
|
| 1078 |
+
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
|
| 1079 |
+
grid_x = torch.reshape(grid_x, [1, 1, X])
|
| 1080 |
+
grid_x = grid_x.repeat(B, Y, 1)
|
| 1081 |
+
|
| 1082 |
+
if stack:
|
| 1083 |
+
# note we stack in xy order
|
| 1084 |
+
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
| 1085 |
+
grid = torch.stack([grid_x, grid_y], dim=-1)
|
| 1086 |
+
return grid
|
| 1087 |
+
else:
|
| 1088 |
+
return grid_y, grid_x
|
| 1089 |
+
|
| 1090 |
+
def get_points_on_a_grid(grid_size, interp_shape,
|
| 1091 |
+
grid_center=(0, 0), device="cuda"):
|
| 1092 |
+
if grid_size == 1:
|
| 1093 |
+
return torch.tensor([interp_shape[1] / 2,
|
| 1094 |
+
interp_shape[0] / 2], device=device)[
|
| 1095 |
+
None, None
|
| 1096 |
+
]
|
| 1097 |
+
|
| 1098 |
+
grid_y, grid_x = meshgrid2d(
|
| 1099 |
+
1, grid_size, grid_size, stack=False, norm=False, device=device
|
| 1100 |
+
)
|
| 1101 |
+
step = interp_shape[1] // 64
|
| 1102 |
+
if grid_center[0] != 0 or grid_center[1] != 0:
|
| 1103 |
+
grid_y = grid_y - grid_size / 2.0
|
| 1104 |
+
grid_x = grid_x - grid_size / 2.0
|
| 1105 |
+
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
|
| 1106 |
+
interp_shape[0] - step * 2
|
| 1107 |
+
)
|
| 1108 |
+
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
|
| 1109 |
+
interp_shape[1] - step * 2
|
| 1110 |
+
)
|
| 1111 |
+
|
| 1112 |
+
grid_y = grid_y + grid_center[0]
|
| 1113 |
+
grid_x = grid_x + grid_center[1]
|
| 1114 |
+
xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
|
| 1115 |
+
return xy
|
| 1116 |
+
|
| 1117 |
+
def normalize_rgb(x,input_size=224,
|
| 1118 |
+
resize_mode: Literal['resize', 'padding'] = 'resize',
|
| 1119 |
+
if_da=False):
|
| 1120 |
+
"""
|
| 1121 |
+
normalize the image for depth anything input
|
| 1122 |
+
|
| 1123 |
+
args:
|
| 1124 |
+
x: the input images [B T C H W]
|
| 1125 |
+
"""
|
| 1126 |
+
if isinstance(x, np.ndarray):
|
| 1127 |
+
x = torch.from_numpy(x) / 255.0
|
| 1128 |
+
elif isinstance(x, torch.Tensor):
|
| 1129 |
+
x = x / 255.0
|
| 1130 |
+
B, T, C, H, W = x.shape
|
| 1131 |
+
x = x.view(B * T, C, H, W)
|
| 1132 |
+
Resizer = Resize(
|
| 1133 |
+
width=input_size,
|
| 1134 |
+
height=input_size,
|
| 1135 |
+
resize_target=False,
|
| 1136 |
+
keep_aspect_ratio=True,
|
| 1137 |
+
ensure_multiple_of=14,
|
| 1138 |
+
resize_method='lower_bound',
|
| 1139 |
+
)
|
| 1140 |
+
if resize_mode == 'padding':
|
| 1141 |
+
# zero padding to make the input size to be multiple of 14
|
| 1142 |
+
if H > W:
|
| 1143 |
+
H_scale = input_size
|
| 1144 |
+
W_scale = W * input_size // H
|
| 1145 |
+
else:
|
| 1146 |
+
W_scale = input_size
|
| 1147 |
+
H_scale = H * input_size // W
|
| 1148 |
+
# resize the image
|
| 1149 |
+
x = F.interpolate(x, size=(H_scale, W_scale),
|
| 1150 |
+
mode='bilinear', align_corners=False)
|
| 1151 |
+
# central padding the image
|
| 1152 |
+
padding_x = (input_size - W_scale) // 2
|
| 1153 |
+
padding_y = (input_size - H_scale) // 2
|
| 1154 |
+
extra_x = (input_size - W_scale) % 2
|
| 1155 |
+
extra_y = (input_size - H_scale) % 2
|
| 1156 |
+
x = F.pad(x, (padding_x, padding_x+extra_x,
|
| 1157 |
+
padding_y, padding_y+extra_y), value=0.)
|
| 1158 |
+
elif resize_mode == 'resize':
|
| 1159 |
+
H_scale, W_scale = Resizer.get_size(H, W)
|
| 1160 |
+
x = F.interpolate(x, size=(int(H_scale), int(W_scale)),
|
| 1161 |
+
mode='bicubic', align_corners=True)
|
| 1162 |
+
# get the mean and std
|
| 1163 |
+
__mean__ = torch.tensor([0.485,
|
| 1164 |
+
0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
|
| 1165 |
+
__std__ = torch.tensor([0.229,
|
| 1166 |
+
0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
|
| 1167 |
+
# normalize the image
|
| 1168 |
+
if if_da:
|
| 1169 |
+
x = (x - __mean__) / __std__
|
| 1170 |
+
else:
|
| 1171 |
+
x = x
|
| 1172 |
+
return x.view(B, T, C, x.shape[-2], x.shape[-1])
|
| 1173 |
+
|
| 1174 |
+
def get_track_points(H, W, T, device, size=100, support_frame=0,
|
| 1175 |
+
query_size=768, unc_metric=None, mode="mixed"):
|
| 1176 |
+
"""
|
| 1177 |
+
This function is used to get the points on the grid
|
| 1178 |
+
args:
|
| 1179 |
+
H: the height of the grid.
|
| 1180 |
+
W: the width of the grid.
|
| 1181 |
+
T: the number of frames.
|
| 1182 |
+
device: the device of the points.
|
| 1183 |
+
size: the size of the grid.
|
| 1184 |
+
"""
|
| 1185 |
+
grid_pts = get_points_on_a_grid(size, (H, W), device=device)
|
| 1186 |
+
grid_pts = grid_pts.round()
|
| 1187 |
+
if mode == "incremental":
|
| 1188 |
+
queries = torch.cat(
|
| 1189 |
+
[torch.randint_like(grid_pts[:, :, :1], T), grid_pts],
|
| 1190 |
+
dim=2,
|
| 1191 |
+
)
|
| 1192 |
+
elif mode == "first":
|
| 1193 |
+
queries_first = torch.cat(
|
| 1194 |
+
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts],
|
| 1195 |
+
dim=2,
|
| 1196 |
+
)
|
| 1197 |
+
queries_support = torch.cat(
|
| 1198 |
+
[torch.randint_like(grid_pts[:, :, :1], T), grid_pts],
|
| 1199 |
+
dim=2,
|
| 1200 |
+
)
|
| 1201 |
+
queries = torch.cat([queries_first, queries_support, queries_support], dim=1)
|
| 1202 |
+
elif mode == "mixed":
|
| 1203 |
+
queries = torch.cat(
|
| 1204 |
+
[torch.randint_like(grid_pts[:, :, :1], T), grid_pts],
|
| 1205 |
+
dim=2,
|
| 1206 |
+
)
|
| 1207 |
+
queries_first = torch.cat(
|
| 1208 |
+
[torch.ones_like(grid_pts[:, :, :1]) * support_frame, grid_pts],
|
| 1209 |
+
dim=2,
|
| 1210 |
+
)
|
| 1211 |
+
queries = torch.cat([queries_first, queries, queries], dim=1)
|
| 1212 |
+
if unc_metric is not None:
|
| 1213 |
+
# filter the points with high uncertainty
|
| 1214 |
+
sample_unc = sample_features5d(unc_metric[None], queries[:,None]).squeeze()
|
| 1215 |
+
if ((sample_unc>0.5).sum() < 20):
|
| 1216 |
+
queries = queries
|
| 1217 |
+
else:
|
| 1218 |
+
queries = queries[:,sample_unc>0.5,:]
|
| 1219 |
+
idx_ = torch.randperm(queries.shape[1], device=device)[:query_size]
|
| 1220 |
+
queries = queries[:, idx_]
|
| 1221 |
+
return queries
|
models/SpaTrackV2/utils/embeddings.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 11 |
+
"""
|
| 12 |
+
grid_size: int of the grid height and width
|
| 13 |
+
return:
|
| 14 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 15 |
+
"""
|
| 16 |
+
if isinstance(grid_size, tuple):
|
| 17 |
+
grid_size_h, grid_size_w = grid_size
|
| 18 |
+
else:
|
| 19 |
+
grid_size_h = grid_size_w = grid_size
|
| 20 |
+
grid_h = np.arange(grid_size_h, dtype=np.float32)
|
| 21 |
+
grid_w = np.arange(grid_size_w, dtype=np.float32)
|
| 22 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 23 |
+
grid = np.stack(grid, axis=0)
|
| 24 |
+
|
| 25 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
| 26 |
+
pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 27 |
+
if cls_token and extra_tokens > 0:
|
| 28 |
+
pos_embed = np.concatenate(
|
| 29 |
+
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
| 30 |
+
)
|
| 31 |
+
return pos_embed
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 35 |
+
assert embed_dim % 3 == 0
|
| 36 |
+
|
| 37 |
+
# use half of dimensions to encode grid_h
|
| 38 |
+
B, S, N, _ = grid.shape
|
| 39 |
+
gridx = grid[..., 0].view(B*S*N).detach().cpu().numpy()
|
| 40 |
+
gridy = grid[..., 1].view(B*S*N).detach().cpu().numpy()
|
| 41 |
+
gridz = grid[..., 2].view(B*S*N).detach().cpu().numpy()
|
| 42 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridx) # (N, D/3)
|
| 43 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridy) # (N, D/3)
|
| 44 |
+
emb_z = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridz) # (N, D/3)
|
| 45 |
+
|
| 46 |
+
emb = np.concatenate([emb_h, emb_w, emb_z], axis=1) # (N, D)
|
| 47 |
+
emb = torch.from_numpy(emb).to(grid.device)
|
| 48 |
+
return emb.view(B, S, N, embed_dim)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 52 |
+
"""
|
| 53 |
+
grid_size: int of the grid height and width
|
| 54 |
+
return:
|
| 55 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 56 |
+
"""
|
| 57 |
+
if isinstance(grid_size, tuple):
|
| 58 |
+
grid_size_h, grid_size_w = grid_size
|
| 59 |
+
else:
|
| 60 |
+
grid_size_h = grid_size_w = grid_size
|
| 61 |
+
grid_h = np.arange(grid_size_h, dtype=np.float32)
|
| 62 |
+
grid_w = np.arange(grid_size_w, dtype=np.float32)
|
| 63 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 64 |
+
grid = np.stack(grid, axis=0)
|
| 65 |
+
|
| 66 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
| 67 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 68 |
+
if cls_token and extra_tokens > 0:
|
| 69 |
+
pos_embed = np.concatenate(
|
| 70 |
+
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
| 71 |
+
)
|
| 72 |
+
return pos_embed
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 76 |
+
assert embed_dim % 2 == 0
|
| 77 |
+
|
| 78 |
+
# use half of dimensions to encode grid_h
|
| 79 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 80 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 81 |
+
|
| 82 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 83 |
+
return emb
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 87 |
+
"""
|
| 88 |
+
embed_dim: output dimension for each position
|
| 89 |
+
pos: a list of positions to be encoded: size (M,)
|
| 90 |
+
out: (M, D)
|
| 91 |
+
"""
|
| 92 |
+
assert embed_dim % 2 == 0
|
| 93 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 94 |
+
omega /= embed_dim / 2.0
|
| 95 |
+
omega = 1.0 / 10000 ** omega # (D/2,)
|
| 96 |
+
pos = pos.reshape(-1) # (M,)
|
| 97 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 98 |
+
|
| 99 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 100 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 101 |
+
|
| 102 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 103 |
+
return emb
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_2d_embedding(xy, C, cat_coords=True):
|
| 107 |
+
B, N, D = xy.shape
|
| 108 |
+
assert D == 2
|
| 109 |
+
|
| 110 |
+
x = xy[:, :, 0:1]
|
| 111 |
+
y = xy[:, :, 1:2]
|
| 112 |
+
div_term = (
|
| 113 |
+
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
| 114 |
+
).reshape(1, 1, int(C / 2))
|
| 115 |
+
|
| 116 |
+
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
| 117 |
+
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
| 118 |
+
|
| 119 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
| 120 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
| 121 |
+
|
| 122 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
| 123 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
| 124 |
+
|
| 125 |
+
pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
|
| 126 |
+
if cat_coords:
|
| 127 |
+
pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
|
| 128 |
+
return pe
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_3d_embedding(xyz, C, cat_coords=True):
|
| 132 |
+
B, N, D = xyz.shape
|
| 133 |
+
assert D == 3
|
| 134 |
+
|
| 135 |
+
x = xyz[:, :, 0:1]
|
| 136 |
+
y = xyz[:, :, 1:2]
|
| 137 |
+
z = xyz[:, :, 2:3]
|
| 138 |
+
div_term = (
|
| 139 |
+
torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
|
| 140 |
+
).reshape(1, 1, int(C / 2))
|
| 141 |
+
|
| 142 |
+
pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
| 143 |
+
pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
| 144 |
+
pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
| 145 |
+
|
| 146 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
| 147 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
| 148 |
+
|
| 149 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
| 150 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
| 151 |
+
|
| 152 |
+
pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
| 153 |
+
pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
| 154 |
+
|
| 155 |
+
pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
|
| 156 |
+
if cat_coords:
|
| 157 |
+
pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
|
| 158 |
+
return pe
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_4d_embedding(xyzw, C, cat_coords=True):
|
| 162 |
+
B, N, D = xyzw.shape
|
| 163 |
+
assert D == 4
|
| 164 |
+
|
| 165 |
+
x = xyzw[:, :, 0:1]
|
| 166 |
+
y = xyzw[:, :, 1:2]
|
| 167 |
+
z = xyzw[:, :, 2:3]
|
| 168 |
+
w = xyzw[:, :, 3:4]
|
| 169 |
+
div_term = (
|
| 170 |
+
torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
|
| 171 |
+
).reshape(1, 1, int(C / 2))
|
| 172 |
+
|
| 173 |
+
pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
| 174 |
+
pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
| 175 |
+
pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
| 176 |
+
pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
| 177 |
+
|
| 178 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
| 179 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
| 180 |
+
|
| 181 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
| 182 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
| 183 |
+
|
| 184 |
+
pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
| 185 |
+
pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
| 186 |
+
|
| 187 |
+
pe_w[:, :, 0::2] = torch.sin(w * div_term)
|
| 188 |
+
pe_w[:, :, 1::2] = torch.cos(w * div_term)
|
| 189 |
+
|
| 190 |
+
pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
|
| 191 |
+
if cat_coords:
|
| 192 |
+
pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
|
| 193 |
+
return pe
|
| 194 |
+
|
| 195 |
+
import torch.nn as nn
|
| 196 |
+
class Embedder_Fourier(nn.Module):
|
| 197 |
+
def __init__(self, input_dim, max_freq_log2, N_freqs,
|
| 198 |
+
log_sampling=True, include_input=True,
|
| 199 |
+
periodic_fns=(torch.sin, torch.cos)):
|
| 200 |
+
'''
|
| 201 |
+
:param input_dim: dimension of input to be embedded
|
| 202 |
+
:param max_freq_log2: log2 of max freq; min freq is 1 by default
|
| 203 |
+
:param N_freqs: number of frequency bands
|
| 204 |
+
:param log_sampling: if True, frequency bands are linerly sampled in log-space
|
| 205 |
+
:param include_input: if True, raw input is included in the embedding
|
| 206 |
+
:param periodic_fns: periodic functions used to embed input
|
| 207 |
+
'''
|
| 208 |
+
super(Embedder_Fourier, self).__init__()
|
| 209 |
+
|
| 210 |
+
self.input_dim = input_dim
|
| 211 |
+
self.include_input = include_input
|
| 212 |
+
self.periodic_fns = periodic_fns
|
| 213 |
+
|
| 214 |
+
self.out_dim = 0
|
| 215 |
+
if self.include_input:
|
| 216 |
+
self.out_dim += self.input_dim
|
| 217 |
+
|
| 218 |
+
self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns)
|
| 219 |
+
|
| 220 |
+
if log_sampling:
|
| 221 |
+
self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
|
| 222 |
+
else:
|
| 223 |
+
self.freq_bands = torch.linspace(
|
| 224 |
+
2. ** 0., 2. ** max_freq_log2, N_freqs)
|
| 225 |
+
|
| 226 |
+
self.freq_bands = self.freq_bands.numpy().tolist()
|
| 227 |
+
|
| 228 |
+
def forward(self,
|
| 229 |
+
input: torch.Tensor,
|
| 230 |
+
rescale: float = 1.0):
|
| 231 |
+
'''
|
| 232 |
+
:param input: tensor of shape [..., self.input_dim]
|
| 233 |
+
:return: tensor of shape [..., self.out_dim]
|
| 234 |
+
'''
|
| 235 |
+
assert (input.shape[-1] == self.input_dim)
|
| 236 |
+
out = []
|
| 237 |
+
if self.include_input:
|
| 238 |
+
out.append(input/rescale)
|
| 239 |
+
|
| 240 |
+
for i in range(len(self.freq_bands)):
|
| 241 |
+
freq = self.freq_bands[i]
|
| 242 |
+
for p_fn in self.periodic_fns:
|
| 243 |
+
out.append(p_fn(input * freq))
|
| 244 |
+
out = torch.cat(out, dim=-1)
|
| 245 |
+
|
| 246 |
+
assert (out.shape[-1] == self.out_dim)
|
| 247 |
+
return out
|
models/SpaTrackV2/utils/model_utils.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from easydict import EasyDict as edict
|
| 10 |
+
from sklearn.decomposition import PCA
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
|
| 13 |
+
EPS = 1e-6
|
| 14 |
+
|
| 15 |
+
def nearest_sample2d(im, x, y, return_inbounds=False):
|
| 16 |
+
# x and y are each B, N
|
| 17 |
+
# output is B, C, N
|
| 18 |
+
if len(im.shape) == 5:
|
| 19 |
+
B, N, C, H, W = list(im.shape)
|
| 20 |
+
else:
|
| 21 |
+
B, C, H, W = list(im.shape)
|
| 22 |
+
N = list(x.shape)[1]
|
| 23 |
+
|
| 24 |
+
x = x.float()
|
| 25 |
+
y = y.float()
|
| 26 |
+
H_f = torch.tensor(H, dtype=torch.float32)
|
| 27 |
+
W_f = torch.tensor(W, dtype=torch.float32)
|
| 28 |
+
|
| 29 |
+
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
| 30 |
+
|
| 31 |
+
max_y = (H_f - 1).int()
|
| 32 |
+
max_x = (W_f - 1).int()
|
| 33 |
+
|
| 34 |
+
x0 = torch.floor(x).int()
|
| 35 |
+
x1 = x0 + 1
|
| 36 |
+
y0 = torch.floor(y).int()
|
| 37 |
+
y1 = y0 + 1
|
| 38 |
+
|
| 39 |
+
x0_clip = torch.clamp(x0, 0, max_x)
|
| 40 |
+
x1_clip = torch.clamp(x1, 0, max_x)
|
| 41 |
+
y0_clip = torch.clamp(y0, 0, max_y)
|
| 42 |
+
y1_clip = torch.clamp(y1, 0, max_y)
|
| 43 |
+
dim2 = W
|
| 44 |
+
dim1 = W * H
|
| 45 |
+
|
| 46 |
+
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
|
| 47 |
+
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
| 48 |
+
|
| 49 |
+
base_y0 = base + y0_clip * dim2
|
| 50 |
+
base_y1 = base + y1_clip * dim2
|
| 51 |
+
|
| 52 |
+
idx_y0_x0 = base_y0 + x0_clip
|
| 53 |
+
idx_y0_x1 = base_y0 + x1_clip
|
| 54 |
+
idx_y1_x0 = base_y1 + x0_clip
|
| 55 |
+
idx_y1_x1 = base_y1 + x1_clip
|
| 56 |
+
|
| 57 |
+
# use the indices to lookup pixels in the flat image
|
| 58 |
+
# im is B x C x H x W
|
| 59 |
+
# move C out to last dim
|
| 60 |
+
if len(im.shape) == 5:
|
| 61 |
+
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
|
| 62 |
+
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
|
| 63 |
+
0, 2, 1
|
| 64 |
+
)
|
| 65 |
+
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
|
| 66 |
+
0, 2, 1
|
| 67 |
+
)
|
| 68 |
+
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
|
| 69 |
+
0, 2, 1
|
| 70 |
+
)
|
| 71 |
+
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
|
| 72 |
+
0, 2, 1
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
|
| 76 |
+
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
| 77 |
+
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
| 78 |
+
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
| 79 |
+
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
| 80 |
+
|
| 81 |
+
# Finally calculate interpolated values.
|
| 82 |
+
x0_f = x0.float()
|
| 83 |
+
x1_f = x1.float()
|
| 84 |
+
y0_f = y0.float()
|
| 85 |
+
y1_f = y1.float()
|
| 86 |
+
|
| 87 |
+
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
| 88 |
+
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
| 89 |
+
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
| 90 |
+
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
| 91 |
+
|
| 92 |
+
# w_yi_xo is B * N * 1
|
| 93 |
+
max_idx = torch.cat([w_y0_x0, w_y0_x1, w_y1_x0, w_y1_x1], dim=-1).max(dim=-1)[1]
|
| 94 |
+
output = torch.stack([i_y0_x0, i_y0_x1, i_y1_x0, i_y1_x1], dim=-1).gather(-1, max_idx[...,None,None].repeat(1,1,C,1)).squeeze(-1)
|
| 95 |
+
|
| 96 |
+
# output is B*N x C
|
| 97 |
+
output = output.view(B, -1, C)
|
| 98 |
+
output = output.permute(0, 2, 1)
|
| 99 |
+
# output is B x C x N
|
| 100 |
+
|
| 101 |
+
if return_inbounds:
|
| 102 |
+
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
| 103 |
+
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
| 104 |
+
inbounds = (x_valid & y_valid).float()
|
| 105 |
+
inbounds = inbounds.reshape(
|
| 106 |
+
B, N
|
| 107 |
+
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
| 108 |
+
return output, inbounds
|
| 109 |
+
|
| 110 |
+
return output # B, C, N
|
| 111 |
+
|
| 112 |
+
def smart_cat(tensor1, tensor2, dim):
|
| 113 |
+
if tensor1 is None:
|
| 114 |
+
return tensor2
|
| 115 |
+
return torch.cat([tensor1, tensor2], dim=dim)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def normalize_single(d):
|
| 119 |
+
# d is a whatever shape torch tensor
|
| 120 |
+
dmin = torch.min(d)
|
| 121 |
+
dmax = torch.max(d)
|
| 122 |
+
d = (d - dmin) / (EPS + (dmax - dmin))
|
| 123 |
+
return d
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def normalize(d):
|
| 127 |
+
# d is B x whatever. normalize within each element of the batch
|
| 128 |
+
out = torch.zeros(d.size())
|
| 129 |
+
if d.is_cuda:
|
| 130 |
+
out = out.cuda()
|
| 131 |
+
B = list(d.size())[0]
|
| 132 |
+
for b in list(range(B)):
|
| 133 |
+
out[b] = normalize_single(d[b])
|
| 134 |
+
return out
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
|
| 138 |
+
# returns a meshgrid sized B x Y x X
|
| 139 |
+
|
| 140 |
+
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
|
| 141 |
+
grid_y = torch.reshape(grid_y, [1, Y, 1])
|
| 142 |
+
grid_y = grid_y.repeat(B, 1, X)
|
| 143 |
+
|
| 144 |
+
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
|
| 145 |
+
grid_x = torch.reshape(grid_x, [1, 1, X])
|
| 146 |
+
grid_x = grid_x.repeat(B, Y, 1)
|
| 147 |
+
|
| 148 |
+
if stack:
|
| 149 |
+
# note we stack in xy order
|
| 150 |
+
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
| 151 |
+
grid = torch.stack([grid_x, grid_y], dim=-1)
|
| 152 |
+
return grid
|
| 153 |
+
else:
|
| 154 |
+
return grid_y, grid_x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
|
| 158 |
+
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
|
| 159 |
+
# returns shape-1
|
| 160 |
+
# axis can be a list of axes
|
| 161 |
+
for (a, b) in zip(x.size(), mask.size()):
|
| 162 |
+
assert a == b # some shape mismatch!
|
| 163 |
+
prod = x * mask
|
| 164 |
+
if dim is None:
|
| 165 |
+
numer = torch.sum(prod)
|
| 166 |
+
denom = EPS + torch.sum(mask)
|
| 167 |
+
else:
|
| 168 |
+
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
| 169 |
+
denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
|
| 170 |
+
|
| 171 |
+
mean = numer / denom
|
| 172 |
+
return mean
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def bilinear_sample2d(im, x, y, return_inbounds=False):
|
| 176 |
+
# x and y are each B, N
|
| 177 |
+
# output is B, C, N
|
| 178 |
+
if len(im.shape) == 5:
|
| 179 |
+
B, N, C, H, W = list(im.shape)
|
| 180 |
+
else:
|
| 181 |
+
B, C, H, W = list(im.shape)
|
| 182 |
+
N = list(x.shape)[1]
|
| 183 |
+
|
| 184 |
+
x = x.float()
|
| 185 |
+
y = y.float()
|
| 186 |
+
H_f = torch.tensor(H, dtype=torch.float32)
|
| 187 |
+
W_f = torch.tensor(W, dtype=torch.float32)
|
| 188 |
+
|
| 189 |
+
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
| 190 |
+
|
| 191 |
+
max_y = (H_f - 1).int()
|
| 192 |
+
max_x = (W_f - 1).int()
|
| 193 |
+
|
| 194 |
+
x0 = torch.floor(x).int()
|
| 195 |
+
x1 = x0 + 1
|
| 196 |
+
y0 = torch.floor(y).int()
|
| 197 |
+
y1 = y0 + 1
|
| 198 |
+
|
| 199 |
+
x0_clip = torch.clamp(x0, 0, max_x)
|
| 200 |
+
x1_clip = torch.clamp(x1, 0, max_x)
|
| 201 |
+
y0_clip = torch.clamp(y0, 0, max_y)
|
| 202 |
+
y1_clip = torch.clamp(y1, 0, max_y)
|
| 203 |
+
dim2 = W
|
| 204 |
+
dim1 = W * H
|
| 205 |
+
|
| 206 |
+
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
|
| 207 |
+
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
| 208 |
+
|
| 209 |
+
base_y0 = base + y0_clip * dim2
|
| 210 |
+
base_y1 = base + y1_clip * dim2
|
| 211 |
+
|
| 212 |
+
idx_y0_x0 = base_y0 + x0_clip
|
| 213 |
+
idx_y0_x1 = base_y0 + x1_clip
|
| 214 |
+
idx_y1_x0 = base_y1 + x0_clip
|
| 215 |
+
idx_y1_x1 = base_y1 + x1_clip
|
| 216 |
+
|
| 217 |
+
# use the indices to lookup pixels in the flat image
|
| 218 |
+
# im is B x C x H x W
|
| 219 |
+
# move C out to last dim
|
| 220 |
+
if len(im.shape) == 5:
|
| 221 |
+
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
|
| 222 |
+
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
|
| 223 |
+
0, 2, 1
|
| 224 |
+
)
|
| 225 |
+
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
|
| 226 |
+
0, 2, 1
|
| 227 |
+
)
|
| 228 |
+
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
|
| 229 |
+
0, 2, 1
|
| 230 |
+
)
|
| 231 |
+
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
|
| 232 |
+
0, 2, 1
|
| 233 |
+
)
|
| 234 |
+
else:
|
| 235 |
+
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
|
| 236 |
+
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
| 237 |
+
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
| 238 |
+
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
| 239 |
+
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
| 240 |
+
|
| 241 |
+
# Finally calculate interpolated values.
|
| 242 |
+
x0_f = x0.float()
|
| 243 |
+
x1_f = x1.float()
|
| 244 |
+
y0_f = y0.float()
|
| 245 |
+
y1_f = y1.float()
|
| 246 |
+
|
| 247 |
+
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
| 248 |
+
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
| 249 |
+
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
| 250 |
+
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
| 251 |
+
|
| 252 |
+
output = (
|
| 253 |
+
w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
|
| 254 |
+
)
|
| 255 |
+
# output is B*N x C
|
| 256 |
+
output = output.view(B, -1, C)
|
| 257 |
+
output = output.permute(0, 2, 1)
|
| 258 |
+
# output is B x C x N
|
| 259 |
+
|
| 260 |
+
if return_inbounds:
|
| 261 |
+
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
| 262 |
+
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
| 263 |
+
inbounds = (x_valid & y_valid).float()
|
| 264 |
+
inbounds = inbounds.reshape(
|
| 265 |
+
B, N
|
| 266 |
+
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
| 267 |
+
return output, inbounds
|
| 268 |
+
|
| 269 |
+
return output # B, C, N
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def procrustes_analysis(X0,X1,Weight): # [B,N,3]
|
| 273 |
+
# translation
|
| 274 |
+
t0 = X0.mean(dim=1,keepdim=True)
|
| 275 |
+
t1 = X1.mean(dim=1,keepdim=True)
|
| 276 |
+
X0c = X0-t0
|
| 277 |
+
X1c = X1-t1
|
| 278 |
+
# scale
|
| 279 |
+
# s0 = (X0c**2).sum(dim=-1).mean().sqrt()
|
| 280 |
+
# s1 = (X1c**2).sum(dim=-1).mean().sqrt()
|
| 281 |
+
# X0cs = X0c/s0
|
| 282 |
+
# X1cs = X1c/s1
|
| 283 |
+
# rotation (use double for SVD, float loses precision)
|
| 284 |
+
U,_,V = (X0c.t()@X1c).double().svd(some=True)
|
| 285 |
+
R = (U@V.t()).float()
|
| 286 |
+
if R.det()<0: R[2] *= -1
|
| 287 |
+
# align X1 to X0: X1to0 = (X1-t1)/@R.t()+t0
|
| 288 |
+
se3 = edict(t0=t0[0],t1=t1[0],R=R)
|
| 289 |
+
|
| 290 |
+
return se3
|
| 291 |
+
|
| 292 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border", interp_mode="bilinear"):
|
| 293 |
+
r"""Sample a tensor using bilinear interpolation
|
| 294 |
+
|
| 295 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
| 296 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
| 297 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
| 298 |
+
convention.
|
| 299 |
+
|
| 300 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
| 301 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
| 302 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
| 303 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
| 304 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
| 305 |
+
|
| 306 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
| 307 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
| 308 |
+
that in this case the order of the components is slightly different
|
| 309 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
| 310 |
+
|
| 311 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
| 312 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
| 313 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
| 314 |
+
pixel.
|
| 315 |
+
|
| 316 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
| 317 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
| 318 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
| 319 |
+
pixel.
|
| 320 |
+
|
| 321 |
+
Similar conventions apply to the :math:`y` for the range
|
| 322 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
| 323 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
input (Tensor): batch of input images.
|
| 327 |
+
coords (Tensor): batch of coordinates.
|
| 328 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
| 329 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
Tensor: sampled points.
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
sizes = input.shape[2:]
|
| 336 |
+
|
| 337 |
+
assert len(sizes) in [2, 3]
|
| 338 |
+
|
| 339 |
+
if len(sizes) == 3:
|
| 340 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
| 341 |
+
coords = coords[..., [1, 2, 0]]
|
| 342 |
+
|
| 343 |
+
if align_corners:
|
| 344 |
+
coords = coords * torch.tensor(
|
| 345 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
| 346 |
+
)
|
| 347 |
+
else:
|
| 348 |
+
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
|
| 349 |
+
|
| 350 |
+
coords -= 1
|
| 351 |
+
|
| 352 |
+
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode, mode=interp_mode)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def sample_features4d(input, coords, interp_mode="bilinear"):
|
| 356 |
+
r"""Sample spatial features
|
| 357 |
+
|
| 358 |
+
`sample_features4d(input, coords)` samples the spatial features
|
| 359 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
| 360 |
+
|
| 361 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
| 362 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
| 363 |
+
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
| 364 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
| 365 |
+
|
| 366 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
| 367 |
+
R, C)`.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
input (Tensor): spatial features.
|
| 371 |
+
coords (Tensor): points.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
Tensor: sampled features.
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
B, _, _, _ = input.shape
|
| 378 |
+
|
| 379 |
+
# B R 2 -> B R 1 2
|
| 380 |
+
coords = coords.unsqueeze(2)
|
| 381 |
+
|
| 382 |
+
# B C R 1
|
| 383 |
+
feats = bilinear_sampler(input, coords, interp_mode=interp_mode)
|
| 384 |
+
|
| 385 |
+
return feats.permute(0, 2, 1, 3).view(
|
| 386 |
+
B, -1, feats.shape[1] * feats.shape[3]
|
| 387 |
+
) # B C R 1 -> B R C
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def sample_features5d(input, coords, interp_mode="bilinear"):
|
| 391 |
+
r"""Sample spatio-temporal features
|
| 392 |
+
|
| 393 |
+
`sample_features5d(input, coords)` works in the same way as
|
| 394 |
+
:func:`sample_features4d` but for spatio-temporal features and points:
|
| 395 |
+
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
| 396 |
+
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
| 397 |
+
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
input (Tensor): spatio-temporal features.
|
| 401 |
+
coords (Tensor): spatio-temporal points.
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
Tensor: sampled features.
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
B, T, _, _, _ = input.shape
|
| 408 |
+
|
| 409 |
+
# B T C H W -> B C T H W
|
| 410 |
+
input = input.permute(0, 2, 1, 3, 4)
|
| 411 |
+
|
| 412 |
+
# B R1 R2 3 -> B R1 R2 1 3
|
| 413 |
+
coords = coords.unsqueeze(3)
|
| 414 |
+
|
| 415 |
+
# B C R1 R2 1
|
| 416 |
+
feats = bilinear_sampler(input, coords, interp_mode=interp_mode)
|
| 417 |
+
|
| 418 |
+
return feats.permute(0, 2, 3, 1, 4).view(
|
| 419 |
+
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
| 420 |
+
) # B C R1 R2 1 -> B R1 R2 C
|
| 421 |
+
|
| 422 |
+
def vis_PCA(fmaps, save_dir):
|
| 423 |
+
"""
|
| 424 |
+
visualize the PCA of the feature maps
|
| 425 |
+
args:
|
| 426 |
+
fmaps: feature maps 1 C H W
|
| 427 |
+
save_dir: the directory to save the PCA visualization
|
| 428 |
+
"""
|
| 429 |
+
|
| 430 |
+
pca = PCA(n_components=3)
|
| 431 |
+
fmap_vis = fmaps[0,...]
|
| 432 |
+
fmap_vnorm = (
|
| 433 |
+
(fmap_vis-fmap_vis.min())/
|
| 434 |
+
(fmap_vis.max()-fmap_vis.min()))
|
| 435 |
+
H_vis, W_vis = fmap_vis.shape[1:]
|
| 436 |
+
fmap_vnorm = fmap_vnorm.reshape(fmap_vnorm.shape[0],
|
| 437 |
+
-1).permute(1,0)
|
| 438 |
+
fmap_pca = pca.fit_transform(fmap_vnorm.detach().cpu().numpy())
|
| 439 |
+
pca = fmap_pca.reshape(H_vis,W_vis,3)
|
| 440 |
+
plt.imsave(save_dir,
|
| 441 |
+
(
|
| 442 |
+
(pca-pca.min())/
|
| 443 |
+
(pca.max()-pca.min())
|
| 444 |
+
))
|
models/SpaTrackV2/utils/visualizer.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
import cv2
|
| 10 |
+
import torch
|
| 11 |
+
import flow_vis
|
| 12 |
+
|
| 13 |
+
from matplotlib import cm
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import torchvision.transforms as transforms
|
| 16 |
+
import moviepy
|
| 17 |
+
from moviepy.editor import ImageSequenceClip
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def read_video_from_path(path):
|
| 22 |
+
cap = cv2.VideoCapture(path)
|
| 23 |
+
if not cap.isOpened():
|
| 24 |
+
print("Error opening video file")
|
| 25 |
+
else:
|
| 26 |
+
frames = []
|
| 27 |
+
while cap.isOpened():
|
| 28 |
+
ret, frame = cap.read()
|
| 29 |
+
if ret == True:
|
| 30 |
+
frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
| 31 |
+
else:
|
| 32 |
+
break
|
| 33 |
+
cap.release()
|
| 34 |
+
return np.stack(frames)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Visualizer:
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
save_dir: str = "./results",
|
| 41 |
+
grayscale: bool = False,
|
| 42 |
+
pad_value: int = 0,
|
| 43 |
+
fps: int = 10,
|
| 44 |
+
mode: str = "rainbow", # 'cool', 'optical_flow'
|
| 45 |
+
linewidth: int = 2,
|
| 46 |
+
show_first_frame: int = 10,
|
| 47 |
+
tracks_leave_trace: int = 0, # -1 for infinite
|
| 48 |
+
):
|
| 49 |
+
self.mode = mode
|
| 50 |
+
self.save_dir = save_dir
|
| 51 |
+
if mode == "rainbow":
|
| 52 |
+
self.color_map = cm.get_cmap("gist_rainbow")
|
| 53 |
+
elif mode == "cool":
|
| 54 |
+
self.color_map = cm.get_cmap(mode)
|
| 55 |
+
self.show_first_frame = show_first_frame
|
| 56 |
+
self.grayscale = grayscale
|
| 57 |
+
self.tracks_leave_trace = tracks_leave_trace
|
| 58 |
+
self.pad_value = pad_value
|
| 59 |
+
self.linewidth = linewidth
|
| 60 |
+
self.fps = fps
|
| 61 |
+
|
| 62 |
+
def visualize(
|
| 63 |
+
self,
|
| 64 |
+
video: torch.Tensor, # (B,T,C,H,W)
|
| 65 |
+
tracks: torch.Tensor, # (B,T,N,2)
|
| 66 |
+
visibility: torch.Tensor = None, # (B, T, N, 1) bool
|
| 67 |
+
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
| 68 |
+
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
| 69 |
+
filename: str = "video",
|
| 70 |
+
writer=None, # tensorboard Summary Writer, used for visualization during training
|
| 71 |
+
step: int = 0,
|
| 72 |
+
query_frame: int = 0,
|
| 73 |
+
save_video: bool = True,
|
| 74 |
+
compensate_for_camera_motion: bool = False,
|
| 75 |
+
rigid_part = None,
|
| 76 |
+
video_depth = None # (B,T,C,H,W)
|
| 77 |
+
):
|
| 78 |
+
if compensate_for_camera_motion:
|
| 79 |
+
assert segm_mask is not None
|
| 80 |
+
if segm_mask is not None:
|
| 81 |
+
coords = tracks[0, query_frame].round().long()
|
| 82 |
+
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
|
| 83 |
+
|
| 84 |
+
video = F.pad(
|
| 85 |
+
video,
|
| 86 |
+
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
|
| 87 |
+
"constant",
|
| 88 |
+
255,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if video_depth is not None:
|
| 92 |
+
video_depth = (video_depth*255).cpu().numpy().astype(np.uint8)
|
| 93 |
+
video_depth = ([cv2.applyColorMap(video_depth[0,i,0], cv2.COLORMAP_INFERNO)
|
| 94 |
+
for i in range(video_depth.shape[1])])
|
| 95 |
+
video_depth = np.stack(video_depth, axis=0)
|
| 96 |
+
video_depth = torch.from_numpy(video_depth).permute(0, 3, 1, 2)[None]
|
| 97 |
+
|
| 98 |
+
tracks = tracks + self.pad_value
|
| 99 |
+
|
| 100 |
+
if self.grayscale:
|
| 101 |
+
transform = transforms.Grayscale()
|
| 102 |
+
video = transform(video)
|
| 103 |
+
video = video.repeat(1, 1, 3, 1, 1)
|
| 104 |
+
|
| 105 |
+
res_video = self.draw_tracks_on_video(
|
| 106 |
+
video=video,
|
| 107 |
+
tracks=tracks,
|
| 108 |
+
visibility=visibility,
|
| 109 |
+
segm_mask=segm_mask,
|
| 110 |
+
gt_tracks=gt_tracks,
|
| 111 |
+
query_frame=query_frame,
|
| 112 |
+
compensate_for_camera_motion=compensate_for_camera_motion,
|
| 113 |
+
rigid_part=rigid_part
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if save_video:
|
| 117 |
+
self.save_video(res_video, filename=filename,
|
| 118 |
+
writer=writer, step=step)
|
| 119 |
+
if video_depth is not None:
|
| 120 |
+
self.save_video(video_depth, filename=filename+"_depth",
|
| 121 |
+
writer=writer, step=step)
|
| 122 |
+
return res_video
|
| 123 |
+
|
| 124 |
+
def save_video(self, video, filename, writer=None, step=0):
|
| 125 |
+
if writer is not None:
|
| 126 |
+
writer.add_video(
|
| 127 |
+
f"{filename}_pred_track",
|
| 128 |
+
video.to(torch.uint8),
|
| 129 |
+
global_step=step,
|
| 130 |
+
fps=self.fps,
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
| 134 |
+
wide_list = list(video.unbind(1))
|
| 135 |
+
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
| 136 |
+
clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
|
| 137 |
+
|
| 138 |
+
# Write the video file
|
| 139 |
+
save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4")
|
| 140 |
+
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
|
| 141 |
+
|
| 142 |
+
print(f"Video saved to {save_path}")
|
| 143 |
+
|
| 144 |
+
def draw_tracks_on_video(
|
| 145 |
+
self,
|
| 146 |
+
video: torch.Tensor,
|
| 147 |
+
tracks: torch.Tensor,
|
| 148 |
+
visibility: torch.Tensor = None,
|
| 149 |
+
segm_mask: torch.Tensor = None,
|
| 150 |
+
gt_tracks=None,
|
| 151 |
+
query_frame: int = 0,
|
| 152 |
+
compensate_for_camera_motion=False,
|
| 153 |
+
rigid_part=None,
|
| 154 |
+
):
|
| 155 |
+
B, T, C, H, W = video.shape
|
| 156 |
+
_, _, N, D = tracks.shape
|
| 157 |
+
|
| 158 |
+
assert D == 2
|
| 159 |
+
assert C == 3
|
| 160 |
+
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
|
| 161 |
+
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
|
| 162 |
+
if gt_tracks is not None:
|
| 163 |
+
gt_tracks = gt_tracks.detach().cpu().numpy()
|
| 164 |
+
|
| 165 |
+
res_video = []
|
| 166 |
+
|
| 167 |
+
# process input video
|
| 168 |
+
for rgb in video:
|
| 169 |
+
res_video.append(rgb.copy())
|
| 170 |
+
|
| 171 |
+
vector_colors = np.zeros((T, N, 3))
|
| 172 |
+
if self.mode == "optical_flow":
|
| 173 |
+
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
| 174 |
+
elif segm_mask is None:
|
| 175 |
+
if self.mode == "rainbow":
|
| 176 |
+
y_min, y_max = (
|
| 177 |
+
tracks[query_frame, :, 1].min(),
|
| 178 |
+
tracks[query_frame, :, 1].max(),
|
| 179 |
+
)
|
| 180 |
+
norm = plt.Normalize(y_min, y_max)
|
| 181 |
+
for n in range(N):
|
| 182 |
+
color = self.color_map(norm(tracks[query_frame, n, 1]))
|
| 183 |
+
color = np.array(color[:3])[None] * 255
|
| 184 |
+
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
| 185 |
+
else:
|
| 186 |
+
# color changes with time
|
| 187 |
+
for t in range(T):
|
| 188 |
+
color = np.array(self.color_map(t / T)[:3])[None] * 255
|
| 189 |
+
vector_colors[t] = np.repeat(color, N, axis=0)
|
| 190 |
+
else:
|
| 191 |
+
if self.mode == "rainbow":
|
| 192 |
+
vector_colors[:, segm_mask <= 0, :] = 255
|
| 193 |
+
|
| 194 |
+
y_min, y_max = (
|
| 195 |
+
tracks[0, segm_mask > 0, 1].min(),
|
| 196 |
+
tracks[0, segm_mask > 0, 1].max(),
|
| 197 |
+
)
|
| 198 |
+
norm = plt.Normalize(y_min, y_max)
|
| 199 |
+
for n in range(N):
|
| 200 |
+
if segm_mask[n] > 0:
|
| 201 |
+
color = self.color_map(norm(tracks[0, n, 1]))
|
| 202 |
+
color = np.array(color[:3])[None] * 255
|
| 203 |
+
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
| 204 |
+
|
| 205 |
+
else:
|
| 206 |
+
# color changes with segm class
|
| 207 |
+
segm_mask = segm_mask.cpu()
|
| 208 |
+
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
| 209 |
+
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
| 210 |
+
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
| 211 |
+
vector_colors = np.repeat(color[None], T, axis=0)
|
| 212 |
+
|
| 213 |
+
# draw tracks
|
| 214 |
+
if self.tracks_leave_trace != 0:
|
| 215 |
+
for t in range(1, T):
|
| 216 |
+
first_ind = (
|
| 217 |
+
max(0, t - self.tracks_leave_trace)
|
| 218 |
+
if self.tracks_leave_trace >= 0
|
| 219 |
+
else 0
|
| 220 |
+
)
|
| 221 |
+
curr_tracks = tracks[first_ind : t + 1]
|
| 222 |
+
curr_colors = vector_colors[first_ind : t + 1]
|
| 223 |
+
if compensate_for_camera_motion:
|
| 224 |
+
diff = (
|
| 225 |
+
tracks[first_ind : t + 1, segm_mask <= 0]
|
| 226 |
+
- tracks[t : t + 1, segm_mask <= 0]
|
| 227 |
+
).mean(1)[:, None]
|
| 228 |
+
|
| 229 |
+
curr_tracks = curr_tracks - diff
|
| 230 |
+
curr_tracks = curr_tracks[:, segm_mask > 0]
|
| 231 |
+
curr_colors = curr_colors[:, segm_mask > 0]
|
| 232 |
+
|
| 233 |
+
res_video[t] = self._draw_pred_tracks(
|
| 234 |
+
res_video[t],
|
| 235 |
+
curr_tracks,
|
| 236 |
+
curr_colors,
|
| 237 |
+
)
|
| 238 |
+
if gt_tracks is not None:
|
| 239 |
+
res_video[t] = self._draw_gt_tracks(
|
| 240 |
+
res_video[t], gt_tracks[first_ind : t + 1]
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
if rigid_part is not None:
|
| 244 |
+
cls_label = torch.unique(rigid_part)
|
| 245 |
+
cls_num = len(torch.unique(rigid_part))
|
| 246 |
+
# visualize the clustering results
|
| 247 |
+
cmap = plt.get_cmap('jet') # get the color mapping
|
| 248 |
+
colors = cmap(np.linspace(0, 1, cls_num))
|
| 249 |
+
colors = (colors[:, :3] * 255)
|
| 250 |
+
color_map = {lable.item(): color for lable, color in zip(cls_label, colors)}
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# draw points
|
| 254 |
+
for t in range(T):
|
| 255 |
+
for i in range(N):
|
| 256 |
+
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
| 257 |
+
visibile = True
|
| 258 |
+
if visibility is not None:
|
| 259 |
+
visibile = visibility[0, t, i] > 0.5
|
| 260 |
+
if coord[0] != 0 and coord[1] != 0:
|
| 261 |
+
if not compensate_for_camera_motion or (
|
| 262 |
+
compensate_for_camera_motion and segm_mask[i] > 0
|
| 263 |
+
):
|
| 264 |
+
if rigid_part is not None:
|
| 265 |
+
color = color_map[rigid_part.squeeze()[i].item()]
|
| 266 |
+
cv2.circle(
|
| 267 |
+
res_video[t],
|
| 268 |
+
coord,
|
| 269 |
+
int(self.linewidth * 2),
|
| 270 |
+
color.tolist(),
|
| 271 |
+
thickness=-1 if visibile else 2
|
| 272 |
+
-1,
|
| 273 |
+
)
|
| 274 |
+
else:
|
| 275 |
+
cv2.circle(
|
| 276 |
+
res_video[t],
|
| 277 |
+
coord,
|
| 278 |
+
int(self.linewidth * 2),
|
| 279 |
+
vector_colors[t, i].tolist(),
|
| 280 |
+
thickness=-1 if visibile else 2
|
| 281 |
+
-1,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# construct the final rgb sequence
|
| 285 |
+
if self.show_first_frame > 0:
|
| 286 |
+
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
|
| 287 |
+
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
| 288 |
+
|
| 289 |
+
def _draw_pred_tracks(
|
| 290 |
+
self,
|
| 291 |
+
rgb: np.ndarray, # H x W x 3
|
| 292 |
+
tracks: np.ndarray, # T x 2
|
| 293 |
+
vector_colors: np.ndarray,
|
| 294 |
+
alpha: float = 0.5,
|
| 295 |
+
):
|
| 296 |
+
T, N, _ = tracks.shape
|
| 297 |
+
|
| 298 |
+
for s in range(T - 1):
|
| 299 |
+
vector_color = vector_colors[s]
|
| 300 |
+
original = rgb.copy()
|
| 301 |
+
alpha = (s / T) ** 2
|
| 302 |
+
for i in range(N):
|
| 303 |
+
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
| 304 |
+
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
| 305 |
+
if coord_y[0] != 0 and coord_y[1] != 0:
|
| 306 |
+
cv2.line(
|
| 307 |
+
rgb,
|
| 308 |
+
coord_y,
|
| 309 |
+
coord_x,
|
| 310 |
+
vector_color[i].tolist(),
|
| 311 |
+
self.linewidth,
|
| 312 |
+
cv2.LINE_AA,
|
| 313 |
+
)
|
| 314 |
+
if self.tracks_leave_trace > 0:
|
| 315 |
+
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
|
| 316 |
+
return rgb
|
| 317 |
+
|
| 318 |
+
def _draw_gt_tracks(
|
| 319 |
+
self,
|
| 320 |
+
rgb: np.ndarray, # H x W x 3,
|
| 321 |
+
gt_tracks: np.ndarray, # T x 2
|
| 322 |
+
):
|
| 323 |
+
T, N, _ = gt_tracks.shape
|
| 324 |
+
color = np.array((211.0, 0.0, 0.0))
|
| 325 |
+
|
| 326 |
+
for t in range(T):
|
| 327 |
+
for i in range(N):
|
| 328 |
+
gt_tracks_i = gt_tracks[t][i]
|
| 329 |
+
# draw a red cross
|
| 330 |
+
if gt_tracks_i[0] > 0 and gt_tracks_i[1] > 0:
|
| 331 |
+
length = self.linewidth * 3
|
| 332 |
+
coord_y = (int(gt_tracks_i[0]) + length, int(gt_tracks_i[1]) + length)
|
| 333 |
+
coord_x = (int(gt_tracks_i[0]) - length, int(gt_tracks_i[1]) - length)
|
| 334 |
+
cv2.line(
|
| 335 |
+
rgb,
|
| 336 |
+
coord_y,
|
| 337 |
+
coord_x,
|
| 338 |
+
color,
|
| 339 |
+
self.linewidth,
|
| 340 |
+
cv2.LINE_AA,
|
| 341 |
+
)
|
| 342 |
+
coord_y = (int(gt_tracks_i[0]) - length, int(gt_tracks_i[1]) + length)
|
| 343 |
+
coord_x = (int(gt_tracks_i[0]) + length, int(gt_tracks_i[1]) - length)
|
| 344 |
+
cv2.line(
|
| 345 |
+
rgb,
|
| 346 |
+
coord_y,
|
| 347 |
+
coord_x,
|
| 348 |
+
color,
|
| 349 |
+
self.linewidth,
|
| 350 |
+
cv2.LINE_AA,
|
| 351 |
+
)
|
| 352 |
+
return rgb
|
models/moge/__init__.py
ADDED
|
File without changes
|
models/moge/model/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from typing import *
|
| 3 |
+
|
| 4 |
+
if TYPE_CHECKING:
|
| 5 |
+
from .v1 import MoGeModel as MoGeModelV1
|
| 6 |
+
from .v2 import MoGeModel as MoGeModelV2
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def import_model_class_by_version(version: str) -> Type[Union['MoGeModelV1', 'MoGeModelV2']]:
|
| 10 |
+
assert version in ['v1', 'v2'], f'Unsupported model version: {version}'
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
module = importlib.import_module(f'.{version}', __package__)
|
| 14 |
+
except ModuleNotFoundError:
|
| 15 |
+
raise ValueError(f'Model version "{version}" not found.')
|
| 16 |
+
|
| 17 |
+
cls = getattr(module, 'MoGeModel')
|
| 18 |
+
return cls
|
models/moge/model/dinov2/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
__version__ = "0.0.1"
|
models/moge/model/dinov2/hub/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
models/moge/model/dinov2/hub/backbones.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Weights(Enum):
|
| 15 |
+
LVD142M = "LVD142M"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _make_dinov2_model(
|
| 19 |
+
*,
|
| 20 |
+
arch_name: str = "vit_large",
|
| 21 |
+
img_size: int = 518,
|
| 22 |
+
patch_size: int = 14,
|
| 23 |
+
init_values: float = 1.0,
|
| 24 |
+
ffn_layer: str = "mlp",
|
| 25 |
+
block_chunks: int = 0,
|
| 26 |
+
num_register_tokens: int = 0,
|
| 27 |
+
interpolate_antialias: bool = False,
|
| 28 |
+
interpolate_offset: float = 0.1,
|
| 29 |
+
pretrained: bool = True,
|
| 30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
from ..models import vision_transformer as vits
|
| 34 |
+
|
| 35 |
+
if isinstance(weights, str):
|
| 36 |
+
try:
|
| 37 |
+
weights = Weights[weights]
|
| 38 |
+
except KeyError:
|
| 39 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
| 40 |
+
|
| 41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 42 |
+
vit_kwargs = dict(
|
| 43 |
+
img_size=img_size,
|
| 44 |
+
patch_size=patch_size,
|
| 45 |
+
init_values=init_values,
|
| 46 |
+
ffn_layer=ffn_layer,
|
| 47 |
+
block_chunks=block_chunks,
|
| 48 |
+
num_register_tokens=num_register_tokens,
|
| 49 |
+
interpolate_antialias=interpolate_antialias,
|
| 50 |
+
interpolate_offset=interpolate_offset,
|
| 51 |
+
)
|
| 52 |
+
vit_kwargs.update(**kwargs)
|
| 53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
| 54 |
+
|
| 55 |
+
if pretrained:
|
| 56 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
| 57 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
| 58 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 59 |
+
model.load_state_dict(state_dict, strict=True)
|
| 60 |
+
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 65 |
+
"""
|
| 66 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 67 |
+
"""
|
| 68 |
+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 72 |
+
"""
|
| 73 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 74 |
+
"""
|
| 75 |
+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 79 |
+
"""
|
| 80 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 81 |
+
"""
|
| 82 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 86 |
+
"""
|
| 87 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 88 |
+
"""
|
| 89 |
+
return _make_dinov2_model(
|
| 90 |
+
arch_name="vit_giant2",
|
| 91 |
+
ffn_layer="swiglufused",
|
| 92 |
+
weights=weights,
|
| 93 |
+
pretrained=pretrained,
|
| 94 |
+
**kwargs,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 99 |
+
"""
|
| 100 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 101 |
+
"""
|
| 102 |
+
return _make_dinov2_model(
|
| 103 |
+
arch_name="vit_small",
|
| 104 |
+
pretrained=pretrained,
|
| 105 |
+
weights=weights,
|
| 106 |
+
num_register_tokens=4,
|
| 107 |
+
interpolate_antialias=True,
|
| 108 |
+
interpolate_offset=0.0,
|
| 109 |
+
**kwargs,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 114 |
+
"""
|
| 115 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 116 |
+
"""
|
| 117 |
+
return _make_dinov2_model(
|
| 118 |
+
arch_name="vit_base",
|
| 119 |
+
pretrained=pretrained,
|
| 120 |
+
weights=weights,
|
| 121 |
+
num_register_tokens=4,
|
| 122 |
+
interpolate_antialias=True,
|
| 123 |
+
interpolate_offset=0.0,
|
| 124 |
+
**kwargs,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 129 |
+
"""
|
| 130 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 131 |
+
"""
|
| 132 |
+
return _make_dinov2_model(
|
| 133 |
+
arch_name="vit_large",
|
| 134 |
+
pretrained=pretrained,
|
| 135 |
+
weights=weights,
|
| 136 |
+
num_register_tokens=4,
|
| 137 |
+
interpolate_antialias=True,
|
| 138 |
+
interpolate_offset=0.0,
|
| 139 |
+
**kwargs,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 144 |
+
"""
|
| 145 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 146 |
+
"""
|
| 147 |
+
return _make_dinov2_model(
|
| 148 |
+
arch_name="vit_giant2",
|
| 149 |
+
ffn_layer="swiglufused",
|
| 150 |
+
weights=weights,
|
| 151 |
+
pretrained=pretrained,
|
| 152 |
+
num_register_tokens=4,
|
| 153 |
+
interpolate_antialias=True,
|
| 154 |
+
interpolate_offset=0.0,
|
| 155 |
+
**kwargs,
|
| 156 |
+
)
|
models/moge/model/dinov2/hub/utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
| 18 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
| 19 |
+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
| 20 |
+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CenterPadding(nn.Module):
|
| 24 |
+
def __init__(self, multiple):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.multiple = multiple
|
| 27 |
+
|
| 28 |
+
def _get_pad(self, size):
|
| 29 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
| 30 |
+
pad_size = new_size - size
|
| 31 |
+
pad_size_left = pad_size // 2
|
| 32 |
+
pad_size_right = pad_size - pad_size_left
|
| 33 |
+
return pad_size_left, pad_size_right
|
| 34 |
+
|
| 35 |
+
@torch.inference_mode()
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
| 38 |
+
output = F.pad(x, pads)
|
| 39 |
+
return output
|
models/moge/model/dinov2/layers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .dino_head import DINOHead
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
| 9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 10 |
+
from .block import NestedTensorBlock
|
| 11 |
+
from .attention import MemEffAttention
|
models/moge/model/dinov2/layers/attention.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("dinov2")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 22 |
+
try:
|
| 23 |
+
if XFORMERS_ENABLED:
|
| 24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 25 |
+
|
| 26 |
+
XFORMERS_AVAILABLE = True
|
| 27 |
+
# warnings.warn("xFormers is available (Attention)")
|
| 28 |
+
else:
|
| 29 |
+
# warnings.warn("xFormers is disabled (Attention)")
|
| 30 |
+
raise ImportError
|
| 31 |
+
except ImportError:
|
| 32 |
+
XFORMERS_AVAILABLE = False
|
| 33 |
+
# warnings.warn("xFormers is not available (Attention)")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Attention(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dim: int,
|
| 40 |
+
num_heads: int = 8,
|
| 41 |
+
qkv_bias: bool = False,
|
| 42 |
+
proj_bias: bool = True,
|
| 43 |
+
attn_drop: float = 0.0,
|
| 44 |
+
proj_drop: float = 0.0,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.num_heads = num_heads
|
| 48 |
+
head_dim = dim // num_heads
|
| 49 |
+
self.scale = head_dim**-0.5
|
| 50 |
+
|
| 51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 53 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 57 |
+
B, N, C = x.shape
|
| 58 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 59 |
+
|
| 60 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 61 |
+
attn = q @ k.transpose(-2, -1)
|
| 62 |
+
|
| 63 |
+
attn = attn.softmax(dim=-1)
|
| 64 |
+
attn = self.attn_drop(attn)
|
| 65 |
+
|
| 66 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 67 |
+
x = self.proj(x)
|
| 68 |
+
x = self.proj_drop(x)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MemEffAttention(Attention):
|
| 73 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 74 |
+
if not XFORMERS_AVAILABLE:
|
| 75 |
+
if attn_bias is not None:
|
| 76 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 77 |
+
return super().forward(x)
|
| 78 |
+
|
| 79 |
+
B, N, C = x.shape
|
| 80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 81 |
+
|
| 82 |
+
q, k, v = unbind(qkv, 2)
|
| 83 |
+
|
| 84 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 85 |
+
x = x.reshape([B, N, C])
|
| 86 |
+
|
| 87 |
+
x = self.proj(x)
|
| 88 |
+
x = self.proj_drop(x)
|
| 89 |
+
return x
|