Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/.hydra/config.yaml +125 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/.hydra/hydra.yaml +186 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/.hydra/overrides.yaml +31 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-1.pth +3 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-2.pth +3 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-final.pth +3 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-last.pth +3 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/CHANGELOG.md +19 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/README.md +373 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/README_submap.md +225 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/base_opt.py +301 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/commons.py +102 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/__init__.py +31 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/base_opt.py +620 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/commons.py +102 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/init_im_poses.py +378 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/optimizer.py +301 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/init_all.py +222 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/utils.py +443 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/deepspeed_zero3_bf16.json +19 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune.yaml +102 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune_paper_h20.yaml +129 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune_pseudo_gt_high_recall.yaml +129 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune_sub_only.yaml +129 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/mytrain.yaml +92 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/environment.yml +245 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/eval_ate_scaled.py +54 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/get_ate.py +74 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/publish_submap.sh +138 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/requirements.txt +30 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/run_tum.sh +77 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/run_tum_top5.sh +66 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/setup.py +8 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/setup_env.sh +18 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/__init__.py +0 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/audit_dataset_num_views.py +412 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/batched_dynamic_router.py +243 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/demo.py +540 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/demo_infinite.py +493 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/demo_submap.py +927 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/download_data.sh +354 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/exp_joint_freeze_frontend_fsdp_8gpu.sh +438 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/graph_gated_memory.py +850 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/mine_pseudo_gt.py +588 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/pseudo_gt.py +348 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/__init__.py +197 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/generic_utils.py +274 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/geometry_utils.py +232 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/tmp.py +39 -0
- checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/visualization_utils.py +167 -0
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/.hydra/config.yaml
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accum_iter: 1
|
| 2 |
+
allow_repeat: false
|
| 3 |
+
amp: 1
|
| 4 |
+
batch_size: 1
|
| 5 |
+
benchmark: false
|
| 6 |
+
custom_lr_scale: 1.0
|
| 7 |
+
data_root: /home/23068142r/work_dir/data
|
| 8 |
+
root_arkit: /home/23068142r/work_dir/data/processed_arkitscenes
|
| 9 |
+
root_scannetpp: /home/23068142r/work_dir/data/preprocessed_scannetpp
|
| 10 |
+
root_scannet: /home/23068142r/work_dir/data/processed_scannet
|
| 11 |
+
root_hypersim: /home/23068142r/work_dir/data/preprocessed_Hypersim
|
| 12 |
+
root_blendedmvs: /home/23068142r/work_dir/data/processed_blendedmvs
|
| 13 |
+
root_megadepth: /home/23068142r/work_dir/data/processed_megadepth
|
| 14 |
+
root_mvs_synth: /home/23068142r/work_dir/data/processed_mvs_synth
|
| 15 |
+
dataset_arkit: ARKitScenes_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_arkit}",
|
| 16 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 17 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_arkit}, n_corres=${n_corres_train})
|
| 18 |
+
dataset_scannetpp: ScanNetpp_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannetpp}",
|
| 19 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 20 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_scannetpp}, n_corres=${n_corres_train})
|
| 21 |
+
dataset_scannet: ScanNet_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannet}",
|
| 22 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 23 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_scannet}, n_corres=${n_corres_train})
|
| 24 |
+
dataset_hypersim: HyperSim_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_hypersim}",
|
| 25 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 26 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_hypersim}, n_corres=${n_corres_train})
|
| 27 |
+
dataset_blendedmvs: BlendedMVS_Multi(allow_repeat=${allow_repeat}, split='train',
|
| 28 |
+
ROOT="${root_blendedmvs}", aug_crop=16, resolution=[(518, 392), (518, 336), (518,
|
| 29 |
+
294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views_blendedmvs},
|
| 30 |
+
n_corres=${n_corres_train})
|
| 31 |
+
dataset_megadepth: MegaDepth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_megadepth}",
|
| 32 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 33 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_megadepth}, n_corres=${n_corres_train})
|
| 34 |
+
dataset_mvs_synth: MVS_Synth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_mvs_synth}",
|
| 35 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 36 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_mvs_synth}, n_corres=${n_corres_train})
|
| 37 |
+
desc_dim: 128
|
| 38 |
+
detach_frontend_tokens: true
|
| 39 |
+
dist_backend: nccl
|
| 40 |
+
dist_url: env://
|
| 41 |
+
distributed: false
|
| 42 |
+
enable_dynamic_boundary: false
|
| 43 |
+
enable_loop: true
|
| 44 |
+
enable_submap: true
|
| 45 |
+
enable_temporal: false
|
| 46 |
+
epochs: 2
|
| 47 |
+
eval_freq: 1
|
| 48 |
+
exp_name: paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12
|
| 49 |
+
fixed_length: true
|
| 50 |
+
freeze_encoder: true
|
| 51 |
+
gpu: 0
|
| 52 |
+
gradient_checkpointing: true
|
| 53 |
+
gumbel_tau: 5.0
|
| 54 |
+
loop_mask_mode: soft_all
|
| 55 |
+
retain_history_grad: true
|
| 56 |
+
submap_train_mode: full_token
|
| 57 |
+
submap_retrieval_topk: 0
|
| 58 |
+
submap_fetch_source: frontend
|
| 59 |
+
submap_descriptor_source: frontend
|
| 60 |
+
train_submap_modules_only: false
|
| 61 |
+
gumbel_tau_end: 0.1
|
| 62 |
+
gumbel_tau_start: 5.0
|
| 63 |
+
keep_freq: 1
|
| 64 |
+
load_only_encoder: false
|
| 65 |
+
local-rank: -1
|
| 66 |
+
logdir: ${save_dir}/${exp_name}/logs
|
| 67 |
+
long_context: false
|
| 68 |
+
lr: 1.0e-05
|
| 69 |
+
max_checkpoints: 10
|
| 70 |
+
max_recursive_submaps: 5
|
| 71 |
+
min_lr: 1.0e-08
|
| 72 |
+
n_corres_test: 0
|
| 73 |
+
n_corres_train: 0
|
| 74 |
+
num_imgs_vis: 4
|
| 75 |
+
num_test_views: 4
|
| 76 |
+
num_views: 24
|
| 77 |
+
num_views_arkit: 64
|
| 78 |
+
num_views_scannetpp: 24
|
| 79 |
+
num_views_scannet: 64
|
| 80 |
+
num_views_hypersim: 24
|
| 81 |
+
num_views_blendedmvs: 64
|
| 82 |
+
num_views_megadepth: 64
|
| 83 |
+
num_views_mvs_synth: 24
|
| 84 |
+
num_workers: 4
|
| 85 |
+
output_dir: ${save_dir}/${exp_name}/
|
| 86 |
+
pretrained: /home/23068142r/work_dir/projects/e2e-semantic-SLAM-submap/ckpt/checkpoint-10.pth.model
|
| 87 |
+
print_freq: 10
|
| 88 |
+
print_img_freq: 50000000
|
| 89 |
+
rank: 0
|
| 90 |
+
resume: null
|
| 91 |
+
retention_ratio: 0.5
|
| 92 |
+
pseudo_gt:
|
| 93 |
+
enable: false
|
| 94 |
+
cache_path: null
|
| 95 |
+
use_soft_targets: true
|
| 96 |
+
min_confidence: 0.65
|
| 97 |
+
min_support_pairs: 1
|
| 98 |
+
topk_pairs: 4
|
| 99 |
+
loss_type: hybrid
|
| 100 |
+
loss_weight_gate: 0.1
|
| 101 |
+
loss_weight_desc: 0.1
|
| 102 |
+
geometric_support_scale: 0.25
|
| 103 |
+
ranking_margin: 0.1
|
| 104 |
+
use_l2m: false
|
| 105 |
+
l2m_min_certainty: 0.0
|
| 106 |
+
l2m_min_inlier_ratio: 0.0
|
| 107 |
+
save_dir: /home/23068142r/work_dir/projects/e2e-semantic-SLAM/checkpoints/paper_smoke_local_8gpu/joint_freeze_frontend_fsdp_sub12
|
| 108 |
+
save_freq: 0.1
|
| 109 |
+
seed: 42
|
| 110 |
+
soft_mask_bias: 0.2
|
| 111 |
+
soft_mask_temperature: 0.25
|
| 112 |
+
start_epoch: 0
|
| 113 |
+
start_step: 0
|
| 114 |
+
submap_size: 6
|
| 115 |
+
task: SLAMFormer_Submap_Finetune
|
| 116 |
+
tbptt_window: 0
|
| 117 |
+
teacher: null
|
| 118 |
+
temporal_embed_mode: learned
|
| 119 |
+
test_criterion: DistillLoss()
|
| 120 |
+
test_dataset: ''
|
| 121 |
+
train_criterion: DistillLoss()
|
| 122 |
+
train_dataset: 16 @ ${dataset_scannetpp} + 16 @ ${dataset_hypersim} + 16 @ ${dataset_mvs_synth}
|
| 123 |
+
warmup_epochs: 0.5
|
| 124 |
+
weight_decay: 0.05
|
| 125 |
+
world_size: 1
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/.hydra/hydra.yaml
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ${save_dir}/${exp_name}
|
| 4 |
+
sweep:
|
| 5 |
+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 6 |
+
subdir: ${hydra.job.num}
|
| 7 |
+
launcher:
|
| 8 |
+
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
|
| 9 |
+
sweeper:
|
| 10 |
+
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
|
| 11 |
+
max_batch_size: null
|
| 12 |
+
params: null
|
| 13 |
+
help:
|
| 14 |
+
app_name: ${hydra.job.name}
|
| 15 |
+
header: '${hydra.help.app_name} is powered by Hydra.
|
| 16 |
+
|
| 17 |
+
'
|
| 18 |
+
footer: 'Powered by Hydra (https://hydra.cc)
|
| 19 |
+
|
| 20 |
+
Use --hydra-help to view Hydra specific help
|
| 21 |
+
|
| 22 |
+
'
|
| 23 |
+
template: '${hydra.help.header}
|
| 24 |
+
|
| 25 |
+
== Configuration groups ==
|
| 26 |
+
|
| 27 |
+
Compose your configuration from those groups (group=option)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
$APP_CONFIG_GROUPS
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
== Config ==
|
| 34 |
+
|
| 35 |
+
Override anything in the config (foo.bar=value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
$CONFIG
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
${hydra.help.footer}
|
| 42 |
+
|
| 43 |
+
'
|
| 44 |
+
hydra_help:
|
| 45 |
+
template: 'Hydra (${hydra.runtime.version})
|
| 46 |
+
|
| 47 |
+
See https://hydra.cc for more info.
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
== Flags ==
|
| 51 |
+
|
| 52 |
+
$FLAGS_HELP
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
== Configuration groups ==
|
| 56 |
+
|
| 57 |
+
Compose your configuration from those groups (For example, append hydra/job_logging=disabled
|
| 58 |
+
to command line)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
$HYDRA_CONFIG_GROUPS
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
Use ''--cfg hydra'' to Show the Hydra config.
|
| 65 |
+
|
| 66 |
+
'
|
| 67 |
+
hydra_help: ???
|
| 68 |
+
hydra_logging:
|
| 69 |
+
version: 1
|
| 70 |
+
formatters:
|
| 71 |
+
simple:
|
| 72 |
+
format: '[%(asctime)s][HYDRA] %(message)s'
|
| 73 |
+
handlers:
|
| 74 |
+
console:
|
| 75 |
+
class: logging.StreamHandler
|
| 76 |
+
formatter: simple
|
| 77 |
+
stream: ext://sys.stdout
|
| 78 |
+
root:
|
| 79 |
+
level: INFO
|
| 80 |
+
handlers:
|
| 81 |
+
- console
|
| 82 |
+
loggers:
|
| 83 |
+
logging_example:
|
| 84 |
+
level: DEBUG
|
| 85 |
+
disable_existing_loggers: false
|
| 86 |
+
job_logging:
|
| 87 |
+
version: 1
|
| 88 |
+
formatters:
|
| 89 |
+
simple:
|
| 90 |
+
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
|
| 91 |
+
handlers:
|
| 92 |
+
console:
|
| 93 |
+
class: logging.StreamHandler
|
| 94 |
+
formatter: simple
|
| 95 |
+
stream: ext://sys.stdout
|
| 96 |
+
file:
|
| 97 |
+
class: logging.FileHandler
|
| 98 |
+
formatter: simple
|
| 99 |
+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
|
| 100 |
+
root:
|
| 101 |
+
level: INFO
|
| 102 |
+
handlers:
|
| 103 |
+
- console
|
| 104 |
+
- file
|
| 105 |
+
disable_existing_loggers: false
|
| 106 |
+
env: {}
|
| 107 |
+
mode: RUN
|
| 108 |
+
searchpath: []
|
| 109 |
+
callbacks: {}
|
| 110 |
+
output_subdir: .hydra
|
| 111 |
+
overrides:
|
| 112 |
+
hydra:
|
| 113 |
+
- hydra.mode=RUN
|
| 114 |
+
task:
|
| 115 |
+
- exp_name=paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12
|
| 116 |
+
- save_dir=/home/23068142r/work_dir/projects/e2e-semantic-SLAM/checkpoints/paper_smoke_local_8gpu/joint_freeze_frontend_fsdp_sub12
|
| 117 |
+
- pretrained=/home/23068142r/work_dir/projects/e2e-semantic-SLAM-submap/ckpt/checkpoint-10.pth.model
|
| 118 |
+
- resume=null
|
| 119 |
+
- data_root=/home/23068142r/work_dir/data
|
| 120 |
+
- root_arkit=/home/23068142r/work_dir/data/processed_arkitscenes
|
| 121 |
+
- root_scannetpp=/home/23068142r/work_dir/data/preprocessed_scannetpp
|
| 122 |
+
- root_scannet=/home/23068142r/work_dir/data/processed_scannet
|
| 123 |
+
- root_hypersim=/home/23068142r/work_dir/data/preprocessed_Hypersim
|
| 124 |
+
- root_blendedmvs=/home/23068142r/work_dir/data/processed_blendedmvs
|
| 125 |
+
- root_megadepth=/home/23068142r/work_dir/data/processed_megadepth
|
| 126 |
+
- root_mvs_synth=/home/23068142r/work_dir/data/processed_mvs_synth
|
| 127 |
+
- num_views=24
|
| 128 |
+
- num_views_arkit=64
|
| 129 |
+
- num_views_scannetpp=24
|
| 130 |
+
- num_views_scannet=64
|
| 131 |
+
- num_views_hypersim=24
|
| 132 |
+
- num_views_blendedmvs=64
|
| 133 |
+
- num_views_megadepth=64
|
| 134 |
+
- num_views_mvs_synth=24
|
| 135 |
+
- train_submap_modules_only=false
|
| 136 |
+
- detach_frontend_tokens=true
|
| 137 |
+
- submap_train_mode=full_token
|
| 138 |
+
- submap_retrieval_topk=0
|
| 139 |
+
- submap_fetch_source=frontend
|
| 140 |
+
- submap_descriptor_source=frontend
|
| 141 |
+
- pseudo_gt.enable=false
|
| 142 |
+
- pseudo_gt.cache_path=null
|
| 143 |
+
- train_dataset=16 @ ${dataset_scannetpp} + 16 @ ${dataset_hypersim} + 16 @ ${dataset_mvs_synth}
|
| 144 |
+
- epochs=2
|
| 145 |
+
- test_dataset=
|
| 146 |
+
job:
|
| 147 |
+
name: finetune
|
| 148 |
+
chdir: null
|
| 149 |
+
override_dirname: data_root=/home/23068142r/work_dir/data,detach_frontend_tokens=true,epochs=2,exp_name=paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12,num_views=24,num_views_arkit=64,num_views_blendedmvs=64,num_views_hypersim=24,num_views_megadepth=64,num_views_mvs_synth=24,num_views_scannet=64,num_views_scannetpp=24,pretrained=/home/23068142r/work_dir/projects/e2e-semantic-SLAM-submap/ckpt/checkpoint-10.pth.model,pseudo_gt.cache_path=null,pseudo_gt.enable=false,resume=null,root_arkit=/home/23068142r/work_dir/data/processed_arkitscenes,root_blendedmvs=/home/23068142r/work_dir/data/processed_blendedmvs,root_hypersim=/home/23068142r/work_dir/data/preprocessed_Hypersim,root_megadepth=/home/23068142r/work_dir/data/processed_megadepth,root_mvs_synth=/home/23068142r/work_dir/data/processed_mvs_synth,root_scannet=/home/23068142r/work_dir/data/processed_scannet,root_scannetpp=/home/23068142r/work_dir/data/preprocessed_scannetpp,save_dir=/home/23068142r/work_dir/projects/e2e-semantic-SLAM/checkpoints/paper_smoke_local_8gpu/joint_freeze_frontend_fsdp_sub12,submap_descriptor_source=frontend,submap_fetch_source=frontend,submap_retrieval_topk=0,submap_train_mode=full_token,test_dataset=,train_dataset=16
|
| 150 |
+
@ ${dataset_scannetpp} + 16 @ ${dataset_hypersim} + 16 @ ${dataset_mvs_synth},train_submap_modules_only=false
|
| 151 |
+
id: ???
|
| 152 |
+
num: ???
|
| 153 |
+
config_name: finetune_paper_h20.yaml
|
| 154 |
+
env_set: {}
|
| 155 |
+
env_copy: []
|
| 156 |
+
config:
|
| 157 |
+
override_dirname:
|
| 158 |
+
kv_sep: '='
|
| 159 |
+
item_sep: ','
|
| 160 |
+
exclude_keys: []
|
| 161 |
+
runtime:
|
| 162 |
+
version: 1.3.2
|
| 163 |
+
version_base: '1.3'
|
| 164 |
+
cwd: /home/23068142r/work_dir/projects/e2e-semantic-SLAM
|
| 165 |
+
config_sources:
|
| 166 |
+
- path: hydra.conf
|
| 167 |
+
schema: pkg
|
| 168 |
+
provider: hydra
|
| 169 |
+
- path: /home/23068142r/work_dir/projects/e2e-semantic-SLAM/src/../config
|
| 170 |
+
schema: file
|
| 171 |
+
provider: main
|
| 172 |
+
- path: ''
|
| 173 |
+
schema: structured
|
| 174 |
+
provider: schema
|
| 175 |
+
output_dir: /home/23068142r/work_dir/projects/e2e-semantic-SLAM/checkpoints/paper_smoke_local_8gpu/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12
|
| 176 |
+
choices:
|
| 177 |
+
hydra/env: default
|
| 178 |
+
hydra/callbacks: null
|
| 179 |
+
hydra/job_logging: default
|
| 180 |
+
hydra/hydra_logging: default
|
| 181 |
+
hydra/hydra_help: default
|
| 182 |
+
hydra/help: default
|
| 183 |
+
hydra/sweeper: basic
|
| 184 |
+
hydra/launcher: basic
|
| 185 |
+
hydra/output: default
|
| 186 |
+
verbose: true
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/.hydra/overrides.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- exp_name=paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12
|
| 2 |
+
- save_dir=/home/23068142r/work_dir/projects/e2e-semantic-SLAM/checkpoints/paper_smoke_local_8gpu/joint_freeze_frontend_fsdp_sub12
|
| 3 |
+
- pretrained=/home/23068142r/work_dir/projects/e2e-semantic-SLAM-submap/ckpt/checkpoint-10.pth.model
|
| 4 |
+
- resume=null
|
| 5 |
+
- data_root=/home/23068142r/work_dir/data
|
| 6 |
+
- root_arkit=/home/23068142r/work_dir/data/processed_arkitscenes
|
| 7 |
+
- root_scannetpp=/home/23068142r/work_dir/data/preprocessed_scannetpp
|
| 8 |
+
- root_scannet=/home/23068142r/work_dir/data/processed_scannet
|
| 9 |
+
- root_hypersim=/home/23068142r/work_dir/data/preprocessed_Hypersim
|
| 10 |
+
- root_blendedmvs=/home/23068142r/work_dir/data/processed_blendedmvs
|
| 11 |
+
- root_megadepth=/home/23068142r/work_dir/data/processed_megadepth
|
| 12 |
+
- root_mvs_synth=/home/23068142r/work_dir/data/processed_mvs_synth
|
| 13 |
+
- num_views=24
|
| 14 |
+
- num_views_arkit=64
|
| 15 |
+
- num_views_scannetpp=24
|
| 16 |
+
- num_views_scannet=64
|
| 17 |
+
- num_views_hypersim=24
|
| 18 |
+
- num_views_blendedmvs=64
|
| 19 |
+
- num_views_megadepth=64
|
| 20 |
+
- num_views_mvs_synth=24
|
| 21 |
+
- train_submap_modules_only=false
|
| 22 |
+
- detach_frontend_tokens=true
|
| 23 |
+
- submap_train_mode=full_token
|
| 24 |
+
- submap_retrieval_topk=0
|
| 25 |
+
- submap_fetch_source=frontend
|
| 26 |
+
- submap_descriptor_source=frontend
|
| 27 |
+
- pseudo_gt.enable=false
|
| 28 |
+
- pseudo_gt.cache_path=null
|
| 29 |
+
- train_dataset=16 @ ${dataset_scannetpp} + 16 @ ${dataset_hypersim} + 16 @ ${dataset_mvs_synth}
|
| 30 |
+
- epochs=2
|
| 31 |
+
- test_dataset=
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-1.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c96e1ed2d9231223d02406ef66794e5f9f39bae990354bd5bc4101ab2396afb6
|
| 3 |
+
size 4516140233
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-2.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:465a4b460ad29b4d539348adebfb85a01845b0f905cad11206d260cf1b95f76f
|
| 3 |
+
size 4516140233
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-final.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3a618c80dcef89b1a41580d858134bac0d79b39130586e8d036a26ac9582c9f2
|
| 3 |
+
size 3873507717
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-last.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb7f68ba708455c27ac4b78128cf8a34f0f9511caf47074e3eea5431d5dfdfdb
|
| 3 |
+
size 4516145306
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/CHANGELOG.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Changelog
|
| 2 |
+
|
| 3 |
+
## Hardware
|
| 4 |
+
|
| 5 |
+
| Component | Specification |
|
| 6 |
+
|-----------|---------------|
|
| 7 |
+
| GPU | 8 x NVIDIA L40S |
|
| 8 |
+
| CPU | AMD EPYC 7763 64-Core Processor (112 vCPUs) |
|
| 9 |
+
| Memory | 755 GiB |
|
| 10 |
+
|
| 11 |
+
## 2026-04-03
|
| 12 |
+
|
| 13 |
+
### Added
|
| 14 |
+
|
| 15 |
+
- new script `slam/exp_joint_freeze_frontend_fsdp_8gpu.sh`
|
| 16 |
+
- `SKIP_TEST` flag in `slam/exp_joint_freeze_frontend_fsdp_8gpu.sh` (default `0`). When set to `1`, the test dataset Hydra override is cleared and no test data loaders are built.
|
| 17 |
+
- Guard in `src/finetune.py` to skip test dataset construction when `test_dataset` is empty, setting `data_loader_test = {}`.
|
| 18 |
+
- Guard in `src/finetune.py` (`train_one_epoch`) against `ZeroDivisionError` when `int(save_freq * len(data_loader))` truncates to 0 (e.g. small dataset on many GPUs). Intra-epoch checkpoint saving is now skipped gracefully in this case.
|
| 19 |
+
- Adjust the dataset path to fit my own computer
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/README.md
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<h1>SLAM-Former Submap Release Guide</h1>
|
| 3 |
+
<h3>GitHub handoff, FSDP training, and full-sequence TUM runbook</h3>
|
| 4 |
+
</div>
|
| 5 |
+
|
| 6 |
+
## What this upload is for
|
| 7 |
+
|
| 8 |
+
This branch is the submap-oriented training and inference release prepared for GitHub handoff.
|
| 9 |
+
|
| 10 |
+
The release is built around the H20 launchers and the current submap memory path, with the following fixed requirements:
|
| 11 |
+
|
| 12 |
+
- **Distributed strategy**: `FSDP`
|
| 13 |
+
- **Submap length**: `submap_size=12`
|
| 14 |
+
- **Descriptor source**: `frontend`
|
| 15 |
+
- **Historical token fetch source**: `frontend`
|
| 16 |
+
- **External epoch control**: set `EPOCHS` outside the script
|
| 17 |
+
- **Comparison modes**:
|
| 18 |
+
- `submap-only`
|
| 19 |
+
- `backend + submap joint training with detached frontend tokens`
|
| 20 |
+
|
| 21 |
+
If you want the original upstream project README, see `README_ori.md`.
|
| 22 |
+
If you want more implementation detail on the submap system, see `README_submap.md` and `submap_handoff.md`.
|
| 23 |
+
|
| 24 |
+
## Clone the published branch
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
git clone -b submap https://github.com/SlamMate/e2e-semantic-SLAM.git
|
| 28 |
+
cd e2e-semantic-SLAM
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Environment
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
conda env create -f environment.yml
|
| 35 |
+
conda activate SLAM-Former
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
The exported `environment.yml` is a full snapshot of the current Conda environment, including CUDA-related Python packages. If you want to move it between machines, you can keep the file as-is or delete the final `prefix:` line for a cleaner portable export.
|
| 39 |
+
|
| 40 |
+
If you prefer the original manual install flow:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
conda create -n SLAM-Former python=3.11
|
| 44 |
+
conda activate SLAM-Former
|
| 45 |
+
|
| 46 |
+
pip install -r requirements.txt
|
| 47 |
+
pip install -e .
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Paths and launch parameters to set on a new cluster
|
| 51 |
+
|
| 52 |
+
| Variable | Meaning | Typical example |
|
| 53 |
+
| --- | --- | --- |
|
| 54 |
+
| `PROJECT_DIR` | repository root | `/path/to/e2e-semantic-SLAM` |
|
| 55 |
+
| `DATA_ROOT` | training data root | `/path/to/data/train` |
|
| 56 |
+
| `PRETRAINED` | pretrained SLAM-Former checkpoint | `/path/to/ckpt/checkpoint-10.pth.model` |
|
| 57 |
+
| `SAVE_DIR` | checkpoint output root | `/path/to/checkpoints` |
|
| 58 |
+
| `CONDA_SH` | conda init script | `/path/to/miniconda3/etc/profile.d/conda.sh` |
|
| 59 |
+
| `CONDA_ENV_NAME` | conda env name | `SLAM-Former` |
|
| 60 |
+
| `MASTER_ADDR` | rank-0 node hostname or IP for multi-node bash launch | `node0` |
|
| 61 |
+
| `MACHINE_RANK` | machine rank for multi-node bash launch | `0..7` |
|
| 62 |
+
| `NUM_MACHINES` | machine count for remote long run | `8` |
|
| 63 |
+
| `GPUS_PER_NODE` | visible GPUs per node | `8` |
|
| 64 |
+
| `MASTER_PORT` | communication port | `29661` / `29671` etc. |
|
| 65 |
+
|
| 66 |
+
For migration to a new machine, the entries above are the first ones to review. If the data layout is unchanged and each dataset still lives under `DATA_ROOT/<dataset>`, you can usually keep the `ROOT_*` defaults.
|
| 67 |
+
|
| 68 |
+
The default data layout expected by the launchers is:
|
| 69 |
+
|
| 70 |
+
- `processed_arkitscenes`
|
| 71 |
+
- `processed_scannetpp`
|
| 72 |
+
- `processed_scannet` or `processed_scannetv2`
|
| 73 |
+
- `hypersim`
|
| 74 |
+
- `processed_blendedmvs`
|
| 75 |
+
- `processed_megadepth`
|
| 76 |
+
- `processed_mvs_synth`
|
| 77 |
+
|
| 78 |
+
## Fixed training invariants in this release
|
| 79 |
+
|
| 80 |
+
- `CONFIG_NAME=finetune_paper_h20.yaml`
|
| 81 |
+
- `DIST_STRATEGY=fsdp`
|
| 82 |
+
- `SUBMAP_SIZE=12`
|
| 83 |
+
- `SUBMAP_TRAIN_MODE=full_token`
|
| 84 |
+
- `SUBMAP_RETRIEVAL_TOPK=0`
|
| 85 |
+
- `SUBMAP_FETCH_SOURCE=frontend`
|
| 86 |
+
- `SUBMAP_DESCRIPTOR_SOURCE=frontend`
|
| 87 |
+
- `freeze_encoder=true` stays inherited from `config/finetune_paper_h20.yaml`
|
| 88 |
+
- `DETACH_FRONTEND_TOKENS=1` in both comparison modes
|
| 89 |
+
- `TRAIN_SUBMAP_MODULES_ONLY=1` for strict submap-only training
|
| 90 |
+
- `TRAIN_SUBMAP_MODULES_ONLY=0` for backend + submap joint training while frontend tokens stay detached
|
| 91 |
+
|
| 92 |
+
## `num_views` reference table
|
| 93 |
+
|
| 94 |
+
The remote 8-node scripts use aggressive long-sequence defaults.
|
| 95 |
+
These are **reference upper settings**, not no-skip guarantees.
|
| 96 |
+
For ARKitScenes, the local audit shows a very small strict no-skip cap because of a few short scenes, but much longer clips are still possible if short scenes are skipped.
|
| 97 |
+
|
| 98 |
+
| Dataset | Remote 8x8 default | StrictCapNoSkip | MedianCap | MaxCap | Notes |
|
| 99 |
+
| --- | ---: | ---: | ---: | ---: | --- |
|
| 100 |
+
| ARKitScenes | 478 | 2 | 92 | 478 | strict cap is dominated by short scenes; long clips work with scene skipping |
|
| 101 |
+
| ScanNet++ | 150 | 45 | 143 | 150 | local processed data already supports long clips |
|
| 102 |
+
| ScanNet | 64 | N/A | N/A | N/A | local audit unavailable; rerun audit on target cluster if needed |
|
| 103 |
+
| HyperSim | 64 | N/A | N/A | N/A | local audit unavailable |
|
| 104 |
+
| BlendedMVS | 64 | N/A | N/A | N/A | local audit unavailable |
|
| 105 |
+
| MegaDepth | 64 | N/A | N/A | 64 | loader-side hard cap is 64 |
|
| 106 |
+
| MVS-Synth | 100 | 69 | 100 | 100 | long clips are supported on all local scenes |
|
| 107 |
+
|
| 108 |
+
If your target cluster has different processed data, rerun the audit and override `NUM_VIEWS_*` as needed.
|
| 109 |
+
|
| 110 |
+
## Scripts provided in this release
|
| 111 |
+
|
| 112 |
+
| Script | Launch mode | Purpose |
|
| 113 |
+
| --- | --- | --- |
|
| 114 |
+
| `slam/sbatch_smoke_submap_only_fsdp_2gpu.sh` | local `sbatch` | 1-node 2-GPU smoke validation for strict submap-only mode |
|
| 115 |
+
| `slam/sbatch_smoke_joint_freeze_frontend_fsdp_2gpu.sh` | local `sbatch` | 1-node 2-GPU smoke validation for backend + submap with detached frontend |
|
| 116 |
+
| `slam/train_remote_submap_only_fsdp_8node8gpu.sh` | remote `bash` | 8-node 8-GPU-per-node long-sequence submap-only training |
|
| 117 |
+
| `slam/train_remote_joint_freeze_frontend_fsdp_8node8gpu.sh` | remote `bash` | 8-node 8-GPU-per-node long-sequence backend + submap training with detached frontend |
|
| 118 |
+
|
| 119 |
+
## Migration guide: run the smoke scripts on another machine
|
| 120 |
+
|
| 121 |
+
These two smoke launchers are self-contained. They no longer rely on the old wrapper scripts. Each launcher:
|
| 122 |
+
|
| 123 |
+
- resolves `PROJECT_DIR` from `PROJECT_DIR`, `SLURM_SUBMIT_DIR`, or the script location;
|
| 124 |
+
- requires `CONDA_SH` to point to a valid `conda.sh` file and activates `CONDA_ENV_NAME` from that shell;
|
| 125 |
+
- loads `cuda12.1/toolkit` when the `module` command is available;
|
| 126 |
+
- sets `PYTHONPATH` to include `src/`;
|
| 127 |
+
- launches `src/finetune.py` through `accelerate launch`.
|
| 128 |
+
|
| 129 |
+
### 1. Minimum checklist before the first run
|
| 130 |
+
|
| 131 |
+
1. Clone or copy the repository onto the new machine.
|
| 132 |
+
2. Create the environment with `conda env create -f environment.yml`, and make sure `CONDA_SH` points at the correct `conda.sh` on the new machine.
|
| 133 |
+
3. Make sure the pretrained checkpoint exists at `PRETRAINED` or override `PRETRAINED=...`.
|
| 134 |
+
4. Point `DATA_ROOT` to the processed training data on the new machine.
|
| 135 |
+
5. Verify the dataset roots under `DATA_ROOT`, or override the individual `ROOT_*` variables.
|
| 136 |
+
6. Check that the new machine has a compatible CUDA setup. If the cluster uses a different module name, edit the `module load cuda12.1/toolkit` line.
|
| 137 |
+
7. If you submit with Slurm, keep the `#SBATCH` resource requests consistent with the machine’s GPU, CPU, and memory limits.
|
| 138 |
+
|
| 139 |
+
### 2. Variables you usually need to change
|
| 140 |
+
|
| 141 |
+
| Variable | Meaning | Typical reason to change |
|
| 142 |
+
| --- | --- | --- |
|
| 143 |
+
| `PROJECT_DIR` | repository root used for `PYTHONPATH`, output paths, and the default checkpoint path | the repo lives in a different directory |
|
| 144 |
+
| `CONDA_SH` | path to `conda.sh` used to initialize Conda | the machine uses a different Miniconda install, or the default path does not exist |
|
| 145 |
+
| `CONDA_ENV_NAME` | environment name to activate | you created the env under a different name |
|
| 146 |
+
| `DATA_ROOT` | top-level directory that contains the processed datasets | the training data is mounted elsewhere |
|
| 147 |
+
| `ROOT_ARKIT`, `ROOT_SCANNETPP`, `ROOT_SCANNET`, `ROOT_SCANNET_FALLBACK`, `ROOT_HYPERSIM`, `ROOT_BLENDEDMVS`, `ROOT_MEGADEPTH`, `ROOT_MVS_SYNTH` | per-dataset roots used by the launchers | the dataset folders do not live directly under `DATA_ROOT` or ScanNet needs a fallback root |
|
| 148 |
+
| `PRETRAINED` | checkpoint loaded before training starts | the pretrained model is stored somewhere else |
|
| 149 |
+
| `SAVE_DIR` | root directory for checkpoints and logs | you want outputs on a different disk |
|
| 150 |
+
| `MASTER_PORT` | port used by `accelerate` to rendezvous the worker processes | another job is already using the default port |
|
| 151 |
+
| `NUM_GPUS` | number of processes / GPUs launched | the target node exposes a different GPU count |
|
| 152 |
+
| `AUTO_DISABLE_MISSING` | auto-disable a dataset whose root is missing or incomplete | set to `0` if you want the job to fail fast instead of silently skipping data |
|
| 153 |
+
| `EXPERIMENT_ROOT`, `VARIANT_NAME`, `EXP_NAME` | folder and experiment naming used for outputs | you want a different output namespace or to avoid collisions with old runs |
|
| 154 |
+
| `RESUME` | checkpoint path to resume from | you are continuing an interrupted run |
|
| 155 |
+
|
| 156 |
+
The following Slurm header lines also need cluster-specific tuning:
|
| 157 |
+
|
| 158 |
+
- `#SBATCH --gres=gpu:2` requests two GPUs.
|
| 159 |
+
- `#SBATCH --cpus-per-task` controls CPU cores.
|
| 160 |
+
- `#SBATCH --mem` controls RAM.
|
| 161 |
+
- `#SBATCH --time` is the wall-time limit.
|
| 162 |
+
|
| 163 |
+
### 3. Variables that control the training recipe
|
| 164 |
+
|
| 165 |
+
| Variable | Meaning | Notes |
|
| 166 |
+
| --- | --- | --- |
|
| 167 |
+
| `CONFIG_NAME` | Hydra config file passed to `src/finetune.py` | both smoke scripts use `finetune_paper_h20.yaml` |
|
| 168 |
+
| `DIST_STRATEGY` | distributed backend (`fsdp` or `ddp`) | both smoke scripts use `fsdp` |
|
| 169 |
+
| `TRAIN_SUBMAP_MODULES_ONLY` | `1` = strict submap-only training; `0` = joint backend+submap training | `1` for `slam/sbatch_smoke_submap_only_fsdp_2gpu.sh`, `0` for `slam/sbatch_smoke_joint_freeze_frontend_fsdp_2gpu.sh` |
|
| 170 |
+
| `DETACH_FRONTEND_TOKENS` | detach frontend tokens from gradients | `1` in both scripts |
|
| 171 |
+
| `SUBMAP_SIZE` | number of frames / tokens per submap | default `12` in both scripts |
|
| 172 |
+
| `SUBMAP_TRAIN_MODE` | submap training mode | default `full_token` |
|
| 173 |
+
| `SUBMAP_RETRIEVAL_TOPK` | retrieval top-k setting | default `0` disables retrieval |
|
| 174 |
+
| `SUBMAP_FETCH_SOURCE` | source used to fetch submap features | default `frontend` |
|
| 175 |
+
| `SUBMAP_DESCRIPTOR_SOURCE` | source used to build submap descriptors | default `frontend` |
|
| 176 |
+
| `ENABLE_PSEUDO_GT` | enable pseudo-GT cache usage | keep `0` unless you have a valid cache |
|
| 177 |
+
| `PSEUDO_GT_CACHE_PATH` | path to the pseudo-GT cache | required when `ENABLE_PSEUDO_GT=1` |
|
| 178 |
+
| `EPOCHS` | number of epochs passed to `src/finetune.py` | smoke scripts default to `2` |
|
| 179 |
+
|
| 180 |
+
### 4. Dataset-mixture knobs
|
| 181 |
+
|
| 182 |
+
| Variable | Meaning | Notes |
|
| 183 |
+
| --- | --- | --- |
|
| 184 |
+
| `SAMPLES_ARKIT`, `SAMPLES_SCANNETPP`, `SAMPLES_SCANNET`, `SAMPLES_HYPERSIM`, `SAMPLES_BLENDEDMVS`, `SAMPLES_MEGADEPTH`, `SAMPLES_MVS_SYNTH` | per-dataset sampling weights in the training mixture | increase or decrease them to rebalance the dataset mix |
|
| 185 |
+
| `NUM_VIEWS_ARKIT`, `NUM_VIEWS_SCANNETPP`, `NUM_VIEWS_SCANNET`, `NUM_VIEWS_HYPERSIM`, `NUM_VIEWS_BLENDEDMVS`, `NUM_VIEWS_MEGADEPTH`, `NUM_VIEWS_MVS_SYNTH` | per-dataset view caps | change these if the processed data on the new machine has different sequence lengths or hard caps |
|
| 186 |
+
| `GLOBAL_NUM_VIEWS` | optional global cap; if unset, the scripts derive it from the active datasets’ `NUM_VIEWS_*` values | set it when you want a single global value for all datasets |
|
| 187 |
+
| `NUM_VIEWS_ALL` | compatibility placeholder kept by the scripts | usually leave it at the default; the launcher mainly uses the per-dataset values above |
|
| 188 |
+
|
| 189 |
+
### 5. The two scripts differ only in these defaults
|
| 190 |
+
|
| 191 |
+
| Item | Submap-only script | Joint + frozen frontend script | Meaning |
|
| 192 |
+
| --- | --- | --- | --- |
|
| 193 |
+
| `TRAIN_SUBMAP_MODULES_ONLY` | `1` | `0` | whether to train only the submap modules or the joint backend + submap stack |
|
| 194 |
+
| `MASTER_PORT` | `29661` | `29662` | keep the two smoke jobs from colliding on the same node |
|
| 195 |
+
| `VARIANT_NAME` | `submap_only_fsdp_sub12` | `joint_freeze_frontend_fsdp_sub12` | output subdirectory name under `SAVE_DIR` |
|
| 196 |
+
| `EXP_NAME` | `paper_smoke_submap_only_fsdp_2gpu_sub12` | `paper_smoke_joint_freeze_frontend_fsdp_2gpu_sub12` | experiment name written into logs and Hydra config |
|
| 197 |
+
| `#SBATCH --cpus-per-task` | `24` | `12` | CPU reservation for the smoke job |
|
| 198 |
+
| `#SBATCH --mem` | `120G` | `24G` | memory reservation for the smoke job |
|
| 199 |
+
|
| 200 |
+
### 6. Direct launch examples on a new machine
|
| 201 |
+
|
| 202 |
+
To run on another Slurm machine, set the machine-specific variables inline and submit the launcher from the repo root.
|
| 203 |
+
|
| 204 |
+
Submap-only smoke:
|
| 205 |
+
|
| 206 |
+
```bash
|
| 207 |
+
PROJECT_DIR=/path/to/e2e-semantic-SLAM \
|
| 208 |
+
CONDA_SH=/path/to/miniconda3/etc/profile.d/conda.sh \
|
| 209 |
+
CONDA_ENV_NAME=SLAM-Former \
|
| 210 |
+
DATA_ROOT=/path/to/data/train \
|
| 211 |
+
PRETRAINED=/path/to/ckpt/checkpoint-10.pth.model \
|
| 212 |
+
SAVE_DIR=/path/to/checkpoints \
|
| 213 |
+
MASTER_PORT=29661 \
|
| 214 |
+
sbatch slam/sbatch_smoke_submap_only_fsdp_2gpu.sh
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
Joint + frozen frontend smoke:
|
| 218 |
+
|
| 219 |
+
```bash
|
| 220 |
+
PROJECT_DIR=/path/to/e2e-semantic-SLAM \
|
| 221 |
+
CONDA_SH=/path/to/miniconda3/etc/profile.d/conda.sh \
|
| 222 |
+
CONDA_ENV_NAME=SLAM-Former \
|
| 223 |
+
DATA_ROOT=/path/to/data/train \
|
| 224 |
+
PRETRAINED=/path/to/ckpt/checkpoint-10.pth.model \
|
| 225 |
+
SAVE_DIR=/path/to/checkpoints \
|
| 226 |
+
MASTER_PORT=29662 \
|
| 227 |
+
sbatch slam/sbatch_smoke_joint_freeze_frontend_fsdp_2gpu.sh
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
If the new machine does not have Slurm, you can still run the scripts with `bash ...` as long as the same two GPUs are visible to the shell and `accelerate` / CUDA are available; the `#SBATCH` lines are then ignored by Bash.
|
| 231 |
+
|
| 232 |
+
## Local smoke validation on the current cluster
|
| 233 |
+
|
| 234 |
+
The two local `sbatch` scripts are intended to validate:
|
| 235 |
+
|
| 236 |
+
- the launcher path
|
| 237 |
+
- the FSDP wiring
|
| 238 |
+
- the README instructions
|
| 239 |
+
- the two comparison modes
|
| 240 |
+
|
| 241 |
+
They intentionally use a small sample budget and `64` views so they can be checked quickly on 1 node and 2 GPUs.
|
| 242 |
+
|
| 243 |
+
### 1. Submap-only smoke
|
| 244 |
+
|
| 245 |
+
```bash
|
| 246 |
+
EPOCHS=1 \
|
| 247 |
+
bash slam/sbatch_smoke_submap_only_fsdp_2gpu.sh
|
| 248 |
+
```
|
| 249 |
+
|
| 250 |
+
### 2. Backend + submap smoke with detached frontend
|
| 251 |
+
|
| 252 |
+
```bash
|
| 253 |
+
EPOCHS=1 \
|
| 254 |
+
bash slam/sbatch_smoke_joint_freeze_frontend_fsdp_2gpu.sh
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
Local smoke defaults:
|
| 258 |
+
|
| 259 |
+
- `SAMPLES_ARKIT=32`
|
| 260 |
+
- `SAMPLES_SCANNETPP=16`
|
| 261 |
+
- all other local smoke sample weights default to `0`
|
| 262 |
+
- `NUM_VIEWS_* = 64`
|
| 263 |
+
- `SUBMAP_SIZE = 12`
|
| 264 |
+
|
| 265 |
+
This keeps the smoke validation focused on code path correctness rather than final throughput.
|
| 266 |
+
|
| 267 |
+
## Remote long-sequence training on 8 nodes x 8 GPUs
|
| 268 |
+
|
| 269 |
+
The remote scripts are the actual release launchers for long-sequence comparison.
|
| 270 |
+
Run the **same script on every node**, and only change `MACHINE_RANK`.
|
| 271 |
+
|
| 272 |
+
Remote defaults:
|
| 273 |
+
|
| 274 |
+
- `AUTO_DISABLE_MISSING=0`
|
| 275 |
+
- full paper dataset sample mix from `finetune_paper_h20.yaml`
|
| 276 |
+
- aggressive per-dataset `NUM_VIEWS_*` values from the table above
|
| 277 |
+
- `SUBMAP_SIZE=12`
|
| 278 |
+
- `SUBMAP_FETCH_SOURCE=frontend`
|
| 279 |
+
- `SUBMAP_DESCRIPTOR_SOURCE=frontend`
|
| 280 |
+
|
| 281 |
+
### 1. Remote 8x8 submap-only run
|
| 282 |
+
|
| 283 |
+
On rank 0:
|
| 284 |
+
|
| 285 |
+
```bash
|
| 286 |
+
MASTER_ADDR=node0 \
|
| 287 |
+
MACHINE_RANK=0 \
|
| 288 |
+
NUM_MACHINES=8 \
|
| 289 |
+
GPUS_PER_NODE=8 \
|
| 290 |
+
DATA_ROOT=/path/to/data/train \
|
| 291 |
+
SAVE_DIR=/path/to/checkpoints \
|
| 292 |
+
PRETRAINED=/path/to/ckpt/checkpoint-10.pth.model \
|
| 293 |
+
CONDA_SH=/path/to/miniconda3/etc/profile.d/conda.sh \
|
| 294 |
+
CONDA_ENV_NAME=SLAM-Former \
|
| 295 |
+
EPOCHS=10 \
|
| 296 |
+
bash slam/train_remote_submap_only_fsdp_8node8gpu.sh
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
On the remaining nodes, rerun the same command with `MACHINE_RANK=1` through `7`.
|
| 300 |
+
|
| 301 |
+
### 2. Remote 8x8 backend + submap run with detached frontend
|
| 302 |
+
|
| 303 |
+
On rank 0:
|
| 304 |
+
|
| 305 |
+
```bash
|
| 306 |
+
MASTER_ADDR=node0 \
|
| 307 |
+
MACHINE_RANK=0 \
|
| 308 |
+
NUM_MACHINES=8 \
|
| 309 |
+
GPUS_PER_NODE=8 \
|
| 310 |
+
DATA_ROOT=/path/to/data/train \
|
| 311 |
+
SAVE_DIR=/path/to/checkpoints \
|
| 312 |
+
PRETRAINED=/path/to/ckpt/checkpoint-10.pth.model \
|
| 313 |
+
CONDA_SH=/path/to/miniconda3/etc/profile.d/conda.sh \
|
| 314 |
+
CONDA_ENV_NAME=SLAM-Former \
|
| 315 |
+
EPOCHS=10 \
|
| 316 |
+
bash slam/train_remote_joint_freeze_frontend_fsdp_8node8gpu.sh
|
| 317 |
+
```
|
| 318 |
+
|
| 319 |
+
On the remaining nodes, rerun the same command with `MACHINE_RANK=1` through `7`.
|
| 320 |
+
|
| 321 |
+
## Whole-sequence TUM inference after training
|
| 322 |
+
|
| 323 |
+
After training, use the saved checkpoint to run the full Freiburg1 TUM sequences.
|
| 324 |
+
|
| 325 |
+
```bash
|
| 326 |
+
CKPT_PATH=/path/to/checkpoint-last.pth \
|
| 327 |
+
RUN_TAG=release_eval \
|
| 328 |
+
SUBMAP_INFERENCE_MODE=full \
|
| 329 |
+
SUBMAP_TRAIN_MODE=full_token \
|
| 330 |
+
SUBMAP_FETCH_SOURCE=frontend \
|
| 331 |
+
SUBMAP_DESCRIPTOR_SOURCE=frontend \
|
| 332 |
+
sbatch run_tum.sh
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
This writes outputs to:
|
| 336 |
+
|
| 337 |
+
```bash
|
| 338 |
+
tum_results_aligned/<RUN_TAG>/rgbd_dataset_freiburg1_*/
|
| 339 |
+
```
|
| 340 |
+
|
| 341 |
+
Each sequence directory contains at least:
|
| 342 |
+
|
| 343 |
+
- `final_traj.txt`
|
| 344 |
+
- `final.ply`
|
| 345 |
+
- `final_pc/`
|
| 346 |
+
|
| 347 |
+
To evaluate ATE after the sequence finishes:
|
| 348 |
+
|
| 349 |
+
```bash
|
| 350 |
+
evo_ape tum <ground_truth.txt> <final_traj.txt> -a --t_max_diff 0.02
|
| 351 |
+
```
|
| 352 |
+
|
| 353 |
+
## Output layout
|
| 354 |
+
|
| 355 |
+
Training outputs are written under:
|
| 356 |
+
|
| 357 |
+
```bash
|
| 358 |
+
$SAVE_DIR/$EXP_NAME/
|
| 359 |
+
```
|
| 360 |
+
|
| 361 |
+
Typical files include:
|
| 362 |
+
|
| 363 |
+
- `checkpoint-last.pth`
|
| 364 |
+
- `model.pth`
|
| 365 |
+
- `logs/`
|
| 366 |
+
- launcher stdout / stderr logs
|
| 367 |
+
|
| 368 |
+
## Practical notes
|
| 369 |
+
|
| 370 |
+
- The local `sbatch` scripts are smoke validators.
|
| 371 |
+
- The remote `bash` scripts are the actual long-sequence release launchers.
|
| 372 |
+
- If the target cluster has a different sequence-cap profile, rerun `slam/audit_dataset_num_views.py` there and override `NUM_VIEWS_*`.
|
| 373 |
+
- If a remote run fails on missing dataset roots, keep `AUTO_DISABLE_MISSING=0` and fix the paths instead of silently training on a partial dataset mix.
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/README_submap.md
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<h1>SLAM-Former Submap Companion / 子图系统说明</h1>
|
| 3 |
+
</div>
|
| 4 |
+
|
| 5 |
+
### Updates
|
| 6 |
+
* [Mar 30, 2026] Added a submap-native training and inference path with switchable `full_token` / `top5_dual_queue` modes.
|
| 7 |
+
* [Mar 30, 2026] Added dedicated TUM top5 inference support and a standalone `run_tum_top5.sh` launcher.
|
| 8 |
+
* [Mar 26, 2026] Submap-only pseudo-GT training was exercised with the high-recall config and FSDP launchers.
|
| 9 |
+
|
| 10 |
+
### Getting Started
|
| 11 |
+
|
| 12 |
+
The environment setup is the same as the original `README.md`:
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
git clone https://github.com/Tsinghua-MARS-Lab/SLAM-Former.git
|
| 16 |
+
cd SLAM-Former
|
| 17 |
+
|
| 18 |
+
conda create -n SLAM-Former python=3.11
|
| 19 |
+
conda activate SLAM-Former
|
| 20 |
+
|
| 21 |
+
pip install -r requirements.txt
|
| 22 |
+
pip install -e .
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### Submap System Overview
|
| 26 |
+
|
| 27 |
+
This companion branch keeps the native SLAM-Former pipeline as close as possible to the original implementation, while adding a submap-oriented memory backend and switchable queue semantics.
|
| 28 |
+
|
| 29 |
+
Core ideas:
|
| 30 |
+
|
| 31 |
+
* **`GraphGatedMemoryManager`** stores historical submaps on CPU, keeps descriptors, and performs loop retrieval.
|
| 32 |
+
* **`slam/demo_submap.py`** is the submap-aware inference entrypoint.
|
| 33 |
+
* **`src/forward_pass.py`** and **`src/forward_pass_submap.py`** slice the backend output so supervision is applied to the current submap while still letting the backend see `prev + curr + retrieved` context.
|
| 34 |
+
* **`src/finetune.py`** and the H20 launchers expose the submap configuration through script variables.
|
| 35 |
+
* **`run_tum.sh`** and **`run_tum_top5.sh`** drive TUM inference with either the full-token path or the top5 path.
|
| 36 |
+
|
| 37 |
+
Important submap switches:
|
| 38 |
+
|
| 39 |
+
* `TRAIN_SUBMAP_MODULES_ONLY`
|
| 40 |
+
* `1` freezes the main SLAMFormer parameters and trains the submap-side modules.
|
| 41 |
+
* `0` keeps the joint training path available.
|
| 42 |
+
* `SUBMAP_TRAIN_MODE`
|
| 43 |
+
* `full_token`: keep the submap path close to the native full-token behavior.
|
| 44 |
+
* `top5_dual_queue`: use two queues, where the frontend queue is read-only and the backend queue receives write-back for retrieved historical submaps.
|
| 45 |
+
* `SUBMAP_RETRIEVAL_TOPK`
|
| 46 |
+
* number of historical submaps fetched in the soft retrieval mode.
|
| 47 |
+
* `SUBMAP_FETCH_SOURCE` / `SUBMAP_DESCRIPTOR_SOURCE`
|
| 48 |
+
* choose whether retrieval and descriptor storage read from `frontend` or `backend` banks.
|
| 49 |
+
* `SUBMAP_INFERENCE_MODE`
|
| 50 |
+
* `full` or `top5` for the TUM launch scripts.
|
| 51 |
+
|
| 52 |
+
### Training Modes
|
| 53 |
+
|
| 54 |
+
There are three common training setups in this branch:
|
| 55 |
+
|
| 56 |
+
#### 1. Joint baseline
|
| 57 |
+
|
| 58 |
+
Use the original joint configuration when you want the closest comparison to the official training branch:
|
| 59 |
+
|
| 60 |
+
* `config/finetune.yaml`
|
| 61 |
+
* `TRAIN_SUBMAP_MODULES_ONLY=0`
|
| 62 |
+
* `SUBMAP_TRAIN_MODE=full_token`
|
| 63 |
+
* `SUBMAP_RETRIEVAL_TOPK=0`
|
| 64 |
+
|
| 65 |
+
#### 2. Submap-only full-token training
|
| 66 |
+
|
| 67 |
+
This is the first submap stage: descriptors and historical submaps are trained with full-token submaps so historical submaps can still receive gradients.
|
| 68 |
+
|
| 69 |
+
Recommended config:
|
| 70 |
+
|
| 71 |
+
* `config/finetune_sub_only.yaml`
|
| 72 |
+
* or `config/finetune_pseudo_gt_high_recall.yaml`
|
| 73 |
+
|
| 74 |
+
Typical launch knobs:
|
| 75 |
+
|
| 76 |
+
* `TRAIN_SUBMAP_MODULES_ONLY=1`
|
| 77 |
+
* `SUBMAP_TRAIN_MODE=full_token`
|
| 78 |
+
* `SUBMAP_RETRIEVAL_TOPK=0`
|
| 79 |
+
|
| 80 |
+
Example:
|
| 81 |
+
|
| 82 |
+
```bash
|
| 83 |
+
ENABLE_PSEUDO_GT=1 \
|
| 84 |
+
PSEUDO_GT_CACHE_PATH=/var/scratch/qzhang2/SLAM-Former/data/train/pseudo_gt/arkitscenes_smoke_test.json \
|
| 85 |
+
CONFIG_NAME=finetune_pseudo_gt_high_recall.yaml \
|
| 86 |
+
TRAIN_SUBMAP_MODULES_ONLY=1 \
|
| 87 |
+
SUBMAP_TRAIN_MODE=full_token \
|
| 88 |
+
SUBMAP_RETRIEVAL_TOPK=0 \
|
| 89 |
+
sbatch slam/sbatch_finetune.sh
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
#### 3. Submap top5 fine-tuning
|
| 93 |
+
|
| 94 |
+
This is the second stage: the backend sees the top5 historical submaps, and the system uses the dual-queue semantics.
|
| 95 |
+
|
| 96 |
+
Typical launch knobs:
|
| 97 |
+
|
| 98 |
+
* `TRAIN_SUBMAP_MODULES_ONLY=1`
|
| 99 |
+
* `SUBMAP_TRAIN_MODE=top5_dual_queue`
|
| 100 |
+
* `SUBMAP_RETRIEVAL_TOPK=5`
|
| 101 |
+
* `SUBMAP_FETCH_SOURCE=frontend`
|
| 102 |
+
* `SUBMAP_DESCRIPTOR_SOURCE=frontend`
|
| 103 |
+
|
| 104 |
+
Example:
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
ENABLE_PSEUDO_GT=1 \
|
| 108 |
+
PSEUDO_GT_CACHE_PATH=/var/scratch/qzhang2/SLAM-Former/data/train/pseudo_gt/arkitscenes_smoke_test.json \
|
| 109 |
+
CONFIG_NAME=finetune_pseudo_gt_high_recall.yaml \
|
| 110 |
+
TRAIN_SUBMAP_MODULES_ONLY=1 \
|
| 111 |
+
SUBMAP_TRAIN_MODE=top5_dual_queue \
|
| 112 |
+
SUBMAP_RETRIEVAL_TOPK=5 \
|
| 113 |
+
sbatch slam/sbatch_finetune.sh
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
### Inference Modes
|
| 117 |
+
|
| 118 |
+
#### Official baseline comparison
|
| 119 |
+
|
| 120 |
+
For the native baseline, keep using the original demo path from `README.md`:
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
python slam/demo.py \
|
| 124 |
+
--ckpt_path ckpt/checkpoint.pth.model \
|
| 125 |
+
--image_folder /path/to/your/images/ \
|
| 126 |
+
--output_dir /output/result \
|
| 127 |
+
--target_size 518 \
|
| 128 |
+
--retention_ratio 0.5
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
#### Submap TUM inference: full mode
|
| 132 |
+
|
| 133 |
+
This path uses the submap-aware demo but keeps the inference behavior in `full` mode.
|
| 134 |
+
|
| 135 |
+
```bash
|
| 136 |
+
SUBMAP_INFERENCE_MODE=full \
|
| 137 |
+
CKPT_PATH=/var/scratch/qzhang2/SLAM-Former/ckpt/checkpoint-10.pth.model \
|
| 138 |
+
sbatch run_tum.sh
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
#### Submap TUM inference: top5 mode
|
| 142 |
+
|
| 143 |
+
This is the dedicated top5 launcher for comparing against the full mode.
|
| 144 |
+
|
| 145 |
+
```bash
|
| 146 |
+
sbatch run_tum_top5.sh
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
You can also override the checkpoint and output root explicitly:
|
| 150 |
+
|
| 151 |
+
```bash
|
| 152 |
+
CKPT_PATH=/var/scratch/qzhang2/SLAM-Former/checkpoints/local_cluster_nv24_sub6/submap_only_pseudo_gt_high_recall_smoke/paper_local_submap_only_pseudo_gt_high_recall_smoke_nv24_sub6/checkpoint-last.pth \
|
| 153 |
+
SUBMAP_INFERENCE_MODE=top5 \
|
| 154 |
+
RUN_TAG=my_top5_compare \
|
| 155 |
+
OUT_DIR=/var/scratch/qzhang2/SLAM-Former/tum_results_aligned_top5/my_top5_compare \
|
| 156 |
+
sbatch run_tum_top5.sh
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
### Launch Scripts
|
| 160 |
+
|
| 161 |
+
| Script | Purpose | Notes |
|
| 162 |
+
| --- | --- | --- |
|
| 163 |
+
| `slam/sbatch_finetune.sh` | Local 3-GPU FSDP finetune launcher | Uses environment variables to select config, submap mode, and pseudo-GT cache. |
|
| 164 |
+
| `slam/run_train_h20_single.sh` | Single-node H20-style training launcher | Good when you want to run the training job directly without the wrapper. |
|
| 165 |
+
| `slam/run_train_h20_multi.sh` | Multi-node H20-style training launcher | Keeps the same submap knobs, but launches across multiple machines. |
|
| 166 |
+
| `run_tum.sh` | TUM inference launcher | Supports `SUBMAP_INFERENCE_MODE=full|top5`. |
|
| 167 |
+
| `run_tum_top5.sh` | Dedicated TUM top5 launcher | Defaults to the latest submap checkpoint and top5 mode. |
|
| 168 |
+
| `slam/demo_submap.py` | Manual inference entrypoint | Accepts `--submap_train_mode`, `--submap_retrieval_topk`, `--loop_mask_mode`, `--submap_fetch_source`, `--submap_descriptor_source`, and `--max_recursive_submaps`. |
|
| 169 |
+
|
| 170 |
+
### Configuration Files
|
| 171 |
+
|
| 172 |
+
| Config | Role |
|
| 173 |
+
| --- | --- |
|
| 174 |
+
| `config/finetune.yaml` | Joint training baseline. |
|
| 175 |
+
| `config/finetune_sub_only.yaml` | Submap-only training with full-token semantics. |
|
| 176 |
+
| `config/finetune_pseudo_gt_high_recall.yaml` | Submap-only training with higher-recall pseudo-GT settings. |
|
| 177 |
+
|
| 178 |
+
### Data and Checkpoints
|
| 179 |
+
|
| 180 |
+
The data layout is the same as the original project. The TUM root used by the launch scripts is:
|
| 181 |
+
|
| 182 |
+
```bash
|
| 183 |
+
/var/scratch/qzhang2/Feature-SLAM/datasets/tum
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
Typical checkpoint locations are under:
|
| 187 |
+
|
| 188 |
+
```bash
|
| 189 |
+
/var/scratch/qzhang2/SLAM-Former/checkpoints/
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
For inference, the scripts usually read `checkpoint-last.pth` from the latest experiment folder.
|
| 193 |
+
|
| 194 |
+
### Output Layout
|
| 195 |
+
|
| 196 |
+
Submap TUM inference writes one folder per sequence, for example:
|
| 197 |
+
|
| 198 |
+
```bash
|
| 199 |
+
tum_results_aligned_top5/<run_tag>/rgbd_dataset_freiburg1_360/
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
Inside each sequence directory you should expect:
|
| 203 |
+
|
| 204 |
+
* `final_traj.txt`
|
| 205 |
+
* `final.ply`
|
| 206 |
+
* `final_pc/`
|
| 207 |
+
|
| 208 |
+
### Visualization
|
| 209 |
+
|
| 210 |
+
Static visualization is the same as the original branch:
|
| 211 |
+
|
| 212 |
+
```bash
|
| 213 |
+
python slam/visualize_results.py --result_dir /path/to/output_dir
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
For TUM results generated by the submap scripts, point `--result_dir` to the corresponding output folder.
|
| 217 |
+
|
| 218 |
+
### Notes
|
| 219 |
+
|
| 220 |
+
* Keep `README.md` as the official baseline reference.
|
| 221 |
+
* Use this file when you want to describe or run the submap branch.
|
| 222 |
+
* The main comparison axes are:
|
| 223 |
+
* official baseline vs submap branch
|
| 224 |
+
* full-token submap mode vs top5 dual-queue mode
|
| 225 |
+
* joint training vs submap-only training
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/base_opt.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import roma
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
import tqdm
|
| 9 |
+
import os
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
from cloud_opt.utils import *
|
| 13 |
+
from cloud_opt.utils import _check_edges, _compute_img_conf
|
| 14 |
+
import cloud_opt.init_all as init_fun
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BaseOptimizer(nn.Module):
|
| 18 |
+
"""Optimize a global scene, given a graph-organized observations.
|
| 19 |
+
Graph node: images
|
| 20 |
+
Graph edges: observations = (pred1, pred2), pred2 is in pred1's coordinate
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, *args, **kwargs):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
def _init_from_views(
|
| 27 |
+
self,
|
| 28 |
+
view1s,
|
| 29 |
+
view2s,
|
| 30 |
+
pred1s,
|
| 31 |
+
pred2s, # whatever predictions, they should be organized into pairwise for graph optimization
|
| 32 |
+
dist="l1",
|
| 33 |
+
conf="log",
|
| 34 |
+
min_conf_thr=3,
|
| 35 |
+
thr_for_init_conf=False,
|
| 36 |
+
base_scale=0.5,
|
| 37 |
+
allow_pw_adaptors=False,
|
| 38 |
+
pw_break=20,
|
| 39 |
+
rand_pose=torch.randn,
|
| 40 |
+
empty_cache=False,
|
| 41 |
+
verbose=True,
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.edges = [
|
| 45 |
+
(int(view1["idx"]), int(view2["idx"]))
|
| 46 |
+
for view1, view2 in zip(view1s, view2s)
|
| 47 |
+
]
|
| 48 |
+
self.dist = ALL_DISTS[dist]
|
| 49 |
+
self.n_imgs = _check_edges(self.edges)
|
| 50 |
+
|
| 51 |
+
self.edge2pts_i = NoGradParamDict(
|
| 52 |
+
{ij: pred1s[n]["pts3d_is_self_view"] for n, ij in enumerate(self.str_edges)}
|
| 53 |
+
) # ij: the name of the edge
|
| 54 |
+
self.edge2pts_j = NoGradParamDict(
|
| 55 |
+
{
|
| 56 |
+
ij: pred2s[n]["pts3d_in_other_view"]
|
| 57 |
+
for n, ij in enumerate(self.str_edges)
|
| 58 |
+
}
|
| 59 |
+
)
|
| 60 |
+
self.edge2conf_i = NoGradParamDict(
|
| 61 |
+
{ij: pred1s[n]["conf_self"] for n, ij in enumerate(self.str_edges)}
|
| 62 |
+
)
|
| 63 |
+
self.edge2conf_j = NoGradParamDict(
|
| 64 |
+
{ij: pred2s[n]["conf"] for n, ij in enumerate(self.str_edges)}
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.imshapes = get_imshapes(self.edges, pred1s, pred2s)
|
| 68 |
+
self.min_conf_thr = min_conf_thr
|
| 69 |
+
self.thr_for_init_conf = thr_for_init_conf
|
| 70 |
+
self.conf_trf = get_conf_trf(conf)
|
| 71 |
+
|
| 72 |
+
self.im_conf = _compute_img_conf(
|
| 73 |
+
self.imshapes, self.device, self.edges, self.edge2conf_i, self.edge2conf_j
|
| 74 |
+
)
|
| 75 |
+
for i in range(len(self.im_conf)):
|
| 76 |
+
self.im_conf[i].requires_grad = False
|
| 77 |
+
|
| 78 |
+
self.init_conf_maps = [c.clone() for c in self.im_conf]
|
| 79 |
+
|
| 80 |
+
self.base_scale = base_scale
|
| 81 |
+
self.norm_pw_scale = True
|
| 82 |
+
self.pw_break = pw_break
|
| 83 |
+
self.POSE_DIM = 7
|
| 84 |
+
self.pw_poses = nn.Parameter(
|
| 85 |
+
rand_pose((self.n_edges, 1 + self.POSE_DIM))
|
| 86 |
+
) # pairwise poses
|
| 87 |
+
self.pw_adaptors = nn.Parameter(
|
| 88 |
+
torch.zeros((self.n_edges, 2))
|
| 89 |
+
) # slight xy/z adaptation
|
| 90 |
+
self.pw_adaptors.requires_grad_(allow_pw_adaptors)
|
| 91 |
+
self.has_im_poses = False
|
| 92 |
+
self.rand_pose = rand_pose
|
| 93 |
+
|
| 94 |
+
def get_known_poses(self):
|
| 95 |
+
if self.has_im_poses:
|
| 96 |
+
known_poses_msk = torch.tensor(
|
| 97 |
+
[not (p.requires_grad) for p in self.im_poses]
|
| 98 |
+
)
|
| 99 |
+
known_poses = self.get_im_poses()
|
| 100 |
+
return known_poses_msk.sum(), known_poses_msk, known_poses
|
| 101 |
+
else:
|
| 102 |
+
return 0, None, None
|
| 103 |
+
|
| 104 |
+
def get_pw_norm_scale_factor(self):
|
| 105 |
+
if self.norm_pw_scale:
|
| 106 |
+
# normalize scales so that things cannot go south
|
| 107 |
+
# we want that exp(scale) ~= self.base_scale
|
| 108 |
+
return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
|
| 109 |
+
else:
|
| 110 |
+
return 1 # don't norm scale for known poses
|
| 111 |
+
|
| 112 |
+
def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
|
| 113 |
+
# all poses == cam-to-world
|
| 114 |
+
pose = poses[idx]
|
| 115 |
+
if not (pose.requires_grad or force):
|
| 116 |
+
return pose
|
| 117 |
+
|
| 118 |
+
if R.shape == (4, 4):
|
| 119 |
+
assert T is None
|
| 120 |
+
T = R[:3, 3]
|
| 121 |
+
R = R[:3, :3]
|
| 122 |
+
|
| 123 |
+
if R is not None:
|
| 124 |
+
pose.data[0:4] = roma.rotmat_to_unitquat(R)
|
| 125 |
+
if T is not None:
|
| 126 |
+
pose.data[4:7] = signed_log1p(
|
| 127 |
+
T / (scale or 1)
|
| 128 |
+
) # translation is function of scale
|
| 129 |
+
|
| 130 |
+
if scale is not None:
|
| 131 |
+
assert poses.shape[-1] in (8, 13)
|
| 132 |
+
pose.data[-1] = np.log(float(scale))
|
| 133 |
+
return pose
|
| 134 |
+
|
| 135 |
+
def forward(self, ret_details=False):
|
| 136 |
+
pw_poses = self.get_pw_poses() # cam-to-world
|
| 137 |
+
pw_adapt = self.get_adaptors()
|
| 138 |
+
proj_pts3d = self.get_pts3d()
|
| 139 |
+
# pre-compute pixel weights
|
| 140 |
+
weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
|
| 141 |
+
weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
|
| 142 |
+
|
| 143 |
+
loss = 0
|
| 144 |
+
if ret_details:
|
| 145 |
+
details = -torch.ones((self.n_imgs, self.n_imgs))
|
| 146 |
+
|
| 147 |
+
for e, (i, j) in enumerate(self.edges):
|
| 148 |
+
i_j = edge_str(i, j)
|
| 149 |
+
# distance in image i and j
|
| 150 |
+
aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
|
| 151 |
+
aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
|
| 152 |
+
li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
|
| 153 |
+
lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
|
| 154 |
+
loss = loss + li + lj
|
| 155 |
+
|
| 156 |
+
if ret_details:
|
| 157 |
+
details[i, j] = li + lj
|
| 158 |
+
loss /= self.n_edges # average over all pairs
|
| 159 |
+
|
| 160 |
+
if ret_details:
|
| 161 |
+
return loss, details
|
| 162 |
+
return loss
|
| 163 |
+
|
| 164 |
+
@torch.cuda.amp.autocast(enabled=False)
|
| 165 |
+
def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
|
| 166 |
+
if init is None:
|
| 167 |
+
pass
|
| 168 |
+
elif init == "msp" or init == "mst":
|
| 169 |
+
init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
|
| 170 |
+
elif init == "known_poses":
|
| 171 |
+
raise NotImplementedError
|
| 172 |
+
self.preset_pose(known_poses=self.camera_poses, requires_grad=True)
|
| 173 |
+
init_fun.init_from_known_poses(
|
| 174 |
+
self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP
|
| 175 |
+
)
|
| 176 |
+
else:
|
| 177 |
+
raise ValueError(f"bad value for {init=}")
|
| 178 |
+
|
| 179 |
+
return global_alignment_loop(self, **kw)
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def str_edges(self):
|
| 183 |
+
return [edge_str(i, j) for i, j in self.edges]
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def n_edges(self):
|
| 187 |
+
return len(self.edges)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def global_alignment_loop(
|
| 191 |
+
net,
|
| 192 |
+
lr=0.01,
|
| 193 |
+
niter=300,
|
| 194 |
+
schedule="cosine",
|
| 195 |
+
lr_min=1e-3,
|
| 196 |
+
temporal_smoothing_weight=0,
|
| 197 |
+
depth_map_save_dir=None,
|
| 198 |
+
):
|
| 199 |
+
params = [p for p in net.parameters() if p.requires_grad]
|
| 200 |
+
if not params:
|
| 201 |
+
return net
|
| 202 |
+
|
| 203 |
+
verbose = net.verbose
|
| 204 |
+
if verbose:
|
| 205 |
+
print("Global alignement - optimizing for:")
|
| 206 |
+
print([name for name, value in net.named_parameters() if value.requires_grad])
|
| 207 |
+
|
| 208 |
+
lr_base = lr
|
| 209 |
+
optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
|
| 210 |
+
|
| 211 |
+
loss = float("inf")
|
| 212 |
+
if verbose:
|
| 213 |
+
with tqdm.tqdm(total=niter) as bar:
|
| 214 |
+
while bar.n < bar.total:
|
| 215 |
+
if bar.n % 500 == 0 and depth_map_save_dir is not None:
|
| 216 |
+
if not os.path.exists(depth_map_save_dir):
|
| 217 |
+
os.makedirs(depth_map_save_dir)
|
| 218 |
+
# visualize the depthmaps
|
| 219 |
+
depth_maps = net.get_depthmaps()
|
| 220 |
+
for i, depth_map in enumerate(depth_maps):
|
| 221 |
+
depth_map_save_path = os.path.join(
|
| 222 |
+
depth_map_save_dir, f"depthmaps_{i}_iter_{bar.n}.png"
|
| 223 |
+
)
|
| 224 |
+
plt.imsave(
|
| 225 |
+
depth_map_save_path,
|
| 226 |
+
depth_map.detach().cpu().numpy(),
|
| 227 |
+
cmap="jet",
|
| 228 |
+
)
|
| 229 |
+
print(
|
| 230 |
+
f"Saved depthmaps at iteration {bar.n} to {depth_map_save_dir}"
|
| 231 |
+
)
|
| 232 |
+
loss, lr = global_alignment_iter(
|
| 233 |
+
net,
|
| 234 |
+
bar.n,
|
| 235 |
+
niter,
|
| 236 |
+
lr_base,
|
| 237 |
+
lr_min,
|
| 238 |
+
optimizer,
|
| 239 |
+
schedule,
|
| 240 |
+
temporal_smoothing_weight=temporal_smoothing_weight,
|
| 241 |
+
)
|
| 242 |
+
bar.set_postfix_str(f"{lr=:g} loss={loss:g}")
|
| 243 |
+
bar.update()
|
| 244 |
+
else:
|
| 245 |
+
for n in range(niter):
|
| 246 |
+
loss, _ = global_alignment_iter(
|
| 247 |
+
net,
|
| 248 |
+
n,
|
| 249 |
+
niter,
|
| 250 |
+
lr_base,
|
| 251 |
+
lr_min,
|
| 252 |
+
optimizer,
|
| 253 |
+
schedule,
|
| 254 |
+
temporal_smoothing_weight=temporal_smoothing_weight,
|
| 255 |
+
)
|
| 256 |
+
return loss
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def global_alignment_iter(
|
| 260 |
+
net,
|
| 261 |
+
cur_iter,
|
| 262 |
+
niter,
|
| 263 |
+
lr_base,
|
| 264 |
+
lr_min,
|
| 265 |
+
optimizer,
|
| 266 |
+
schedule,
|
| 267 |
+
temporal_smoothing_weight=0,
|
| 268 |
+
):
|
| 269 |
+
t = cur_iter / niter
|
| 270 |
+
if schedule == "cosine":
|
| 271 |
+
lr = cosine_schedule(t, lr_base, lr_min)
|
| 272 |
+
elif schedule == "linear":
|
| 273 |
+
lr = linear_schedule(t, lr_base, lr_min)
|
| 274 |
+
elif schedule.startswith("cycle"):
|
| 275 |
+
try:
|
| 276 |
+
num_cycles = int(schedule[5:])
|
| 277 |
+
except ValueError:
|
| 278 |
+
num_cycles = 2
|
| 279 |
+
lr = cycled_linear_schedule(t, lr_base, lr_min, num_cycles=num_cycles)
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"bad lr {schedule=}")
|
| 282 |
+
|
| 283 |
+
adjust_learning_rate_by_lr(optimizer, lr)
|
| 284 |
+
optimizer.zero_grad()
|
| 285 |
+
|
| 286 |
+
if net.empty_cache:
|
| 287 |
+
torch.cuda.empty_cache()
|
| 288 |
+
|
| 289 |
+
loss = net(epoch=cur_iter)
|
| 290 |
+
|
| 291 |
+
if net.empty_cache:
|
| 292 |
+
torch.cuda.empty_cache()
|
| 293 |
+
|
| 294 |
+
loss.backward()
|
| 295 |
+
|
| 296 |
+
if net.empty_cache:
|
| 297 |
+
torch.cuda.empty_cache()
|
| 298 |
+
|
| 299 |
+
optimizer.step()
|
| 300 |
+
|
| 301 |
+
return float(loss), lr
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/commons.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# utility functions for global alignment
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def edge_str(i, j):
|
| 13 |
+
return f"{i}_{j}"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def i_j_ij(ij):
|
| 17 |
+
return edge_str(*ij), ij
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def edge_conf(conf_i, conf_j, edge):
|
| 21 |
+
return float(conf_i[edge].mean() * conf_j[edge].mean())
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def compute_edge_scores(edges, conf_i, conf_j):
|
| 25 |
+
return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def NoGradParamDict(x):
|
| 29 |
+
assert isinstance(x, dict)
|
| 30 |
+
return nn.ParameterDict(x).requires_grad_(False)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_imshapes(edges, pred_i, pred_j):
|
| 34 |
+
n_imgs = max(max(e) for e in edges) + 1
|
| 35 |
+
imshapes = [None] * n_imgs
|
| 36 |
+
for e, (i, j) in enumerate(edges):
|
| 37 |
+
shape_i = tuple(pred_i[e].shape[0:2])
|
| 38 |
+
shape_j = tuple(pred_j[e].shape[0:2])
|
| 39 |
+
if imshapes[i]:
|
| 40 |
+
assert imshapes[i] == shape_i, f"incorrect shape for image {i}"
|
| 41 |
+
if imshapes[j]:
|
| 42 |
+
assert imshapes[j] == shape_j, f"incorrect shape for image {j}"
|
| 43 |
+
imshapes[i] = shape_i
|
| 44 |
+
imshapes[j] = shape_j
|
| 45 |
+
return imshapes
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_conf_trf(mode):
|
| 49 |
+
if mode == "log":
|
| 50 |
+
|
| 51 |
+
def conf_trf(x):
|
| 52 |
+
return x.log()
|
| 53 |
+
|
| 54 |
+
elif mode == "sqrt":
|
| 55 |
+
|
| 56 |
+
def conf_trf(x):
|
| 57 |
+
return x.sqrt()
|
| 58 |
+
|
| 59 |
+
elif mode == "m1":
|
| 60 |
+
|
| 61 |
+
def conf_trf(x):
|
| 62 |
+
return x - 1
|
| 63 |
+
|
| 64 |
+
elif mode in ("id", "none"):
|
| 65 |
+
|
| 66 |
+
def conf_trf(x):
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError(f"bad mode for {mode=}")
|
| 71 |
+
return conf_trf
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def l2_dist(a, b, weight):
|
| 75 |
+
return (a - b).square().sum(dim=-1) * weight
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def l1_dist(a, b, weight):
|
| 79 |
+
return (a - b).norm(dim=-1) * weight
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def signed_log1p(x):
|
| 86 |
+
sign = torch.sign(x)
|
| 87 |
+
return sign * torch.log1p(torch.abs(x))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def signed_expm1(x):
|
| 91 |
+
sign = torch.sign(x)
|
| 92 |
+
return sign * torch.expm1(torch.abs(x))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def cosine_schedule(t, lr_start, lr_end):
|
| 96 |
+
assert 0 <= t <= 1
|
| 97 |
+
return lr_end + (lr_start - lr_end) * (1 + np.cos(t * np.pi)) / 2
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def linear_schedule(t, lr_start, lr_end):
|
| 101 |
+
assert 0 <= t <= 1
|
| 102 |
+
return lr_start + (lr_end - lr_start) * t
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# global alignment optimization wrapper function
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
from enum import Enum
|
| 8 |
+
|
| 9 |
+
from .optimizer import PointCloudOptimizer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GlobalAlignerMode(Enum):
|
| 13 |
+
PointCloudOptimizer = "PointCloudOptimizer"
|
| 14 |
+
ModularPointCloudOptimizer = "ModularPointCloudOptimizer"
|
| 15 |
+
PairViewer = "PairViewer"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def global_aligner(
|
| 19 |
+
dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw
|
| 20 |
+
):
|
| 21 |
+
# extract all inputs
|
| 22 |
+
view1, view2, pred1, pred2 = [
|
| 23 |
+
dust3r_output[k] for k in "view1 view2 pred1 pred2".split()
|
| 24 |
+
]
|
| 25 |
+
# build the optimizer
|
| 26 |
+
if mode == GlobalAlignerMode.PointCloudOptimizer:
|
| 27 |
+
net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
|
| 28 |
+
else:
|
| 29 |
+
raise NotImplementedError(f"Unknown mode {mode}")
|
| 30 |
+
|
| 31 |
+
return net
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/base_opt.py
ADDED
|
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Base class for the global alignement procedure
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import roma
|
| 13 |
+
from copy import deepcopy
|
| 14 |
+
import tqdm
|
| 15 |
+
import cv2
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from dust3r.utils.geometry import inv, geotrf
|
| 18 |
+
from dust3r.utils.device import to_numpy
|
| 19 |
+
from dust3r.utils.image import rgb
|
| 20 |
+
from dust3r.viz import SceneViz, segment_sky, auto_cam_size
|
| 21 |
+
|
| 22 |
+
from cloud_opt.dust3r_opt.commons import (
|
| 23 |
+
edge_str,
|
| 24 |
+
ALL_DISTS,
|
| 25 |
+
NoGradParamDict,
|
| 26 |
+
get_imshapes,
|
| 27 |
+
signed_expm1,
|
| 28 |
+
signed_log1p,
|
| 29 |
+
cosine_schedule,
|
| 30 |
+
linear_schedule,
|
| 31 |
+
get_conf_trf,
|
| 32 |
+
)
|
| 33 |
+
import cloud_opt.dust3r_opt.init_im_poses as init_fun
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
from scipy.spatial.transform import Rotation
|
| 36 |
+
from evo.core.trajectory import PosePath3D, PoseTrajectory3D
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def adjust_learning_rate_by_lr(optimizer, lr):
|
| 40 |
+
for param_group in optimizer.param_groups:
|
| 41 |
+
if "lr_scale" in param_group:
|
| 42 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
| 43 |
+
else:
|
| 44 |
+
param_group["lr"] = lr
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def make_traj(args) -> PoseTrajectory3D:
|
| 48 |
+
if isinstance(args, tuple) or isinstance(args, list):
|
| 49 |
+
traj, tstamps = args
|
| 50 |
+
return PoseTrajectory3D(
|
| 51 |
+
positions_xyz=traj[:, :3],
|
| 52 |
+
orientations_quat_wxyz=traj[:, 3:],
|
| 53 |
+
timestamps=tstamps,
|
| 54 |
+
)
|
| 55 |
+
assert isinstance(args, PoseTrajectory3D), type(args)
|
| 56 |
+
return deepcopy(args)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def save_trajectory_tum_format(traj, filename):
|
| 60 |
+
traj = make_traj(traj)
|
| 61 |
+
tostr = lambda a: " ".join(map(str, a))
|
| 62 |
+
with Path(filename).open("w") as f:
|
| 63 |
+
for i in range(traj.num_poses):
|
| 64 |
+
f.write(
|
| 65 |
+
f"{traj.timestamps[i]} {tostr(traj.positions_xyz[i])} {tostr(traj.orientations_quat_wxyz[i][[0,1,2,3]])}\n"
|
| 66 |
+
)
|
| 67 |
+
print(f"Saved trajectory to {filename}")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def c2w_to_tumpose(c2w):
|
| 71 |
+
"""
|
| 72 |
+
Convert a camera-to-world matrix to a tuple of translation and rotation
|
| 73 |
+
|
| 74 |
+
input: c2w: 4x4 matrix
|
| 75 |
+
output: tuple of translation and rotation (x y z qw qx qy qz)
|
| 76 |
+
"""
|
| 77 |
+
# convert input to numpy
|
| 78 |
+
c2w = to_numpy(c2w)
|
| 79 |
+
xyz = c2w[:3, -1]
|
| 80 |
+
rot = Rotation.from_matrix(c2w[:3, :3])
|
| 81 |
+
qx, qy, qz, qw = rot.as_quat()
|
| 82 |
+
tum_pose = np.concatenate([xyz, [qw, qx, qy, qz]])
|
| 83 |
+
return tum_pose
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class BasePCOptimizer(nn.Module):
|
| 87 |
+
"""Optimize a global scene, given a list of pairwise observations.
|
| 88 |
+
Graph node: images
|
| 89 |
+
Graph edges: observations = (pred1, pred2)
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, *args, **kwargs):
|
| 93 |
+
if len(args) == 1 and len(kwargs) == 0:
|
| 94 |
+
other = deepcopy(args[0])
|
| 95 |
+
attrs = """edges is_symmetrized dist n_imgs pred_i pred_j imshapes
|
| 96 |
+
min_conf_thr conf_thr conf_i conf_j im_conf
|
| 97 |
+
base_scale norm_pw_scale POSE_DIM pw_poses
|
| 98 |
+
pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose""".split()
|
| 99 |
+
self.__dict__.update({k: other[k] for k in attrs})
|
| 100 |
+
else:
|
| 101 |
+
self._init_from_views(*args, **kwargs)
|
| 102 |
+
|
| 103 |
+
def _init_from_views(
|
| 104 |
+
self,
|
| 105 |
+
view1,
|
| 106 |
+
view2,
|
| 107 |
+
pred1,
|
| 108 |
+
pred2,
|
| 109 |
+
dist="l1",
|
| 110 |
+
conf="log",
|
| 111 |
+
min_conf_thr=3,
|
| 112 |
+
base_scale=0.5,
|
| 113 |
+
allow_pw_adaptors=False,
|
| 114 |
+
pw_break=20,
|
| 115 |
+
rand_pose=torch.randn,
|
| 116 |
+
iterationsCount=None,
|
| 117 |
+
verbose=True,
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
if not isinstance(view1["idx"], list):
|
| 121 |
+
view1["idx"] = view1["idx"].tolist()
|
| 122 |
+
if not isinstance(view2["idx"], list):
|
| 123 |
+
view2["idx"] = view2["idx"].tolist()
|
| 124 |
+
self.edges = [(int(i), int(j)) for i, j in zip(view1["idx"], view2["idx"])]
|
| 125 |
+
self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
|
| 126 |
+
self.dist = ALL_DISTS[dist]
|
| 127 |
+
self.verbose = verbose
|
| 128 |
+
|
| 129 |
+
self.n_imgs = self._check_edges()
|
| 130 |
+
|
| 131 |
+
# input data
|
| 132 |
+
pred1_pts = pred1["pts3d_in_self_view"]
|
| 133 |
+
pred2_pts = pred2["pts3d_in_other_view"]
|
| 134 |
+
self.pred_i = NoGradParamDict(
|
| 135 |
+
{ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}
|
| 136 |
+
)
|
| 137 |
+
self.pred_j = NoGradParamDict(
|
| 138 |
+
{ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}
|
| 139 |
+
)
|
| 140 |
+
self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)
|
| 141 |
+
|
| 142 |
+
# work in log-scale with conf
|
| 143 |
+
pred1_conf = pred1["conf_self"]
|
| 144 |
+
pred2_conf = pred2["conf"]
|
| 145 |
+
self.min_conf_thr = min_conf_thr
|
| 146 |
+
self.conf_trf = get_conf_trf(conf)
|
| 147 |
+
|
| 148 |
+
self.conf_i = NoGradParamDict(
|
| 149 |
+
{ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)}
|
| 150 |
+
)
|
| 151 |
+
self.conf_j = NoGradParamDict(
|
| 152 |
+
{ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)}
|
| 153 |
+
)
|
| 154 |
+
self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
|
| 155 |
+
for i in range(len(self.im_conf)):
|
| 156 |
+
self.im_conf[i].requires_grad = False
|
| 157 |
+
|
| 158 |
+
# pairwise pose parameters
|
| 159 |
+
self.base_scale = base_scale
|
| 160 |
+
self.norm_pw_scale = True
|
| 161 |
+
self.pw_break = pw_break
|
| 162 |
+
self.POSE_DIM = 7
|
| 163 |
+
self.pw_poses = nn.Parameter(
|
| 164 |
+
rand_pose((self.n_edges, 1 + self.POSE_DIM))
|
| 165 |
+
) # pairwise poses
|
| 166 |
+
self.pw_adaptors = nn.Parameter(
|
| 167 |
+
torch.zeros((self.n_edges, 2))
|
| 168 |
+
) # slight xy/z adaptation
|
| 169 |
+
self.pw_adaptors.requires_grad_(allow_pw_adaptors)
|
| 170 |
+
self.has_im_poses = False
|
| 171 |
+
self.rand_pose = rand_pose
|
| 172 |
+
|
| 173 |
+
# possibly store images for show_pointcloud
|
| 174 |
+
self.imgs = None
|
| 175 |
+
if "img" in view1 and "img" in view2:
|
| 176 |
+
imgs = [torch.zeros((3,) + hw) for hw in self.imshapes]
|
| 177 |
+
for v in range(len(self.edges)):
|
| 178 |
+
idx = view1["idx"][v]
|
| 179 |
+
imgs[idx] = view1["img"][v]
|
| 180 |
+
idx = view2["idx"][v]
|
| 181 |
+
imgs[idx] = view2["img"][v]
|
| 182 |
+
self.imgs = rgb(imgs)
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def n_edges(self):
|
| 186 |
+
return len(self.edges)
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def str_edges(self):
|
| 190 |
+
return [edge_str(i, j) for i, j in self.edges]
|
| 191 |
+
|
| 192 |
+
@property
|
| 193 |
+
def imsizes(self):
|
| 194 |
+
return [(w, h) for h, w in self.imshapes]
|
| 195 |
+
|
| 196 |
+
@property
|
| 197 |
+
def device(self):
|
| 198 |
+
return next(iter(self.parameters())).device
|
| 199 |
+
|
| 200 |
+
def state_dict(self, trainable=True):
|
| 201 |
+
all_params = super().state_dict()
|
| 202 |
+
return {
|
| 203 |
+
k: v
|
| 204 |
+
for k, v in all_params.items()
|
| 205 |
+
if k.startswith(("_", "pred_i.", "pred_j.", "conf_i.", "conf_j."))
|
| 206 |
+
!= trainable
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
def load_state_dict(self, data):
|
| 210 |
+
return super().load_state_dict(self.state_dict(trainable=False) | data)
|
| 211 |
+
|
| 212 |
+
def _check_edges(self):
|
| 213 |
+
indices = sorted({i for edge in self.edges for i in edge})
|
| 214 |
+
assert indices == list(range(len(indices))), "bad pair indices: missing values "
|
| 215 |
+
return len(indices)
|
| 216 |
+
|
| 217 |
+
@torch.no_grad()
|
| 218 |
+
def _compute_img_conf(self, pred1_conf, pred2_conf):
|
| 219 |
+
im_conf = nn.ParameterList(
|
| 220 |
+
[torch.zeros(hw, device=self.device) for hw in self.imshapes]
|
| 221 |
+
)
|
| 222 |
+
for e, (i, j) in enumerate(self.edges):
|
| 223 |
+
im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
|
| 224 |
+
im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
|
| 225 |
+
return im_conf
|
| 226 |
+
|
| 227 |
+
def get_adaptors(self):
|
| 228 |
+
adapt = self.pw_adaptors
|
| 229 |
+
adapt = torch.cat(
|
| 230 |
+
(adapt[:, 0:1], adapt), dim=-1
|
| 231 |
+
) # (scale_xy, scale_xy, scale_z)
|
| 232 |
+
if self.norm_pw_scale: # normalize so that the product == 1
|
| 233 |
+
adapt = adapt - adapt.mean(dim=1, keepdim=True)
|
| 234 |
+
return (adapt / self.pw_break).exp()
|
| 235 |
+
|
| 236 |
+
def _get_poses(self, poses):
|
| 237 |
+
# normalize rotation
|
| 238 |
+
Q = poses[:, :4]
|
| 239 |
+
T = signed_expm1(poses[:, 4:7])
|
| 240 |
+
RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
|
| 241 |
+
return RT
|
| 242 |
+
|
| 243 |
+
def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
|
| 244 |
+
# all poses == cam-to-world
|
| 245 |
+
pose = poses[idx]
|
| 246 |
+
if not (pose.requires_grad or force):
|
| 247 |
+
return pose
|
| 248 |
+
|
| 249 |
+
if R.shape == (4, 4):
|
| 250 |
+
assert T is None
|
| 251 |
+
T = R[:3, 3]
|
| 252 |
+
R = R[:3, :3]
|
| 253 |
+
|
| 254 |
+
if R is not None:
|
| 255 |
+
pose.data[0:4] = roma.rotmat_to_unitquat(R)
|
| 256 |
+
if T is not None:
|
| 257 |
+
pose.data[4:7] = signed_log1p(
|
| 258 |
+
T / (scale or 1)
|
| 259 |
+
) # translation is function of scale
|
| 260 |
+
|
| 261 |
+
if scale is not None:
|
| 262 |
+
assert poses.shape[-1] in (8, 13)
|
| 263 |
+
pose.data[-1] = np.log(float(scale))
|
| 264 |
+
return pose
|
| 265 |
+
|
| 266 |
+
def get_pw_norm_scale_factor(self):
|
| 267 |
+
if self.norm_pw_scale:
|
| 268 |
+
# normalize scales so that things cannot go south
|
| 269 |
+
# we want that exp(scale) ~= self.base_scale
|
| 270 |
+
return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
|
| 271 |
+
else:
|
| 272 |
+
return 1 # don't norm scale for known poses
|
| 273 |
+
|
| 274 |
+
def get_pw_scale(self):
|
| 275 |
+
scale = self.pw_poses[:, -1].exp() # (n_edges,)
|
| 276 |
+
scale = scale * self.get_pw_norm_scale_factor()
|
| 277 |
+
return scale
|
| 278 |
+
|
| 279 |
+
def get_pw_poses(self): # cam to world
|
| 280 |
+
RT = self._get_poses(self.pw_poses)
|
| 281 |
+
scaled_RT = RT.clone()
|
| 282 |
+
scaled_RT[:, :3] *= self.get_pw_scale().view(
|
| 283 |
+
-1, 1, 1
|
| 284 |
+
) # scale the rotation AND translation
|
| 285 |
+
return scaled_RT
|
| 286 |
+
|
| 287 |
+
def get_masks(self):
|
| 288 |
+
return [(conf > self.min_conf_thr) for conf in self.im_conf]
|
| 289 |
+
|
| 290 |
+
def depth_to_pts3d(self):
|
| 291 |
+
raise NotImplementedError()
|
| 292 |
+
|
| 293 |
+
def get_pts3d(self, raw=False):
|
| 294 |
+
res = self.depth_to_pts3d()
|
| 295 |
+
if not raw:
|
| 296 |
+
res = [dm[: h * w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
|
| 297 |
+
return res
|
| 298 |
+
|
| 299 |
+
def _set_focal(self, idx, focal, force=False):
|
| 300 |
+
raise NotImplementedError()
|
| 301 |
+
|
| 302 |
+
def get_focals(self):
|
| 303 |
+
raise NotImplementedError()
|
| 304 |
+
|
| 305 |
+
def get_known_focal_mask(self):
|
| 306 |
+
raise NotImplementedError()
|
| 307 |
+
|
| 308 |
+
def get_principal_points(self):
|
| 309 |
+
raise NotImplementedError()
|
| 310 |
+
|
| 311 |
+
def get_conf(self, mode=None):
|
| 312 |
+
trf = self.conf_trf if mode is None else get_conf_trf(mode)
|
| 313 |
+
return [trf(c) for c in self.im_conf]
|
| 314 |
+
|
| 315 |
+
def get_im_poses(self):
|
| 316 |
+
raise NotImplementedError()
|
| 317 |
+
|
| 318 |
+
def _set_depthmap(self, idx, depth, force=False):
|
| 319 |
+
raise NotImplementedError()
|
| 320 |
+
|
| 321 |
+
def get_depthmaps(self, raw=False):
|
| 322 |
+
raise NotImplementedError()
|
| 323 |
+
|
| 324 |
+
def save_depth_maps(self, path):
|
| 325 |
+
depth_maps = self.get_depthmaps()
|
| 326 |
+
images = []
|
| 327 |
+
|
| 328 |
+
for i, depth_map in enumerate(depth_maps):
|
| 329 |
+
# Apply color map to depth map
|
| 330 |
+
depth_map_colored = cv2.applyColorMap(
|
| 331 |
+
(depth_map * 255).detach().cpu().numpy().astype(np.uint8),
|
| 332 |
+
cv2.COLORMAP_JET,
|
| 333 |
+
)
|
| 334 |
+
img_path = f"{path}/frame_{(i):04d}.png"
|
| 335 |
+
cv2.imwrite(img_path, depth_map_colored)
|
| 336 |
+
images.append(Image.open(img_path))
|
| 337 |
+
np.save(f"{path}/frame_{(i):04d}.npy", depth_map.detach().cpu().numpy())
|
| 338 |
+
|
| 339 |
+
images[0].save(
|
| 340 |
+
f"{path}/_depth_maps.gif",
|
| 341 |
+
save_all=True,
|
| 342 |
+
append_images=images[1:],
|
| 343 |
+
duration=100,
|
| 344 |
+
loop=0,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
return depth_maps
|
| 348 |
+
|
| 349 |
+
def clean_pointcloud(self, **kw):
|
| 350 |
+
cams = inv(self.get_im_poses())
|
| 351 |
+
K = self.get_intrinsics()
|
| 352 |
+
depthmaps = self.get_depthmaps()
|
| 353 |
+
all_pts3d = self.get_pts3d()
|
| 354 |
+
|
| 355 |
+
new_im_confs = clean_pointcloud(
|
| 356 |
+
self.im_conf, K, cams, depthmaps, all_pts3d, **kw
|
| 357 |
+
)
|
| 358 |
+
for i, new_conf in enumerate(new_im_confs):
|
| 359 |
+
self.im_conf[i].data[:] = new_conf
|
| 360 |
+
return self
|
| 361 |
+
|
| 362 |
+
def get_tum_poses(self):
|
| 363 |
+
poses = self.get_im_poses()
|
| 364 |
+
tt = np.arange(len(poses)).astype(float)
|
| 365 |
+
tum_poses = [c2w_to_tumpose(p) for p in poses]
|
| 366 |
+
tum_poses = np.stack(tum_poses, 0)
|
| 367 |
+
return [tum_poses, tt]
|
| 368 |
+
|
| 369 |
+
def save_tum_poses(self, path):
|
| 370 |
+
traj = self.get_tum_poses()
|
| 371 |
+
save_trajectory_tum_format(traj, path)
|
| 372 |
+
return traj[0] # return the poses
|
| 373 |
+
|
| 374 |
+
def save_focals(self, path):
|
| 375 |
+
# convert focal to txt
|
| 376 |
+
focals = self.get_focals()
|
| 377 |
+
np.savetxt(path, focals.detach().cpu().numpy(), fmt="%.6f")
|
| 378 |
+
return focals
|
| 379 |
+
|
| 380 |
+
def save_intrinsics(self, path):
|
| 381 |
+
K_raw = self.get_intrinsics()
|
| 382 |
+
K = K_raw.reshape(-1, 9)
|
| 383 |
+
np.savetxt(path, K.detach().cpu().numpy(), fmt="%.6f")
|
| 384 |
+
return K_raw
|
| 385 |
+
|
| 386 |
+
def save_conf_maps(self, path):
|
| 387 |
+
conf = self.get_conf()
|
| 388 |
+
for i, c in enumerate(conf):
|
| 389 |
+
np.save(f"{path}/conf_{i}.npy", c.detach().cpu().numpy())
|
| 390 |
+
return conf
|
| 391 |
+
|
| 392 |
+
def save_init_conf_maps(self, path):
|
| 393 |
+
conf = self.get_init_conf()
|
| 394 |
+
for i, c in enumerate(conf):
|
| 395 |
+
np.save(f"{path}/init_conf_{i}.npy", c.detach().cpu().numpy())
|
| 396 |
+
return conf
|
| 397 |
+
|
| 398 |
+
def save_rgb_imgs(self, path):
|
| 399 |
+
imgs = self.imgs
|
| 400 |
+
for i, img in enumerate(imgs):
|
| 401 |
+
# convert from rgb to bgr
|
| 402 |
+
img = img[..., ::-1]
|
| 403 |
+
cv2.imwrite(f"{path}/frame_{i:04d}.png", img * 255)
|
| 404 |
+
return imgs
|
| 405 |
+
|
| 406 |
+
def save_dynamic_masks(self, path):
|
| 407 |
+
dynamic_masks = (
|
| 408 |
+
self.dynamic_masks
|
| 409 |
+
if getattr(self, "sam2_dynamic_masks", None) is None
|
| 410 |
+
else self.sam2_dynamic_masks
|
| 411 |
+
)
|
| 412 |
+
for i, dynamic_mask in enumerate(dynamic_masks):
|
| 413 |
+
cv2.imwrite(
|
| 414 |
+
f"{path}/dynamic_mask_{i}.png",
|
| 415 |
+
(dynamic_mask * 255).detach().cpu().numpy().astype(np.uint8),
|
| 416 |
+
)
|
| 417 |
+
return dynamic_masks
|
| 418 |
+
|
| 419 |
+
def save_depth_maps(self, path):
|
| 420 |
+
depth_maps = self.get_depthmaps()
|
| 421 |
+
images = []
|
| 422 |
+
|
| 423 |
+
for i, depth_map in enumerate(depth_maps):
|
| 424 |
+
# Apply color map to depth map
|
| 425 |
+
depth_map_colored = cv2.applyColorMap(
|
| 426 |
+
(depth_map * 255).detach().cpu().numpy().astype(np.uint8),
|
| 427 |
+
cv2.COLORMAP_JET,
|
| 428 |
+
)
|
| 429 |
+
img_path = f"{path}/frame_{(i):04d}.png"
|
| 430 |
+
cv2.imwrite(img_path, depth_map_colored)
|
| 431 |
+
images.append(Image.open(img_path))
|
| 432 |
+
np.save(f"{path}/frame_{(i):04d}.npy", depth_map.detach().cpu().numpy())
|
| 433 |
+
|
| 434 |
+
images[0].save(
|
| 435 |
+
f"{path}/_depth_maps.gif",
|
| 436 |
+
save_all=True,
|
| 437 |
+
append_images=images[1:],
|
| 438 |
+
duration=100,
|
| 439 |
+
loop=0,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
return depth_maps
|
| 443 |
+
|
| 444 |
+
def forward(self, ret_details=False):
|
| 445 |
+
pw_poses = self.get_pw_poses() # cam-to-world
|
| 446 |
+
pw_adapt = self.get_adaptors()
|
| 447 |
+
proj_pts3d = self.get_pts3d()
|
| 448 |
+
# pre-compute pixel weights
|
| 449 |
+
weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
|
| 450 |
+
weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
|
| 451 |
+
|
| 452 |
+
loss = 0
|
| 453 |
+
if ret_details:
|
| 454 |
+
details = -torch.ones((self.n_imgs, self.n_imgs))
|
| 455 |
+
|
| 456 |
+
for e, (i, j) in enumerate(self.edges):
|
| 457 |
+
i_j = edge_str(i, j)
|
| 458 |
+
# distance in image i and j
|
| 459 |
+
aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
|
| 460 |
+
aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
|
| 461 |
+
li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
|
| 462 |
+
lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
|
| 463 |
+
loss = loss + li + lj
|
| 464 |
+
|
| 465 |
+
if ret_details:
|
| 466 |
+
details[i, j] = li + lj
|
| 467 |
+
loss /= self.n_edges # average over all pairs
|
| 468 |
+
|
| 469 |
+
if ret_details:
|
| 470 |
+
return loss, details
|
| 471 |
+
return loss
|
| 472 |
+
|
| 473 |
+
@torch.cuda.amp.autocast(enabled=False)
|
| 474 |
+
def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
|
| 475 |
+
if init is None:
|
| 476 |
+
pass
|
| 477 |
+
elif init == "msp" or init == "mst":
|
| 478 |
+
init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
|
| 479 |
+
elif init == "known_poses":
|
| 480 |
+
init_fun.init_from_known_poses(
|
| 481 |
+
self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP
|
| 482 |
+
)
|
| 483 |
+
else:
|
| 484 |
+
raise ValueError(f"bad value for {init=}")
|
| 485 |
+
return global_alignment_loop(self, **kw)
|
| 486 |
+
|
| 487 |
+
@torch.no_grad()
|
| 488 |
+
def mask_sky(self):
|
| 489 |
+
res = deepcopy(self)
|
| 490 |
+
for i in range(self.n_imgs):
|
| 491 |
+
sky = segment_sky(self.imgs[i])
|
| 492 |
+
res.im_conf[i][sky] = 0
|
| 493 |
+
return res
|
| 494 |
+
|
| 495 |
+
def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
|
| 496 |
+
viz = SceneViz()
|
| 497 |
+
if self.imgs is None:
|
| 498 |
+
colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
|
| 499 |
+
colors = list(map(tuple, colors.tolist()))
|
| 500 |
+
for n in range(self.n_imgs):
|
| 501 |
+
viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
|
| 502 |
+
else:
|
| 503 |
+
viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
|
| 504 |
+
colors = np.random.randint(256, size=(self.n_imgs, 3))
|
| 505 |
+
|
| 506 |
+
# camera poses
|
| 507 |
+
im_poses = to_numpy(self.get_im_poses())
|
| 508 |
+
if cam_size is None:
|
| 509 |
+
cam_size = auto_cam_size(im_poses)
|
| 510 |
+
viz.add_cameras(
|
| 511 |
+
im_poses,
|
| 512 |
+
self.get_focals(),
|
| 513 |
+
colors=colors,
|
| 514 |
+
images=self.imgs,
|
| 515 |
+
imsizes=self.imsizes,
|
| 516 |
+
cam_size=cam_size,
|
| 517 |
+
)
|
| 518 |
+
if show_pw_cams:
|
| 519 |
+
pw_poses = self.get_pw_poses()
|
| 520 |
+
viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)
|
| 521 |
+
|
| 522 |
+
if show_pw_pts3d:
|
| 523 |
+
pts = [
|
| 524 |
+
geotrf(pw_poses[e], self.pred_i[edge_str(i, j)])
|
| 525 |
+
for e, (i, j) in enumerate(self.edges)
|
| 526 |
+
]
|
| 527 |
+
viz.add_pointcloud(pts, (128, 0, 128))
|
| 528 |
+
|
| 529 |
+
viz.show(**kw)
|
| 530 |
+
return viz
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def global_alignment_loop(net, lr=0.01, niter=300, schedule="cosine", lr_min=1e-6):
|
| 534 |
+
params = [p for p in net.parameters() if p.requires_grad]
|
| 535 |
+
if not params:
|
| 536 |
+
return net
|
| 537 |
+
|
| 538 |
+
verbose = net.verbose
|
| 539 |
+
if verbose:
|
| 540 |
+
print("Global alignement - optimizing for:")
|
| 541 |
+
print([name for name, value in net.named_parameters() if value.requires_grad])
|
| 542 |
+
|
| 543 |
+
lr_base = lr
|
| 544 |
+
optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
|
| 545 |
+
|
| 546 |
+
loss = float("inf")
|
| 547 |
+
if verbose:
|
| 548 |
+
with tqdm.tqdm(total=niter) as bar:
|
| 549 |
+
while bar.n < bar.total:
|
| 550 |
+
loss, lr = global_alignment_iter(
|
| 551 |
+
net, bar.n, niter, lr_base, lr_min, optimizer, schedule
|
| 552 |
+
)
|
| 553 |
+
bar.set_postfix_str(f"{lr=:g} loss={loss:g}")
|
| 554 |
+
bar.update()
|
| 555 |
+
else:
|
| 556 |
+
for n in range(niter):
|
| 557 |
+
loss, _ = global_alignment_iter(
|
| 558 |
+
net, n, niter, lr_base, lr_min, optimizer, schedule
|
| 559 |
+
)
|
| 560 |
+
return loss
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule):
|
| 564 |
+
t = cur_iter / niter
|
| 565 |
+
if schedule == "cosine":
|
| 566 |
+
lr = cosine_schedule(t, lr_base, lr_min)
|
| 567 |
+
elif schedule == "linear":
|
| 568 |
+
lr = linear_schedule(t, lr_base, lr_min)
|
| 569 |
+
else:
|
| 570 |
+
raise ValueError(f"bad lr {schedule=}")
|
| 571 |
+
adjust_learning_rate_by_lr(optimizer, lr)
|
| 572 |
+
optimizer.zero_grad()
|
| 573 |
+
loss = net()
|
| 574 |
+
loss.backward()
|
| 575 |
+
optimizer.step()
|
| 576 |
+
|
| 577 |
+
return float(loss), lr
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
@torch.no_grad()
|
| 581 |
+
def clean_pointcloud(
|
| 582 |
+
im_confs, K, cams, depthmaps, all_pts3d, tol=0.001, bad_conf=0, dbg=()
|
| 583 |
+
):
|
| 584 |
+
"""Method:
|
| 585 |
+
1) express all 3d points in each camera coordinate frame
|
| 586 |
+
2) if they're in front of a depthmap --> then lower their confidence
|
| 587 |
+
"""
|
| 588 |
+
assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d)
|
| 589 |
+
assert 0 <= tol < 1
|
| 590 |
+
res = [c.clone() for c in im_confs]
|
| 591 |
+
|
| 592 |
+
# reshape appropriately
|
| 593 |
+
all_pts3d = [p.view(*c.shape, 3) for p, c in zip(all_pts3d, im_confs)]
|
| 594 |
+
depthmaps = [d.view(*c.shape) for d, c in zip(depthmaps, im_confs)]
|
| 595 |
+
|
| 596 |
+
for i, pts3d in enumerate(all_pts3d):
|
| 597 |
+
for j in range(len(all_pts3d)):
|
| 598 |
+
if i == j:
|
| 599 |
+
continue
|
| 600 |
+
|
| 601 |
+
# project 3dpts in other view
|
| 602 |
+
proj = geotrf(cams[j], pts3d)
|
| 603 |
+
proj_depth = proj[:, :, 2]
|
| 604 |
+
u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
|
| 605 |
+
|
| 606 |
+
# check which points are actually in the visible cone
|
| 607 |
+
H, W = im_confs[j].shape
|
| 608 |
+
msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H)
|
| 609 |
+
msk_j = v[msk_i], u[msk_i]
|
| 610 |
+
|
| 611 |
+
# find bad points = those in front but less confident
|
| 612 |
+
bad_points = (proj_depth[msk_i] < (1 - tol) * depthmaps[j][msk_j]) & (
|
| 613 |
+
res[i][msk_i] < res[j][msk_j]
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
bad_msk_i = msk_i.clone()
|
| 617 |
+
bad_msk_i[msk_i] = bad_points
|
| 618 |
+
res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf)
|
| 619 |
+
|
| 620 |
+
return res
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/commons.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# utility functions for global alignment
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def edge_str(i, j):
|
| 13 |
+
return f"{i}_{j}"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def i_j_ij(ij):
|
| 17 |
+
return edge_str(*ij), ij
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def edge_conf(conf_i, conf_j, edge):
|
| 21 |
+
return float(conf_i[edge].mean() * conf_j[edge].mean())
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def compute_edge_scores(edges, conf_i, conf_j):
|
| 25 |
+
return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def NoGradParamDict(x):
|
| 29 |
+
assert isinstance(x, dict)
|
| 30 |
+
return nn.ParameterDict(x).requires_grad_(False)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_imshapes(edges, pred_i, pred_j):
|
| 34 |
+
n_imgs = max(max(e) for e in edges) + 1
|
| 35 |
+
imshapes = [None] * n_imgs
|
| 36 |
+
for e, (i, j) in enumerate(edges):
|
| 37 |
+
shape_i = tuple(pred_i[e].shape[0:2])
|
| 38 |
+
shape_j = tuple(pred_j[e].shape[0:2])
|
| 39 |
+
if imshapes[i]:
|
| 40 |
+
assert imshapes[i] == shape_i, f"incorrect shape for image {i}"
|
| 41 |
+
if imshapes[j]:
|
| 42 |
+
assert imshapes[j] == shape_j, f"incorrect shape for image {j}"
|
| 43 |
+
imshapes[i] = shape_i
|
| 44 |
+
imshapes[j] = shape_j
|
| 45 |
+
return imshapes
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_conf_trf(mode):
|
| 49 |
+
if mode == "log":
|
| 50 |
+
|
| 51 |
+
def conf_trf(x):
|
| 52 |
+
return x.log()
|
| 53 |
+
|
| 54 |
+
elif mode == "sqrt":
|
| 55 |
+
|
| 56 |
+
def conf_trf(x):
|
| 57 |
+
return x.sqrt()
|
| 58 |
+
|
| 59 |
+
elif mode == "m1":
|
| 60 |
+
|
| 61 |
+
def conf_trf(x):
|
| 62 |
+
return x - 1
|
| 63 |
+
|
| 64 |
+
elif mode in ("id", "none"):
|
| 65 |
+
|
| 66 |
+
def conf_trf(x):
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError(f"bad mode for {mode=}")
|
| 71 |
+
return conf_trf
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def l2_dist(a, b, weight):
|
| 75 |
+
return (a - b).square().sum(dim=-1) * weight
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def l1_dist(a, b, weight):
|
| 79 |
+
return (a - b).norm(dim=-1) * weight
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def signed_log1p(x):
|
| 86 |
+
sign = torch.sign(x)
|
| 87 |
+
return sign * torch.log1p(torch.abs(x))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def signed_expm1(x):
|
| 91 |
+
sign = torch.sign(x)
|
| 92 |
+
return sign * torch.expm1(torch.abs(x))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def cosine_schedule(t, lr_start, lr_end):
|
| 96 |
+
assert 0 <= t <= 1
|
| 97 |
+
return lr_end + (lr_start - lr_end) * (1 + np.cos(t * np.pi)) / 2
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def linear_schedule(t, lr_start, lr_end):
|
| 101 |
+
assert 0 <= t <= 1
|
| 102 |
+
return lr_start + (lr_end - lr_start) * t
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/init_im_poses.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Initialization functions for global alignment
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
from functools import cache
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import scipy.sparse as sp
|
| 11 |
+
import torch
|
| 12 |
+
import cv2
|
| 13 |
+
import roma
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses
|
| 17 |
+
from dust3r.post_process import estimate_focal_knowing_depth
|
| 18 |
+
from dust3r.viz import to_numpy
|
| 19 |
+
|
| 20 |
+
from cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@torch.no_grad()
|
| 24 |
+
def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):
|
| 25 |
+
device = self.device
|
| 26 |
+
|
| 27 |
+
# indices of known poses
|
| 28 |
+
nkp, known_poses_msk, known_poses = get_known_poses(self)
|
| 29 |
+
assert nkp == self.n_imgs, "not all poses are known"
|
| 30 |
+
|
| 31 |
+
# get all focals
|
| 32 |
+
nkf, _, im_focals = get_known_focals(self)
|
| 33 |
+
assert nkf == self.n_imgs
|
| 34 |
+
im_pp = self.get_principal_points()
|
| 35 |
+
|
| 36 |
+
best_depthmaps = {}
|
| 37 |
+
# init all pairwise poses
|
| 38 |
+
for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)):
|
| 39 |
+
i_j = edge_str(i, j)
|
| 40 |
+
|
| 41 |
+
# find relative pose for this pair
|
| 42 |
+
P1 = torch.eye(4, device=device)
|
| 43 |
+
msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1)
|
| 44 |
+
_, P2 = fast_pnp(
|
| 45 |
+
self.pred_j[i_j],
|
| 46 |
+
float(im_focals[i].mean()),
|
| 47 |
+
pp=im_pp[i],
|
| 48 |
+
msk=msk,
|
| 49 |
+
device=device,
|
| 50 |
+
niter_PnP=niter_PnP,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# align the two predicted camera with the two gt cameras
|
| 54 |
+
s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]])
|
| 55 |
+
# normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1
|
| 56 |
+
# and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3])
|
| 57 |
+
self._set_pose(self.pw_poses, e, R, T, scale=s)
|
| 58 |
+
|
| 59 |
+
# remember if this is a good depthmap
|
| 60 |
+
score = float(self.conf_i[i_j].mean())
|
| 61 |
+
if score > best_depthmaps.get(i, (0,))[0]:
|
| 62 |
+
best_depthmaps[i] = score, i_j, s
|
| 63 |
+
|
| 64 |
+
# init all image poses
|
| 65 |
+
for n in range(self.n_imgs):
|
| 66 |
+
assert known_poses_msk[n]
|
| 67 |
+
_, i_j, scale = best_depthmaps[n]
|
| 68 |
+
depth = self.pred_i[i_j][:, :, 2]
|
| 69 |
+
self._set_depthmap(n, depth * scale)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
def init_minimum_spanning_tree(self, **kw):
|
| 74 |
+
"""Init all camera poses (image-wise and pairwise poses) given
|
| 75 |
+
an initial set of pairwise estimations.
|
| 76 |
+
"""
|
| 77 |
+
device = self.device
|
| 78 |
+
pts3d, _, im_focals, im_poses = minimum_spanning_tree(
|
| 79 |
+
self.imshapes,
|
| 80 |
+
self.edges,
|
| 81 |
+
self.pred_i,
|
| 82 |
+
self.pred_j,
|
| 83 |
+
self.conf_i,
|
| 84 |
+
self.conf_j,
|
| 85 |
+
self.im_conf,
|
| 86 |
+
self.min_conf_thr,
|
| 87 |
+
device,
|
| 88 |
+
has_im_poses=self.has_im_poses,
|
| 89 |
+
verbose=self.verbose,
|
| 90 |
+
**kw,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
return init_from_pts3d(self, pts3d, im_focals, im_poses)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def init_from_pts3d(self, pts3d, im_focals, im_poses):
|
| 97 |
+
# init poses
|
| 98 |
+
nkp, known_poses_msk, known_poses = get_known_poses(self)
|
| 99 |
+
if nkp == 1:
|
| 100 |
+
raise NotImplementedError(
|
| 101 |
+
"Would be simpler to just align everything afterwards on the single known pose"
|
| 102 |
+
)
|
| 103 |
+
elif nkp > 1:
|
| 104 |
+
# global rigid SE3 alignment
|
| 105 |
+
s, R, T = align_multiple_poses(
|
| 106 |
+
im_poses[known_poses_msk], known_poses[known_poses_msk]
|
| 107 |
+
)
|
| 108 |
+
trf = sRT_to_4x4(s, R, T, device=known_poses.device)
|
| 109 |
+
|
| 110 |
+
# rotate everything
|
| 111 |
+
im_poses = trf @ im_poses
|
| 112 |
+
im_poses[:, :3, :3] /= s # undo scaling on the rotation part
|
| 113 |
+
for img_pts3d in pts3d:
|
| 114 |
+
img_pts3d[:] = geotrf(trf, img_pts3d)
|
| 115 |
+
|
| 116 |
+
# set all pairwise poses
|
| 117 |
+
for e, (i, j) in enumerate(self.edges):
|
| 118 |
+
i_j = edge_str(i, j)
|
| 119 |
+
# compute transform that goes from cam to world
|
| 120 |
+
s, R, T = rigid_points_registration(
|
| 121 |
+
self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]
|
| 122 |
+
)
|
| 123 |
+
self._set_pose(self.pw_poses, e, R, T, scale=s)
|
| 124 |
+
|
| 125 |
+
# take into account the scale normalization
|
| 126 |
+
s_factor = self.get_pw_norm_scale_factor()
|
| 127 |
+
im_poses[:, :3, 3] *= s_factor # apply downscaling factor
|
| 128 |
+
for img_pts3d in pts3d:
|
| 129 |
+
img_pts3d *= s_factor
|
| 130 |
+
|
| 131 |
+
# init all image poses
|
| 132 |
+
if self.has_im_poses:
|
| 133 |
+
for i in range(self.n_imgs):
|
| 134 |
+
cam2world = im_poses[i]
|
| 135 |
+
depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
|
| 136 |
+
self._set_depthmap(i, depth)
|
| 137 |
+
self._set_pose(self.im_poses, i, cam2world)
|
| 138 |
+
if im_focals[i] is not None:
|
| 139 |
+
self._set_focal(i, im_focals[i])
|
| 140 |
+
|
| 141 |
+
if self.verbose:
|
| 142 |
+
pass
|
| 143 |
+
# print(' init loss =', float(self()))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def minimum_spanning_tree(
|
| 147 |
+
imshapes,
|
| 148 |
+
edges,
|
| 149 |
+
pred_i,
|
| 150 |
+
pred_j,
|
| 151 |
+
conf_i,
|
| 152 |
+
conf_j,
|
| 153 |
+
im_conf,
|
| 154 |
+
min_conf_thr,
|
| 155 |
+
device,
|
| 156 |
+
has_im_poses=True,
|
| 157 |
+
niter_PnP=10,
|
| 158 |
+
verbose=True,
|
| 159 |
+
):
|
| 160 |
+
n_imgs = len(imshapes)
|
| 161 |
+
sparse_graph = -dict_to_sparse_graph(
|
| 162 |
+
compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j)
|
| 163 |
+
)
|
| 164 |
+
print(sparse_graph)
|
| 165 |
+
msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()
|
| 166 |
+
|
| 167 |
+
# temp variable to store 3d points
|
| 168 |
+
pts3d = [None] * len(imshapes)
|
| 169 |
+
|
| 170 |
+
todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges
|
| 171 |
+
im_poses = [None] * n_imgs
|
| 172 |
+
im_focals = [None] * n_imgs
|
| 173 |
+
|
| 174 |
+
# init with strongest edge
|
| 175 |
+
score, i, j = todo.pop()
|
| 176 |
+
if verbose:
|
| 177 |
+
print(f" init edge ({i}*,{j}*) {score=}")
|
| 178 |
+
i_j = edge_str(i, j)
|
| 179 |
+
pts3d[i] = pred_i[i_j].clone()
|
| 180 |
+
pts3d[j] = pred_j[i_j].clone()
|
| 181 |
+
done = {i, j}
|
| 182 |
+
if has_im_poses:
|
| 183 |
+
im_poses[i] = torch.eye(4, device=device)
|
| 184 |
+
im_focals[i] = estimate_focal(pred_i[i_j])
|
| 185 |
+
|
| 186 |
+
# set initial pointcloud based on pairwise graph
|
| 187 |
+
msp_edges = [(i, j)]
|
| 188 |
+
while todo:
|
| 189 |
+
# each time, predict the next one
|
| 190 |
+
score, i, j = todo.pop()
|
| 191 |
+
|
| 192 |
+
if im_focals[i] is None:
|
| 193 |
+
im_focals[i] = estimate_focal(pred_i[i_j])
|
| 194 |
+
|
| 195 |
+
if i in done:
|
| 196 |
+
if verbose:
|
| 197 |
+
print(f" init edge ({i},{j}*) {score=}")
|
| 198 |
+
assert j not in done
|
| 199 |
+
# align pred[i] with pts3d[i], and then set j accordingly
|
| 200 |
+
i_j = edge_str(i, j)
|
| 201 |
+
s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j])
|
| 202 |
+
trf = sRT_to_4x4(s, R, T, device)
|
| 203 |
+
pts3d[j] = geotrf(trf, pred_j[i_j])
|
| 204 |
+
done.add(j)
|
| 205 |
+
msp_edges.append((i, j))
|
| 206 |
+
|
| 207 |
+
if has_im_poses and im_poses[i] is None:
|
| 208 |
+
im_poses[i] = sRT_to_4x4(1, R, T, device)
|
| 209 |
+
|
| 210 |
+
elif j in done:
|
| 211 |
+
if verbose:
|
| 212 |
+
print(f" init edge ({i}*,{j}) {score=}")
|
| 213 |
+
assert i not in done
|
| 214 |
+
i_j = edge_str(i, j)
|
| 215 |
+
s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j])
|
| 216 |
+
trf = sRT_to_4x4(s, R, T, device)
|
| 217 |
+
pts3d[i] = geotrf(trf, pred_i[i_j])
|
| 218 |
+
done.add(i)
|
| 219 |
+
msp_edges.append((i, j))
|
| 220 |
+
|
| 221 |
+
if has_im_poses and im_poses[i] is None:
|
| 222 |
+
im_poses[i] = sRT_to_4x4(1, R, T, device)
|
| 223 |
+
else:
|
| 224 |
+
# let's try again later
|
| 225 |
+
todo.insert(0, (score, i, j))
|
| 226 |
+
|
| 227 |
+
if has_im_poses:
|
| 228 |
+
# complete all missing informations
|
| 229 |
+
pair_scores = list(
|
| 230 |
+
sparse_graph.values()
|
| 231 |
+
) # already negative scores: less is best
|
| 232 |
+
edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[
|
| 233 |
+
np.argsort(pair_scores)
|
| 234 |
+
]
|
| 235 |
+
for i, j in edges_from_best_to_worse.tolist():
|
| 236 |
+
if im_focals[i] is None:
|
| 237 |
+
im_focals[i] = estimate_focal(pred_i[edge_str(i, j)])
|
| 238 |
+
|
| 239 |
+
for i in range(n_imgs):
|
| 240 |
+
if im_poses[i] is None:
|
| 241 |
+
msk = im_conf[i] > min_conf_thr
|
| 242 |
+
res = fast_pnp(
|
| 243 |
+
pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP
|
| 244 |
+
)
|
| 245 |
+
if res:
|
| 246 |
+
im_focals[i], im_poses[i] = res
|
| 247 |
+
if im_poses[i] is None:
|
| 248 |
+
im_poses[i] = torch.eye(4, device=device)
|
| 249 |
+
im_poses = torch.stack(im_poses)
|
| 250 |
+
else:
|
| 251 |
+
im_poses = im_focals = None
|
| 252 |
+
|
| 253 |
+
return pts3d, msp_edges, im_focals, im_poses
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def dict_to_sparse_graph(dic):
|
| 257 |
+
n_imgs = max(max(e) for e in dic) + 1
|
| 258 |
+
res = sp.dok_array((n_imgs, n_imgs))
|
| 259 |
+
for edge, value in dic.items():
|
| 260 |
+
res[edge] = value
|
| 261 |
+
return res
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def rigid_points_registration(pts1, pts2, conf):
|
| 265 |
+
R, T, s = roma.rigid_points_registration(
|
| 266 |
+
pts1.reshape(-1, 3),
|
| 267 |
+
pts2.reshape(-1, 3),
|
| 268 |
+
weights=conf.ravel(),
|
| 269 |
+
compute_scaling=True,
|
| 270 |
+
)
|
| 271 |
+
return s, R, T # return un-scaled (R, T)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def sRT_to_4x4(scale, R, T, device):
|
| 275 |
+
trf = torch.eye(4, device=device)
|
| 276 |
+
trf[:3, :3] = R * scale
|
| 277 |
+
trf[:3, 3] = T.ravel() # doesn't need scaling
|
| 278 |
+
return trf
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def estimate_focal(pts3d_i, pp=None):
|
| 282 |
+
if pp is None:
|
| 283 |
+
H, W, THREE = pts3d_i.shape
|
| 284 |
+
assert THREE == 3
|
| 285 |
+
pp = torch.tensor((W / 2, H / 2), device=pts3d_i.device)
|
| 286 |
+
focal = estimate_focal_knowing_depth(
|
| 287 |
+
pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode="weiszfeld"
|
| 288 |
+
).ravel()
|
| 289 |
+
return float(focal)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
@cache
|
| 293 |
+
def pixel_grid(H, W):
|
| 294 |
+
return np.mgrid[:W, :H].T.astype(np.float32)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
|
| 298 |
+
# extract camera poses and focals with RANSAC-PnP
|
| 299 |
+
if msk.sum() < 4:
|
| 300 |
+
return None # we need at least 4 points for PnP
|
| 301 |
+
pts3d, msk = map(to_numpy, (pts3d, msk))
|
| 302 |
+
|
| 303 |
+
H, W, THREE = pts3d.shape
|
| 304 |
+
assert THREE == 3
|
| 305 |
+
pixels = pixel_grid(H, W)
|
| 306 |
+
|
| 307 |
+
if focal is None:
|
| 308 |
+
S = max(W, H)
|
| 309 |
+
tentative_focals = np.geomspace(S / 2, S * 3, 21)
|
| 310 |
+
else:
|
| 311 |
+
tentative_focals = [focal]
|
| 312 |
+
|
| 313 |
+
if pp is None:
|
| 314 |
+
pp = (W / 2, H / 2)
|
| 315 |
+
else:
|
| 316 |
+
pp = to_numpy(pp)
|
| 317 |
+
|
| 318 |
+
best = (0,)
|
| 319 |
+
for focal in tentative_focals:
|
| 320 |
+
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
|
| 321 |
+
try:
|
| 322 |
+
success, R, T, inliers = cv2.solvePnPRansac(
|
| 323 |
+
pts3d[msk],
|
| 324 |
+
pixels[msk],
|
| 325 |
+
K,
|
| 326 |
+
None,
|
| 327 |
+
iterationsCount=niter_PnP,
|
| 328 |
+
reprojectionError=5,
|
| 329 |
+
flags=cv2.SOLVEPNP_SQPNP,
|
| 330 |
+
)
|
| 331 |
+
if not success:
|
| 332 |
+
continue
|
| 333 |
+
except:
|
| 334 |
+
continue
|
| 335 |
+
|
| 336 |
+
score = len(inliers)
|
| 337 |
+
if success and score > best[0]:
|
| 338 |
+
best = score, R, T, focal
|
| 339 |
+
|
| 340 |
+
if not best[0]:
|
| 341 |
+
return None
|
| 342 |
+
|
| 343 |
+
_, R, T, best_focal = best
|
| 344 |
+
R = cv2.Rodrigues(R)[0] # world to cam
|
| 345 |
+
R, T = map(torch.from_numpy, (R, T))
|
| 346 |
+
return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def get_known_poses(self):
|
| 350 |
+
if self.has_im_poses:
|
| 351 |
+
known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses])
|
| 352 |
+
known_poses = self.get_im_poses()
|
| 353 |
+
return known_poses_msk.sum(), known_poses_msk, known_poses
|
| 354 |
+
else:
|
| 355 |
+
return 0, None, None
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def get_known_focals(self):
|
| 359 |
+
if self.has_im_poses:
|
| 360 |
+
known_focal_msk = self.get_known_focal_mask()
|
| 361 |
+
known_focals = self.get_focals()
|
| 362 |
+
return known_focal_msk.sum(), known_focal_msk, known_focals
|
| 363 |
+
else:
|
| 364 |
+
return 0, None, None
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def align_multiple_poses(src_poses, target_poses):
|
| 368 |
+
N = len(src_poses)
|
| 369 |
+
assert src_poses.shape == target_poses.shape == (N, 4, 4)
|
| 370 |
+
|
| 371 |
+
def center_and_z(poses):
|
| 372 |
+
eps = get_med_dist_between_poses(poses) / 100
|
| 373 |
+
return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps * poses[:, :3, 2]))
|
| 374 |
+
|
| 375 |
+
R, T, s = roma.rigid_points_registration(
|
| 376 |
+
center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True
|
| 377 |
+
)
|
| 378 |
+
return s, R, T
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/optimizer.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Main class for the implementation of the global alignment
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from cloud_opt.dust3r_opt.base_opt import BasePCOptimizer
|
| 12 |
+
from dust3r.utils.geometry import xy_grid, geotrf
|
| 13 |
+
from dust3r.utils.device import to_cpu, to_numpy
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PointCloudOptimizer(BasePCOptimizer):
|
| 17 |
+
"""Optimize a global scene, given a list of pairwise observations.
|
| 18 |
+
Graph node: images
|
| 19 |
+
Graph edges: observations = (pred1, pred2)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
|
| 25 |
+
self.has_im_poses = True # by definition of this class
|
| 26 |
+
self.focal_break = focal_break
|
| 27 |
+
|
| 28 |
+
# adding thing to optimize
|
| 29 |
+
self.im_depthmaps = nn.ParameterList(
|
| 30 |
+
torch.randn(H, W) / 10 - 3 for H, W in self.imshapes
|
| 31 |
+
) # log(depth)
|
| 32 |
+
self.im_poses = nn.ParameterList(
|
| 33 |
+
self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)
|
| 34 |
+
) # camera poses
|
| 35 |
+
self.im_focals = nn.ParameterList(
|
| 36 |
+
torch.FloatTensor([self.focal_break * np.log(max(H, W))])
|
| 37 |
+
for H, W in self.imshapes
|
| 38 |
+
) # camera intrinsics
|
| 39 |
+
self.im_pp = nn.ParameterList(
|
| 40 |
+
torch.zeros((2,)) for _ in range(self.n_imgs)
|
| 41 |
+
) # camera intrinsics
|
| 42 |
+
self.im_pp.requires_grad_(optimize_pp)
|
| 43 |
+
|
| 44 |
+
self.imshape = self.imshapes[0]
|
| 45 |
+
im_areas = [h * w for h, w in self.imshapes]
|
| 46 |
+
self.max_area = max(im_areas)
|
| 47 |
+
|
| 48 |
+
# adding thing to optimize
|
| 49 |
+
self.im_depthmaps = ParameterStack(
|
| 50 |
+
self.im_depthmaps, is_param=True, fill=self.max_area
|
| 51 |
+
)
|
| 52 |
+
self.im_poses = ParameterStack(self.im_poses, is_param=True)
|
| 53 |
+
self.im_focals = ParameterStack(self.im_focals, is_param=True)
|
| 54 |
+
self.im_pp = ParameterStack(self.im_pp, is_param=True)
|
| 55 |
+
self.register_buffer(
|
| 56 |
+
"_pp", torch.tensor([(w / 2, h / 2) for h, w in self.imshapes])
|
| 57 |
+
)
|
| 58 |
+
self.register_buffer(
|
| 59 |
+
"_grid",
|
| 60 |
+
ParameterStack(
|
| 61 |
+
[xy_grid(W, H, device=self.device) for H, W in self.imshapes],
|
| 62 |
+
fill=self.max_area,
|
| 63 |
+
),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# pre-compute pixel weights
|
| 67 |
+
self.register_buffer(
|
| 68 |
+
"_weight_i",
|
| 69 |
+
ParameterStack(
|
| 70 |
+
[self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges],
|
| 71 |
+
fill=self.max_area,
|
| 72 |
+
),
|
| 73 |
+
)
|
| 74 |
+
self.register_buffer(
|
| 75 |
+
"_weight_j",
|
| 76 |
+
ParameterStack(
|
| 77 |
+
[self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges],
|
| 78 |
+
fill=self.max_area,
|
| 79 |
+
),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# precompute aa
|
| 83 |
+
self.register_buffer(
|
| 84 |
+
"_stacked_pred_i",
|
| 85 |
+
ParameterStack(self.pred_i, self.str_edges, fill=self.max_area),
|
| 86 |
+
)
|
| 87 |
+
self.register_buffer(
|
| 88 |
+
"_stacked_pred_j",
|
| 89 |
+
ParameterStack(self.pred_j, self.str_edges, fill=self.max_area),
|
| 90 |
+
)
|
| 91 |
+
self.register_buffer("_ei", torch.tensor([i for i, j in self.edges]))
|
| 92 |
+
self.register_buffer("_ej", torch.tensor([j for i, j in self.edges]))
|
| 93 |
+
self.total_area_i = sum([im_areas[i] for i, j in self.edges])
|
| 94 |
+
self.total_area_j = sum([im_areas[j] for i, j in self.edges])
|
| 95 |
+
|
| 96 |
+
def _check_all_imgs_are_selected(self, msk):
|
| 97 |
+
assert np.all(
|
| 98 |
+
self._get_msk_indices(msk) == np.arange(self.n_imgs)
|
| 99 |
+
), "incomplete mask!"
|
| 100 |
+
|
| 101 |
+
def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
|
| 102 |
+
self._check_all_imgs_are_selected(pose_msk)
|
| 103 |
+
|
| 104 |
+
if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
|
| 105 |
+
known_poses = [known_poses]
|
| 106 |
+
for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
|
| 107 |
+
if self.verbose:
|
| 108 |
+
print(f" (setting pose #{idx} = {pose[:3,3]})")
|
| 109 |
+
self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))
|
| 110 |
+
|
| 111 |
+
# normalize scale if there's less than 1 known pose
|
| 112 |
+
n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
|
| 113 |
+
self.norm_pw_scale = n_known_poses <= 1
|
| 114 |
+
|
| 115 |
+
self.im_poses.requires_grad_(False)
|
| 116 |
+
self.norm_pw_scale = False
|
| 117 |
+
|
| 118 |
+
def preset_focal(self, known_focals, msk=None):
|
| 119 |
+
self._check_all_imgs_are_selected(msk)
|
| 120 |
+
|
| 121 |
+
for idx, focal in zip(self._get_msk_indices(msk), known_focals):
|
| 122 |
+
if self.verbose:
|
| 123 |
+
print(f" (setting focal #{idx} = {focal})")
|
| 124 |
+
self._no_grad(self._set_focal(idx, focal))
|
| 125 |
+
|
| 126 |
+
self.im_focals.requires_grad_(False)
|
| 127 |
+
|
| 128 |
+
def preset_principal_point(self, known_pp, msk=None):
|
| 129 |
+
self._check_all_imgs_are_selected(msk)
|
| 130 |
+
|
| 131 |
+
for idx, pp in zip(self._get_msk_indices(msk), known_pp):
|
| 132 |
+
if self.verbose:
|
| 133 |
+
print(f" (setting principal point #{idx} = {pp})")
|
| 134 |
+
self._no_grad(self._set_principal_point(idx, pp))
|
| 135 |
+
|
| 136 |
+
self.im_pp.requires_grad_(False)
|
| 137 |
+
|
| 138 |
+
def _get_msk_indices(self, msk):
|
| 139 |
+
if msk is None:
|
| 140 |
+
return range(self.n_imgs)
|
| 141 |
+
elif isinstance(msk, int):
|
| 142 |
+
return [msk]
|
| 143 |
+
elif isinstance(msk, (tuple, list)):
|
| 144 |
+
return self._get_msk_indices(np.array(msk))
|
| 145 |
+
elif msk.dtype in (bool, torch.bool, np.bool_):
|
| 146 |
+
assert len(msk) == self.n_imgs
|
| 147 |
+
return np.where(msk)[0]
|
| 148 |
+
elif np.issubdtype(msk.dtype, np.integer):
|
| 149 |
+
return msk
|
| 150 |
+
else:
|
| 151 |
+
raise ValueError(f"bad {msk=}")
|
| 152 |
+
|
| 153 |
+
def _no_grad(self, tensor):
|
| 154 |
+
assert (
|
| 155 |
+
tensor.requires_grad
|
| 156 |
+
), "it must be True at this point, otherwise no modification occurs"
|
| 157 |
+
|
| 158 |
+
def _set_focal(self, idx, focal, force=False):
|
| 159 |
+
param = self.im_focals[idx]
|
| 160 |
+
if (
|
| 161 |
+
param.requires_grad or force
|
| 162 |
+
): # can only init a parameter not already initialized
|
| 163 |
+
param.data[:] = self.focal_break * np.log(focal)
|
| 164 |
+
return param
|
| 165 |
+
|
| 166 |
+
def get_focals(self):
|
| 167 |
+
log_focals = torch.stack(list(self.im_focals), dim=0)
|
| 168 |
+
return (log_focals / self.focal_break).exp()
|
| 169 |
+
|
| 170 |
+
def get_known_focal_mask(self):
|
| 171 |
+
return torch.tensor([not (p.requires_grad) for p in self.im_focals])
|
| 172 |
+
|
| 173 |
+
def _set_principal_point(self, idx, pp, force=False):
|
| 174 |
+
param = self.im_pp[idx]
|
| 175 |
+
H, W = self.imshapes[idx]
|
| 176 |
+
if (
|
| 177 |
+
param.requires_grad or force
|
| 178 |
+
): # can only init a parameter not already initialized
|
| 179 |
+
param.data[:] = to_cpu(to_numpy(pp) - (W / 2, H / 2)) / 10
|
| 180 |
+
return param
|
| 181 |
+
|
| 182 |
+
def get_principal_points(self):
|
| 183 |
+
return self._pp + 10 * self.im_pp
|
| 184 |
+
|
| 185 |
+
def get_intrinsics(self):
|
| 186 |
+
K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
|
| 187 |
+
focals = self.get_focals().flatten()
|
| 188 |
+
K[:, 0, 0] = K[:, 1, 1] = focals
|
| 189 |
+
K[:, :2, 2] = self.get_principal_points()
|
| 190 |
+
K[:, 2, 2] = 1
|
| 191 |
+
return K
|
| 192 |
+
|
| 193 |
+
def get_im_poses(self): # cam to world
|
| 194 |
+
cam2world = self._get_poses(self.im_poses)
|
| 195 |
+
return cam2world
|
| 196 |
+
|
| 197 |
+
def _set_depthmap(self, idx, depth, force=False):
|
| 198 |
+
depth = _ravel_hw(depth, self.max_area)
|
| 199 |
+
|
| 200 |
+
param = self.im_depthmaps[idx]
|
| 201 |
+
if (
|
| 202 |
+
param.requires_grad or force
|
| 203 |
+
): # can only init a parameter not already initialized
|
| 204 |
+
param.data[:] = depth.log().nan_to_num(neginf=0)
|
| 205 |
+
return param
|
| 206 |
+
|
| 207 |
+
def get_depthmaps(self, raw=False):
|
| 208 |
+
res = self.im_depthmaps.exp()
|
| 209 |
+
if not raw:
|
| 210 |
+
res = [dm[: h * w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
|
| 211 |
+
return res
|
| 212 |
+
|
| 213 |
+
def depth_to_pts3d(self):
|
| 214 |
+
# Get depths and projection params if not provided
|
| 215 |
+
focals = self.get_focals()
|
| 216 |
+
pp = self.get_principal_points()
|
| 217 |
+
im_poses = self.get_im_poses()
|
| 218 |
+
depth = self.get_depthmaps(raw=True)
|
| 219 |
+
|
| 220 |
+
# get pointmaps in camera frame
|
| 221 |
+
rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp)
|
| 222 |
+
# project to world frame
|
| 223 |
+
return geotrf(im_poses, rel_ptmaps)
|
| 224 |
+
|
| 225 |
+
def get_pts3d(self, raw=False):
|
| 226 |
+
res = self.depth_to_pts3d()
|
| 227 |
+
if not raw:
|
| 228 |
+
res = [dm[: h * w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
|
| 229 |
+
return res
|
| 230 |
+
|
| 231 |
+
def forward(self):
|
| 232 |
+
pw_poses = self.get_pw_poses() # cam-to-world
|
| 233 |
+
pw_adapt = self.get_adaptors().unsqueeze(1)
|
| 234 |
+
proj_pts3d = self.get_pts3d(raw=True)
|
| 235 |
+
|
| 236 |
+
# rotate pairwise prediction according to pw_poses
|
| 237 |
+
aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i)
|
| 238 |
+
aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)
|
| 239 |
+
|
| 240 |
+
# compute the less
|
| 241 |
+
li = (
|
| 242 |
+
self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum()
|
| 243 |
+
/ self.total_area_i
|
| 244 |
+
)
|
| 245 |
+
lj = (
|
| 246 |
+
self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum()
|
| 247 |
+
/ self.total_area_j
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
return li + lj
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
|
| 254 |
+
pp = pp.unsqueeze(1)
|
| 255 |
+
focal = focal.unsqueeze(1)
|
| 256 |
+
assert focal.shape == (len(depth), 1, 1)
|
| 257 |
+
assert pp.shape == (len(depth), 1, 2)
|
| 258 |
+
assert pixel_grid.shape == depth.shape + (2,)
|
| 259 |
+
depth = depth.unsqueeze(-1)
|
| 260 |
+
return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def ParameterStack(params, keys=None, is_param=None, fill=0):
|
| 264 |
+
if keys is not None:
|
| 265 |
+
params = [params[k] for k in keys]
|
| 266 |
+
|
| 267 |
+
if fill > 0:
|
| 268 |
+
params = [_ravel_hw(p, fill) for p in params]
|
| 269 |
+
|
| 270 |
+
requires_grad = params[0].requires_grad
|
| 271 |
+
assert all(p.requires_grad == requires_grad for p in params)
|
| 272 |
+
|
| 273 |
+
params = torch.stack(list(params)).float().detach()
|
| 274 |
+
if is_param or requires_grad:
|
| 275 |
+
params = nn.Parameter(params)
|
| 276 |
+
params.requires_grad_(requires_grad)
|
| 277 |
+
return params
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def _ravel_hw(tensor, fill=0):
|
| 281 |
+
# ravel H,W
|
| 282 |
+
tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
| 283 |
+
|
| 284 |
+
if len(tensor) < fill:
|
| 285 |
+
tensor = torch.cat(
|
| 286 |
+
(tensor, tensor.new_zeros((fill - len(tensor),) + tensor.shape[1:]))
|
| 287 |
+
)
|
| 288 |
+
return tensor
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
|
| 292 |
+
focal_base = max(H, W) / (
|
| 293 |
+
2 * np.tan(np.deg2rad(60) / 2)
|
| 294 |
+
) # size / 1.1547005383792515
|
| 295 |
+
return minf * focal_base, maxf * focal_base
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def apply_mask(img, msk):
|
| 299 |
+
img = img.copy()
|
| 300 |
+
img[msk] = 0
|
| 301 |
+
return img
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/init_all.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import cache
|
| 2 |
+
import numpy as np
|
| 3 |
+
import scipy.sparse as sp
|
| 4 |
+
import torch
|
| 5 |
+
import cv2
|
| 6 |
+
import roma
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from cloud_opt.utils import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def compute_edge_scores(edges, edge2conf_i, edge2conf_j):
|
| 13 |
+
"""
|
| 14 |
+
edges: 'i_j', (i,j)
|
| 15 |
+
"""
|
| 16 |
+
score_dict = {
|
| 17 |
+
(i, j): edge_conf(edge2conf_i[e], edge2conf_j[e]) for e, (i, j) in edges
|
| 18 |
+
}
|
| 19 |
+
return score_dict
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def dict_to_sparse_graph(dic):
|
| 23 |
+
n_imgs = max(max(e) for e in dic) + 1
|
| 24 |
+
res = sp.dok_array((n_imgs, n_imgs))
|
| 25 |
+
for edge, value in dic.items():
|
| 26 |
+
res[edge] = value
|
| 27 |
+
return res
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@torch.no_grad()
|
| 31 |
+
def init_minimum_spanning_tree(self, **kw):
|
| 32 |
+
"""Init all camera poses (image-wise and pairwise poses) given
|
| 33 |
+
an initial set of pairwise estimations.
|
| 34 |
+
"""
|
| 35 |
+
device = self.device
|
| 36 |
+
pts3d, _, im_focals, im_poses = minimum_spanning_tree(
|
| 37 |
+
self.imshapes,
|
| 38 |
+
self.edges,
|
| 39 |
+
self.edge2pts_i,
|
| 40 |
+
self.edge2pts_j,
|
| 41 |
+
self.edge2conf_i,
|
| 42 |
+
self.edge2conf_j,
|
| 43 |
+
self.im_conf,
|
| 44 |
+
self.min_conf_thr,
|
| 45 |
+
device,
|
| 46 |
+
has_im_poses=self.has_im_poses,
|
| 47 |
+
verbose=self.verbose,
|
| 48 |
+
**kw,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return init_from_pts3d(self, pts3d, im_focals, im_poses)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def minimum_spanning_tree(
|
| 55 |
+
imshapes,
|
| 56 |
+
edges,
|
| 57 |
+
edge2pred_i,
|
| 58 |
+
edge2pred_j,
|
| 59 |
+
edge2conf_i,
|
| 60 |
+
edge2conf_j,
|
| 61 |
+
im_conf,
|
| 62 |
+
min_conf_thr,
|
| 63 |
+
device,
|
| 64 |
+
has_im_poses=True,
|
| 65 |
+
niter_PnP=10,
|
| 66 |
+
verbose=True,
|
| 67 |
+
save_score_path=None,
|
| 68 |
+
):
|
| 69 |
+
n_imgs = len(imshapes)
|
| 70 |
+
eadge_and_scores = compute_edge_scores(map(i_j_ij, edges), edge2conf_i, edge2conf_j)
|
| 71 |
+
sparse_graph = -dict_to_sparse_graph(eadge_and_scores)
|
| 72 |
+
msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()
|
| 73 |
+
|
| 74 |
+
# temp variable to store 3d points
|
| 75 |
+
pts3d = [None] * len(imshapes)
|
| 76 |
+
|
| 77 |
+
todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges
|
| 78 |
+
im_poses = [None] * n_imgs
|
| 79 |
+
im_focals = [None] * n_imgs
|
| 80 |
+
|
| 81 |
+
# init with strongest edge
|
| 82 |
+
score, i, j = todo.pop()
|
| 83 |
+
if verbose:
|
| 84 |
+
print(f" init edge ({i}*,{j}*) {score=}")
|
| 85 |
+
i_j = edge_str(i, j)
|
| 86 |
+
|
| 87 |
+
pts3d[i] = edge2pred_i[i_j].clone()
|
| 88 |
+
pts3d[j] = edge2pred_j[i_j].clone()
|
| 89 |
+
done = {i, j}
|
| 90 |
+
if has_im_poses:
|
| 91 |
+
im_poses[i] = torch.eye(4, device=device)
|
| 92 |
+
im_focals[i] = estimate_focal(edge2pred_i[i_j])
|
| 93 |
+
|
| 94 |
+
# set initial pointcloud based on pairwise graph
|
| 95 |
+
msp_edges = [(i, j)]
|
| 96 |
+
while todo:
|
| 97 |
+
# each time, predict the next one
|
| 98 |
+
score, i, j = todo.pop()
|
| 99 |
+
|
| 100 |
+
if im_focals[i] is None:
|
| 101 |
+
im_focals[i] = estimate_focal(edge2pred_i[i_j])
|
| 102 |
+
|
| 103 |
+
if i in done:
|
| 104 |
+
if verbose:
|
| 105 |
+
print(f" init edge ({i},{j}*) {score=}")
|
| 106 |
+
assert j not in done
|
| 107 |
+
# align pred[i] with pts3d[i], and then set j accordingly
|
| 108 |
+
i_j = edge_str(i, j)
|
| 109 |
+
s, R, T = rigid_points_registration(
|
| 110 |
+
edge2pred_i[i_j], pts3d[i], conf=edge2conf_i[i_j]
|
| 111 |
+
)
|
| 112 |
+
trf = sRT_to_4x4(s, R, T, device)
|
| 113 |
+
pts3d[j] = geotrf(trf, edge2pred_j[i_j])
|
| 114 |
+
done.add(j)
|
| 115 |
+
msp_edges.append((i, j))
|
| 116 |
+
|
| 117 |
+
if has_im_poses and im_poses[i] is None:
|
| 118 |
+
im_poses[i] = sRT_to_4x4(1, R, T, device)
|
| 119 |
+
|
| 120 |
+
elif j in done:
|
| 121 |
+
if verbose:
|
| 122 |
+
print(f" init edge ({i}*,{j}) {score=}")
|
| 123 |
+
assert i not in done
|
| 124 |
+
i_j = edge_str(i, j)
|
| 125 |
+
s, R, T = rigid_points_registration(
|
| 126 |
+
edge2pred_j[i_j], pts3d[j], conf=edge2conf_j[i_j]
|
| 127 |
+
)
|
| 128 |
+
trf = sRT_to_4x4(s, R, T, device)
|
| 129 |
+
pts3d[i] = geotrf(trf, edge2pred_i[i_j])
|
| 130 |
+
done.add(i)
|
| 131 |
+
msp_edges.append((i, j))
|
| 132 |
+
|
| 133 |
+
if has_im_poses and im_poses[i] is None:
|
| 134 |
+
im_poses[i] = sRT_to_4x4(1, R, T, device)
|
| 135 |
+
else:
|
| 136 |
+
# let's try again later
|
| 137 |
+
todo.insert(0, (score, i, j))
|
| 138 |
+
|
| 139 |
+
if has_im_poses:
|
| 140 |
+
# complete all missing informations
|
| 141 |
+
pair_scores = list(
|
| 142 |
+
sparse_graph.values()
|
| 143 |
+
) # already negative scores: less is best
|
| 144 |
+
edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[
|
| 145 |
+
np.argsort(pair_scores)
|
| 146 |
+
]
|
| 147 |
+
for i, j in edges_from_best_to_worse.tolist():
|
| 148 |
+
if im_focals[i] is None:
|
| 149 |
+
im_focals[i] = estimate_focal(edge2pred_i[edge_str(i, j)])
|
| 150 |
+
|
| 151 |
+
for i in range(n_imgs):
|
| 152 |
+
if im_poses[i] is None:
|
| 153 |
+
msk = im_conf[i] > min_conf_thr
|
| 154 |
+
res = fast_pnp(
|
| 155 |
+
pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP
|
| 156 |
+
)
|
| 157 |
+
if res:
|
| 158 |
+
im_focals[i], im_poses[i] = res
|
| 159 |
+
if im_poses[i] is None:
|
| 160 |
+
im_poses[i] = torch.eye(4, device=device)
|
| 161 |
+
im_poses = torch.stack(im_poses)
|
| 162 |
+
else:
|
| 163 |
+
im_poses = im_focals = None
|
| 164 |
+
|
| 165 |
+
return pts3d, msp_edges, im_focals, im_poses
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def init_from_pts3d(self, pts3d, im_focals, im_poses):
|
| 169 |
+
# init poses
|
| 170 |
+
nkp, known_poses_msk, known_poses = self.get_known_poses()
|
| 171 |
+
if nkp == 1:
|
| 172 |
+
raise NotImplementedError(
|
| 173 |
+
"Would be simpler to just align everything afterwards on the single known pose"
|
| 174 |
+
)
|
| 175 |
+
elif nkp > 1:
|
| 176 |
+
# global rigid SE3 alignment
|
| 177 |
+
s, R, T = align_multiple_poses(
|
| 178 |
+
im_poses[known_poses_msk], known_poses[known_poses_msk]
|
| 179 |
+
)
|
| 180 |
+
trf = sRT_to_4x4(s, R, T, device=known_poses.device)
|
| 181 |
+
|
| 182 |
+
# rotate everything
|
| 183 |
+
im_poses = trf @ im_poses
|
| 184 |
+
im_poses[:, :3, :3] /= s # undo scaling on the rotation part
|
| 185 |
+
for img_pts3d in pts3d:
|
| 186 |
+
img_pts3d[:] = geotrf(trf, img_pts3d)
|
| 187 |
+
else:
|
| 188 |
+
pass # no known poses
|
| 189 |
+
|
| 190 |
+
# set all pairwise poses
|
| 191 |
+
for e, (i, j) in enumerate(self.edges):
|
| 192 |
+
i_j = edge_str(i, j)
|
| 193 |
+
# compute transform that goes from cam to world
|
| 194 |
+
s, R, T = rigid_points_registration(
|
| 195 |
+
self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]
|
| 196 |
+
)
|
| 197 |
+
self._set_pose(self.pw_poses, e, R, T, scale=s)
|
| 198 |
+
|
| 199 |
+
# take into account the scale normalization
|
| 200 |
+
s_factor = self.get_pw_norm_scale_factor()
|
| 201 |
+
im_poses[:, :3, 3] *= s_factor # apply downscaling factor
|
| 202 |
+
for img_pts3d in pts3d:
|
| 203 |
+
img_pts3d *= s_factor
|
| 204 |
+
|
| 205 |
+
# init all image poses
|
| 206 |
+
if self.has_im_poses:
|
| 207 |
+
for i in range(self.n_imgs):
|
| 208 |
+
cam2world = im_poses[i]
|
| 209 |
+
depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
|
| 210 |
+
self._set_depthmap(i, depth)
|
| 211 |
+
self._set_pose(self.im_poses, i, cam2world)
|
| 212 |
+
if im_focals[i] is not None:
|
| 213 |
+
if not self.shared_focal:
|
| 214 |
+
self._set_focal(i, im_focals[i])
|
| 215 |
+
if self.shared_focal:
|
| 216 |
+
self._set_focal(0, sum(im_focals) / self.n_imgs)
|
| 217 |
+
if self.n_imgs > 2:
|
| 218 |
+
self._set_init_depthmap()
|
| 219 |
+
|
| 220 |
+
if self.verbose:
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
print(" init loss =", float(self()))
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/utils.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import roma
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
from functools import cache
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def todevice(batch, device, callback=None, non_blocking=False):
|
| 10 |
+
"""Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
|
| 11 |
+
|
| 12 |
+
batch: list, tuple, dict of tensors or other things
|
| 13 |
+
device: pytorch device or 'numpy'
|
| 14 |
+
callback: function that would be called on every sub-elements.
|
| 15 |
+
"""
|
| 16 |
+
if callback:
|
| 17 |
+
batch = callback(batch)
|
| 18 |
+
|
| 19 |
+
if isinstance(batch, dict):
|
| 20 |
+
return {k: todevice(v, device) for k, v in batch.items()}
|
| 21 |
+
|
| 22 |
+
if isinstance(batch, (tuple, list)):
|
| 23 |
+
return type(batch)(todevice(x, device) for x in batch)
|
| 24 |
+
|
| 25 |
+
x = batch
|
| 26 |
+
if device == "numpy":
|
| 27 |
+
if isinstance(x, torch.Tensor):
|
| 28 |
+
x = x.detach().cpu().numpy()
|
| 29 |
+
elif x is not None:
|
| 30 |
+
if isinstance(x, np.ndarray):
|
| 31 |
+
x = torch.from_numpy(x)
|
| 32 |
+
if torch.is_tensor(x):
|
| 33 |
+
x = x.to(device, non_blocking=non_blocking)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
to_device = todevice # alias
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def to_numpy(x):
|
| 41 |
+
return todevice(x, "numpy")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def to_cpu(x):
|
| 45 |
+
return todevice(x, "cpu")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def to_cuda(x):
|
| 49 |
+
return todevice(x, "cuda")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def signed_log1p(x):
|
| 53 |
+
sign = torch.sign(x)
|
| 54 |
+
return sign * torch.log1p(torch.abs(x))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def l2_dist(a, b, weight):
|
| 58 |
+
return (a - b).square().sum(dim=-1) * weight
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def l1_dist(a, b, weight):
|
| 62 |
+
return (a - b).norm(dim=-1) * weight
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _check_edges(edges):
|
| 69 |
+
indices = sorted({i for edge in edges for i in edge})
|
| 70 |
+
assert indices == list(range(len(indices))), "bad pair indices: missing values "
|
| 71 |
+
return len(indices)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def NoGradParamDict(x):
|
| 75 |
+
assert isinstance(x, dict)
|
| 76 |
+
return nn.ParameterDict(x).requires_grad_(False)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def edge_str(i, j):
|
| 80 |
+
return f"{i}_{j}"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def i_j_ij(ij):
|
| 84 |
+
# inputs are (i, j)
|
| 85 |
+
return edge_str(*ij), ij
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def edge_conf(conf_i, conf_j):
|
| 89 |
+
score = float(conf_i.mean() * conf_j.mean())
|
| 90 |
+
return score
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_imshapes(edges, pred_i, pred_j):
|
| 94 |
+
n_imgs = max(max(e) for e in edges) + 1
|
| 95 |
+
imshapes = [None] * n_imgs
|
| 96 |
+
for e, (i, j) in enumerate(edges):
|
| 97 |
+
shape_i = tuple(pred_i[e]["pts3d_is_self_view"].shape[0:2])
|
| 98 |
+
shape_j = tuple(pred_j[e]["pts3d_in_other_view"].shape[0:2])
|
| 99 |
+
if imshapes[i]:
|
| 100 |
+
assert imshapes[i] == shape_i, f"incorrect shape for image {i}"
|
| 101 |
+
if imshapes[j]:
|
| 102 |
+
assert imshapes[j] == shape_j, f"incorrect shape for image {j}"
|
| 103 |
+
imshapes[i] = shape_i
|
| 104 |
+
imshapes[j] = shape_j
|
| 105 |
+
return imshapes
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_conf_trf(mode):
|
| 109 |
+
if mode == "log":
|
| 110 |
+
|
| 111 |
+
def conf_trf(x):
|
| 112 |
+
return x.log()
|
| 113 |
+
|
| 114 |
+
elif mode == "sqrt":
|
| 115 |
+
|
| 116 |
+
def conf_trf(x):
|
| 117 |
+
return x.sqrt()
|
| 118 |
+
|
| 119 |
+
elif mode == "m1":
|
| 120 |
+
|
| 121 |
+
def conf_trf(x):
|
| 122 |
+
return x - 1
|
| 123 |
+
|
| 124 |
+
elif mode in ("id", "none"):
|
| 125 |
+
|
| 126 |
+
def conf_trf(x):
|
| 127 |
+
return x
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
raise ValueError(f"bad mode for {mode=}")
|
| 131 |
+
return conf_trf
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@torch.no_grad()
|
| 135 |
+
def _compute_img_conf(imshapes, device, edges, edge2conf_i, edge2conf_j):
|
| 136 |
+
im_conf = nn.ParameterList([torch.zeros(hw, device=device) for hw in imshapes])
|
| 137 |
+
for e, (i, j) in enumerate(edges):
|
| 138 |
+
im_conf[i] = torch.maximum(im_conf[i], edge2conf_i[edge_str(i, j)])
|
| 139 |
+
im_conf[j] = torch.maximum(im_conf[j], edge2conf_j[edge_str(i, j)])
|
| 140 |
+
return im_conf
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def xy_grid(
|
| 144 |
+
W,
|
| 145 |
+
H,
|
| 146 |
+
device=None,
|
| 147 |
+
origin=(0, 0),
|
| 148 |
+
unsqueeze=None,
|
| 149 |
+
cat_dim=-1,
|
| 150 |
+
homogeneous=False,
|
| 151 |
+
**arange_kw,
|
| 152 |
+
):
|
| 153 |
+
"""Output a (H,W,2) array of int32
|
| 154 |
+
with output[j,i,0] = i + origin[0]
|
| 155 |
+
output[j,i,1] = j + origin[1]
|
| 156 |
+
"""
|
| 157 |
+
if device is None:
|
| 158 |
+
# numpy
|
| 159 |
+
arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
|
| 160 |
+
else:
|
| 161 |
+
# torch
|
| 162 |
+
arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
|
| 163 |
+
meshgrid, stack = torch.meshgrid, torch.stack
|
| 164 |
+
ones = lambda *a: torch.ones(*a, device=device)
|
| 165 |
+
|
| 166 |
+
tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
|
| 167 |
+
grid = meshgrid(tw, th, indexing="xy")
|
| 168 |
+
if homogeneous:
|
| 169 |
+
grid = grid + (ones((H, W)),)
|
| 170 |
+
if unsqueeze is not None:
|
| 171 |
+
grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
|
| 172 |
+
if cat_dim is not None:
|
| 173 |
+
grid = stack(grid, cat_dim)
|
| 174 |
+
return grid
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def estimate_focal_knowing_depth(
|
| 178 |
+
pts3d, pp, focal_mode="median", min_focal=0.0, max_focal=np.inf
|
| 179 |
+
):
|
| 180 |
+
"""Reprojection method, for when the absolute depth is known:
|
| 181 |
+
1) estimate the camera focal using a robust estimator
|
| 182 |
+
2) reproject points onto true rays, minimizing a certain error
|
| 183 |
+
"""
|
| 184 |
+
B, H, W, THREE = pts3d.shape
|
| 185 |
+
assert THREE == 3
|
| 186 |
+
|
| 187 |
+
# centered pixel grid
|
| 188 |
+
pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(
|
| 189 |
+
-1, 1, 2
|
| 190 |
+
) # B,HW,2
|
| 191 |
+
pts3d = pts3d.flatten(1, 2) # (B, HW, 3)
|
| 192 |
+
|
| 193 |
+
if focal_mode == "median":
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
# direct estimation of focal
|
| 196 |
+
u, v = pixels.unbind(dim=-1)
|
| 197 |
+
x, y, z = pts3d.unbind(dim=-1)
|
| 198 |
+
fx_votes = (u * z) / x
|
| 199 |
+
fy_votes = (v * z) / y
|
| 200 |
+
|
| 201 |
+
# assume square pixels, hence same focal for X and Y
|
| 202 |
+
f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
|
| 203 |
+
focal = torch.nanmedian(f_votes, dim=-1).values
|
| 204 |
+
|
| 205 |
+
elif focal_mode == "weiszfeld":
|
| 206 |
+
# init focal with l2 closed form
|
| 207 |
+
# we try to find focal = argmin Sum | pixel - focal * (x,y)/z|
|
| 208 |
+
xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(
|
| 209 |
+
posinf=0, neginf=0
|
| 210 |
+
) # homogeneous (x,y,1)
|
| 211 |
+
|
| 212 |
+
dot_xy_px = (xy_over_z * pixels).sum(dim=-1)
|
| 213 |
+
dot_xy_xy = xy_over_z.square().sum(dim=-1)
|
| 214 |
+
|
| 215 |
+
focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)
|
| 216 |
+
|
| 217 |
+
# iterative re-weighted least-squares
|
| 218 |
+
for iter in range(10):
|
| 219 |
+
# re-weighting by inverse of distance
|
| 220 |
+
dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1)
|
| 221 |
+
# print(dis.nanmean(-1))
|
| 222 |
+
w = dis.clip(min=1e-8).reciprocal()
|
| 223 |
+
# update the scaling with the new weights
|
| 224 |
+
focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)
|
| 225 |
+
else:
|
| 226 |
+
raise ValueError(f"bad {focal_mode=}")
|
| 227 |
+
|
| 228 |
+
focal_base = max(H, W) / (
|
| 229 |
+
2 * np.tan(np.deg2rad(60) / 2)
|
| 230 |
+
) # size / 1.1547005383792515
|
| 231 |
+
focal = focal.clip(min=min_focal * focal_base, max=max_focal * focal_base)
|
| 232 |
+
# print(focal)
|
| 233 |
+
return focal
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def estimate_focal(pts3d_i, pp=None):
|
| 237 |
+
if pp is None:
|
| 238 |
+
H, W, THREE = pts3d_i.shape
|
| 239 |
+
assert THREE == 3
|
| 240 |
+
pp = torch.tensor((W / 2, H / 2), device=pts3d_i.device)
|
| 241 |
+
focal = estimate_focal_knowing_depth(
|
| 242 |
+
pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode="weiszfeld"
|
| 243 |
+
).ravel()
|
| 244 |
+
return float(focal)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def rigid_points_registration(pts1, pts2, conf):
|
| 248 |
+
R, T, s = roma.rigid_points_registration(
|
| 249 |
+
pts1.reshape(-1, 3),
|
| 250 |
+
pts2.reshape(-1, 3),
|
| 251 |
+
weights=conf.ravel(),
|
| 252 |
+
compute_scaling=True,
|
| 253 |
+
)
|
| 254 |
+
return s, R, T # return un-scaled (R, T)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def sRT_to_4x4(scale, R, T, device):
|
| 258 |
+
trf = torch.eye(4, device=device)
|
| 259 |
+
trf[:3, :3] = R * scale
|
| 260 |
+
trf[:3, 3] = T.ravel() # doesn't need scaling
|
| 261 |
+
return trf
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def geotrf(Trf, pts, ncol=None, norm=False):
|
| 265 |
+
"""Apply a geometric transformation to a list of 3-D points.
|
| 266 |
+
|
| 267 |
+
H: 3x3 or 4x4 projection matrix (typically a Homography)
|
| 268 |
+
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
|
| 269 |
+
|
| 270 |
+
ncol: int. number of columns of the result (2 or 3)
|
| 271 |
+
norm: float. if != 0, the resut is projected on the z=norm plane.
|
| 272 |
+
|
| 273 |
+
Returns an array of projected 2d points.
|
| 274 |
+
"""
|
| 275 |
+
assert Trf.ndim >= 2
|
| 276 |
+
if isinstance(Trf, np.ndarray):
|
| 277 |
+
pts = np.asarray(pts)
|
| 278 |
+
elif isinstance(Trf, torch.Tensor):
|
| 279 |
+
pts = torch.as_tensor(pts, dtype=Trf.dtype)
|
| 280 |
+
|
| 281 |
+
# adapt shape if necessary
|
| 282 |
+
output_reshape = pts.shape[:-1]
|
| 283 |
+
ncol = ncol or pts.shape[-1]
|
| 284 |
+
|
| 285 |
+
# optimized code
|
| 286 |
+
if (
|
| 287 |
+
isinstance(Trf, torch.Tensor)
|
| 288 |
+
and isinstance(pts, torch.Tensor)
|
| 289 |
+
and Trf.ndim == 3
|
| 290 |
+
and pts.ndim == 4
|
| 291 |
+
):
|
| 292 |
+
d = pts.shape[3]
|
| 293 |
+
if Trf.shape[-1] == d:
|
| 294 |
+
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
|
| 295 |
+
elif Trf.shape[-1] == d + 1:
|
| 296 |
+
pts = (
|
| 297 |
+
torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
|
| 298 |
+
+ Trf[:, None, None, :d, d]
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
|
| 302 |
+
else:
|
| 303 |
+
if Trf.ndim >= 3:
|
| 304 |
+
n = Trf.ndim - 2
|
| 305 |
+
assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
|
| 306 |
+
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
|
| 307 |
+
|
| 308 |
+
if pts.ndim > Trf.ndim:
|
| 309 |
+
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
|
| 310 |
+
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
|
| 311 |
+
elif pts.ndim == 2:
|
| 312 |
+
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
|
| 313 |
+
pts = pts[:, None, :]
|
| 314 |
+
|
| 315 |
+
if pts.shape[-1] + 1 == Trf.shape[-1]:
|
| 316 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 317 |
+
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
|
| 318 |
+
elif pts.shape[-1] == Trf.shape[-1]:
|
| 319 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 320 |
+
pts = pts @ Trf
|
| 321 |
+
else:
|
| 322 |
+
pts = Trf @ pts.T
|
| 323 |
+
if pts.ndim >= 2:
|
| 324 |
+
pts = pts.swapaxes(-1, -2)
|
| 325 |
+
|
| 326 |
+
if norm:
|
| 327 |
+
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
|
| 328 |
+
if norm != 1:
|
| 329 |
+
pts *= norm
|
| 330 |
+
|
| 331 |
+
res = pts[..., :ncol].reshape(*output_reshape, ncol)
|
| 332 |
+
return res
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def inv(mat):
|
| 336 |
+
"""Invert a torch or numpy matrix"""
|
| 337 |
+
if isinstance(mat, torch.Tensor):
|
| 338 |
+
return torch.linalg.inv(mat)
|
| 339 |
+
if isinstance(mat, np.ndarray):
|
| 340 |
+
return np.linalg.inv(mat)
|
| 341 |
+
raise ValueError(f"bad matrix type = {type(mat)}")
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@cache
|
| 345 |
+
def pixel_grid(H, W):
|
| 346 |
+
return np.mgrid[:W, :H].T.astype(np.float32)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
|
| 350 |
+
# extract camera poses and focals with RANSAC-PnP
|
| 351 |
+
if msk.sum() < 4:
|
| 352 |
+
return None # we need at least 4 points for PnP
|
| 353 |
+
pts3d, msk = map(to_numpy, (pts3d, msk))
|
| 354 |
+
|
| 355 |
+
H, W, THREE = pts3d.shape
|
| 356 |
+
assert THREE == 3
|
| 357 |
+
pixels = pixel_grid(H, W)
|
| 358 |
+
|
| 359 |
+
if focal is None:
|
| 360 |
+
S = max(W, H)
|
| 361 |
+
tentative_focals = np.geomspace(S / 2, S * 3, 21)
|
| 362 |
+
else:
|
| 363 |
+
tentative_focals = [focal]
|
| 364 |
+
|
| 365 |
+
if pp is None:
|
| 366 |
+
pp = (W / 2, H / 2)
|
| 367 |
+
else:
|
| 368 |
+
pp = to_numpy(pp)
|
| 369 |
+
|
| 370 |
+
best = (0,)
|
| 371 |
+
for focal in tentative_focals:
|
| 372 |
+
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
|
| 373 |
+
|
| 374 |
+
success, R, T, inliers = cv2.solvePnPRansac(
|
| 375 |
+
pts3d[msk],
|
| 376 |
+
pixels[msk],
|
| 377 |
+
K,
|
| 378 |
+
None,
|
| 379 |
+
iterationsCount=niter_PnP,
|
| 380 |
+
reprojectionError=5,
|
| 381 |
+
flags=cv2.SOLVEPNP_SQPNP,
|
| 382 |
+
)
|
| 383 |
+
if not success:
|
| 384 |
+
continue
|
| 385 |
+
|
| 386 |
+
score = len(inliers)
|
| 387 |
+
if success and score > best[0]:
|
| 388 |
+
best = score, R, T, focal
|
| 389 |
+
|
| 390 |
+
if not best[0]:
|
| 391 |
+
return None
|
| 392 |
+
|
| 393 |
+
_, R, T, best_focal = best
|
| 394 |
+
R = cv2.Rodrigues(R)[0] # world to cam
|
| 395 |
+
R, T = map(torch.from_numpy, (R, T))
|
| 396 |
+
return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def get_med_dist_between_poses(poses):
|
| 400 |
+
from scipy.spatial.distance import pdist
|
| 401 |
+
|
| 402 |
+
return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def align_multiple_poses(src_poses, target_poses):
|
| 406 |
+
N = len(src_poses)
|
| 407 |
+
assert src_poses.shape == target_poses.shape == (N, 4, 4)
|
| 408 |
+
|
| 409 |
+
def center_and_z(poses):
|
| 410 |
+
eps = get_med_dist_between_poses(poses) / 100
|
| 411 |
+
return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps * poses[:, :3, 2]))
|
| 412 |
+
|
| 413 |
+
R, T, s = roma.rigid_points_registration(
|
| 414 |
+
center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True
|
| 415 |
+
)
|
| 416 |
+
return s, R, T
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def cosine_schedule(t, lr_start, lr_end):
|
| 420 |
+
assert 0 <= t <= 1
|
| 421 |
+
return lr_end + (lr_start - lr_end) * (1 + np.cos(t * np.pi)) / 2
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def linear_schedule(t, lr_start, lr_end):
|
| 425 |
+
assert 0 <= t <= 1
|
| 426 |
+
return lr_start + (lr_end - lr_start) * t
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def cycled_linear_schedule(t, lr_start, lr_end, num_cycles=2):
|
| 430 |
+
assert 0 <= t <= 1
|
| 431 |
+
cycle_t = t * num_cycles
|
| 432 |
+
cycle_t = cycle_t - int(cycle_t)
|
| 433 |
+
if t == 1:
|
| 434 |
+
cycle_t = 1
|
| 435 |
+
return linear_schedule(cycle_t, lr_start, lr_end)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def adjust_learning_rate_by_lr(optimizer, lr):
|
| 439 |
+
for param_group in optimizer.param_groups:
|
| 440 |
+
if "lr_scale" in param_group:
|
| 441 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
| 442 |
+
else:
|
| 443 |
+
param_group["lr"] = lr
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/deepspeed_zero3_bf16.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": true
|
| 4 |
+
},
|
| 5 |
+
"gradient_accumulation_steps": 1,
|
| 6 |
+
"gradient_clipping": 1.0,
|
| 7 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 8 |
+
"steps_per_print": 2000,
|
| 9 |
+
"wall_clock_breakdown": false,
|
| 10 |
+
"zero_optimization": {
|
| 11 |
+
"stage": 3,
|
| 12 |
+
"overlap_comm": true,
|
| 13 |
+
"contiguous_gradients": true,
|
| 14 |
+
"reduce_bucket_size": 50000000,
|
| 15 |
+
"stage3_prefetch_bucket_size": 5000000,
|
| 16 |
+
"stage3_param_persistence_threshold": 100000,
|
| 17 |
+
"gather_16bit_weights_on_model_save": true
|
| 18 |
+
}
|
| 19 |
+
}
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune.yaml
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accum_iter: 1
|
| 2 |
+
allow_repeat: false
|
| 3 |
+
amp: 1
|
| 4 |
+
batch_size: 1
|
| 5 |
+
benchmark: false
|
| 6 |
+
custom_lr_scale: 1.0
|
| 7 |
+
dataset_arkit: ARKitScenes_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="/var/scratch/qzhang2/SLAM-Former/data/train/processed_arkitscenes/",
|
| 8 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 9 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 10 |
+
dataset_mvs_synth: MVS_Synth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="/var/scratch/qzhang2/SLAM-Former/data/train/processed_mvs_synth",
|
| 11 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 12 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 13 |
+
dataset_scannetpp: ScanNetpp_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="/var/scratch/qzhang2/SLAM-Former/data/train/processed_scannetpp/",
|
| 14 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 15 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 16 |
+
desc_dim: 128
|
| 17 |
+
detach_frontend_tokens: false
|
| 18 |
+
dist_backend: nccl
|
| 19 |
+
dist_url: env://
|
| 20 |
+
distributed: false
|
| 21 |
+
enable_dynamic_boundary: false
|
| 22 |
+
enable_loop: true
|
| 23 |
+
enable_submap: true
|
| 24 |
+
enable_temporal: false
|
| 25 |
+
epochs: 10
|
| 26 |
+
eval_freq: 1
|
| 27 |
+
exp_name: submap_joint_softall_64_24_v1
|
| 28 |
+
fixed_length: true
|
| 29 |
+
freeze_encoder: true
|
| 30 |
+
gpu: 0
|
| 31 |
+
gradient_checkpointing: true
|
| 32 |
+
gumbel_tau: 5.0
|
| 33 |
+
loop_mask_mode: soft_all
|
| 34 |
+
retain_history_grad: true
|
| 35 |
+
submap_train_mode: full_token
|
| 36 |
+
submap_retrieval_topk: 0
|
| 37 |
+
submap_fetch_source: frontend
|
| 38 |
+
submap_descriptor_source: frontend
|
| 39 |
+
train_submap_modules_only: false
|
| 40 |
+
gumbel_tau_end: 0.1
|
| 41 |
+
gumbel_tau_start: 5.0
|
| 42 |
+
hydra:
|
| 43 |
+
run:
|
| 44 |
+
dir: ${save_dir}/${exp_name}
|
| 45 |
+
verbose: true
|
| 46 |
+
keep_freq: 1
|
| 47 |
+
load_only_encoder: false
|
| 48 |
+
local-rank: -1
|
| 49 |
+
logdir: ${save_dir}/${exp_name}/logs
|
| 50 |
+
long_context: false
|
| 51 |
+
lr: 1e-5
|
| 52 |
+
max_checkpoints: 10
|
| 53 |
+
max_recursive_submaps: 5
|
| 54 |
+
min_lr: 1e-8
|
| 55 |
+
n_corres_test: 0
|
| 56 |
+
n_corres_train: 0
|
| 57 |
+
num_imgs_vis: 4
|
| 58 |
+
num_test_views: 4
|
| 59 |
+
num_views: 24
|
| 60 |
+
num_workers: 4
|
| 61 |
+
output_dir: ${save_dir}/${exp_name}/
|
| 62 |
+
pretrained: /var/scratch/qzhang2/SLAM-Former/ckpt/checkpoint-10.pth.model
|
| 63 |
+
print_freq: 10
|
| 64 |
+
print_img_freq: 50000000
|
| 65 |
+
rank: 0
|
| 66 |
+
resume: null
|
| 67 |
+
retention_ratio: 0.5
|
| 68 |
+
pseudo_gt:
|
| 69 |
+
enable: false
|
| 70 |
+
cache_path: null
|
| 71 |
+
use_soft_targets: true
|
| 72 |
+
min_confidence: 0.65
|
| 73 |
+
min_support_pairs: 1
|
| 74 |
+
topk_pairs: 4
|
| 75 |
+
loss_type: hybrid
|
| 76 |
+
loss_weight_gate: 0.1
|
| 77 |
+
loss_weight_desc: 0.1
|
| 78 |
+
geometric_support_scale: 0.25
|
| 79 |
+
ranking_margin: 0.1
|
| 80 |
+
use_l2m: false
|
| 81 |
+
l2m_min_certainty: 0.0
|
| 82 |
+
l2m_min_inlier_ratio: 0.0
|
| 83 |
+
save_dir: /var/scratch/qzhang2/SLAM-Former/checkpoints
|
| 84 |
+
save_freq: 0.1
|
| 85 |
+
seed: 42
|
| 86 |
+
soft_mask_bias: 0.2
|
| 87 |
+
soft_mask_temperature: 0.25
|
| 88 |
+
start_epoch: 0
|
| 89 |
+
start_step: 0
|
| 90 |
+
submap_size: 6
|
| 91 |
+
task: SLAMFormer_Submap_Finetune
|
| 92 |
+
tbptt_window: 0
|
| 93 |
+
teacher: null
|
| 94 |
+
temporal_embed_mode: learned
|
| 95 |
+
test_criterion: DistillLoss()
|
| 96 |
+
test_dataset: 500 @ ARKitScenes_Multi(split='test', ROOT="/var/scratch/qzhang2/SLAM-Former/data/train/processed_arkitscenes/",
|
| 97 |
+
resolution=(518, 392), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test})
|
| 98 |
+
train_criterion: DistillLoss()
|
| 99 |
+
train_dataset: 2250 @ ${dataset_scannetpp} + 450 @ ${dataset_mvs_synth} + 4500 @ ${dataset_arkit}
|
| 100 |
+
warmup_epochs: 0.5
|
| 101 |
+
weight_decay: 0.05
|
| 102 |
+
world_size: 1
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune_paper_h20.yaml
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accum_iter: 1
|
| 2 |
+
allow_repeat: false
|
| 3 |
+
amp: 1
|
| 4 |
+
batch_size: 1
|
| 5 |
+
benchmark: false
|
| 6 |
+
custom_lr_scale: 1.0
|
| 7 |
+
data_root: /var/scratch/qzhang2/SLAM-Former/data/train
|
| 8 |
+
root_arkit: ${data_root}/processed_arkitscenes
|
| 9 |
+
root_scannetpp: ${data_root}/processed_scannetpp
|
| 10 |
+
root_scannet: ${data_root}/processed_scannet
|
| 11 |
+
root_hypersim: ${data_root}/hypersim
|
| 12 |
+
root_blendedmvs: ${data_root}/processed_blendedmvs
|
| 13 |
+
root_megadepth: ${data_root}/processed_megadepth
|
| 14 |
+
root_mvs_synth: ${data_root}/processed_mvs_synth
|
| 15 |
+
dataset_arkit: ARKitScenes_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_arkit}",
|
| 16 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 17 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_arkit}, n_corres=${n_corres_train})
|
| 18 |
+
dataset_scannetpp: ScanNetpp_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannetpp}",
|
| 19 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 20 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_scannetpp}, n_corres=${n_corres_train})
|
| 21 |
+
dataset_scannet: ScanNet_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannet}",
|
| 22 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 23 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_scannet}, n_corres=${n_corres_train})
|
| 24 |
+
dataset_hypersim: HyperSim_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_hypersim}",
|
| 25 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 26 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_hypersim}, n_corres=${n_corres_train})
|
| 27 |
+
dataset_blendedmvs: BlendedMVS_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_blendedmvs}",
|
| 28 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 29 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_blendedmvs}, n_corres=${n_corres_train})
|
| 30 |
+
dataset_megadepth: MegaDepth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_megadepth}",
|
| 31 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 32 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_megadepth}, n_corres=${n_corres_train})
|
| 33 |
+
dataset_mvs_synth: MVS_Synth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_mvs_synth}",
|
| 34 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 35 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_mvs_synth}, n_corres=${n_corres_train})
|
| 36 |
+
desc_dim: 128
|
| 37 |
+
detach_frontend_tokens: false
|
| 38 |
+
dist_backend: nccl
|
| 39 |
+
dist_url: env://
|
| 40 |
+
distributed: false
|
| 41 |
+
enable_dynamic_boundary: false
|
| 42 |
+
enable_loop: true
|
| 43 |
+
enable_submap: true
|
| 44 |
+
enable_temporal: false
|
| 45 |
+
epochs: 10
|
| 46 |
+
eval_freq: 1
|
| 47 |
+
exp_name: paper_all_h20_joint
|
| 48 |
+
fixed_length: true
|
| 49 |
+
freeze_encoder: true
|
| 50 |
+
gpu: 0
|
| 51 |
+
gradient_checkpointing: true
|
| 52 |
+
gumbel_tau: 5.0
|
| 53 |
+
loop_mask_mode: soft_all
|
| 54 |
+
retain_history_grad: true
|
| 55 |
+
submap_train_mode: full_token
|
| 56 |
+
submap_retrieval_topk: 0
|
| 57 |
+
submap_fetch_source: frontend
|
| 58 |
+
submap_descriptor_source: frontend
|
| 59 |
+
train_submap_modules_only: false
|
| 60 |
+
gumbel_tau_end: 0.1
|
| 61 |
+
gumbel_tau_start: 5.0
|
| 62 |
+
hydra:
|
| 63 |
+
run:
|
| 64 |
+
dir: ${save_dir}/${exp_name}
|
| 65 |
+
verbose: true
|
| 66 |
+
keep_freq: 1
|
| 67 |
+
load_only_encoder: false
|
| 68 |
+
local-rank: -1
|
| 69 |
+
logdir: ${save_dir}/${exp_name}/logs
|
| 70 |
+
long_context: false
|
| 71 |
+
lr: 1e-5
|
| 72 |
+
max_checkpoints: 10
|
| 73 |
+
max_recursive_submaps: 5
|
| 74 |
+
min_lr: 1e-8
|
| 75 |
+
n_corres_test: 0
|
| 76 |
+
n_corres_train: 0
|
| 77 |
+
num_imgs_vis: 4
|
| 78 |
+
num_test_views: 4
|
| 79 |
+
num_views: 24
|
| 80 |
+
num_views_arkit: 24
|
| 81 |
+
num_views_scannetpp: 24
|
| 82 |
+
num_views_scannet: 24
|
| 83 |
+
num_views_hypersim: 24
|
| 84 |
+
num_views_blendedmvs: 24
|
| 85 |
+
num_views_megadepth: 24
|
| 86 |
+
num_views_mvs_synth: 24
|
| 87 |
+
num_workers: 4
|
| 88 |
+
output_dir: ${save_dir}/${exp_name}/
|
| 89 |
+
pretrained: /var/scratch/qzhang2/SLAM-Former/ckpt/checkpoint-10.pth.model
|
| 90 |
+
print_freq: 10
|
| 91 |
+
print_img_freq: 50000000
|
| 92 |
+
rank: 0
|
| 93 |
+
resume: null
|
| 94 |
+
retention_ratio: 0.5
|
| 95 |
+
pseudo_gt:
|
| 96 |
+
enable: false
|
| 97 |
+
cache_path: null
|
| 98 |
+
use_soft_targets: true
|
| 99 |
+
min_confidence: 0.65
|
| 100 |
+
min_support_pairs: 1
|
| 101 |
+
topk_pairs: 4
|
| 102 |
+
loss_type: hybrid
|
| 103 |
+
loss_weight_gate: 0.1
|
| 104 |
+
loss_weight_desc: 0.1
|
| 105 |
+
geometric_support_scale: 0.25
|
| 106 |
+
ranking_margin: 0.1
|
| 107 |
+
use_l2m: false
|
| 108 |
+
l2m_min_certainty: 0.0
|
| 109 |
+
l2m_min_inlier_ratio: 0.0
|
| 110 |
+
save_dir: /var/scratch/qzhang2/SLAM-Former/checkpoints
|
| 111 |
+
save_freq: 0.1
|
| 112 |
+
seed: 42
|
| 113 |
+
soft_mask_bias: 0.2
|
| 114 |
+
soft_mask_temperature: 0.25
|
| 115 |
+
start_epoch: 0
|
| 116 |
+
start_step: 0
|
| 117 |
+
submap_size: 6
|
| 118 |
+
task: SLAMFormer_Submap_Finetune
|
| 119 |
+
tbptt_window: 0
|
| 120 |
+
teacher: null
|
| 121 |
+
temporal_embed_mode: learned
|
| 122 |
+
test_criterion: DistillLoss()
|
| 123 |
+
test_dataset: 500 @ ARKitScenes_Multi(split='test', ROOT="${root_arkit}",
|
| 124 |
+
resolution=(518, 392), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test})
|
| 125 |
+
train_criterion: DistillLoss()
|
| 126 |
+
train_dataset: 4500 @ ${dataset_arkit} + 2250 @ ${dataset_scannetpp} + 4500 @ ${dataset_scannet} + 1200 @ ${dataset_hypersim} + 2250 @ ${dataset_blendedmvs} + 2250 @ ${dataset_megadepth} + 450 @ ${dataset_mvs_synth}
|
| 127 |
+
warmup_epochs: 0.5
|
| 128 |
+
weight_decay: 0.05
|
| 129 |
+
world_size: 1
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune_pseudo_gt_high_recall.yaml
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accum_iter: 1
|
| 2 |
+
allow_repeat: false
|
| 3 |
+
amp: 1
|
| 4 |
+
batch_size: 1
|
| 5 |
+
benchmark: false
|
| 6 |
+
custom_lr_scale: 1.0
|
| 7 |
+
data_root: /var/scratch/qzhang2/SLAM-Former/data/train
|
| 8 |
+
root_arkit: ${data_root}/processed_arkitscenes
|
| 9 |
+
root_scannetpp: ${data_root}/processed_scannetpp
|
| 10 |
+
root_scannet: ${data_root}/processed_scannetv2
|
| 11 |
+
root_hypersim: ${data_root}/hypersim
|
| 12 |
+
root_blendedmvs: ${data_root}/processed_blendedmvs
|
| 13 |
+
root_megadepth: ${data_root}/processed_megadepth
|
| 14 |
+
root_mvs_synth: ${data_root}/processed_mvs_synth
|
| 15 |
+
dataset_arkit: ARKitScenes_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_arkit}",
|
| 16 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 17 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_arkit}, n_corres=${n_corres_train})
|
| 18 |
+
dataset_scannetpp: ScanNetpp_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannetpp}",
|
| 19 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 20 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_scannetpp}, n_corres=${n_corres_train})
|
| 21 |
+
dataset_scannet: ScanNet_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannet}",
|
| 22 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 23 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_scannet}, n_corres=${n_corres_train})
|
| 24 |
+
dataset_hypersim: HyperSim_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_hypersim}",
|
| 25 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 26 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_hypersim}, n_corres=${n_corres_train})
|
| 27 |
+
dataset_blendedmvs: BlendedMVS_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_blendedmvs}",
|
| 28 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 29 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_blendedmvs}, n_corres=${n_corres_train})
|
| 30 |
+
dataset_megadepth: MegaDepth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_megadepth}",
|
| 31 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 32 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_megadepth}, n_corres=${n_corres_train})
|
| 33 |
+
dataset_mvs_synth: MVS_Synth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_mvs_synth}",
|
| 34 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 35 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views_mvs_synth}, n_corres=${n_corres_train})
|
| 36 |
+
desc_dim: 128
|
| 37 |
+
detach_frontend_tokens: true
|
| 38 |
+
dist_backend: nccl
|
| 39 |
+
dist_url: env://
|
| 40 |
+
distributed: false
|
| 41 |
+
enable_dynamic_boundary: false
|
| 42 |
+
enable_loop: true
|
| 43 |
+
enable_submap: true
|
| 44 |
+
enable_temporal: false
|
| 45 |
+
epochs: 10
|
| 46 |
+
eval_freq: 1
|
| 47 |
+
exp_name: paper_all_h20_pseudo_gt_high_recall
|
| 48 |
+
fixed_length: true
|
| 49 |
+
freeze_encoder: true
|
| 50 |
+
gpu: 0
|
| 51 |
+
gradient_checkpointing: true
|
| 52 |
+
gumbel_tau: 5.0
|
| 53 |
+
loop_mask_mode: soft_all
|
| 54 |
+
retain_history_grad: true
|
| 55 |
+
submap_train_mode: full_token
|
| 56 |
+
submap_retrieval_topk: 0
|
| 57 |
+
submap_fetch_source: frontend
|
| 58 |
+
submap_descriptor_source: frontend
|
| 59 |
+
train_submap_modules_only: true
|
| 60 |
+
gumbel_tau_end: 0.1
|
| 61 |
+
gumbel_tau_start: 5.0
|
| 62 |
+
hydra:
|
| 63 |
+
run:
|
| 64 |
+
dir: ${save_dir}/${exp_name}
|
| 65 |
+
verbose: true
|
| 66 |
+
keep_freq: 1
|
| 67 |
+
load_only_encoder: false
|
| 68 |
+
local-rank: -1
|
| 69 |
+
logdir: ${save_dir}/${exp_name}/logs
|
| 70 |
+
long_context: false
|
| 71 |
+
lr: 1e-5
|
| 72 |
+
max_checkpoints: 10
|
| 73 |
+
max_recursive_submaps: 5
|
| 74 |
+
min_lr: 1e-8
|
| 75 |
+
n_corres_test: 0
|
| 76 |
+
n_corres_train: 0
|
| 77 |
+
num_imgs_vis: 4
|
| 78 |
+
num_test_views: 4
|
| 79 |
+
num_views: 24
|
| 80 |
+
num_views_arkit: 24
|
| 81 |
+
num_views_scannetpp: 24
|
| 82 |
+
num_views_scannet: 24
|
| 83 |
+
num_views_hypersim: 24
|
| 84 |
+
num_views_blendedmvs: 24
|
| 85 |
+
num_views_megadepth: 24
|
| 86 |
+
num_views_mvs_synth: 24
|
| 87 |
+
num_workers: 4
|
| 88 |
+
output_dir: ${save_dir}/${exp_name}/
|
| 89 |
+
pretrained: /var/scratch/qzhang2/SLAM-Former/ckpt/checkpoint-10.pth.model
|
| 90 |
+
print_freq: 10
|
| 91 |
+
print_img_freq: 50000000
|
| 92 |
+
rank: 0
|
| 93 |
+
resume: null
|
| 94 |
+
retention_ratio: 0.5
|
| 95 |
+
pseudo_gt:
|
| 96 |
+
enable: false
|
| 97 |
+
cache_path: null
|
| 98 |
+
use_soft_targets: true
|
| 99 |
+
min_confidence: 0.5
|
| 100 |
+
min_support_pairs: 1
|
| 101 |
+
topk_pairs: 8
|
| 102 |
+
loss_type: hybrid
|
| 103 |
+
loss_weight_gate: 0.05
|
| 104 |
+
loss_weight_desc: 0.15
|
| 105 |
+
geometric_support_scale: 0.5
|
| 106 |
+
ranking_margin: 0.05
|
| 107 |
+
use_l2m: true
|
| 108 |
+
l2m_min_certainty: 0.35
|
| 109 |
+
l2m_min_inlier_ratio: 0.2
|
| 110 |
+
save_dir: /var/scratch/qzhang2/SLAM-Former/checkpoints
|
| 111 |
+
save_freq: 0.1
|
| 112 |
+
seed: 42
|
| 113 |
+
soft_mask_bias: 0.2
|
| 114 |
+
soft_mask_temperature: 0.25
|
| 115 |
+
start_epoch: 0
|
| 116 |
+
start_step: 0
|
| 117 |
+
submap_size: 6
|
| 118 |
+
task: SLAMFormer_Submap_Finetune
|
| 119 |
+
tbptt_window: 0
|
| 120 |
+
teacher: null
|
| 121 |
+
temporal_embed_mode: learned
|
| 122 |
+
test_criterion: DistillLoss()
|
| 123 |
+
test_dataset: 500 @ ARKitScenes_Multi(split='test', ROOT="${root_arkit}",
|
| 124 |
+
resolution=(518, 392), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test})
|
| 125 |
+
train_criterion: DistillLoss()
|
| 126 |
+
train_dataset: 4500 @ ${dataset_arkit} + 2250 @ ${dataset_scannetpp} + 4500 @ ${dataset_scannet} + 1200 @ ${dataset_hypersim} + 2250 @ ${dataset_blendedmvs} + 2250 @ ${dataset_megadepth} + 450 @ ${dataset_mvs_synth}
|
| 127 |
+
warmup_epochs: 0.5
|
| 128 |
+
weight_decay: 0.05
|
| 129 |
+
world_size: 1
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune_sub_only.yaml
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accum_iter: 1
|
| 2 |
+
allow_repeat: false
|
| 3 |
+
amp: 1
|
| 4 |
+
batch_size: 1
|
| 5 |
+
benchmark: false
|
| 6 |
+
custom_lr_scale: 1.0
|
| 7 |
+
data_root: /var/scratch/qzhang2/SLAM-Former/data/train
|
| 8 |
+
root_arkit: ${data_root}/processed_arkitscenes
|
| 9 |
+
root_scannetpp: ${data_root}/processed_scannetpp
|
| 10 |
+
root_scannet: ${data_root}/processed_scannet
|
| 11 |
+
root_hypersim: ${data_root}/hypersim
|
| 12 |
+
root_blendedmvs: ${data_root}/processed_blendedmvs
|
| 13 |
+
root_megadepth: ${data_root}/processed_megadepth
|
| 14 |
+
root_mvs_synth: ${data_root}/processed_mvs_synth
|
| 15 |
+
dataset_arkit: ARKitScenes_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_arkit}",
|
| 16 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 17 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 18 |
+
dataset_scannet: ScanNet_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannet}",
|
| 19 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 20 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 21 |
+
dataset_hypersim: HyperSim_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_hypersim}",
|
| 22 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 23 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 24 |
+
dataset_blendedmvs: BlendedMVS_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_blendedmvs}",
|
| 25 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 26 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 27 |
+
dataset_megadepth: MegaDepth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_megadepth}",
|
| 28 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 29 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 30 |
+
dataset_mvs_synth: MVS_Synth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_mvs_synth}",
|
| 31 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 32 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 33 |
+
dataset_scannetpp: ScanNetpp_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannetpp}",
|
| 34 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
|
| 35 |
+
(518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 36 |
+
desc_dim: 128
|
| 37 |
+
detach_frontend_tokens: true
|
| 38 |
+
dist_backend: nccl
|
| 39 |
+
dist_url: env://
|
| 40 |
+
distributed: false
|
| 41 |
+
enable_dynamic_boundary: false
|
| 42 |
+
enable_loop: true
|
| 43 |
+
enable_submap: true
|
| 44 |
+
enable_temporal: false
|
| 45 |
+
epochs: 10
|
| 46 |
+
eval_freq: 1
|
| 47 |
+
exp_name: submap_sub_only_softall_64_24_v1
|
| 48 |
+
fixed_length: true
|
| 49 |
+
freeze_encoder: true
|
| 50 |
+
gpu: 0
|
| 51 |
+
gradient_checkpointing: true
|
| 52 |
+
gumbel_tau: 5.0
|
| 53 |
+
loop_mask_mode: soft_all
|
| 54 |
+
retain_history_grad: true
|
| 55 |
+
submap_train_mode: full_token
|
| 56 |
+
submap_retrieval_topk: 0
|
| 57 |
+
submap_fetch_source: frontend
|
| 58 |
+
submap_descriptor_source: frontend
|
| 59 |
+
train_submap_modules_only: true
|
| 60 |
+
gumbel_tau_end: 0.1
|
| 61 |
+
gumbel_tau_start: 5.0
|
| 62 |
+
hydra:
|
| 63 |
+
run:
|
| 64 |
+
dir: ${save_dir}/${exp_name}
|
| 65 |
+
verbose: true
|
| 66 |
+
keep_freq: 1
|
| 67 |
+
load_only_encoder: false
|
| 68 |
+
local-rank: -1
|
| 69 |
+
logdir: ${save_dir}/${exp_name}/logs
|
| 70 |
+
long_context: false
|
| 71 |
+
lr: 1e-5
|
| 72 |
+
max_checkpoints: 10
|
| 73 |
+
max_recursive_submaps: 5
|
| 74 |
+
min_lr: 1e-8
|
| 75 |
+
n_corres_test: 0
|
| 76 |
+
n_corres_train: 0
|
| 77 |
+
num_imgs_vis: 4
|
| 78 |
+
num_test_views: 4
|
| 79 |
+
num_views: 24
|
| 80 |
+
num_views_arkit: 24
|
| 81 |
+
num_views_scannetpp: 24
|
| 82 |
+
num_views_scannet: 24
|
| 83 |
+
num_views_hypersim: 24
|
| 84 |
+
num_views_blendedmvs: 24
|
| 85 |
+
num_views_megadepth: 24
|
| 86 |
+
num_views_mvs_synth: 24
|
| 87 |
+
num_workers: 4
|
| 88 |
+
output_dir: ${save_dir}/${exp_name}/
|
| 89 |
+
pretrained: /var/scratch/qzhang2/SLAM-Former/ckpt/checkpoint-10.pth.model
|
| 90 |
+
print_freq: 10
|
| 91 |
+
print_img_freq: 50000000
|
| 92 |
+
rank: 0
|
| 93 |
+
resume: null
|
| 94 |
+
retention_ratio: 0.5
|
| 95 |
+
pseudo_gt:
|
| 96 |
+
enable: false
|
| 97 |
+
cache_path: null
|
| 98 |
+
use_soft_targets: true
|
| 99 |
+
min_confidence: 0.65
|
| 100 |
+
min_support_pairs: 1
|
| 101 |
+
topk_pairs: 4
|
| 102 |
+
loss_type: hybrid
|
| 103 |
+
loss_weight_gate: 0.1
|
| 104 |
+
loss_weight_desc: 0.1
|
| 105 |
+
geometric_support_scale: 0.25
|
| 106 |
+
ranking_margin: 0.1
|
| 107 |
+
use_l2m: false
|
| 108 |
+
l2m_min_certainty: 0.0
|
| 109 |
+
l2m_min_inlier_ratio: 0.0
|
| 110 |
+
save_dir: /var/scratch/qzhang2/SLAM-Former/checkpoints
|
| 111 |
+
save_freq: 0.1
|
| 112 |
+
seed: 42
|
| 113 |
+
soft_mask_bias: 0.2
|
| 114 |
+
soft_mask_temperature: 0.25
|
| 115 |
+
start_epoch: 0
|
| 116 |
+
start_step: 0
|
| 117 |
+
submap_size: 6
|
| 118 |
+
task: SLAMFormer_Submap_Finetune
|
| 119 |
+
tbptt_window: 0
|
| 120 |
+
teacher: null
|
| 121 |
+
temporal_embed_mode: learned
|
| 122 |
+
test_criterion: DistillLoss()
|
| 123 |
+
test_dataset: 500 @ ARKitScenes_Multi(split='test', ROOT="${root_arkit}",
|
| 124 |
+
resolution=(518, 392), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test})
|
| 125 |
+
train_criterion: DistillLoss()
|
| 126 |
+
train_dataset: 2250 @ ${dataset_scannetpp} + 450 @ ${dataset_mvs_synth} + 4500 @ ${dataset_arkit}
|
| 127 |
+
warmup_epochs: 0.5
|
| 128 |
+
weight_decay: 0.05
|
| 129 |
+
world_size: 1
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/mytrain.yaml
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
teacher: ../ckpt/model.pt # no use
|
| 2 |
+
pretrained: ../ckpt/pi3.pth
|
| 3 |
+
|
| 4 |
+
load_only_encoder: False
|
| 5 |
+
long_context: False
|
| 6 |
+
fixed_length: True
|
| 7 |
+
resume: Null
|
| 8 |
+
benchmark: False
|
| 9 |
+
num_views : 12
|
| 10 |
+
num_test_views : 4
|
| 11 |
+
n_corres_train: 0
|
| 12 |
+
n_corres_test: 0
|
| 13 |
+
|
| 14 |
+
train_criterion: DistillLoss()
|
| 15 |
+
test_criterion: DistillLoss()
|
| 16 |
+
allow_repeat: False
|
| 17 |
+
|
| 18 |
+
dataset3: ARKitScenes_Multi(allow_repeat=${allow_repeat}, split='train', ROOT='../data/train/processed_arkitscenes/',
|
| 19 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
|
| 20 |
+
transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 21 |
+
dataset5: ScanNetpp_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="../data/train/processed_scannetpp/",
|
| 22 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
|
| 23 |
+
transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 24 |
+
dataset6: ScanNet_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="../data/train/processed_scannet/",
|
| 25 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
|
| 26 |
+
transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 27 |
+
dataset7: HyperSim_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="../data/train/hypersim",
|
| 28 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
|
| 29 |
+
transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 30 |
+
dataset8: BlendedMVS_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="../data/train/processed_blendedmvs/",
|
| 31 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
|
| 32 |
+
transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 33 |
+
dataset9: MegaDepth_Multi(allow_repeat=${allow_repeat}, split="train", ROOT="../data/train/processed_megadepth",
|
| 34 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
|
| 35 |
+
transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 36 |
+
|
| 37 |
+
dataset14: MVS_Synth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="../data/train/processed_mvs_synth",
|
| 38 |
+
aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
|
| 39 |
+
transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
train_dataset: 4500 @ ${dataset3} + 2250 @ ${dataset5} + 4500 @ ${dataset6} + 1200 @ ${dataset7} + 2250 @ ${dataset8} + 2250 @ ${dataset9} + 450 @ ${dataset14}
|
| 43 |
+
|
| 44 |
+
test_dataset: 1000 @ ARKitScenes_Multi(split='test', ROOT='../data/train/processed_arkitscenes/', resolution=(518, 392), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test})
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
seed: 0
|
| 48 |
+
batch_size: 1
|
| 49 |
+
accum_iter: 1
|
| 50 |
+
gradient_checkpointing: False
|
| 51 |
+
epochs: 10
|
| 52 |
+
start_epoch: 0
|
| 53 |
+
start_step: 0
|
| 54 |
+
weight_decay: 0.05
|
| 55 |
+
|
| 56 |
+
lr: 1e-5
|
| 57 |
+
min_lr: 1e-8
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
warmup_epochs: 0.5
|
| 61 |
+
amp: 1
|
| 62 |
+
|
| 63 |
+
num_workers: 4 # 12
|
| 64 |
+
world_size: 1
|
| 65 |
+
local-rank: -1
|
| 66 |
+
dist_url: 'env://'
|
| 67 |
+
rank: 0
|
| 68 |
+
gpu: 0
|
| 69 |
+
distributed: False
|
| 70 |
+
dist_backend: 'nccl'
|
| 71 |
+
|
| 72 |
+
eval_freq: 1
|
| 73 |
+
save_freq: 0.1
|
| 74 |
+
max_checkpoints: 10
|
| 75 |
+
keep_freq: 1
|
| 76 |
+
print_freq: 10
|
| 77 |
+
print_img_freq: 50000000
|
| 78 |
+
num_imgs_vis: 4
|
| 79 |
+
save_dir: '../checkpoints'
|
| 80 |
+
|
| 81 |
+
exp_name: 'SLAMFormer_v1'
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
task: 'StreamVGGT'
|
| 87 |
+
logdir: ${save_dir}/${exp_name}/logs
|
| 88 |
+
output_dir: ${save_dir}/${exp_name}/
|
| 89 |
+
hydra:
|
| 90 |
+
verbose: True
|
| 91 |
+
run:
|
| 92 |
+
dir: ${save_dir}/${exp_name}
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/environment.yml
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: SLAM-Former
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
dependencies:
|
| 5 |
+
- _libgcc_mutex=0.1=main
|
| 6 |
+
- _openmp_mutex=5.1=1_gnu
|
| 7 |
+
- bzip2=1.0.8=h5eee18b_6
|
| 8 |
+
- ca-certificates=2025.12.2=h06a4308_0
|
| 9 |
+
- expat=2.7.4=h7354ed3_0
|
| 10 |
+
- ld_impl_linux-64=2.44=h153f514_2
|
| 11 |
+
- libexpat=2.7.4=h7354ed3_0
|
| 12 |
+
- libffi=3.4.4=h6a678d5_1
|
| 13 |
+
- libgcc=15.2.0=h69a1729_7
|
| 14 |
+
- libgcc-ng=15.2.0=h166f726_7
|
| 15 |
+
- libgomp=15.2.0=h4751f2c_7
|
| 16 |
+
- libnsl=2.0.0=h5eee18b_0
|
| 17 |
+
- libstdcxx=15.2.0=h39759b7_7
|
| 18 |
+
- libstdcxx-ng=15.2.0=hc03a8fd_7
|
| 19 |
+
- libuuid=1.41.5=h5eee18b_0
|
| 20 |
+
- libxcb=1.17.0=h9b100fa_0
|
| 21 |
+
- libzlib=1.3.1=hb25bd0a_0
|
| 22 |
+
- ncurses=6.5=h7934f7d_0
|
| 23 |
+
- openssl=3.5.5=h1b28b03_0
|
| 24 |
+
- packaging=25.0=py311h06a4308_1
|
| 25 |
+
- pip=26.0.1=pyhc872135_0
|
| 26 |
+
- pthread-stubs=0.3=h0ce48e5_1
|
| 27 |
+
- python=3.11.14=h6fa692b_0
|
| 28 |
+
- readline=8.3=hc2a1206_0
|
| 29 |
+
- setuptools=80.10.2=py311h06a4308_0
|
| 30 |
+
- sqlite=3.51.1=h3e8d24a_1
|
| 31 |
+
- tk=8.6.15=h54e0aa7_0
|
| 32 |
+
- tzdata=2026a=he532380_0
|
| 33 |
+
- wheel=0.46.3=py311h06a4308_0
|
| 34 |
+
- xorg-libx11=1.8.12=h9b100fa_1
|
| 35 |
+
- xorg-libxau=1.0.12=h9b100fa_0
|
| 36 |
+
- xorg-libxdmcp=1.1.5=h9b100fa_0
|
| 37 |
+
- xorg-xorgproto=2024.1=h5eee18b_1
|
| 38 |
+
- xz=5.8.2=h448239c_0
|
| 39 |
+
- zlib=1.3.1=hb25bd0a_0
|
| 40 |
+
- pip:
|
| 41 |
+
- absl-py==2.4.0
|
| 42 |
+
- accelerate==1.13.0
|
| 43 |
+
- addict==2.4.0
|
| 44 |
+
- aiofiles==24.1.0
|
| 45 |
+
- annotated-doc==0.0.4
|
| 46 |
+
- annotated-types==0.7.0
|
| 47 |
+
- antlr4-python3-runtime==4.9.3
|
| 48 |
+
- anyio==4.12.1
|
| 49 |
+
- argcomplete==3.6.3
|
| 50 |
+
- asttokens==3.0.1
|
| 51 |
+
- attrs==25.4.0
|
| 52 |
+
- beartype==0.22.9
|
| 53 |
+
- blinker==1.9.0
|
| 54 |
+
- brotli==1.2.0
|
| 55 |
+
- ccimport==0.4.4
|
| 56 |
+
- certifi==2026.2.25
|
| 57 |
+
- charset-normalizer==3.4.5
|
| 58 |
+
- click==8.3.1
|
| 59 |
+
- colorama==0.4.6
|
| 60 |
+
- colorlog==6.10.1
|
| 61 |
+
- comm==0.2.3
|
| 62 |
+
- configargparse==1.7.3
|
| 63 |
+
- contourpy==1.3.3
|
| 64 |
+
- cumm-cu124==0.7.11
|
| 65 |
+
- cycler==0.12.1
|
| 66 |
+
- dacite==1.9.2
|
| 67 |
+
- dash==4.0.0
|
| 68 |
+
- decorator==4.4.2
|
| 69 |
+
- einops==0.8.2
|
| 70 |
+
- embreex==2.17.7.post7
|
| 71 |
+
- evo==1.34.3
|
| 72 |
+
- executing==2.2.1
|
| 73 |
+
- fast-pytorch-kmeans==0.2.2
|
| 74 |
+
- fastapi==0.135.1
|
| 75 |
+
- fastjsonschema==2.21.2
|
| 76 |
+
- ffmpy==1.0.0
|
| 77 |
+
- filelock==3.25.0
|
| 78 |
+
- fire==0.7.1
|
| 79 |
+
- flask==3.1.3
|
| 80 |
+
- fonttools==4.61.1
|
| 81 |
+
- fsspec==2026.2.0
|
| 82 |
+
- gradio==6.9.0
|
| 83 |
+
- gradio-client==2.3.0
|
| 84 |
+
- groovy==0.1.2
|
| 85 |
+
- grpcio==1.78.0
|
| 86 |
+
- gsplat==1.5.3
|
| 87 |
+
- h11==0.16.0
|
| 88 |
+
- h5py==3.16.0
|
| 89 |
+
- hf-xet==1.3.2
|
| 90 |
+
- httpcore==1.0.9
|
| 91 |
+
- httpx==0.28.1
|
| 92 |
+
- huggingface-hub==1.6.0
|
| 93 |
+
- hydra-core==1.3.2
|
| 94 |
+
- idna==3.11
|
| 95 |
+
- imageio==2.37.2
|
| 96 |
+
- imageio-ffmpeg==0.6.0
|
| 97 |
+
- importlib-metadata==8.7.1
|
| 98 |
+
- ipython==9.10.0
|
| 99 |
+
- ipython-pygments-lexers==1.1.1
|
| 100 |
+
- ipywidgets==8.1.8
|
| 101 |
+
- itsdangerous==2.2.0
|
| 102 |
+
- jaxtyping==0.3.9
|
| 103 |
+
- jedi==0.19.2
|
| 104 |
+
- jinja2==3.1.6
|
| 105 |
+
- joblib==1.5.3
|
| 106 |
+
- jsonschema==4.26.0
|
| 107 |
+
- jsonschema-specifications==2025.9.1
|
| 108 |
+
- jupyter-core==5.9.1
|
| 109 |
+
- jupyterlab-widgets==3.0.16
|
| 110 |
+
- kiwisolver==1.4.9
|
| 111 |
+
- kornia==0.8.2
|
| 112 |
+
- kornia-rs==0.1.10
|
| 113 |
+
- lark==1.3.1
|
| 114 |
+
- lpips==0.1.4
|
| 115 |
+
- lxml==6.0.2
|
| 116 |
+
- lz4==4.4.5
|
| 117 |
+
- manifold3d==3.4.0
|
| 118 |
+
- mapbox-earcut==2.0.0
|
| 119 |
+
- markdown==3.10.2
|
| 120 |
+
- markdown-it-py==4.0.0
|
| 121 |
+
- markupsafe==3.0.3
|
| 122 |
+
- matplotlib==3.10.8
|
| 123 |
+
- matplotlib-inline==0.2.1
|
| 124 |
+
- mdurl==0.1.2
|
| 125 |
+
- moviepy==1.0.3
|
| 126 |
+
- mpmath==1.3.0
|
| 127 |
+
- msgspec==0.20.0
|
| 128 |
+
- narwhals==2.17.0
|
| 129 |
+
- natsort==8.4.0
|
| 130 |
+
- nbformat==5.10.4
|
| 131 |
+
- nest-asyncio==1.6.0
|
| 132 |
+
- networkx==3.6.1
|
| 133 |
+
- ninja==1.13.0
|
| 134 |
+
- numexpr==2.14.1
|
| 135 |
+
- numpy==2.2.6
|
| 136 |
+
- nvidia-cublas-cu12==12.8.4.1
|
| 137 |
+
- nvidia-cuda-cupti-cu12==12.8.90
|
| 138 |
+
- nvidia-cuda-nvrtc-cu12==12.8.93
|
| 139 |
+
- nvidia-cuda-runtime-cu12==12.8.90
|
| 140 |
+
- nvidia-cudnn-cu12==9.10.2.21
|
| 141 |
+
- nvidia-cufft-cu12==11.3.3.83
|
| 142 |
+
- nvidia-cufile-cu12==1.13.1.3
|
| 143 |
+
- nvidia-curand-cu12==10.3.9.90
|
| 144 |
+
- nvidia-cusolver-cu12==11.7.3.90
|
| 145 |
+
- nvidia-cusparse-cu12==12.5.8.93
|
| 146 |
+
- nvidia-cusparselt-cu12==0.7.1
|
| 147 |
+
- nvidia-nccl-cu12==2.27.5
|
| 148 |
+
- nvidia-nvjitlink-cu12==12.8.93
|
| 149 |
+
- nvidia-nvshmem-cu12==3.3.20
|
| 150 |
+
- nvidia-nvtx-cu12==12.8.90
|
| 151 |
+
- omegaconf==2.3.0
|
| 152 |
+
- open3d==0.18.0
|
| 153 |
+
- opencv-python==4.13.0.92
|
| 154 |
+
- orjson==3.11.7
|
| 155 |
+
- pandas==3.0.1
|
| 156 |
+
- parso==0.8.6
|
| 157 |
+
- pccm==0.4.16
|
| 158 |
+
- pexpect==4.9.0
|
| 159 |
+
- pillow==12.1.1
|
| 160 |
+
- platformdirs==4.9.4
|
| 161 |
+
- plotly==6.6.0
|
| 162 |
+
- plyfile==1.1.3
|
| 163 |
+
- portalocker==3.2.0
|
| 164 |
+
- proglog==0.1.12
|
| 165 |
+
- prompt-toolkit==3.0.52
|
| 166 |
+
- protobuf==7.34.0
|
| 167 |
+
- psutil==7.2.2
|
| 168 |
+
- ptyprocess==0.7.0
|
| 169 |
+
- pure-eval==0.2.3
|
| 170 |
+
- pyarrow==23.0.1
|
| 171 |
+
- pybind11==3.0.3
|
| 172 |
+
- pycollada==0.9.3
|
| 173 |
+
- pydantic==2.12.5
|
| 174 |
+
- pydantic-core==2.41.5
|
| 175 |
+
- pydub==0.25.1
|
| 176 |
+
- pyglet==1.5.31
|
| 177 |
+
- pygments==2.19.2
|
| 178 |
+
- pyparsing==3.3.2
|
| 179 |
+
- pyquaternion==0.9.9
|
| 180 |
+
- python-dateutil==2.9.0.post0
|
| 181 |
+
- python-multipart==0.0.22
|
| 182 |
+
- pytz==2026.1.post1
|
| 183 |
+
- pyyaml==6.0.3
|
| 184 |
+
- referencing==0.37.0
|
| 185 |
+
- regex==2026.2.28
|
| 186 |
+
- requests==2.32.5
|
| 187 |
+
- rerun-sdk==0.30.1
|
| 188 |
+
- retrying==1.4.2
|
| 189 |
+
- rich==14.3.3
|
| 190 |
+
- roma==1.5.6
|
| 191 |
+
- rosbags==0.11.0
|
| 192 |
+
- rpds-py==0.30.0
|
| 193 |
+
- rtree==1.4.1
|
| 194 |
+
- ruamel-yaml==0.19.1
|
| 195 |
+
- safehttpx==0.1.7
|
| 196 |
+
- safetensors==0.7.0
|
| 197 |
+
- scikit-learn==1.8.0
|
| 198 |
+
- scipy==1.17.1
|
| 199 |
+
- seaborn==0.13.2
|
| 200 |
+
- semantic-version==2.10.0
|
| 201 |
+
- shapely==2.1.2
|
| 202 |
+
- shellingham==1.5.4
|
| 203 |
+
- six==1.17.0
|
| 204 |
+
- slamformer==0.1.0
|
| 205 |
+
- spconv-cu124==2.3.8
|
| 206 |
+
- stack-data==0.6.3
|
| 207 |
+
- starlette==0.52.1
|
| 208 |
+
- svg-path==7.0
|
| 209 |
+
- sympy==1.14.0
|
| 210 |
+
- tensorboard==2.20.0
|
| 211 |
+
- tensorboard-data-server==0.7.2
|
| 212 |
+
- termcolor==3.3.0
|
| 213 |
+
- threadpoolctl==3.6.0
|
| 214 |
+
- timm==1.0.26
|
| 215 |
+
- tokenizers==0.22.2
|
| 216 |
+
- tomlkit==0.13.3
|
| 217 |
+
- torch==2.9.1
|
| 218 |
+
- torch-cluster==1.6.3+pt25cu124
|
| 219 |
+
- torch-scatter==2.1.2+pt25cu124
|
| 220 |
+
- torch-sparse==0.6.18+pt25cu124
|
| 221 |
+
- torch-spline-conv==1.2.2+pt25cu124
|
| 222 |
+
- torchvision==0.24.1
|
| 223 |
+
- tqdm==4.67.3
|
| 224 |
+
- traitlets==5.14.3
|
| 225 |
+
- transformers==5.3.0
|
| 226 |
+
- trimesh==4.11.3
|
| 227 |
+
- triton==3.5.1
|
| 228 |
+
- typer==0.24.1
|
| 229 |
+
- typing-extensions==4.15.0
|
| 230 |
+
- typing-inspection==0.4.2
|
| 231 |
+
- urllib3==2.6.3
|
| 232 |
+
- uvicorn==0.41.0
|
| 233 |
+
- vhacdx==0.0.10
|
| 234 |
+
- viser==1.0.24
|
| 235 |
+
- wadler-lindig==0.1.7
|
| 236 |
+
- wcwidth==0.6.0
|
| 237 |
+
- websockets==15.0.1
|
| 238 |
+
- werkzeug==3.1.6
|
| 239 |
+
- widgetsnbextension==4.0.15
|
| 240 |
+
- xxhash==3.6.0
|
| 241 |
+
- yapf==0.43.0
|
| 242 |
+
- yourdfpy==0.0.60
|
| 243 |
+
- zipp==3.23.0
|
| 244 |
+
- zstandard==0.25.0
|
| 245 |
+
prefix: /var/scratch/qzhang2/miniconda3/envs/SLAM-Former
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/eval_ate_scaled.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
import evo
|
| 7 |
+
from evo.core import metrics
|
| 8 |
+
import evo.main_ape as main_ape
|
| 9 |
+
from evo.core.metrics import PoseRelation
|
| 10 |
+
from evo.core.trajectory import PosePath3D
|
| 11 |
+
from evo.tools import file_interface
|
| 12 |
+
import evo.core.sync as sync
|
| 13 |
+
HAS_EVO = True
|
| 14 |
+
except ImportError:
|
| 15 |
+
HAS_EVO = False
|
| 16 |
+
print("EVO not found. Please install evo: pip install evo")
|
| 17 |
+
exit(1)
|
| 18 |
+
|
| 19 |
+
TUM_DIR = "/var/scratch/qzhang2/Feature-SLAM/datasets/tum"
|
| 20 |
+
RESULTS_DIR = os.environ.get("RESULTS_DIR", "./tum_results")
|
| 21 |
+
|
| 22 |
+
sequences = [d for d in os.listdir(RESULTS_DIR) if os.path.isdir(os.path.join(RESULTS_DIR, d))]
|
| 23 |
+
|
| 24 |
+
print(f"{'Sequence':<40} | {'ATE (m) [Unscaled]':<20} | {'ATE (m) [Scale Aligned]':<20}")
|
| 25 |
+
print("-" * 85)
|
| 26 |
+
|
| 27 |
+
def get_ate(est_file, gt_file, align_scale=False):
|
| 28 |
+
try:
|
| 29 |
+
traj_ref = file_interface.read_tum_trajectory_file(gt_file)
|
| 30 |
+
traj_est = file_interface.read_tum_trajectory_file(est_file)
|
| 31 |
+
traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est)
|
| 32 |
+
|
| 33 |
+
# Sim3 alignment (rotation, translation, and SCALE)
|
| 34 |
+
traj_est.align(traj_ref, correct_scale=align_scale)
|
| 35 |
+
|
| 36 |
+
result = main_ape.ape(traj_ref, traj_est, pose_relation=PoseRelation.translation_part, align=True, correct_scale=align_scale)
|
| 37 |
+
return result.stats["rmse"]
|
| 38 |
+
except Exception as e:
|
| 39 |
+
return f"Error: {e}"
|
| 40 |
+
|
| 41 |
+
for seq in sorted(sequences):
|
| 42 |
+
est_file = os.path.join(RESULTS_DIR, seq, f"final_traj.txt")
|
| 43 |
+
gt_file = os.path.join(TUM_DIR, seq, "groundtruth.txt")
|
| 44 |
+
|
| 45 |
+
if os.path.exists(est_file) and os.path.exists(gt_file):
|
| 46 |
+
ate_unscaled = get_ate(est_file, gt_file, align_scale=False)
|
| 47 |
+
ate_scaled = get_ate(est_file, gt_file, align_scale=True)
|
| 48 |
+
|
| 49 |
+
unscaled_str = f"{ate_unscaled:.4f}" if isinstance(ate_unscaled, float) else str(ate_unscaled)
|
| 50 |
+
scaled_str = f"{ate_scaled:.4f}" if isinstance(ate_scaled, float) else str(ate_scaled)
|
| 51 |
+
|
| 52 |
+
print(f"{seq:<40} | {unscaled_str:<20} | {scaled_str:<20}")
|
| 53 |
+
else:
|
| 54 |
+
print(f"{seq:<40} | Missing files or still running")
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/get_ate.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
import evo
|
| 7 |
+
from evo.core import metrics
|
| 8 |
+
import evo.main_ape as main_ape
|
| 9 |
+
from evo.core.metrics import PoseRelation
|
| 10 |
+
from evo.core.trajectory import PosePath3D
|
| 11 |
+
from evo.tools import file_interface
|
| 12 |
+
import evo.core.sync as sync
|
| 13 |
+
HAS_EVO = True
|
| 14 |
+
except ImportError:
|
| 15 |
+
HAS_EVO = False
|
| 16 |
+
print("EVO not found, using simple ATE calculation")
|
| 17 |
+
|
| 18 |
+
TUM_DIR = "/var/scratch/qzhang2/Feature-SLAM/datasets/tum"
|
| 19 |
+
RESULTS_DIR = os.environ.get("RESULTS_DIR", "./tum_results")
|
| 20 |
+
|
| 21 |
+
sequences = [d for d in os.listdir(RESULTS_DIR) if os.path.isdir(os.path.join(RESULTS_DIR, d))]
|
| 22 |
+
|
| 23 |
+
print(f"{'Sequence':<40} | {'ATE (m)':<15}")
|
| 24 |
+
print("-" * 58)
|
| 25 |
+
|
| 26 |
+
def get_ate(est_file, gt_file):
|
| 27 |
+
if HAS_EVO:
|
| 28 |
+
try:
|
| 29 |
+
traj_ref = file_interface.read_tum_trajectory_file(gt_file)
|
| 30 |
+
traj_est = file_interface.read_tum_trajectory_file(est_file)
|
| 31 |
+
traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est)
|
| 32 |
+
traj_est.align(traj_ref, correct_scale=False)
|
| 33 |
+
|
| 34 |
+
result = main_ape.ape(traj_ref, traj_est, pose_relation=PoseRelation.translation_part, align=True, correct_scale=False)
|
| 35 |
+
return result.stats["rmse"]
|
| 36 |
+
except Exception as e:
|
| 37 |
+
return f"Error: {e}"
|
| 38 |
+
else:
|
| 39 |
+
# Fallback to simple ATE if evo is not available
|
| 40 |
+
try:
|
| 41 |
+
est_data = np.loadtxt(est_file)
|
| 42 |
+
gt_data = np.loadtxt(gt_file)
|
| 43 |
+
|
| 44 |
+
# Simple timestamp matching
|
| 45 |
+
ate_sum = 0
|
| 46 |
+
count = 0
|
| 47 |
+
|
| 48 |
+
for est in est_data:
|
| 49 |
+
ts = est[0]
|
| 50 |
+
# Find closest gt timestamp
|
| 51 |
+
idx = np.argmin(np.abs(gt_data[:, 0] - ts))
|
| 52 |
+
if np.abs(gt_data[idx, 0] - ts) < 0.1: # 100ms threshold
|
| 53 |
+
diff = est[1:4] - gt_data[idx, 1:4]
|
| 54 |
+
ate_sum += np.sum(diff**2)
|
| 55 |
+
count += 1
|
| 56 |
+
|
| 57 |
+
if count > 0:
|
| 58 |
+
return np.sqrt(ate_sum / count)
|
| 59 |
+
return "No matches"
|
| 60 |
+
except Exception as e:
|
| 61 |
+
return f"Error"
|
| 62 |
+
|
| 63 |
+
for seq in sorted(sequences):
|
| 64 |
+
est_file = os.path.join(RESULTS_DIR, seq, f"final_traj.txt")
|
| 65 |
+
gt_file = os.path.join(TUM_DIR, seq, "groundtruth.txt")
|
| 66 |
+
|
| 67 |
+
if os.path.exists(est_file) and os.path.exists(gt_file):
|
| 68 |
+
ate = get_ate(est_file, gt_file)
|
| 69 |
+
if isinstance(ate, float):
|
| 70 |
+
print(f"{seq:<40} | {ate:.4f}")
|
| 71 |
+
else:
|
| 72 |
+
print(f"{seq:<40} | {ate}")
|
| 73 |
+
else:
|
| 74 |
+
print(f"{seq:<40} | Missing files or running")
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/publish_submap.sh
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
usage() {
|
| 5 |
+
cat <<'EOF'
|
| 6 |
+
Usage:
|
| 7 |
+
GITHUB_TOKEN=... ./publish_submap.sh [commit message]
|
| 8 |
+
|
| 9 |
+
Optional env vars:
|
| 10 |
+
SOURCE_ROOT Source repository root (default: current git top-level)
|
| 11 |
+
PUBLISH_DIR Export directory (default: /var/scratch/qzhang2/e2e-semantic-SLAM-publish)
|
| 12 |
+
TARGET_REPO GitHub repo in owner/name form (default: SlamMate/e2e-semantic-SLAM)
|
| 13 |
+
BRANCH Git branch to push (default: submap)
|
| 14 |
+
GH_TOKEN Alternative to GITHUB_TOKEN
|
| 15 |
+
GIT_USER_NAME Commit author name override
|
| 16 |
+
GIT_USER_EMAIL Commit author email override
|
| 17 |
+
EOF
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
if [[ "${1:-}" == "-h" || "${1:-}" == "--help" ]]; then
|
| 21 |
+
usage
|
| 22 |
+
exit 0
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
if [[ $# -gt 0 ]]; then
|
| 26 |
+
COMMIT_MSG="$*"
|
| 27 |
+
else
|
| 28 |
+
COMMIT_MSG="${COMMIT_MSG:-Auto export: $(date +%F_%H%M%S)}"
|
| 29 |
+
fi
|
| 30 |
+
|
| 31 |
+
SOURCE_ROOT="${SOURCE_ROOT:-$(git rev-parse --show-toplevel 2>/dev/null || pwd)}"
|
| 32 |
+
PUBLISH_DIR="${PUBLISH_DIR:-/var/scratch/qzhang2/e2e-semantic-SLAM-publish}"
|
| 33 |
+
TARGET_REPO="${TARGET_REPO:-SlamMate/e2e-semantic-SLAM}"
|
| 34 |
+
BRANCH="${BRANCH:-submap}"
|
| 35 |
+
TOKEN="${GITHUB_TOKEN:-${GH_TOKEN:-}}"
|
| 36 |
+
|
| 37 |
+
if [[ -z "${TOKEN}" ]]; then
|
| 38 |
+
echo "Missing GITHUB_TOKEN (or GH_TOKEN)." >&2
|
| 39 |
+
exit 1
|
| 40 |
+
fi
|
| 41 |
+
|
| 42 |
+
case "${PUBLISH_DIR}" in
|
| 43 |
+
/tmp/*|/var/tmp/*|/var/scratch/*) ;;
|
| 44 |
+
*)
|
| 45 |
+
echo "Refusing to use unsafe PUBLISH_DIR: ${PUBLISH_DIR}" >&2
|
| 46 |
+
exit 1
|
| 47 |
+
;;
|
| 48 |
+
esac
|
| 49 |
+
|
| 50 |
+
if [[ ! -d "${SOURCE_ROOT}" ]]; then
|
| 51 |
+
echo "SOURCE_ROOT does not exist: ${SOURCE_ROOT}" >&2
|
| 52 |
+
exit 1
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
rm -rf "${PUBLISH_DIR}"
|
| 56 |
+
mkdir -p "${PUBLISH_DIR}"
|
| 57 |
+
|
| 58 |
+
copy_file() {
|
| 59 |
+
local rel="$1"
|
| 60 |
+
if [[ -f "${SOURCE_ROOT}/${rel}" ]]; then
|
| 61 |
+
rsync -a "${SOURCE_ROOT}/${rel}" "${PUBLISH_DIR}/"
|
| 62 |
+
fi
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
copy_dir() {
|
| 66 |
+
local rel="$1"
|
| 67 |
+
if [[ -d "${SOURCE_ROOT}/${rel}" ]]; then
|
| 68 |
+
mkdir -p "${PUBLISH_DIR}/${rel}"
|
| 69 |
+
rsync -a \
|
| 70 |
+
--exclude '__pycache__/' \
|
| 71 |
+
--exclude '*.pyc' \
|
| 72 |
+
--exclude '*.pyo' \
|
| 73 |
+
"${SOURCE_ROOT}/${rel}/" "${PUBLISH_DIR}/${rel}/"
|
| 74 |
+
fi
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
for file in \
|
| 78 |
+
.gitignore \
|
| 79 |
+
README.md \
|
| 80 |
+
README_submap.md \
|
| 81 |
+
README_cluster_migration.md \
|
| 82 |
+
requirements.txt \
|
| 83 |
+
setup.py \
|
| 84 |
+
setup_env.sh \
|
| 85 |
+
run_tum.sh \
|
| 86 |
+
run_tum_top5.sh \
|
| 87 |
+
eval_ate_scaled.py \
|
| 88 |
+
get_ate.py \
|
| 89 |
+
submap_handoff.md \
|
| 90 |
+
publish_submap.sh
|
| 91 |
+
do
|
| 92 |
+
copy_file "${file}"
|
| 93 |
+
done
|
| 94 |
+
|
| 95 |
+
for dir in cloud_opt config slam src; do
|
| 96 |
+
copy_dir "${dir}"
|
| 97 |
+
done
|
| 98 |
+
|
| 99 |
+
if [[ ! -f "${PUBLISH_DIR}/README.md" ]]; then
|
| 100 |
+
echo "Export did not produce README.md; aborting." >&2
|
| 101 |
+
exit 1
|
| 102 |
+
fi
|
| 103 |
+
|
| 104 |
+
git -C "${PUBLISH_DIR}" init >/dev/null
|
| 105 |
+
git -C "${PUBLISH_DIR}" checkout -b "${BRANCH}" >/dev/null
|
| 106 |
+
|
| 107 |
+
GIT_USER_NAME="${GIT_USER_NAME:-$(git config --global user.name 2>/dev/null || echo Cascade)}"
|
| 108 |
+
GIT_USER_EMAIL="${GIT_USER_EMAIL:-$(git config --global user.email 2>/dev/null || echo cascade@example.com)}"
|
| 109 |
+
git -C "${PUBLISH_DIR}" config user.name "${GIT_USER_NAME}"
|
| 110 |
+
git -C "${PUBLISH_DIR}" config user.email "${GIT_USER_EMAIL}"
|
| 111 |
+
|
| 112 |
+
git -C "${PUBLISH_DIR}" add .
|
| 113 |
+
if git -C "${PUBLISH_DIR}" diff --cached --quiet; then
|
| 114 |
+
echo "No changes to commit in ${PUBLISH_DIR}."
|
| 115 |
+
else
|
| 116 |
+
git -C "${PUBLISH_DIR}" commit -m "${COMMIT_MSG}"
|
| 117 |
+
fi
|
| 118 |
+
|
| 119 |
+
git -C "${PUBLISH_DIR}" remote remove origin >/dev/null 2>&1 || true
|
| 120 |
+
git -C "${PUBLISH_DIR}" remote add origin "https://github.com/${TARGET_REPO}.git"
|
| 121 |
+
|
| 122 |
+
ASKPASS_HELPER="$(mktemp)"
|
| 123 |
+
cat >"${ASKPASS_HELPER}" <<'EOF'
|
| 124 |
+
#!/usr/bin/env bash
|
| 125 |
+
case "$1" in
|
| 126 |
+
*Username*) printf '%s\n' 'x-access-token' ;;
|
| 127 |
+
*Password*) printf '%s\n' "${GIT_TOKEN}" ;;
|
| 128 |
+
*) printf '%s\n' '' ;;
|
| 129 |
+
esac
|
| 130 |
+
EOF
|
| 131 |
+
chmod 700 "${ASKPASS_HELPER}"
|
| 132 |
+
trap 'rm -f "${ASKPASS_HELPER}"' EXIT
|
| 133 |
+
|
| 134 |
+
GIT_TOKEN="${TOKEN}" GIT_ASKPASS="${ASKPASS_HELPER}" GIT_TERMINAL_PROMPT=0 \
|
| 135 |
+
git -C "${PUBLISH_DIR}" push --force-with-lease -u origin "${BRANCH}"
|
| 136 |
+
|
| 137 |
+
echo "Published ${SOURCE_ROOT} -> ${TARGET_REPO} [${BRANCH}]"
|
| 138 |
+
echo "Export dir: ${PUBLISH_DIR}"
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/requirements.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.9.1
|
| 2 |
+
torchvision==0.24.1
|
| 3 |
+
numpy==2.2.6
|
| 4 |
+
Pillow
|
| 5 |
+
huggingface_hub
|
| 6 |
+
safetensors
|
| 7 |
+
roma
|
| 8 |
+
gradio
|
| 9 |
+
matplotlib
|
| 10 |
+
tqdm
|
| 11 |
+
opencv-python
|
| 12 |
+
scipy
|
| 13 |
+
einops
|
| 14 |
+
trimesh
|
| 15 |
+
tensorboard
|
| 16 |
+
pyglet<2
|
| 17 |
+
viser
|
| 18 |
+
gradio
|
| 19 |
+
lpips
|
| 20 |
+
hydra-core
|
| 21 |
+
h5py
|
| 22 |
+
accelerate
|
| 23 |
+
transformers
|
| 24 |
+
scikit-learn
|
| 25 |
+
gsplat
|
| 26 |
+
evo
|
| 27 |
+
open3d
|
| 28 |
+
rerun-sdk
|
| 29 |
+
kornia
|
| 30 |
+
moviepy==1.0.3
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/run_tum.sh
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=slam_tum
|
| 3 |
+
#SBATCH --output=slam_tum_%j.out
|
| 4 |
+
#SBATCH --partition=defq
|
| 5 |
+
#SBATCH --nodes=1
|
| 6 |
+
#SBATCH --gpus=1
|
| 7 |
+
#SBATCH --time=12:00:00
|
| 8 |
+
|
| 9 |
+
source /var/scratch/qzhang2/miniconda3/etc/profile.d/conda.sh
|
| 10 |
+
conda activate SLAM-Former
|
| 11 |
+
|
| 12 |
+
module load cuda12.1/toolkit/12.1
|
| 13 |
+
# Try with a different cudnn or skip it if it's causing issues. Let's not load cudnn module as it might conflict with pytorch's built-in cudnn
|
| 14 |
+
# module load cuDNN/cuda12.1/9.1.0.70
|
| 15 |
+
export CC=/opt/ohpc/pub/compiler/gcc/9.4.0/bin/gcc
|
| 16 |
+
export CXX=/opt/ohpc/pub/compiler/gcc/9.4.0/bin/g++
|
| 17 |
+
|
| 18 |
+
TUM_DIR="/var/scratch/qzhang2/Feature-SLAM/datasets/tum"
|
| 19 |
+
CKPT_PATH="${CKPT_PATH:-/var/scratch/qzhang2/SLAM-Former/ckpt/checkpoint-10.pth.model}"
|
| 20 |
+
RESULT_ROOT="${RESULT_ROOT:-./tum_results_aligned}"
|
| 21 |
+
SUBMAP_INFERENCE_MODE="${SUBMAP_INFERENCE_MODE:-full}"
|
| 22 |
+
LOOP_MASK_MODE="${LOOP_MASK_MODE:-soft_all}"
|
| 23 |
+
SUBMAP_TRAIN_MODE="${SUBMAP_TRAIN_MODE:-top5_dual_queue}"
|
| 24 |
+
SUBMAP_RETRIEVAL_TOPK="${SUBMAP_RETRIEVAL_TOPK:-5}"
|
| 25 |
+
SUBMAP_FETCH_SOURCE="${SUBMAP_FETCH_SOURCE:-frontend}"
|
| 26 |
+
SUBMAP_DESCRIPTOR_SOURCE="${SUBMAP_DESCRIPTOR_SOURCE:-frontend}"
|
| 27 |
+
MAX_RECURSIVE_SUBMAPS="${MAX_RECURSIVE_SUBMAPS:-5}"
|
| 28 |
+
CKPT_NAME="$(basename "$CKPT_PATH")"
|
| 29 |
+
CKPT_NAME="${CKPT_NAME%.pth.model}"
|
| 30 |
+
CKPT_NAME="${CKPT_NAME%.pth}"
|
| 31 |
+
CKPT_PARENT="$(basename "$(dirname "$CKPT_PATH")")"
|
| 32 |
+
CKPT_HASH="$(printf '%s' "$CKPT_PATH" | sha1sum | cut -c1-8)"
|
| 33 |
+
RUN_TAG="${RUN_TAG:-${CKPT_PARENT}__${CKPT_NAME}__${CKPT_HASH}}"
|
| 34 |
+
OUT_DIR="${OUT_DIR:-${RESULT_ROOT}/${RUN_TAG}}"
|
| 35 |
+
mkdir -p "$OUT_DIR"
|
| 36 |
+
|
| 37 |
+
case "$SUBMAP_INFERENCE_MODE" in
|
| 38 |
+
full)
|
| 39 |
+
DEMO_ARGS=()
|
| 40 |
+
;;
|
| 41 |
+
top5)
|
| 42 |
+
DEMO_ARGS=(
|
| 43 |
+
--loop_mask_mode "$LOOP_MASK_MODE"
|
| 44 |
+
--submap_train_mode "$SUBMAP_TRAIN_MODE"
|
| 45 |
+
--submap_retrieval_topk "$SUBMAP_RETRIEVAL_TOPK"
|
| 46 |
+
--submap_fetch_source "$SUBMAP_FETCH_SOURCE"
|
| 47 |
+
--submap_descriptor_source "$SUBMAP_DESCRIPTOR_SOURCE"
|
| 48 |
+
--max_recursive_submaps "$MAX_RECURSIVE_SUBMAPS"
|
| 49 |
+
)
|
| 50 |
+
;;
|
| 51 |
+
*)
|
| 52 |
+
echo "Unknown SUBMAP_INFERENCE_MODE=$SUBMAP_INFERENCE_MODE (expected full or top5)" >&2
|
| 53 |
+
exit 1
|
| 54 |
+
;;
|
| 55 |
+
esac
|
| 56 |
+
|
| 57 |
+
echo "Checkpoint: $CKPT_PATH"
|
| 58 |
+
echo "Run tag: $RUN_TAG"
|
| 59 |
+
echo "Output root: $OUT_DIR"
|
| 60 |
+
echo "Submap inference mode: $SUBMAP_INFERENCE_MODE"
|
| 61 |
+
|
| 62 |
+
for seq in "$TUM_DIR"/rgbd_dataset_freiburg1_*; do
|
| 63 |
+
if [ -d "$seq/rgb" ]; then
|
| 64 |
+
seq_name=$(basename "$seq")
|
| 65 |
+
echo "======================================"
|
| 66 |
+
echo "Running on $seq_name..."
|
| 67 |
+
echo "======================================"
|
| 68 |
+
|
| 69 |
+
# The demo expects the image folder which contains images. For TUM it is the 'rgb' folder.
|
| 70 |
+
python slam/demo_submap.py \
|
| 71 |
+
--ckpt_path "$CKPT_PATH" \
|
| 72 |
+
--image_folder "$seq/rgb" \
|
| 73 |
+
--output_dir "$OUT_DIR/$seq_name" \
|
| 74 |
+
--target_size 518 \
|
| 75 |
+
"${DEMO_ARGS[@]}"
|
| 76 |
+
fi
|
| 77 |
+
done
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/run_tum_top5.sh
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=slam_tum_top5
|
| 3 |
+
#SBATCH --output=slam_tum_top5_%j.out
|
| 4 |
+
#SBATCH --partition=defq
|
| 5 |
+
#SBATCH --nodes=1
|
| 6 |
+
#SBATCH --gpus=1
|
| 7 |
+
#SBATCH --time=12:00:00
|
| 8 |
+
|
| 9 |
+
source /var/scratch/qzhang2/miniconda3/etc/profile.d/conda.sh
|
| 10 |
+
conda activate SLAM-Former
|
| 11 |
+
|
| 12 |
+
module load cuda12.1/toolkit/12.1
|
| 13 |
+
# Try with a different cudnn or skip it if it's causing issues. Let's not load cudnn module as it might conflict with pytorch's built-in cudnn
|
| 14 |
+
# module load cuDNN/cuda12.1/9.1.0.70
|
| 15 |
+
export CC=/opt/ohpc/pub/compiler/gcc/9.4.0/bin/gcc
|
| 16 |
+
export CXX=/opt/ohpc/pub/compiler/gcc/9.4.0/bin/g++
|
| 17 |
+
|
| 18 |
+
TUM_DIR="/var/scratch/qzhang2/Feature-SLAM/datasets/tum"
|
| 19 |
+
CKPT_PATH="${CKPT_PATH:-/var/scratch/qzhang2/SLAM-Former/checkpoints/local_cluster_nv24_sub6/submap_only_pseudo_gt_high_recall_smoke/paper_local_submap_only_pseudo_gt_high_recall_smoke_nv24_sub6/checkpoint-last.pth}"
|
| 20 |
+
RESULT_ROOT="${RESULT_ROOT:-./tum_results_aligned_top5}"
|
| 21 |
+
SUBMAP_INFERENCE_MODE="${SUBMAP_INFERENCE_MODE:-top5}"
|
| 22 |
+
LOOP_MASK_MODE="${LOOP_MASK_MODE:-soft_all}"
|
| 23 |
+
SUBMAP_TRAIN_MODE="${SUBMAP_TRAIN_MODE:-top5_dual_queue}"
|
| 24 |
+
SUBMAP_RETRIEVAL_TOPK="${SUBMAP_RETRIEVAL_TOPK:-5}"
|
| 25 |
+
SUBMAP_FETCH_SOURCE="${SUBMAP_FETCH_SOURCE:-frontend}"
|
| 26 |
+
SUBMAP_DESCRIPTOR_SOURCE="${SUBMAP_DESCRIPTOR_SOURCE:-frontend}"
|
| 27 |
+
MAX_RECURSIVE_SUBMAPS="${MAX_RECURSIVE_SUBMAPS:-5}"
|
| 28 |
+
CKPT_NAME="$(basename "$CKPT_PATH")"
|
| 29 |
+
CKPT_NAME="${CKPT_NAME%.pth.model}"
|
| 30 |
+
CKPT_NAME="${CKPT_NAME%.pth}"
|
| 31 |
+
CKPT_PARENT="$(basename "$(dirname "$CKPT_PATH")")"
|
| 32 |
+
CKPT_HASH="$(printf '%s' "$CKPT_PATH" | sha1sum | cut -c1-8)"
|
| 33 |
+
RUN_TAG="${RUN_TAG:-${CKPT_PARENT}__${CKPT_NAME}__${SUBMAP_INFERENCE_MODE}__${CKPT_HASH}}"
|
| 34 |
+
OUT_DIR="${OUT_DIR:-${RESULT_ROOT}/${RUN_TAG}}"
|
| 35 |
+
mkdir -p "$OUT_DIR"
|
| 36 |
+
|
| 37 |
+
DEMO_ARGS=(
|
| 38 |
+
--loop_mask_mode "$LOOP_MASK_MODE"
|
| 39 |
+
--submap_train_mode "$SUBMAP_TRAIN_MODE"
|
| 40 |
+
--submap_retrieval_topk "$SUBMAP_RETRIEVAL_TOPK"
|
| 41 |
+
--submap_fetch_source "$SUBMAP_FETCH_SOURCE"
|
| 42 |
+
--submap_descriptor_source "$SUBMAP_DESCRIPTOR_SOURCE"
|
| 43 |
+
--max_recursive_submaps "$MAX_RECURSIVE_SUBMAPS"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
echo "Checkpoint: $CKPT_PATH"
|
| 47 |
+
echo "Run tag: $RUN_TAG"
|
| 48 |
+
echo "Output root: $OUT_DIR"
|
| 49 |
+
echo "Submap inference mode: $SUBMAP_INFERENCE_MODE"
|
| 50 |
+
echo "Comparative baseline: run_tum.sh with SUBMAP_INFERENCE_MODE=full and the same CKPT_PATH"
|
| 51 |
+
|
| 52 |
+
for seq in "$TUM_DIR"/rgbd_dataset_freiburg1_*; do
|
| 53 |
+
if [ -d "$seq/rgb" ]; then
|
| 54 |
+
seq_name=$(basename "$seq")
|
| 55 |
+
echo "======================================"
|
| 56 |
+
echo "Running on $seq_name..."
|
| 57 |
+
echo "======================================"
|
| 58 |
+
|
| 59 |
+
python slam/demo_submap.py \
|
| 60 |
+
--ckpt_path "$CKPT_PATH" \
|
| 61 |
+
--image_folder "$seq/rgb" \
|
| 62 |
+
--output_dir "$OUT_DIR/$seq_name" \
|
| 63 |
+
--target_size 518 \
|
| 64 |
+
"${DEMO_ARGS[@]}"
|
| 65 |
+
fi
|
| 66 |
+
done
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/setup.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name='slamformer',
|
| 5 |
+
version='0.1.0',
|
| 6 |
+
description='SLAM-Former: Putting SLAM into One Transformer.',
|
| 7 |
+
packages=find_packages(include=['evals', 'evals.*', 'src/slamformer', 'src/slamformer.*']),
|
| 8 |
+
)
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/setup_env.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=setup_slam
|
| 3 |
+
#SBATCH --output=setup_slam.out
|
| 4 |
+
#SBATCH --partition=defq
|
| 5 |
+
#SBATCH --nodes=1
|
| 6 |
+
#SBATCH --gpus=1
|
| 7 |
+
#SBATCH --time=02:00:00
|
| 8 |
+
|
| 9 |
+
source /var/scratch/qzhang2/miniconda3/etc/profile.d/conda.sh
|
| 10 |
+
conda activate SLAM-Former
|
| 11 |
+
|
| 12 |
+
module load cuda12.1/toolkit/12.1
|
| 13 |
+
module load cuDNN/cuda12.1/9.1.0.70
|
| 14 |
+
export CC=/opt/ohpc/pub/compiler/gcc/9.4.0/bin/gcc
|
| 15 |
+
export CXX=/opt/ohpc/pub/compiler/gcc/9.4.0/bin/g++
|
| 16 |
+
|
| 17 |
+
pip install -r requirements.txt
|
| 18 |
+
pip install -e .
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/__init__.py
ADDED
|
File without changes
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/audit_dataset_num_views.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import csv
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Callable
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import h5py
|
| 13 |
+
except Exception:
|
| 14 |
+
h5py = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
README_STATUS = {
|
| 18 |
+
"ARKitScenes": "released",
|
| 19 |
+
"ScanNet": "released",
|
| 20 |
+
"ScanNet++": "coming_soon_in_readme",
|
| 21 |
+
"HyperSim": "released_separate_hf",
|
| 22 |
+
"BlendedMVS": "coming_soon_in_readme",
|
| 23 |
+
"MegaDepth": "coming_soon_in_readme",
|
| 24 |
+
"MVS-Synth": "released",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
THRESHOLDS = [24, 32, 48, 64]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def retention_str(count: int | None, total: int | None) -> str:
|
| 31 |
+
if count is None or total in (None, 0):
|
| 32 |
+
return "N/A"
|
| 33 |
+
return f"{count}/{total} ({100.0 * count / total:.1f}%)"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def summarize_caps(caps: list[int], thresholds: list[int]) -> dict:
|
| 37 |
+
if not caps:
|
| 38 |
+
return {
|
| 39 |
+
"unit_count": 0,
|
| 40 |
+
"strict_cap_no_skip": None,
|
| 41 |
+
"median_cap": None,
|
| 42 |
+
"max_cap": None,
|
| 43 |
+
"threshold_counts": {t: None for t in thresholds},
|
| 44 |
+
}
|
| 45 |
+
arr = np.asarray(caps, dtype=np.int64)
|
| 46 |
+
return {
|
| 47 |
+
"unit_count": int(arr.size),
|
| 48 |
+
"strict_cap_no_skip": int(arr.min()),
|
| 49 |
+
"median_cap": int(np.median(arr)),
|
| 50 |
+
"max_cap": int(arr.max()),
|
| 51 |
+
"threshold_counts": {t: int((arr >= t).sum()) for t in thresholds},
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def audit_arkit(root: Path, thresholds: list[int]) -> dict:
|
| 56 |
+
meta_root = root / "Training"
|
| 57 |
+
all_meta = meta_root / "all_metadata.npz"
|
| 58 |
+
if not all_meta.is_file():
|
| 59 |
+
return {"available": False, "notes": "missing Training/all_metadata.npz"}
|
| 60 |
+
scene_caps = []
|
| 61 |
+
with np.load(all_meta) as data:
|
| 62 |
+
scenes = data["scenes"]
|
| 63 |
+
for scene in scenes:
|
| 64 |
+
scene_dir = meta_root / str(scene)
|
| 65 |
+
meta_path = scene_dir / "new_scene_metadata.npz"
|
| 66 |
+
if not scene_dir.is_dir() or not meta_path.is_file():
|
| 67 |
+
continue
|
| 68 |
+
with np.load(meta_path, allow_pickle=True) as data:
|
| 69 |
+
scene_len = len(data["images"])
|
| 70 |
+
best_group_len = max(
|
| 71 |
+
(len(group) + 1 for group in data["image_collection"].item().values()),
|
| 72 |
+
default=0,
|
| 73 |
+
)
|
| 74 |
+
scene_caps.append(min(scene_len, best_group_len))
|
| 75 |
+
stats = summarize_caps(scene_caps, thresholds)
|
| 76 |
+
stats.update(
|
| 77 |
+
{
|
| 78 |
+
"available": True,
|
| 79 |
+
"unit_name": "scene",
|
| 80 |
+
"constraint": "min(scene_len, best_image_collection_len_per_scene)",
|
| 81 |
+
"notes": "strict cap is dominated by a few very short scenes; higher num_views still work with scene skipping",
|
| 82 |
+
}
|
| 83 |
+
)
|
| 84 |
+
return stats
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def audit_scannetpp(root: Path, thresholds: list[int]) -> dict:
|
| 88 |
+
all_meta = root / "all_metadata.npz"
|
| 89 |
+
if not all_meta.is_file():
|
| 90 |
+
return {"available": False, "notes": "missing all_metadata.npz"}
|
| 91 |
+
with np.load(all_meta) as data:
|
| 92 |
+
scenes = data["scenes"]
|
| 93 |
+
scene_caps = []
|
| 94 |
+
for scene in scenes:
|
| 95 |
+
scene_dir = root / str(scene)
|
| 96 |
+
meta_path = scene_dir / "new_scene_metadata.npz"
|
| 97 |
+
images_dir = scene_dir / "images"
|
| 98 |
+
if not scene_dir.is_dir() or not meta_path.is_file() or not images_dir.is_dir():
|
| 99 |
+
continue
|
| 100 |
+
with np.load(meta_path, allow_pickle=True) as data:
|
| 101 |
+
images = data["images"]
|
| 102 |
+
imgs_on_disk = {name[:-4] for name in os.listdir(images_dir)}
|
| 103 |
+
dslr_ids = [
|
| 104 |
+
i
|
| 105 |
+
for i in range(len(images))
|
| 106 |
+
if images[i].startswith("DSC") and images[i] in imgs_on_disk
|
| 107 |
+
]
|
| 108 |
+
iphone_ids = [
|
| 109 |
+
i
|
| 110 |
+
for i in range(len(images))
|
| 111 |
+
if images[i].startswith("frame") and images[i] in imgs_on_disk
|
| 112 |
+
]
|
| 113 |
+
best_cap = 0
|
| 114 |
+
for ref_id, group in data["image_collection"].item().items():
|
| 115 |
+
group_len = len(group) + 1
|
| 116 |
+
video_len = len(dslr_ids) if images[ref_id].startswith("frame") else len(iphone_ids)
|
| 117 |
+
best_cap = max(best_cap, min(group_len, video_len))
|
| 118 |
+
if best_cap > 0:
|
| 119 |
+
scene_caps.append(best_cap)
|
| 120 |
+
stats = summarize_caps(scene_caps, thresholds)
|
| 121 |
+
stats.update(
|
| 122 |
+
{
|
| 123 |
+
"available": True,
|
| 124 |
+
"unit_name": "scene",
|
| 125 |
+
"constraint": "best min(group_len, paired_video_len) per scene",
|
| 126 |
+
"notes": "README still marks this split as coming soon, but local processed data is already present",
|
| 127 |
+
}
|
| 128 |
+
)
|
| 129 |
+
return stats
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def audit_scannet(root: Path, thresholds: list[int]) -> dict:
|
| 133 |
+
scans_train = root / "scans_train"
|
| 134 |
+
if not scans_train.is_dir():
|
| 135 |
+
return {"available": False, "notes": "missing scans_train"}
|
| 136 |
+
scene_caps = []
|
| 137 |
+
for scene in sorted(os.listdir(scans_train)):
|
| 138 |
+
if not scene.startswith("scene"):
|
| 139 |
+
continue
|
| 140 |
+
meta_path = scans_train / scene / "new_scene_metadata.npz"
|
| 141 |
+
if not meta_path.is_file():
|
| 142 |
+
continue
|
| 143 |
+
with np.load(meta_path, allow_pickle=True) as data:
|
| 144 |
+
scene_caps.append(len(data["images"]))
|
| 145 |
+
stats = summarize_caps(scene_caps, thresholds)
|
| 146 |
+
stats.update(
|
| 147 |
+
{
|
| 148 |
+
"available": True,
|
| 149 |
+
"unit_name": "scene",
|
| 150 |
+
"constraint": "scene_len",
|
| 151 |
+
"notes": "",
|
| 152 |
+
}
|
| 153 |
+
)
|
| 154 |
+
return stats
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def audit_hypersim(root: Path, thresholds: list[int]) -> dict:
|
| 158 |
+
if not root.is_dir():
|
| 159 |
+
return {"available": False, "notes": "missing hypersim root"}
|
| 160 |
+
scene_caps = []
|
| 161 |
+
for scene in sorted(os.listdir(root)):
|
| 162 |
+
scene_dir = root / scene
|
| 163 |
+
if not scene_dir.is_dir():
|
| 164 |
+
continue
|
| 165 |
+
for subscene in sorted(os.listdir(scene_dir)):
|
| 166 |
+
subscene_dir = scene_dir / subscene
|
| 167 |
+
if not subscene_dir.is_dir():
|
| 168 |
+
continue
|
| 169 |
+
rgb_paths = [name for name in os.listdir(subscene_dir) if name.endswith(".png")]
|
| 170 |
+
if rgb_paths:
|
| 171 |
+
scene_caps.append(len(rgb_paths))
|
| 172 |
+
stats = summarize_caps(scene_caps, thresholds)
|
| 173 |
+
stats.update(
|
| 174 |
+
{
|
| 175 |
+
"available": True,
|
| 176 |
+
"unit_name": "subscene",
|
| 177 |
+
"constraint": "scene_len",
|
| 178 |
+
"notes": "",
|
| 179 |
+
}
|
| 180 |
+
)
|
| 181 |
+
return stats
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def audit_mvs_synth(root: Path, thresholds: list[int]) -> dict:
|
| 185 |
+
if not root.is_dir():
|
| 186 |
+
return {"available": False, "notes": "missing processed_mvs_synth root"}
|
| 187 |
+
scene_caps = []
|
| 188 |
+
for scene in sorted(os.listdir(root)):
|
| 189 |
+
rgb_dir = root / scene / "rgb"
|
| 190 |
+
if not rgb_dir.is_dir():
|
| 191 |
+
continue
|
| 192 |
+
scene_len = len([name for name in os.listdir(rgb_dir) if name.endswith(".jpg")])
|
| 193 |
+
if scene_len > 0:
|
| 194 |
+
scene_caps.append(scene_len)
|
| 195 |
+
stats = summarize_caps(scene_caps, thresholds)
|
| 196 |
+
stats.update(
|
| 197 |
+
{
|
| 198 |
+
"available": True,
|
| 199 |
+
"unit_name": "scene",
|
| 200 |
+
"constraint": "scene_len",
|
| 201 |
+
"notes": "",
|
| 202 |
+
}
|
| 203 |
+
)
|
| 204 |
+
return stats
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def audit_megadepth(root: Path, thresholds: list[int]) -> dict:
|
| 208 |
+
sets_path = root / "megadepth_sets_64.npz"
|
| 209 |
+
if not sets_path.is_file():
|
| 210 |
+
return {"available": False, "notes": "missing megadepth_sets_64.npz; code hard cap is 64"}
|
| 211 |
+
with np.load(sets_path, allow_pickle=True) as data:
|
| 212 |
+
scenes = data["scenes"]
|
| 213 |
+
sets = data["sets"]
|
| 214 |
+
valid_scene = np.array([not str(scene).startswith(("0015", "0022")) for scene in scenes])
|
| 215 |
+
valid_scene_ids = np.nonzero(valid_scene)[0]
|
| 216 |
+
train_mask = np.in1d(sets[:, 0], valid_scene_ids)
|
| 217 |
+
caps = [64] * int(train_mask.sum())
|
| 218 |
+
stats = summarize_caps(caps, thresholds)
|
| 219 |
+
stats.update(
|
| 220 |
+
{
|
| 221 |
+
"available": True,
|
| 222 |
+
"unit_name": "set",
|
| 223 |
+
"constraint": "fixed_64_image_set",
|
| 224 |
+
"notes": "hard code cap is 64 because the loader slices image_idxs = sets[idx][1:65]",
|
| 225 |
+
}
|
| 226 |
+
)
|
| 227 |
+
return stats
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def build_adjacency_list(score_matrix: np.ndarray, thresh: float = 0.2) -> list[list[int]]:
|
| 231 |
+
score_matrix = score_matrix - thresh
|
| 232 |
+
score_matrix[score_matrix < 0] = 0
|
| 233 |
+
rows, cols = np.nonzero(score_matrix)
|
| 234 |
+
adjacency = [[] for _ in range(len(score_matrix))]
|
| 235 |
+
for row, col in zip(rows, cols):
|
| 236 |
+
adjacency[row].append(int(col))
|
| 237 |
+
return adjacency
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def reachable_count(adjacency: list[list[int]], start_index: int) -> int:
|
| 241 |
+
stack = [start_index]
|
| 242 |
+
visited = set()
|
| 243 |
+
while stack:
|
| 244 |
+
node = stack.pop()
|
| 245 |
+
if node in visited:
|
| 246 |
+
continue
|
| 247 |
+
visited.add(node)
|
| 248 |
+
stack.extend(neighbor for neighbor in adjacency[node] if neighbor not in visited)
|
| 249 |
+
return len(visited)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def audit_blendedmvs(root: Path, thresholds: list[int]) -> dict:
|
| 253 |
+
overlap_path = root / "new_overlap.h5"
|
| 254 |
+
if not overlap_path.is_file():
|
| 255 |
+
return {"available": False, "notes": "missing new_overlap.h5"}
|
| 256 |
+
if h5py is None:
|
| 257 |
+
return {"available": False, "notes": "h5py is unavailable"}
|
| 258 |
+
ref_caps = []
|
| 259 |
+
with h5py.File(overlap_path, "r") as handle:
|
| 260 |
+
for scene_dir in handle.keys():
|
| 261 |
+
group = handle[scene_dir]
|
| 262 |
+
indices = group["indices"][:]
|
| 263 |
+
values = group["values"][:]
|
| 264 |
+
shape = group.attrs["shape"]
|
| 265 |
+
score_matrix = np.zeros(shape, dtype=np.float32)
|
| 266 |
+
score_matrix[indices[0], indices[1]] = values
|
| 267 |
+
adjacency = build_adjacency_list(score_matrix)
|
| 268 |
+
ref_caps.extend(reachable_count(adjacency, idx) for idx in range(len(adjacency)))
|
| 269 |
+
stats = summarize_caps(ref_caps, thresholds)
|
| 270 |
+
stats.update(
|
| 271 |
+
{
|
| 272 |
+
"available": True,
|
| 273 |
+
"unit_name": "reference",
|
| 274 |
+
"constraint": "reachable_unique_images_per_reference",
|
| 275 |
+
"notes": "for allow_repeat=false this is the reachable unique-image cap from the overlap graph",
|
| 276 |
+
}
|
| 277 |
+
)
|
| 278 |
+
return stats
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def build_dataset_specs(data_root: Path) -> list[tuple[str, Path, Callable[[Path, list[int]], dict]]]:
|
| 282 |
+
return [
|
| 283 |
+
("ARKitScenes", data_root / "processed_arkitscenes", audit_arkit),
|
| 284 |
+
("ScanNet", data_root / "processed_scannet", audit_scannet),
|
| 285 |
+
("ScanNet++", data_root / "processed_scannetpp", audit_scannetpp),
|
| 286 |
+
("HyperSim", data_root / "hypersim", audit_hypersim),
|
| 287 |
+
("BlendedMVS", data_root / "processed_blendedmvs", audit_blendedmvs),
|
| 288 |
+
("MegaDepth", data_root / "processed_megadepth", audit_megadepth),
|
| 289 |
+
("MVS-Synth", data_root / "processed_mvs_synth", audit_mvs_synth),
|
| 290 |
+
]
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def render_markdown(rows: list[dict], output_md: Path, thresholds: list[int]) -> None:
|
| 294 |
+
output_md.parent.mkdir(parents=True, exist_ok=True)
|
| 295 |
+
header = [
|
| 296 |
+
"# Paper Training Dataset num_views Audit",
|
| 297 |
+
"",
|
| 298 |
+
"This table is aligned with the paper's training dataset list in `refer/arXiv-2509.16909v1/sec/4_exp.tex` and the release status in `README.md`.",
|
| 299 |
+
"",
|
| 300 |
+
"`StrictCapNoSkip` means the largest `num_views` that keeps every current local scene/reference usable under `allow_repeat=false`.",
|
| 301 |
+
"Using a larger `num_views` can still work, but shorter scenes will be skipped by the dataset loader.",
|
| 302 |
+
"",
|
| 303 |
+
"| Dataset | READMEStatus | LocalAvailable | LocalPath | Unit | UnitCount | StrictCapNoSkip | MedianCap | MaxCap | >=24 | >=32 | >=48 | >=64 | Constraint | Notes |",
|
| 304 |
+
"|---|---|---|---|---:|---:|---:|---:|---:|---|---|---|---|---|---|",
|
| 305 |
+
]
|
| 306 |
+
lines = []
|
| 307 |
+
for row in rows:
|
| 308 |
+
lines.append(
|
| 309 |
+
"| {dataset} | {readme_status} | {local_available} | `{local_path}` | {unit_name} | {unit_count} | {strict_cap_no_skip} | {median_cap} | {max_cap} | {ge24} | {ge32} | {ge48} | {ge64} | {constraint} | {notes} |".format(
|
| 310 |
+
dataset=row["dataset"],
|
| 311 |
+
readme_status=row["readme_status"],
|
| 312 |
+
local_available=row["local_available"],
|
| 313 |
+
local_path=row["local_path"],
|
| 314 |
+
unit_name=row["unit_name"],
|
| 315 |
+
unit_count=row["unit_count"],
|
| 316 |
+
strict_cap_no_skip=row["strict_cap_no_skip"],
|
| 317 |
+
median_cap=row["median_cap"],
|
| 318 |
+
max_cap=row["max_cap"],
|
| 319 |
+
ge24=row[f"ge_{thresholds[0]}"],
|
| 320 |
+
ge32=row[f"ge_{thresholds[1]}"],
|
| 321 |
+
ge48=row[f"ge_{thresholds[2]}"],
|
| 322 |
+
ge64=row[f"ge_{thresholds[3]}"],
|
| 323 |
+
constraint=row["constraint"],
|
| 324 |
+
notes=row["notes"],
|
| 325 |
+
)
|
| 326 |
+
)
|
| 327 |
+
output_md.write_text("\n".join(header + lines) + "\n")
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def render_csv(rows: list[dict], output_csv: Path, thresholds: list[int]) -> None:
|
| 331 |
+
output_csv.parent.mkdir(parents=True, exist_ok=True)
|
| 332 |
+
fieldnames = [
|
| 333 |
+
"dataset",
|
| 334 |
+
"readme_status",
|
| 335 |
+
"local_available",
|
| 336 |
+
"local_path",
|
| 337 |
+
"unit_name",
|
| 338 |
+
"unit_count",
|
| 339 |
+
"strict_cap_no_skip",
|
| 340 |
+
"median_cap",
|
| 341 |
+
"max_cap",
|
| 342 |
+
f"ge_{thresholds[0]}_count",
|
| 343 |
+
f"ge_{thresholds[1]}_count",
|
| 344 |
+
f"ge_{thresholds[2]}_count",
|
| 345 |
+
f"ge_{thresholds[3]}_count",
|
| 346 |
+
"constraint",
|
| 347 |
+
"notes",
|
| 348 |
+
]
|
| 349 |
+
with output_csv.open("w", newline="") as handle:
|
| 350 |
+
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
| 351 |
+
writer.writeheader()
|
| 352 |
+
for row in rows:
|
| 353 |
+
writer.writerow(
|
| 354 |
+
{
|
| 355 |
+
"dataset": row["dataset"],
|
| 356 |
+
"readme_status": row["readme_status"],
|
| 357 |
+
"local_available": row["local_available"],
|
| 358 |
+
"local_path": row["local_path"],
|
| 359 |
+
"unit_name": row["unit_name"],
|
| 360 |
+
"unit_count": row["unit_count"],
|
| 361 |
+
"strict_cap_no_skip": row["strict_cap_no_skip"],
|
| 362 |
+
"median_cap": row["median_cap"],
|
| 363 |
+
"max_cap": row["max_cap"],
|
| 364 |
+
f"ge_{thresholds[0]}_count": row[f"ge_{thresholds[0]}_count"],
|
| 365 |
+
f"ge_{thresholds[1]}_count": row[f"ge_{thresholds[1]}_count"],
|
| 366 |
+
f"ge_{thresholds[2]}_count": row[f"ge_{thresholds[2]}_count"],
|
| 367 |
+
f"ge_{thresholds[3]}_count": row[f"ge_{thresholds[3]}_count"],
|
| 368 |
+
"constraint": row["constraint"],
|
| 369 |
+
"notes": row["notes"],
|
| 370 |
+
}
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def main() -> None:
|
| 375 |
+
repo_root = Path(__file__).resolve().parents[1]
|
| 376 |
+
parser = argparse.ArgumentParser()
|
| 377 |
+
parser.add_argument("--data-root", type=Path, default=repo_root / "data" / "train")
|
| 378 |
+
parser.add_argument("--output-md", type=Path, default=repo_root / "reports" / "paper_dataset_num_views.md")
|
| 379 |
+
parser.add_argument("--output-csv", type=Path, default=repo_root / "reports" / "paper_dataset_num_views.csv")
|
| 380 |
+
args = parser.parse_args()
|
| 381 |
+
|
| 382 |
+
rows = []
|
| 383 |
+
for dataset_name, local_path, audit_fn in build_dataset_specs(args.data_root):
|
| 384 |
+
stats = audit_fn(local_path, THRESHOLDS)
|
| 385 |
+
row = {
|
| 386 |
+
"dataset": dataset_name,
|
| 387 |
+
"readme_status": README_STATUS[dataset_name],
|
| 388 |
+
"local_available": "yes" if stats.get("available", False) else "no",
|
| 389 |
+
"local_path": str(local_path),
|
| 390 |
+
"unit_name": stats.get("unit_name", "N/A"),
|
| 391 |
+
"unit_count": stats.get("unit_count", "N/A"),
|
| 392 |
+
"strict_cap_no_skip": stats.get("strict_cap_no_skip", "N/A"),
|
| 393 |
+
"median_cap": stats.get("median_cap", "N/A"),
|
| 394 |
+
"max_cap": stats.get("max_cap", "N/A"),
|
| 395 |
+
"constraint": stats.get("constraint", "N/A"),
|
| 396 |
+
"notes": stats.get("notes", ""),
|
| 397 |
+
}
|
| 398 |
+
threshold_counts = stats.get("threshold_counts", {})
|
| 399 |
+
for threshold in THRESHOLDS:
|
| 400 |
+
count = threshold_counts.get(threshold)
|
| 401 |
+
row[f"ge_{threshold}"] = retention_str(count, stats.get("unit_count"))
|
| 402 |
+
row[f"ge_{threshold}_count"] = count if count is not None else "N/A"
|
| 403 |
+
rows.append(row)
|
| 404 |
+
|
| 405 |
+
render_markdown(rows, args.output_md, THRESHOLDS)
|
| 406 |
+
render_csv(rows, args.output_csv, THRESHOLDS)
|
| 407 |
+
print(args.output_md)
|
| 408 |
+
print(args.output_csv)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
if __name__ == "__main__":
|
| 412 |
+
main()
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/batched_dynamic_router.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BatchedDynamicSubmapRouter: Differentiable dynamic submap boundary predictor.
|
| 3 |
+
|
| 4 |
+
Key design decisions:
|
| 5 |
+
- Fully vectorized (Fix #4): NO Python-level conditionals on boundary_flag.
|
| 6 |
+
Every GPU always executes the full compute graph (descriptor + retrieval +
|
| 7 |
+
backendT). The boundary decision is applied as a differentiable gating
|
| 8 |
+
multiplier so DDP AllReduce never deadlocks.
|
| 9 |
+
- Masked pooling (Fix #3): valid_mask [B, max_K] prevents zero-padding from
|
| 10 |
+
injecting dead gradients / NaN into descriptors.
|
| 11 |
+
- Gumbel-Softmax temperature annealing (Fix #5): τ decays from tau_start
|
| 12 |
+
to tau_end over training via cosine schedule.
|
| 13 |
+
|
| 14 |
+
No original source files are modified.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BatchedDynamicSubmapRouter(nn.Module):
|
| 25 |
+
"""Differentiable boundary predictor with vectorized soft gating.
|
| 26 |
+
|
| 27 |
+
At every frame step the router:
|
| 28 |
+
1. Appends the new token to a padded accumulation buffer.
|
| 29 |
+
2. Predicts a boundary probability via an MLP + Gumbel-Softmax STE.
|
| 30 |
+
3. **Always** computes descriptors and runs backend for the full batch
|
| 31 |
+
(no Python branching).
|
| 32 |
+
4. Soft-gates the result: ``final = flag * backend + (1-flag) * frontend``.
|
| 33 |
+
5. Soft-resets the accumulation buffer using the gate.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
token_dim: 2C — full token feature dimension (default 2048).
|
| 37 |
+
boundary_hidden_dim: MLP hidden size (default 512).
|
| 38 |
+
tau_start: initial Gumbel-Softmax temperature.
|
| 39 |
+
tau_end: final Gumbel-Softmax temperature.
|
| 40 |
+
max_K: maximum frames per submap (buffer size).
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
token_dim: int = 2048,
|
| 46 |
+
boundary_hidden_dim: int = 512,
|
| 47 |
+
tau_start: float = 5.0,
|
| 48 |
+
tau_end: float = 0.1,
|
| 49 |
+
max_K: int = 20,
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.token_dim = token_dim
|
| 53 |
+
self.tau_start = tau_start
|
| 54 |
+
self.tau_end = tau_end
|
| 55 |
+
self.max_K = max_K
|
| 56 |
+
|
| 57 |
+
# Boundary predictor: takes [prev_token || curr_token] → logit
|
| 58 |
+
self.boundary_predictor = nn.Sequential(
|
| 59 |
+
nn.Linear(2 * token_dim, boundary_hidden_dim),
|
| 60 |
+
nn.ReLU(inplace=True),
|
| 61 |
+
nn.Linear(boundary_hidden_dim, 1),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# ── temperature annealing (Fix #5) ───────────────────
|
| 65 |
+
def get_tau(self, progress: float) -> float:
|
| 66 |
+
"""Cosine-annealed Gumbel-Softmax temperature.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
progress: training progress in [0, 1].
|
| 70 |
+
Returns:
|
| 71 |
+
current τ value.
|
| 72 |
+
"""
|
| 73 |
+
progress = min(max(progress, 0.0), 1.0)
|
| 74 |
+
return self.tau_end + 0.5 * (self.tau_start - self.tau_end) * (
|
| 75 |
+
1.0 + math.cos(math.pi * progress)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# ── boundary prediction (vectorized) ─────────────────
|
| 79 |
+
def predict_boundary(
|
| 80 |
+
self,
|
| 81 |
+
prev_token: torch.Tensor,
|
| 82 |
+
curr_token: torch.Tensor,
|
| 83 |
+
tau: float = 1.0,
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
"""Predict boundary flag for each element in the batch.
|
| 86 |
+
|
| 87 |
+
Uses Straight-Through Estimator (STE) via Gumbel-Softmax:
|
| 88 |
+
forward = hard 0/1, backward = soft gradients.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
prev_token: [B, D] — pooled previous-frame token.
|
| 92 |
+
curr_token: [B, D] — pooled current-frame token.
|
| 93 |
+
tau: Gumbel-Softmax temperature.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
boundary_flag: [B, 1] — hard 0.0 or 1.0 (differentiable via STE).
|
| 97 |
+
"""
|
| 98 |
+
combined = torch.cat([prev_token, curr_token], dim=-1) # [B, 2D]
|
| 99 |
+
logit = self.boundary_predictor(combined) # [B, 1]
|
| 100 |
+
|
| 101 |
+
# Stack [logit, -logit] → [B, 2] (cut / no-cut)
|
| 102 |
+
logits_2 = torch.cat([logit, -logit], dim=-1) # [B, 2]
|
| 103 |
+
one_hot = F.gumbel_softmax(logits_2, tau=tau, hard=True, dim=-1)
|
| 104 |
+
# one_hot[:, 0] = 1 means "cut", one_hot[:, 1] = 1 means "no cut"
|
| 105 |
+
boundary_flag = one_hot[:, :1] # [B, 1]
|
| 106 |
+
return boundary_flag
|
| 107 |
+
|
| 108 |
+
# ── masked pooling (Fix #3) ──────────────────────────
|
| 109 |
+
@staticmethod
|
| 110 |
+
def masked_pool(
|
| 111 |
+
accum_tokens: torch.Tensor,
|
| 112 |
+
valid_mask: torch.Tensor,
|
| 113 |
+
) -> torch.Tensor:
|
| 114 |
+
"""Safe masked average pooling over the accumulation buffer.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
accum_tokens: [B, max_K, P, C] — padded token buffer.
|
| 118 |
+
valid_mask: [B, max_K] — True where frames are valid.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
pooled: [B, C] — mean-pooled descriptor input.
|
| 122 |
+
"""
|
| 123 |
+
# [B, max_K, 1, 1]
|
| 124 |
+
mask_f = valid_mask.float().unsqueeze(-1).unsqueeze(-1)
|
| 125 |
+
# Sum over frames and patches, divide by valid count
|
| 126 |
+
# First pool patches within each frame: [B, max_K, C]
|
| 127 |
+
frame_pooled = accum_tokens.mean(dim=2) # [B, max_K, C]
|
| 128 |
+
mask_f_1d = valid_mask.float().unsqueeze(-1) # [B, max_K, 1]
|
| 129 |
+
# Weighted sum over frames
|
| 130 |
+
numerator = (frame_pooled * mask_f_1d).sum(dim=1) # [B, C]
|
| 131 |
+
denominator = mask_f_1d.sum(dim=1).clamp(min=1.0) # [B, 1]
|
| 132 |
+
return numerator / denominator # [B, C]
|
| 133 |
+
|
| 134 |
+
# ── vectorized soft reset (Fix #4) ───────────────────
|
| 135 |
+
@staticmethod
|
| 136 |
+
def soft_reset(
|
| 137 |
+
accum_tokens: torch.Tensor,
|
| 138 |
+
valid_mask: torch.Tensor,
|
| 139 |
+
accum_len: torch.Tensor,
|
| 140 |
+
boundary_flag: torch.Tensor,
|
| 141 |
+
):
|
| 142 |
+
"""Vectorized buffer reset using the boundary gate.
|
| 143 |
+
|
| 144 |
+
Elements where boundary_flag == 1 get zeroed out;
|
| 145 |
+
elements where boundary_flag == 0 are kept unchanged.
|
| 146 |
+
No Python-level branching.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
accum_tokens: [B, max_K, P, C].
|
| 150 |
+
valid_mask: [B, max_K].
|
| 151 |
+
accum_len: [B] long — current length per sequence.
|
| 152 |
+
boundary_flag: [B, 1] — 0.0 or 1.0.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
(accum_tokens, valid_mask, accum_len) — updated in-place style.
|
| 156 |
+
"""
|
| 157 |
+
keep = 1.0 - boundary_flag # [B, 1]
|
| 158 |
+
|
| 159 |
+
# Tokens: multiply by keep (broadcasts over max_K, P, C)
|
| 160 |
+
accum_tokens = accum_tokens * keep.unsqueeze(-1).unsqueeze(-1)
|
| 161 |
+
|
| 162 |
+
# Valid mask: zero out where boundary triggered
|
| 163 |
+
valid_mask = valid_mask & keep.squeeze(-1).bool()
|
| 164 |
+
|
| 165 |
+
# Length: reset to 0 where boundary triggered
|
| 166 |
+
accum_len = accum_len * keep.squeeze(-1).long()
|
| 167 |
+
|
| 168 |
+
return accum_tokens, valid_mask, accum_len
|
| 169 |
+
|
| 170 |
+
# ── full vectorized step ─────────────────────────────
|
| 171 |
+
def step(
|
| 172 |
+
self,
|
| 173 |
+
new_token: torch.Tensor,
|
| 174 |
+
prev_token: torch.Tensor,
|
| 175 |
+
accum_tokens: torch.Tensor,
|
| 176 |
+
valid_mask: torch.Tensor,
|
| 177 |
+
accum_len: torch.Tensor,
|
| 178 |
+
tau: float = 1.0,
|
| 179 |
+
):
|
| 180 |
+
"""Execute one frame step for the full batch (vectorized).
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
new_token: [B, P, C] — current frame token.
|
| 184 |
+
prev_token: [B, D] — pooled previous-frame representation.
|
| 185 |
+
accum_tokens: [B, max_K, P, C] — accumulation buffer.
|
| 186 |
+
valid_mask: [B, max_K] — bool mask.
|
| 187 |
+
accum_len: [B] long — frames accumulated so far.
|
| 188 |
+
tau: current Gumbel temperature.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
boundary_flag: [B, 1] — hard gate.
|
| 192 |
+
curr_desc: [B, C] — pooled descriptor (for loop retrieval).
|
| 193 |
+
accum_tokens: updated buffer.
|
| 194 |
+
valid_mask: updated mask.
|
| 195 |
+
accum_len: updated lengths.
|
| 196 |
+
"""
|
| 197 |
+
B = new_token.shape[0]
|
| 198 |
+
device = new_token.device
|
| 199 |
+
|
| 200 |
+
# 1. Append token to buffer (scatter, no branching)
|
| 201 |
+
# Clamp index to avoid out-of-bounds when buffer is full
|
| 202 |
+
write_idx = accum_len.clamp(max=self.max_K - 1) # [B]
|
| 203 |
+
for b in range(B):
|
| 204 |
+
# NOTE: this loop is over a *fixed* batch size (same on all GPUs),
|
| 205 |
+
# so it does NOT cause DDP divergence.
|
| 206 |
+
idx = write_idx[b].item()
|
| 207 |
+
accum_tokens[b, idx] = new_token[b]
|
| 208 |
+
valid_mask[b, idx] = True
|
| 209 |
+
accum_len = (accum_len + 1).clamp(max=self.max_K)
|
| 210 |
+
|
| 211 |
+
# 2. Boundary prediction (always for ALL B)
|
| 212 |
+
curr_pooled = new_token.mean(dim=1) # [B, C]
|
| 213 |
+
boundary_flag = self.predict_boundary(prev_token, curr_pooled, tau=tau)
|
| 214 |
+
|
| 215 |
+
# 3. Compute descriptor (always for ALL B, masked pooling)
|
| 216 |
+
curr_desc = self.masked_pool(accum_tokens, valid_mask) # [B, C]
|
| 217 |
+
|
| 218 |
+
# 4. Soft reset (vectorized, Fix #4)
|
| 219 |
+
accum_tokens, valid_mask, accum_len = self.soft_reset(
|
| 220 |
+
accum_tokens, valid_mask, accum_len, boundary_flag
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
return boundary_flag, curr_desc, accum_tokens, valid_mask, accum_len
|
| 224 |
+
|
| 225 |
+
# ── buffer initialization helper ─────────────────────
|
| 226 |
+
def init_buffers(
|
| 227 |
+
self,
|
| 228 |
+
batch_size: int,
|
| 229 |
+
P: int,
|
| 230 |
+
C: int,
|
| 231 |
+
device: torch.device,
|
| 232 |
+
):
|
| 233 |
+
"""Create fresh accumulation buffers for a batch.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
accum_tokens: [B, max_K, P, C] zeros.
|
| 237 |
+
valid_mask: [B, max_K] False.
|
| 238 |
+
accum_len: [B] zeros (long).
|
| 239 |
+
"""
|
| 240 |
+
accum_tokens = torch.zeros(batch_size, self.max_K, P, C, device=device)
|
| 241 |
+
valid_mask = torch.zeros(batch_size, self.max_K, dtype=torch.bool, device=device)
|
| 242 |
+
accum_len = torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 243 |
+
return accum_tokens, valid_mask, accum_len
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/demo.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os,sys
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import re
|
| 6 |
+
import cv2
|
| 7 |
+
import glob
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
import time
|
| 11 |
+
import open3d as o3d
|
| 12 |
+
from rich import print
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
from scipy.spatial.transform import Rotation as R
|
| 15 |
+
|
| 16 |
+
import rerun as rr
|
| 17 |
+
import rerun.blueprint as rrb
|
| 18 |
+
|
| 19 |
+
sys.path.append('src')
|
| 20 |
+
from slamformer.models.slamformer import SLAMFormer
|
| 21 |
+
|
| 22 |
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
| 23 |
+
|
| 24 |
+
sys.path.append(current_directory+'/../')
|
| 25 |
+
|
| 26 |
+
import slam.utils as utils
|
| 27 |
+
from slam.rerun_helper import log_camera, log_window
|
| 28 |
+
|
| 29 |
+
def strip_module(state_dict):
|
| 30 |
+
"""
|
| 31 |
+
Removes the 'module.' prefix from the keys of a state_dict.
|
| 32 |
+
Args:
|
| 33 |
+
state_dict (dict): The original state_dict with possible 'module.' prefixes.
|
| 34 |
+
Returns:
|
| 35 |
+
OrderedDict: A new state_dict with 'module.' prefixes removed.
|
| 36 |
+
"""
|
| 37 |
+
new_state_dict = OrderedDict()
|
| 38 |
+
for k, v in state_dict.items():
|
| 39 |
+
name = k[7:] if k.startswith("module.") else k
|
| 40 |
+
new_state_dict[name] = v
|
| 41 |
+
return new_state_dict
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SLAM:
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
outdir='output/tmp',
|
| 48 |
+
kf_th=0.1,
|
| 49 |
+
bn_every=10,
|
| 50 |
+
vis=False,
|
| 51 |
+
save_gmem=True,
|
| 52 |
+
ckpt_path='path/to/ckpt.pth',
|
| 53 |
+
target_size=518,
|
| 54 |
+
retention_ratio=0.5
|
| 55 |
+
):
|
| 56 |
+
|
| 57 |
+
self.outdir = outdir
|
| 58 |
+
self.kf_th=kf_th
|
| 59 |
+
self.save_gmem = save_gmem
|
| 60 |
+
self.bn_every=bn_every
|
| 61 |
+
self.vis = vis
|
| 62 |
+
self.ckpt_path = ckpt_path
|
| 63 |
+
self.target_size = target_size
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
self.times = []
|
| 67 |
+
self.kf_time = []
|
| 68 |
+
self.backend_time = []
|
| 69 |
+
|
| 70 |
+
# model params
|
| 71 |
+
self.model = SLAMFormer(retention_ratio=retention_ratio, bn_every=bn_every)
|
| 72 |
+
self.model = self.model.eval()
|
| 73 |
+
self.load_model()
|
| 74 |
+
self.model.eval()
|
| 75 |
+
self.model.to('cuda')
|
| 76 |
+
|
| 77 |
+
# SLAM params
|
| 78 |
+
self.fid = -1
|
| 79 |
+
self.kid = -1
|
| 80 |
+
self.kfids = []
|
| 81 |
+
self.last_kfid = 0
|
| 82 |
+
self.kf_timestamps = []
|
| 83 |
+
# frontend
|
| 84 |
+
self.frontend_times = 0
|
| 85 |
+
# Token map
|
| 86 |
+
self.map = None
|
| 87 |
+
self.map_opt = None
|
| 88 |
+
|
| 89 |
+
self.signal_backend = False
|
| 90 |
+
self.backend_every = self.bn_every #10
|
| 91 |
+
#
|
| 92 |
+
self.extrins = []
|
| 93 |
+
self.intrins = []
|
| 94 |
+
self.frames = []
|
| 95 |
+
self.kf_frames = []
|
| 96 |
+
|
| 97 |
+
#
|
| 98 |
+
self.K = None
|
| 99 |
+
self.update_K = False
|
| 100 |
+
|
| 101 |
+
# vis
|
| 102 |
+
if self.vis:
|
| 103 |
+
self.entity="world"
|
| 104 |
+
rr.init("SLAM", spawn=True)
|
| 105 |
+
rr.log(self.entity, rr.ViewCoordinates.RIGHT_HAND_Z_UP)
|
| 106 |
+
self.Twk = np.eye(4)
|
| 107 |
+
self.K = np.eye(3)
|
| 108 |
+
|
| 109 |
+
def load_model(self):
|
| 110 |
+
ckpt_raw = torch.load(self.ckpt_path, map_location='cuda', weights_only=False)
|
| 111 |
+
|
| 112 |
+
if isinstance(ckpt_raw, dict):
|
| 113 |
+
if "model" in ckpt_raw:
|
| 114 |
+
ckpt = ckpt_raw["model"]
|
| 115 |
+
print("Loaded state_dict from 'model' key in checkpoint.")
|
| 116 |
+
else:
|
| 117 |
+
ckpt = ckpt_raw
|
| 118 |
+
else:
|
| 119 |
+
ckpt = ckpt_raw
|
| 120 |
+
|
| 121 |
+
ckpt = utils.strip_module(ckpt)
|
| 122 |
+
self.model.load_state_dict(ckpt, strict=False)
|
| 123 |
+
del ckpt, ckpt_raw
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def time(self):
|
| 127 |
+
torch.cuda.synchronize()
|
| 128 |
+
return time.perf_counter()
|
| 129 |
+
|
| 130 |
+
def kf_detect(self, image):
|
| 131 |
+
if self.kid == -1:
|
| 132 |
+
self.extrins.append(torch.eye(4))
|
| 133 |
+
return True
|
| 134 |
+
|
| 135 |
+
frame = utils.load_image(image, self.target_size)
|
| 136 |
+
_,H,W = frame.shape
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
st = self.time #time.perf_counter()
|
| 140 |
+
token = self.model.KFT(torch.stack([self.kf_frames[-1],frame.cuda()]))
|
| 141 |
+
if self.vis:
|
| 142 |
+
# scale the pose to global
|
| 143 |
+
res = self.model.extract(token, cam_only=True)
|
| 144 |
+
#z = res['local_points'][0,0,:,:,-1].cpu().numpy()
|
| 145 |
+
if not hasattr(self,'depth_lask_kf'):
|
| 146 |
+
scale=1
|
| 147 |
+
else:
|
| 148 |
+
scale=1 #np.median(self.depth_last_kf/(z+1e-6))
|
| 149 |
+
camera_pose = res['camera_poses']
|
| 150 |
+
|
| 151 |
+
extrinsic = torch.inverse(camera_pose)
|
| 152 |
+
if extrinsic.shape[1] > 1:
|
| 153 |
+
extrinsic_ref=extrinsic.cpu()[0,-2]
|
| 154 |
+
extrinsic = extrinsic.cpu()[0,-1]
|
| 155 |
+
Tki = torch.inverse(camera_pose[0,0])@camera_pose[0,1]
|
| 156 |
+
Tki = Tki.cpu().numpy()
|
| 157 |
+
self.Twi = self.Twk@Tki
|
| 158 |
+
K44 = np.eye(4)
|
| 159 |
+
K44[:3,:3] = self.K
|
| 160 |
+
log_camera("camera",self.Twi, K44, kfd=True)
|
| 161 |
+
# make the window follow camera
|
| 162 |
+
log_window(f"{self.entity}",np.linalg.inv(self.Twi), K44)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
else:
|
| 166 |
+
res = self.model.extract(token, cam_only=True)
|
| 167 |
+
camera_pose = res['camera_poses']
|
| 168 |
+
extrinsic = torch.inverse(camera_pose)
|
| 169 |
+
if extrinsic.shape[1] > 1:
|
| 170 |
+
extrinsic_ref=extrinsic.cpu()[0,-2]
|
| 171 |
+
extrinsic = extrinsic.cpu()[0,-1]
|
| 172 |
+
self.kft_extrinsic_ref = torch.eye(4)#extrinsic_ref
|
| 173 |
+
|
| 174 |
+
dist = torch.sqrt(torch.sum((extrinsic[:3,3] - extrinsic_ref[:3,3])**2))
|
| 175 |
+
isKF = dist > self.kf_th
|
| 176 |
+
|
| 177 |
+
print(dist)
|
| 178 |
+
|
| 179 |
+
if isKF:
|
| 180 |
+
self.extrins.append(extrinsic)
|
| 181 |
+
return isKF
|
| 182 |
+
|
| 183 |
+
def frontend(self, image):
|
| 184 |
+
|
| 185 |
+
if self.vis:
|
| 186 |
+
rr.log("image", rr.Image(image[:,:,::-1]))#,static=True)
|
| 187 |
+
|
| 188 |
+
self.fid += 1
|
| 189 |
+
print('Frame', self.fid)
|
| 190 |
+
# run kf detector
|
| 191 |
+
st = self.time
|
| 192 |
+
enough_disparity = self.kf_detect(image)
|
| 193 |
+
self.kf_time.append(self.time-st)
|
| 194 |
+
if not enough_disparity:
|
| 195 |
+
return False
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
torch.cuda.empty_cache()
|
| 199 |
+
# run T-frontend
|
| 200 |
+
H_,W_,_ = image.shape
|
| 201 |
+
frame = utils.load_image(image, self.target_size)
|
| 202 |
+
self.H,self.W,_ = frame.shape
|
| 203 |
+
st = self.time
|
| 204 |
+
self.last_kf = frame.cuda()
|
| 205 |
+
self.kf_frames.append(self.last_kf)
|
| 206 |
+
self.last_kfid = self.fid
|
| 207 |
+
self.frames.append(self.last_kf.clone())
|
| 208 |
+
self.kid += 1
|
| 209 |
+
print("[italic purple] # KEYFRAME", self.kid)
|
| 210 |
+
self.kf_timestamps.append(self.cur_timestamp)
|
| 211 |
+
frame = frame.cuda()
|
| 212 |
+
st = self.time
|
| 213 |
+
|
| 214 |
+
if self.nkf == 1:
|
| 215 |
+
pass
|
| 216 |
+
elif self.nkf == 2:
|
| 217 |
+
token = self.model.frontendT(torch.stack([self.kf_frames[0],frame]))
|
| 218 |
+
self.map_add(token)
|
| 219 |
+
else:
|
| 220 |
+
token = self.model.frontendT(frame)
|
| 221 |
+
print(self.time-st)
|
| 222 |
+
|
| 223 |
+
self.map_add(token)
|
| 224 |
+
|
| 225 |
+
self.kfids.append(self.fid)
|
| 226 |
+
self.times.append(self.time-st)
|
| 227 |
+
torch.cuda.empty_cache()
|
| 228 |
+
|
| 229 |
+
# send signal to backend
|
| 230 |
+
self.frontend_times += 1
|
| 231 |
+
if self.frontend_times % self.backend_every == 0:
|
| 232 |
+
self.signal_backend = True
|
| 233 |
+
|
| 234 |
+
if self.vis and self.map is not None:
|
| 235 |
+
st = time.time()
|
| 236 |
+
map_before_bn = None
|
| 237 |
+
if self.map_opt is None:
|
| 238 |
+
map_before_bn = self.map
|
| 239 |
+
else:
|
| 240 |
+
S = self.map.shape[0]
|
| 241 |
+
S_oldopt = self.map_opt.shape[0]
|
| 242 |
+
|
| 243 |
+
map_before_bn = torch.cat([self.map_opt, self.map[S_oldopt:]],axis=0)
|
| 244 |
+
if self.nkf == 2:
|
| 245 |
+
ps,cs,confs,poses = self.extract(self.map)
|
| 246 |
+
|
| 247 |
+
else:
|
| 248 |
+
ps,cs,confs,poses = self.extract(self.map[-1:])
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
self.vis_mem = [ps,cs,confs,poses]
|
| 252 |
+
|
| 253 |
+
conf_threshold = np.percentile(confs, 15)
|
| 254 |
+
msk = confs>=conf_threshold
|
| 255 |
+
|
| 256 |
+
ps = ps[msk]
|
| 257 |
+
cs = cs[msk]
|
| 258 |
+
K44 = np.eye(4)
|
| 259 |
+
K44[:3,:3] = self.K
|
| 260 |
+
|
| 261 |
+
if self.nkf == 2:
|
| 262 |
+
log_camera(f"{self.entity}/camera_kf/0",poses[0], K44)
|
| 263 |
+
log_camera(f"{self.entity}/camera_kf/1",poses[1], K44)
|
| 264 |
+
|
| 265 |
+
rr.log(f"{self.entity}/lines/0to1", rr.LineStrips3D([poses[:,:3,3].tolist()],colors=[0,0,255],radii=[0.005]))
|
| 266 |
+
|
| 267 |
+
self.last_kf_pose = poses[1]
|
| 268 |
+
else:
|
| 269 |
+
log_camera(f"{self.entity}/camera_kf/{self.nkf-1}",poses.reshape(4,4), K44)
|
| 270 |
+
rr.log(f"{self.entity}/lines/{self.nkf-2}to{self.nkf-1}", rr.LineStrips3D([np.stack([self.last_kf_pose[:3,3],poses[0,:3,3]]).tolist()],colors=[0,0,255],radii=[0.005]))
|
| 271 |
+
|
| 272 |
+
self.last_kf_pose = poses[0]
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
rr.log(
|
| 276 |
+
f"{self.entity}/pointclouds/{self.nkf}",
|
| 277 |
+
rr.Points3D(ps, colors=cs, radii=0.01),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
print('log', time.time()-st)
|
| 281 |
+
|
| 282 |
+
self.Twk = poses[-1].reshape(4,4)
|
| 283 |
+
|
| 284 |
+
def backend(self, final=False):
|
| 285 |
+
if not self.signal_backend:
|
| 286 |
+
return
|
| 287 |
+
|
| 288 |
+
torch.cuda.empty_cache()
|
| 289 |
+
|
| 290 |
+
del self.model.fkv
|
| 291 |
+
torch.cuda.empty_cache()
|
| 292 |
+
print('Backending...', self.nkf, 'KFs')
|
| 293 |
+
st = time.perf_counter()
|
| 294 |
+
map_optimed = self.model.backendT(self.map.cuda())
|
| 295 |
+
self.backend_time.append(time.perf_counter()-st)
|
| 296 |
+
print('backend_take', time.perf_counter()-st)
|
| 297 |
+
torch.cuda.empty_cache()
|
| 298 |
+
|
| 299 |
+
if self.map_opt is not None:
|
| 300 |
+
del self.map_opt
|
| 301 |
+
torch.cuda.empty_cache()
|
| 302 |
+
self.map_opt = map_optimed.cpu()
|
| 303 |
+
|
| 304 |
+
self.signal_backend = False
|
| 305 |
+
torch.cuda.empty_cache()
|
| 306 |
+
|
| 307 |
+
if self.vis:
|
| 308 |
+
ps,cs,confs,poses = self.extract(self.map_opt)
|
| 309 |
+
self.vis_mem = [ps,cs,confs,poses]
|
| 310 |
+
conf_threshold = np.percentile(confs, 15)
|
| 311 |
+
msk = confs>=conf_threshold
|
| 312 |
+
|
| 313 |
+
ps = ps[msk]
|
| 314 |
+
cs = cs[msk]
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
for s in range(self.nkf+1):
|
| 318 |
+
rr.log(f"{self.entity}/pointclouds/{s}", rr.Points3D(np.array([])))
|
| 319 |
+
|
| 320 |
+
for s in range(self.nkf):
|
| 321 |
+
K44 = np.eye(4)
|
| 322 |
+
K44[:3,:3] = self.K
|
| 323 |
+
log_camera(f"{self.entity}/camera_kf/{s}",poses[s].reshape(4,4), K44, update=True)
|
| 324 |
+
|
| 325 |
+
for s in range(1, self.nkf):
|
| 326 |
+
rr.log(f"{self.entity}/lines/{s-1}to{s}", rr.LineStrips3D([poses[s-1:s+1,:3,3].tolist()],colors=[0,0,255],radii=[0.005]))
|
| 327 |
+
|
| 328 |
+
rr.log(
|
| 329 |
+
f"{self.entity}/pointclouds/{self.nkf}",
|
| 330 |
+
rr.Points3D(ps, colors=cs, radii=0.01),
|
| 331 |
+
)
|
| 332 |
+
self.last_kf_pose = poses[-1]
|
| 333 |
+
|
| 334 |
+
def step(self, timestamp, image):
|
| 335 |
+
if timestamp is None:
|
| 336 |
+
self.cur_timestamp = self.fid+1
|
| 337 |
+
else:
|
| 338 |
+
self.cur_timestamp = timestamp
|
| 339 |
+
|
| 340 |
+
self.frontend(image)
|
| 341 |
+
|
| 342 |
+
self.backend()
|
| 343 |
+
|
| 344 |
+
def map_add(self, token_kf):
|
| 345 |
+
if self.map is None:
|
| 346 |
+
self.map = token_kf.cpu() if self.save_gmem else token_kf #[tok.cpu() for tok in token_kf]
|
| 347 |
+
else:
|
| 348 |
+
if self.save_gmem:
|
| 349 |
+
self.map = torch.cat([self.map, token_kf.cpu()],axis=0) # S,P,C
|
| 350 |
+
else:
|
| 351 |
+
self.map = torch.cat([self.map, token_kf],axis=0) # S,P,C
|
| 352 |
+
|
| 353 |
+
@property
|
| 354 |
+
def nkf(self):
|
| 355 |
+
return self.kid+1
|
| 356 |
+
|
| 357 |
+
@property
|
| 358 |
+
def nf(self):
|
| 359 |
+
return self.fid+1
|
| 360 |
+
|
| 361 |
+
def terminate(self):
|
| 362 |
+
if self.nkf % self.backend_every != 0:
|
| 363 |
+
self.signal_backend = True
|
| 364 |
+
self.backend(final=True)
|
| 365 |
+
|
| 366 |
+
print(self.kf_time)
|
| 367 |
+
print(self.times)
|
| 368 |
+
print(self.backend_time)
|
| 369 |
+
print('frontend take', np.mean(self.times))
|
| 370 |
+
print('KFT')
|
| 371 |
+
print('total', np.sum(self.kf_time), 'FPS', float(len(self.kf_time))/np.sum(self.kf_time))
|
| 372 |
+
print('FT')
|
| 373 |
+
print('total', np.sum(self.times), 'FPS', float(len(self.times))/np.sum(self.times))
|
| 374 |
+
print('BT')
|
| 375 |
+
print('total', np.sum(self.backend_time), 'FPS', float(len(self.backend_time))/np.sum(self.backend_time))
|
| 376 |
+
print('Summary')
|
| 377 |
+
print('total', np.sum(self.kf_time)+np.sum(self.times)+np.sum(self.backend_time), 'FPS', float(len(self.kf_time))/(np.sum(self.kf_time)+np.sum(self.times)+np.sum(self.backend_time)))
|
| 378 |
+
self.save_result(f'{self.outdir}/final', self.map_opt)
|
| 379 |
+
|
| 380 |
+
def extract(self, map_all=None):
|
| 381 |
+
result = self.model.extract(map_all.cuda())
|
| 382 |
+
|
| 383 |
+
pts = result['points'].cpu().numpy() # 1,S,H,W,3
|
| 384 |
+
local_pts = result['local_points'].cpu().numpy() # 1,S,H,W,3
|
| 385 |
+
_,S,H,W,_ = pts.shape
|
| 386 |
+
conf = result['conf'].cpu().numpy()
|
| 387 |
+
point_clouds = [pts[0,s] for s in range(S)]
|
| 388 |
+
#conf_threshold = np.percentile(conf, 15)
|
| 389 |
+
#confs = [conf[0,s]>=conf_threshold for s in range(S)]
|
| 390 |
+
colors = torch.stack(self.frames[-S:]).permute(0,2,3,1).reshape(-1,3).cpu().numpy()[:,::-1] # S,H,W,C
|
| 391 |
+
confs = conf.reshape(-1)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
camera_pose = result['camera_poses'].cpu().numpy()[0] # S,4,4
|
| 395 |
+
pts = pts.reshape(-1,3)
|
| 396 |
+
colors = colors.reshape(-1,3)
|
| 397 |
+
|
| 398 |
+
# set depth for the last kf
|
| 399 |
+
self.depth_last_kf = local_pts[0,-1,:,:,-1]
|
| 400 |
+
|
| 401 |
+
return pts, colors, confs, camera_pose
|
| 402 |
+
|
| 403 |
+
def save_result(self, output_path = 'output/tmp', map_all=None, traj=True):
|
| 404 |
+
'''
|
| 405 |
+
if map_all is None:
|
| 406 |
+
map_all = self.map
|
| 407 |
+
'''
|
| 408 |
+
print(self.kfids)
|
| 409 |
+
|
| 410 |
+
if map_all is None:
|
| 411 |
+
map_all = self.map_opt
|
| 412 |
+
|
| 413 |
+
result = self.model.extract(map_all.cuda())
|
| 414 |
+
pts = result['points'].cpu().numpy() # 1,S,H,W,3
|
| 415 |
+
_,S,H,W,_ = pts.shape
|
| 416 |
+
conf = result['conf'].cpu().numpy()
|
| 417 |
+
point_clouds = [pts[0,s] for s in range(S)]
|
| 418 |
+
conf_threshold = np.percentile(conf, 15)
|
| 419 |
+
confs = [conf[0,s]>=conf_threshold for s in range(S)]
|
| 420 |
+
|
| 421 |
+
colors = torch.stack(self.frames).permute(0,2,3,1).reshape(-1,3).cpu().numpy()[:,::-1] # S,H,W,C
|
| 422 |
+
msk = np.stack(confs).reshape(-1)
|
| 423 |
+
pcd = o3d.geometry.PointCloud()
|
| 424 |
+
pcd.points = o3d.utility.Vector3dVector(pts.reshape(-1,3).astype(np.float64)[msk])
|
| 425 |
+
pcd.colors = o3d.utility.Vector3dVector(colors.reshape(-1,3).astype(np.float64)[msk])
|
| 426 |
+
#downpcd = pcd.voxel_down_sample(voxel_size=0.005)
|
| 427 |
+
o3d.io.write_point_cloud(f"{output_path}.ply", pcd)
|
| 428 |
+
camera_pose = result['camera_poses'].cpu() #torch.Size([1, 14, 4,4])
|
| 429 |
+
poses = camera_pose[0].numpy()
|
| 430 |
+
|
| 431 |
+
self.write_poses_to_file(f"{output_path}_traj.txt", poses, self.kf_timestamps)
|
| 432 |
+
self.save_framewise_pointclouds(f"{output_path}_pc", point_clouds, self.kf_timestamps, confs)
|
| 433 |
+
|
| 434 |
+
return result
|
| 435 |
+
|
| 436 |
+
def write_poses_to_file(self, filename, poses, frame_ids):
|
| 437 |
+
|
| 438 |
+
with open(filename, "w") as f:
|
| 439 |
+
assert len(poses) == len(frame_ids), "Number of provided poses and number of frame ids do not match"
|
| 440 |
+
for frame_id, pose in zip(frame_ids, poses):
|
| 441 |
+
x, y, z = pose[0:3, 3]
|
| 442 |
+
rotation_matrix = pose[0:3, 0:3]
|
| 443 |
+
quaternion = R.from_matrix(rotation_matrix).as_quat() # x, y, z, w
|
| 444 |
+
output = np.array([float(frame_id), x, y, z, *quaternion])
|
| 445 |
+
f.write(" ".join(f"{v:.8f}" for v in output) + "\n")
|
| 446 |
+
|
| 447 |
+
def save_framewise_pointclouds(self, filename, pointclouds, frame_ids, conf_masks):
|
| 448 |
+
os.makedirs(filename, exist_ok=True)
|
| 449 |
+
for frame_id, pointcloud, conf_masks in zip(frame_ids, pointclouds, conf_masks):
|
| 450 |
+
# save pcd as numpy array
|
| 451 |
+
np.savez(f"{filename}/{frame_id}.npz", pointcloud=pointcloud, mask=conf_masks)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def get_parser():
|
| 455 |
+
parser = argparse.ArgumentParser(description="SLAM-Former demo")
|
| 456 |
+
parser.add_argument("--ckpt_path", type=str, default="path/to/checkpoint.pth.model", help="Path to the checkpoint")
|
| 457 |
+
parser.add_argument("--image_folder", type=str, default="path/to/image/folder", help="Path to folder containing images")
|
| 458 |
+
parser.add_argument("--target_size", type=int, default=518, help="the target size of image(longer side)")
|
| 459 |
+
parser.add_argument("--output_dir", type=str, default="outputs/tmp", help="Path to save the output")
|
| 460 |
+
parser.add_argument("--stride", type=int, default=1, help="Frame stride for subsampling the input sequence")
|
| 461 |
+
parser.add_argument("--kf_th", type=float, default=0.1, help="Keyframe selection threshold (minimum translation distance)")
|
| 462 |
+
parser.add_argument("--retention_ratio", type=float, default=0.5, help="KV Pruning retention ratio")
|
| 463 |
+
parser.add_argument("--bn_every", type=int, default=10, help="Run backend optimization every N keyframes")
|
| 464 |
+
parser.add_argument("--vis", action="store_true", help="Enable real-time visualization with Rerun")
|
| 465 |
+
parser.add_argument("--resize_rate", type=float, default=1, help="Resize rate for input images before processing")
|
| 466 |
+
|
| 467 |
+
args = parser.parse_args()
|
| 468 |
+
return args
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
if __name__ == '__main__':
|
| 472 |
+
args = get_parser()
|
| 473 |
+
image_folder = args.image_folder
|
| 474 |
+
outdir = args.output_dir
|
| 475 |
+
os.makedirs(outdir, exist_ok=True)
|
| 476 |
+
|
| 477 |
+
if 'tum' in args.image_folder:
|
| 478 |
+
fx = 525.0 # focal length x
|
| 479 |
+
fy = 525.0 # focal length y
|
| 480 |
+
cx = 319.5 # optical center x
|
| 481 |
+
cy = 239.5 # optical center y
|
| 482 |
+
K = np.eye(3)
|
| 483 |
+
K[0,0] = fx
|
| 484 |
+
K[1,1] = fy
|
| 485 |
+
K[0,2] = cx
|
| 486 |
+
K[1,2] = cy
|
| 487 |
+
elif 'Replica' in args.image_folder:
|
| 488 |
+
fx = 600. # focal length x
|
| 489 |
+
fy = 600.0 # focal length y
|
| 490 |
+
cx = 599.5 # optical center x
|
| 491 |
+
cy = 339.5 # optical center y
|
| 492 |
+
K = np.eye(3)
|
| 493 |
+
K[0,0] = fx
|
| 494 |
+
K[1,1] = fy
|
| 495 |
+
K[0,2] = cx
|
| 496 |
+
K[1,2] = cy
|
| 497 |
+
else:
|
| 498 |
+
K = None
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
# Use the provided image folder path
|
| 502 |
+
print(f"Loading images from {image_folder}...")
|
| 503 |
+
image_names = [f for f in glob.glob(os.path.join(image_folder, "*"))
|
| 504 |
+
if "depth" not in os.path.basename(f).lower() and "txt" not in os.path.basename(f).lower()
|
| 505 |
+
and "db" not in os.path.basename(f).lower()]
|
| 506 |
+
image_names = utils.sort_images_by_number(image_names)
|
| 507 |
+
|
| 508 |
+
frame_ids = []
|
| 509 |
+
for path in image_names:
|
| 510 |
+
filename = os.path.basename(path)
|
| 511 |
+
match = re.search(r'\d+(?:\.\d+)?', filename) # matches integers and decimals
|
| 512 |
+
if match:
|
| 513 |
+
frame_ids.append(float(match.group()))
|
| 514 |
+
else:
|
| 515 |
+
raise ValueError(f"No number found in image name: {filename}")
|
| 516 |
+
|
| 517 |
+
print(f"Found {len(image_names)} images")
|
| 518 |
+
|
| 519 |
+
print('resize image', args.resize_rate)
|
| 520 |
+
|
| 521 |
+
slam = SLAM(
|
| 522 |
+
outdir=outdir,
|
| 523 |
+
kf_th=args.kf_th,
|
| 524 |
+
bn_every=args.bn_every,
|
| 525 |
+
vis=args.vis,
|
| 526 |
+
ckpt_path=args.ckpt_path,
|
| 527 |
+
target_size=args.target_size,
|
| 528 |
+
retention_ratio=args.retention_ratio
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
slam.K = K
|
| 532 |
+
for frame_id, image_name in zip(frame_ids[::args.stride], image_names[::args.stride]):
|
| 533 |
+
img = cv2.imread(image_name)
|
| 534 |
+
|
| 535 |
+
if args.resize_rate != 1:
|
| 536 |
+
H,W,_ = img.shape
|
| 537 |
+
img = cv2.resize(img, (int(W*args.resize_rate), int(H*args.resize_rate)), cv2.INTER_CUBIC)
|
| 538 |
+
slam.step(frame_id, img)
|
| 539 |
+
result = slam.terminate()
|
| 540 |
+
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/demo_infinite.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Infinite SLAM-Former — Inference Demo with all CLI toggles.
|
| 4 |
+
|
| 5 |
+
Wraps the original SLAMFormer + SLAM class with the new modules:
|
| 6 |
+
- GraphGatedMemoryManager (submap backend + loop closure)
|
| 7 |
+
- TemporalEmbedWrapper (dual temporal embedding injection)
|
| 8 |
+
- BatchedDynamicSubmapRouter (learned submap boundaries)
|
| 9 |
+
|
| 10 |
+
When all toggles are off, behaviour is identical to slam/demo.py.
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
# Vanilla (same as demo.py):
|
| 14 |
+
python slam/demo_infinite.py --ckpt_path ckpt/checkpoint.pth.model \
|
| 15 |
+
--image_folder /path/to/images --output_dir outputs/tmp
|
| 16 |
+
|
| 17 |
+
# With submap backend + loop closure + temporal embed:
|
| 18 |
+
python slam/demo_infinite.py --ckpt_path ckpt/checkpoint.pth.model \
|
| 19 |
+
--image_folder /path/to/images --output_dir outputs/tmp \
|
| 20 |
+
--enable_submap_backend --enable_loop_closure --enable_temporal_embed
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
+
import re
|
| 26 |
+
import glob
|
| 27 |
+
import argparse
|
| 28 |
+
import time
|
| 29 |
+
from collections import OrderedDict
|
| 30 |
+
|
| 31 |
+
import cv2
|
| 32 |
+
import torch
|
| 33 |
+
import numpy as np
|
| 34 |
+
from scipy.spatial.transform import Rotation as R
|
| 35 |
+
|
| 36 |
+
# ─── Path setup ──────────────────────────────────────────
|
| 37 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 38 |
+
PROJECT_DIR = os.path.dirname(SCRIPT_DIR)
|
| 39 |
+
SRC_DIR = os.path.join(PROJECT_DIR, "src")
|
| 40 |
+
|
| 41 |
+
if SRC_DIR not in sys.path:
|
| 42 |
+
sys.path.insert(0, SRC_DIR)
|
| 43 |
+
if PROJECT_DIR not in sys.path:
|
| 44 |
+
sys.path.insert(0, PROJECT_DIR)
|
| 45 |
+
|
| 46 |
+
from slamformer.models.slamformer import SLAMFormer
|
| 47 |
+
import slam.utils as utils
|
| 48 |
+
from slam.graph_gated_memory import (
|
| 49 |
+
GraphGatedMemoryManager,
|
| 50 |
+
TemporalEmbedWrapper,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def strip_module(state_dict):
|
| 55 |
+
if not isinstance(state_dict, dict):
|
| 56 |
+
return state_dict
|
| 57 |
+
new_state_dict = OrderedDict()
|
| 58 |
+
for k, v in state_dict.items():
|
| 59 |
+
name = k[7:] if k.startswith("module.") else k
|
| 60 |
+
new_state_dict[name] = v
|
| 61 |
+
return new_state_dict
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def strip_prefix(state_dict, prefix):
|
| 65 |
+
if not isinstance(state_dict, dict):
|
| 66 |
+
return state_dict
|
| 67 |
+
new_state_dict = OrderedDict()
|
| 68 |
+
for k, v in state_dict.items():
|
| 69 |
+
name = k[len(prefix):] if k.startswith(prefix) else k
|
| 70 |
+
new_state_dict[name] = v
|
| 71 |
+
return new_state_dict
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_cfg_value(cfg, key, default=None):
|
| 75 |
+
if cfg is None:
|
| 76 |
+
return default
|
| 77 |
+
try:
|
| 78 |
+
if hasattr(cfg, "get"):
|
| 79 |
+
return cfg.get(key, default)
|
| 80 |
+
except Exception:
|
| 81 |
+
pass
|
| 82 |
+
try:
|
| 83 |
+
return cfg[key]
|
| 84 |
+
except Exception:
|
| 85 |
+
return getattr(cfg, key, default)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def extract_checkpoint_parts(ckpt_raw):
|
| 89 |
+
train_cfg = None
|
| 90 |
+
memory_state = None
|
| 91 |
+
temporal_state = None
|
| 92 |
+
if isinstance(ckpt_raw, dict):
|
| 93 |
+
train_cfg = ckpt_raw.get("args")
|
| 94 |
+
if "model" in ckpt_raw:
|
| 95 |
+
ckpt = ckpt_raw["model"]
|
| 96 |
+
else:
|
| 97 |
+
ckpt = ckpt_raw
|
| 98 |
+
nested = ckpt_raw.get("submap_modules", {})
|
| 99 |
+
if isinstance(nested, dict):
|
| 100 |
+
memory_state = ckpt_raw.get("memory_mgr", nested.get("memory_mgr"))
|
| 101 |
+
temporal_state = ckpt_raw.get("temporal_wrapper", nested.get("temporal_wrapper"))
|
| 102 |
+
else:
|
| 103 |
+
memory_state = ckpt_raw.get("memory_mgr")
|
| 104 |
+
temporal_state = ckpt_raw.get("temporal_wrapper")
|
| 105 |
+
else:
|
| 106 |
+
ckpt = ckpt_raw
|
| 107 |
+
return ckpt, train_cfg, memory_state, temporal_state
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class InfiniteSLAM:
|
| 111 |
+
"""Extended SLAM pipeline with optional submap backend and loop closure.
|
| 112 |
+
|
| 113 |
+
When all enable_* flags are False, this behaves identically to the
|
| 114 |
+
original SLAM class in slam/demo.py.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
outdir: str = "output/tmp",
|
| 120 |
+
kf_th: float = 0.1,
|
| 121 |
+
bn_every: int = 10,
|
| 122 |
+
ckpt_path: str = "",
|
| 123 |
+
target_size: int = 518,
|
| 124 |
+
retention_ratio: float = 0.5,
|
| 125 |
+
# ── new toggles ──
|
| 126 |
+
enable_submap_backend: bool = False,
|
| 127 |
+
submap_size: int = 10,
|
| 128 |
+
max_recursive_submaps: int = 3,
|
| 129 |
+
enable_loop_closure: bool = False,
|
| 130 |
+
desc_dim: int = 128,
|
| 131 |
+
enable_temporal_embed: bool = False,
|
| 132 |
+
temporal_embed_mode: str = "learned",
|
| 133 |
+
):
|
| 134 |
+
self.outdir = outdir
|
| 135 |
+
self.kf_th = kf_th
|
| 136 |
+
self.bn_every = bn_every
|
| 137 |
+
self.target_size = target_size
|
| 138 |
+
ckpt_raw = None
|
| 139 |
+
self.train_cfg = None
|
| 140 |
+
if ckpt_path and os.path.exists(ckpt_path):
|
| 141 |
+
ckpt_raw = torch.load(ckpt_path, map_location="cuda", weights_only=False)
|
| 142 |
+
_, self.train_cfg, _, _ = extract_checkpoint_parts(ckpt_raw)
|
| 143 |
+
|
| 144 |
+
retention_ratio = float(get_cfg_value(self.train_cfg, "retention_ratio", retention_ratio))
|
| 145 |
+
self.enable_submap = bool(get_cfg_value(self.train_cfg, "enable_submap", enable_submap_backend))
|
| 146 |
+
self.enable_loop = bool(get_cfg_value(self.train_cfg, "enable_loop", enable_loop_closure))
|
| 147 |
+
self.enable_temporal = bool(get_cfg_value(self.train_cfg, "enable_temporal", enable_temporal_embed))
|
| 148 |
+
submap_size = int(get_cfg_value(self.train_cfg, "submap_size", submap_size))
|
| 149 |
+
max_recursive_submaps = int(get_cfg_value(self.train_cfg, "max_recursive_submaps", max_recursive_submaps))
|
| 150 |
+
desc_dim = int(get_cfg_value(self.train_cfg, "desc_dim", desc_dim))
|
| 151 |
+
loop_mask_mode = get_cfg_value(self.train_cfg, "loop_mask_mode", "hard_top1")
|
| 152 |
+
soft_mask_temperature = float(get_cfg_value(self.train_cfg, "soft_mask_temperature", 0.25))
|
| 153 |
+
soft_mask_bias = float(get_cfg_value(self.train_cfg, "soft_mask_bias", 0.2))
|
| 154 |
+
submap_train_mode = get_cfg_value(self.train_cfg, "submap_train_mode", "full_token")
|
| 155 |
+
submap_retrieval_topk = int(get_cfg_value(self.train_cfg, "submap_retrieval_topk", 0))
|
| 156 |
+
submap_fetch_source = get_cfg_value(self.train_cfg, "submap_fetch_source", "frontend")
|
| 157 |
+
submap_descriptor_source = get_cfg_value(self.train_cfg, "submap_descriptor_source", "frontend")
|
| 158 |
+
temporal_embed_mode = get_cfg_value(self.train_cfg, "temporal_embed_mode", temporal_embed_mode)
|
| 159 |
+
|
| 160 |
+
# ── Model (frozen) ───────────────────────────────
|
| 161 |
+
self.model = SLAMFormer(
|
| 162 |
+
retention_ratio=retention_ratio, bn_every=bn_every
|
| 163 |
+
)
|
| 164 |
+
self.model.eval()
|
| 165 |
+
|
| 166 |
+
# ── Memory manager ───────────────────────────────
|
| 167 |
+
embed_dim = self.model.dec_embed_dim
|
| 168 |
+
self.memory_mgr = GraphGatedMemoryManager(
|
| 169 |
+
submap_size=submap_size,
|
| 170 |
+
max_recursive_submaps=max_recursive_submaps,
|
| 171 |
+
desc_dim=desc_dim,
|
| 172 |
+
embed_dim=embed_dim,
|
| 173 |
+
loop_mask_mode=loop_mask_mode,
|
| 174 |
+
soft_mask_temperature=soft_mask_temperature,
|
| 175 |
+
soft_mask_bias=soft_mask_bias,
|
| 176 |
+
retain_history_grad=False,
|
| 177 |
+
submap_train_mode=submap_train_mode,
|
| 178 |
+
submap_retrieval_topk=submap_retrieval_topk,
|
| 179 |
+
submap_fetch_source=submap_fetch_source,
|
| 180 |
+
submap_descriptor_source=submap_descriptor_source,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# ── Temporal wrapper ─────────────────────────────
|
| 184 |
+
self.temporal_wrapper = TemporalEmbedWrapper(
|
| 185 |
+
embed_dim=embed_dim,
|
| 186 |
+
max_frames=5000,
|
| 187 |
+
mode=temporal_embed_mode,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
self._load_model(ckpt_raw)
|
| 191 |
+
self.model.to("cuda")
|
| 192 |
+
self.memory_mgr.to("cuda").eval()
|
| 193 |
+
self.temporal_wrapper.to("cuda").eval()
|
| 194 |
+
self.memory_mgr.reset()
|
| 195 |
+
|
| 196 |
+
# ── SLAM state ───────────────────────────────────
|
| 197 |
+
self.fid = -1
|
| 198 |
+
self.kid = -1
|
| 199 |
+
self.kfids = []
|
| 200 |
+
self.kf_timestamps = []
|
| 201 |
+
self.frames = []
|
| 202 |
+
self.kf_frames = []
|
| 203 |
+
self.map = None
|
| 204 |
+
self.map_opt = None
|
| 205 |
+
self.signal_backend = False
|
| 206 |
+
self.frontend_times = 0
|
| 207 |
+
self.times = []
|
| 208 |
+
self.kf_time = []
|
| 209 |
+
self.backend_time = []
|
| 210 |
+
self.K = None
|
| 211 |
+
self.cur_timestamp = 0
|
| 212 |
+
|
| 213 |
+
def _load_model(self, ckpt_raw):
|
| 214 |
+
if ckpt_raw is None:
|
| 215 |
+
return
|
| 216 |
+
ckpt, _, memory_state, temporal_state = extract_checkpoint_parts(ckpt_raw)
|
| 217 |
+
ckpt = strip_module(ckpt)
|
| 218 |
+
self.model.load_state_dict(ckpt, strict=False)
|
| 219 |
+
if memory_state is not None:
|
| 220 |
+
memory_state = strip_prefix(strip_module(memory_state), "memory_mgr.")
|
| 221 |
+
self.memory_mgr.load_state_dict(memory_state, strict=False)
|
| 222 |
+
if temporal_state is not None:
|
| 223 |
+
temporal_state = strip_prefix(strip_module(temporal_state), "temporal_wrapper.")
|
| 224 |
+
self.temporal_wrapper.load_state_dict(temporal_state, strict=False)
|
| 225 |
+
del ckpt
|
| 226 |
+
|
| 227 |
+
@property
|
| 228 |
+
def time(self):
|
| 229 |
+
torch.cuda.synchronize()
|
| 230 |
+
return time.perf_counter()
|
| 231 |
+
|
| 232 |
+
@property
|
| 233 |
+
def nkf(self):
|
| 234 |
+
return self.kid + 1
|
| 235 |
+
|
| 236 |
+
def kf_detect(self, image):
|
| 237 |
+
if self.kid == -1:
|
| 238 |
+
return True
|
| 239 |
+
frame = utils.load_image(image, self.target_size)
|
| 240 |
+
token = self.model.KFT(torch.stack([self.kf_frames[-1], frame.cuda()]))
|
| 241 |
+
res = self.model.extract(token, cam_only=True)
|
| 242 |
+
camera_pose = res["camera_poses"]
|
| 243 |
+
extrinsic = torch.inverse(camera_pose)
|
| 244 |
+
if extrinsic.shape[1] > 1:
|
| 245 |
+
extrinsic_ref = extrinsic.cpu()[0, -2]
|
| 246 |
+
extrinsic = extrinsic.cpu()[0, -1]
|
| 247 |
+
else:
|
| 248 |
+
return True
|
| 249 |
+
dist = torch.sqrt(torch.sum((extrinsic[:3, 3] - extrinsic_ref[:3, 3]) ** 2))
|
| 250 |
+
return dist > self.kf_th
|
| 251 |
+
|
| 252 |
+
def frontend(self, image):
|
| 253 |
+
self.fid += 1
|
| 254 |
+
print("Frame", self.fid)
|
| 255 |
+
|
| 256 |
+
st = self.time
|
| 257 |
+
enough_disparity = self.kf_detect(image)
|
| 258 |
+
self.kf_time.append(self.time - st)
|
| 259 |
+
if not enough_disparity:
|
| 260 |
+
return False
|
| 261 |
+
|
| 262 |
+
torch.cuda.empty_cache()
|
| 263 |
+
frame = utils.load_image(image, self.target_size)
|
| 264 |
+
st = self.time
|
| 265 |
+
self.last_kf = frame.cuda()
|
| 266 |
+
self.kf_frames.append(self.last_kf)
|
| 267 |
+
self.frames.append(self.last_kf.clone())
|
| 268 |
+
self.kid += 1
|
| 269 |
+
print(f" # KEYFRAME {self.kid}")
|
| 270 |
+
self.kf_timestamps.append(self.cur_timestamp)
|
| 271 |
+
|
| 272 |
+
st = self.time
|
| 273 |
+
if self.nkf == 1:
|
| 274 |
+
pass
|
| 275 |
+
elif self.nkf == 2:
|
| 276 |
+
token = self.model.frontendT(torch.stack([self.kf_frames[0], frame.cuda()]))
|
| 277 |
+
self._map_add(token)
|
| 278 |
+
else:
|
| 279 |
+
token = self.model.frontendT(frame.cuda())
|
| 280 |
+
self._map_add(token)
|
| 281 |
+
|
| 282 |
+
self.kfids.append(self.fid)
|
| 283 |
+
self.times.append(self.time - st)
|
| 284 |
+
torch.cuda.empty_cache()
|
| 285 |
+
|
| 286 |
+
# ── Submap accumulation ──────────────────────────
|
| 287 |
+
if self.enable_submap and self.map is not None:
|
| 288 |
+
last_token = self.map[-1:] if self.map.dim() == 3 else self.map
|
| 289 |
+
self.memory_mgr.accumulate(last_token, frame_id=self.fid)
|
| 290 |
+
|
| 291 |
+
self.frontend_times += 1
|
| 292 |
+
if self.frontend_times % self.bn_every == 0:
|
| 293 |
+
self.signal_backend = True
|
| 294 |
+
|
| 295 |
+
def backend(self, final=False):
|
| 296 |
+
if not self.signal_backend:
|
| 297 |
+
return
|
| 298 |
+
|
| 299 |
+
torch.cuda.empty_cache()
|
| 300 |
+
if hasattr(self.model, "fkv"):
|
| 301 |
+
del self.model.fkv
|
| 302 |
+
torch.cuda.empty_cache()
|
| 303 |
+
|
| 304 |
+
print("Backending...", self.nkf, "KFs")
|
| 305 |
+
st = time.perf_counter()
|
| 306 |
+
|
| 307 |
+
if self.enable_submap and self.memory_mgr.submap_complete:
|
| 308 |
+
# ── Submap-based backend ─────────────────────
|
| 309 |
+
hidden_B, loop_gate, meta = self.memory_mgr.finalize_submap(
|
| 310 |
+
model=self.model,
|
| 311 |
+
device=torch.device("cuda"),
|
| 312 |
+
temporal_wrapper=self.temporal_wrapper if self.enable_temporal else None,
|
| 313 |
+
enable_temporal_embed=self.enable_temporal,
|
| 314 |
+
enable_loop_closure=self.enable_loop,
|
| 315 |
+
)
|
| 316 |
+
# Use the active submap portion as map_opt
|
| 317 |
+
n_prev = meta['n_prev']
|
| 318 |
+
n_curr = meta['n_curr']
|
| 319 |
+
self.map_opt = hidden_B[n_prev:n_prev + n_curr].cpu()
|
| 320 |
+
else:
|
| 321 |
+
# ── Vanilla backend ──────────────────────────
|
| 322 |
+
map_optimed = self.model.backendT(self.map.cuda())
|
| 323 |
+
self.map_opt = map_optimed.cpu()
|
| 324 |
+
|
| 325 |
+
self.backend_time.append(time.perf_counter() - st)
|
| 326 |
+
print("backend_take", time.perf_counter() - st)
|
| 327 |
+
self.signal_backend = False
|
| 328 |
+
torch.cuda.empty_cache()
|
| 329 |
+
|
| 330 |
+
def step(self, timestamp, image):
|
| 331 |
+
self.cur_timestamp = timestamp if timestamp is not None else self.fid + 1
|
| 332 |
+
self.frontend(image)
|
| 333 |
+
self.backend()
|
| 334 |
+
|
| 335 |
+
def _map_add(self, token_kf):
|
| 336 |
+
if self.map is None:
|
| 337 |
+
self.map = token_kf.cpu()
|
| 338 |
+
else:
|
| 339 |
+
self.map = torch.cat([self.map, token_kf.cpu()], axis=0)
|
| 340 |
+
|
| 341 |
+
def terminate(self):
|
| 342 |
+
if self.nkf % self.bn_every != 0:
|
| 343 |
+
self.signal_backend = True
|
| 344 |
+
self.backend(final=True)
|
| 345 |
+
|
| 346 |
+
print("Frontend times:", self.times)
|
| 347 |
+
print("Backend times:", self.backend_time)
|
| 348 |
+
if self.times:
|
| 349 |
+
print("Frontend avg:", np.mean(self.times))
|
| 350 |
+
if self.backend_time:
|
| 351 |
+
print("Backend avg:", np.mean(self.backend_time))
|
| 352 |
+
print("Summary FPS:", float(len(self.kf_time)) / (
|
| 353 |
+
np.sum(self.kf_time) + np.sum(self.times) + np.sum(self.backend_time) + 1e-9
|
| 354 |
+
))
|
| 355 |
+
|
| 356 |
+
self._save_result(f"{self.outdir}/final", self.map_opt)
|
| 357 |
+
|
| 358 |
+
def _save_result(self, output_path, map_all=None):
|
| 359 |
+
import open3d as o3d
|
| 360 |
+
|
| 361 |
+
print(self.kfids)
|
| 362 |
+
if map_all is None:
|
| 363 |
+
map_all = self.map_opt
|
| 364 |
+
|
| 365 |
+
map_gpu = map_all.cuda()
|
| 366 |
+
# Ensure shape_ is set correctly for extract
|
| 367 |
+
BN = map_gpu.shape[0]
|
| 368 |
+
# Use shape_ from the last frontendT call for H, W
|
| 369 |
+
_, _, H, W, ph, pw = self.model.shape_
|
| 370 |
+
self.model.shape_ = (1, BN, H, W, ph, pw)
|
| 371 |
+
result = self.model.extract(map_gpu)
|
| 372 |
+
pts = result["points"].cpu().numpy()
|
| 373 |
+
_, S, H, W, _ = pts.shape
|
| 374 |
+
conf = result["conf"].cpu().numpy()
|
| 375 |
+
conf_threshold = np.percentile(conf, 15)
|
| 376 |
+
confs = [conf[0, s] >= conf_threshold for s in range(S)]
|
| 377 |
+
|
| 378 |
+
colors = torch.stack(self.frames).permute(0, 2, 3, 1).reshape(-1, 3).cpu().numpy()[:, ::-1]
|
| 379 |
+
msk = np.stack(confs).reshape(-1)
|
| 380 |
+
pcd = o3d.geometry.PointCloud()
|
| 381 |
+
pcd.points = o3d.utility.Vector3dVector(pts.reshape(-1, 3).astype(np.float64)[msk])
|
| 382 |
+
pcd.colors = o3d.utility.Vector3dVector(colors.reshape(-1, 3).astype(np.float64)[msk])
|
| 383 |
+
o3d.io.write_point_cloud(f"{output_path}.ply", pcd)
|
| 384 |
+
|
| 385 |
+
camera_pose = result["camera_poses"].cpu()
|
| 386 |
+
poses = camera_pose[0].numpy()
|
| 387 |
+
self._write_poses(f"{output_path}_traj.txt", poses, self.kf_timestamps)
|
| 388 |
+
|
| 389 |
+
def _write_poses(self, filename, poses, frame_ids):
|
| 390 |
+
with open(filename, "w") as f:
|
| 391 |
+
for frame_id, pose in zip(frame_ids, poses):
|
| 392 |
+
x, y, z = pose[0:3, 3]
|
| 393 |
+
quat = R.from_matrix(pose[0:3, 0:3]).as_quat()
|
| 394 |
+
output = np.array([float(frame_id), x, y, z, *quat])
|
| 395 |
+
f.write(" ".join(f"{v:.8f}" for v in output) + "\n")
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def get_parser():
|
| 399 |
+
parser = argparse.ArgumentParser(description="Infinite SLAM-Former Demo")
|
| 400 |
+
# ── Original demo.py args ────────────────────────────
|
| 401 |
+
parser.add_argument("--ckpt_path", type=str, default="")
|
| 402 |
+
parser.add_argument("--image_folder", type=str, default="")
|
| 403 |
+
parser.add_argument("--target_size", type=int, default=518)
|
| 404 |
+
parser.add_argument("--output_dir", type=str, default="outputs/tmp")
|
| 405 |
+
parser.add_argument("--stride", type=int, default=1)
|
| 406 |
+
parser.add_argument("--kf_th", type=float, default=0.1)
|
| 407 |
+
parser.add_argument("--retention_ratio", type=float, default=0.5)
|
| 408 |
+
parser.add_argument("--bn_every", type=int, default=10)
|
| 409 |
+
parser.add_argument("--resize_rate", type=float, default=1.0)
|
| 410 |
+
|
| 411 |
+
# ── Task 1: submap backend ───────────────────────────
|
| 412 |
+
parser.add_argument("--enable_submap_backend", action="store_true")
|
| 413 |
+
parser.add_argument("--submap_size", type=int, default=10)
|
| 414 |
+
parser.add_argument("--max_recursive_submaps", type=int, default=3)
|
| 415 |
+
parser.add_argument("--enable_loop_closure", action="store_true")
|
| 416 |
+
parser.add_argument("--desc_dim", type=int, default=128)
|
| 417 |
+
|
| 418 |
+
# ── Fix #7: temporal embedding ───────────────────────
|
| 419 |
+
parser.add_argument("--enable_temporal_embed", action="store_true")
|
| 420 |
+
parser.add_argument("--temporal_embed_mode", type=str, default="learned",
|
| 421 |
+
choices=["learned", "sinusoidal"])
|
| 422 |
+
|
| 423 |
+
return parser
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
if __name__ == "__main__":
|
| 427 |
+
args = get_parser().parse_args()
|
| 428 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 429 |
+
|
| 430 |
+
# ── Camera intrinsics (same logic as original demo.py) ──
|
| 431 |
+
if "tum" in args.image_folder:
|
| 432 |
+
K = np.eye(3)
|
| 433 |
+
K[0, 0], K[1, 1] = 525.0, 525.0
|
| 434 |
+
K[0, 2], K[1, 2] = 319.5, 239.5
|
| 435 |
+
elif "Replica" in args.image_folder:
|
| 436 |
+
K = np.eye(3)
|
| 437 |
+
K[0, 0], K[1, 1] = 600.0, 600.0
|
| 438 |
+
K[0, 2], K[1, 2] = 599.5, 339.5
|
| 439 |
+
else:
|
| 440 |
+
K = None
|
| 441 |
+
|
| 442 |
+
# ── Load images ──────────────────────────────────────
|
| 443 |
+
print(f"Loading images from {args.image_folder}...")
|
| 444 |
+
image_names = [
|
| 445 |
+
f for f in glob.glob(os.path.join(args.image_folder, "*"))
|
| 446 |
+
if "depth" not in os.path.basename(f).lower()
|
| 447 |
+
and "txt" not in os.path.basename(f).lower()
|
| 448 |
+
and "db" not in os.path.basename(f).lower()
|
| 449 |
+
]
|
| 450 |
+
image_names = utils.sort_images_by_number(image_names)
|
| 451 |
+
|
| 452 |
+
frame_ids = []
|
| 453 |
+
for path in image_names:
|
| 454 |
+
match = re.search(r"\d+(?:\.\d+)?", os.path.basename(path))
|
| 455 |
+
if match:
|
| 456 |
+
frame_ids.append(float(match.group()))
|
| 457 |
+
else:
|
| 458 |
+
raise ValueError(f"No number found in image name: {path}")
|
| 459 |
+
|
| 460 |
+
print(f"Found {len(image_names)} images")
|
| 461 |
+
|
| 462 |
+
# ── Create SLAM instance ─────────────────────────────
|
| 463 |
+
slam = InfiniteSLAM(
|
| 464 |
+
outdir=args.output_dir,
|
| 465 |
+
kf_th=args.kf_th,
|
| 466 |
+
bn_every=args.bn_every,
|
| 467 |
+
ckpt_path=args.ckpt_path,
|
| 468 |
+
target_size=args.target_size,
|
| 469 |
+
retention_ratio=args.retention_ratio,
|
| 470 |
+
enable_submap_backend=args.enable_submap_backend,
|
| 471 |
+
submap_size=args.submap_size,
|
| 472 |
+
max_recursive_submaps=args.max_recursive_submaps,
|
| 473 |
+
enable_loop_closure=args.enable_loop_closure,
|
| 474 |
+
desc_dim=args.desc_dim,
|
| 475 |
+
enable_temporal_embed=args.enable_temporal_embed,
|
| 476 |
+
temporal_embed_mode=args.temporal_embed_mode,
|
| 477 |
+
)
|
| 478 |
+
slam.K = K
|
| 479 |
+
|
| 480 |
+
# ── Run ──────────────────────────────────────────────
|
| 481 |
+
for frame_id, image_name in zip(
|
| 482 |
+
frame_ids[:: args.stride], image_names[:: args.stride]
|
| 483 |
+
):
|
| 484 |
+
img = cv2.imread(image_name)
|
| 485 |
+
if args.resize_rate != 1:
|
| 486 |
+
H, W, _ = img.shape
|
| 487 |
+
img = cv2.resize(
|
| 488 |
+
img, (int(W * args.resize_rate), int(H * args.resize_rate)),
|
| 489 |
+
cv2.INTER_CUBIC,
|
| 490 |
+
)
|
| 491 |
+
slam.step(frame_id, img)
|
| 492 |
+
|
| 493 |
+
slam.terminate()
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/demo_submap.py
ADDED
|
@@ -0,0 +1,927 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os,sys
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import re
|
| 6 |
+
import cv2
|
| 7 |
+
import glob
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
import time
|
| 11 |
+
import open3d as o3d
|
| 12 |
+
from rich import print
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
from scipy.spatial.transform import Rotation as R
|
| 15 |
+
|
| 16 |
+
import rerun as rr
|
| 17 |
+
import rerun.blueprint as rrb
|
| 18 |
+
|
| 19 |
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
| 20 |
+
|
| 21 |
+
sys.path.append(current_directory+'/../')
|
| 22 |
+
sys.path.append('src')
|
| 23 |
+
from slamformer.models.slamformer import SLAMFormer
|
| 24 |
+
from slam.graph_gated_memory import GraphGatedMemoryManager, TemporalEmbedWrapper
|
| 25 |
+
|
| 26 |
+
import slam.utils as utils
|
| 27 |
+
from slam.rerun_helper import log_camera, log_window
|
| 28 |
+
|
| 29 |
+
def strip_module(state_dict):
|
| 30 |
+
"""
|
| 31 |
+
Removes the 'module.' prefix from the keys of a state_dict.
|
| 32 |
+
Args:
|
| 33 |
+
state_dict (dict): The original state_dict with possible 'module.' prefixes.
|
| 34 |
+
Returns:
|
| 35 |
+
OrderedDict: A new state_dict with 'module.' prefixes removed.
|
| 36 |
+
"""
|
| 37 |
+
new_state_dict = OrderedDict()
|
| 38 |
+
for k, v in state_dict.items():
|
| 39 |
+
name = k[7:] if k.startswith("module.") else k
|
| 40 |
+
new_state_dict[name] = v
|
| 41 |
+
return new_state_dict
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def strip_prefix(state_dict, prefix):
|
| 45 |
+
if not isinstance(state_dict, dict):
|
| 46 |
+
return state_dict
|
| 47 |
+
new_state_dict = OrderedDict()
|
| 48 |
+
for k, v in state_dict.items():
|
| 49 |
+
name = k[len(prefix):] if k.startswith(prefix) else k
|
| 50 |
+
new_state_dict[name] = v
|
| 51 |
+
return new_state_dict
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_cfg_value(cfg, key, default=None):
|
| 55 |
+
if cfg is None:
|
| 56 |
+
return default
|
| 57 |
+
try:
|
| 58 |
+
if hasattr(cfg, "get"):
|
| 59 |
+
return cfg.get(key, default)
|
| 60 |
+
except Exception:
|
| 61 |
+
pass
|
| 62 |
+
try:
|
| 63 |
+
return cfg[key]
|
| 64 |
+
except Exception:
|
| 65 |
+
return getattr(cfg, key, default)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_checkpoint_parts(ckpt_raw):
|
| 69 |
+
train_cfg = None
|
| 70 |
+
memory_state = None
|
| 71 |
+
temporal_state = None
|
| 72 |
+
if isinstance(ckpt_raw, dict):
|
| 73 |
+
train_cfg = ckpt_raw.get("args")
|
| 74 |
+
if "model" in ckpt_raw:
|
| 75 |
+
ckpt = ckpt_raw["model"]
|
| 76 |
+
print("Loaded state_dict from 'model' key in checkpoint.")
|
| 77 |
+
else:
|
| 78 |
+
ckpt = ckpt_raw
|
| 79 |
+
nested = ckpt_raw.get("submap_modules", {})
|
| 80 |
+
if isinstance(nested, dict):
|
| 81 |
+
memory_state = ckpt_raw.get("memory_mgr", nested.get("memory_mgr"))
|
| 82 |
+
temporal_state = ckpt_raw.get("temporal_wrapper", nested.get("temporal_wrapper"))
|
| 83 |
+
else:
|
| 84 |
+
memory_state = ckpt_raw.get("memory_mgr")
|
| 85 |
+
temporal_state = ckpt_raw.get("temporal_wrapper")
|
| 86 |
+
else:
|
| 87 |
+
ckpt = ckpt_raw
|
| 88 |
+
return ckpt, train_cfg, memory_state, temporal_state
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class KVSubmapManager:
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
submap_size=10,
|
| 95 |
+
max_loop_submaps=5,
|
| 96 |
+
loop_similarity_threshold=0.75,
|
| 97 |
+
desc_dim=128,
|
| 98 |
+
embed_dim=1024,
|
| 99 |
+
):
|
| 100 |
+
self.K = submap_size
|
| 101 |
+
self.max_loop_submaps = max_loop_submaps
|
| 102 |
+
self.loop_similarity_threshold = loop_similarity_threshold
|
| 103 |
+
self.desc_dim = desc_dim
|
| 104 |
+
self.embed_dim = embed_dim
|
| 105 |
+
|
| 106 |
+
self.desc_proj = torch.nn.Linear(2 * embed_dim, desc_dim)
|
| 107 |
+
self.submap_tokens_cpu = {}
|
| 108 |
+
self.submap_descriptors = {}
|
| 109 |
+
self.submap_frame_ids = {}
|
| 110 |
+
self.adjacency = {}
|
| 111 |
+
|
| 112 |
+
self._curr_tokens = []
|
| 113 |
+
self._curr_frame_ids = []
|
| 114 |
+
self._current_submap_id = 0
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def current_submap_id(self):
|
| 118 |
+
return self._current_submap_id
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def current_frame_ids(self):
|
| 122 |
+
return list(self._curr_frame_ids)
|
| 123 |
+
|
| 124 |
+
def to(self, device):
|
| 125 |
+
self.desc_proj = self.desc_proj.to(device)
|
| 126 |
+
return self
|
| 127 |
+
|
| 128 |
+
def load_memory_state_dict(self, state_dict):
|
| 129 |
+
desc_state = OrderedDict()
|
| 130 |
+
for k, v in state_dict.items():
|
| 131 |
+
key = k
|
| 132 |
+
if key.startswith("memory_mgr."):
|
| 133 |
+
key = key[len("memory_mgr."):]
|
| 134 |
+
if key.startswith("desc_proj."):
|
| 135 |
+
desc_state[key[len("desc_proj."):]] = v
|
| 136 |
+
if desc_state:
|
| 137 |
+
msg = self.desc_proj.load_state_dict(desc_state, strict=False)
|
| 138 |
+
print(f"Loaded submap descriptor head: {msg}")
|
| 139 |
+
|
| 140 |
+
def accumulate(self, frame_token, frame_id):
|
| 141 |
+
if frame_token.dim() == 2:
|
| 142 |
+
frame_token = frame_token.unsqueeze(0)
|
| 143 |
+
self._curr_tokens.append(frame_token.detach().cpu())
|
| 144 |
+
self._curr_frame_ids.append(int(frame_id))
|
| 145 |
+
|
| 146 |
+
def has_current_tokens(self):
|
| 147 |
+
return len(self._curr_tokens) > 0
|
| 148 |
+
|
| 149 |
+
def current_tokens(self, device):
|
| 150 |
+
if not self._curr_tokens:
|
| 151 |
+
return None
|
| 152 |
+
return torch.cat(self._curr_tokens, dim=0).to(device, non_blocking=True)
|
| 153 |
+
|
| 154 |
+
def compute_descriptor(self, tokens):
|
| 155 |
+
if self.desc_proj.weight.device != tokens.device:
|
| 156 |
+
self.desc_proj = self.desc_proj.to(tokens.device)
|
| 157 |
+
pooled = tokens.mean(dim=(0, 1), keepdim=False).unsqueeze(0).float()
|
| 158 |
+
desc = self.desc_proj(pooled).squeeze(0)
|
| 159 |
+
return torch.nn.functional.normalize(desc, dim=0)
|
| 160 |
+
|
| 161 |
+
def select_context_submaps(self, curr_desc):
|
| 162 |
+
prev_sid = self._current_submap_id - 1
|
| 163 |
+
if prev_sid not in self.submap_tokens_cpu:
|
| 164 |
+
prev_sid = None
|
| 165 |
+
|
| 166 |
+
hist_ids = sorted([sid for sid in self.submap_tokens_cpu.keys() if sid < self._current_submap_id])
|
| 167 |
+
loop_ids = []
|
| 168 |
+
|
| 169 |
+
primary_sid = None
|
| 170 |
+
primary_sim = None
|
| 171 |
+
if self.max_loop_submaps > 0:
|
| 172 |
+
candidates = [sid for sid in hist_ids if sid != prev_sid]
|
| 173 |
+
if candidates:
|
| 174 |
+
curr_desc_cpu = curr_desc.detach().cpu().float()
|
| 175 |
+
scored = []
|
| 176 |
+
for sid in candidates:
|
| 177 |
+
hist_desc = self.submap_descriptors[sid].float()
|
| 178 |
+
sim = torch.nn.functional.cosine_similarity(
|
| 179 |
+
curr_desc_cpu.unsqueeze(0),
|
| 180 |
+
hist_desc.unsqueeze(0),
|
| 181 |
+
dim=-1,
|
| 182 |
+
).item()
|
| 183 |
+
scored.append((sim, sid))
|
| 184 |
+
scored.sort(reverse=True)
|
| 185 |
+
if scored:
|
| 186 |
+
primary_sim, primary_sid = scored[0]
|
| 187 |
+
if primary_sim >= self.loop_similarity_threshold:
|
| 188 |
+
loop_ids.append(primary_sid)
|
| 189 |
+
for nid in sorted(self.adjacency.get(primary_sid, set())):
|
| 190 |
+
if len(loop_ids) >= self.max_loop_submaps:
|
| 191 |
+
break
|
| 192 |
+
if nid != self._current_submap_id and nid in self.submap_tokens_cpu and nid not in loop_ids:
|
| 193 |
+
loop_ids.append(nid)
|
| 194 |
+
else:
|
| 195 |
+
primary_sid = None
|
| 196 |
+
|
| 197 |
+
deduped = []
|
| 198 |
+
seen = set()
|
| 199 |
+
for sid in loop_ids:
|
| 200 |
+
if sid not in seen:
|
| 201 |
+
deduped.append(sid)
|
| 202 |
+
seen.add(sid)
|
| 203 |
+
return prev_sid, deduped, primary_sid, primary_sim
|
| 204 |
+
|
| 205 |
+
def build_backend_tokens(self, device, prev_sid=None, loop_ids=None):
|
| 206 |
+
loop_ids = [] if loop_ids is None else list(loop_ids)
|
| 207 |
+
parts = []
|
| 208 |
+
n_prev = 0
|
| 209 |
+
loop_token_counts = []
|
| 210 |
+
|
| 211 |
+
if prev_sid is not None and prev_sid in self.submap_tokens_cpu:
|
| 212 |
+
prev_tokens = self.submap_tokens_cpu[prev_sid].to(device, non_blocking=True)
|
| 213 |
+
parts.append(prev_tokens)
|
| 214 |
+
n_prev = prev_tokens.shape[0]
|
| 215 |
+
|
| 216 |
+
curr_tokens = self.current_tokens(device)
|
| 217 |
+
n_curr = curr_tokens.shape[0]
|
| 218 |
+
parts.append(curr_tokens)
|
| 219 |
+
|
| 220 |
+
for sid in loop_ids:
|
| 221 |
+
loop_tokens = self.submap_tokens_cpu[sid].to(device, non_blocking=True)
|
| 222 |
+
parts.append(loop_tokens)
|
| 223 |
+
loop_token_counts.append(loop_tokens.shape[0])
|
| 224 |
+
|
| 225 |
+
combined = torch.cat(parts, dim=0)
|
| 226 |
+
meta = {
|
| 227 |
+
"n_prev": n_prev,
|
| 228 |
+
"n_curr": n_curr,
|
| 229 |
+
"curr_frame_ids": list(self._curr_frame_ids),
|
| 230 |
+
"loop_token_counts": loop_token_counts,
|
| 231 |
+
}
|
| 232 |
+
return combined, meta
|
| 233 |
+
|
| 234 |
+
def _store_refined_submap(self, sid, refined_tokens):
|
| 235 |
+
self.submap_tokens_cpu[sid] = refined_tokens.detach().cpu()
|
| 236 |
+
self.submap_descriptors[sid] = self.compute_descriptor(refined_tokens).detach().cpu()
|
| 237 |
+
|
| 238 |
+
def finalize_current_submap(
|
| 239 |
+
self,
|
| 240 |
+
hidden_B,
|
| 241 |
+
n_prev,
|
| 242 |
+
n_curr,
|
| 243 |
+
prev_sid=None,
|
| 244 |
+
loop_ids=None,
|
| 245 |
+
loop_token_counts=None,
|
| 246 |
+
primary_sid=None,
|
| 247 |
+
):
|
| 248 |
+
loop_ids = [] if loop_ids is None else list(loop_ids)
|
| 249 |
+
loop_token_counts = [] if loop_token_counts is None else list(loop_token_counts)
|
| 250 |
+
|
| 251 |
+
offset = 0
|
| 252 |
+
if prev_sid is not None and n_prev > 0:
|
| 253 |
+
refined_prev = hidden_B[:n_prev]
|
| 254 |
+
self._store_refined_submap(prev_sid, refined_prev)
|
| 255 |
+
offset = n_prev
|
| 256 |
+
|
| 257 |
+
refined_curr = hidden_B[offset:offset + n_curr]
|
| 258 |
+
offset += n_curr
|
| 259 |
+
|
| 260 |
+
for sid, count in zip(loop_ids, loop_token_counts):
|
| 261 |
+
refined_loop = hidden_B[offset:offset + count]
|
| 262 |
+
self._store_refined_submap(sid, refined_loop)
|
| 263 |
+
offset += count
|
| 264 |
+
|
| 265 |
+
sid = self._current_submap_id
|
| 266 |
+
self.submap_tokens_cpu[sid] = refined_curr.detach().cpu()
|
| 267 |
+
self.submap_descriptors[sid] = self.compute_descriptor(refined_curr).detach().cpu()
|
| 268 |
+
self.submap_frame_ids[sid] = list(self._curr_frame_ids)
|
| 269 |
+
|
| 270 |
+
if primary_sid is not None and primary_sid in self.submap_tokens_cpu:
|
| 271 |
+
self.adjacency.setdefault(sid, set()).add(primary_sid)
|
| 272 |
+
self.adjacency.setdefault(primary_sid, set()).add(sid)
|
| 273 |
+
|
| 274 |
+
self._curr_tokens = []
|
| 275 |
+
self._curr_frame_ids = []
|
| 276 |
+
self._current_submap_id += 1
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class SLAM:
|
| 280 |
+
def __init__(
|
| 281 |
+
self,
|
| 282 |
+
outdir='output/tmp',
|
| 283 |
+
kf_th=0.1,
|
| 284 |
+
bn_every=None,
|
| 285 |
+
vis=False,
|
| 286 |
+
save_gmem=True,
|
| 287 |
+
ckpt_path='path/to/ckpt.pth',
|
| 288 |
+
target_size=518,
|
| 289 |
+
retention_ratio=None,
|
| 290 |
+
loop_mask_mode=None,
|
| 291 |
+
submap_train_mode=None,
|
| 292 |
+
submap_retrieval_topk=None,
|
| 293 |
+
submap_fetch_source=None,
|
| 294 |
+
submap_descriptor_source=None,
|
| 295 |
+
max_recursive_submaps=None,
|
| 296 |
+
):
|
| 297 |
+
|
| 298 |
+
self.outdir = outdir
|
| 299 |
+
self.kf_th=kf_th
|
| 300 |
+
self.save_gmem = save_gmem
|
| 301 |
+
self.bn_every=bn_every
|
| 302 |
+
self.vis = vis
|
| 303 |
+
self.ckpt_path = ckpt_path
|
| 304 |
+
self.target_size = target_size
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
self.times = []
|
| 308 |
+
self.kf_time = []
|
| 309 |
+
self.backend_time = []
|
| 310 |
+
|
| 311 |
+
ckpt_raw = torch.load(self.ckpt_path, map_location='cpu', weights_only=False)
|
| 312 |
+
_, self.train_cfg, _, _ = extract_checkpoint_parts(ckpt_raw)
|
| 313 |
+
|
| 314 |
+
self.bn_every = int(bn_every if bn_every is not None else get_cfg_value(self.train_cfg, 'submap_size', 10))
|
| 315 |
+
self.retention_ratio = float(
|
| 316 |
+
retention_ratio if retention_ratio is not None else get_cfg_value(self.train_cfg, 'retention_ratio', 0.5)
|
| 317 |
+
)
|
| 318 |
+
self.enable_loop = bool(get_cfg_value(self.train_cfg, 'enable_loop', False))
|
| 319 |
+
self.enable_temporal = bool(get_cfg_value(self.train_cfg, 'enable_temporal', False))
|
| 320 |
+
self.tbptt_window = int(get_cfg_value(self.train_cfg, 'tbptt_window', 0))
|
| 321 |
+
self.max_recursive_submaps = int(
|
| 322 |
+
max_recursive_submaps if max_recursive_submaps is not None else get_cfg_value(self.train_cfg, 'max_recursive_submaps', 5)
|
| 323 |
+
)
|
| 324 |
+
self.desc_dim = int(get_cfg_value(self.train_cfg, 'desc_dim', 128))
|
| 325 |
+
self.loop_mask_mode = loop_mask_mode if loop_mask_mode is not None else get_cfg_value(self.train_cfg, 'loop_mask_mode', 'hard_top1')
|
| 326 |
+
self.soft_mask_temperature = float(get_cfg_value(self.train_cfg, 'soft_mask_temperature', 0.25))
|
| 327 |
+
self.soft_mask_bias = float(get_cfg_value(self.train_cfg, 'soft_mask_bias', 0.2))
|
| 328 |
+
self.submap_train_mode = (
|
| 329 |
+
submap_train_mode if submap_train_mode is not None else get_cfg_value(self.train_cfg, 'submap_train_mode', 'full_token')
|
| 330 |
+
)
|
| 331 |
+
self.submap_retrieval_topk = int(
|
| 332 |
+
submap_retrieval_topk if submap_retrieval_topk is not None else get_cfg_value(self.train_cfg, 'submap_retrieval_topk', 0)
|
| 333 |
+
)
|
| 334 |
+
self.submap_fetch_source = (
|
| 335 |
+
submap_fetch_source if submap_fetch_source is not None else get_cfg_value(self.train_cfg, 'submap_fetch_source', 'frontend')
|
| 336 |
+
)
|
| 337 |
+
self.submap_descriptor_source = (
|
| 338 |
+
submap_descriptor_source
|
| 339 |
+
if submap_descriptor_source is not None
|
| 340 |
+
else get_cfg_value(self.train_cfg, 'submap_descriptor_source', 'frontend')
|
| 341 |
+
)
|
| 342 |
+
self.temporal_embed_mode = get_cfg_value(self.train_cfg, 'temporal_embed_mode', 'learned')
|
| 343 |
+
|
| 344 |
+
# model params
|
| 345 |
+
self.model = SLAMFormer(retention_ratio=self.retention_ratio, bn_every=self.bn_every)
|
| 346 |
+
self.model = self.model.eval()
|
| 347 |
+
self.memory_mgr = GraphGatedMemoryManager(
|
| 348 |
+
submap_size=self.bn_every,
|
| 349 |
+
max_recursive_submaps=self.max_recursive_submaps,
|
| 350 |
+
desc_dim=self.desc_dim,
|
| 351 |
+
embed_dim=self.model.dec_embed_dim,
|
| 352 |
+
loop_mask_mode=self.loop_mask_mode,
|
| 353 |
+
soft_mask_temperature=self.soft_mask_temperature,
|
| 354 |
+
soft_mask_bias=self.soft_mask_bias,
|
| 355 |
+
retain_history_grad=False,
|
| 356 |
+
submap_train_mode=self.submap_train_mode,
|
| 357 |
+
submap_retrieval_topk=self.submap_retrieval_topk,
|
| 358 |
+
submap_fetch_source=self.submap_fetch_source,
|
| 359 |
+
submap_descriptor_source=self.submap_descriptor_source,
|
| 360 |
+
)
|
| 361 |
+
self.temporal_wrapper = (
|
| 362 |
+
TemporalEmbedWrapper(
|
| 363 |
+
embed_dim=self.model.dec_embed_dim,
|
| 364 |
+
max_frames=5000,
|
| 365 |
+
mode=self.temporal_embed_mode,
|
| 366 |
+
)
|
| 367 |
+
if self.enable_temporal else None
|
| 368 |
+
)
|
| 369 |
+
self.load_model(ckpt_raw)
|
| 370 |
+
del ckpt_raw
|
| 371 |
+
self.model.eval()
|
| 372 |
+
self.model.to('cuda')
|
| 373 |
+
self.memory_mgr.to('cuda')
|
| 374 |
+
self.memory_mgr.eval()
|
| 375 |
+
if self.temporal_wrapper is not None:
|
| 376 |
+
self.temporal_wrapper.to('cuda')
|
| 377 |
+
self.temporal_wrapper.eval()
|
| 378 |
+
self.memory_mgr.reset()
|
| 379 |
+
print(
|
| 380 |
+
f"Resolved inference config: submap_size={self.bn_every}, retention_ratio={self.retention_ratio}, "
|
| 381 |
+
f"enable_loop={self.enable_loop}, enable_temporal={self.enable_temporal}, loop_mask_mode={self.loop_mask_mode}, "
|
| 382 |
+
f"submap_train_mode={self.submap_train_mode}, submap_retrieval_topk={self.submap_retrieval_topk}, "
|
| 383 |
+
f"submap_fetch_source={self.submap_fetch_source}, submap_descriptor_source={self.submap_descriptor_source}, "
|
| 384 |
+
f"max_recursive_submaps={self.max_recursive_submaps}"
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# SLAM params
|
| 388 |
+
self.fid = -1
|
| 389 |
+
self.kid = -1
|
| 390 |
+
self.kfids = []
|
| 391 |
+
self.last_kfid = 0
|
| 392 |
+
self.kf_timestamps = []
|
| 393 |
+
# frontend
|
| 394 |
+
self.frontend_times = 0
|
| 395 |
+
# Token map
|
| 396 |
+
self.map = None
|
| 397 |
+
self.map_opt = None
|
| 398 |
+
|
| 399 |
+
self.signal_backend = False
|
| 400 |
+
self.backend_every = self.bn_every #10
|
| 401 |
+
#
|
| 402 |
+
self.extrins = []
|
| 403 |
+
self.intrins = []
|
| 404 |
+
self.frames = []
|
| 405 |
+
self.kf_frames = []
|
| 406 |
+
|
| 407 |
+
#
|
| 408 |
+
self.K = None
|
| 409 |
+
self.update_K = False
|
| 410 |
+
|
| 411 |
+
# vis
|
| 412 |
+
if self.vis:
|
| 413 |
+
self.entity="world"
|
| 414 |
+
rr.init("SLAM", spawn=True)
|
| 415 |
+
rr.log(self.entity, rr.ViewCoordinates.RIGHT_HAND_Z_UP)
|
| 416 |
+
self.Twk = np.eye(4)
|
| 417 |
+
self.K = np.eye(3)
|
| 418 |
+
|
| 419 |
+
def load_model(self, ckpt_raw):
|
| 420 |
+
ckpt, _, memory_state, temporal_state = extract_checkpoint_parts(ckpt_raw)
|
| 421 |
+
ckpt = utils.strip_module(ckpt)
|
| 422 |
+
self.model.load_state_dict(ckpt, strict=False)
|
| 423 |
+
if memory_state is not None:
|
| 424 |
+
memory_state = strip_prefix(strip_module(memory_state), 'memory_mgr.')
|
| 425 |
+
msg = self.memory_mgr.load_state_dict(memory_state, strict=False)
|
| 426 |
+
print(f"Loaded memory manager: {msg}")
|
| 427 |
+
if self.temporal_wrapper is not None and temporal_state is not None:
|
| 428 |
+
temporal_state = strip_prefix(strip_module(temporal_state), 'temporal_wrapper.')
|
| 429 |
+
msg = self.temporal_wrapper.load_state_dict(temporal_state, strict=False)
|
| 430 |
+
print(f"Loaded temporal wrapper: {msg}")
|
| 431 |
+
del ckpt
|
| 432 |
+
|
| 433 |
+
@property
|
| 434 |
+
def time(self):
|
| 435 |
+
torch.cuda.synchronize()
|
| 436 |
+
return time.perf_counter()
|
| 437 |
+
|
| 438 |
+
def _estimate_kf_pose(self, frame_pair, use_amp=True):
|
| 439 |
+
token = self.model.KFT(frame_pair, use_amp=use_amp)
|
| 440 |
+
res = self.model.extract(token, cam_only=True, use_amp=use_amp)
|
| 441 |
+
camera_pose = res['camera_poses']
|
| 442 |
+
extrinsic = torch.linalg.inv(camera_pose)
|
| 443 |
+
return token, camera_pose, extrinsic
|
| 444 |
+
|
| 445 |
+
def kf_detect(self, image):
|
| 446 |
+
if self.kid == -1:
|
| 447 |
+
self.extrins.append(torch.eye(4))
|
| 448 |
+
return True
|
| 449 |
+
|
| 450 |
+
frame = utils.load_image(image, self.target_size)
|
| 451 |
+
_,H,W = frame.shape
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
st = self.time #time.perf_counter()
|
| 455 |
+
frame_pair = torch.stack([self.kf_frames[-1], frame.cuda()])
|
| 456 |
+
token, camera_pose, extrinsic = self._estimate_kf_pose(frame_pair, use_amp=True)
|
| 457 |
+
if (not torch.isfinite(camera_pose).all()) or (not torch.isfinite(extrinsic).all()):
|
| 458 |
+
print("[warning] Non-finite keyframe pose under AMP; retrying keyframe detection in float32.")
|
| 459 |
+
token, camera_pose, extrinsic = self._estimate_kf_pose(frame_pair, use_amp=False)
|
| 460 |
+
if (not torch.isfinite(camera_pose).all()) or (not torch.isfinite(extrinsic).all()):
|
| 461 |
+
print("[warning] Non-finite keyframe pose after float32 retry; skipping this frame for keyframe detection.")
|
| 462 |
+
return False
|
| 463 |
+
if self.vis:
|
| 464 |
+
# scale the pose to global
|
| 465 |
+
#z = res['local_points'][0,0,:,:,-1].cpu().numpy()
|
| 466 |
+
if not hasattr(self,'depth_lask_kf'):
|
| 467 |
+
scale=1
|
| 468 |
+
else:
|
| 469 |
+
scale=1 #np.median(self.depth_last_kf/(z+1e-6))
|
| 470 |
+
if extrinsic.shape[1] > 1:
|
| 471 |
+
extrinsic_ref=extrinsic.cpu()[0,-2]
|
| 472 |
+
extrinsic = extrinsic.cpu()[0,-1]
|
| 473 |
+
Tki = torch.linalg.inv(camera_pose[0,0])@camera_pose[0,1]
|
| 474 |
+
Tki = Tki.cpu().numpy()
|
| 475 |
+
self.Twi = self.Twk@Tki
|
| 476 |
+
K44 = np.eye(4)
|
| 477 |
+
K44[:3,:3] = self.K
|
| 478 |
+
log_camera("camera",self.Twi, K44, kfd=True)
|
| 479 |
+
# make the window follow camera
|
| 480 |
+
log_window(f"{self.entity}",np.linalg.inv(self.Twi), K44)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
else:
|
| 484 |
+
if extrinsic.shape[1] > 1:
|
| 485 |
+
extrinsic_ref=extrinsic.cpu()[0,-2]
|
| 486 |
+
extrinsic = extrinsic.cpu()[0,-1]
|
| 487 |
+
self.kft_extrinsic_ref = torch.eye(4)#extrinsic_ref
|
| 488 |
+
|
| 489 |
+
dist = torch.sqrt(torch.sum((extrinsic[:3,3] - extrinsic_ref[:3,3])**2))
|
| 490 |
+
if not torch.isfinite(dist):
|
| 491 |
+
print("[warning] Non-finite keyframe distance after pose recovery; skipping this frame.")
|
| 492 |
+
return False
|
| 493 |
+
isKF = dist > self.kf_th
|
| 494 |
+
|
| 495 |
+
print(dist)
|
| 496 |
+
|
| 497 |
+
if isKF:
|
| 498 |
+
self.extrins.append(extrinsic)
|
| 499 |
+
return isKF
|
| 500 |
+
|
| 501 |
+
def frontend(self, image):
|
| 502 |
+
|
| 503 |
+
if self.vis:
|
| 504 |
+
rr.log("image", rr.Image(image[:,:,::-1]))#,static=True)
|
| 505 |
+
|
| 506 |
+
self.fid += 1
|
| 507 |
+
print('Frame', self.fid)
|
| 508 |
+
# run kf detector
|
| 509 |
+
st = self.time
|
| 510 |
+
enough_disparity = self.kf_detect(image)
|
| 511 |
+
self.kf_time.append(self.time-st)
|
| 512 |
+
if not enough_disparity:
|
| 513 |
+
return False
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
torch.cuda.empty_cache()
|
| 517 |
+
# run T-frontend
|
| 518 |
+
H_,W_,_ = image.shape
|
| 519 |
+
frame = utils.load_image(image, self.target_size)
|
| 520 |
+
self.H,self.W,_ = frame.shape
|
| 521 |
+
st = self.time
|
| 522 |
+
self.last_kf = frame.cuda()
|
| 523 |
+
self.kf_frames.append(self.last_kf)
|
| 524 |
+
self.last_kfid = self.fid
|
| 525 |
+
self.frames.append(self.last_kf.clone())
|
| 526 |
+
self.kid += 1
|
| 527 |
+
print("[italic purple] # KEYFRAME", self.kid)
|
| 528 |
+
self.kf_timestamps.append(self.cur_timestamp)
|
| 529 |
+
frame = frame.cuda()
|
| 530 |
+
st = self.time
|
| 531 |
+
|
| 532 |
+
if self.nkf == 1:
|
| 533 |
+
pass
|
| 534 |
+
elif self.nkf == 2:
|
| 535 |
+
token = self.model.frontendT(torch.stack([self.kf_frames[0],frame]))
|
| 536 |
+
self.map_add(token)
|
| 537 |
+
else:
|
| 538 |
+
token = self.model.frontendT(frame)
|
| 539 |
+
print(self.time-st)
|
| 540 |
+
|
| 541 |
+
self.map_add(token)
|
| 542 |
+
|
| 543 |
+
self.kfids.append(self.fid)
|
| 544 |
+
self.times.append(self.time-st)
|
| 545 |
+
torch.cuda.empty_cache()
|
| 546 |
+
|
| 547 |
+
# send signal to backend
|
| 548 |
+
self.frontend_times += 1
|
| 549 |
+
if self.memory_mgr.submap_complete:
|
| 550 |
+
self.signal_backend = True
|
| 551 |
+
|
| 552 |
+
if self.vis and self.map is not None:
|
| 553 |
+
st = time.time()
|
| 554 |
+
map_before_bn = None
|
| 555 |
+
if self.map_opt is None:
|
| 556 |
+
map_before_bn = self.map
|
| 557 |
+
else:
|
| 558 |
+
S = self.map.shape[0]
|
| 559 |
+
S_oldopt = self.map_opt.shape[0]
|
| 560 |
+
|
| 561 |
+
map_before_bn = torch.cat([self.map_opt, self.map[S_oldopt:]],axis=0)
|
| 562 |
+
if self.nkf == 2:
|
| 563 |
+
ps,cs,confs,poses = self.extract(self.map)
|
| 564 |
+
|
| 565 |
+
else:
|
| 566 |
+
ps,cs,confs,poses = self.extract(self.map[-1:])
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
self.vis_mem = [ps,cs,confs,poses]
|
| 570 |
+
|
| 571 |
+
conf_threshold = np.percentile(confs, 15)
|
| 572 |
+
msk = confs>=conf_threshold
|
| 573 |
+
|
| 574 |
+
ps = ps[msk]
|
| 575 |
+
cs = cs[msk]
|
| 576 |
+
K44 = np.eye(4)
|
| 577 |
+
K44[:3,:3] = self.K
|
| 578 |
+
|
| 579 |
+
if self.nkf == 2:
|
| 580 |
+
log_camera(f"{self.entity}/camera_kf/0",poses[0], K44)
|
| 581 |
+
log_camera(f"{self.entity}/camera_kf/1",poses[1], K44)
|
| 582 |
+
|
| 583 |
+
rr.log(f"{self.entity}/lines/0to1", rr.LineStrips3D([poses[:,:3,3].tolist()],colors=[0,0,255],radii=[0.005]))
|
| 584 |
+
|
| 585 |
+
self.last_kf_pose = poses[1]
|
| 586 |
+
else:
|
| 587 |
+
log_camera(f"{self.entity}/camera_kf/{self.nkf-1}",poses.reshape(4,4), K44)
|
| 588 |
+
rr.log(f"{self.entity}/lines/{self.nkf-2}to{self.nkf-1}", rr.LineStrips3D([np.stack([self.last_kf_pose[:3,3],poses[0,:3,3]]).tolist()],colors=[0,0,255],radii=[0.005]))
|
| 589 |
+
|
| 590 |
+
self.last_kf_pose = poses[0]
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
rr.log(
|
| 594 |
+
f"{self.entity}/pointclouds/{self.nkf}",
|
| 595 |
+
rr.Points3D(ps, colors=cs, radii=0.01),
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
print('log', time.time()-st)
|
| 599 |
+
|
| 600 |
+
self.Twk = poses[-1].reshape(4,4)
|
| 601 |
+
|
| 602 |
+
def backend(self, final=False):
|
| 603 |
+
if not self.signal_backend:
|
| 604 |
+
return
|
| 605 |
+
|
| 606 |
+
torch.cuda.empty_cache()
|
| 607 |
+
if not self.memory_mgr._curr_tokens:
|
| 608 |
+
self.signal_backend = False
|
| 609 |
+
return
|
| 610 |
+
|
| 611 |
+
if hasattr(self.model, 'fkv'):
|
| 612 |
+
del self.model.fkv
|
| 613 |
+
self.model.reset_backend_cache = True
|
| 614 |
+
self.model._prune_idx_cache = None
|
| 615 |
+
self.model._prune_idx_cache_N = 0
|
| 616 |
+
self.model._prune_idx_cache_hw = None
|
| 617 |
+
|
| 618 |
+
print('Backending...', self.nkf, 'KFs')
|
| 619 |
+
st = time.perf_counter()
|
| 620 |
+
map_optimed, loop_gate, backend_meta = self.memory_mgr.finalize_submap(
|
| 621 |
+
model=self.model,
|
| 622 |
+
device=torch.device('cuda'),
|
| 623 |
+
temporal_wrapper=self.temporal_wrapper,
|
| 624 |
+
enable_temporal_embed=self.enable_temporal,
|
| 625 |
+
enable_loop_closure=self.enable_loop,
|
| 626 |
+
tbptt_window=self.tbptt_window,
|
| 627 |
+
)
|
| 628 |
+
backend_take = time.perf_counter()-st
|
| 629 |
+
self.backend_time.append(backend_take)
|
| 630 |
+
print(
|
| 631 |
+
f'Submap backend: sid={self.memory_mgr.current_submap_id - 1}, '
|
| 632 |
+
f'prev_tokens={backend_meta["n_prev"]}, retrieved_tokens={backend_meta["n_retrieved"]}, '
|
| 633 |
+
f'loop_gate={float(loop_gate.squeeze().detach().cpu())}'
|
| 634 |
+
)
|
| 635 |
+
print('backend_take', backend_take)
|
| 636 |
+
torch.cuda.empty_cache()
|
| 637 |
+
|
| 638 |
+
map_cpu = self.map.detach().cpu() if self.map.is_cuda else self.map
|
| 639 |
+
if self.map_opt is None:
|
| 640 |
+
self.map_opt = map_cpu.clone()
|
| 641 |
+
elif self.map_opt.shape[0] < map_cpu.shape[0]:
|
| 642 |
+
self.map_opt = torch.cat([self.map_opt, map_cpu[self.map_opt.shape[0]:]], dim=0)
|
| 643 |
+
|
| 644 |
+
for local_idx, frame_id in enumerate(backend_meta["frame_ids"]):
|
| 645 |
+
self.map_opt[int(frame_id)] = map_optimed[local_idx].detach().cpu()
|
| 646 |
+
|
| 647 |
+
self.signal_backend = False
|
| 648 |
+
torch.cuda.empty_cache()
|
| 649 |
+
|
| 650 |
+
if self.vis:
|
| 651 |
+
ps,cs,confs,poses = self.extract(self.map_opt)
|
| 652 |
+
self.vis_mem = [ps,cs,confs,poses]
|
| 653 |
+
conf_threshold = np.percentile(confs, 15)
|
| 654 |
+
msk = confs>=conf_threshold
|
| 655 |
+
|
| 656 |
+
ps = ps[msk]
|
| 657 |
+
cs = cs[msk]
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
for s in range(self.nkf+1):
|
| 661 |
+
rr.log(f"{self.entity}/pointclouds/{s}", rr.Points3D(np.array([])))
|
| 662 |
+
|
| 663 |
+
for s in range(self.nkf):
|
| 664 |
+
K44 = np.eye(4)
|
| 665 |
+
K44[:3,:3] = self.K
|
| 666 |
+
log_camera(f"{self.entity}/camera_kf/{s}",poses[s].reshape(4,4), K44, update=True)
|
| 667 |
+
|
| 668 |
+
for s in range(1, self.nkf):
|
| 669 |
+
rr.log(f"{self.entity}/lines/{s-1}to{s}", rr.LineStrips3D([poses[s-1:s+1,:3,3].tolist()],colors=[0,0,255],radii=[0.005]))
|
| 670 |
+
|
| 671 |
+
rr.log(
|
| 672 |
+
f"{self.entity}/pointclouds/{self.nkf}",
|
| 673 |
+
rr.Points3D(ps, colors=cs, radii=0.01),
|
| 674 |
+
)
|
| 675 |
+
self.last_kf_pose = poses[-1]
|
| 676 |
+
|
| 677 |
+
def step(self, timestamp, image):
|
| 678 |
+
if timestamp is None:
|
| 679 |
+
self.cur_timestamp = self.fid+1
|
| 680 |
+
else:
|
| 681 |
+
self.cur_timestamp = timestamp
|
| 682 |
+
|
| 683 |
+
self.frontend(image)
|
| 684 |
+
|
| 685 |
+
self.backend()
|
| 686 |
+
|
| 687 |
+
def map_add(self, token_kf):
|
| 688 |
+
token_kf = token_kf.detach()
|
| 689 |
+
start_idx = 0 if self.map is None else int(self.map.shape[0])
|
| 690 |
+
for i, tok in enumerate(token_kf):
|
| 691 |
+
self.memory_mgr.accumulate(tok.unsqueeze(0), start_idx + i)
|
| 692 |
+
if self.map is None:
|
| 693 |
+
self.map = token_kf.cpu() if self.save_gmem else token_kf #[tok.cpu() for tok in token_kf]
|
| 694 |
+
else:
|
| 695 |
+
if self.save_gmem:
|
| 696 |
+
self.map = torch.cat([self.map, token_kf.cpu()],axis=0) # S,P,C
|
| 697 |
+
else:
|
| 698 |
+
self.map = torch.cat([self.map, token_kf],axis=0) # S,P,C
|
| 699 |
+
|
| 700 |
+
@property
|
| 701 |
+
def nkf(self):
|
| 702 |
+
return self.kid+1
|
| 703 |
+
|
| 704 |
+
@property
|
| 705 |
+
def nf(self):
|
| 706 |
+
return self.fid+1
|
| 707 |
+
|
| 708 |
+
def terminate(self):
|
| 709 |
+
if self.memory_mgr._curr_tokens:
|
| 710 |
+
self.signal_backend = True
|
| 711 |
+
self.backend(final=True)
|
| 712 |
+
|
| 713 |
+
print(self.kf_time)
|
| 714 |
+
print(self.times)
|
| 715 |
+
print(self.backend_time)
|
| 716 |
+
print('frontend take', np.mean(self.times))
|
| 717 |
+
print('KFT')
|
| 718 |
+
print('total', np.sum(self.kf_time), 'FPS', float(len(self.kf_time))/np.sum(self.kf_time))
|
| 719 |
+
print('FT')
|
| 720 |
+
print('total', np.sum(self.times), 'FPS', float(len(self.times))/np.sum(self.times))
|
| 721 |
+
print('BT')
|
| 722 |
+
print('total', np.sum(self.backend_time), 'FPS', float(len(self.backend_time))/np.sum(self.backend_time))
|
| 723 |
+
print('Summary')
|
| 724 |
+
print('total', np.sum(self.kf_time)+np.sum(self.times)+np.sum(self.backend_time), 'FPS', float(len(self.kf_time))/(np.sum(self.kf_time)+np.sum(self.times)+np.sum(self.backend_time)))
|
| 725 |
+
map_to_save = self.map_opt if self.map_opt is not None else self.map
|
| 726 |
+
if map_to_save is None:
|
| 727 |
+
print(
|
| 728 |
+
f"[warning] No map was built for this sequence (nkf={self.nkf}). "
|
| 729 |
+
"Skipping result export to avoid crashing on a single-keyframe run."
|
| 730 |
+
)
|
| 731 |
+
return None
|
| 732 |
+
self.save_result(f'{self.outdir}/final', map_to_save)
|
| 733 |
+
|
| 734 |
+
def extract(self, map_all=None):
|
| 735 |
+
result = self.model.extract(map_all.cuda())
|
| 736 |
+
|
| 737 |
+
pts = result['points'].cpu().numpy() # 1,S,H,W,3
|
| 738 |
+
local_pts = result['local_points'].cpu().numpy() # 1,S,H,W,3
|
| 739 |
+
_,S,H,W,_ = pts.shape
|
| 740 |
+
conf = result['conf'].cpu().numpy()
|
| 741 |
+
point_clouds = [pts[0,s] for s in range(S)]
|
| 742 |
+
#conf_threshold = np.percentile(conf, 15)
|
| 743 |
+
#confs = [conf[0,s]>=conf_threshold for s in range(S)]
|
| 744 |
+
colors = torch.stack(self.frames[-S:]).permute(0,2,3,1).reshape(-1,3).cpu().numpy()[:,::-1] # S,H,W,C
|
| 745 |
+
confs = conf.reshape(-1)
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
camera_pose = result['camera_poses'].cpu().numpy()[0] # S,4,4
|
| 749 |
+
pts = pts.reshape(-1,3)
|
| 750 |
+
colors = colors.reshape(-1,3)
|
| 751 |
+
|
| 752 |
+
# set depth for the last kf
|
| 753 |
+
self.depth_last_kf = local_pts[0,-1,:,:,-1]
|
| 754 |
+
|
| 755 |
+
return pts, colors, confs, camera_pose
|
| 756 |
+
|
| 757 |
+
def save_result(self, output_path = 'output/tmp', map_all=None, traj=True):
|
| 758 |
+
'''
|
| 759 |
+
if map_all is None:
|
| 760 |
+
map_all = self.map
|
| 761 |
+
'''
|
| 762 |
+
print(self.kfids)
|
| 763 |
+
|
| 764 |
+
if map_all is None:
|
| 765 |
+
map_all = self.map_opt if self.map_opt is not None else self.map
|
| 766 |
+
if map_all is None:
|
| 767 |
+
print(f"[warning] save_result() called with no map data; skipping export for {output_path}.")
|
| 768 |
+
return None
|
| 769 |
+
|
| 770 |
+
# Chunk-process to avoid OOM on long sequences
|
| 771 |
+
# (our finetuning removed torch.no_grad() from model internals)
|
| 772 |
+
S_total = map_all.shape[0]
|
| 773 |
+
chunk_size = 50 # process 50 frames at a time
|
| 774 |
+
all_pts, all_conf, all_poses = [], [], []
|
| 775 |
+
|
| 776 |
+
for start in range(0, S_total, chunk_size):
|
| 777 |
+
end = min(start + chunk_size, S_total)
|
| 778 |
+
chunk = map_all[start:end].cuda()
|
| 779 |
+
torch.cuda.empty_cache()
|
| 780 |
+
result_chunk = self.model.extract(chunk)
|
| 781 |
+
all_pts.append(result_chunk['points'].cpu())
|
| 782 |
+
all_conf.append(result_chunk['conf'].cpu())
|
| 783 |
+
all_poses.append(result_chunk['camera_poses'].cpu())
|
| 784 |
+
del result_chunk, chunk
|
| 785 |
+
torch.cuda.empty_cache()
|
| 786 |
+
|
| 787 |
+
pts = torch.cat(all_pts, dim=1).numpy() # 1,S,H,W,3
|
| 788 |
+
conf = torch.cat(all_conf, dim=1).numpy()
|
| 789 |
+
camera_pose = torch.cat(all_poses, dim=1) # 1,S,4,4
|
| 790 |
+
|
| 791 |
+
_,S,H,W,_ = pts.shape
|
| 792 |
+
point_clouds = [pts[0,s] for s in range(S)]
|
| 793 |
+
conf_threshold = np.percentile(conf, 15)
|
| 794 |
+
confs = [conf[0,s]>=conf_threshold for s in range(S)]
|
| 795 |
+
|
| 796 |
+
colors = torch.stack(self.frames).permute(0,2,3,1).reshape(-1,3).cpu().numpy()[:,::-1] # S,H,W,C
|
| 797 |
+
msk = np.stack(confs).reshape(-1)
|
| 798 |
+
pcd = o3d.geometry.PointCloud()
|
| 799 |
+
pcd.points = o3d.utility.Vector3dVector(pts.reshape(-1,3).astype(np.float64)[msk])
|
| 800 |
+
pcd.colors = o3d.utility.Vector3dVector(colors.reshape(-1,3).astype(np.float64)[msk])
|
| 801 |
+
#downpcd = pcd.voxel_down_sample(voxel_size=0.005)
|
| 802 |
+
o3d.io.write_point_cloud(f"{output_path}.ply", pcd)
|
| 803 |
+
poses = camera_pose[0].numpy()
|
| 804 |
+
|
| 805 |
+
self.write_poses_to_file(f"{output_path}_traj.txt", poses, self.kf_timestamps)
|
| 806 |
+
self.save_framewise_pointclouds(f"{output_path}_pc", point_clouds, self.kf_timestamps, confs)
|
| 807 |
+
|
| 808 |
+
return {'points': torch.from_numpy(pts), 'conf': torch.from_numpy(conf), 'camera_poses': camera_pose}
|
| 809 |
+
|
| 810 |
+
def write_poses_to_file(self, filename, poses, frame_ids):
|
| 811 |
+
|
| 812 |
+
with open(filename, "w") as f:
|
| 813 |
+
assert len(poses) == len(frame_ids), "Number of provided poses and number of frame ids do not match"
|
| 814 |
+
for frame_id, pose in zip(frame_ids, poses):
|
| 815 |
+
x, y, z = pose[0:3, 3]
|
| 816 |
+
rotation_matrix = pose[0:3, 0:3]
|
| 817 |
+
quaternion = R.from_matrix(rotation_matrix).as_quat() # x, y, z, w
|
| 818 |
+
output = np.array([float(frame_id), x, y, z, *quaternion])
|
| 819 |
+
f.write(" ".join(f"{v:.8f}" for v in output) + "\n")
|
| 820 |
+
|
| 821 |
+
def save_framewise_pointclouds(self, filename, pointclouds, frame_ids, conf_masks):
|
| 822 |
+
os.makedirs(filename, exist_ok=True)
|
| 823 |
+
for frame_id, pointcloud, conf_masks in zip(frame_ids, pointclouds, conf_masks):
|
| 824 |
+
# save pcd as numpy array
|
| 825 |
+
np.savez(f"{filename}/{frame_id}.npz", pointcloud=pointcloud, mask=conf_masks)
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
def get_parser():
|
| 829 |
+
parser = argparse.ArgumentParser(description="SLAM-Former demo")
|
| 830 |
+
parser.add_argument("--ckpt_path", type=str, default="path/to/checkpoint.pth.model", help="Path to the checkpoint")
|
| 831 |
+
parser.add_argument("--image_folder", type=str, default="path/to/image/folder", help="Path to folder containing images")
|
| 832 |
+
parser.add_argument("--target_size", type=int, default=518, help="the target size of image(longer side)")
|
| 833 |
+
parser.add_argument("--output_dir", type=str, default="outputs/tmp", help="Path to save the output")
|
| 834 |
+
parser.add_argument("--stride", type=int, default=1, help="Frame stride for subsampling the input sequence")
|
| 835 |
+
parser.add_argument("--kf_th", type=float, default=0.1, help="Keyframe selection threshold (minimum translation distance)")
|
| 836 |
+
parser.add_argument("--retention_ratio", type=float, default=None, help="KV Pruning retention ratio")
|
| 837 |
+
parser.add_argument("--bn_every", type=int, default=None, help="Run backend optimization every N keyframes")
|
| 838 |
+
parser.add_argument("--loop_mask_mode", type=str, default=None, choices=["hard_top1", "soft_all"], help="Override loop retrieval masking mode")
|
| 839 |
+
parser.add_argument("--submap_train_mode", type=str, default=None, choices=["full_token", "top5_dual_queue"], help="Override submap queue mode")
|
| 840 |
+
parser.add_argument("--submap_retrieval_topk", type=int, default=None, help="Override number of historical submaps fetched in soft_all mode")
|
| 841 |
+
parser.add_argument("--submap_fetch_source", type=str, default=None, choices=["frontend", "backend"], help="Override token source used for retrieval")
|
| 842 |
+
parser.add_argument("--submap_descriptor_source", type=str, default=None, choices=["frontend", "backend"], help="Override descriptor source used for retrieval")
|
| 843 |
+
parser.add_argument("--max_recursive_submaps", type=int, default=None, help="Override recursive covisibility fetch limit")
|
| 844 |
+
parser.add_argument("--vis", action="store_true", help="Enable real-time visualization with Rerun")
|
| 845 |
+
parser.add_argument("--resize_rate", type=float, default=1, help="Resize rate for input images before processing")
|
| 846 |
+
|
| 847 |
+
args = parser.parse_args()
|
| 848 |
+
return args
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
if __name__ == '__main__':
|
| 852 |
+
args = get_parser()
|
| 853 |
+
image_folder = args.image_folder
|
| 854 |
+
outdir = args.output_dir
|
| 855 |
+
os.makedirs(outdir, exist_ok=True)
|
| 856 |
+
|
| 857 |
+
if 'tum' in args.image_folder:
|
| 858 |
+
fx = 525.0 # focal length x
|
| 859 |
+
fy = 525.0 # focal length y
|
| 860 |
+
cx = 319.5 # optical center x
|
| 861 |
+
cy = 239.5 # optical center y
|
| 862 |
+
K = np.eye(3)
|
| 863 |
+
K[0,0] = fx
|
| 864 |
+
K[1,1] = fy
|
| 865 |
+
K[0,2] = cx
|
| 866 |
+
K[1,2] = cy
|
| 867 |
+
elif 'Replica' in args.image_folder:
|
| 868 |
+
fx = 600. # focal length x
|
| 869 |
+
fy = 600.0 # focal length y
|
| 870 |
+
cx = 599.5 # optical center x
|
| 871 |
+
cy = 339.5 # optical center y
|
| 872 |
+
K = np.eye(3)
|
| 873 |
+
K[0,0] = fx
|
| 874 |
+
K[1,1] = fy
|
| 875 |
+
K[0,2] = cx
|
| 876 |
+
K[1,2] = cy
|
| 877 |
+
else:
|
| 878 |
+
K = None
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
# Use the provided image folder path
|
| 882 |
+
print(f"Loading images from {image_folder}...")
|
| 883 |
+
image_names = [f for f in glob.glob(os.path.join(image_folder, "*"))
|
| 884 |
+
if "depth" not in os.path.basename(f).lower() and "txt" not in os.path.basename(f).lower()
|
| 885 |
+
and "db" not in os.path.basename(f).lower()]
|
| 886 |
+
image_names = utils.sort_images_by_number(image_names)
|
| 887 |
+
|
| 888 |
+
frame_ids = []
|
| 889 |
+
for path in image_names:
|
| 890 |
+
filename = os.path.basename(path)
|
| 891 |
+
match = re.search(r'\d+(?:\.\d+)?', filename) # matches integers and decimals
|
| 892 |
+
if match:
|
| 893 |
+
frame_ids.append(float(match.group()))
|
| 894 |
+
else:
|
| 895 |
+
raise ValueError(f"No number found in image name: {filename}")
|
| 896 |
+
|
| 897 |
+
print(f"Found {len(image_names)} images")
|
| 898 |
+
|
| 899 |
+
print('resize image', args.resize_rate)
|
| 900 |
+
|
| 901 |
+
slam = SLAM(
|
| 902 |
+
outdir=outdir,
|
| 903 |
+
kf_th=args.kf_th,
|
| 904 |
+
bn_every=args.bn_every,
|
| 905 |
+
vis=args.vis,
|
| 906 |
+
ckpt_path=args.ckpt_path,
|
| 907 |
+
target_size=args.target_size,
|
| 908 |
+
retention_ratio=args.retention_ratio,
|
| 909 |
+
loop_mask_mode=args.loop_mask_mode,
|
| 910 |
+
submap_train_mode=args.submap_train_mode,
|
| 911 |
+
submap_retrieval_topk=args.submap_retrieval_topk,
|
| 912 |
+
submap_fetch_source=args.submap_fetch_source,
|
| 913 |
+
submap_descriptor_source=args.submap_descriptor_source,
|
| 914 |
+
max_recursive_submaps=args.max_recursive_submaps,
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
slam.K = K
|
| 918 |
+
|
| 919 |
+
for frame_id, image_name in zip(frame_ids[::args.stride], image_names[::args.stride]):
|
| 920 |
+
img = cv2.imread(image_name)
|
| 921 |
+
|
| 922 |
+
if args.resize_rate != 1:
|
| 923 |
+
H,W,_ = img.shape
|
| 924 |
+
img = cv2.resize(img, (int(W*args.resize_rate), int(H*args.resize_rate)), cv2.INTER_CUBIC)
|
| 925 |
+
slam.step(frame_id, img)
|
| 926 |
+
result = slam.terminate()
|
| 927 |
+
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/download_data.sh
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# ──────────────────────────────────────────────────────────
|
| 3 |
+
# Local download checklist template for SLAM-Former fine-tuning
|
| 4 |
+
#
|
| 5 |
+
# README-backed links:
|
| 6 |
+
# - ARKitScenes / MVS-Synth / ScanNet -> Hugging Face SLF dataset tree
|
| 7 |
+
# (ScanNet is stored as a split archive: processed_scannetv2.zip.part.aa/.ab/.ac/.ad)
|
| 8 |
+
# - HyperSim -> Hugging Face preprocessed_Hypersim tree
|
| 9 |
+
# - ScanNet++ / BlendedMVS / MegaDepth -> README says "coming soon" or no direct archive command is listed here
|
| 10 |
+
#
|
| 11 |
+
# This script is intentionally split into per-dataset toggles so you can fill
|
| 12 |
+
# the missing datasets one by one on the local machine without overwriting
|
| 13 |
+
# existing extracted folders.
|
| 14 |
+
#
|
| 15 |
+
# Example:
|
| 16 |
+
# DOWNLOAD_ARKITSCENES=1 DOWNLOAD_SCANNET=1 bash slam/download_data.sh
|
| 17 |
+
# ──────────────────────────────────────────────────────────
|
| 18 |
+
set -euo pipefail
|
| 19 |
+
|
| 20 |
+
PROJECT_DIR="/var/scratch/qzhang2/SLAM-Former"
|
| 21 |
+
DATA_DIR="$PROJECT_DIR/data/train"
|
| 22 |
+
CKPT_DIR="$PROJECT_DIR/ckpt"
|
| 23 |
+
LOG_FILE="$PROJECT_DIR/download.log"
|
| 24 |
+
|
| 25 |
+
mkdir -p "$DATA_DIR" "$CKPT_DIR"
|
| 26 |
+
|
| 27 |
+
log() { echo "[$(date '+%Y-%m-%d %H:%M:%S')] $*" | tee -a "$LOG_FILE"; }
|
| 28 |
+
|
| 29 |
+
HF_BASE="https://huggingface.co/datasets/KevinConnorLee/SLF/resolve/main"
|
| 30 |
+
HF_HYPERSIM_BASE="https://huggingface.co/datasets/KevinConnorLee/preprocessed_Hypersim/resolve/main"
|
| 31 |
+
HF_CKPT="https://huggingface.co/Jarrome/SLAM-Former/resolve/main/518/checkpoint-10.pth.model"
|
| 32 |
+
HF_ARKITSCENES_REPO="https://huggingface.co/datasets/Pointcept/arkitscenes-compressed"
|
| 33 |
+
HF_SCANNETPP_REPO="https://huggingface.co/datasets/Pointcept/scannetpp-compressed"
|
| 34 |
+
HF_HYPERSIM_REPO="https://huggingface.co/datasets/geyongtao/hypersim"
|
| 35 |
+
BLENDEDMVS_LOWRES_URL="https://1drv.ms/u/s!Ag8Dbz2Aqc81gVDu7FHfbPZwqhIy?e=BHY07t"
|
| 36 |
+
BLENDEDMVS_HIGHRES_URL="https://1drv.ms/u/s!Ag8Dbz2Aqc81ezb9OciQ4zKwJ_w?e=afFOTi"
|
| 37 |
+
MEGADEPTH_V1_URL="https://www.cs.cornell.edu/projects/megadepth/dataset/Megadepth_v1/MegaDepth_v1.tar.gz"
|
| 38 |
+
|
| 39 |
+
DOWNLOAD_CHECKPOINT="${DOWNLOAD_CHECKPOINT:-1}"
|
| 40 |
+
KEEP_ARCHIVES="${KEEP_ARCHIVES:-0}"
|
| 41 |
+
DOWNLOAD_ARKITSCENES="${DOWNLOAD_ARKITSCENES:-0}"
|
| 42 |
+
DOWNLOAD_SCANNETPP="${DOWNLOAD_SCANNETPP:-0}"
|
| 43 |
+
DOWNLOAD_MVS_SYNTH="${DOWNLOAD_MVS_SYNTH:-0}"
|
| 44 |
+
DOWNLOAD_SCANNET="${DOWNLOAD_SCANNET:-0}"
|
| 45 |
+
DOWNLOAD_HYPERSIM="${DOWNLOAD_HYPERSIM:-0}"
|
| 46 |
+
DOWNLOAD_BLENDEDMVS="${DOWNLOAD_BLENDEDMVS:-0}"
|
| 47 |
+
DOWNLOAD_MEGADEPTH="${DOWNLOAD_MEGADEPTH:-0}"
|
| 48 |
+
HF_BLENDEDMVS_BASE="${HF_BLENDEDMVS_BASE:-}"
|
| 49 |
+
HF_MEGADEPTH_BASE="${HF_MEGADEPTH_BASE:-}"
|
| 50 |
+
RELEASED_COMPLETION_MARKER="$PROJECT_DIR/.download_complete_released_paper_datasets"
|
| 51 |
+
FULL_COMPLETION_MARKER="$PROJECT_DIR/.download_complete_all_requested_paper_datasets"
|
| 52 |
+
IN_PROGRESS_MARKER="$PROJECT_DIR/.download_in_progress"
|
| 53 |
+
|
| 54 |
+
rm -f "$PROJECT_DIR/.download_complete" "$RELEASED_COMPLETION_MARKER" "$FULL_COMPLETION_MARKER" "$IN_PROGRESS_MARKER"
|
| 55 |
+
touch "$IN_PROGRESS_MARKER"
|
| 56 |
+
|
| 57 |
+
download_file() {
|
| 58 |
+
local url="$1"
|
| 59 |
+
local output="$2"
|
| 60 |
+
if [ -f "$output" ]; then
|
| 61 |
+
log "Found existing file $(basename "$output"), attempting resume if incomplete."
|
| 62 |
+
fi
|
| 63 |
+
wget -c --progress=bar:force -O "$output" "$url" 2>&1 | tee -a "$LOG_FILE"
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
download_hf_repo_dataset() {
|
| 67 |
+
local label="$1"
|
| 68 |
+
local expected_dir="$2"
|
| 69 |
+
local repo_url="$3"
|
| 70 |
+
if [ -d "$expected_dir" ] && [ -n "$(find "$expected_dir" -mindepth 1 -maxdepth 1 2>/dev/null | head -n 1)" ]; then
|
| 71 |
+
log "$label already exists, skipping."
|
| 72 |
+
return
|
| 73 |
+
fi
|
| 74 |
+
rm -rf "$expected_dir"
|
| 75 |
+
mkdir -p "$expected_dir"
|
| 76 |
+
log "=== Cloning $label from Hugging Face ==="
|
| 77 |
+
git clone --depth 1 "$repo_url" "$expected_dir" 2>&1 | tee -a "$LOG_FILE"
|
| 78 |
+
if command -v git-lfs >/dev/null 2>&1; then
|
| 79 |
+
log "Fetching LFS files for $label..."
|
| 80 |
+
(cd "$expected_dir" && git lfs pull) 2>&1 | tee -a "$LOG_FILE"
|
| 81 |
+
else
|
| 82 |
+
log "WARNING: git-lfs not found; $label may contain LFS pointer files only."
|
| 83 |
+
fi
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
extract_archive_auto() {
|
| 87 |
+
local archive="$1"
|
| 88 |
+
local target_dir="$2"
|
| 89 |
+
mkdir -p "$target_dir"
|
| 90 |
+
if unzip -t "$archive" >/dev/null 2>&1; then
|
| 91 |
+
unzip -o "$archive" -d "$target_dir" 2>&1 | tee -a "$LOG_FILE"
|
| 92 |
+
return 0
|
| 93 |
+
fi
|
| 94 |
+
if tar -tf "$archive" >/dev/null 2>&1; then
|
| 95 |
+
tar -xf "$archive" -C "$target_dir" 2>&1 | tee -a "$LOG_FILE"
|
| 96 |
+
return 0
|
| 97 |
+
fi
|
| 98 |
+
log "ERROR: Unsupported archive format for $archive"
|
| 99 |
+
return 1
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
download_url_dataset() {
|
| 103 |
+
local label="$1"
|
| 104 |
+
local expected_dir="$2"
|
| 105 |
+
local url="$3"
|
| 106 |
+
local archive_name="$4"
|
| 107 |
+
if [ -d "$expected_dir" ] && [ -n "$(find "$expected_dir" -mindepth 1 -maxdepth 1 2>/dev/null | head -n 1)" ]; then
|
| 108 |
+
log "$label already exists, skipping."
|
| 109 |
+
return
|
| 110 |
+
fi
|
| 111 |
+
log "=== Downloading $label ==="
|
| 112 |
+
download_file "$url" "$DATA_DIR/$archive_name"
|
| 113 |
+
extract_archive_auto "$DATA_DIR/$archive_name" "$expected_dir"
|
| 114 |
+
if [ "$KEEP_ARCHIVES" != "1" ]; then
|
| 115 |
+
rm -f "$DATA_DIR/$archive_name"
|
| 116 |
+
fi
|
| 117 |
+
if [ -d "$expected_dir" ]; then
|
| 118 |
+
log "$label done."
|
| 119 |
+
else
|
| 120 |
+
log "WARNING: $label archive extracted, but expected path is still missing: $expected_dir"
|
| 121 |
+
fi
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
assemble_parts() {
|
| 125 |
+
local output="$1"
|
| 126 |
+
shift
|
| 127 |
+
if [ -f "$output" ]; then
|
| 128 |
+
log "Found existing assembled archive $(basename "$output"), skipping assembly."
|
| 129 |
+
return
|
| 130 |
+
fi
|
| 131 |
+
cat "$@" > "${output}.tmp"
|
| 132 |
+
mv "${output}.tmp" "$output"
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
rename_if_needed() {
|
| 136 |
+
local expected="$1"
|
| 137 |
+
shift
|
| 138 |
+
if [ -e "$expected" ]; then
|
| 139 |
+
return
|
| 140 |
+
fi
|
| 141 |
+
local candidate
|
| 142 |
+
for candidate in "$@"; do
|
| 143 |
+
if [ -e "$candidate" ]; then
|
| 144 |
+
mv "$candidate" "$expected"
|
| 145 |
+
return
|
| 146 |
+
fi
|
| 147 |
+
done
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
extract_zip() {
|
| 151 |
+
local archive="$1"
|
| 152 |
+
unzip -o "$archive" -d "$DATA_DIR/" 2>&1 | tee -a "$LOG_FILE"
|
| 153 |
+
if [ "$KEEP_ARCHIVES" != "1" ]; then
|
| 154 |
+
rm -f "$archive"
|
| 155 |
+
fi
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
cleanup_parts() {
|
| 159 |
+
if [ "$KEEP_ARCHIVES" = "1" ]; then
|
| 160 |
+
return
|
| 161 |
+
fi
|
| 162 |
+
rm -f "$@"
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
download_single_archive_dataset() {
|
| 166 |
+
local label="$1"
|
| 167 |
+
local expected_dir="$2"
|
| 168 |
+
local archive_name="$3"
|
| 169 |
+
local base_url="$4"
|
| 170 |
+
shift 4
|
| 171 |
+
if [ -d "$expected_dir" ]; then
|
| 172 |
+
log "$label already exists, skipping."
|
| 173 |
+
return
|
| 174 |
+
fi
|
| 175 |
+
cd "$DATA_DIR"
|
| 176 |
+
log "=== Downloading $label ==="
|
| 177 |
+
download_file "$base_url/$archive_name" "$DATA_DIR/$archive_name"
|
| 178 |
+
log "Extracting $label..."
|
| 179 |
+
extract_zip "$DATA_DIR/$archive_name"
|
| 180 |
+
rename_if_needed "$expected_dir" "$@"
|
| 181 |
+
if [ -d "$expected_dir" ]; then
|
| 182 |
+
log "$label done."
|
| 183 |
+
else
|
| 184 |
+
log "WARNING: $label archive extracted, but expected path is still missing: $expected_dir"
|
| 185 |
+
fi
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
download_split_archive_dataset() {
|
| 189 |
+
local label="$1"
|
| 190 |
+
local expected_dir="$2"
|
| 191 |
+
local assembled_archive="$3"
|
| 192 |
+
local base_url="$4"
|
| 193 |
+
local aliases_string="$5"
|
| 194 |
+
shift 5
|
| 195 |
+
local parts=("$@")
|
| 196 |
+
if [ -d "$expected_dir" ] && [ -n "$(find "$expected_dir" -mindepth 1 -maxdepth 1 2>/dev/null | head -n 1)" ]; then
|
| 197 |
+
log "$label already exists, skipping."
|
| 198 |
+
return
|
| 199 |
+
fi
|
| 200 |
+
cd "$DATA_DIR"
|
| 201 |
+
log "=== Downloading $label ==="
|
| 202 |
+
local downloaded_parts=()
|
| 203 |
+
local part
|
| 204 |
+
for part in "${parts[@]}"; do
|
| 205 |
+
download_file "$base_url/$part" "$DATA_DIR/$part"
|
| 206 |
+
downloaded_parts+=("$DATA_DIR/$part")
|
| 207 |
+
done
|
| 208 |
+
log "Assembling $label archive..."
|
| 209 |
+
assemble_parts "$DATA_DIR/$assembled_archive" "${downloaded_parts[@]}"
|
| 210 |
+
log "Extracting $label..."
|
| 211 |
+
extract_zip "$DATA_DIR/$assembled_archive"
|
| 212 |
+
IFS='|' read -r -a alias_candidates <<< "$aliases_string"
|
| 213 |
+
local alias_paths=()
|
| 214 |
+
local alias
|
| 215 |
+
for alias in "${alias_candidates[@]}"; do
|
| 216 |
+
alias_paths+=("$DATA_DIR/$alias")
|
| 217 |
+
done
|
| 218 |
+
rename_if_needed "$expected_dir" "${alias_paths[@]}"
|
| 219 |
+
cleanup_parts "${downloaded_parts[@]}"
|
| 220 |
+
if [ -d "$expected_dir" ]; then
|
| 221 |
+
log "$label done."
|
| 222 |
+
else
|
| 223 |
+
log "WARNING: $label archive extracted, but expected path is still missing: $expected_dir"
|
| 224 |
+
fi
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
if [ "$DOWNLOAD_CHECKPOINT" = "1" ]; then
|
| 228 |
+
log "=== Downloading pretrained checkpoint (3.84 GB) ==="
|
| 229 |
+
if [ ! -f "$CKPT_DIR/checkpoint-10.pth.model" ]; then
|
| 230 |
+
download_file "$HF_CKPT" "$CKPT_DIR/checkpoint-10.pth.model"
|
| 231 |
+
log "Checkpoint downloaded."
|
| 232 |
+
else
|
| 233 |
+
log "Checkpoint already exists, skipping."
|
| 234 |
+
fi
|
| 235 |
+
fi
|
| 236 |
+
|
| 237 |
+
if [ "$DOWNLOAD_ARKITSCENES" = "1" ]; then
|
| 238 |
+
download_hf_repo_dataset \
|
| 239 |
+
"ARKitScenes (HF repo: Pointcept/arkitscenes-compressed)" \
|
| 240 |
+
"$DATA_DIR/processed_arkitscenes" \
|
| 241 |
+
"$HF_ARKITSCENES_REPO"
|
| 242 |
+
fi
|
| 243 |
+
|
| 244 |
+
if [ "$DOWNLOAD_SCANNETPP" = "1" ]; then
|
| 245 |
+
download_hf_repo_dataset \
|
| 246 |
+
"ScanNet++ (HF repo: Pointcept/scannetpp-compressed)" \
|
| 247 |
+
"$DATA_DIR/processed_scannetpp" \
|
| 248 |
+
"$HF_SCANNETPP_REPO"
|
| 249 |
+
fi
|
| 250 |
+
|
| 251 |
+
if [ "$DOWNLOAD_MVS_SYNTH" = "1" ]; then
|
| 252 |
+
download_single_archive_dataset \
|
| 253 |
+
"MVS-Synth (README-backed HF tree)" \
|
| 254 |
+
"$DATA_DIR/processed_mvs_synth" \
|
| 255 |
+
"processed_mvs_synth.zip" \
|
| 256 |
+
"$HF_BASE" \
|
| 257 |
+
"processed_mvs_synth" \
|
| 258 |
+
"$DATA_DIR/processed_mvs_synth"
|
| 259 |
+
fi
|
| 260 |
+
|
| 261 |
+
if [ "$DOWNLOAD_SCANNET" = "1" ]; then
|
| 262 |
+
download_split_archive_dataset \
|
| 263 |
+
"ScanNet (HF split archive: KevinConnorLee/SLF)" \
|
| 264 |
+
"$DATA_DIR/processed_scannet" \
|
| 265 |
+
"processed_scannetv2.zip" \
|
| 266 |
+
"$HF_BASE" \
|
| 267 |
+
"processed_scannetv2|processed_scannet" \
|
| 268 |
+
"processed_scannetv2.zip.part.aa" \
|
| 269 |
+
"processed_scannetv2.zip.part.ab" \
|
| 270 |
+
"processed_scannetv2.zip.part.ac" \
|
| 271 |
+
"processed_scannetv2.zip.part.ad"
|
| 272 |
+
fi
|
| 273 |
+
|
| 274 |
+
if [ "$DOWNLOAD_HYPERSIM" = "1" ]; then
|
| 275 |
+
download_hf_repo_dataset \
|
| 276 |
+
"HyperSim (HF repo: geyongtao/hypersim)" \
|
| 277 |
+
"$DATA_DIR/hypersim" \
|
| 278 |
+
"$HF_HYPERSIM_REPO"
|
| 279 |
+
fi
|
| 280 |
+
|
| 281 |
+
if [ "$DOWNLOAD_BLENDEDMVS" = "1" ]; then
|
| 282 |
+
BLENDEDMVS_URL="${BLENDEDMVS_URL:-$BLENDEDMVS_LOWRES_URL}"
|
| 283 |
+
BLENDEDMVS_ARCHIVE_NAME="${BLENDEDMVS_ARCHIVE_NAME:-blendedmvs_lowres.download}"
|
| 284 |
+
if [ "${BLENDEDMVS_VARIANT:-lowres}" = "highres" ]; then
|
| 285 |
+
BLENDEDMVS_URL="$BLENDEDMVS_HIGHRES_URL"
|
| 286 |
+
BLENDEDMVS_ARCHIVE_NAME="${BLENDEDMVS_ARCHIVE_NAME:-blendedmvs_highres.download}"
|
| 287 |
+
fi
|
| 288 |
+
download_url_dataset \
|
| 289 |
+
"BlendedMVS (${BLENDEDMVS_VARIANT:-lowres})" \
|
| 290 |
+
"$DATA_DIR/processed_blendedmvs" \
|
| 291 |
+
"$BLENDEDMVS_URL" \
|
| 292 |
+
"$BLENDEDMVS_ARCHIVE_NAME"
|
| 293 |
+
fi
|
| 294 |
+
|
| 295 |
+
if [ "$DOWNLOAD_MEGADEPTH" = "1" ] && [ ! -d "$DATA_DIR/processed_megadepth" ]; then
|
| 296 |
+
download_url_dataset \
|
| 297 |
+
"MegaDepth v1 (Cornell official archive)" \
|
| 298 |
+
"$DATA_DIR/processed_megadepth" \
|
| 299 |
+
"$MEGADEPTH_V1_URL" \
|
| 300 |
+
"MegaDepth_v1.tar.gz"
|
| 301 |
+
fi
|
| 302 |
+
|
| 303 |
+
log "=== Download complete ==="
|
| 304 |
+
log "Disk usage:"
|
| 305 |
+
du -sh "$DATA_DIR"/* "$CKPT_DIR"/* 2>/dev/null | tee -a "$LOG_FILE"
|
| 306 |
+
log "Total:"
|
| 307 |
+
du -sh "$DATA_DIR" "$CKPT_DIR" | tee -a "$LOG_FILE"
|
| 308 |
+
|
| 309 |
+
missing_released=()
|
| 310 |
+
missing_all_requested=()
|
| 311 |
+
|
| 312 |
+
if [ "$DOWNLOAD_CHECKPOINT" = "1" ] && [ ! -f "$CKPT_DIR/checkpoint-10.pth.model" ]; then
|
| 313 |
+
missing_released+=("ckpt/checkpoint-10.pth.model")
|
| 314 |
+
fi
|
| 315 |
+
if [ ! -d "$DATA_DIR/processed_scannetpp" ]; then
|
| 316 |
+
missing_released+=("data/train/processed_scannetpp")
|
| 317 |
+
fi
|
| 318 |
+
if [ ! -d "$DATA_DIR/processed_mvs_synth" ]; then
|
| 319 |
+
missing_released+=("data/train/processed_mvs_synth")
|
| 320 |
+
fi
|
| 321 |
+
if [ ! -d "$DATA_DIR/processed_arkitscenes" ]; then
|
| 322 |
+
missing_released+=("data/train/processed_arkitscenes")
|
| 323 |
+
fi
|
| 324 |
+
if [ "$DOWNLOAD_SCANNET" = "1" ] && [ ! -d "$DATA_DIR/processed_scannet" ]; then
|
| 325 |
+
missing_released+=("data/train/processed_scannet")
|
| 326 |
+
fi
|
| 327 |
+
if [ "$DOWNLOAD_HYPERSIM" = "1" ] && [ ! -d "$DATA_DIR/hypersim" ]; then
|
| 328 |
+
missing_released+=("data/train/hypersim")
|
| 329 |
+
fi
|
| 330 |
+
if [ "$DOWNLOAD_BLENDEDMVS" = "1" ] && [ ! -d "$DATA_DIR/processed_blendedmvs" ]; then
|
| 331 |
+
missing_released+=("data/train/processed_blendedmvs")
|
| 332 |
+
fi
|
| 333 |
+
|
| 334 |
+
missing_all_requested=("${missing_released[@]}")
|
| 335 |
+
if [ "$DOWNLOAD_MEGADEPTH" = "1" ] && [ ! -d "$DATA_DIR/processed_megadepth" ]; then
|
| 336 |
+
missing_all_requested+=("data/train/processed_megadepth")
|
| 337 |
+
fi
|
| 338 |
+
|
| 339 |
+
rm -f "$IN_PROGRESS_MARKER"
|
| 340 |
+
|
| 341 |
+
if [ "${#missing_released[@]}" -eq 0 ]; then
|
| 342 |
+
touch "$RELEASED_COMPLETION_MARKER"
|
| 343 |
+
log "Released paper datasets are complete. Marker created at $RELEASED_COMPLETION_MARKER"
|
| 344 |
+
else
|
| 345 |
+
log "WARNING: Missing released download targets: ${missing_released[*]}"
|
| 346 |
+
fi
|
| 347 |
+
|
| 348 |
+
if [ "${#missing_all_requested[@]}" -eq 0 ]; then
|
| 349 |
+
touch "$FULL_COMPLETION_MARKER"
|
| 350 |
+
touch "$PROJECT_DIR/.download_complete"
|
| 351 |
+
log "All requested datasets are complete. Markers created at $FULL_COMPLETION_MARKER and $PROJECT_DIR/.download_complete"
|
| 352 |
+
else
|
| 353 |
+
log "WARNING: Missing requested targets: ${missing_all_requested[*]}"
|
| 354 |
+
fi
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/exp_joint_freeze_frontend_fsdp_8gpu.sh
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=sf_smoke_joint_freeze_frontend_fsdp_2gpu
|
| 3 |
+
#SBATCH --nodes=1
|
| 4 |
+
#SBATCH --ntasks-per-node=1
|
| 5 |
+
#SBATCH --cpus-per-task=12
|
| 6 |
+
#SBATCH --gres=gpu:2
|
| 7 |
+
#SBATCH --mem=24G
|
| 8 |
+
#SBATCH --time=24:00:00
|
| 9 |
+
#SBATCH --output=%x_%j.out
|
| 10 |
+
#SBATCH --error=%x_%j.err
|
| 11 |
+
|
| 12 |
+
set -euo pipefail
|
| 13 |
+
|
| 14 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 15 |
+
PROJECT_DIR="${PROJECT_DIR:-${SLURM_SUBMIT_DIR:-$(dirname "$SCRIPT_DIR")}}"
|
| 16 |
+
SRC_DIR="$PROJECT_DIR/src"
|
| 17 |
+
|
| 18 |
+
# 迁移到其他集群时优先修改:GPU/端口、conda 初始化脚本、数据目录、预训练权重和保存目录。
|
| 19 |
+
MASTER_PORT="${MASTER_PORT:-29662}"
|
| 20 |
+
NUM_GPUS="${NUM_GPUS:-8}"
|
| 21 |
+
CONDA_SH="${CONDA_SH:-/home/23068142r/miniconda3/etc/profile.d/conda.sh}"
|
| 22 |
+
CONDA_ENV_NAME="SLAM-Former"
|
| 23 |
+
DATA_ROOT="${DATA_ROOT:-/home/23068142r/work_dir/data}"
|
| 24 |
+
ROOT_ARKIT="${ROOT_ARKIT:-$DATA_ROOT/processed_arkitscenes}"
|
| 25 |
+
ROOT_SCANNETPP="${ROOT_SCANNETPP:-$DATA_ROOT/preprocessed_scannetpp}"
|
| 26 |
+
ROOT_SCANNET="${ROOT_SCANNET:-$DATA_ROOT/processed_scannet}"
|
| 27 |
+
ROOT_SCANNET_FALLBACK="${ROOT_SCANNET_FALLBACK:-$DATA_ROOT/processed_scannetv2}"
|
| 28 |
+
ROOT_HYPERSIM="${ROOT_HYPERSIM:-$DATA_ROOT/preprocessed_Hypersim}"
|
| 29 |
+
ROOT_BLENDEDMVS="${ROOT_BLENDEDMVS:-$DATA_ROOT/processed_blendedmvs}"
|
| 30 |
+
ROOT_MEGADEPTH="${ROOT_MEGADEPTH:-$DATA_ROOT/processed_megadepth}"
|
| 31 |
+
ROOT_MVS_SYNTH="${ROOT_MVS_SYNTH:-$DATA_ROOT/processed_mvs_synth}"
|
| 32 |
+
EXPERIMENT_ROOT="${EXPERIMENT_ROOT:-paper_smoke_local_8gpu}"
|
| 33 |
+
VARIANT_NAME="${VARIANT_NAME:-joint_freeze_frontend_fsdp_sub12}"
|
| 34 |
+
EXP_NAME="${EXP_NAME:-paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12}"
|
| 35 |
+
SAVE_DIR="${SAVE_DIR:-$PROJECT_DIR/checkpoints/$EXPERIMENT_ROOT/$VARIANT_NAME}"
|
| 36 |
+
PRETRAINED="${PRETRAINED:-/home/23068142r/work_dir/projects/e2e-semantic-SLAM-submap/ckpt/checkpoint-10.pth.model}"
|
| 37 |
+
RESUME="${RESUME:-null}"
|
| 38 |
+
|
| 39 |
+
CONFIG_NAME="${CONFIG_NAME:-finetune_paper_h20.yaml}"
|
| 40 |
+
DIST_STRATEGY="${DIST_STRATEGY:-fsdp}"
|
| 41 |
+
AUTO_DISABLE_MISSING="${AUTO_DISABLE_MISSING:-1}"
|
| 42 |
+
TRAIN_SUBMAP_MODULES_ONLY="${TRAIN_SUBMAP_MODULES_ONLY:-0}"
|
| 43 |
+
DETACH_FRONTEND_TOKENS="${DETACH_FRONTEND_TOKENS:-1}"
|
| 44 |
+
NUM_VIEWS_ALL="${NUM_VIEWS_ALL:-64}"
|
| 45 |
+
NUM_VIEWS_ARKIT="${NUM_VIEWS_ARKIT:-64}"
|
| 46 |
+
NUM_VIEWS_SCANNETPP="${NUM_VIEWS_SCANNETPP:-24}"
|
| 47 |
+
NUM_VIEWS_SCANNET="${NUM_VIEWS_SCANNET:-64}"
|
| 48 |
+
NUM_VIEWS_HYPERSIM="${NUM_VIEWS_HYPERSIM:-24}"
|
| 49 |
+
NUM_VIEWS_BLENDEDMVS="${NUM_VIEWS_BLENDEDMVS:-64}"
|
| 50 |
+
NUM_VIEWS_MEGADEPTH="${NUM_VIEWS_MEGADEPTH:-64}"
|
| 51 |
+
NUM_VIEWS_MVS_SYNTH="${NUM_VIEWS_MVS_SYNTH:-24}"
|
| 52 |
+
SUBMAP_SIZE="${SUBMAP_SIZE:-12}"
|
| 53 |
+
SUBMAP_TRAIN_MODE="${SUBMAP_TRAIN_MODE:-full_token}"
|
| 54 |
+
SUBMAP_RETRIEVAL_TOPK="${SUBMAP_RETRIEVAL_TOPK:-0}"
|
| 55 |
+
SUBMAP_FETCH_SOURCE="${SUBMAP_FETCH_SOURCE:-frontend}"
|
| 56 |
+
SUBMAP_DESCRIPTOR_SOURCE="${SUBMAP_DESCRIPTOR_SOURCE:-frontend}"
|
| 57 |
+
ENABLE_PSEUDO_GT="${ENABLE_PSEUDO_GT:-0}"
|
| 58 |
+
PSEUDO_GT_CACHE_PATH="${PSEUDO_GT_CACHE_PATH:-}"
|
| 59 |
+
SKIP_TEST="${SKIP_TEST:-1}"
|
| 60 |
+
EPOCHS="${EPOCHS:-2}"
|
| 61 |
+
SAMPLES_ARKIT="${SAMPLES_ARKIT:-0}"
|
| 62 |
+
SAMPLES_SCANNETPP="${SAMPLES_SCANNETPP:-16}"
|
| 63 |
+
SAMPLES_SCANNET="${SAMPLES_SCANNET:-0}"
|
| 64 |
+
SAMPLES_HYPERSIM="${SAMPLES_HYPERSIM:-16}"
|
| 65 |
+
SAMPLES_BLENDEDMVS="${SAMPLES_BLENDEDMVS:-0}"
|
| 66 |
+
SAMPLES_MEGADEPTH="${SAMPLES_MEGADEPTH:-0}"
|
| 67 |
+
SAMPLES_MVS_SYNTH="${SAMPLES_MVS_SYNTH:-16}"
|
| 68 |
+
GLOBAL_NUM_VIEWS="${GLOBAL_NUM_VIEWS:-}"
|
| 69 |
+
|
| 70 |
+
if [ "$ENABLE_PSEUDO_GT" = "1" ]; then
|
| 71 |
+
if [ -z "$PSEUDO_GT_CACHE_PATH" ] || [ "$PSEUDO_GT_CACHE_PATH" = "null" ]; then
|
| 72 |
+
echo "ERROR: ENABLE_PSEUDO_GT=1 requires PSEUDO_GT_CACHE_PATH to be set."
|
| 73 |
+
exit 1
|
| 74 |
+
fi
|
| 75 |
+
fi
|
| 76 |
+
|
| 77 |
+
if [ ! -f "$PRETRAINED" ]; then
|
| 78 |
+
echo "ERROR: Missing pretrained checkpoint: $PRETRAINED"
|
| 79 |
+
exit 1
|
| 80 |
+
fi
|
| 81 |
+
|
| 82 |
+
if [ ! -f "$CONDA_SH" ]; then
|
| 83 |
+
echo "ERROR: Missing conda init script: $CONDA_SH"
|
| 84 |
+
exit 1
|
| 85 |
+
fi
|
| 86 |
+
source "$CONDA_SH"
|
| 87 |
+
conda activate "$CONDA_ENV_NAME"
|
| 88 |
+
export PATH="$CONDA_PREFIX/bin:$PATH"
|
| 89 |
+
if command -v module >/dev/null 2>&1; then
|
| 90 |
+
module load cuda12.1/toolkit || true
|
| 91 |
+
fi
|
| 92 |
+
|
| 93 |
+
export PYTHONPATH="$PROJECT_DIR/src:$PROJECT_DIR:${PYTHONPATH:-}"
|
| 94 |
+
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-4}"
|
| 95 |
+
export HYDRA_FULL_ERROR="${HYDRA_FULL_ERROR:-1}"
|
| 96 |
+
export PYTORCH_ALLOC_CONF="${PYTORCH_ALLOC_CONF:-expandable_segments:True}"
|
| 97 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 98 |
+
export CONDA_SH
|
| 99 |
+
export CONFIG_NAME
|
| 100 |
+
export EXP_NAME
|
| 101 |
+
export MASTER_PORT
|
| 102 |
+
export DIST_STRATEGY
|
| 103 |
+
export AUTO_DISABLE_MISSING
|
| 104 |
+
export SAVE_DIR
|
| 105 |
+
export PRETRAINED
|
| 106 |
+
export RESUME
|
| 107 |
+
export NUM_GPUS
|
| 108 |
+
export TRAIN_SUBMAP_MODULES_ONLY
|
| 109 |
+
export DETACH_FRONTEND_TOKENS
|
| 110 |
+
export NUM_VIEWS_ALL
|
| 111 |
+
export NUM_VIEWS_ARKIT
|
| 112 |
+
export NUM_VIEWS_SCANNETPP
|
| 113 |
+
export NUM_VIEWS_SCANNET
|
| 114 |
+
export NUM_VIEWS_HYPERSIM
|
| 115 |
+
export NUM_VIEWS_BLENDEDMVS
|
| 116 |
+
export NUM_VIEWS_MEGADEPTH
|
| 117 |
+
export NUM_VIEWS_MVS_SYNTH
|
| 118 |
+
export SUBMAP_TRAIN_MODE
|
| 119 |
+
export SUBMAP_RETRIEVAL_TOPK
|
| 120 |
+
export SUBMAP_FETCH_SOURCE
|
| 121 |
+
export SUBMAP_DESCRIPTOR_SOURCE
|
| 122 |
+
export ENABLE_PSEUDO_GT
|
| 123 |
+
export PSEUDO_GT_CACHE_PATH
|
| 124 |
+
export SAMPLES_ARKIT
|
| 125 |
+
export SAMPLES_SCANNETPP
|
| 126 |
+
export SAMPLES_SCANNET
|
| 127 |
+
export SAMPLES_HYPERSIM
|
| 128 |
+
export SAMPLES_BLENDEDMVS
|
| 129 |
+
export SAMPLES_MEGADEPTH
|
| 130 |
+
export SAMPLES_MVS_SYNTH
|
| 131 |
+
export DATA_ROOT
|
| 132 |
+
export ROOT_ARKIT
|
| 133 |
+
export ROOT_SCANNETPP
|
| 134 |
+
export ROOT_SCANNET
|
| 135 |
+
export ROOT_SCANNET_FALLBACK
|
| 136 |
+
export ROOT_HYPERSIM
|
| 137 |
+
export ROOT_BLENDEDMVS
|
| 138 |
+
export ROOT_MEGADEPTH
|
| 139 |
+
export ROOT_MVS_SYNTH
|
| 140 |
+
export SKIP_TEST
|
| 141 |
+
export GLOBAL_NUM_VIEWS
|
| 142 |
+
|
| 143 |
+
dataset_is_ready() {
|
| 144 |
+
local label="$1"
|
| 145 |
+
local path="$2"
|
| 146 |
+
case "$label" in
|
| 147 |
+
"ARKitScenes")
|
| 148 |
+
[ -f "$path/Training/all_metadata.npz" ]
|
| 149 |
+
;;
|
| 150 |
+
"ScanNet++")
|
| 151 |
+
[ -f "$path/all_metadata.npz" ]
|
| 152 |
+
;;
|
| 153 |
+
"ScanNet")
|
| 154 |
+
[ -d "$path/scans_train" ] && [ -n "$(find "$path/scans_train" -mindepth 2 -maxdepth 2 -type f -name 'new_scene_metadata.npz' -print -quit 2>/dev/null)" ]
|
| 155 |
+
;;
|
| 156 |
+
"HyperSim")
|
| 157 |
+
[ -d "$path" ] && [ -n "$(find "$path" -mindepth 3 -maxdepth 3 -type f -name '*rgb.png' -print -quit 2>/dev/null)" ]
|
| 158 |
+
;;
|
| 159 |
+
"BlendedMVS")
|
| 160 |
+
[ -f "$path/new_overlap.h5" ]
|
| 161 |
+
;;
|
| 162 |
+
"MegaDepth")
|
| 163 |
+
[ -f "$path/megadepth_sets_64.npz" ]
|
| 164 |
+
;;
|
| 165 |
+
"MVS-Synth")
|
| 166 |
+
[ -d "$path" ] && [ -n "$(find "$path" -mindepth 2 -maxdepth 2 -type d -name 'cam' -print -quit 2>/dev/null)" ]
|
| 167 |
+
;;
|
| 168 |
+
*)
|
| 169 |
+
[ -e "$path" ]
|
| 170 |
+
;;
|
| 171 |
+
esac
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
dataset_probe_hint() {
|
| 175 |
+
local label="$1"
|
| 176 |
+
case "$label" in
|
| 177 |
+
"ARKitScenes")
|
| 178 |
+
echo "Training/all_metadata.npz"
|
| 179 |
+
;;
|
| 180 |
+
"ScanNet++")
|
| 181 |
+
echo "all_metadata.npz"
|
| 182 |
+
;;
|
| 183 |
+
"ScanNet")
|
| 184 |
+
echo "scans_train/*/new_scene_metadata.npz"
|
| 185 |
+
;;
|
| 186 |
+
"HyperSim")
|
| 187 |
+
echo "scene/subscene/*rgb.png"
|
| 188 |
+
;;
|
| 189 |
+
"BlendedMVS")
|
| 190 |
+
echo "new_overlap.h5"
|
| 191 |
+
;;
|
| 192 |
+
"MegaDepth")
|
| 193 |
+
echo "megadepth_sets_64.npz"
|
| 194 |
+
;;
|
| 195 |
+
"MVS-Synth")
|
| 196 |
+
echo "*/cam"
|
| 197 |
+
;;
|
| 198 |
+
*)
|
| 199 |
+
echo "required dataset files"
|
| 200 |
+
;;
|
| 201 |
+
esac
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
resolve_scannet_root() {
|
| 205 |
+
local preferred="$1"
|
| 206 |
+
local fallback="$2"
|
| 207 |
+
if dataset_is_ready "ScanNet" "$preferred"; then
|
| 208 |
+
echo "$preferred"
|
| 209 |
+
return
|
| 210 |
+
fi
|
| 211 |
+
if [ "$fallback" != "$preferred" ] && dataset_is_ready "ScanNet" "$fallback"; then
|
| 212 |
+
echo "INFO: ScanNet root $preferred is incomplete; falling back to $fallback" >&2
|
| 213 |
+
echo "$fallback"
|
| 214 |
+
return
|
| 215 |
+
fi
|
| 216 |
+
echo "$preferred"
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
handle_missing_dataset() {
|
| 220 |
+
local label="$1"
|
| 221 |
+
local path="$2"
|
| 222 |
+
local weight_var="$3"
|
| 223 |
+
local weight="${!weight_var}"
|
| 224 |
+
if [ "$weight" -le 0 ]; then
|
| 225 |
+
return
|
| 226 |
+
fi
|
| 227 |
+
if ! dataset_is_ready "$label" "$path"; then
|
| 228 |
+
local probe_hint
|
| 229 |
+
probe_hint="$(dataset_probe_hint "$label")"
|
| 230 |
+
if [ "$AUTO_DISABLE_MISSING" = "1" ]; then
|
| 231 |
+
echo "WARNING: Missing or incomplete ${label} dataset root: ${path}"
|
| 232 |
+
echo "WARNING: Expected ${probe_hint} under ${path}"
|
| 233 |
+
echo "WARNING: Disabling ${label} by setting ${weight_var}=0"
|
| 234 |
+
printf -v "$weight_var" '0'
|
| 235 |
+
else
|
| 236 |
+
echo "ERROR: Missing or incomplete ${label} dataset root: ${path}"
|
| 237 |
+
echo "ERROR: Expected ${probe_hint} under ${path}"
|
| 238 |
+
exit 1
|
| 239 |
+
fi
|
| 240 |
+
fi
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
append_dataset() {
|
| 244 |
+
local weight="$1"
|
| 245 |
+
local token="$2"
|
| 246 |
+
if [ "$weight" -le 0 ]; then
|
| 247 |
+
return
|
| 248 |
+
fi
|
| 249 |
+
DATASET_PARTS+=("${weight} @ \${${token}}")
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
mkdir -p "$SAVE_DIR/$EXP_NAME"
|
| 253 |
+
|
| 254 |
+
ROOT_SCANNET="$(resolve_scannet_root "$ROOT_SCANNET" "$ROOT_SCANNET_FALLBACK")"
|
| 255 |
+
|
| 256 |
+
handle_missing_dataset "ARKitScenes" "$ROOT_ARKIT" SAMPLES_ARKIT
|
| 257 |
+
handle_missing_dataset "ScanNet++" "$ROOT_SCANNETPP" SAMPLES_SCANNETPP
|
| 258 |
+
handle_missing_dataset "ScanNet" "$ROOT_SCANNET" SAMPLES_SCANNET
|
| 259 |
+
handle_missing_dataset "HyperSim" "$ROOT_HYPERSIM" SAMPLES_HYPERSIM
|
| 260 |
+
handle_missing_dataset "BlendedMVS" "$ROOT_BLENDEDMVS" SAMPLES_BLENDEDMVS
|
| 261 |
+
handle_missing_dataset "MegaDepth" "$ROOT_MEGADEPTH" SAMPLES_MEGADEPTH
|
| 262 |
+
handle_missing_dataset "MVS-Synth" "$ROOT_MVS_SYNTH" SAMPLES_MVS_SYNTH
|
| 263 |
+
|
| 264 |
+
if [ ! -f "$PRETRAINED" ]; then
|
| 265 |
+
echo "ERROR: Missing pretrained checkpoint: $PRETRAINED"
|
| 266 |
+
exit 1
|
| 267 |
+
fi
|
| 268 |
+
|
| 269 |
+
PSEUDO_GT_OVERRIDES=()
|
| 270 |
+
if [ "$ENABLE_PSEUDO_GT" = "1" ]; then
|
| 271 |
+
if [ -z "$PSEUDO_GT_CACHE_PATH" ] || [ "$PSEUDO_GT_CACHE_PATH" = "null" ]; then
|
| 272 |
+
echo "ERROR: ENABLE_PSEUDO_GT=1 requires PSEUDO_GT_CACHE_PATH to be set."
|
| 273 |
+
exit 1
|
| 274 |
+
fi
|
| 275 |
+
PSEUDO_GT_OVERRIDES=(
|
| 276 |
+
"pseudo_gt.enable=true"
|
| 277 |
+
"pseudo_gt.cache_path=${PSEUDO_GT_CACHE_PATH}"
|
| 278 |
+
)
|
| 279 |
+
else
|
| 280 |
+
PSEUDO_GT_OVERRIDES=(
|
| 281 |
+
"pseudo_gt.enable=false"
|
| 282 |
+
"pseudo_gt.cache_path=null"
|
| 283 |
+
)
|
| 284 |
+
fi
|
| 285 |
+
|
| 286 |
+
if [ "$TRAIN_SUBMAP_MODULES_ONLY" = "1" ]; then
|
| 287 |
+
TRAIN_SUBMAP_MODULES_ONLY_HYDRA="true"
|
| 288 |
+
else
|
| 289 |
+
TRAIN_SUBMAP_MODULES_ONLY_HYDRA="false"
|
| 290 |
+
fi
|
| 291 |
+
|
| 292 |
+
if [ "$DETACH_FRONTEND_TOKENS" = "1" ]; then
|
| 293 |
+
DETACH_FRONTEND_TOKENS_HYDRA="true"
|
| 294 |
+
else
|
| 295 |
+
DETACH_FRONTEND_TOKENS_HYDRA="false"
|
| 296 |
+
fi
|
| 297 |
+
|
| 298 |
+
DATASET_PARTS=()
|
| 299 |
+
append_dataset "$SAMPLES_ARKIT" dataset_arkit
|
| 300 |
+
append_dataset "$SAMPLES_SCANNETPP" dataset_scannetpp
|
| 301 |
+
append_dataset "$SAMPLES_SCANNET" dataset_scannet
|
| 302 |
+
append_dataset "$SAMPLES_HYPERSIM" dataset_hypersim
|
| 303 |
+
append_dataset "$SAMPLES_BLENDEDMVS" dataset_blendedmvs
|
| 304 |
+
append_dataset "$SAMPLES_MEGADEPTH" dataset_megadepth
|
| 305 |
+
append_dataset "$SAMPLES_MVS_SYNTH" dataset_mvs_synth
|
| 306 |
+
|
| 307 |
+
if [ "${#DATASET_PARTS[@]}" -eq 0 ]; then
|
| 308 |
+
echo "ERROR: No training dataset remains after weight filtering."
|
| 309 |
+
exit 1
|
| 310 |
+
fi
|
| 311 |
+
|
| 312 |
+
if [ -z "$GLOBAL_NUM_VIEWS" ]; then
|
| 313 |
+
GLOBAL_NUM_VIEWS=0
|
| 314 |
+
if [ "$SAMPLES_ARKIT" -gt 0 ] && [ "$NUM_VIEWS_ARKIT" -gt "$GLOBAL_NUM_VIEWS" ]; then
|
| 315 |
+
GLOBAL_NUM_VIEWS="$NUM_VIEWS_ARKIT"
|
| 316 |
+
fi
|
| 317 |
+
if [ "$SAMPLES_SCANNETPP" -gt 0 ] && [ "$NUM_VIEWS_SCANNETPP" -gt "$GLOBAL_NUM_VIEWS" ]; then
|
| 318 |
+
GLOBAL_NUM_VIEWS="$NUM_VIEWS_SCANNETPP"
|
| 319 |
+
fi
|
| 320 |
+
if [ "$SAMPLES_SCANNET" -gt 0 ] && [ "$NUM_VIEWS_SCANNET" -gt "$GLOBAL_NUM_VIEWS" ]; then
|
| 321 |
+
GLOBAL_NUM_VIEWS="$NUM_VIEWS_SCANNET"
|
| 322 |
+
fi
|
| 323 |
+
if [ "$SAMPLES_HYPERSIM" -gt 0 ] && [ "$NUM_VIEWS_HYPERSIM" -gt "$GLOBAL_NUM_VIEWS" ]; then
|
| 324 |
+
GLOBAL_NUM_VIEWS="$NUM_VIEWS_HYPERSIM"
|
| 325 |
+
fi
|
| 326 |
+
if [ "$SAMPLES_BLENDEDMVS" -gt 0 ] && [ "$NUM_VIEWS_BLENDEDMVS" -gt "$GLOBAL_NUM_VIEWS" ]; then
|
| 327 |
+
GLOBAL_NUM_VIEWS="$NUM_VIEWS_BLENDEDMVS"
|
| 328 |
+
fi
|
| 329 |
+
if [ "$SAMPLES_MEGADEPTH" -gt 0 ] && [ "$NUM_VIEWS_MEGADEPTH" -gt "$GLOBAL_NUM_VIEWS" ]; then
|
| 330 |
+
GLOBAL_NUM_VIEWS="$NUM_VIEWS_MEGADEPTH"
|
| 331 |
+
fi
|
| 332 |
+
if [ "$SAMPLES_MVS_SYNTH" -gt 0 ] && [ "$NUM_VIEWS_MVS_SYNTH" -gt "$GLOBAL_NUM_VIEWS" ]; then
|
| 333 |
+
GLOBAL_NUM_VIEWS="$NUM_VIEWS_MVS_SYNTH"
|
| 334 |
+
fi
|
| 335 |
+
fi
|
| 336 |
+
|
| 337 |
+
TRAIN_DATASET="${DATASET_PARTS[0]}"
|
| 338 |
+
for part in "${DATASET_PARTS[@]:1}"; do
|
| 339 |
+
TRAIN_DATASET="${TRAIN_DATASET} + ${part}"
|
| 340 |
+
done
|
| 341 |
+
|
| 342 |
+
HYDRA_ARGS=(
|
| 343 |
+
"--config-name" "$CONFIG_NAME"
|
| 344 |
+
"exp_name=$EXP_NAME"
|
| 345 |
+
"save_dir=$SAVE_DIR"
|
| 346 |
+
"pretrained=$PRETRAINED"
|
| 347 |
+
"resume=$RESUME"
|
| 348 |
+
"data_root=$DATA_ROOT"
|
| 349 |
+
"root_arkit=$ROOT_ARKIT"
|
| 350 |
+
"root_scannetpp=$ROOT_SCANNETPP"
|
| 351 |
+
"root_scannet=$ROOT_SCANNET"
|
| 352 |
+
"root_hypersim=$ROOT_HYPERSIM"
|
| 353 |
+
"root_blendedmvs=$ROOT_BLENDEDMVS"
|
| 354 |
+
"root_megadepth=$ROOT_MEGADEPTH"
|
| 355 |
+
"root_mvs_synth=$ROOT_MVS_SYNTH"
|
| 356 |
+
"num_views=$GLOBAL_NUM_VIEWS"
|
| 357 |
+
"num_views_arkit=$NUM_VIEWS_ARKIT"
|
| 358 |
+
"num_views_scannetpp=$NUM_VIEWS_SCANNETPP"
|
| 359 |
+
"num_views_scannet=$NUM_VIEWS_SCANNET"
|
| 360 |
+
"num_views_hypersim=$NUM_VIEWS_HYPERSIM"
|
| 361 |
+
"num_views_blendedmvs=$NUM_VIEWS_BLENDEDMVS"
|
| 362 |
+
"num_views_megadepth=$NUM_VIEWS_MEGADEPTH"
|
| 363 |
+
"num_views_mvs_synth=$NUM_VIEWS_MVS_SYNTH"
|
| 364 |
+
"train_submap_modules_only=$TRAIN_SUBMAP_MODULES_ONLY_HYDRA"
|
| 365 |
+
"detach_frontend_tokens=$DETACH_FRONTEND_TOKENS_HYDRA"
|
| 366 |
+
"submap_train_mode=$SUBMAP_TRAIN_MODE"
|
| 367 |
+
"submap_retrieval_topk=$SUBMAP_RETRIEVAL_TOPK"
|
| 368 |
+
"submap_fetch_source=$SUBMAP_FETCH_SOURCE"
|
| 369 |
+
"submap_descriptor_source=$SUBMAP_DESCRIPTOR_SOURCE"
|
| 370 |
+
"${PSEUDO_GT_OVERRIDES[@]}"
|
| 371 |
+
"train_dataset=$TRAIN_DATASET"
|
| 372 |
+
"epochs=$EPOCHS"
|
| 373 |
+
)
|
| 374 |
+
if [ "$SKIP_TEST" = "1" ]; then
|
| 375 |
+
HYDRA_ARGS+=("test_dataset=")
|
| 376 |
+
fi
|
| 377 |
+
HYDRA_ARGS+=("$@")
|
| 378 |
+
|
| 379 |
+
echo "=== Starting 2GPU local smoke: joint backend+submap with frozen frontend tokens ==="
|
| 380 |
+
echo " Project dir : $PROJECT_DIR"
|
| 381 |
+
echo " Launch entry : inline"
|
| 382 |
+
echo " Config : $CONFIG_NAME"
|
| 383 |
+
echo " Experiment : $EXP_NAME"
|
| 384 |
+
echo " Save dir : $SAVE_DIR"
|
| 385 |
+
echo " Pretrained : $PRETRAINED"
|
| 386 |
+
echo " Resume : $RESUME"
|
| 387 |
+
echo " Num GPUs : $NUM_GPUS"
|
| 388 |
+
echo " Distributed strategy : $DIST_STRATEGY"
|
| 389 |
+
echo " Train submap only : $TRAIN_SUBMAP_MODULES_ONLY"
|
| 390 |
+
echo " Detach frontend tokens: $DETACH_FRONTEND_TOKENS"
|
| 391 |
+
echo " Num views : arkit=$NUM_VIEWS_ARKIT scannetpp=$NUM_VIEWS_SCANNETPP scannet=$NUM_VIEWS_SCANNET hypersim=$NUM_VIEWS_HYPERSIM blendedmvs=$NUM_VIEWS_BLENDEDMVS megadepth=$NUM_VIEWS_MEGADEPTH mvs_synth=$NUM_VIEWS_MVS_SYNTH"
|
| 392 |
+
echo " Submap size : $SUBMAP_SIZE"
|
| 393 |
+
echo " Samples : arkit=$SAMPLES_ARKIT scannetpp=$SAMPLES_SCANNETPP scannet=$SAMPLES_SCANNET hypersim=$SAMPLES_HYPERSIM blendedmvs=$SAMPLES_BLENDEDMVS megadepth=$SAMPLES_MEGADEPTH mvs_synth=$SAMPLES_MVS_SYNTH"
|
| 394 |
+
echo " Epochs override : ${EPOCHS:-<config default>}"
|
| 395 |
+
echo " Skip test : $SKIP_TEST"
|
| 396 |
+
echo " Train dataset : $TRAIN_DATASET"
|
| 397 |
+
echo " Global num views : $GLOBAL_NUM_VIEWS"
|
| 398 |
+
echo
|
| 399 |
+
|
| 400 |
+
COMMON_ARGS=(
|
| 401 |
+
--num_machines 1
|
| 402 |
+
--num_processes "$NUM_GPUS"
|
| 403 |
+
--main_process_port "$MASTER_PORT"
|
| 404 |
+
--dynamo_backend no
|
| 405 |
+
--mixed_precision bf16
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
if [ "$DIST_STRATEGY" = "fsdp" ]; then
|
| 409 |
+
accelerate launch \
|
| 410 |
+
"${COMMON_ARGS[@]}" \
|
| 411 |
+
--use_fsdp \
|
| 412 |
+
--fsdp_sharding_strategy FULL_SHARD \
|
| 413 |
+
--fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP \
|
| 414 |
+
--fsdp_transformer_layer_cls_to_wrap BlockRope \
|
| 415 |
+
--fsdp_state_dict_type FULL_STATE_DICT \
|
| 416 |
+
--fsdp_backward_prefetch BACKWARD_PRE \
|
| 417 |
+
--fsdp_use_orig_params true \
|
| 418 |
+
--fsdp_sync_module_states true \
|
| 419 |
+
--fsdp_activation_checkpointing true \
|
| 420 |
+
"$SRC_DIR/finetune.py" \
|
| 421 |
+
"${HYDRA_ARGS[@]}"
|
| 422 |
+
elif [ "$DIST_STRATEGY" = "ddp" ]; then
|
| 423 |
+
if [ "$NUM_GPUS" -gt 1 ]; then
|
| 424 |
+
accelerate launch \
|
| 425 |
+
--multi_gpu \
|
| 426 |
+
"${COMMON_ARGS[@]}" \
|
| 427 |
+
"$SRC_DIR/finetune.py" \
|
| 428 |
+
"${HYDRA_ARGS[@]}"
|
| 429 |
+
else
|
| 430 |
+
accelerate launch \
|
| 431 |
+
"${COMMON_ARGS[@]}" \
|
| 432 |
+
"$SRC_DIR/finetune.py" \
|
| 433 |
+
"${HYDRA_ARGS[@]}"
|
| 434 |
+
fi
|
| 435 |
+
else
|
| 436 |
+
echo "ERROR: Unsupported DIST_STRATEGY=$DIST_STRATEGY"
|
| 437 |
+
exit 1
|
| 438 |
+
fi
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/graph_gated_memory.py
ADDED
|
@@ -0,0 +1,850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GraphGatedMemoryManager: Sparse Windowed & Recursively Retrieved Submap Backend.
|
| 3 |
+
|
| 4 |
+
Key components:
|
| 5 |
+
- SubMapBuffer: CPU-side storage for historical submap tokens + descriptors
|
| 6 |
+
- GraphGatedMemoryManager: differentiable loop closure with [NO_LOOP] gating,
|
| 7 |
+
recursive covisibility fetching, GPU active workspace management
|
| 8 |
+
- TemporalEmbedWrapper: dual-injection temporal embedding (Fix #7)
|
| 9 |
+
- _safe_oom_retry / _build_temporal_mask: helpers
|
| 10 |
+
|
| 11 |
+
All features are toggled via CLI arguments; when disabled, behaviour is
|
| 12 |
+
identical to the original SLAM-Former pipeline. No original source files
|
| 13 |
+
are modified.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import numpy as np
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
from typing import Dict, List, Optional, Set, Tuple
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ═══════════════════════════════════════════════════════════
|
| 26 |
+
# Helpers
|
| 27 |
+
# ═══════════════════════════════════════════════════════════
|
| 28 |
+
|
| 29 |
+
def _safe_oom_retry(fn, *args, **kwargs):
|
| 30 |
+
"""Run *fn*; on CUDA OOM, free cache and retry once."""
|
| 31 |
+
try:
|
| 32 |
+
return fn(*args, **kwargs)
|
| 33 |
+
except RuntimeError as e:
|
| 34 |
+
if "out of memory" in str(e):
|
| 35 |
+
torch.cuda.empty_cache()
|
| 36 |
+
return fn(*args, **kwargs)
|
| 37 |
+
raise
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _build_temporal_mask(frame_id_map: torch.Tensor, P: int) -> torch.Tensor:
|
| 41 |
+
"""Build a causal attention mask from a non-contiguous frame-id tensor.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
frame_id_map: [L] int tensor — true temporal frame index for every token.
|
| 45 |
+
P: tokens per frame (patch_h*patch_w + register_tokens).
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
attn_mask: [L, L] float tensor where future-frame positions are -inf.
|
| 49 |
+
"""
|
| 50 |
+
L = frame_id_map.shape[0]
|
| 51 |
+
fids = frame_id_map.unsqueeze(1) # [L, 1]
|
| 52 |
+
fids_t = frame_id_map.unsqueeze(0) # [1, L]
|
| 53 |
+
# A token at position i may NOT attend to a token at position j
|
| 54 |
+
# if j belongs to a strictly later frame than i (causal).
|
| 55 |
+
future = fids < fids_t # [L, L] bool
|
| 56 |
+
mask = torch.zeros(L, L, device=frame_id_map.device, dtype=torch.float32)
|
| 57 |
+
mask.masked_fill_(future, float("-inf"))
|
| 58 |
+
return mask
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _build_sinusoidal_pe(max_len: int, dim: int) -> torch.Tensor:
|
| 62 |
+
"""Fixed sinusoidal position encoding [max_len, dim]."""
|
| 63 |
+
pe = torch.zeros(max_len, dim)
|
| 64 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 65 |
+
div_term = torch.exp(
|
| 66 |
+
torch.arange(0, dim, 2, dtype=torch.float) * (-math.log(10000.0) / dim)
|
| 67 |
+
)
|
| 68 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 69 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 70 |
+
return pe
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ═══════════════════════════════════════════════════════════
|
| 74 |
+
# SubMapBuffer — CPU-side historical storage
|
| 75 |
+
# ═══════════════════════════════════════════════════════════
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class SubMapBuffer:
|
| 79 |
+
"""Lightweight CPU buffer for completed submaps.
|
| 80 |
+
|
| 81 |
+
Each entry is keyed by ``submap_id`` (int, monotonically increasing).
|
| 82 |
+
|
| 83 |
+
Attributes:
|
| 84 |
+
cpu_frontend_token_buffer: submap_id → [K, P, 2C] frontend token tensor.
|
| 85 |
+
cpu_backend_token_buffer: submap_id → [K, P, 2C] backend-refined token tensor.
|
| 86 |
+
cpu_frontend_descriptor_buffer: submap_id → [desc_dim] frontend descriptor.
|
| 87 |
+
cpu_backend_descriptor_buffer: submap_id → [desc_dim] backend descriptor.
|
| 88 |
+
cpu_frame_ids: submap_id → list of original frame indices.
|
| 89 |
+
"""
|
| 90 |
+
cpu_frontend_token_buffer: Dict[int, torch.Tensor] = field(default_factory=dict)
|
| 91 |
+
cpu_backend_token_buffer: Dict[int, torch.Tensor] = field(default_factory=dict)
|
| 92 |
+
cpu_frontend_descriptor_buffer: Dict[int, torch.Tensor] = field(default_factory=dict)
|
| 93 |
+
cpu_backend_descriptor_buffer: Dict[int, torch.Tensor] = field(default_factory=dict)
|
| 94 |
+
cpu_frame_ids: Dict[int, List[int]] = field(default_factory=dict)
|
| 95 |
+
store_on_cpu: bool = True
|
| 96 |
+
detach_stored: bool = True
|
| 97 |
+
default_token_source: str = "frontend"
|
| 98 |
+
default_descriptor_source: str = "frontend"
|
| 99 |
+
default_writeback_token_source: str = "frontend"
|
| 100 |
+
default_writeback_descriptor_source: str = "frontend"
|
| 101 |
+
|
| 102 |
+
# ── convenience ──────────────────────────────────────
|
| 103 |
+
@property
|
| 104 |
+
def cpu_token_buffer(self) -> Dict[int, torch.Tensor]:
|
| 105 |
+
return self._get_token_buffer(self.default_token_source)
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def cpu_descriptor_buffer(self) -> Dict[int, torch.Tensor]:
|
| 109 |
+
return self._get_descriptor_buffer(self.default_descriptor_source)
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def num_submaps(self) -> int:
|
| 113 |
+
return len(self.cpu_descriptor_buffer)
|
| 114 |
+
|
| 115 |
+
def _get_token_buffer(self, source: str) -> Dict[int, torch.Tensor]:
|
| 116 |
+
if source == "frontend":
|
| 117 |
+
return self.cpu_frontend_token_buffer
|
| 118 |
+
if source == "backend":
|
| 119 |
+
return self.cpu_backend_token_buffer
|
| 120 |
+
raise ValueError(f"Unsupported token source: {source}")
|
| 121 |
+
|
| 122 |
+
def _get_descriptor_buffer(self, source: str) -> Dict[int, torch.Tensor]:
|
| 123 |
+
if source == "frontend":
|
| 124 |
+
return self.cpu_frontend_descriptor_buffer
|
| 125 |
+
if source == "backend":
|
| 126 |
+
return self.cpu_backend_descriptor_buffer
|
| 127 |
+
raise ValueError(f"Unsupported descriptor source: {source}")
|
| 128 |
+
|
| 129 |
+
def _resolve_token_source(self, source: Optional[str], for_writeback: bool = False) -> str:
|
| 130 |
+
if source is not None:
|
| 131 |
+
return source
|
| 132 |
+
return self.default_writeback_token_source if for_writeback else self.default_token_source
|
| 133 |
+
|
| 134 |
+
def _resolve_descriptor_source(self, source: Optional[str], for_writeback: bool = False) -> str:
|
| 135 |
+
if source is not None:
|
| 136 |
+
return source
|
| 137 |
+
return self.default_writeback_descriptor_source if for_writeback else self.default_descriptor_source
|
| 138 |
+
|
| 139 |
+
def _prepare_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 140 |
+
if self.detach_stored:
|
| 141 |
+
tensor = tensor.detach()
|
| 142 |
+
if self.store_on_cpu:
|
| 143 |
+
tensor = tensor.cpu()
|
| 144 |
+
return tensor
|
| 145 |
+
|
| 146 |
+
def _move_to_device(self, tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
|
| 147 |
+
if tensor.device == device:
|
| 148 |
+
return tensor
|
| 149 |
+
return tensor.to(device, non_blocking=self.store_on_cpu)
|
| 150 |
+
|
| 151 |
+
def store(
|
| 152 |
+
self,
|
| 153 |
+
submap_id: int,
|
| 154 |
+
frame_ids: List[int],
|
| 155 |
+
frontend_tokens: Optional[torch.Tensor] = None,
|
| 156 |
+
frontend_descriptor: Optional[torch.Tensor] = None,
|
| 157 |
+
backend_tokens: Optional[torch.Tensor] = None,
|
| 158 |
+
backend_descriptor: Optional[torch.Tensor] = None,
|
| 159 |
+
):
|
| 160 |
+
"""Store a completed submap in the configured history banks."""
|
| 161 |
+
if frontend_tokens is not None:
|
| 162 |
+
self.cpu_frontend_token_buffer[submap_id] = self._prepare_tensor(frontend_tokens)
|
| 163 |
+
if frontend_descriptor is not None:
|
| 164 |
+
self.cpu_frontend_descriptor_buffer[submap_id] = self._prepare_tensor(frontend_descriptor)
|
| 165 |
+
if backend_tokens is not None:
|
| 166 |
+
self.cpu_backend_token_buffer[submap_id] = self._prepare_tensor(backend_tokens)
|
| 167 |
+
if backend_descriptor is not None:
|
| 168 |
+
self.cpu_backend_descriptor_buffer[submap_id] = self._prepare_tensor(backend_descriptor)
|
| 169 |
+
self.cpu_frame_ids[submap_id] = list(frame_ids)
|
| 170 |
+
|
| 171 |
+
def fetch_tokens(
|
| 172 |
+
self,
|
| 173 |
+
submap_id: int,
|
| 174 |
+
device: torch.device,
|
| 175 |
+
source: Optional[str] = None,
|
| 176 |
+
) -> torch.Tensor:
|
| 177 |
+
"""Move a submap's tokens to *device*."""
|
| 178 |
+
source = self._resolve_token_source(source)
|
| 179 |
+
return self._move_to_device(self._get_token_buffer(source)[submap_id], device)
|
| 180 |
+
|
| 181 |
+
def fetch_frame_ids(self, submap_id: int) -> List[int]:
|
| 182 |
+
return self.cpu_frame_ids[submap_id]
|
| 183 |
+
|
| 184 |
+
def get_all_descriptors(
|
| 185 |
+
self,
|
| 186 |
+
device: torch.device,
|
| 187 |
+
source: Optional[str] = None,
|
| 188 |
+
) -> torch.Tensor:
|
| 189 |
+
"""Return [num_submaps, desc_dim] on *device*, ordered by submap_id."""
|
| 190 |
+
source = self._resolve_descriptor_source(source)
|
| 191 |
+
descriptor_buffer = self._get_descriptor_buffer(source)
|
| 192 |
+
if not descriptor_buffer:
|
| 193 |
+
return torch.empty(0, device=device)
|
| 194 |
+
ids = sorted(descriptor_buffer.keys())
|
| 195 |
+
descs = torch.stack([descriptor_buffer[i] for i in ids])
|
| 196 |
+
return self._move_to_device(descs, device)
|
| 197 |
+
|
| 198 |
+
def id_at_index(self, index: int, source: Optional[str] = None) -> int:
|
| 199 |
+
"""Map a 0-based index (into the descriptor matrix) back to submap_id."""
|
| 200 |
+
source = self._resolve_descriptor_source(source)
|
| 201 |
+
return sorted(self._get_descriptor_buffer(source).keys())[index]
|
| 202 |
+
|
| 203 |
+
def update_descriptor(
|
| 204 |
+
self,
|
| 205 |
+
submap_id: int,
|
| 206 |
+
descriptor: torch.Tensor,
|
| 207 |
+
source: Optional[str] = None,
|
| 208 |
+
):
|
| 209 |
+
source = self._resolve_descriptor_source(source, for_writeback=True)
|
| 210 |
+
self._get_descriptor_buffer(source)[submap_id] = self._prepare_tensor(descriptor)
|
| 211 |
+
|
| 212 |
+
def update_tokens(
|
| 213 |
+
self,
|
| 214 |
+
submap_id: int,
|
| 215 |
+
tokens: torch.Tensor,
|
| 216 |
+
source: Optional[str] = None,
|
| 217 |
+
):
|
| 218 |
+
"""Write-back refined tokens (after backend) to CPU buffer."""
|
| 219 |
+
source = self._resolve_token_source(source, for_writeback=True)
|
| 220 |
+
self._get_token_buffer(source)[submap_id] = self._prepare_tensor(tokens)
|
| 221 |
+
|
| 222 |
+
def detach_all(self):
|
| 223 |
+
for buffer_dict in (
|
| 224 |
+
self.cpu_frontend_token_buffer,
|
| 225 |
+
self.cpu_backend_token_buffer,
|
| 226 |
+
self.cpu_frontend_descriptor_buffer,
|
| 227 |
+
self.cpu_backend_descriptor_buffer,
|
| 228 |
+
):
|
| 229 |
+
for sid in list(buffer_dict.keys()):
|
| 230 |
+
buffer_dict[sid] = buffer_dict[sid].detach()
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# ═══════════════════════════════════════════════════════════
|
| 234 |
+
# TemporalEmbedWrapper — Dual injection (Fix #7)
|
| 235 |
+
# ═══════════════════════════════════════════════════════════
|
| 236 |
+
|
| 237 |
+
class TemporalEmbedWrapper(nn.Module):
|
| 238 |
+
"""Dual temporal embedding: input injection + output injection.
|
| 239 |
+
|
| 240 |
+
Phase 1 (input): Add t_emb to hidden_F[:,:,:C] BEFORE backendT.
|
| 241 |
+
→ Temporal info participates in all 36 layers of attention.
|
| 242 |
+
Phase 2 (output): Add projected t_emb to BOTH halves of hidden_B AFTER
|
| 243 |
+
backendT.
|
| 244 |
+
→ Guarantees Layer 35 (geometry) and Layer 36 (semantics) both
|
| 245 |
+
carry temporal information for downstream heads.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
embed_dim: feature dimension C (default 1024).
|
| 249 |
+
max_frames: maximum temporal index supported.
|
| 250 |
+
mode: 'learned' | 'sinusoidal'.
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
def __init__(self, embed_dim: int = 1024, max_frames: int = 2000,
|
| 254 |
+
mode: str = "learned"):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.embed_dim = embed_dim
|
| 257 |
+
self.max_frames = max_frames
|
| 258 |
+
self.mode = mode
|
| 259 |
+
|
| 260 |
+
if mode == "learned":
|
| 261 |
+
self.temporal_embed = nn.Embedding(max_frames, embed_dim)
|
| 262 |
+
elif mode == "sinusoidal":
|
| 263 |
+
pe = _build_sinusoidal_pe(max_frames, embed_dim)
|
| 264 |
+
self.register_buffer("temporal_embed_fixed", pe)
|
| 265 |
+
else:
|
| 266 |
+
raise ValueError(f"Unknown temporal embed mode: {mode}")
|
| 267 |
+
|
| 268 |
+
# Separate projections for the two output halves
|
| 269 |
+
self.output_proj_layer35 = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 270 |
+
self.output_proj_layer36 = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 271 |
+
|
| 272 |
+
# Learnable gates — initialised at 0 (sigmoid → 0.5 at start)
|
| 273 |
+
self.gate_layer35 = nn.Parameter(torch.zeros(1))
|
| 274 |
+
self.gate_layer36 = nn.Parameter(torch.zeros(1))
|
| 275 |
+
|
| 276 |
+
# ── core ─────────────────────────────────────────────
|
| 277 |
+
def get_temporal_embed(self, frame_ids: torch.Tensor) -> torch.Tensor:
|
| 278 |
+
"""Return [N, C] temporal embedding for given frame indices."""
|
| 279 |
+
frame_ids = frame_ids.clamp(max=self.max_frames - 1)
|
| 280 |
+
if self.mode == "learned":
|
| 281 |
+
return self.temporal_embed(frame_ids)
|
| 282 |
+
else:
|
| 283 |
+
return self.temporal_embed_fixed[frame_ids]
|
| 284 |
+
|
| 285 |
+
def inject_input(self, hidden_F: torch.Tensor,
|
| 286 |
+
frame_ids: torch.Tensor) -> torch.Tensor:
|
| 287 |
+
"""Phase 1: add temporal embedding to first C dims of hidden_F.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
hidden_F: [N, P, 2C] frontend token map.
|
| 291 |
+
frame_ids: [N] long tensor of temporal frame indices.
|
| 292 |
+
Returns:
|
| 293 |
+
hidden_F with temporal embedding added to [:, :, :C].
|
| 294 |
+
"""
|
| 295 |
+
N, P, C2 = hidden_F.shape
|
| 296 |
+
C = C2 // 2
|
| 297 |
+
t_emb = self.get_temporal_embed(frame_ids) # [N, C]
|
| 298 |
+
hidden_F = hidden_F.clone()
|
| 299 |
+
hidden_F[:, :, :C] = hidden_F[:, :, :C] + t_emb.unsqueeze(1)
|
| 300 |
+
return hidden_F
|
| 301 |
+
|
| 302 |
+
def inject_output(self, hidden_B: torch.Tensor,
|
| 303 |
+
frame_ids: torch.Tensor) -> torch.Tensor:
|
| 304 |
+
"""Phase 2: add projected temporal embedding to BOTH halves of hidden_B.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
hidden_B: [N, P, 2C] backend output (Layer35 || Layer36).
|
| 308 |
+
frame_ids: [N] long tensor.
|
| 309 |
+
Returns:
|
| 310 |
+
hidden_B with gated temporal embedding added to both halves.
|
| 311 |
+
"""
|
| 312 |
+
N, P, C2 = hidden_B.shape
|
| 313 |
+
C = C2 // 2
|
| 314 |
+
t_emb = self.get_temporal_embed(frame_ids) # [N, C]
|
| 315 |
+
|
| 316 |
+
t_emb_35 = self.output_proj_layer35(t_emb) # [N, C]
|
| 317 |
+
t_emb_36 = self.output_proj_layer36(t_emb) # [N, C]
|
| 318 |
+
|
| 319 |
+
g35 = torch.sigmoid(self.gate_layer35)
|
| 320 |
+
g36 = torch.sigmoid(self.gate_layer36)
|
| 321 |
+
|
| 322 |
+
hidden_B = hidden_B.clone()
|
| 323 |
+
hidden_B[:, :, :C] = hidden_B[:, :, :C] + g35 * t_emb_35.unsqueeze(1)
|
| 324 |
+
hidden_B[:, :, C:] = hidden_B[:, :, C:] + g36 * t_emb_36.unsqueeze(1)
|
| 325 |
+
return hidden_B
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# ═══════════════════════════════════════════════════════════
|
| 329 |
+
# GraphGatedMemoryManager
|
| 330 |
+
# ═══════════════════════════════════════════════════════════
|
| 331 |
+
|
| 332 |
+
class GraphGatedMemoryManager(nn.Module):
|
| 333 |
+
"""Sparse Windowed & Recursively Retrieved Submap Backend.
|
| 334 |
+
|
| 335 |
+
Manages:
|
| 336 |
+
* GPU active workspace (S_prev + S_curr, up to 2K frames).
|
| 337 |
+
* Differentiable loop closure with a [NO_LOOP] dummy descriptor.
|
| 338 |
+
* Recursive covisibility fetching via an adjacency graph.
|
| 339 |
+
* CPU ↔ GPU memory offloading.
|
| 340 |
+
|
| 341 |
+
All operations are designed to be DDP-safe (Fix #4): the full
|
| 342 |
+
descriptor / retrieval / backendT path is always executed for every
|
| 343 |
+
batch element, gated by a differentiable multiplier.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
submap_size: K — number of frames per submap.
|
| 347 |
+
max_recursive_submaps: cap on historical submaps fetched at once.
|
| 348 |
+
desc_dim: global descriptor dimension (default 128).
|
| 349 |
+
embed_dim: token feature dimension C (default 1024).
|
| 350 |
+
gumbel_tau: initial Gumbel-Softmax temperature.
|
| 351 |
+
loop_mask_mode: "hard_top1" or "soft_all".
|
| 352 |
+
soft_mask_temperature: temperature for soft_all mode.
|
| 353 |
+
soft_mask_bias: bias for soft_all mode.
|
| 354 |
+
"""
|
| 355 |
+
|
| 356 |
+
def __init__(
|
| 357 |
+
self,
|
| 358 |
+
submap_size: int = 10,
|
| 359 |
+
max_recursive_submaps: int = 5,
|
| 360 |
+
desc_dim: int = 128,
|
| 361 |
+
embed_dim: int = 1024,
|
| 362 |
+
gumbel_tau: float = 1.0,
|
| 363 |
+
loop_mask_mode: str = "hard_top1",
|
| 364 |
+
soft_mask_temperature: float = 0.25,
|
| 365 |
+
soft_mask_bias: float = 0.2,
|
| 366 |
+
retain_history_grad: bool = False,
|
| 367 |
+
submap_train_mode: str = "full_token",
|
| 368 |
+
submap_retrieval_topk: int = 5,
|
| 369 |
+
submap_fetch_source: str = "frontend",
|
| 370 |
+
submap_descriptor_source: str = "frontend",
|
| 371 |
+
):
|
| 372 |
+
super().__init__()
|
| 373 |
+
self.K = submap_size
|
| 374 |
+
self.max_recursive = max_recursive_submaps
|
| 375 |
+
self.desc_dim = desc_dim
|
| 376 |
+
self.embed_dim = embed_dim
|
| 377 |
+
self.gumbel_tau = gumbel_tau
|
| 378 |
+
self.loop_mask_mode = loop_mask_mode
|
| 379 |
+
self.soft_mask_temperature = soft_mask_temperature
|
| 380 |
+
self.soft_mask_bias = soft_mask_bias
|
| 381 |
+
self.retain_history_grad = retain_history_grad
|
| 382 |
+
self.submap_train_mode = submap_train_mode
|
| 383 |
+
self.submap_retrieval_topk = int(submap_retrieval_topk)
|
| 384 |
+
self.submap_fetch_source = submap_fetch_source
|
| 385 |
+
self.submap_descriptor_source = submap_descriptor_source
|
| 386 |
+
|
| 387 |
+
valid_modes = {"full_token", "top5_dual_queue"}
|
| 388 |
+
if self.submap_train_mode not in valid_modes:
|
| 389 |
+
raise ValueError(
|
| 390 |
+
f"Unsupported submap_train_mode: {self.submap_train_mode}. "
|
| 391 |
+
f"Expected one of {sorted(valid_modes)}."
|
| 392 |
+
)
|
| 393 |
+
valid_sources = {"frontend", "backend"}
|
| 394 |
+
if self.submap_fetch_source not in valid_sources:
|
| 395 |
+
raise ValueError(f"Unsupported submap_fetch_source: {self.submap_fetch_source}")
|
| 396 |
+
if self.submap_descriptor_source not in valid_sources:
|
| 397 |
+
raise ValueError(f"Unsupported submap_descriptor_source: {self.submap_descriptor_source}")
|
| 398 |
+
|
| 399 |
+
self.use_dual_queue = self.submap_train_mode == "top5_dual_queue"
|
| 400 |
+
if self.submap_retrieval_topk <= 0:
|
| 401 |
+
self.submap_retrieval_topk = 5 if self.use_dual_queue else 0
|
| 402 |
+
|
| 403 |
+
# ── learnable parameters ─────────────────────────
|
| 404 |
+
# [NO_LOOP] dummy descriptor at index 0
|
| 405 |
+
self.no_loop_descriptor = nn.Parameter(
|
| 406 |
+
torch.randn(1, desc_dim) * 0.02
|
| 407 |
+
)
|
| 408 |
+
# Project pooled tokens → global descriptor
|
| 409 |
+
self.desc_proj = nn.Linear(2 * embed_dim, desc_dim)
|
| 410 |
+
|
| 411 |
+
# ── non-parameter state ──────────────────────────
|
| 412 |
+
self.buffer = self._build_buffer()
|
| 413 |
+
self.adjacency: Dict[int, Set[int]] = {}
|
| 414 |
+
|
| 415 |
+
# Frame-level accumulation for the *current* submap
|
| 416 |
+
self._curr_tokens: List[torch.Tensor] = [] # list of [1, P, 2C]
|
| 417 |
+
self._curr_frame_ids: List[int] = []
|
| 418 |
+
# Tokens for the *previous* submap (kept on GPU for sliding window)
|
| 419 |
+
self._prev_tokens: Optional[torch.Tensor] = None # [K, P, 2C]
|
| 420 |
+
self._prev_frame_ids: List[int] = []
|
| 421 |
+
|
| 422 |
+
self._current_submap_id: int = 0
|
| 423 |
+
self._global_frame_counter: int = 0
|
| 424 |
+
|
| 425 |
+
def _build_buffer(self) -> SubMapBuffer:
|
| 426 |
+
return SubMapBuffer(
|
| 427 |
+
store_on_cpu=not self.retain_history_grad,
|
| 428 |
+
detach_stored=not self.retain_history_grad,
|
| 429 |
+
default_token_source=self.submap_fetch_source,
|
| 430 |
+
default_descriptor_source=self.submap_descriptor_source,
|
| 431 |
+
default_writeback_token_source=(
|
| 432 |
+
"backend" if self.use_dual_queue else self.submap_fetch_source
|
| 433 |
+
),
|
| 434 |
+
default_writeback_descriptor_source=(
|
| 435 |
+
"backend"
|
| 436 |
+
if (self.use_dual_queue or self.submap_descriptor_source == "backend")
|
| 437 |
+
else self.submap_descriptor_source
|
| 438 |
+
),
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# ── properties ───────────────────────────────────────
|
| 442 |
+
@property
|
| 443 |
+
def current_submap_id(self) -> int:
|
| 444 |
+
return self._current_submap_id
|
| 445 |
+
|
| 446 |
+
@property
|
| 447 |
+
def submap_complete(self) -> bool:
|
| 448 |
+
return len(self._curr_tokens) >= self.K
|
| 449 |
+
|
| 450 |
+
# ── accumulate ───────────────────────────────────────
|
| 451 |
+
def accumulate(self, frame_token: torch.Tensor, frame_id: Optional[int] = None):
|
| 452 |
+
"""Append a single frame's token to the current submap.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
frame_token: [1, P, 2C] or [P, 2C] — output of frontendT.
|
| 456 |
+
frame_id: original temporal frame index (auto-incremented if None).
|
| 457 |
+
"""
|
| 458 |
+
if frame_token.dim() == 2:
|
| 459 |
+
frame_token = frame_token.unsqueeze(0)
|
| 460 |
+
self._curr_tokens.append(frame_token)
|
| 461 |
+
fid = frame_id if frame_id is not None else self._global_frame_counter
|
| 462 |
+
self._curr_frame_ids.append(fid)
|
| 463 |
+
self._global_frame_counter += 1
|
| 464 |
+
|
| 465 |
+
# ── descriptor computation ───────────────────────────
|
| 466 |
+
def compute_descriptor(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 467 |
+
"""Pool submap tokens → global descriptor.
|
| 468 |
+
|
| 469 |
+
Args:
|
| 470 |
+
tokens: [K, P, 2C] on GPU.
|
| 471 |
+
Returns:
|
| 472 |
+
descriptor: [1, desc_dim].
|
| 473 |
+
"""
|
| 474 |
+
# Mean-pool over frames and patches → [1, 2C]
|
| 475 |
+
device_type = tokens.device.type if tokens.is_cuda else "cpu"
|
| 476 |
+
with torch.amp.autocast(device_type=device_type, enabled=False):
|
| 477 |
+
pooled = tokens.float().mean(dim=(0, 1), keepdim=False).unsqueeze(0)
|
| 478 |
+
pooled = torch.nan_to_num(pooled, nan=0.0, posinf=0.0, neginf=0.0)
|
| 479 |
+
desc = self.desc_proj(pooled.to(dtype=self.desc_proj.weight.dtype)).float()
|
| 480 |
+
return torch.nan_to_num(desc, nan=0.0, posinf=0.0, neginf=0.0)
|
| 481 |
+
|
| 482 |
+
# ── loop retrieval (differentiable, DDP-safe) ────────
|
| 483 |
+
def retrieve(
|
| 484 |
+
self,
|
| 485 |
+
curr_desc: torch.Tensor,
|
| 486 |
+
device: torch.device,
|
| 487 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[int], int, int, List[int], List[int], Optional[torch.Tensor]]:
|
| 488 |
+
"""Differentiable loop closure retrieval (Fix B2: no Python branching).
|
| 489 |
+
|
| 490 |
+
ALL code paths always fetch tokens so every GPU executes the same
|
| 491 |
+
compute graph. The gate multiplier controls whether retrieved
|
| 492 |
+
tokens contribute to the output.
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
gate: [1, 1] — differentiable: 0 = no loop, 1 = loop.
|
| 496 |
+
retrieved_tokens: [R, P, 2C] on GPU (R may be 0 if no history).
|
| 497 |
+
retrieved_fids: list[int] matching dim-0 of retrieved_tokens.
|
| 498 |
+
n_valid_retrieved: int — how many tokens are real.
|
| 499 |
+
primary_sid: submap id of the primary retrieved submap (-1 if none).
|
| 500 |
+
fetch_ids: list[int] — submap IDs that were fetched (for write-back).
|
| 501 |
+
fetch_token_counts: list[int] — number of tokens per fetched submap.
|
| 502 |
+
retrieval_weights: [R] weights for soft_all mode.
|
| 503 |
+
"""
|
| 504 |
+
num_hist = self.buffer.num_submaps
|
| 505 |
+
|
| 506 |
+
if num_hist == 0:
|
| 507 |
+
# No history at all — return empty (no padding waste: A5 fix)
|
| 508 |
+
P = self._curr_tokens[0].shape[1] if self._curr_tokens else 1
|
| 509 |
+
C2 = self._curr_tokens[0].shape[2] if self._curr_tokens else 2 * self.embed_dim
|
| 510 |
+
gate = torch.zeros(1, 1, device=device)
|
| 511 |
+
return gate, torch.empty(0, P, C2, device=device), [], 0, -1, [], [], None
|
| 512 |
+
|
| 513 |
+
if self.loop_mask_mode == "soft_all":
|
| 514 |
+
hist_ids = sorted(self.buffer.cpu_descriptor_buffer.keys())
|
| 515 |
+
prev_sid = self._current_submap_id - 1 if self._current_submap_id > 0 else None
|
| 516 |
+
fetch_ids = [sid for sid in hist_ids if sid != prev_sid]
|
| 517 |
+
P = self._curr_tokens[0].shape[1]
|
| 518 |
+
C2 = self._curr_tokens[0].shape[2]
|
| 519 |
+
if not fetch_ids:
|
| 520 |
+
gate = torch.zeros(1, 1, device=device)
|
| 521 |
+
return gate, torch.empty(0, P, C2, device=device), [], 0, -1, [], [], None
|
| 522 |
+
|
| 523 |
+
hist_descs = torch.stack(
|
| 524 |
+
[self.buffer.cpu_descriptor_buffer[sid] for sid in fetch_ids]
|
| 525 |
+
).to(device, non_blocking=True).float()
|
| 526 |
+
curr_desc_safe = torch.nan_to_num(curr_desc.float(), nan=0.0, posinf=0.0, neginf=0.0)
|
| 527 |
+
hist_descs = torch.nan_to_num(hist_descs, nan=0.0, posinf=0.0, neginf=0.0)
|
| 528 |
+
sim = F.cosine_similarity(
|
| 529 |
+
curr_desc_safe.unsqueeze(1),
|
| 530 |
+
hist_descs.unsqueeze(0),
|
| 531 |
+
dim=-1,
|
| 532 |
+
)
|
| 533 |
+
sim = torch.nan_to_num(sim, nan=-1.0, posinf=1.0, neginf=-1.0).clamp(min=-1.0, max=1.0)
|
| 534 |
+
|
| 535 |
+
selected_fetch_ids = list(fetch_ids)
|
| 536 |
+
selected_sim = sim
|
| 537 |
+
if self.submap_retrieval_topk > 0 and len(fetch_ids) > self.submap_retrieval_topk:
|
| 538 |
+
topk = min(self.submap_retrieval_topk, len(fetch_ids))
|
| 539 |
+
top_scores, top_indices = torch.topk(sim.squeeze(0), k=topk, dim=-1)
|
| 540 |
+
if topk == 1:
|
| 541 |
+
selected_index_list = [int(top_indices.item())]
|
| 542 |
+
else:
|
| 543 |
+
selected_index_list = [int(idx) for idx in top_indices.tolist()]
|
| 544 |
+
selected_fetch_ids = [fetch_ids[idx] for idx in selected_index_list]
|
| 545 |
+
selected_sim = top_scores.unsqueeze(0)
|
| 546 |
+
|
| 547 |
+
tau = max(float(self.soft_mask_temperature), 1e-6)
|
| 548 |
+
weights = torch.sigmoid((selected_sim - self.soft_mask_bias) / tau)
|
| 549 |
+
weights = torch.nan_to_num(weights, nan=0.0, posinf=1.0, neginf=0.0).clamp(min=0.0, max=1.0)
|
| 550 |
+
gate = weights.max(dim=-1, keepdim=True).values
|
| 551 |
+
selected_idx = int(selected_sim.argmax(dim=-1).item())
|
| 552 |
+
primary_submap_id = selected_fetch_ids[selected_idx]
|
| 553 |
+
|
| 554 |
+
retrieved_list: List[torch.Tensor] = []
|
| 555 |
+
retrieved_fids: List[int] = []
|
| 556 |
+
fetch_token_counts: List[int] = []
|
| 557 |
+
for sid in selected_fetch_ids:
|
| 558 |
+
t = self.buffer.fetch_tokens(sid, device)
|
| 559 |
+
retrieved_list.append(t)
|
| 560 |
+
fetch_token_counts.append(t.shape[0])
|
| 561 |
+
retrieved_fids.extend(self.buffer.fetch_frame_ids(sid))
|
| 562 |
+
|
| 563 |
+
retrieved = torch.cat(retrieved_list, dim=0) if retrieved_list else torch.empty(0, P, C2, device=device)
|
| 564 |
+
n_valid = retrieved.shape[0]
|
| 565 |
+
return gate, retrieved, retrieved_fids, n_valid, primary_submap_id, selected_fetch_ids, fetch_token_counts, weights.squeeze(0)
|
| 566 |
+
|
| 567 |
+
# Build descriptor bank: [NO_LOOP] + all historical
|
| 568 |
+
hist_descs = torch.nan_to_num(
|
| 569 |
+
self.buffer.get_all_descriptors(device).float(),
|
| 570 |
+
nan=0.0,
|
| 571 |
+
posinf=0.0,
|
| 572 |
+
neginf=0.0,
|
| 573 |
+
) # [H, D]
|
| 574 |
+
bank = torch.cat([self.no_loop_descriptor.float(), hist_descs], dim=0) # [H+1, D]
|
| 575 |
+
curr_desc_safe = torch.nan_to_num(curr_desc.float(), nan=0.0, posinf=0.0, neginf=0.0)
|
| 576 |
+
|
| 577 |
+
# Cosine similarity → Gumbel-Softmax
|
| 578 |
+
sim = F.cosine_similarity(
|
| 579 |
+
curr_desc_safe.unsqueeze(1), # [1, 1, D]
|
| 580 |
+
bank.unsqueeze(0), # [1, H+1, D]
|
| 581 |
+
dim=-1,
|
| 582 |
+
) # [1, H+1]
|
| 583 |
+
sim = torch.nan_to_num(sim, nan=-1.0, posinf=1.0, neginf=-1.0).clamp(min=-1.0, max=1.0)
|
| 584 |
+
selection = F.gumbel_softmax(sim, tau=max(float(self.gumbel_tau), 1e-6), hard=True, dim=-1)
|
| 585 |
+
selection = torch.nan_to_num(selection, nan=0.0, posinf=1.0, neginf=0.0)
|
| 586 |
+
|
| 587 |
+
selected_idx = selection.argmax(dim=-1).item() # int
|
| 588 |
+
|
| 589 |
+
# Gate: differentiable sum of non-NO_LOOP probabilities
|
| 590 |
+
gate = selection[:, 1:].sum(dim=-1, keepdim=True).clamp(min=0.0, max=1.0) # [1, 1]
|
| 591 |
+
|
| 592 |
+
# ── ALWAYS fetch primary + recursive neighbours (Fix B2) ──
|
| 593 |
+
# When NO_LOOP is selected (idx 0), we still fetch the *best*
|
| 594 |
+
# historical submap's tokens so the compute graph is identical
|
| 595 |
+
# across GPUs; the gate multiplier zeros them out.
|
| 596 |
+
P = self._curr_tokens[0].shape[1]
|
| 597 |
+
C2 = self._curr_tokens[0].shape[2]
|
| 598 |
+
retrieved_list: List[torch.Tensor] = []
|
| 599 |
+
retrieved_fids: List[int] = []
|
| 600 |
+
fetch_token_counts: List[int] = []
|
| 601 |
+
|
| 602 |
+
# Determine which submap to fetch (always pick one)
|
| 603 |
+
if selected_idx > 0:
|
| 604 |
+
primary_submap_id = self.buffer.id_at_index(selected_idx - 1)
|
| 605 |
+
else:
|
| 606 |
+
# NO_LOOP selected — still fetch the highest-similarity submap
|
| 607 |
+
# so the compute graph is the same across GPUs
|
| 608 |
+
non_noloop_sim = sim[0, 1:] # [H]
|
| 609 |
+
fallback_idx = non_noloop_sim.argmax().item()
|
| 610 |
+
primary_submap_id = self.buffer.id_at_index(fallback_idx)
|
| 611 |
+
|
| 612 |
+
fetch_ids = [primary_submap_id]
|
| 613 |
+
neighbours = self.adjacency.get(primary_submap_id, set())
|
| 614 |
+
for nid in sorted(neighbours):
|
| 615 |
+
if len(fetch_ids) >= self.max_recursive:
|
| 616 |
+
break
|
| 617 |
+
if nid != self._current_submap_id and nid in self.buffer.cpu_token_buffer:
|
| 618 |
+
fetch_ids.append(nid)
|
| 619 |
+
|
| 620 |
+
for sid in fetch_ids:
|
| 621 |
+
t = self.buffer.fetch_tokens(sid, device)
|
| 622 |
+
retrieved_list.append(t)
|
| 623 |
+
fetch_token_counts.append(t.shape[0])
|
| 624 |
+
retrieved_fids.extend(self.buffer.fetch_frame_ids(sid))
|
| 625 |
+
|
| 626 |
+
# No fixed-size padding (Fix A5): return actual tokens only
|
| 627 |
+
if retrieved_list:
|
| 628 |
+
retrieved = torch.cat(retrieved_list, dim=0) # [R, P, C2]
|
| 629 |
+
else:
|
| 630 |
+
retrieved = torch.empty(0, P, C2, device=device)
|
| 631 |
+
|
| 632 |
+
n_valid = retrieved.shape[0]
|
| 633 |
+
return gate, retrieved, retrieved_fids, n_valid, primary_submap_id, fetch_ids, fetch_token_counts, None
|
| 634 |
+
|
| 635 |
+
# ── finalize submap ──────────────────────────────────
|
| 636 |
+
def finalize_submap(
|
| 637 |
+
self,
|
| 638 |
+
model,
|
| 639 |
+
device: torch.device,
|
| 640 |
+
temporal_wrapper: Optional[TemporalEmbedWrapper] = None,
|
| 641 |
+
enable_temporal_embed: bool = False,
|
| 642 |
+
enable_loop_closure: bool = False,
|
| 643 |
+
tbptt_window: int = 10,
|
| 644 |
+
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
| 645 |
+
"""Finalize the current submap: pool, retrieve, run backend, slide window.
|
| 646 |
+
|
| 647 |
+
DDP-safe: always executes the full compute graph (Fix #4).
|
| 648 |
+
A4 fix: tokens within ``tbptt_window`` recent submaps keep gradients;
|
| 649 |
+
older ones are detached to cap memory.
|
| 650 |
+
A5 fix: retrieved tokens are NOT padded to a fixed size; only real
|
| 651 |
+
tokens enter backendT. The gate multiplier zeros out retrieved
|
| 652 |
+
contributions when NO_LOOP is selected.
|
| 653 |
+
|
| 654 |
+
Args:
|
| 655 |
+
model: SLAMFormer instance.
|
| 656 |
+
device: GPU device.
|
| 657 |
+
temporal_wrapper: TemporalEmbedWrapper (or None).
|
| 658 |
+
enable_temporal_embed: whether to inject temporal embeddings.
|
| 659 |
+
enable_loop_closure: whether to attempt differentiable loop retrieval.
|
| 660 |
+
tbptt_window: number of recent submaps whose stored tokens
|
| 661 |
+
keep gradients (A4 fix). Older submaps are
|
| 662 |
+
detached.
|
| 663 |
+
|
| 664 |
+
Returns:
|
| 665 |
+
backend_out: [N_total, P, 2C] — refined tokens from backendT.
|
| 666 |
+
loop_gate: [1, 1] — differentiable gate (0 = no loop, 1 = loop).
|
| 667 |
+
meta: dict with keys 'n_prev', 'n_curr', 'n_retrieved',
|
| 668 |
+
'frame_ids' (full list), 'curr_frame_ids'.
|
| 669 |
+
"""
|
| 670 |
+
# ── 1. Stack current submap tokens ───────────────
|
| 671 |
+
curr_tokens = torch.cat(self._curr_tokens, dim=0).to(device) # [K, P, 2C]
|
| 672 |
+
curr_desc = self.compute_descriptor(curr_tokens) # [1, D]
|
| 673 |
+
|
| 674 |
+
# ── 2. Loop retrieval (always executed, DDP-safe: Fix B2) ─
|
| 675 |
+
if enable_loop_closure:
|
| 676 |
+
loop_gate, retrieved_tokens, retrieved_fids, n_valid_ret, primary_sid, \
|
| 677 |
+
fetch_ids, fetch_token_counts, retrieval_weights = self.retrieve(curr_desc, device)
|
| 678 |
+
else:
|
| 679 |
+
P, C2 = curr_tokens.shape[1], curr_tokens.shape[2]
|
| 680 |
+
loop_gate = torch.zeros(1, 1, device=device)
|
| 681 |
+
retrieved_tokens = torch.empty(0, P, C2, device=device)
|
| 682 |
+
retrieved_fids = []
|
| 683 |
+
n_valid_ret = 0
|
| 684 |
+
primary_sid = -1
|
| 685 |
+
fetch_ids = []
|
| 686 |
+
fetch_token_counts = []
|
| 687 |
+
retrieval_weights = None
|
| 688 |
+
|
| 689 |
+
# ── 3. Build combined token tensor ───────────────
|
| 690 |
+
parts = []
|
| 691 |
+
fid_parts: List[int] = []
|
| 692 |
+
|
| 693 |
+
if self._prev_tokens is not None:
|
| 694 |
+
parts.append(self._prev_tokens.to(device))
|
| 695 |
+
fid_parts.extend(self._prev_frame_ids)
|
| 696 |
+
|
| 697 |
+
parts.append(curr_tokens)
|
| 698 |
+
fid_parts.extend(self._curr_frame_ids)
|
| 699 |
+
|
| 700 |
+
# A5 fix: only append retrieved tokens if there are any (no zero-padding)
|
| 701 |
+
n_retrieved = 0
|
| 702 |
+
if retrieved_tokens.shape[0] > 0:
|
| 703 |
+
if retrieval_weights is not None and len(fetch_token_counts) == len(retrieval_weights):
|
| 704 |
+
gated_chunks = []
|
| 705 |
+
offset = 0
|
| 706 |
+
for weight, count in zip(retrieval_weights, fetch_token_counts):
|
| 707 |
+
gated_chunks.append(
|
| 708 |
+
retrieved_tokens[offset: offset + count] * weight.reshape(1, 1, 1)
|
| 709 |
+
)
|
| 710 |
+
offset += count
|
| 711 |
+
gated_retrieved = torch.cat(gated_chunks, dim=0) if gated_chunks else retrieved_tokens
|
| 712 |
+
else:
|
| 713 |
+
gated_retrieved = retrieved_tokens * loop_gate.unsqueeze(-1)
|
| 714 |
+
parts.append(gated_retrieved)
|
| 715 |
+
fid_parts.extend(retrieved_fids)
|
| 716 |
+
n_retrieved = gated_retrieved.shape[0]
|
| 717 |
+
|
| 718 |
+
combined = torch.cat(parts, dim=0) if parts else curr_tokens # [N_total, P, 2C]
|
| 719 |
+
frame_ids_tensor = torch.tensor(fid_parts, dtype=torch.long, device=device)
|
| 720 |
+
|
| 721 |
+
# ── 4. Temporal embedding — Phase 1 (input) ──────
|
| 722 |
+
if enable_temporal_embed and temporal_wrapper is not None:
|
| 723 |
+
combined = temporal_wrapper.inject_input(combined, frame_ids_tensor)
|
| 724 |
+
|
| 725 |
+
# ── 5. Run backendT (always, DDP-safe) ───────────
|
| 726 |
+
hidden_B = _safe_oom_retry(model.backendT, combined)
|
| 727 |
+
|
| 728 |
+
# ── 6. Temporal embedding — Phase 2 (output) ─────
|
| 729 |
+
if enable_temporal_embed and temporal_wrapper is not None:
|
| 730 |
+
hidden_B = temporal_wrapper.inject_output(hidden_B, frame_ids_tensor)
|
| 731 |
+
|
| 732 |
+
# ── 7. Update adjacency graph (always, no Python branching) ──
|
| 733 |
+
if enable_loop_closure and primary_sid >= 0:
|
| 734 |
+
cid = self._current_submap_id
|
| 735 |
+
self.adjacency.setdefault(cid, set()).add(primary_sid)
|
| 736 |
+
self.adjacency.setdefault(primary_sid, set()).add(cid)
|
| 737 |
+
|
| 738 |
+
# ── 8. Slice refined tokens for each part ────────
|
| 739 |
+
n_prev = self._prev_tokens.shape[0] if self._prev_tokens is not None else 0
|
| 740 |
+
n_curr = curr_tokens.shape[0]
|
| 741 |
+
completed_submap_id = self._current_submap_id
|
| 742 |
+
prev_sid = completed_submap_id - 1 if n_prev > 0 else None
|
| 743 |
+
|
| 744 |
+
curr_backend_tokens = hidden_B[n_prev:n_prev + n_curr]
|
| 745 |
+
should_store_backend_tokens = self.use_dual_queue or self.submap_fetch_source == "backend"
|
| 746 |
+
should_store_backend_desc = self.use_dual_queue or self.submap_descriptor_source == "backend"
|
| 747 |
+
curr_backend_desc = (
|
| 748 |
+
self.compute_descriptor(curr_backend_tokens) if should_store_backend_desc else None
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
if self.use_dual_queue or should_store_backend_tokens or should_store_backend_desc:
|
| 752 |
+
if prev_sid is not None:
|
| 753 |
+
refined_prev = hidden_B[:n_prev]
|
| 754 |
+
if refined_prev.shape[0] > 0:
|
| 755 |
+
if should_store_backend_tokens:
|
| 756 |
+
self.buffer.update_tokens(prev_sid, refined_prev, source="backend")
|
| 757 |
+
if should_store_backend_desc:
|
| 758 |
+
self.buffer.update_descriptor(
|
| 759 |
+
prev_sid,
|
| 760 |
+
self.compute_descriptor(refined_prev).squeeze(0),
|
| 761 |
+
source="backend",
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
offset = n_prev + n_curr
|
| 765 |
+
for sid, count in zip(fetch_ids, fetch_token_counts):
|
| 766 |
+
refined_ret = hidden_B[offset: offset + count]
|
| 767 |
+
if refined_ret.shape[0] > 0:
|
| 768 |
+
if should_store_backend_tokens:
|
| 769 |
+
self.buffer.update_tokens(sid, refined_ret, source="backend")
|
| 770 |
+
if should_store_backend_desc:
|
| 771 |
+
self.buffer.update_descriptor(
|
| 772 |
+
sid,
|
| 773 |
+
self.compute_descriptor(refined_ret).squeeze(0),
|
| 774 |
+
source="backend",
|
| 775 |
+
)
|
| 776 |
+
offset += count
|
| 777 |
+
|
| 778 |
+
# ── 9. Store current submap
|
| 779 |
+
self.buffer.store(
|
| 780 |
+
submap_id=completed_submap_id,
|
| 781 |
+
frame_ids=self._curr_frame_ids,
|
| 782 |
+
frontend_tokens=curr_tokens,
|
| 783 |
+
frontend_descriptor=curr_desc.squeeze(0),
|
| 784 |
+
backend_tokens=curr_backend_tokens if should_store_backend_tokens else None,
|
| 785 |
+
backend_descriptor=(
|
| 786 |
+
curr_backend_desc.squeeze(0) if curr_backend_desc is not None else None
|
| 787 |
+
),
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
# ── 10. Release retrieved tensors ─────────────────
|
| 791 |
+
del retrieved_tokens
|
| 792 |
+
|
| 793 |
+
# ── 11. TBPTT: detach memory states periodically ──
|
| 794 |
+
# When submap_id crosses a tbptt_window boundary, detach ALL
|
| 795 |
+
# stored tokens in the buffer and _prev_tokens. This cuts the
|
| 796 |
+
# backward graph at memory-state level (not loss level), so
|
| 797 |
+
# gradients flow within the window but not across.
|
| 798 |
+
next_id = self._current_submap_id + 1
|
| 799 |
+
should_detach_history = (
|
| 800 |
+
tbptt_window is not None and tbptt_window > 0 and next_id % tbptt_window == 0 and next_id > 0
|
| 801 |
+
)
|
| 802 |
+
if should_detach_history:
|
| 803 |
+
self.buffer.detach_all()
|
| 804 |
+
|
| 805 |
+
# ── 12. Slide window ─────────────────────────────
|
| 806 |
+
next_prev_tokens = curr_tokens
|
| 807 |
+
if self.submap_fetch_source == "backend" and should_store_backend_tokens:
|
| 808 |
+
next_prev_tokens = curr_backend_tokens
|
| 809 |
+
|
| 810 |
+
if self.retain_history_grad and not should_detach_history:
|
| 811 |
+
self._prev_tokens = next_prev_tokens
|
| 812 |
+
elif self.retain_history_grad:
|
| 813 |
+
self._prev_tokens = next_prev_tokens.detach()
|
| 814 |
+
else:
|
| 815 |
+
self._prev_tokens = next_prev_tokens.detach().cpu()
|
| 816 |
+
self._prev_frame_ids = list(self._curr_frame_ids)
|
| 817 |
+
self._curr_tokens = []
|
| 818 |
+
self._curr_frame_ids = []
|
| 819 |
+
self._current_submap_id += 1
|
| 820 |
+
|
| 821 |
+
active_curr_desc = curr_desc
|
| 822 |
+
if self.submap_descriptor_source == "backend" and curr_backend_desc is not None:
|
| 823 |
+
active_curr_desc = curr_backend_desc
|
| 824 |
+
|
| 825 |
+
meta = {
|
| 826 |
+
'submap_id': completed_submap_id,
|
| 827 |
+
'n_prev': n_prev,
|
| 828 |
+
'n_curr': n_curr,
|
| 829 |
+
'n_retrieved': n_retrieved,
|
| 830 |
+
'frame_ids': fid_parts,
|
| 831 |
+
'curr_frame_ids': list(self._prev_frame_ids), # after slide, prev = old curr
|
| 832 |
+
'curr_descriptor': active_curr_desc,
|
| 833 |
+
'curr_frontend_descriptor': curr_desc,
|
| 834 |
+
'curr_backend_descriptor': curr_backend_desc,
|
| 835 |
+
'submap_train_mode': self.submap_train_mode,
|
| 836 |
+
'submap_descriptor_source': self.submap_descriptor_source,
|
| 837 |
+
}
|
| 838 |
+
return hidden_B, loop_gate, meta
|
| 839 |
+
|
| 840 |
+
# ── reset ────────────────────────────────────────────
|
| 841 |
+
def reset(self):
|
| 842 |
+
"""Clear all state (e.g. between sequences)."""
|
| 843 |
+
self.buffer = self._build_buffer()
|
| 844 |
+
self.adjacency.clear()
|
| 845 |
+
self._curr_tokens.clear()
|
| 846 |
+
self._curr_frame_ids.clear()
|
| 847 |
+
self._prev_tokens = None
|
| 848 |
+
self._prev_frame_ids.clear()
|
| 849 |
+
self._current_submap_id = 0
|
| 850 |
+
self._global_frame_counter = 0
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/mine_pseudo_gt.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
def canonical_view_key_from_values(dataset: Optional[str], label: Optional[str]) -> Optional[str]:
|
| 12 |
+
if dataset is None or label is None:
|
| 13 |
+
return None
|
| 14 |
+
return f"{dataset}::{label}"
|
| 15 |
+
|
| 16 |
+
def load_payload(path: Optional[str]):
|
| 17 |
+
if not path:
|
| 18 |
+
return None
|
| 19 |
+
path = os.path.expanduser(path)
|
| 20 |
+
if path.endswith(".json"):
|
| 21 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 22 |
+
return json.load(f)
|
| 23 |
+
if path.endswith(".jsonl"):
|
| 24 |
+
records = []
|
| 25 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 26 |
+
for line in f:
|
| 27 |
+
line = line.strip()
|
| 28 |
+
if line:
|
| 29 |
+
records.append(json.loads(line))
|
| 30 |
+
return {"records": records}
|
| 31 |
+
return torch.load(path, map_location="cpu", weights_only=False)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _scalar(value, default=None):
|
| 35 |
+
if isinstance(value, (list, tuple)):
|
| 36 |
+
if not value:
|
| 37 |
+
return default
|
| 38 |
+
return _scalar(value[0], default)
|
| 39 |
+
if torch.is_tensor(value):
|
| 40 |
+
if value.numel() == 0:
|
| 41 |
+
return default
|
| 42 |
+
return value.detach().reshape(-1)[0].cpu().item()
|
| 43 |
+
return value if value is not None else default
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _float(value, default=0.0):
|
| 47 |
+
value = _scalar(value, default)
|
| 48 |
+
try:
|
| 49 |
+
return float(value)
|
| 50 |
+
except (TypeError, ValueError):
|
| 51 |
+
return float(default)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _bool(record: Optional[Dict[str, Any]], score_threshold: float = 0.0) -> Optional[bool]:
|
| 55 |
+
if record is None:
|
| 56 |
+
return None
|
| 57 |
+
for key in ("is_positive", "accepted", "match", "loop"):
|
| 58 |
+
if key in record:
|
| 59 |
+
return bool(record[key])
|
| 60 |
+
tag = str(record.get("tag", "")).strip().lower()
|
| 61 |
+
if tag in {"positive", "pos", "match", "loop", "true", "1"}:
|
| 62 |
+
return True
|
| 63 |
+
if tag in {"negative", "neg", "false", "0"}:
|
| 64 |
+
return False
|
| 65 |
+
score = _float(record.get("score"), _float(record.get("confidence"), 0.0))
|
| 66 |
+
if score_threshold > 0:
|
| 67 |
+
return score >= score_threshold
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def load_pair_cache(path: Optional[str]) -> Dict[Tuple[str, str], Dict[str, Any]]:
|
| 72 |
+
payload = load_payload(path)
|
| 73 |
+
if payload is None:
|
| 74 |
+
return {}
|
| 75 |
+
raw_records = payload
|
| 76 |
+
if isinstance(payload, dict):
|
| 77 |
+
raw_records = payload.get("records", payload.get("pairs", payload))
|
| 78 |
+
if isinstance(raw_records, dict):
|
| 79 |
+
iterator = [
|
| 80 |
+
({**value, "pair_key": key} if isinstance(value, dict) else value)
|
| 81 |
+
for key, value in raw_records.items()
|
| 82 |
+
]
|
| 83 |
+
else:
|
| 84 |
+
iterator = raw_records
|
| 85 |
+
cache = {}
|
| 86 |
+
for record in iterator:
|
| 87 |
+
if not isinstance(record, dict):
|
| 88 |
+
continue
|
| 89 |
+
key_a = record.get("key_a") or record.get("frame_key_a")
|
| 90 |
+
key_b = record.get("key_b") or record.get("frame_key_b")
|
| 91 |
+
if key_a is None:
|
| 92 |
+
key_a = canonical_view_key_from_values(record.get("dataset_a"), record.get("label_a"))
|
| 93 |
+
if key_b is None:
|
| 94 |
+
key_b = canonical_view_key_from_values(record.get("dataset_b"), record.get("label_b"))
|
| 95 |
+
pair_key = record.get("pair_key")
|
| 96 |
+
if (key_a is None or key_b is None) and isinstance(pair_key, str) and "||" in pair_key:
|
| 97 |
+
key_a, key_b = pair_key.split("||", 1)
|
| 98 |
+
if key_a is None or key_b is None or key_a == key_b:
|
| 99 |
+
continue
|
| 100 |
+
pair = (key_a, key_b) if key_a <= key_b else (key_b, key_a)
|
| 101 |
+
cache[pair] = record
|
| 102 |
+
return cache
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def inverse_pose(pose: np.ndarray) -> np.ndarray:
|
| 106 |
+
pose = np.asarray(pose, dtype=np.float32)
|
| 107 |
+
inv_pose = np.eye(4, dtype=np.float32)
|
| 108 |
+
R = pose[:3, :3]
|
| 109 |
+
t = pose[:3, 3]
|
| 110 |
+
inv_pose[:3, :3] = R.T
|
| 111 |
+
inv_pose[:3, 3] = -R.T @ t
|
| 112 |
+
return inv_pose
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def pose_distance(pose_a: np.ndarray, pose_b: np.ndarray) -> float:
|
| 116 |
+
return float(np.linalg.norm(np.asarray(pose_a[:3, 3]) - np.asarray(pose_b[:3, 3])))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def heading_angle_deg(pose_a: np.ndarray, pose_b: np.ndarray) -> float:
|
| 120 |
+
forward_a = np.asarray(pose_a[:3, 2], dtype=np.float32)
|
| 121 |
+
forward_b = np.asarray(pose_b[:3, 2], dtype=np.float32)
|
| 122 |
+
norm_a = np.linalg.norm(forward_a)
|
| 123 |
+
norm_b = np.linalg.norm(forward_b)
|
| 124 |
+
if norm_a <= 1e-8 or norm_b <= 1e-8:
|
| 125 |
+
return 180.0
|
| 126 |
+
dot = float(np.clip(np.dot(forward_a / norm_a, forward_b / norm_b), -1.0, 1.0))
|
| 127 |
+
return float(np.degrees(np.arccos(dot)))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def sample_pixels(depth: np.ndarray, sample_points: int, rng: np.random.Generator) -> Optional[np.ndarray]:
|
| 131 |
+
valid = np.argwhere(depth > 0)
|
| 132 |
+
if len(valid) == 0:
|
| 133 |
+
return None
|
| 134 |
+
if len(valid) > sample_points:
|
| 135 |
+
valid = valid[rng.choice(len(valid), size=sample_points, replace=False)]
|
| 136 |
+
uv = valid[:, [1, 0]].astype(np.float32)
|
| 137 |
+
uv += 0.5
|
| 138 |
+
return uv
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def project_points(uv: np.ndarray, depth: np.ndarray, intrinsics: np.ndarray, pose_src: np.ndarray, pose_dst: np.ndarray):
|
| 142 |
+
fx = float(intrinsics[0, 0])
|
| 143 |
+
fy = float(intrinsics[1, 1])
|
| 144 |
+
cx = float(intrinsics[0, 2])
|
| 145 |
+
cy = float(intrinsics[1, 2])
|
| 146 |
+
u = uv[:, 0]
|
| 147 |
+
v = uv[:, 1]
|
| 148 |
+
z = depth[np.clip((v - 0.5).astype(np.int64), 0, depth.shape[0] - 1), np.clip((u - 0.5).astype(np.int64), 0, depth.shape[1] - 1)]
|
| 149 |
+
x = (u - cx) * z / max(fx, 1e-6)
|
| 150 |
+
y = (v - cy) * z / max(fy, 1e-6)
|
| 151 |
+
points_cam = np.stack([x, y, z], axis=-1)
|
| 152 |
+
world = points_cam @ pose_src[:3, :3].T + pose_src[:3, 3]
|
| 153 |
+
world_to_dst = inverse_pose(pose_dst)
|
| 154 |
+
points_dst = world @ world_to_dst[:3, :3].T + world_to_dst[:3, 3]
|
| 155 |
+
return points_dst
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def directed_overlap(
|
| 159 |
+
src_depth: np.ndarray,
|
| 160 |
+
dst_depth: np.ndarray,
|
| 161 |
+
intrinsics_src: np.ndarray,
|
| 162 |
+
intrinsics_dst: np.ndarray,
|
| 163 |
+
pose_src: np.ndarray,
|
| 164 |
+
pose_dst: np.ndarray,
|
| 165 |
+
sample_points: int,
|
| 166 |
+
depth_tolerance_ratio: float,
|
| 167 |
+
rng: np.random.Generator,
|
| 168 |
+
) -> Dict[str, float]:
|
| 169 |
+
uv = sample_pixels(src_depth, sample_points, rng)
|
| 170 |
+
if uv is None:
|
| 171 |
+
return {"frustum": 0.0, "depth": 0.0, "consistent": 0.0, "count": 0.0}
|
| 172 |
+
points_dst = project_points(uv, src_depth, intrinsics_src, pose_src, pose_dst)
|
| 173 |
+
z_dst = points_dst[:, 2]
|
| 174 |
+
fx = float(intrinsics_dst[0, 0])
|
| 175 |
+
fy = float(intrinsics_dst[1, 1])
|
| 176 |
+
cx = float(intrinsics_dst[0, 2])
|
| 177 |
+
cy = float(intrinsics_dst[1, 2])
|
| 178 |
+
u_dst = fx * points_dst[:, 0] / np.clip(z_dst, 1e-6, None) + cx
|
| 179 |
+
v_dst = fy * points_dst[:, 1] / np.clip(z_dst, 1e-6, None) + cy
|
| 180 |
+
inside = (
|
| 181 |
+
(z_dst > 1e-6)
|
| 182 |
+
& (u_dst >= 0.0)
|
| 183 |
+
& (u_dst < dst_depth.shape[1])
|
| 184 |
+
& (v_dst >= 0.0)
|
| 185 |
+
& (v_dst < dst_depth.shape[0])
|
| 186 |
+
)
|
| 187 |
+
frustum = float(inside.mean())
|
| 188 |
+
if not inside.any():
|
| 189 |
+
return {"frustum": frustum, "depth": 0.0, "consistent": 0.0, "count": float(len(uv))}
|
| 190 |
+
u_idx = np.clip(np.round(u_dst[inside]).astype(np.int64), 0, dst_depth.shape[1] - 1)
|
| 191 |
+
v_idx = np.clip(np.round(v_dst[inside]).astype(np.int64), 0, dst_depth.shape[0] - 1)
|
| 192 |
+
sampled_dst = dst_depth[v_idx, u_idx]
|
| 193 |
+
valid_dst = sampled_dst > 0
|
| 194 |
+
consistent = valid_dst & (np.abs(sampled_dst - z_dst[inside]) / np.clip(sampled_dst, 1e-6, None) <= depth_tolerance_ratio)
|
| 195 |
+
depth_overlap = float(consistent.sum() / max(1, len(uv)))
|
| 196 |
+
return {
|
| 197 |
+
"frustum": frustum,
|
| 198 |
+
"depth": depth_overlap,
|
| 199 |
+
"consistent": float(consistent.sum()),
|
| 200 |
+
"count": float(len(uv)),
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def symmetric_overlap(frame_a, frame_b, depth_a, depth_b, args, rng: np.random.Generator) -> Dict[str, float]:
|
| 205 |
+
a_to_b = directed_overlap(
|
| 206 |
+
depth_a,
|
| 207 |
+
depth_b,
|
| 208 |
+
frame_a["intrinsics"],
|
| 209 |
+
frame_b["intrinsics"],
|
| 210 |
+
frame_a["pose"],
|
| 211 |
+
frame_b["pose"],
|
| 212 |
+
args.sample_points,
|
| 213 |
+
args.depth_tolerance_ratio,
|
| 214 |
+
rng,
|
| 215 |
+
)
|
| 216 |
+
b_to_a = directed_overlap(
|
| 217 |
+
depth_b,
|
| 218 |
+
depth_a,
|
| 219 |
+
frame_b["intrinsics"],
|
| 220 |
+
frame_a["intrinsics"],
|
| 221 |
+
frame_b["pose"],
|
| 222 |
+
frame_a["pose"],
|
| 223 |
+
args.sample_points,
|
| 224 |
+
args.depth_tolerance_ratio,
|
| 225 |
+
rng,
|
| 226 |
+
)
|
| 227 |
+
return {
|
| 228 |
+
"frustum_overlap": 0.5 * (a_to_b["frustum"] + b_to_a["frustum"]),
|
| 229 |
+
"depth_overlap": 0.5 * (a_to_b["depth"] + b_to_a["depth"]),
|
| 230 |
+
"geometric_support_count": int(round(a_to_b["consistent"] + b_to_a["consistent"])),
|
| 231 |
+
"sample_count": int(round(a_to_b["count"] + b_to_a["count"])),
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def load_arkitscenes(root: str, split: str) -> List[Dict[str, Any]]:
|
| 236 |
+
split_dir = "Training" if split.lower() == "train" else "Test"
|
| 237 |
+
meta_root = Path(root) / split_dir
|
| 238 |
+
all_metadata = np.load(meta_root / "all_metadata.npz")
|
| 239 |
+
scenes = []
|
| 240 |
+
for scene_name in all_metadata["scenes"]:
|
| 241 |
+
scene_name = str(scene_name)
|
| 242 |
+
scene_dir = meta_root / scene_name
|
| 243 |
+
meta_path = scene_dir / "new_scene_metadata.npz"
|
| 244 |
+
if not meta_path.is_file():
|
| 245 |
+
continue
|
| 246 |
+
with np.load(meta_path, allow_pickle=True) as meta:
|
| 247 |
+
images = meta["images"]
|
| 248 |
+
intrinsics = meta["intrinsics"]
|
| 249 |
+
trajectories = meta["trajectories"]
|
| 250 |
+
frames = []
|
| 251 |
+
for basename, intri, pose in zip(images, intrinsics, trajectories):
|
| 252 |
+
basename = str(basename)
|
| 253 |
+
K = np.eye(3, dtype=np.float32)
|
| 254 |
+
K[0, 0] = intri[2]
|
| 255 |
+
K[1, 1] = intri[3]
|
| 256 |
+
K[0, 2] = intri[4]
|
| 257 |
+
K[1, 2] = intri[5]
|
| 258 |
+
depth_path = scene_dir / "lowres_depth" / basename
|
| 259 |
+
image_path = scene_dir / "vga_wide" / basename.replace(".png", ".jpg")
|
| 260 |
+
if not depth_path.is_file():
|
| 261 |
+
continue
|
| 262 |
+
frames.append({
|
| 263 |
+
"dataset": "arkitscenes",
|
| 264 |
+
"scene": scene_name,
|
| 265 |
+
"label": f"{scene_name}_{basename}",
|
| 266 |
+
"pose": np.asarray(pose, dtype=np.float32),
|
| 267 |
+
"intrinsics": K,
|
| 268 |
+
"depth_path": str(depth_path),
|
| 269 |
+
"image_path": str(image_path),
|
| 270 |
+
})
|
| 271 |
+
if frames:
|
| 272 |
+
scenes.append({"scene": scene_name, "frames": frames})
|
| 273 |
+
return scenes
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def load_scannetpp(root: str) -> List[Dict[str, Any]]:
|
| 277 |
+
scenes = []
|
| 278 |
+
all_metadata = np.load(Path(root) / "all_metadata.npz")
|
| 279 |
+
for scene_name in all_metadata["scenes"]:
|
| 280 |
+
scene_name = str(scene_name)
|
| 281 |
+
scene_dir = Path(root) / scene_name
|
| 282 |
+
meta_path = scene_dir / "new_scene_metadata.npz"
|
| 283 |
+
if not meta_path.is_file():
|
| 284 |
+
continue
|
| 285 |
+
with np.load(meta_path, allow_pickle=True) as meta:
|
| 286 |
+
images = meta["images"]
|
| 287 |
+
intrinsics = meta["intrinsics"]
|
| 288 |
+
trajectories = meta["trajectories"]
|
| 289 |
+
frames = []
|
| 290 |
+
for basename, intri, pose in zip(images, intrinsics, trajectories):
|
| 291 |
+
basename = str(basename)
|
| 292 |
+
depth_path = scene_dir / "depth" / f"{basename}.png"
|
| 293 |
+
image_path = scene_dir / "images" / f"{basename}.jpg"
|
| 294 |
+
if not depth_path.is_file():
|
| 295 |
+
continue
|
| 296 |
+
frames.append({
|
| 297 |
+
"dataset": "ScanNet++",
|
| 298 |
+
"scene": scene_name,
|
| 299 |
+
"label": f"{scene_name}_{basename}",
|
| 300 |
+
"pose": np.asarray(pose, dtype=np.float32),
|
| 301 |
+
"intrinsics": np.asarray(intri, dtype=np.float32),
|
| 302 |
+
"depth_path": str(depth_path),
|
| 303 |
+
"image_path": str(image_path),
|
| 304 |
+
})
|
| 305 |
+
if frames:
|
| 306 |
+
scenes.append({"scene": scene_name, "frames": frames})
|
| 307 |
+
return scenes
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def load_mvs_synth(root: str) -> List[Dict[str, Any]]:
|
| 311 |
+
scenes = []
|
| 312 |
+
for scene_name in sorted(os.listdir(root)):
|
| 313 |
+
scene_dir = Path(root) / scene_name
|
| 314 |
+
rgb_dir = scene_dir / "rgb"
|
| 315 |
+
depth_dir = scene_dir / "depth"
|
| 316 |
+
cam_dir = scene_dir / "cam"
|
| 317 |
+
if not rgb_dir.is_dir() or not depth_dir.is_dir() or not cam_dir.is_dir():
|
| 318 |
+
continue
|
| 319 |
+
basenames = sorted([path.stem for path in rgb_dir.glob("*.jpg")])
|
| 320 |
+
frames = []
|
| 321 |
+
for basename in basenames:
|
| 322 |
+
cam_path = cam_dir / f"{basename}.npz"
|
| 323 |
+
depth_path = depth_dir / f"{basename}.npy"
|
| 324 |
+
if not cam_path.is_file() or not depth_path.is_file():
|
| 325 |
+
continue
|
| 326 |
+
cam = np.load(cam_path)
|
| 327 |
+
frames.append({
|
| 328 |
+
"dataset": "MVS_Synth",
|
| 329 |
+
"scene": scene_name,
|
| 330 |
+
"label": f"{scene_name}_{basename}",
|
| 331 |
+
"pose": np.asarray(cam["pose"], dtype=np.float32),
|
| 332 |
+
"intrinsics": np.asarray(cam["intrinsics"], dtype=np.float32),
|
| 333 |
+
"depth_path": str(depth_path),
|
| 334 |
+
"image_path": str(rgb_dir / f"{basename}.jpg"),
|
| 335 |
+
})
|
| 336 |
+
if frames:
|
| 337 |
+
scenes.append({"scene": scene_name, "frames": frames})
|
| 338 |
+
return scenes
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def load_depth(frame: Dict[str, Any]) -> np.ndarray:
|
| 342 |
+
depth_path = frame["depth_path"]
|
| 343 |
+
if depth_path.endswith(".npy"):
|
| 344 |
+
depth = np.load(depth_path).astype(np.float32)
|
| 345 |
+
valid = depth > 0
|
| 346 |
+
if valid.any():
|
| 347 |
+
threshold = np.percentile(depth[valid], 98)
|
| 348 |
+
depth[depth > threshold] = 0.0
|
| 349 |
+
depth[depth > 1000.0] = 0.0
|
| 350 |
+
else:
|
| 351 |
+
depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 1000.0
|
| 352 |
+
depth[~np.isfinite(depth)] = 0.0
|
| 353 |
+
return depth
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def load_scenes(dataset: str, root: str, split: str) -> List[Dict[str, Any]]:
|
| 357 |
+
dataset = dataset.lower()
|
| 358 |
+
if dataset == "arkitscenes":
|
| 359 |
+
return load_arkitscenes(root, split)
|
| 360 |
+
if dataset == "scannetpp":
|
| 361 |
+
return load_scannetpp(root)
|
| 362 |
+
if dataset == "mvs_synth":
|
| 363 |
+
return load_mvs_synth(root)
|
| 364 |
+
raise ValueError(f"Unsupported dataset: {dataset}")
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def limit_frames(frames: List[Dict[str, Any]], args) -> List[Dict[str, Any]]:
|
| 368 |
+
frames = frames[:: max(1, args.frame_stride)]
|
| 369 |
+
if args.max_frames_per_scene > 0:
|
| 370 |
+
frames = frames[: args.max_frames_per_scene]
|
| 371 |
+
return frames
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def pair_key(frame_a: Dict[str, Any], frame_b: Dict[str, Any]) -> Tuple[str, str]:
|
| 375 |
+
key_a = canonical_view_key_from_values(frame_a["dataset"], frame_a["label"])
|
| 376 |
+
key_b = canonical_view_key_from_values(frame_b["dataset"], frame_b["label"])
|
| 377 |
+
return (key_a, key_b) if key_a <= key_b else (key_b, key_a)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def lookup_cache(cache: Dict[Tuple[str, str], Dict[str, Any]], frame_a: Dict[str, Any], frame_b: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 381 |
+
return cache.get(pair_key(frame_a, frame_b))
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def l2m_positive(record: Optional[Dict[str, Any]], args) -> bool:
|
| 385 |
+
if record is None:
|
| 386 |
+
return False
|
| 387 |
+
match_count = int(round(_float(record.get("l2m_match_count"), _float(record.get("match_count"), 0.0))))
|
| 388 |
+
certainty = _float(record.get("l2m_mean_certainty"), _float(record.get("mean_certainty"), 0.0))
|
| 389 |
+
inlier_ratio = _float(record.get("l2m_inlier_ratio"), _float(record.get("inlier_ratio"), 0.0))
|
| 390 |
+
return (
|
| 391 |
+
match_count >= args.l2m_min_match_count
|
| 392 |
+
and certainty >= args.l2m_min_certainty
|
| 393 |
+
and inlier_ratio >= args.l2m_min_inlier_ratio
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def mine_scene(scene: Dict[str, Any], args, sage_cache, l2m_cache, rng: np.random.Generator):
|
| 398 |
+
frames = limit_frames(scene["frames"], args)
|
| 399 |
+
if len(frames) <= args.min_frame_gap:
|
| 400 |
+
return []
|
| 401 |
+
depth_cache: Dict[str, np.ndarray] = {}
|
| 402 |
+
records = []
|
| 403 |
+
num_pairs = 0
|
| 404 |
+
for i in range(len(frames)):
|
| 405 |
+
for j in range(i + args.min_frame_gap, len(frames), max(1, args.pair_step)):
|
| 406 |
+
if args.max_pairs_per_scene > 0 and num_pairs >= args.max_pairs_per_scene:
|
| 407 |
+
return records
|
| 408 |
+
frame_a = frames[i]
|
| 409 |
+
frame_b = frames[j]
|
| 410 |
+
dist = pose_distance(frame_a["pose"], frame_b["pose"])
|
| 411 |
+
heading = heading_angle_deg(frame_a["pose"], frame_b["pose"])
|
| 412 |
+
positive_coarse = dist <= args.max_translation and heading <= args.max_heading_deg
|
| 413 |
+
negative_coarse = dist <= args.hard_negative_max_translation and heading <= args.hard_negative_max_heading_deg
|
| 414 |
+
if not positive_coarse and not negative_coarse:
|
| 415 |
+
continue
|
| 416 |
+
if frame_a["depth_path"] not in depth_cache:
|
| 417 |
+
depth_cache[frame_a["depth_path"]] = load_depth(frame_a)
|
| 418 |
+
if frame_b["depth_path"] not in depth_cache:
|
| 419 |
+
depth_cache[frame_b["depth_path"]] = load_depth(frame_b)
|
| 420 |
+
overlap = symmetric_overlap(frame_a, frame_b, depth_cache[frame_a["depth_path"]], depth_cache[frame_b["depth_path"]], args, rng)
|
| 421 |
+
frustum_overlap = overlap["frustum_overlap"]
|
| 422 |
+
depth_overlap = overlap["depth_overlap"]
|
| 423 |
+
geometric_support_count = overlap["geometric_support_count"]
|
| 424 |
+
num_pairs += 1
|
| 425 |
+
sage_record = lookup_cache(sage_cache, frame_a, frame_b)
|
| 426 |
+
l2m_record = lookup_cache(l2m_cache, frame_a, frame_b)
|
| 427 |
+
sage_pass = _bool(sage_record, args.sage_min_score)
|
| 428 |
+
l2m_pass = l2m_positive(l2m_record, args)
|
| 429 |
+
if positive_coarse and frustum_overlap >= args.min_frustum_overlap and depth_overlap >= args.min_depth_overlap:
|
| 430 |
+
accepted = True
|
| 431 |
+
verification_path = "geometry_only"
|
| 432 |
+
if sage_cache:
|
| 433 |
+
if sage_pass is False:
|
| 434 |
+
accepted = False
|
| 435 |
+
elif sage_pass is True:
|
| 436 |
+
verification_path = "geometry+sage"
|
| 437 |
+
if l2m_cache:
|
| 438 |
+
if l2m_pass and verification_path == "geometry+sage":
|
| 439 |
+
verification_path = "geometry+sage+l2m"
|
| 440 |
+
elif l2m_pass and verification_path != "geometry+sage":
|
| 441 |
+
verification_path = "geometry+l2m_rescue"
|
| 442 |
+
accepted = True
|
| 443 |
+
elif sage_cache and sage_pass is False:
|
| 444 |
+
accepted = False
|
| 445 |
+
if accepted:
|
| 446 |
+
l2m_match_count = int(round(_float((l2m_record or {}).get("l2m_match_count"), _float((l2m_record or {}).get("match_count"), 0.0))))
|
| 447 |
+
l2m_mean_certainty = _float((l2m_record or {}).get("l2m_mean_certainty"), _float((l2m_record or {}).get("mean_certainty"), 0.0))
|
| 448 |
+
l2m_inlier_ratio = _float((l2m_record or {}).get("l2m_inlier_ratio"), _float((l2m_record or {}).get("inlier_ratio"), 0.0))
|
| 449 |
+
score = max(depth_overlap, frustum_overlap)
|
| 450 |
+
if sage_pass is True:
|
| 451 |
+
score = max(score, _float((sage_record or {}).get("score"), _float((sage_record or {}).get("confidence"), score)))
|
| 452 |
+
if l2m_pass:
|
| 453 |
+
score = max(score, l2m_mean_certainty)
|
| 454 |
+
record = {
|
| 455 |
+
"scene": scene["scene"],
|
| 456 |
+
"dataset_a": frame_a["dataset"],
|
| 457 |
+
"dataset_b": frame_b["dataset"],
|
| 458 |
+
"label_a": frame_a["label"],
|
| 459 |
+
"label_b": frame_b["label"],
|
| 460 |
+
"key_a": canonical_view_key_from_values(frame_a["dataset"], frame_a["label"]),
|
| 461 |
+
"key_b": canonical_view_key_from_values(frame_b["dataset"], frame_b["label"]),
|
| 462 |
+
"tag": "positive",
|
| 463 |
+
"is_positive": True,
|
| 464 |
+
"score": float(score),
|
| 465 |
+
"soft_overlap_target": float(depth_overlap),
|
| 466 |
+
"overlap": float(depth_overlap),
|
| 467 |
+
"pair_confidence_weight": float(max(0.05, min(1.0, score))),
|
| 468 |
+
"weight": float(max(0.05, min(1.0, score))),
|
| 469 |
+
"pose_distance": float(dist),
|
| 470 |
+
"heading_deg": float(heading),
|
| 471 |
+
"frustum_overlap": float(frustum_overlap),
|
| 472 |
+
"depth_overlap": float(depth_overlap),
|
| 473 |
+
"geometric_support_count": int(geometric_support_count),
|
| 474 |
+
"verification_path": verification_path,
|
| 475 |
+
"l2m_match_count": l2m_match_count,
|
| 476 |
+
"l2m_mean_certainty": float(l2m_mean_certainty),
|
| 477 |
+
"l2m_inlier_ratio": float(l2m_inlier_ratio),
|
| 478 |
+
"image_path_a": frame_a["image_path"],
|
| 479 |
+
"image_path_b": frame_b["image_path"],
|
| 480 |
+
}
|
| 481 |
+
records.append(record)
|
| 482 |
+
elif negative_coarse and frustum_overlap <= args.negative_max_frustum_overlap and depth_overlap <= args.negative_max_depth_overlap:
|
| 483 |
+
score = max(1.0 - max(frustum_overlap, depth_overlap), 0.0)
|
| 484 |
+
records.append({
|
| 485 |
+
"scene": scene["scene"],
|
| 486 |
+
"dataset_a": frame_a["dataset"],
|
| 487 |
+
"dataset_b": frame_b["dataset"],
|
| 488 |
+
"label_a": frame_a["label"],
|
| 489 |
+
"label_b": frame_b["label"],
|
| 490 |
+
"key_a": canonical_view_key_from_values(frame_a["dataset"], frame_a["label"]),
|
| 491 |
+
"key_b": canonical_view_key_from_values(frame_b["dataset"], frame_b["label"]),
|
| 492 |
+
"tag": "hard_negative",
|
| 493 |
+
"is_positive": False,
|
| 494 |
+
"score": float(score),
|
| 495 |
+
"soft_overlap_target": 0.0,
|
| 496 |
+
"overlap": 0.0,
|
| 497 |
+
"pair_confidence_weight": float(max(0.05, min(1.0, score))),
|
| 498 |
+
"weight": float(max(0.05, min(1.0, score))),
|
| 499 |
+
"pose_distance": float(dist),
|
| 500 |
+
"heading_deg": float(heading),
|
| 501 |
+
"frustum_overlap": float(frustum_overlap),
|
| 502 |
+
"depth_overlap": float(depth_overlap),
|
| 503 |
+
"geometric_support_count": 0,
|
| 504 |
+
"verification_path": "geometry_negative",
|
| 505 |
+
"l2m_match_count": 0,
|
| 506 |
+
"l2m_mean_certainty": 0.0,
|
| 507 |
+
"l2m_inlier_ratio": 0.0,
|
| 508 |
+
"image_path_a": frame_a["image_path"],
|
| 509 |
+
"image_path_b": frame_b["image_path"],
|
| 510 |
+
})
|
| 511 |
+
return records
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def build_argparser():
|
| 515 |
+
parser = argparse.ArgumentParser()
|
| 516 |
+
parser.add_argument("--dataset", required=True, choices=["arkitscenes", "scannetpp", "mvs_synth"])
|
| 517 |
+
parser.add_argument("--root", required=True)
|
| 518 |
+
parser.add_argument("--output", required=True)
|
| 519 |
+
parser.add_argument("--split", default="train")
|
| 520 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 521 |
+
parser.add_argument("--max-scenes", type=int, default=0)
|
| 522 |
+
parser.add_argument("--max-frames-per-scene", type=int, default=0)
|
| 523 |
+
parser.add_argument("--frame-stride", type=int, default=1)
|
| 524 |
+
parser.add_argument("--pair-step", type=int, default=1)
|
| 525 |
+
parser.add_argument("--min-frame-gap", type=int, default=12)
|
| 526 |
+
parser.add_argument("--max-pairs-per-scene", type=int, default=0)
|
| 527 |
+
parser.add_argument("--sample-points", type=int, default=2048)
|
| 528 |
+
parser.add_argument("--depth-tolerance-ratio", type=float, default=0.05)
|
| 529 |
+
parser.add_argument("--max-translation", type=float, default=2.5)
|
| 530 |
+
parser.add_argument("--max-heading-deg", type=float, default=45.0)
|
| 531 |
+
parser.add_argument("--hard-negative-max-translation", type=float, default=3.5)
|
| 532 |
+
parser.add_argument("--hard-negative-max-heading-deg", type=float, default=75.0)
|
| 533 |
+
parser.add_argument("--min-frustum-overlap", type=float, default=0.2)
|
| 534 |
+
parser.add_argument("--min-depth-overlap", type=float, default=0.1)
|
| 535 |
+
parser.add_argument("--negative-max-frustum-overlap", type=float, default=0.05)
|
| 536 |
+
parser.add_argument("--negative-max-depth-overlap", type=float, default=0.02)
|
| 537 |
+
parser.add_argument("--sage-cache", default=None)
|
| 538 |
+
parser.add_argument("--sage-min-score", type=float, default=0.5)
|
| 539 |
+
parser.add_argument("--l2m-cache", default=None)
|
| 540 |
+
parser.add_argument("--l2m-min-match-count", type=int, default=64)
|
| 541 |
+
parser.add_argument("--l2m-min-certainty", type=float, default=0.5)
|
| 542 |
+
parser.add_argument("--l2m-min-inlier-ratio", type=float, default=0.3)
|
| 543 |
+
return parser
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def main():
|
| 547 |
+
args = build_argparser().parse_args()
|
| 548 |
+
rng = np.random.default_rng(args.seed)
|
| 549 |
+
scenes = load_scenes(args.dataset, args.root, args.split)
|
| 550 |
+
if args.max_scenes > 0:
|
| 551 |
+
scenes = scenes[: args.max_scenes]
|
| 552 |
+
sage_cache = load_pair_cache(args.sage_cache)
|
| 553 |
+
l2m_cache = load_pair_cache(args.l2m_cache)
|
| 554 |
+
all_records = []
|
| 555 |
+
scene_stats = []
|
| 556 |
+
for index, scene in enumerate(scenes):
|
| 557 |
+
records = mine_scene(scene, args, sage_cache, l2m_cache, rng)
|
| 558 |
+
all_records.extend(records)
|
| 559 |
+
positives = sum(1 for record in records if record.get("is_positive", False))
|
| 560 |
+
negatives = sum(1 for record in records if not record.get("is_positive", False))
|
| 561 |
+
scene_stats.append({
|
| 562 |
+
"scene": scene["scene"],
|
| 563 |
+
"num_frames": len(scene["frames"]),
|
| 564 |
+
"num_records": len(records),
|
| 565 |
+
"num_positives": positives,
|
| 566 |
+
"num_negatives": negatives,
|
| 567 |
+
})
|
| 568 |
+
print(f"[{index + 1}/{len(scenes)}] {scene['scene']}: {len(records)} records ({positives} pos / {negatives} neg)")
|
| 569 |
+
output_path = Path(args.output)
|
| 570 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 571 |
+
metadata = {
|
| 572 |
+
"dataset": args.dataset,
|
| 573 |
+
"root": args.root,
|
| 574 |
+
"split": args.split,
|
| 575 |
+
"num_scenes": len(scenes),
|
| 576 |
+
"num_records": len(all_records),
|
| 577 |
+
"num_positive_records": sum(1 for record in all_records if record.get("is_positive", False)),
|
| 578 |
+
"num_negative_records": sum(1 for record in all_records if not record.get("is_positive", False)),
|
| 579 |
+
"args": vars(args),
|
| 580 |
+
"scene_stats": scene_stats,
|
| 581 |
+
}
|
| 582 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 583 |
+
json.dump({"metadata": metadata, "records": all_records}, f, indent=2)
|
| 584 |
+
print(f"Wrote {len(all_records)} records to {output_path}")
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
if __name__ == "__main__":
|
| 588 |
+
main()
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/pseudo_gt.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
_POSITIVE_TAGS = {"positive", "pos", "loop", "match", "1", "true"}
|
| 10 |
+
_NEGATIVE_TAGS = {"negative", "neg", "hard_negative", "hard-neg", "0", "false"}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _cfg_get(cfg, key, default=None):
|
| 14 |
+
if cfg is None:
|
| 15 |
+
return default
|
| 16 |
+
if isinstance(cfg, dict):
|
| 17 |
+
return cfg.get(key, default)
|
| 18 |
+
if hasattr(cfg, "get"):
|
| 19 |
+
return cfg.get(key, default)
|
| 20 |
+
return getattr(cfg, key, default)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _first_scalar(value):
|
| 24 |
+
if isinstance(value, (list, tuple)):
|
| 25 |
+
if len(value) == 0:
|
| 26 |
+
return None
|
| 27 |
+
return _first_scalar(value[0])
|
| 28 |
+
if torch.is_tensor(value):
|
| 29 |
+
if value.numel() == 0:
|
| 30 |
+
return None
|
| 31 |
+
return value.detach().reshape(-1)[0].cpu().item()
|
| 32 |
+
return value
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _to_str(value) -> Optional[str]:
|
| 36 |
+
value = _first_scalar(value)
|
| 37 |
+
if value is None:
|
| 38 |
+
return None
|
| 39 |
+
return str(value)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _to_float(value, default: float = 0.0) -> float:
|
| 43 |
+
value = _first_scalar(value)
|
| 44 |
+
if value is None:
|
| 45 |
+
return float(default)
|
| 46 |
+
try:
|
| 47 |
+
return float(value)
|
| 48 |
+
except (TypeError, ValueError):
|
| 49 |
+
return float(default)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def canonical_view_key_from_values(dataset: Optional[str], label: Optional[str]) -> Optional[str]:
|
| 53 |
+
if dataset is None or label is None:
|
| 54 |
+
return None
|
| 55 |
+
return f"{dataset}::{label}"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def canonical_view_key(view: Dict[str, Any]) -> Optional[str]:
|
| 59 |
+
if not isinstance(view, dict):
|
| 60 |
+
return None
|
| 61 |
+
return canonical_view_key_from_values(_to_str(view.get("dataset")), _to_str(view.get("label")))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class FramePairPseudoGTDatabase:
|
| 65 |
+
def __init__(self, records: Dict[Tuple[str, str], Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None):
|
| 66 |
+
self.records = records
|
| 67 |
+
self.metadata = metadata or {}
|
| 68 |
+
|
| 69 |
+
def __len__(self) -> int:
|
| 70 |
+
return len(self.records)
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def from_file(cls, path: str) -> "FramePairPseudoGTDatabase":
|
| 74 |
+
path = os.path.expanduser(path)
|
| 75 |
+
if not os.path.isfile(path):
|
| 76 |
+
raise FileNotFoundError(f"Pseudo-GT cache not found: {path}")
|
| 77 |
+
if path.endswith(".json"):
|
| 78 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 79 |
+
payload = json.load(f)
|
| 80 |
+
elif path.endswith(".jsonl"):
|
| 81 |
+
raw_records = []
|
| 82 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 83 |
+
for line in f:
|
| 84 |
+
line = line.strip()
|
| 85 |
+
if line:
|
| 86 |
+
raw_records.append(json.loads(line))
|
| 87 |
+
payload = {"records": raw_records}
|
| 88 |
+
else:
|
| 89 |
+
payload = torch.load(path, map_location="cpu", weights_only=False)
|
| 90 |
+
|
| 91 |
+
metadata = payload.get("metadata", {}) if isinstance(payload, dict) else {}
|
| 92 |
+
raw_records = payload
|
| 93 |
+
if isinstance(payload, dict):
|
| 94 |
+
raw_records = payload.get("records", payload.get("pairs", payload))
|
| 95 |
+
|
| 96 |
+
records: Dict[Tuple[str, str], Dict[str, Any]] = {}
|
| 97 |
+
if isinstance(raw_records, dict):
|
| 98 |
+
iterator = [
|
| 99 |
+
({**value, "pair_key": key} if isinstance(value, dict) else value)
|
| 100 |
+
for key, value in raw_records.items()
|
| 101 |
+
]
|
| 102 |
+
else:
|
| 103 |
+
iterator = raw_records
|
| 104 |
+
for raw_record in iterator:
|
| 105 |
+
record = cls._normalize_record(raw_record)
|
| 106 |
+
if record is None:
|
| 107 |
+
continue
|
| 108 |
+
pair = record["pair"]
|
| 109 |
+
previous = records.get(pair)
|
| 110 |
+
if previous is None or record["score"] > previous["score"]:
|
| 111 |
+
records[pair] = record
|
| 112 |
+
return cls(records, metadata=metadata)
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def _normalize_record(cls, raw_record: Any) -> Optional[Dict[str, Any]]:
|
| 116 |
+
if not isinstance(raw_record, dict):
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
key_a = _to_str(raw_record.get("key_a") or raw_record.get("frame_key_a"))
|
| 120 |
+
key_b = _to_str(raw_record.get("key_b") or raw_record.get("frame_key_b"))
|
| 121 |
+
if key_a is None:
|
| 122 |
+
key_a = canonical_view_key_from_values(
|
| 123 |
+
_to_str(raw_record.get("dataset_a")),
|
| 124 |
+
_to_str(raw_record.get("label_a")),
|
| 125 |
+
)
|
| 126 |
+
if key_b is None:
|
| 127 |
+
key_b = canonical_view_key_from_values(
|
| 128 |
+
_to_str(raw_record.get("dataset_b")),
|
| 129 |
+
_to_str(raw_record.get("label_b")),
|
| 130 |
+
)
|
| 131 |
+
pair_key = _to_str(raw_record.get("pair_key"))
|
| 132 |
+
if (key_a is None or key_b is None) and pair_key and "||" in pair_key:
|
| 133 |
+
key_a, key_b = pair_key.split("||", 1)
|
| 134 |
+
if key_a is None or key_b is None or key_a == key_b:
|
| 135 |
+
return None
|
| 136 |
+
pair = (key_a, key_b) if key_a <= key_b else (key_b, key_a)
|
| 137 |
+
|
| 138 |
+
tag = (_to_str(raw_record.get("tag")) or "").strip().lower()
|
| 139 |
+
is_positive = bool(raw_record.get("is_positive", False))
|
| 140 |
+
if tag in _POSITIVE_TAGS:
|
| 141 |
+
is_positive = True
|
| 142 |
+
elif tag in _NEGATIVE_TAGS:
|
| 143 |
+
is_positive = False
|
| 144 |
+
|
| 145 |
+
score = _to_float(
|
| 146 |
+
raw_record.get("score"),
|
| 147 |
+
_to_float(raw_record.get("confidence"), _to_float(raw_record.get("pair_confidence_weight"), 0.0)),
|
| 148 |
+
)
|
| 149 |
+
overlap = _to_float(
|
| 150 |
+
raw_record.get("soft_overlap_target"),
|
| 151 |
+
_to_float(raw_record.get("overlap"), score if is_positive else 0.0),
|
| 152 |
+
)
|
| 153 |
+
weight = _to_float(raw_record.get("weight"), _to_float(raw_record.get("pair_confidence_weight"), max(score, overlap)))
|
| 154 |
+
geometric_support_count = int(round(_to_float(raw_record.get("geometric_support_count"), 0.0)))
|
| 155 |
+
l2m_match_count = int(round(_to_float(raw_record.get("l2m_match_count"), 0.0)))
|
| 156 |
+
l2m_mean_certainty = _to_float(raw_record.get("l2m_mean_certainty"), _to_float(raw_record.get("l2m_certainty"), 0.0))
|
| 157 |
+
l2m_inlier_ratio = _to_float(raw_record.get("l2m_inlier_ratio"), 0.0)
|
| 158 |
+
|
| 159 |
+
return {
|
| 160 |
+
"pair": pair,
|
| 161 |
+
"is_positive": is_positive,
|
| 162 |
+
"score": float(score),
|
| 163 |
+
"overlap": float(overlap),
|
| 164 |
+
"weight": float(weight),
|
| 165 |
+
"geometric_support_count": geometric_support_count,
|
| 166 |
+
"l2m_match_count": l2m_match_count,
|
| 167 |
+
"l2m_mean_certainty": float(l2m_mean_certainty),
|
| 168 |
+
"l2m_inlier_ratio": float(l2m_inlier_ratio),
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
def lookup(self, key_a: Optional[str], key_b: Optional[str]) -> Optional[Dict[str, Any]]:
|
| 172 |
+
if key_a is None or key_b is None or key_a == key_b:
|
| 173 |
+
return None
|
| 174 |
+
pair = (key_a, key_b) if key_a <= key_b else (key_b, key_a)
|
| 175 |
+
return self.records.get(pair)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class PseudoGTLoopSupervisor:
|
| 179 |
+
def __init__(self, frame_db: FramePairPseudoGTDatabase, pseudo_gt_cfg):
|
| 180 |
+
self.frame_db = frame_db
|
| 181 |
+
self.use_soft_targets = bool(_cfg_get(pseudo_gt_cfg, "use_soft_targets", True))
|
| 182 |
+
self.min_confidence = float(_cfg_get(pseudo_gt_cfg, "min_confidence", 0.65))
|
| 183 |
+
self.min_support_pairs = max(1, int(_cfg_get(pseudo_gt_cfg, "min_support_pairs", 1)))
|
| 184 |
+
self.topk_pairs = max(1, int(_cfg_get(pseudo_gt_cfg, "topk_pairs", 4)))
|
| 185 |
+
self.loss_weight_gate = float(_cfg_get(pseudo_gt_cfg, "loss_weight_gate", 0.1))
|
| 186 |
+
self.loss_weight_desc = float(_cfg_get(pseudo_gt_cfg, "loss_weight_desc", 0.1))
|
| 187 |
+
self.loss_type = str(_cfg_get(pseudo_gt_cfg, "loss_type", "hybrid"))
|
| 188 |
+
self.geometric_support_scale = float(_cfg_get(pseudo_gt_cfg, "geometric_support_scale", 0.25))
|
| 189 |
+
self.ranking_margin = float(_cfg_get(pseudo_gt_cfg, "ranking_margin", 0.1))
|
| 190 |
+
self.use_l2m = bool(_cfg_get(pseudo_gt_cfg, "use_l2m", False))
|
| 191 |
+
self.l2m_min_certainty = float(_cfg_get(pseudo_gt_cfg, "l2m_min_certainty", 0.0))
|
| 192 |
+
self.l2m_min_inlier_ratio = float(_cfg_get(pseudo_gt_cfg, "l2m_min_inlier_ratio", 0.0))
|
| 193 |
+
|
| 194 |
+
@classmethod
|
| 195 |
+
def from_config(cls, pseudo_gt_cfg) -> Optional["PseudoGTLoopSupervisor"]:
|
| 196 |
+
if pseudo_gt_cfg is None or not bool(_cfg_get(pseudo_gt_cfg, "enable", False)):
|
| 197 |
+
return None
|
| 198 |
+
cache_path = _cfg_get(pseudo_gt_cfg, "cache_path", None)
|
| 199 |
+
if cache_path in (None, "", "null"):
|
| 200 |
+
raise ValueError("`pseudo_gt.enable=true` requires `pseudo_gt.cache_path`.")
|
| 201 |
+
return cls(FramePairPseudoGTDatabase.from_file(cache_path), pseudo_gt_cfg)
|
| 202 |
+
|
| 203 |
+
def _frame_keys(self, batch: Sequence[Dict[str, Any]], frame_ids: Sequence[int]) -> List[str]:
|
| 204 |
+
keys: List[str] = []
|
| 205 |
+
if batch is None:
|
| 206 |
+
return keys
|
| 207 |
+
num_views = len(batch)
|
| 208 |
+
for frame_id in frame_ids:
|
| 209 |
+
if num_views <= 0:
|
| 210 |
+
continue
|
| 211 |
+
index = min(max(int(frame_id), 0), num_views - 1)
|
| 212 |
+
key = canonical_view_key(batch[index])
|
| 213 |
+
if key is not None:
|
| 214 |
+
keys.append(key)
|
| 215 |
+
return keys
|
| 216 |
+
|
| 217 |
+
def _has_geometry(self, record: Dict[str, Any]) -> bool:
|
| 218 |
+
support = int(record.get("geometric_support_count", 0)) > 0 or int(record.get("l2m_match_count", 0)) > 0
|
| 219 |
+
if not support:
|
| 220 |
+
return False
|
| 221 |
+
if self.use_l2m and record.get("l2m_match_count", 0) > 0:
|
| 222 |
+
if float(record.get("l2m_mean_certainty", 0.0)) < self.l2m_min_certainty:
|
| 223 |
+
return False
|
| 224 |
+
if float(record.get("l2m_inlier_ratio", 0.0)) < self.l2m_min_inlier_ratio:
|
| 225 |
+
return False
|
| 226 |
+
return True
|
| 227 |
+
|
| 228 |
+
def build_submap_targets(self, batch, current_frame_ids, history_frame_ids_by_submap):
|
| 229 |
+
current_keys = self._frame_keys(batch, current_frame_ids)
|
| 230 |
+
if not current_keys:
|
| 231 |
+
return []
|
| 232 |
+
targets = []
|
| 233 |
+
for submap_id, history_frame_ids in sorted(history_frame_ids_by_submap.items()):
|
| 234 |
+
history_keys = self._frame_keys(batch, history_frame_ids)
|
| 235 |
+
if not history_keys:
|
| 236 |
+
continue
|
| 237 |
+
records = []
|
| 238 |
+
for current_key in current_keys:
|
| 239 |
+
for history_key in history_keys:
|
| 240 |
+
record = self.frame_db.lookup(current_key, history_key)
|
| 241 |
+
if record is not None:
|
| 242 |
+
records.append(record)
|
| 243 |
+
if not records:
|
| 244 |
+
continue
|
| 245 |
+
positives = [record for record in records if record.get("is_positive", False) and record.get("score", 0.0) >= self.min_confidence]
|
| 246 |
+
negatives = [record for record in records if (not record.get("is_positive", False)) and record.get("score", 0.0) >= self.min_confidence]
|
| 247 |
+
if len(positives) >= self.min_support_pairs:
|
| 248 |
+
ranked = sorted(positives, key=lambda record: max(record.get("weight", 0.0), record.get("overlap", 0.0), record.get("score", 0.0)), reverse=True)[: self.topk_pairs]
|
| 249 |
+
soft_target = sum(record.get("overlap", record.get("score", 0.0)) for record in ranked) / max(1, len(ranked))
|
| 250 |
+
confidence = sum(max(record.get("weight", 0.0), record.get("score", 0.0), record.get("overlap", 0.0)) for record in ranked) / max(1, len(ranked))
|
| 251 |
+
geometry = sum(1 for record in positives if self._has_geometry(record))
|
| 252 |
+
targets.append({
|
| 253 |
+
"submap_id": int(submap_id),
|
| 254 |
+
"binary": 1.0,
|
| 255 |
+
"soft": float(soft_target if self.use_soft_targets else 1.0),
|
| 256 |
+
"weight": float(max(0.05, min(1.0, confidence))),
|
| 257 |
+
"geometry": int(geometry),
|
| 258 |
+
})
|
| 259 |
+
elif negatives and not positives:
|
| 260 |
+
ranked = sorted(negatives, key=lambda record: max(record.get("weight", 0.0), record.get("score", 0.0)), reverse=True)[: self.topk_pairs]
|
| 261 |
+
confidence = sum(max(record.get("weight", 0.0), record.get("score", 0.0)) for record in ranked) / max(1, len(ranked))
|
| 262 |
+
targets.append({
|
| 263 |
+
"submap_id": int(submap_id),
|
| 264 |
+
"binary": 0.0,
|
| 265 |
+
"soft": 0.0,
|
| 266 |
+
"weight": float(max(0.05, min(1.0, confidence))),
|
| 267 |
+
"geometry": 0,
|
| 268 |
+
})
|
| 269 |
+
return targets
|
| 270 |
+
|
| 271 |
+
def compute_loss(self, memory_mgr, batch, hidden_B, meta, loop_gate):
|
| 272 |
+
current_submap_id = int(meta.get("submap_id", -1))
|
| 273 |
+
if current_submap_id <= 0:
|
| 274 |
+
return None, {}
|
| 275 |
+
history_frame_ids_by_submap = {
|
| 276 |
+
int(submap_id): list(frame_ids)
|
| 277 |
+
for submap_id, frame_ids in memory_mgr.buffer.cpu_frame_ids.items()
|
| 278 |
+
if int(submap_id) < current_submap_id
|
| 279 |
+
}
|
| 280 |
+
if not history_frame_ids_by_submap:
|
| 281 |
+
return None, {}
|
| 282 |
+
targets = self.build_submap_targets(batch, meta.get("curr_frame_ids", []), history_frame_ids_by_submap)
|
| 283 |
+
if not targets:
|
| 284 |
+
return None, {}
|
| 285 |
+
|
| 286 |
+
current_desc = meta.get("curr_descriptor")
|
| 287 |
+
if current_desc is None:
|
| 288 |
+
n_prev = int(meta.get("n_prev", 0))
|
| 289 |
+
n_curr = int(meta.get("n_curr", 0))
|
| 290 |
+
current_tokens = hidden_B[n_prev:n_prev + n_curr]
|
| 291 |
+
if current_tokens.numel() == 0:
|
| 292 |
+
return None, {}
|
| 293 |
+
current_desc = memory_mgr.compute_descriptor(current_tokens)
|
| 294 |
+
current_desc = current_desc.float()
|
| 295 |
+
|
| 296 |
+
valid_targets = []
|
| 297 |
+
history_descs = []
|
| 298 |
+
for target in targets:
|
| 299 |
+
history_desc = memory_mgr.buffer.cpu_descriptor_buffer.get(target["submap_id"])
|
| 300 |
+
if history_desc is None:
|
| 301 |
+
continue
|
| 302 |
+
history_desc = history_desc.reshape(-1).to(current_desc.device, non_blocking=True).float()
|
| 303 |
+
history_descs.append(history_desc)
|
| 304 |
+
valid_targets.append(target)
|
| 305 |
+
if not history_descs:
|
| 306 |
+
return None, {}
|
| 307 |
+
|
| 308 |
+
history_descs = torch.stack(history_descs, dim=0)
|
| 309 |
+
predicted_cosine = F.cosine_similarity(current_desc.expand(history_descs.shape[0], -1), history_descs, dim=-1).clamp(min=-1.0, max=1.0)
|
| 310 |
+
predicted_similarity = 0.5 * (predicted_cosine + 1.0)
|
| 311 |
+
target_binary = torch.tensor([target["binary"] for target in valid_targets], device=predicted_similarity.device, dtype=predicted_similarity.dtype)
|
| 312 |
+
target_soft = torch.tensor([target["soft"] for target in valid_targets], device=predicted_similarity.device, dtype=predicted_similarity.dtype)
|
| 313 |
+
weights = torch.tensor([target["weight"] for target in valid_targets], device=predicted_similarity.device, dtype=predicted_similarity.dtype)
|
| 314 |
+
geometry = torch.tensor([target["geometry"] for target in valid_targets], device=predicted_similarity.device, dtype=predicted_similarity.dtype)
|
| 315 |
+
if self.geometric_support_scale > 0:
|
| 316 |
+
weights = (weights * (1.0 + self.geometric_support_scale * geometry.clamp(max=1.0))).clamp(max=1.5)
|
| 317 |
+
|
| 318 |
+
gate_loss = predicted_similarity.new_zeros(())
|
| 319 |
+
if self.loss_weight_gate > 0 and torch.is_tensor(loop_gate) and loop_gate.requires_grad:
|
| 320 |
+
gate_target = target_binary.max()
|
| 321 |
+
gate_weight = weights[target_binary.argmax()] if target_binary.numel() > 0 else weights.new_ones(())
|
| 322 |
+
gate_pred = loop_gate.reshape(-1)[0].clamp(min=1e-6, max=1.0 - 1e-6)
|
| 323 |
+
gate_loss = F.binary_cross_entropy(gate_pred, gate_target.clamp(min=0.0, max=1.0), reduction="none") * gate_weight
|
| 324 |
+
gate_loss = gate_loss.mean() * self.loss_weight_gate
|
| 325 |
+
|
| 326 |
+
descriptor_loss = predicted_similarity.new_zeros(())
|
| 327 |
+
if self.loss_weight_desc > 0:
|
| 328 |
+
regression = F.smooth_l1_loss(predicted_similarity, target_soft.clamp(min=0.0, max=1.0), reduction="none")
|
| 329 |
+
descriptor_loss = (regression * weights).sum() / weights.sum().clamp(min=1e-6)
|
| 330 |
+
if self.loss_type in {"hybrid", "ranking"}:
|
| 331 |
+
pos_mask = target_binary > 0.5
|
| 332 |
+
neg_mask = target_binary < 0.5
|
| 333 |
+
if pos_mask.any() and neg_mask.any():
|
| 334 |
+
pos_score = predicted_similarity[pos_mask].max()
|
| 335 |
+
neg_score = predicted_similarity[neg_mask].max()
|
| 336 |
+
descriptor_loss = descriptor_loss + F.relu(self.ranking_margin - pos_score + neg_score)
|
| 337 |
+
descriptor_loss = descriptor_loss * self.loss_weight_desc
|
| 338 |
+
|
| 339 |
+
total_loss = gate_loss + descriptor_loss
|
| 340 |
+
details = {
|
| 341 |
+
"pseudo_gt_gate_loss": float(gate_loss.detach()),
|
| 342 |
+
"pseudo_gt_desc_loss": float(descriptor_loss.detach()),
|
| 343 |
+
"pseudo_gt_total": float(total_loss.detach()),
|
| 344 |
+
"pseudo_gt_pairs": float(len(valid_targets)),
|
| 345 |
+
"pseudo_gt_positive_pairs": float((target_binary > 0.5).sum().detach()),
|
| 346 |
+
"pseudo_gt_negative_pairs": float((target_binary < 0.5).sum().detach()),
|
| 347 |
+
}
|
| 348 |
+
return total_loss, details
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/__init__.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import trimesh
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from .geometry_utils import NormalGenerator
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import rerun as rr
|
| 16 |
+
from .visualization_utils import reverse_imagenet_normalize, colormap_image
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from typing import Dict, Any
|
| 20 |
+
|
| 21 |
+
# depth prediction normals computer
|
| 22 |
+
#PRED_FORMAT_SIZE = [480,640]#[192, 256]
|
| 23 |
+
PRED_FORMAT_SIZE = [680,1200]#[192, 256]
|
| 24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
+
compute_normals = NormalGenerator(PRED_FORMAT_SIZE[0], PRED_FORMAT_SIZE[1]).to(device)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def to_device(input_dict, key_ignores=[], device="cuda"):
|
| 29 |
+
""" " Moves tensors in the input dict to the gpu and ignores tensors/elements
|
| 30 |
+
as with keys in key_ignores.
|
| 31 |
+
"""
|
| 32 |
+
for k, v in input_dict.items():
|
| 33 |
+
if k not in key_ignores:
|
| 34 |
+
input_dict[k] = v.to(device).float()
|
| 35 |
+
return input_dict
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def log_source_data(src_entity_path: str, src_data: Dict[str, Any]) -> None:
|
| 39 |
+
src_images_k3hw = reverse_imagenet_normalize(
|
| 40 |
+
torch.tensor(src_data["image_b3hw"][0].to(device))
|
| 41 |
+
)
|
| 42 |
+
num_src_cameras = src_data["world_T_cam_b44"][0].shape[0]
|
| 43 |
+
for src_idx in range(num_src_cameras):
|
| 44 |
+
src_cam_path = f"{src_entity_path}/{src_idx}"
|
| 45 |
+
world_T_cam_44 = src_data["world_T_cam_b44"][0][src_idx].squeeze().cpu().numpy()
|
| 46 |
+
K_44 = src_data["K_s0_b44"][0][src_idx].squeeze().cpu().numpy()
|
| 47 |
+
log_camera(src_cam_path, world_T_cam_44, K_44)
|
| 48 |
+
log_image(src_cam_path, src_images_k3hw[src_idx], denormalize=False)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def log_camera(
|
| 53 |
+
entity_path: str, world_T_cam_44: torch.Tensor, K_44: torch.Tensor, kfd=False, update=False,
|
| 54 |
+
) -> None:
|
| 55 |
+
assert world_T_cam_44.shape == (4, 4)
|
| 56 |
+
assert K_44.shape == (4, 4)
|
| 57 |
+
# Convert and log camera parameters
|
| 58 |
+
Rot, trans = world_T_cam_44[:3, :3], world_T_cam_44[:3, 3]
|
| 59 |
+
K_33 = K_44[:3, :3]
|
| 60 |
+
|
| 61 |
+
K_33[:2] /= 4
|
| 62 |
+
|
| 63 |
+
rr.log(entity_path, rr.Transform3D(translation=trans, mat3x3=Rot))#, axis_length=0))
|
| 64 |
+
if not update: # frontend
|
| 65 |
+
if not kfd:
|
| 66 |
+
rr.log(
|
| 67 |
+
entity_path+'/frustum',
|
| 68 |
+
rr.Pinhole(
|
| 69 |
+
#image_from_camera=K_33,
|
| 70 |
+
#width=PRED_FORMAT_SIZE[1]/4,
|
| 71 |
+
#height=PRED_FORMAT_SIZE[0]/4,
|
| 72 |
+
fov_y=0.7853982,
|
| 73 |
+
aspect_ratio=1.7777778,
|
| 74 |
+
#camera_xyz=rr.ViewCoordinates.RUB,
|
| 75 |
+
camera_xyz=None,
|
| 76 |
+
image_plane_distance=0.1,
|
| 77 |
+
color=[0, 255, 0],
|
| 78 |
+
line_width=0.003,
|
| 79 |
+
),
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
rr.log(
|
| 83 |
+
entity_path+'/frustum',
|
| 84 |
+
rr.Pinhole(
|
| 85 |
+
image_from_camera=K_33,
|
| 86 |
+
width=PRED_FORMAT_SIZE[1]/4,
|
| 87 |
+
height=PRED_FORMAT_SIZE[0]/4,
|
| 88 |
+
),
|
| 89 |
+
)
|
| 90 |
+
else:# backend
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def log_window(
|
| 96 |
+
entity_path: str, world_T_cam_44: torch.Tensor, K_44: torch.Tensor
|
| 97 |
+
) -> None:
|
| 98 |
+
assert world_T_cam_44.shape == (4, 4)
|
| 99 |
+
assert K_44.shape == (4, 4)
|
| 100 |
+
# Convert and log camera parameters
|
| 101 |
+
Rot, trans = world_T_cam_44[:3, :3], world_T_cam_44[:3, 3]
|
| 102 |
+
rr.log(entity_path, rr.Transform3D(translation=trans, mat3x3=Rot))#, axis_length=0))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def log_image(
|
| 107 |
+
entity_path: str, color_frame_b3hw: torch.Tensor, denormalize=True
|
| 108 |
+
) -> None:
|
| 109 |
+
# Image logging
|
| 110 |
+
color_frame_3hw = color_frame_b3hw.squeeze(0)
|
| 111 |
+
if denormalize:
|
| 112 |
+
main_color_3hw = reverse_imagenet_normalize(color_frame_3hw)
|
| 113 |
+
else:
|
| 114 |
+
main_color_3hw = color_frame_3hw
|
| 115 |
+
pil_image = Image.fromarray(
|
| 116 |
+
np.uint8(main_color_3hw.permute(1, 2, 0).cpu().detach().numpy() * 255)
|
| 117 |
+
)
|
| 118 |
+
pil_image = pil_image.resize((PRED_FORMAT_SIZE[1], PRED_FORMAT_SIZE[0]))
|
| 119 |
+
rr.log(f"{entity_path}/image/rgb", rr.Image(pil_image))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def log_rerun(
|
| 123 |
+
entity_path: str,
|
| 124 |
+
cur_data: Dict[str, Any],
|
| 125 |
+
src_data: Dict[str, Any],
|
| 126 |
+
outputs: Dict[str, Any],
|
| 127 |
+
scene_trimesh_mesh: trimesh.Trimesh,
|
| 128 |
+
should_log_source_cams: bool = True,
|
| 129 |
+
) -> None:
|
| 130 |
+
"""
|
| 131 |
+
Logs camera intri/extri, depth, rgb, and mesh to rerun.
|
| 132 |
+
"""
|
| 133 |
+
curr_entity_path = f"{entity_path}/current_cam"
|
| 134 |
+
src_entity_path = f"{entity_path}/source_cam"
|
| 135 |
+
if should_log_source_cams:
|
| 136 |
+
log_source_data(src_entity_path, src_data)
|
| 137 |
+
|
| 138 |
+
world_T_cam_44 = cur_data["world_T_cam_b44"].squeeze().cpu().numpy()
|
| 139 |
+
K_44 = cur_data["K_s0_b44"].squeeze().cpu().numpy()
|
| 140 |
+
log_camera(curr_entity_path, world_T_cam_44, K_44)
|
| 141 |
+
|
| 142 |
+
# Depth logging
|
| 143 |
+
depth_pred = outputs["depth_pred_s0_b1hw"]
|
| 144 |
+
our_depth_3hw = depth_pred.squeeze(0)
|
| 145 |
+
our_depth_hw3 = our_depth_3hw.permute(1, 2, 0)
|
| 146 |
+
rr.log(
|
| 147 |
+
f"{curr_entity_path}/image/depth",
|
| 148 |
+
rr.DepthImage(our_depth_hw3.numpy(force=True)),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Normal logging
|
| 152 |
+
invK_s0_b44 = cur_data["invK_s0_b44"].to(device)
|
| 153 |
+
normals_b3hw = compute_normals(depth_pred, invK_s0_b44)
|
| 154 |
+
our_normals_3hw = 0.5 * (1 + normals_b3hw).squeeze(0)
|
| 155 |
+
pil_normal = Image.fromarray(
|
| 156 |
+
np.uint8(our_normals_3hw.permute(1, 2, 0).cpu().detach().numpy() * 255)
|
| 157 |
+
)
|
| 158 |
+
rr.log(f"{curr_entity_path}/image/normal", rr.Image(pil_normal))
|
| 159 |
+
|
| 160 |
+
# Image logging
|
| 161 |
+
color_frame_b3hw = (
|
| 162 |
+
cur_data["high_res_color_b3hw"]
|
| 163 |
+
if "high_res_color_b3hw" in cur_data
|
| 164 |
+
else cur_data["image_b3hw"]
|
| 165 |
+
)
|
| 166 |
+
color_frame_3hw = color_frame_b3hw.squeeze(0)
|
| 167 |
+
main_color_3hw = reverse_imagenet_normalize(color_frame_3hw)
|
| 168 |
+
pil_image = Image.fromarray(
|
| 169 |
+
np.uint8(main_color_3hw.permute(1, 2, 0).cpu().detach().numpy() * 255)
|
| 170 |
+
)
|
| 171 |
+
pil_image = pil_image.resize((PRED_FORMAT_SIZE[1], PRED_FORMAT_SIZE[0]))
|
| 172 |
+
rr.log(f"{curr_entity_path}/image/rgb", rr.Image(pil_image))
|
| 173 |
+
|
| 174 |
+
# lowest cost guess from the cost volume
|
| 175 |
+
lowest_cost_bhw = outputs["lowest_cost_bhw"]
|
| 176 |
+
lowest_cost_3hw = colormap_image(
|
| 177 |
+
lowest_cost_bhw,
|
| 178 |
+
vmin=0,
|
| 179 |
+
vmax=5,
|
| 180 |
+
)
|
| 181 |
+
pil_cost = Image.fromarray(
|
| 182 |
+
np.uint8(lowest_cost_3hw.permute(1, 2, 0).cpu().detach().numpy() * 255)
|
| 183 |
+
)
|
| 184 |
+
pil_cost = pil_cost.resize((PRED_FORMAT_SIZE[1], PRED_FORMAT_SIZE[0]))
|
| 185 |
+
rr.log("lowest_cost_volume", rr.Image(pil_cost))
|
| 186 |
+
|
| 187 |
+
# Fused mesh logging
|
| 188 |
+
rr.log(
|
| 189 |
+
f"{entity_path}/mesh",
|
| 190 |
+
rr.Mesh3D(
|
| 191 |
+
vertex_positions=scene_trimesh_mesh.vertices,
|
| 192 |
+
triangle_indices=scene_trimesh_mesh.faces,
|
| 193 |
+
vertex_colors=scene_trimesh_mesh.visual.vertex_colors,
|
| 194 |
+
),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/generic_utils.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import kornia
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision.transforms.functional as TF
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def copy_code_state(path):
|
| 16 |
+
"""Copies the code directory into the path specified using rsync. It will
|
| 17 |
+
use a .gitignore file to exclude files in rsync. We preserve modification
|
| 18 |
+
times in rsync."""
|
| 19 |
+
|
| 20 |
+
# create dir
|
| 21 |
+
Path(os.path.join(path)).mkdir(parents=True, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
if os.path.exists("./.gitignore"):
|
| 24 |
+
# use .gitignore to remove junk
|
| 25 |
+
rsync_command = (
|
| 26 |
+
f"rsync -art --exclude-from='./.gitignore' --exclude '.git' . {path}"
|
| 27 |
+
)
|
| 28 |
+
else:
|
| 29 |
+
print("WARNING: no .gitignore found so can't use that to exlcude large "
|
| 30 |
+
"files when making a back up of files in copy_code_state.")
|
| 31 |
+
rsync_command = (
|
| 32 |
+
f"rsync -art --exclude '.git' . {path}"
|
| 33 |
+
)
|
| 34 |
+
os.system(rsync_command)
|
| 35 |
+
|
| 36 |
+
def readlines(filepath):
|
| 37 |
+
""" Reads in a text file and returns lines in a list. """
|
| 38 |
+
with open(filepath, 'r') as f:
|
| 39 |
+
lines = f.read().splitlines()
|
| 40 |
+
return lines
|
| 41 |
+
|
| 42 |
+
def normalize_depth_single(depth_11hw, mask_11hw, robust=False):
|
| 43 |
+
|
| 44 |
+
if mask_11hw is not None:
|
| 45 |
+
valid_depth_vals_N = depth_11hw.masked_select(mask_11hw)
|
| 46 |
+
else:
|
| 47 |
+
valid_depth_vals_N = torch.flatten(depth_11hw)
|
| 48 |
+
|
| 49 |
+
num_valid_pix = valid_depth_vals_N.nelement()
|
| 50 |
+
num_percentile_pix = num_valid_pix // 10
|
| 51 |
+
|
| 52 |
+
if num_valid_pix == 0:
|
| 53 |
+
return depth_11hw
|
| 54 |
+
|
| 55 |
+
sorted_depth_vals_N = torch.sort(valid_depth_vals_N)[0]
|
| 56 |
+
depth_flat_N = sorted_depth_vals_N[num_percentile_pix:-num_percentile_pix]
|
| 57 |
+
|
| 58 |
+
if depth_flat_N.nelement() == 0:
|
| 59 |
+
depth_flat_N = valid_depth_vals_N
|
| 60 |
+
|
| 61 |
+
if robust:
|
| 62 |
+
depth_shift = depth_flat_N.median()
|
| 63 |
+
depth_scale = torch.mean(torch.abs(depth_flat_N - depth_shift))
|
| 64 |
+
else:
|
| 65 |
+
depth_shift = depth_flat_N.mean()
|
| 66 |
+
depth_scale = depth_flat_N.std()
|
| 67 |
+
|
| 68 |
+
depth_norm = (depth_11hw - depth_shift) / depth_scale
|
| 69 |
+
|
| 70 |
+
return depth_norm
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def normalize_depth(depth_b1hw: torch.Tensor,
|
| 74 |
+
mask_b1hw: torch.Tensor = None,
|
| 75 |
+
robust: bool = False):
|
| 76 |
+
|
| 77 |
+
depths_11hw = torch.split(depth_b1hw, 1, 0)
|
| 78 |
+
masks_11hw = ([None] * len(depths_11hw) if mask_b1hw is None
|
| 79 |
+
else torch.split(mask_b1hw, 1, 0))
|
| 80 |
+
|
| 81 |
+
depths_norm_11hw = [normalize_depth_single(d, m, robust)
|
| 82 |
+
for d, m in zip(depths_11hw, masks_11hw)]
|
| 83 |
+
|
| 84 |
+
return torch.cat(depths_norm_11hw, dim=0)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def upsample(x):
|
| 89 |
+
"""
|
| 90 |
+
Upsample input tensor by a factor of 2
|
| 91 |
+
"""
|
| 92 |
+
return nn.functional.interpolate(
|
| 93 |
+
x,
|
| 94 |
+
scale_factor=2,
|
| 95 |
+
mode="bilinear",
|
| 96 |
+
align_corners=False,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def batched_trace(mat_bNN):
|
| 100 |
+
return mat_bNN.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
|
| 101 |
+
|
| 102 |
+
def tensor_B_to_bM(tensor_BS, batch_size, num_views):
|
| 103 |
+
"""Unpacks a flattened tensor of tupled elements (BS) into bMS. Tuple size
|
| 104 |
+
is M."""
|
| 105 |
+
# S for wild card number of dims in the middle
|
| 106 |
+
# tensor_bSM = tensor_BS.unfold(0, step=num_views, size=num_views)
|
| 107 |
+
# tensor_bMS = tensor_bSM.movedim(-1, 1)
|
| 108 |
+
tensor_bMS = tensor_BS.view([batch_size, num_views] + list(tensor_BS.shape[1:]))
|
| 109 |
+
|
| 110 |
+
return tensor_bMS
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def tensor_bM_to_B(tensor_bMS):
|
| 114 |
+
"""Packs an inflated tensor of tupled elements (bMS) into BS. Tuple size
|
| 115 |
+
is M."""
|
| 116 |
+
# S for wild card number of dims in the middle
|
| 117 |
+
num_views = tensor_bMS.shape[1]
|
| 118 |
+
num_batches = tensor_bMS.shape[0]
|
| 119 |
+
|
| 120 |
+
tensor_BS = tensor_bMS.view([num_views * num_batches] + list(tensor_bMS.shape[2:]))
|
| 121 |
+
|
| 122 |
+
return tensor_BS
|
| 123 |
+
|
| 124 |
+
def combine_dims(x, dim_begin, dim_end):
|
| 125 |
+
"""Views x with the dimensions from dim_begin to dim_end folded."""
|
| 126 |
+
combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
|
| 127 |
+
return x.view(combined_shape)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def to_gpu(input_dict, key_ignores=[]):
|
| 131 |
+
"""" Moves tensors in the input dict to the gpu and ignores tensors/elements
|
| 132 |
+
as with keys in key_ignores.
|
| 133 |
+
"""
|
| 134 |
+
for k, v in input_dict.items():
|
| 135 |
+
if k not in key_ignores:
|
| 136 |
+
input_dict[k] = v.cuda().float()
|
| 137 |
+
return input_dict
|
| 138 |
+
|
| 139 |
+
def imagenet_normalize(image):
|
| 140 |
+
""" Normalizes an image with ImageNet statistics. """
|
| 141 |
+
image = TF.normalize(tensor=image,
|
| 142 |
+
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
| 143 |
+
return image
|
| 144 |
+
|
| 145 |
+
def reverse_imagenet_normalize(image):
|
| 146 |
+
""" Reverses ImageNet normalization in an input image. """
|
| 147 |
+
|
| 148 |
+
image = TF.normalize(tensor=image,
|
| 149 |
+
mean=(-2.11790393, -2.03571429, -1.80444444),
|
| 150 |
+
std=(4.36681223, 4.46428571, 4.44444444))
|
| 151 |
+
return image
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def read_image_file(filepath,
|
| 155 |
+
height=None,
|
| 156 |
+
width=None,
|
| 157 |
+
value_scale_factor=1.0,
|
| 158 |
+
resampling_mode=Image.BILINEAR,
|
| 159 |
+
disable_warning=False,
|
| 160 |
+
target_aspect_ratio=None):
|
| 161 |
+
"""" Reads an image file using PIL, then optionally resizes the image,
|
| 162 |
+
with selective resampling, scales values, and returns the image as a
|
| 163 |
+
tensor
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
filepath: path to the image.
|
| 167 |
+
height, width: resolution to resize the image to. Both must not be
|
| 168 |
+
None for scaling to take place.
|
| 169 |
+
value_scale_factor: value to scale image values with, default is 1.0
|
| 170 |
+
resampling_mode: resampling method when resizing using PIL. Default
|
| 171 |
+
is PIL.Image.BILINEAR
|
| 172 |
+
target_aspect_ratio: if not None, will crop the image to match this
|
| 173 |
+
aspect ratio. Default is None
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
img: tensor with (optionally) scaled and resized image data.
|
| 177 |
+
|
| 178 |
+
"""
|
| 179 |
+
img = Image.open(filepath)
|
| 180 |
+
|
| 181 |
+
if target_aspect_ratio:
|
| 182 |
+
crop_image_to_target_ratio(img, target_aspect_ratio)
|
| 183 |
+
|
| 184 |
+
# resize if both width and height are not none.
|
| 185 |
+
if height is not None and width is not None:
|
| 186 |
+
img_width, img_height = img.size
|
| 187 |
+
# do we really need to resize? If not, skip.
|
| 188 |
+
if (img_width, img_height) != (width, height):
|
| 189 |
+
# warn if it doesn't make sense.
|
| 190 |
+
if ((width > img_width or height > img_height) and
|
| 191 |
+
not disable_warning):
|
| 192 |
+
logger.warning(
|
| 193 |
+
f"WARNING: target size ({width}, {height}) has a "
|
| 194 |
+
f"dimension larger than input size ({img_width}, "
|
| 195 |
+
f"{img_height}).")
|
| 196 |
+
img = img.resize((width, height), resample=resampling_mode)
|
| 197 |
+
|
| 198 |
+
img = TF.to_tensor(img).float() * value_scale_factor
|
| 199 |
+
|
| 200 |
+
return img
|
| 201 |
+
|
| 202 |
+
def crop_image_to_target_ratio(image, target_aspect_ratio=4.0/3.0):
|
| 203 |
+
""" Crops an image to satisfy a target aspect ratio. """
|
| 204 |
+
|
| 205 |
+
actual_aspect_ratio = image.width/image.height
|
| 206 |
+
|
| 207 |
+
if actual_aspect_ratio > target_aspect_ratio:
|
| 208 |
+
# we should crop width
|
| 209 |
+
new_width = image.height * target_aspect_ratio
|
| 210 |
+
|
| 211 |
+
left = (image.width - new_width)/2
|
| 212 |
+
top = 0
|
| 213 |
+
right = (image.width + new_width)/2
|
| 214 |
+
bottom = image.height
|
| 215 |
+
|
| 216 |
+
# Crop the center of the image
|
| 217 |
+
image = image.crop((left, top, right, bottom))
|
| 218 |
+
|
| 219 |
+
elif actual_aspect_ratio < target_aspect_ratio:
|
| 220 |
+
# we should crop height
|
| 221 |
+
new_height = image.width/target_aspect_ratio
|
| 222 |
+
|
| 223 |
+
left = 0
|
| 224 |
+
top = (image.height - new_height)/2
|
| 225 |
+
right = image.width
|
| 226 |
+
bottom = (image.height + new_height)/2
|
| 227 |
+
|
| 228 |
+
# Crop the center of the image
|
| 229 |
+
image = image.crop((left, top, right, bottom))
|
| 230 |
+
|
| 231 |
+
return image
|
| 232 |
+
|
| 233 |
+
def cache_model_outputs(
|
| 234 |
+
output_path,
|
| 235 |
+
outputs,
|
| 236 |
+
cur_data,
|
| 237 |
+
src_data,
|
| 238 |
+
batch_ind,
|
| 239 |
+
batch_size,
|
| 240 |
+
):
|
| 241 |
+
""" Helper function for model output during inference. """
|
| 242 |
+
|
| 243 |
+
for elem_ind in range(outputs["depth_pred_s0_b1hw"].shape[0]):
|
| 244 |
+
if "frame_id_string" in cur_data:
|
| 245 |
+
frame_id = cur_data["frame_id_string"][elem_ind]
|
| 246 |
+
else:
|
| 247 |
+
frame_id = (batch_ind * batch_size) + elem_ind
|
| 248 |
+
frame_id = f"{str(frame_id):6d}"
|
| 249 |
+
|
| 250 |
+
elem_filepath = os.path.join(output_path, f"{frame_id}.pickle")
|
| 251 |
+
|
| 252 |
+
elem_output_dict = {}
|
| 253 |
+
|
| 254 |
+
for key in outputs:
|
| 255 |
+
if outputs[key] is not None:
|
| 256 |
+
elem_output_dict[key] = outputs[key][elem_ind].unsqueeze(0)
|
| 257 |
+
else:
|
| 258 |
+
elem_output_dict[key] = None
|
| 259 |
+
|
| 260 |
+
# include some auxiliary information
|
| 261 |
+
elem_output_dict["K_full_depth_b44"] = cur_data[
|
| 262 |
+
"K_full_depth_b44"
|
| 263 |
+
][elem_ind].unsqueeze(0)
|
| 264 |
+
elem_output_dict["K_s0_b44"] = cur_data[
|
| 265 |
+
"K_s0_b44"
|
| 266 |
+
][elem_ind].unsqueeze(0)
|
| 267 |
+
|
| 268 |
+
elem_output_dict["frame_id"] = cur_data["frame_id_string"][elem_ind]
|
| 269 |
+
elem_output_dict["src_ids"] = []
|
| 270 |
+
for src_id_list in src_data["frame_id_string"]:
|
| 271 |
+
elem_output_dict["src_ids"].append(src_id_list[elem_ind])
|
| 272 |
+
|
| 273 |
+
with open(elem_filepath, 'wb') as handle:
|
| 274 |
+
pickle.dump(elem_output_dict, handle)
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/geometry_utils.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import kornia
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.jit as jit
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@torch.jit.script
|
| 11 |
+
def to_homogeneous(input_tensor: Tensor, dim: int = 0) -> Tensor:
|
| 12 |
+
"""
|
| 13 |
+
Converts tensor to homogeneous coordinates by adding ones to the specified
|
| 14 |
+
dimension
|
| 15 |
+
"""
|
| 16 |
+
ones = torch.ones_like(input_tensor.select(dim, 0).unsqueeze(dim))
|
| 17 |
+
output_bkN = torch.cat([input_tensor, ones], dim=dim)
|
| 18 |
+
return output_bkN
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BackprojectDepth(jit.ScriptModule):
|
| 22 |
+
"""
|
| 23 |
+
Layer that projects points from 2D camera to 3D space. The 3D points are
|
| 24 |
+
represented in homogeneous coordinates.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, height: int, width: int):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.height = height
|
| 31 |
+
self.width = width
|
| 32 |
+
|
| 33 |
+
xx, yy = torch.meshgrid(
|
| 34 |
+
torch.arange(self.width),
|
| 35 |
+
torch.arange(self.height),
|
| 36 |
+
indexing='xy',
|
| 37 |
+
)
|
| 38 |
+
pix_coords_2hw = torch.stack((xx, yy), axis=0) + 0.5
|
| 39 |
+
|
| 40 |
+
pix_coords_13N = to_homogeneous(
|
| 41 |
+
pix_coords_2hw,
|
| 42 |
+
dim=0,
|
| 43 |
+
).flatten(1).unsqueeze(0)
|
| 44 |
+
|
| 45 |
+
# make these tensors into buffers so they are put on the correct GPU
|
| 46 |
+
# automatically
|
| 47 |
+
self.register_buffer("pix_coords_13N", pix_coords_13N)
|
| 48 |
+
|
| 49 |
+
@jit.script_method
|
| 50 |
+
def forward(self, depth_b1hw: Tensor, invK_b44: Tensor) -> Tensor:
|
| 51 |
+
"""
|
| 52 |
+
Backprojects spatial points in 2D image space to world space using
|
| 53 |
+
invK_b44 at the depths defined in depth_b1hw.
|
| 54 |
+
"""
|
| 55 |
+
cam_points_b3N = torch.matmul(invK_b44[:, :3, :3], self.pix_coords_13N)
|
| 56 |
+
cam_points_b3N = depth_b1hw.flatten(start_dim=2) * cam_points_b3N
|
| 57 |
+
cam_points_b4N = to_homogeneous(cam_points_b3N, dim=1)
|
| 58 |
+
return cam_points_b4N
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Project3D(jit.ScriptModule):
|
| 62 |
+
"""
|
| 63 |
+
Layer that projects 3D points into the 2D camera
|
| 64 |
+
"""
|
| 65 |
+
def __init__(self, eps: float = 1e-8):
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
self.register_buffer("eps", torch.tensor(eps).view(1, 1, 1))
|
| 69 |
+
|
| 70 |
+
@jit.script_method
|
| 71 |
+
def forward(self, points_b4N: Tensor,
|
| 72 |
+
K_b44: Tensor, cam_T_world_b44: Tensor) -> Tensor:
|
| 73 |
+
"""
|
| 74 |
+
Projects spatial points in 3D world space to camera image space using
|
| 75 |
+
the extrinsics matrix cam_T_world_b44 and intrinsics K_b44.
|
| 76 |
+
"""
|
| 77 |
+
P_b44 = K_b44 @ cam_T_world_b44
|
| 78 |
+
|
| 79 |
+
cam_points_b3N = P_b44[:, :3] @ points_b4N
|
| 80 |
+
|
| 81 |
+
# from Kornia and OpenCV, https://kornia.readthedocs.io/en/latest/_modules/kornia/geometry/conversions.html#convert_points_from_homogeneous
|
| 82 |
+
mask = torch.abs(cam_points_b3N[:, 2:]) > self.eps
|
| 83 |
+
depth_b1N = (cam_points_b3N[:, 2:] + self.eps)
|
| 84 |
+
scale = torch.where(mask, 1.0 / depth_b1N, torch.tensor(1.0, device=depth_b1N.device))
|
| 85 |
+
|
| 86 |
+
pix_coords_b2N = cam_points_b3N[:, :2] * scale
|
| 87 |
+
|
| 88 |
+
return torch.cat([pix_coords_b2N, depth_b1N], dim=1)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class NormalGenerator(jit.ScriptModule):
|
| 92 |
+
def __init__(self, height: int, width: int,
|
| 93 |
+
smoothing_kernel_size: int=5, smoothing_kernel_std: float=2.0):
|
| 94 |
+
"""
|
| 95 |
+
Estimates normals from depth maps.
|
| 96 |
+
"""
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.height = height
|
| 99 |
+
self.width = width
|
| 100 |
+
|
| 101 |
+
self.backproject = BackprojectDepth(self.height, self.width)
|
| 102 |
+
|
| 103 |
+
self.kernel_size = smoothing_kernel_size
|
| 104 |
+
self.std = smoothing_kernel_std
|
| 105 |
+
|
| 106 |
+
def forward(self, depth_b1hw: Tensor, invK_b44: Tensor) -> Tensor:
|
| 107 |
+
"""
|
| 108 |
+
First smoothes incoming depth maps with a gaussian blur, backprojects
|
| 109 |
+
those depth points into world space (see BackprojectDepth), estimates
|
| 110 |
+
the spatial gradient at those points, and finally uses normalized cross
|
| 111 |
+
correlation to estimate a normal vector at each location.
|
| 112 |
+
|
| 113 |
+
"""
|
| 114 |
+
depth_smooth_b1hw = kornia.filters.gaussian_blur2d(
|
| 115 |
+
depth_b1hw,
|
| 116 |
+
(self.kernel_size, self.kernel_size),
|
| 117 |
+
(self.std, self.std),
|
| 118 |
+
)
|
| 119 |
+
cam_points_b4N = self.backproject(depth_smooth_b1hw, invK_b44)
|
| 120 |
+
cam_points_b3hw = cam_points_b4N[:, :3].view(-1, 3, self.height, self.width)
|
| 121 |
+
|
| 122 |
+
gradients_b32hw = kornia.filters.spatial_gradient(cam_points_b3hw)
|
| 123 |
+
|
| 124 |
+
return F.normalize(
|
| 125 |
+
torch.cross(
|
| 126 |
+
gradients_b32hw[:, :, 0],
|
| 127 |
+
gradients_b32hw[:, :, 1],
|
| 128 |
+
dim=1,
|
| 129 |
+
),
|
| 130 |
+
dim=1,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_camera_rays(
|
| 135 |
+
world_T_cam_b44,
|
| 136 |
+
world_points_b3N,
|
| 137 |
+
in_camera_frame,
|
| 138 |
+
cam_T_world_b44=None,
|
| 139 |
+
eps=1e-4,
|
| 140 |
+
):
|
| 141 |
+
"""
|
| 142 |
+
Computes camera rays for given camera data and points, optionally shifts
|
| 143 |
+
rays to camera frame.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
if in_camera_frame:
|
| 147 |
+
batch_size = world_points_b3N.shape[0]
|
| 148 |
+
num_points = world_points_b3N.shape[2]
|
| 149 |
+
world_points_b4N = torch.cat(
|
| 150 |
+
[
|
| 151 |
+
world_points_b3N,
|
| 152 |
+
torch.ones(batch_size, 1, num_points).to(world_points_b3N.device),
|
| 153 |
+
],
|
| 154 |
+
1,
|
| 155 |
+
)
|
| 156 |
+
camera_points_b3N = torch.matmul(cam_T_world_b44[:, :3, :4],
|
| 157 |
+
world_points_b4N)
|
| 158 |
+
rays_b3N = camera_points_b3N
|
| 159 |
+
else:
|
| 160 |
+
rays_b3N = world_points_b3N - world_T_cam_b44[:, 0:3, 3][:, :, None].expand(
|
| 161 |
+
world_points_b3N.shape
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
rays_b3N = torch.nn.functional.normalize(rays_b3N, dim=1)
|
| 165 |
+
|
| 166 |
+
return rays_b3N
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def pose_distance(pose_b44):
|
| 170 |
+
"""
|
| 171 |
+
DVMVS frame pose distance.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
R = pose_b44[:, :3, :3]
|
| 175 |
+
t = pose_b44[:, :3, 3]
|
| 176 |
+
R_trace = R.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
|
| 177 |
+
R_measure = torch.sqrt(2 *
|
| 178 |
+
(1 - torch.minimum(torch.ones_like(R_trace)*3.0, R_trace) / 3))
|
| 179 |
+
t_measure = torch.norm(t, dim=1)
|
| 180 |
+
combined_measure = torch.sqrt(t_measure ** 2 + R_measure ** 2)
|
| 181 |
+
|
| 182 |
+
return combined_measure, R_measure, t_measure
|
| 183 |
+
|
| 184 |
+
def qvec2rotmat(qvec):
|
| 185 |
+
"""
|
| 186 |
+
Quaternion to 3x3 rotation matrix.
|
| 187 |
+
"""
|
| 188 |
+
return np.array([
|
| 189 |
+
[
|
| 190 |
+
1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
|
| 191 |
+
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
| 192 |
+
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
|
| 193 |
+
], [
|
| 194 |
+
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
| 195 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
|
| 196 |
+
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
|
| 197 |
+
], [
|
| 198 |
+
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
| 199 |
+
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
| 200 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
|
| 201 |
+
]
|
| 202 |
+
])
|
| 203 |
+
|
| 204 |
+
def rotx(t):
|
| 205 |
+
"""
|
| 206 |
+
3D Rotation about the x-axis.
|
| 207 |
+
"""
|
| 208 |
+
c = np.cos(t)
|
| 209 |
+
s = np.sin(t)
|
| 210 |
+
return np.array([[1, 0, 0],
|
| 211 |
+
[0, c, -s],
|
| 212 |
+
[0, s, c]])
|
| 213 |
+
|
| 214 |
+
def roty(t):
|
| 215 |
+
"""
|
| 216 |
+
3D Rotation about the y-axis.
|
| 217 |
+
"""
|
| 218 |
+
c = np.cos(t)
|
| 219 |
+
s = np.sin(t)
|
| 220 |
+
return np.array([[c, 0, s],
|
| 221 |
+
[0, 1, 0],
|
| 222 |
+
[-s, 0, c]])
|
| 223 |
+
|
| 224 |
+
def rotz(t):
|
| 225 |
+
"""
|
| 226 |
+
3D Rotation about the z-axis.
|
| 227 |
+
"""
|
| 228 |
+
c = np.cos(t)
|
| 229 |
+
s = np.sin(t)
|
| 230 |
+
return np.array([[c, -s, 0],
|
| 231 |
+
[s, c, 0],
|
| 232 |
+
[0, 0, 1]])
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/tmp.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import rerun as rr
|
| 3 |
+
|
| 4 |
+
# 初始化 Rerun
|
| 5 |
+
rr.init("Multi-Camera Pose Example",spawn=True)
|
| 6 |
+
|
| 7 |
+
# 假设你有两个摄像头的位姿和内参
|
| 8 |
+
# 摄像头1的位姿和内参
|
| 9 |
+
pose1 = np.array([
|
| 10 |
+
[0.99, -0.10, 0.10, 1.0],
|
| 11 |
+
[0.10, 0.99, -0.10, 2.0],
|
| 12 |
+
[-0.10, 0.10, 0.99, 3.0],
|
| 13 |
+
[0.0, 0.0, 0.0, 1.0]
|
| 14 |
+
])
|
| 15 |
+
intrinsic1 = np.array([
|
| 16 |
+
[500, 0, 320],
|
| 17 |
+
[0, 500, 240],
|
| 18 |
+
[0, 0, 1]
|
| 19 |
+
])
|
| 20 |
+
|
| 21 |
+
# 摄像头2的位姿和内参
|
| 22 |
+
pose2 = np.array([
|
| 23 |
+
[0.99, 0.10, -0.10, -1.0],
|
| 24 |
+
[-0.10, 0.99, 0.10, -2.0],
|
| 25 |
+
[0.10, -0.10, 0.99, -3.0],
|
| 26 |
+
[0.0, 0.0, 0.0, 1.0]
|
| 27 |
+
])
|
| 28 |
+
intrinsic2 = np.array([
|
| 29 |
+
[500, 0, 320],
|
| 30 |
+
[0, 500, 240],
|
| 31 |
+
[0, 0, 1]
|
| 32 |
+
])
|
| 33 |
+
|
| 34 |
+
# 展示摄像头1
|
| 35 |
+
rr.log_camera("camera1", pose=pose1, intrinsic=intrinsic1)
|
| 36 |
+
|
| 37 |
+
# 展示摄像头2
|
| 38 |
+
rr.log_camera("camera2", pose=pose2, intrinsic=intrinsic2)
|
| 39 |
+
|
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/visualization_utils.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import moviepy.editor as mpy
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from .generic_utils import reverse_imagenet_normalize
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def colormap_image(
|
| 13 |
+
image_1hw,
|
| 14 |
+
mask_1hw=None,
|
| 15 |
+
invalid_color=(0.0, 0, 0.0),
|
| 16 |
+
flip=True,
|
| 17 |
+
vmin=None,
|
| 18 |
+
vmax=None,
|
| 19 |
+
return_vminvmax=False,
|
| 20 |
+
colormap="turbo",
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Colormaps a one channel tensor using a matplotlib colormap.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
image_1hw: the tensor to colomap.
|
| 27 |
+
mask_1hw: an optional float mask where 1.0 donates valid pixels.
|
| 28 |
+
colormap: the colormap to use. Default is turbo.
|
| 29 |
+
invalid_color: the color to use for invalid pixels.
|
| 30 |
+
flip: should we flip the colormap? True by default.
|
| 31 |
+
vmin: if provided uses this as the minimum when normalizing the tensor.
|
| 32 |
+
vmax: if provided uses this as the maximum when normalizing the tensor.
|
| 33 |
+
When either of vmin or vmax are None, they are computed from the
|
| 34 |
+
tensor.
|
| 35 |
+
return_vminvmax: when true, returns vmin and vmax.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
image_cm_3hw: image of the colormapped tensor.
|
| 39 |
+
vmin, vmax: returned when return_vminvmax is true.
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
"""
|
| 43 |
+
valid_vals = image_1hw if mask_1hw is None else image_1hw[mask_1hw.bool()]
|
| 44 |
+
if vmin is None:
|
| 45 |
+
vmin = valid_vals.min()
|
| 46 |
+
if vmax is None:
|
| 47 |
+
vmax = valid_vals.max()
|
| 48 |
+
|
| 49 |
+
cmap = torch.Tensor(
|
| 50 |
+
plt.cm.get_cmap(colormap)(
|
| 51 |
+
torch.linspace(0, 1, 256)
|
| 52 |
+
)[:, :3]
|
| 53 |
+
).to(image_1hw.device)
|
| 54 |
+
if flip:
|
| 55 |
+
cmap = torch.flip(cmap, (0,))
|
| 56 |
+
|
| 57 |
+
h, w = image_1hw.shape[1:]
|
| 58 |
+
|
| 59 |
+
image_norm_1hw = (image_1hw - vmin) / (vmax - vmin)
|
| 60 |
+
image_int_1hw = (torch.clamp(image_norm_1hw * 255, 0, 255)).byte().long()
|
| 61 |
+
|
| 62 |
+
image_cm_3hw = cmap[image_int_1hw.flatten(start_dim=1)
|
| 63 |
+
].permute([0, 2, 1]).view([-1, h, w])
|
| 64 |
+
|
| 65 |
+
if mask_1hw is not None:
|
| 66 |
+
invalid_color = torch.Tensor(invalid_color).view(3, 1, 1).to(image_1hw.device)
|
| 67 |
+
image_cm_3hw = image_cm_3hw * mask_1hw + invalid_color * (1 - mask_1hw)
|
| 68 |
+
|
| 69 |
+
if return_vminvmax:
|
| 70 |
+
return image_cm_3hw, vmin, vmax
|
| 71 |
+
else:
|
| 72 |
+
return image_cm_3hw
|
| 73 |
+
|
| 74 |
+
def save_viz_video_frames(frame_list, path, fps=30):
|
| 75 |
+
"""
|
| 76 |
+
Saves a video file of numpy RGB frames in frame_list.
|
| 77 |
+
"""
|
| 78 |
+
clip = mpy.ImageSequenceClip(frame_list, fps=fps)
|
| 79 |
+
clip.write_videofile(path, verbose=False, logger=None)
|
| 80 |
+
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def quick_viz_export(
|
| 85 |
+
output_path,
|
| 86 |
+
outputs,
|
| 87 |
+
cur_data,
|
| 88 |
+
batch_ind,
|
| 89 |
+
valid_mask_b,
|
| 90 |
+
batch_size):
|
| 91 |
+
""" Helper function for quickly exporting depth maps during inference. """
|
| 92 |
+
|
| 93 |
+
if valid_mask_b.sum() == 0:
|
| 94 |
+
batch_vmin = 0.0
|
| 95 |
+
batch_vmax = 5.0
|
| 96 |
+
else:
|
| 97 |
+
batch_vmin = cur_data["full_res_depth_b1hw"][valid_mask_b].min()
|
| 98 |
+
batch_vmax = cur_data["full_res_depth_b1hw"][valid_mask_b].max()
|
| 99 |
+
|
| 100 |
+
if batch_vmax == batch_vmin:
|
| 101 |
+
batch_vmin = 0.0
|
| 102 |
+
batch_vmax = 5.0
|
| 103 |
+
|
| 104 |
+
for elem_ind in range(outputs["depth_pred_s0_b1hw"].shape[0]):
|
| 105 |
+
if "frame_id_string" in cur_data:
|
| 106 |
+
frame_id = cur_data["frame_id_string"][elem_ind]
|
| 107 |
+
else:
|
| 108 |
+
frame_id = (batch_ind * batch_size) + elem_ind
|
| 109 |
+
frame_id = f"{str(frame_id):6d}"
|
| 110 |
+
|
| 111 |
+
# check for valid depths from dataloader
|
| 112 |
+
if valid_mask_b[elem_ind].sum() == 0:
|
| 113 |
+
sample_vmin = 0.0
|
| 114 |
+
sample_vmax = 0.0
|
| 115 |
+
else:
|
| 116 |
+
# these will be the same when the depth map is all ones.
|
| 117 |
+
sample_vmin = cur_data["full_res_depth_b1hw"][elem_ind][valid_mask_b[elem_ind]].min()
|
| 118 |
+
sample_vmax = cur_data["full_res_depth_b1hw"][elem_ind][valid_mask_b[elem_ind]].max()
|
| 119 |
+
|
| 120 |
+
# if no meaningful gt depth in dataloader, don't viz gt and
|
| 121 |
+
# set vmin/max to default
|
| 122 |
+
if sample_vmax != sample_vmin:
|
| 123 |
+
full_res_depth_1hw = cur_data["full_res_depth_b1hw"][elem_ind]
|
| 124 |
+
|
| 125 |
+
full_res_depth_3hw = colormap_image(
|
| 126 |
+
full_res_depth_1hw,
|
| 127 |
+
vmin=batch_vmin, vmax=batch_vmax
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
full_res_depth_hw3 = np.uint8(
|
| 131 |
+
full_res_depth_3hw.permute(1,2,0
|
| 132 |
+
).cpu().detach().numpy() * 255
|
| 133 |
+
)
|
| 134 |
+
Image.fromarray(full_res_depth_hw3).save(
|
| 135 |
+
os.path.join(output_path,
|
| 136 |
+
f"{frame_id}_gt_depth.png")
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
lowest_cost_3hw = colormap_image(
|
| 140 |
+
outputs["lowest_cost_bhw"][elem_ind].unsqueeze(0),
|
| 141 |
+
vmin=batch_vmin, vmax=batch_vmax
|
| 142 |
+
)
|
| 143 |
+
pil_image = Image.fromarray(
|
| 144 |
+
np.uint8(
|
| 145 |
+
lowest_cost_3hw.permute(1,2,0
|
| 146 |
+
).cpu().detach().numpy() * 255)
|
| 147 |
+
)
|
| 148 |
+
pil_image.save(os.path.join(output_path,
|
| 149 |
+
f"{frame_id}_lowest_cost_pred.png"))
|
| 150 |
+
|
| 151 |
+
depth_3hw = colormap_image(
|
| 152 |
+
outputs["depth_pred_s0_b1hw"][elem_ind],
|
| 153 |
+
vmin=batch_vmin, vmax=batch_vmax)
|
| 154 |
+
pil_image = Image.fromarray(
|
| 155 |
+
np.uint8(depth_3hw.permute(1,2,0
|
| 156 |
+
).cpu().detach().numpy() * 255)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
pil_image.save(os.path.join(output_path, f"{frame_id}_pred_depth.png"))
|
| 160 |
+
|
| 161 |
+
main_color_3hw = cur_data["high_res_color_b3hw"][elem_ind]
|
| 162 |
+
main_color_3hw = reverse_imagenet_normalize(main_color_3hw)
|
| 163 |
+
pil_image = Image.fromarray(
|
| 164 |
+
np.uint8(main_color_3hw.permute(1,2,0
|
| 165 |
+
).cpu().detach().numpy() * 255)
|
| 166 |
+
)
|
| 167 |
+
pil_image.save(os.path.join(output_path, f"{frame_id}_color.png"))
|