larryshaw0079 commited on
Commit
873de4e
·
verified ·
1 Parent(s): 4543b45

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/.hydra/config.yaml +125 -0
  2. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/.hydra/hydra.yaml +186 -0
  3. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/.hydra/overrides.yaml +31 -0
  4. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-1.pth +3 -0
  5. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-2.pth +3 -0
  6. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-final.pth +3 -0
  7. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-last.pth +3 -0
  8. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/CHANGELOG.md +19 -0
  9. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/README.md +373 -0
  10. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/README_submap.md +225 -0
  11. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/base_opt.py +301 -0
  12. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/commons.py +102 -0
  13. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/__init__.py +31 -0
  14. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/base_opt.py +620 -0
  15. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/commons.py +102 -0
  16. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/init_im_poses.py +378 -0
  17. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/optimizer.py +301 -0
  18. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/init_all.py +222 -0
  19. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/utils.py +443 -0
  20. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/deepspeed_zero3_bf16.json +19 -0
  21. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune.yaml +102 -0
  22. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune_paper_h20.yaml +129 -0
  23. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune_pseudo_gt_high_recall.yaml +129 -0
  24. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune_sub_only.yaml +129 -0
  25. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/mytrain.yaml +92 -0
  26. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/environment.yml +245 -0
  27. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/eval_ate_scaled.py +54 -0
  28. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/get_ate.py +74 -0
  29. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/publish_submap.sh +138 -0
  30. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/requirements.txt +30 -0
  31. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/run_tum.sh +77 -0
  32. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/run_tum_top5.sh +66 -0
  33. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/setup.py +8 -0
  34. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/setup_env.sh +18 -0
  35. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/__init__.py +0 -0
  36. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/audit_dataset_num_views.py +412 -0
  37. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/batched_dynamic_router.py +243 -0
  38. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/demo.py +540 -0
  39. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/demo_infinite.py +493 -0
  40. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/demo_submap.py +927 -0
  41. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/download_data.sh +354 -0
  42. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/exp_joint_freeze_frontend_fsdp_8gpu.sh +438 -0
  43. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/graph_gated_memory.py +850 -0
  44. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/mine_pseudo_gt.py +588 -0
  45. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/pseudo_gt.py +348 -0
  46. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/__init__.py +197 -0
  47. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/generic_utils.py +274 -0
  48. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/geometry_utils.py +232 -0
  49. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/tmp.py +39 -0
  50. checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/visualization_utils.py +167 -0
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/.hydra/config.yaml ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accum_iter: 1
2
+ allow_repeat: false
3
+ amp: 1
4
+ batch_size: 1
5
+ benchmark: false
6
+ custom_lr_scale: 1.0
7
+ data_root: /home/23068142r/work_dir/data
8
+ root_arkit: /home/23068142r/work_dir/data/processed_arkitscenes
9
+ root_scannetpp: /home/23068142r/work_dir/data/preprocessed_scannetpp
10
+ root_scannet: /home/23068142r/work_dir/data/processed_scannet
11
+ root_hypersim: /home/23068142r/work_dir/data/preprocessed_Hypersim
12
+ root_blendedmvs: /home/23068142r/work_dir/data/processed_blendedmvs
13
+ root_megadepth: /home/23068142r/work_dir/data/processed_megadepth
14
+ root_mvs_synth: /home/23068142r/work_dir/data/processed_mvs_synth
15
+ dataset_arkit: ARKitScenes_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_arkit}",
16
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
17
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_arkit}, n_corres=${n_corres_train})
18
+ dataset_scannetpp: ScanNetpp_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannetpp}",
19
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
20
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_scannetpp}, n_corres=${n_corres_train})
21
+ dataset_scannet: ScanNet_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannet}",
22
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
23
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_scannet}, n_corres=${n_corres_train})
24
+ dataset_hypersim: HyperSim_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_hypersim}",
25
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
26
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_hypersim}, n_corres=${n_corres_train})
27
+ dataset_blendedmvs: BlendedMVS_Multi(allow_repeat=${allow_repeat}, split='train',
28
+ ROOT="${root_blendedmvs}", aug_crop=16, resolution=[(518, 392), (518, 336), (518,
29
+ 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views_blendedmvs},
30
+ n_corres=${n_corres_train})
31
+ dataset_megadepth: MegaDepth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_megadepth}",
32
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
33
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_megadepth}, n_corres=${n_corres_train})
34
+ dataset_mvs_synth: MVS_Synth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_mvs_synth}",
35
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
36
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_mvs_synth}, n_corres=${n_corres_train})
37
+ desc_dim: 128
38
+ detach_frontend_tokens: true
39
+ dist_backend: nccl
40
+ dist_url: env://
41
+ distributed: false
42
+ enable_dynamic_boundary: false
43
+ enable_loop: true
44
+ enable_submap: true
45
+ enable_temporal: false
46
+ epochs: 2
47
+ eval_freq: 1
48
+ exp_name: paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12
49
+ fixed_length: true
50
+ freeze_encoder: true
51
+ gpu: 0
52
+ gradient_checkpointing: true
53
+ gumbel_tau: 5.0
54
+ loop_mask_mode: soft_all
55
+ retain_history_grad: true
56
+ submap_train_mode: full_token
57
+ submap_retrieval_topk: 0
58
+ submap_fetch_source: frontend
59
+ submap_descriptor_source: frontend
60
+ train_submap_modules_only: false
61
+ gumbel_tau_end: 0.1
62
+ gumbel_tau_start: 5.0
63
+ keep_freq: 1
64
+ load_only_encoder: false
65
+ local-rank: -1
66
+ logdir: ${save_dir}/${exp_name}/logs
67
+ long_context: false
68
+ lr: 1.0e-05
69
+ max_checkpoints: 10
70
+ max_recursive_submaps: 5
71
+ min_lr: 1.0e-08
72
+ n_corres_test: 0
73
+ n_corres_train: 0
74
+ num_imgs_vis: 4
75
+ num_test_views: 4
76
+ num_views: 24
77
+ num_views_arkit: 64
78
+ num_views_scannetpp: 24
79
+ num_views_scannet: 64
80
+ num_views_hypersim: 24
81
+ num_views_blendedmvs: 64
82
+ num_views_megadepth: 64
83
+ num_views_mvs_synth: 24
84
+ num_workers: 4
85
+ output_dir: ${save_dir}/${exp_name}/
86
+ pretrained: /home/23068142r/work_dir/projects/e2e-semantic-SLAM-submap/ckpt/checkpoint-10.pth.model
87
+ print_freq: 10
88
+ print_img_freq: 50000000
89
+ rank: 0
90
+ resume: null
91
+ retention_ratio: 0.5
92
+ pseudo_gt:
93
+ enable: false
94
+ cache_path: null
95
+ use_soft_targets: true
96
+ min_confidence: 0.65
97
+ min_support_pairs: 1
98
+ topk_pairs: 4
99
+ loss_type: hybrid
100
+ loss_weight_gate: 0.1
101
+ loss_weight_desc: 0.1
102
+ geometric_support_scale: 0.25
103
+ ranking_margin: 0.1
104
+ use_l2m: false
105
+ l2m_min_certainty: 0.0
106
+ l2m_min_inlier_ratio: 0.0
107
+ save_dir: /home/23068142r/work_dir/projects/e2e-semantic-SLAM/checkpoints/paper_smoke_local_8gpu/joint_freeze_frontend_fsdp_sub12
108
+ save_freq: 0.1
109
+ seed: 42
110
+ soft_mask_bias: 0.2
111
+ soft_mask_temperature: 0.25
112
+ start_epoch: 0
113
+ start_step: 0
114
+ submap_size: 6
115
+ task: SLAMFormer_Submap_Finetune
116
+ tbptt_window: 0
117
+ teacher: null
118
+ temporal_embed_mode: learned
119
+ test_criterion: DistillLoss()
120
+ test_dataset: ''
121
+ train_criterion: DistillLoss()
122
+ train_dataset: 16 @ ${dataset_scannetpp} + 16 @ ${dataset_hypersim} + 16 @ ${dataset_mvs_synth}
123
+ warmup_epochs: 0.5
124
+ weight_decay: 0.05
125
+ world_size: 1
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/.hydra/hydra.yaml ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ${save_dir}/${exp_name}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - exp_name=paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12
116
+ - save_dir=/home/23068142r/work_dir/projects/e2e-semantic-SLAM/checkpoints/paper_smoke_local_8gpu/joint_freeze_frontend_fsdp_sub12
117
+ - pretrained=/home/23068142r/work_dir/projects/e2e-semantic-SLAM-submap/ckpt/checkpoint-10.pth.model
118
+ - resume=null
119
+ - data_root=/home/23068142r/work_dir/data
120
+ - root_arkit=/home/23068142r/work_dir/data/processed_arkitscenes
121
+ - root_scannetpp=/home/23068142r/work_dir/data/preprocessed_scannetpp
122
+ - root_scannet=/home/23068142r/work_dir/data/processed_scannet
123
+ - root_hypersim=/home/23068142r/work_dir/data/preprocessed_Hypersim
124
+ - root_blendedmvs=/home/23068142r/work_dir/data/processed_blendedmvs
125
+ - root_megadepth=/home/23068142r/work_dir/data/processed_megadepth
126
+ - root_mvs_synth=/home/23068142r/work_dir/data/processed_mvs_synth
127
+ - num_views=24
128
+ - num_views_arkit=64
129
+ - num_views_scannetpp=24
130
+ - num_views_scannet=64
131
+ - num_views_hypersim=24
132
+ - num_views_blendedmvs=64
133
+ - num_views_megadepth=64
134
+ - num_views_mvs_synth=24
135
+ - train_submap_modules_only=false
136
+ - detach_frontend_tokens=true
137
+ - submap_train_mode=full_token
138
+ - submap_retrieval_topk=0
139
+ - submap_fetch_source=frontend
140
+ - submap_descriptor_source=frontend
141
+ - pseudo_gt.enable=false
142
+ - pseudo_gt.cache_path=null
143
+ - train_dataset=16 @ ${dataset_scannetpp} + 16 @ ${dataset_hypersim} + 16 @ ${dataset_mvs_synth}
144
+ - epochs=2
145
+ - test_dataset=
146
+ job:
147
+ name: finetune
148
+ chdir: null
149
+ override_dirname: data_root=/home/23068142r/work_dir/data,detach_frontend_tokens=true,epochs=2,exp_name=paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12,num_views=24,num_views_arkit=64,num_views_blendedmvs=64,num_views_hypersim=24,num_views_megadepth=64,num_views_mvs_synth=24,num_views_scannet=64,num_views_scannetpp=24,pretrained=/home/23068142r/work_dir/projects/e2e-semantic-SLAM-submap/ckpt/checkpoint-10.pth.model,pseudo_gt.cache_path=null,pseudo_gt.enable=false,resume=null,root_arkit=/home/23068142r/work_dir/data/processed_arkitscenes,root_blendedmvs=/home/23068142r/work_dir/data/processed_blendedmvs,root_hypersim=/home/23068142r/work_dir/data/preprocessed_Hypersim,root_megadepth=/home/23068142r/work_dir/data/processed_megadepth,root_mvs_synth=/home/23068142r/work_dir/data/processed_mvs_synth,root_scannet=/home/23068142r/work_dir/data/processed_scannet,root_scannetpp=/home/23068142r/work_dir/data/preprocessed_scannetpp,save_dir=/home/23068142r/work_dir/projects/e2e-semantic-SLAM/checkpoints/paper_smoke_local_8gpu/joint_freeze_frontend_fsdp_sub12,submap_descriptor_source=frontend,submap_fetch_source=frontend,submap_retrieval_topk=0,submap_train_mode=full_token,test_dataset=,train_dataset=16
150
+ @ ${dataset_scannetpp} + 16 @ ${dataset_hypersim} + 16 @ ${dataset_mvs_synth},train_submap_modules_only=false
151
+ id: ???
152
+ num: ???
153
+ config_name: finetune_paper_h20.yaml
154
+ env_set: {}
155
+ env_copy: []
156
+ config:
157
+ override_dirname:
158
+ kv_sep: '='
159
+ item_sep: ','
160
+ exclude_keys: []
161
+ runtime:
162
+ version: 1.3.2
163
+ version_base: '1.3'
164
+ cwd: /home/23068142r/work_dir/projects/e2e-semantic-SLAM
165
+ config_sources:
166
+ - path: hydra.conf
167
+ schema: pkg
168
+ provider: hydra
169
+ - path: /home/23068142r/work_dir/projects/e2e-semantic-SLAM/src/../config
170
+ schema: file
171
+ provider: main
172
+ - path: ''
173
+ schema: structured
174
+ provider: schema
175
+ output_dir: /home/23068142r/work_dir/projects/e2e-semantic-SLAM/checkpoints/paper_smoke_local_8gpu/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12
176
+ choices:
177
+ hydra/env: default
178
+ hydra/callbacks: null
179
+ hydra/job_logging: default
180
+ hydra/hydra_logging: default
181
+ hydra/hydra_help: default
182
+ hydra/help: default
183
+ hydra/sweeper: basic
184
+ hydra/launcher: basic
185
+ hydra/output: default
186
+ verbose: true
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/.hydra/overrides.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - exp_name=paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12
2
+ - save_dir=/home/23068142r/work_dir/projects/e2e-semantic-SLAM/checkpoints/paper_smoke_local_8gpu/joint_freeze_frontend_fsdp_sub12
3
+ - pretrained=/home/23068142r/work_dir/projects/e2e-semantic-SLAM-submap/ckpt/checkpoint-10.pth.model
4
+ - resume=null
5
+ - data_root=/home/23068142r/work_dir/data
6
+ - root_arkit=/home/23068142r/work_dir/data/processed_arkitscenes
7
+ - root_scannetpp=/home/23068142r/work_dir/data/preprocessed_scannetpp
8
+ - root_scannet=/home/23068142r/work_dir/data/processed_scannet
9
+ - root_hypersim=/home/23068142r/work_dir/data/preprocessed_Hypersim
10
+ - root_blendedmvs=/home/23068142r/work_dir/data/processed_blendedmvs
11
+ - root_megadepth=/home/23068142r/work_dir/data/processed_megadepth
12
+ - root_mvs_synth=/home/23068142r/work_dir/data/processed_mvs_synth
13
+ - num_views=24
14
+ - num_views_arkit=64
15
+ - num_views_scannetpp=24
16
+ - num_views_scannet=64
17
+ - num_views_hypersim=24
18
+ - num_views_blendedmvs=64
19
+ - num_views_megadepth=64
20
+ - num_views_mvs_synth=24
21
+ - train_submap_modules_only=false
22
+ - detach_frontend_tokens=true
23
+ - submap_train_mode=full_token
24
+ - submap_retrieval_topk=0
25
+ - submap_fetch_source=frontend
26
+ - submap_descriptor_source=frontend
27
+ - pseudo_gt.enable=false
28
+ - pseudo_gt.cache_path=null
29
+ - train_dataset=16 @ ${dataset_scannetpp} + 16 @ ${dataset_hypersim} + 16 @ ${dataset_mvs_synth}
30
+ - epochs=2
31
+ - test_dataset=
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c96e1ed2d9231223d02406ef66794e5f9f39bae990354bd5bc4101ab2396afb6
3
+ size 4516140233
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:465a4b460ad29b4d539348adebfb85a01845b0f905cad11206d260cf1b95f76f
3
+ size 4516140233
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a618c80dcef89b1a41580d858134bac0d79b39130586e8d036a26ac9582c9f2
3
+ size 3873507717
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/checkpoint-last.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb7f68ba708455c27ac4b78128cf8a34f0f9511caf47074e3eea5431d5dfdfdb
3
+ size 4516145306
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/CHANGELOG.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changelog
2
+
3
+ ## Hardware
4
+
5
+ | Component | Specification |
6
+ |-----------|---------------|
7
+ | GPU | 8 x NVIDIA L40S |
8
+ | CPU | AMD EPYC 7763 64-Core Processor (112 vCPUs) |
9
+ | Memory | 755 GiB |
10
+
11
+ ## 2026-04-03
12
+
13
+ ### Added
14
+
15
+ - new script `slam/exp_joint_freeze_frontend_fsdp_8gpu.sh`
16
+ - `SKIP_TEST` flag in `slam/exp_joint_freeze_frontend_fsdp_8gpu.sh` (default `0`). When set to `1`, the test dataset Hydra override is cleared and no test data loaders are built.
17
+ - Guard in `src/finetune.py` to skip test dataset construction when `test_dataset` is empty, setting `data_loader_test = {}`.
18
+ - Guard in `src/finetune.py` (`train_one_epoch`) against `ZeroDivisionError` when `int(save_freq * len(data_loader))` truncates to 0 (e.g. small dataset on many GPUs). Intra-epoch checkpoint saving is now skipped gracefully in this case.
19
+ - Adjust the dataset path to fit my own computer
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/README.md ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>SLAM-Former Submap Release Guide</h1>
3
+ <h3>GitHub handoff, FSDP training, and full-sequence TUM runbook</h3>
4
+ </div>
5
+
6
+ ## What this upload is for
7
+
8
+ This branch is the submap-oriented training and inference release prepared for GitHub handoff.
9
+
10
+ The release is built around the H20 launchers and the current submap memory path, with the following fixed requirements:
11
+
12
+ - **Distributed strategy**: `FSDP`
13
+ - **Submap length**: `submap_size=12`
14
+ - **Descriptor source**: `frontend`
15
+ - **Historical token fetch source**: `frontend`
16
+ - **External epoch control**: set `EPOCHS` outside the script
17
+ - **Comparison modes**:
18
+ - `submap-only`
19
+ - `backend + submap joint training with detached frontend tokens`
20
+
21
+ If you want the original upstream project README, see `README_ori.md`.
22
+ If you want more implementation detail on the submap system, see `README_submap.md` and `submap_handoff.md`.
23
+
24
+ ## Clone the published branch
25
+
26
+ ```bash
27
+ git clone -b submap https://github.com/SlamMate/e2e-semantic-SLAM.git
28
+ cd e2e-semantic-SLAM
29
+ ```
30
+
31
+ ## Environment
32
+
33
+ ```bash
34
+ conda env create -f environment.yml
35
+ conda activate SLAM-Former
36
+ ```
37
+
38
+ The exported `environment.yml` is a full snapshot of the current Conda environment, including CUDA-related Python packages. If you want to move it between machines, you can keep the file as-is or delete the final `prefix:` line for a cleaner portable export.
39
+
40
+ If you prefer the original manual install flow:
41
+
42
+ ```bash
43
+ conda create -n SLAM-Former python=3.11
44
+ conda activate SLAM-Former
45
+
46
+ pip install -r requirements.txt
47
+ pip install -e .
48
+ ```
49
+
50
+ ## Paths and launch parameters to set on a new cluster
51
+
52
+ | Variable | Meaning | Typical example |
53
+ | --- | --- | --- |
54
+ | `PROJECT_DIR` | repository root | `/path/to/e2e-semantic-SLAM` |
55
+ | `DATA_ROOT` | training data root | `/path/to/data/train` |
56
+ | `PRETRAINED` | pretrained SLAM-Former checkpoint | `/path/to/ckpt/checkpoint-10.pth.model` |
57
+ | `SAVE_DIR` | checkpoint output root | `/path/to/checkpoints` |
58
+ | `CONDA_SH` | conda init script | `/path/to/miniconda3/etc/profile.d/conda.sh` |
59
+ | `CONDA_ENV_NAME` | conda env name | `SLAM-Former` |
60
+ | `MASTER_ADDR` | rank-0 node hostname or IP for multi-node bash launch | `node0` |
61
+ | `MACHINE_RANK` | machine rank for multi-node bash launch | `0..7` |
62
+ | `NUM_MACHINES` | machine count for remote long run | `8` |
63
+ | `GPUS_PER_NODE` | visible GPUs per node | `8` |
64
+ | `MASTER_PORT` | communication port | `29661` / `29671` etc. |
65
+
66
+ For migration to a new machine, the entries above are the first ones to review. If the data layout is unchanged and each dataset still lives under `DATA_ROOT/<dataset>`, you can usually keep the `ROOT_*` defaults.
67
+
68
+ The default data layout expected by the launchers is:
69
+
70
+ - `processed_arkitscenes`
71
+ - `processed_scannetpp`
72
+ - `processed_scannet` or `processed_scannetv2`
73
+ - `hypersim`
74
+ - `processed_blendedmvs`
75
+ - `processed_megadepth`
76
+ - `processed_mvs_synth`
77
+
78
+ ## Fixed training invariants in this release
79
+
80
+ - `CONFIG_NAME=finetune_paper_h20.yaml`
81
+ - `DIST_STRATEGY=fsdp`
82
+ - `SUBMAP_SIZE=12`
83
+ - `SUBMAP_TRAIN_MODE=full_token`
84
+ - `SUBMAP_RETRIEVAL_TOPK=0`
85
+ - `SUBMAP_FETCH_SOURCE=frontend`
86
+ - `SUBMAP_DESCRIPTOR_SOURCE=frontend`
87
+ - `freeze_encoder=true` stays inherited from `config/finetune_paper_h20.yaml`
88
+ - `DETACH_FRONTEND_TOKENS=1` in both comparison modes
89
+ - `TRAIN_SUBMAP_MODULES_ONLY=1` for strict submap-only training
90
+ - `TRAIN_SUBMAP_MODULES_ONLY=0` for backend + submap joint training while frontend tokens stay detached
91
+
92
+ ## `num_views` reference table
93
+
94
+ The remote 8-node scripts use aggressive long-sequence defaults.
95
+ These are **reference upper settings**, not no-skip guarantees.
96
+ For ARKitScenes, the local audit shows a very small strict no-skip cap because of a few short scenes, but much longer clips are still possible if short scenes are skipped.
97
+
98
+ | Dataset | Remote 8x8 default | StrictCapNoSkip | MedianCap | MaxCap | Notes |
99
+ | --- | ---: | ---: | ---: | ---: | --- |
100
+ | ARKitScenes | 478 | 2 | 92 | 478 | strict cap is dominated by short scenes; long clips work with scene skipping |
101
+ | ScanNet++ | 150 | 45 | 143 | 150 | local processed data already supports long clips |
102
+ | ScanNet | 64 | N/A | N/A | N/A | local audit unavailable; rerun audit on target cluster if needed |
103
+ | HyperSim | 64 | N/A | N/A | N/A | local audit unavailable |
104
+ | BlendedMVS | 64 | N/A | N/A | N/A | local audit unavailable |
105
+ | MegaDepth | 64 | N/A | N/A | 64 | loader-side hard cap is 64 |
106
+ | MVS-Synth | 100 | 69 | 100 | 100 | long clips are supported on all local scenes |
107
+
108
+ If your target cluster has different processed data, rerun the audit and override `NUM_VIEWS_*` as needed.
109
+
110
+ ## Scripts provided in this release
111
+
112
+ | Script | Launch mode | Purpose |
113
+ | --- | --- | --- |
114
+ | `slam/sbatch_smoke_submap_only_fsdp_2gpu.sh` | local `sbatch` | 1-node 2-GPU smoke validation for strict submap-only mode |
115
+ | `slam/sbatch_smoke_joint_freeze_frontend_fsdp_2gpu.sh` | local `sbatch` | 1-node 2-GPU smoke validation for backend + submap with detached frontend |
116
+ | `slam/train_remote_submap_only_fsdp_8node8gpu.sh` | remote `bash` | 8-node 8-GPU-per-node long-sequence submap-only training |
117
+ | `slam/train_remote_joint_freeze_frontend_fsdp_8node8gpu.sh` | remote `bash` | 8-node 8-GPU-per-node long-sequence backend + submap training with detached frontend |
118
+
119
+ ## Migration guide: run the smoke scripts on another machine
120
+
121
+ These two smoke launchers are self-contained. They no longer rely on the old wrapper scripts. Each launcher:
122
+
123
+ - resolves `PROJECT_DIR` from `PROJECT_DIR`, `SLURM_SUBMIT_DIR`, or the script location;
124
+ - requires `CONDA_SH` to point to a valid `conda.sh` file and activates `CONDA_ENV_NAME` from that shell;
125
+ - loads `cuda12.1/toolkit` when the `module` command is available;
126
+ - sets `PYTHONPATH` to include `src/`;
127
+ - launches `src/finetune.py` through `accelerate launch`.
128
+
129
+ ### 1. Minimum checklist before the first run
130
+
131
+ 1. Clone or copy the repository onto the new machine.
132
+ 2. Create the environment with `conda env create -f environment.yml`, and make sure `CONDA_SH` points at the correct `conda.sh` on the new machine.
133
+ 3. Make sure the pretrained checkpoint exists at `PRETRAINED` or override `PRETRAINED=...`.
134
+ 4. Point `DATA_ROOT` to the processed training data on the new machine.
135
+ 5. Verify the dataset roots under `DATA_ROOT`, or override the individual `ROOT_*` variables.
136
+ 6. Check that the new machine has a compatible CUDA setup. If the cluster uses a different module name, edit the `module load cuda12.1/toolkit` line.
137
+ 7. If you submit with Slurm, keep the `#SBATCH` resource requests consistent with the machine’s GPU, CPU, and memory limits.
138
+
139
+ ### 2. Variables you usually need to change
140
+
141
+ | Variable | Meaning | Typical reason to change |
142
+ | --- | --- | --- |
143
+ | `PROJECT_DIR` | repository root used for `PYTHONPATH`, output paths, and the default checkpoint path | the repo lives in a different directory |
144
+ | `CONDA_SH` | path to `conda.sh` used to initialize Conda | the machine uses a different Miniconda install, or the default path does not exist |
145
+ | `CONDA_ENV_NAME` | environment name to activate | you created the env under a different name |
146
+ | `DATA_ROOT` | top-level directory that contains the processed datasets | the training data is mounted elsewhere |
147
+ | `ROOT_ARKIT`, `ROOT_SCANNETPP`, `ROOT_SCANNET`, `ROOT_SCANNET_FALLBACK`, `ROOT_HYPERSIM`, `ROOT_BLENDEDMVS`, `ROOT_MEGADEPTH`, `ROOT_MVS_SYNTH` | per-dataset roots used by the launchers | the dataset folders do not live directly under `DATA_ROOT` or ScanNet needs a fallback root |
148
+ | `PRETRAINED` | checkpoint loaded before training starts | the pretrained model is stored somewhere else |
149
+ | `SAVE_DIR` | root directory for checkpoints and logs | you want outputs on a different disk |
150
+ | `MASTER_PORT` | port used by `accelerate` to rendezvous the worker processes | another job is already using the default port |
151
+ | `NUM_GPUS` | number of processes / GPUs launched | the target node exposes a different GPU count |
152
+ | `AUTO_DISABLE_MISSING` | auto-disable a dataset whose root is missing or incomplete | set to `0` if you want the job to fail fast instead of silently skipping data |
153
+ | `EXPERIMENT_ROOT`, `VARIANT_NAME`, `EXP_NAME` | folder and experiment naming used for outputs | you want a different output namespace or to avoid collisions with old runs |
154
+ | `RESUME` | checkpoint path to resume from | you are continuing an interrupted run |
155
+
156
+ The following Slurm header lines also need cluster-specific tuning:
157
+
158
+ - `#SBATCH --gres=gpu:2` requests two GPUs.
159
+ - `#SBATCH --cpus-per-task` controls CPU cores.
160
+ - `#SBATCH --mem` controls RAM.
161
+ - `#SBATCH --time` is the wall-time limit.
162
+
163
+ ### 3. Variables that control the training recipe
164
+
165
+ | Variable | Meaning | Notes |
166
+ | --- | --- | --- |
167
+ | `CONFIG_NAME` | Hydra config file passed to `src/finetune.py` | both smoke scripts use `finetune_paper_h20.yaml` |
168
+ | `DIST_STRATEGY` | distributed backend (`fsdp` or `ddp`) | both smoke scripts use `fsdp` |
169
+ | `TRAIN_SUBMAP_MODULES_ONLY` | `1` = strict submap-only training; `0` = joint backend+submap training | `1` for `slam/sbatch_smoke_submap_only_fsdp_2gpu.sh`, `0` for `slam/sbatch_smoke_joint_freeze_frontend_fsdp_2gpu.sh` |
170
+ | `DETACH_FRONTEND_TOKENS` | detach frontend tokens from gradients | `1` in both scripts |
171
+ | `SUBMAP_SIZE` | number of frames / tokens per submap | default `12` in both scripts |
172
+ | `SUBMAP_TRAIN_MODE` | submap training mode | default `full_token` |
173
+ | `SUBMAP_RETRIEVAL_TOPK` | retrieval top-k setting | default `0` disables retrieval |
174
+ | `SUBMAP_FETCH_SOURCE` | source used to fetch submap features | default `frontend` |
175
+ | `SUBMAP_DESCRIPTOR_SOURCE` | source used to build submap descriptors | default `frontend` |
176
+ | `ENABLE_PSEUDO_GT` | enable pseudo-GT cache usage | keep `0` unless you have a valid cache |
177
+ | `PSEUDO_GT_CACHE_PATH` | path to the pseudo-GT cache | required when `ENABLE_PSEUDO_GT=1` |
178
+ | `EPOCHS` | number of epochs passed to `src/finetune.py` | smoke scripts default to `2` |
179
+
180
+ ### 4. Dataset-mixture knobs
181
+
182
+ | Variable | Meaning | Notes |
183
+ | --- | --- | --- |
184
+ | `SAMPLES_ARKIT`, `SAMPLES_SCANNETPP`, `SAMPLES_SCANNET`, `SAMPLES_HYPERSIM`, `SAMPLES_BLENDEDMVS`, `SAMPLES_MEGADEPTH`, `SAMPLES_MVS_SYNTH` | per-dataset sampling weights in the training mixture | increase or decrease them to rebalance the dataset mix |
185
+ | `NUM_VIEWS_ARKIT`, `NUM_VIEWS_SCANNETPP`, `NUM_VIEWS_SCANNET`, `NUM_VIEWS_HYPERSIM`, `NUM_VIEWS_BLENDEDMVS`, `NUM_VIEWS_MEGADEPTH`, `NUM_VIEWS_MVS_SYNTH` | per-dataset view caps | change these if the processed data on the new machine has different sequence lengths or hard caps |
186
+ | `GLOBAL_NUM_VIEWS` | optional global cap; if unset, the scripts derive it from the active datasets’ `NUM_VIEWS_*` values | set it when you want a single global value for all datasets |
187
+ | `NUM_VIEWS_ALL` | compatibility placeholder kept by the scripts | usually leave it at the default; the launcher mainly uses the per-dataset values above |
188
+
189
+ ### 5. The two scripts differ only in these defaults
190
+
191
+ | Item | Submap-only script | Joint + frozen frontend script | Meaning |
192
+ | --- | --- | --- | --- |
193
+ | `TRAIN_SUBMAP_MODULES_ONLY` | `1` | `0` | whether to train only the submap modules or the joint backend + submap stack |
194
+ | `MASTER_PORT` | `29661` | `29662` | keep the two smoke jobs from colliding on the same node |
195
+ | `VARIANT_NAME` | `submap_only_fsdp_sub12` | `joint_freeze_frontend_fsdp_sub12` | output subdirectory name under `SAVE_DIR` |
196
+ | `EXP_NAME` | `paper_smoke_submap_only_fsdp_2gpu_sub12` | `paper_smoke_joint_freeze_frontend_fsdp_2gpu_sub12` | experiment name written into logs and Hydra config |
197
+ | `#SBATCH --cpus-per-task` | `24` | `12` | CPU reservation for the smoke job |
198
+ | `#SBATCH --mem` | `120G` | `24G` | memory reservation for the smoke job |
199
+
200
+ ### 6. Direct launch examples on a new machine
201
+
202
+ To run on another Slurm machine, set the machine-specific variables inline and submit the launcher from the repo root.
203
+
204
+ Submap-only smoke:
205
+
206
+ ```bash
207
+ PROJECT_DIR=/path/to/e2e-semantic-SLAM \
208
+ CONDA_SH=/path/to/miniconda3/etc/profile.d/conda.sh \
209
+ CONDA_ENV_NAME=SLAM-Former \
210
+ DATA_ROOT=/path/to/data/train \
211
+ PRETRAINED=/path/to/ckpt/checkpoint-10.pth.model \
212
+ SAVE_DIR=/path/to/checkpoints \
213
+ MASTER_PORT=29661 \
214
+ sbatch slam/sbatch_smoke_submap_only_fsdp_2gpu.sh
215
+ ```
216
+
217
+ Joint + frozen frontend smoke:
218
+
219
+ ```bash
220
+ PROJECT_DIR=/path/to/e2e-semantic-SLAM \
221
+ CONDA_SH=/path/to/miniconda3/etc/profile.d/conda.sh \
222
+ CONDA_ENV_NAME=SLAM-Former \
223
+ DATA_ROOT=/path/to/data/train \
224
+ PRETRAINED=/path/to/ckpt/checkpoint-10.pth.model \
225
+ SAVE_DIR=/path/to/checkpoints \
226
+ MASTER_PORT=29662 \
227
+ sbatch slam/sbatch_smoke_joint_freeze_frontend_fsdp_2gpu.sh
228
+ ```
229
+
230
+ If the new machine does not have Slurm, you can still run the scripts with `bash ...` as long as the same two GPUs are visible to the shell and `accelerate` / CUDA are available; the `#SBATCH` lines are then ignored by Bash.
231
+
232
+ ## Local smoke validation on the current cluster
233
+
234
+ The two local `sbatch` scripts are intended to validate:
235
+
236
+ - the launcher path
237
+ - the FSDP wiring
238
+ - the README instructions
239
+ - the two comparison modes
240
+
241
+ They intentionally use a small sample budget and `64` views so they can be checked quickly on 1 node and 2 GPUs.
242
+
243
+ ### 1. Submap-only smoke
244
+
245
+ ```bash
246
+ EPOCHS=1 \
247
+ bash slam/sbatch_smoke_submap_only_fsdp_2gpu.sh
248
+ ```
249
+
250
+ ### 2. Backend + submap smoke with detached frontend
251
+
252
+ ```bash
253
+ EPOCHS=1 \
254
+ bash slam/sbatch_smoke_joint_freeze_frontend_fsdp_2gpu.sh
255
+ ```
256
+
257
+ Local smoke defaults:
258
+
259
+ - `SAMPLES_ARKIT=32`
260
+ - `SAMPLES_SCANNETPP=16`
261
+ - all other local smoke sample weights default to `0`
262
+ - `NUM_VIEWS_* = 64`
263
+ - `SUBMAP_SIZE = 12`
264
+
265
+ This keeps the smoke validation focused on code path correctness rather than final throughput.
266
+
267
+ ## Remote long-sequence training on 8 nodes x 8 GPUs
268
+
269
+ The remote scripts are the actual release launchers for long-sequence comparison.
270
+ Run the **same script on every node**, and only change `MACHINE_RANK`.
271
+
272
+ Remote defaults:
273
+
274
+ - `AUTO_DISABLE_MISSING=0`
275
+ - full paper dataset sample mix from `finetune_paper_h20.yaml`
276
+ - aggressive per-dataset `NUM_VIEWS_*` values from the table above
277
+ - `SUBMAP_SIZE=12`
278
+ - `SUBMAP_FETCH_SOURCE=frontend`
279
+ - `SUBMAP_DESCRIPTOR_SOURCE=frontend`
280
+
281
+ ### 1. Remote 8x8 submap-only run
282
+
283
+ On rank 0:
284
+
285
+ ```bash
286
+ MASTER_ADDR=node0 \
287
+ MACHINE_RANK=0 \
288
+ NUM_MACHINES=8 \
289
+ GPUS_PER_NODE=8 \
290
+ DATA_ROOT=/path/to/data/train \
291
+ SAVE_DIR=/path/to/checkpoints \
292
+ PRETRAINED=/path/to/ckpt/checkpoint-10.pth.model \
293
+ CONDA_SH=/path/to/miniconda3/etc/profile.d/conda.sh \
294
+ CONDA_ENV_NAME=SLAM-Former \
295
+ EPOCHS=10 \
296
+ bash slam/train_remote_submap_only_fsdp_8node8gpu.sh
297
+ ```
298
+
299
+ On the remaining nodes, rerun the same command with `MACHINE_RANK=1` through `7`.
300
+
301
+ ### 2. Remote 8x8 backend + submap run with detached frontend
302
+
303
+ On rank 0:
304
+
305
+ ```bash
306
+ MASTER_ADDR=node0 \
307
+ MACHINE_RANK=0 \
308
+ NUM_MACHINES=8 \
309
+ GPUS_PER_NODE=8 \
310
+ DATA_ROOT=/path/to/data/train \
311
+ SAVE_DIR=/path/to/checkpoints \
312
+ PRETRAINED=/path/to/ckpt/checkpoint-10.pth.model \
313
+ CONDA_SH=/path/to/miniconda3/etc/profile.d/conda.sh \
314
+ CONDA_ENV_NAME=SLAM-Former \
315
+ EPOCHS=10 \
316
+ bash slam/train_remote_joint_freeze_frontend_fsdp_8node8gpu.sh
317
+ ```
318
+
319
+ On the remaining nodes, rerun the same command with `MACHINE_RANK=1` through `7`.
320
+
321
+ ## Whole-sequence TUM inference after training
322
+
323
+ After training, use the saved checkpoint to run the full Freiburg1 TUM sequences.
324
+
325
+ ```bash
326
+ CKPT_PATH=/path/to/checkpoint-last.pth \
327
+ RUN_TAG=release_eval \
328
+ SUBMAP_INFERENCE_MODE=full \
329
+ SUBMAP_TRAIN_MODE=full_token \
330
+ SUBMAP_FETCH_SOURCE=frontend \
331
+ SUBMAP_DESCRIPTOR_SOURCE=frontend \
332
+ sbatch run_tum.sh
333
+ ```
334
+
335
+ This writes outputs to:
336
+
337
+ ```bash
338
+ tum_results_aligned/<RUN_TAG>/rgbd_dataset_freiburg1_*/
339
+ ```
340
+
341
+ Each sequence directory contains at least:
342
+
343
+ - `final_traj.txt`
344
+ - `final.ply`
345
+ - `final_pc/`
346
+
347
+ To evaluate ATE after the sequence finishes:
348
+
349
+ ```bash
350
+ evo_ape tum <ground_truth.txt> <final_traj.txt> -a --t_max_diff 0.02
351
+ ```
352
+
353
+ ## Output layout
354
+
355
+ Training outputs are written under:
356
+
357
+ ```bash
358
+ $SAVE_DIR/$EXP_NAME/
359
+ ```
360
+
361
+ Typical files include:
362
+
363
+ - `checkpoint-last.pth`
364
+ - `model.pth`
365
+ - `logs/`
366
+ - launcher stdout / stderr logs
367
+
368
+ ## Practical notes
369
+
370
+ - The local `sbatch` scripts are smoke validators.
371
+ - The remote `bash` scripts are the actual long-sequence release launchers.
372
+ - If the target cluster has a different sequence-cap profile, rerun `slam/audit_dataset_num_views.py` there and override `NUM_VIEWS_*`.
373
+ - If a remote run fails on missing dataset roots, keep `AUTO_DISABLE_MISSING=0` and fix the paths instead of silently training on a partial dataset mix.
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/README_submap.md ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>SLAM-Former Submap Companion / 子图系统说明</h1>
3
+ </div>
4
+
5
+ ### Updates
6
+ * [Mar 30, 2026] Added a submap-native training and inference path with switchable `full_token` / `top5_dual_queue` modes.
7
+ * [Mar 30, 2026] Added dedicated TUM top5 inference support and a standalone `run_tum_top5.sh` launcher.
8
+ * [Mar 26, 2026] Submap-only pseudo-GT training was exercised with the high-recall config and FSDP launchers.
9
+
10
+ ### Getting Started
11
+
12
+ The environment setup is the same as the original `README.md`:
13
+
14
+ ```bash
15
+ git clone https://github.com/Tsinghua-MARS-Lab/SLAM-Former.git
16
+ cd SLAM-Former
17
+
18
+ conda create -n SLAM-Former python=3.11
19
+ conda activate SLAM-Former
20
+
21
+ pip install -r requirements.txt
22
+ pip install -e .
23
+ ```
24
+
25
+ ### Submap System Overview
26
+
27
+ This companion branch keeps the native SLAM-Former pipeline as close as possible to the original implementation, while adding a submap-oriented memory backend and switchable queue semantics.
28
+
29
+ Core ideas:
30
+
31
+ * **`GraphGatedMemoryManager`** stores historical submaps on CPU, keeps descriptors, and performs loop retrieval.
32
+ * **`slam/demo_submap.py`** is the submap-aware inference entrypoint.
33
+ * **`src/forward_pass.py`** and **`src/forward_pass_submap.py`** slice the backend output so supervision is applied to the current submap while still letting the backend see `prev + curr + retrieved` context.
34
+ * **`src/finetune.py`** and the H20 launchers expose the submap configuration through script variables.
35
+ * **`run_tum.sh`** and **`run_tum_top5.sh`** drive TUM inference with either the full-token path or the top5 path.
36
+
37
+ Important submap switches:
38
+
39
+ * `TRAIN_SUBMAP_MODULES_ONLY`
40
+ * `1` freezes the main SLAMFormer parameters and trains the submap-side modules.
41
+ * `0` keeps the joint training path available.
42
+ * `SUBMAP_TRAIN_MODE`
43
+ * `full_token`: keep the submap path close to the native full-token behavior.
44
+ * `top5_dual_queue`: use two queues, where the frontend queue is read-only and the backend queue receives write-back for retrieved historical submaps.
45
+ * `SUBMAP_RETRIEVAL_TOPK`
46
+ * number of historical submaps fetched in the soft retrieval mode.
47
+ * `SUBMAP_FETCH_SOURCE` / `SUBMAP_DESCRIPTOR_SOURCE`
48
+ * choose whether retrieval and descriptor storage read from `frontend` or `backend` banks.
49
+ * `SUBMAP_INFERENCE_MODE`
50
+ * `full` or `top5` for the TUM launch scripts.
51
+
52
+ ### Training Modes
53
+
54
+ There are three common training setups in this branch:
55
+
56
+ #### 1. Joint baseline
57
+
58
+ Use the original joint configuration when you want the closest comparison to the official training branch:
59
+
60
+ * `config/finetune.yaml`
61
+ * `TRAIN_SUBMAP_MODULES_ONLY=0`
62
+ * `SUBMAP_TRAIN_MODE=full_token`
63
+ * `SUBMAP_RETRIEVAL_TOPK=0`
64
+
65
+ #### 2. Submap-only full-token training
66
+
67
+ This is the first submap stage: descriptors and historical submaps are trained with full-token submaps so historical submaps can still receive gradients.
68
+
69
+ Recommended config:
70
+
71
+ * `config/finetune_sub_only.yaml`
72
+ * or `config/finetune_pseudo_gt_high_recall.yaml`
73
+
74
+ Typical launch knobs:
75
+
76
+ * `TRAIN_SUBMAP_MODULES_ONLY=1`
77
+ * `SUBMAP_TRAIN_MODE=full_token`
78
+ * `SUBMAP_RETRIEVAL_TOPK=0`
79
+
80
+ Example:
81
+
82
+ ```bash
83
+ ENABLE_PSEUDO_GT=1 \
84
+ PSEUDO_GT_CACHE_PATH=/var/scratch/qzhang2/SLAM-Former/data/train/pseudo_gt/arkitscenes_smoke_test.json \
85
+ CONFIG_NAME=finetune_pseudo_gt_high_recall.yaml \
86
+ TRAIN_SUBMAP_MODULES_ONLY=1 \
87
+ SUBMAP_TRAIN_MODE=full_token \
88
+ SUBMAP_RETRIEVAL_TOPK=0 \
89
+ sbatch slam/sbatch_finetune.sh
90
+ ```
91
+
92
+ #### 3. Submap top5 fine-tuning
93
+
94
+ This is the second stage: the backend sees the top5 historical submaps, and the system uses the dual-queue semantics.
95
+
96
+ Typical launch knobs:
97
+
98
+ * `TRAIN_SUBMAP_MODULES_ONLY=1`
99
+ * `SUBMAP_TRAIN_MODE=top5_dual_queue`
100
+ * `SUBMAP_RETRIEVAL_TOPK=5`
101
+ * `SUBMAP_FETCH_SOURCE=frontend`
102
+ * `SUBMAP_DESCRIPTOR_SOURCE=frontend`
103
+
104
+ Example:
105
+
106
+ ```bash
107
+ ENABLE_PSEUDO_GT=1 \
108
+ PSEUDO_GT_CACHE_PATH=/var/scratch/qzhang2/SLAM-Former/data/train/pseudo_gt/arkitscenes_smoke_test.json \
109
+ CONFIG_NAME=finetune_pseudo_gt_high_recall.yaml \
110
+ TRAIN_SUBMAP_MODULES_ONLY=1 \
111
+ SUBMAP_TRAIN_MODE=top5_dual_queue \
112
+ SUBMAP_RETRIEVAL_TOPK=5 \
113
+ sbatch slam/sbatch_finetune.sh
114
+ ```
115
+
116
+ ### Inference Modes
117
+
118
+ #### Official baseline comparison
119
+
120
+ For the native baseline, keep using the original demo path from `README.md`:
121
+
122
+ ```bash
123
+ python slam/demo.py \
124
+ --ckpt_path ckpt/checkpoint.pth.model \
125
+ --image_folder /path/to/your/images/ \
126
+ --output_dir /output/result \
127
+ --target_size 518 \
128
+ --retention_ratio 0.5
129
+ ```
130
+
131
+ #### Submap TUM inference: full mode
132
+
133
+ This path uses the submap-aware demo but keeps the inference behavior in `full` mode.
134
+
135
+ ```bash
136
+ SUBMAP_INFERENCE_MODE=full \
137
+ CKPT_PATH=/var/scratch/qzhang2/SLAM-Former/ckpt/checkpoint-10.pth.model \
138
+ sbatch run_tum.sh
139
+ ```
140
+
141
+ #### Submap TUM inference: top5 mode
142
+
143
+ This is the dedicated top5 launcher for comparing against the full mode.
144
+
145
+ ```bash
146
+ sbatch run_tum_top5.sh
147
+ ```
148
+
149
+ You can also override the checkpoint and output root explicitly:
150
+
151
+ ```bash
152
+ CKPT_PATH=/var/scratch/qzhang2/SLAM-Former/checkpoints/local_cluster_nv24_sub6/submap_only_pseudo_gt_high_recall_smoke/paper_local_submap_only_pseudo_gt_high_recall_smoke_nv24_sub6/checkpoint-last.pth \
153
+ SUBMAP_INFERENCE_MODE=top5 \
154
+ RUN_TAG=my_top5_compare \
155
+ OUT_DIR=/var/scratch/qzhang2/SLAM-Former/tum_results_aligned_top5/my_top5_compare \
156
+ sbatch run_tum_top5.sh
157
+ ```
158
+
159
+ ### Launch Scripts
160
+
161
+ | Script | Purpose | Notes |
162
+ | --- | --- | --- |
163
+ | `slam/sbatch_finetune.sh` | Local 3-GPU FSDP finetune launcher | Uses environment variables to select config, submap mode, and pseudo-GT cache. |
164
+ | `slam/run_train_h20_single.sh` | Single-node H20-style training launcher | Good when you want to run the training job directly without the wrapper. |
165
+ | `slam/run_train_h20_multi.sh` | Multi-node H20-style training launcher | Keeps the same submap knobs, but launches across multiple machines. |
166
+ | `run_tum.sh` | TUM inference launcher | Supports `SUBMAP_INFERENCE_MODE=full|top5`. |
167
+ | `run_tum_top5.sh` | Dedicated TUM top5 launcher | Defaults to the latest submap checkpoint and top5 mode. |
168
+ | `slam/demo_submap.py` | Manual inference entrypoint | Accepts `--submap_train_mode`, `--submap_retrieval_topk`, `--loop_mask_mode`, `--submap_fetch_source`, `--submap_descriptor_source`, and `--max_recursive_submaps`. |
169
+
170
+ ### Configuration Files
171
+
172
+ | Config | Role |
173
+ | --- | --- |
174
+ | `config/finetune.yaml` | Joint training baseline. |
175
+ | `config/finetune_sub_only.yaml` | Submap-only training with full-token semantics. |
176
+ | `config/finetune_pseudo_gt_high_recall.yaml` | Submap-only training with higher-recall pseudo-GT settings. |
177
+
178
+ ### Data and Checkpoints
179
+
180
+ The data layout is the same as the original project. The TUM root used by the launch scripts is:
181
+
182
+ ```bash
183
+ /var/scratch/qzhang2/Feature-SLAM/datasets/tum
184
+ ```
185
+
186
+ Typical checkpoint locations are under:
187
+
188
+ ```bash
189
+ /var/scratch/qzhang2/SLAM-Former/checkpoints/
190
+ ```
191
+
192
+ For inference, the scripts usually read `checkpoint-last.pth` from the latest experiment folder.
193
+
194
+ ### Output Layout
195
+
196
+ Submap TUM inference writes one folder per sequence, for example:
197
+
198
+ ```bash
199
+ tum_results_aligned_top5/<run_tag>/rgbd_dataset_freiburg1_360/
200
+ ```
201
+
202
+ Inside each sequence directory you should expect:
203
+
204
+ * `final_traj.txt`
205
+ * `final.ply`
206
+ * `final_pc/`
207
+
208
+ ### Visualization
209
+
210
+ Static visualization is the same as the original branch:
211
+
212
+ ```bash
213
+ python slam/visualize_results.py --result_dir /path/to/output_dir
214
+ ```
215
+
216
+ For TUM results generated by the submap scripts, point `--result_dir` to the corresponding output folder.
217
+
218
+ ### Notes
219
+
220
+ * Keep `README.md` as the official baseline reference.
221
+ * Use this file when you want to describe or run the submap branch.
222
+ * The main comparison axes are:
223
+ * official baseline vs submap branch
224
+ * full-token submap mode vs top5 dual-queue mode
225
+ * joint training vs submap-only training
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/base_opt.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import roma
7
+ from copy import deepcopy
8
+ import tqdm
9
+ import os
10
+ import matplotlib.pyplot as plt
11
+
12
+ from cloud_opt.utils import *
13
+ from cloud_opt.utils import _check_edges, _compute_img_conf
14
+ import cloud_opt.init_all as init_fun
15
+
16
+
17
+ class BaseOptimizer(nn.Module):
18
+ """Optimize a global scene, given a graph-organized observations.
19
+ Graph node: images
20
+ Graph edges: observations = (pred1, pred2), pred2 is in pred1's coordinate
21
+ """
22
+
23
+ def __init__(self, *args, **kwargs):
24
+ pass
25
+
26
+ def _init_from_views(
27
+ self,
28
+ view1s,
29
+ view2s,
30
+ pred1s,
31
+ pred2s, # whatever predictions, they should be organized into pairwise for graph optimization
32
+ dist="l1",
33
+ conf="log",
34
+ min_conf_thr=3,
35
+ thr_for_init_conf=False,
36
+ base_scale=0.5,
37
+ allow_pw_adaptors=False,
38
+ pw_break=20,
39
+ rand_pose=torch.randn,
40
+ empty_cache=False,
41
+ verbose=True,
42
+ ):
43
+ super().__init__()
44
+ self.edges = [
45
+ (int(view1["idx"]), int(view2["idx"]))
46
+ for view1, view2 in zip(view1s, view2s)
47
+ ]
48
+ self.dist = ALL_DISTS[dist]
49
+ self.n_imgs = _check_edges(self.edges)
50
+
51
+ self.edge2pts_i = NoGradParamDict(
52
+ {ij: pred1s[n]["pts3d_is_self_view"] for n, ij in enumerate(self.str_edges)}
53
+ ) # ij: the name of the edge
54
+ self.edge2pts_j = NoGradParamDict(
55
+ {
56
+ ij: pred2s[n]["pts3d_in_other_view"]
57
+ for n, ij in enumerate(self.str_edges)
58
+ }
59
+ )
60
+ self.edge2conf_i = NoGradParamDict(
61
+ {ij: pred1s[n]["conf_self"] for n, ij in enumerate(self.str_edges)}
62
+ )
63
+ self.edge2conf_j = NoGradParamDict(
64
+ {ij: pred2s[n]["conf"] for n, ij in enumerate(self.str_edges)}
65
+ )
66
+
67
+ self.imshapes = get_imshapes(self.edges, pred1s, pred2s)
68
+ self.min_conf_thr = min_conf_thr
69
+ self.thr_for_init_conf = thr_for_init_conf
70
+ self.conf_trf = get_conf_trf(conf)
71
+
72
+ self.im_conf = _compute_img_conf(
73
+ self.imshapes, self.device, self.edges, self.edge2conf_i, self.edge2conf_j
74
+ )
75
+ for i in range(len(self.im_conf)):
76
+ self.im_conf[i].requires_grad = False
77
+
78
+ self.init_conf_maps = [c.clone() for c in self.im_conf]
79
+
80
+ self.base_scale = base_scale
81
+ self.norm_pw_scale = True
82
+ self.pw_break = pw_break
83
+ self.POSE_DIM = 7
84
+ self.pw_poses = nn.Parameter(
85
+ rand_pose((self.n_edges, 1 + self.POSE_DIM))
86
+ ) # pairwise poses
87
+ self.pw_adaptors = nn.Parameter(
88
+ torch.zeros((self.n_edges, 2))
89
+ ) # slight xy/z adaptation
90
+ self.pw_adaptors.requires_grad_(allow_pw_adaptors)
91
+ self.has_im_poses = False
92
+ self.rand_pose = rand_pose
93
+
94
+ def get_known_poses(self):
95
+ if self.has_im_poses:
96
+ known_poses_msk = torch.tensor(
97
+ [not (p.requires_grad) for p in self.im_poses]
98
+ )
99
+ known_poses = self.get_im_poses()
100
+ return known_poses_msk.sum(), known_poses_msk, known_poses
101
+ else:
102
+ return 0, None, None
103
+
104
+ def get_pw_norm_scale_factor(self):
105
+ if self.norm_pw_scale:
106
+ # normalize scales so that things cannot go south
107
+ # we want that exp(scale) ~= self.base_scale
108
+ return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
109
+ else:
110
+ return 1 # don't norm scale for known poses
111
+
112
+ def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
113
+ # all poses == cam-to-world
114
+ pose = poses[idx]
115
+ if not (pose.requires_grad or force):
116
+ return pose
117
+
118
+ if R.shape == (4, 4):
119
+ assert T is None
120
+ T = R[:3, 3]
121
+ R = R[:3, :3]
122
+
123
+ if R is not None:
124
+ pose.data[0:4] = roma.rotmat_to_unitquat(R)
125
+ if T is not None:
126
+ pose.data[4:7] = signed_log1p(
127
+ T / (scale or 1)
128
+ ) # translation is function of scale
129
+
130
+ if scale is not None:
131
+ assert poses.shape[-1] in (8, 13)
132
+ pose.data[-1] = np.log(float(scale))
133
+ return pose
134
+
135
+ def forward(self, ret_details=False):
136
+ pw_poses = self.get_pw_poses() # cam-to-world
137
+ pw_adapt = self.get_adaptors()
138
+ proj_pts3d = self.get_pts3d()
139
+ # pre-compute pixel weights
140
+ weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
141
+ weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
142
+
143
+ loss = 0
144
+ if ret_details:
145
+ details = -torch.ones((self.n_imgs, self.n_imgs))
146
+
147
+ for e, (i, j) in enumerate(self.edges):
148
+ i_j = edge_str(i, j)
149
+ # distance in image i and j
150
+ aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
151
+ aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
152
+ li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
153
+ lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
154
+ loss = loss + li + lj
155
+
156
+ if ret_details:
157
+ details[i, j] = li + lj
158
+ loss /= self.n_edges # average over all pairs
159
+
160
+ if ret_details:
161
+ return loss, details
162
+ return loss
163
+
164
+ @torch.cuda.amp.autocast(enabled=False)
165
+ def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
166
+ if init is None:
167
+ pass
168
+ elif init == "msp" or init == "mst":
169
+ init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
170
+ elif init == "known_poses":
171
+ raise NotImplementedError
172
+ self.preset_pose(known_poses=self.camera_poses, requires_grad=True)
173
+ init_fun.init_from_known_poses(
174
+ self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP
175
+ )
176
+ else:
177
+ raise ValueError(f"bad value for {init=}")
178
+
179
+ return global_alignment_loop(self, **kw)
180
+
181
+ @property
182
+ def str_edges(self):
183
+ return [edge_str(i, j) for i, j in self.edges]
184
+
185
+ @property
186
+ def n_edges(self):
187
+ return len(self.edges)
188
+
189
+
190
+ def global_alignment_loop(
191
+ net,
192
+ lr=0.01,
193
+ niter=300,
194
+ schedule="cosine",
195
+ lr_min=1e-3,
196
+ temporal_smoothing_weight=0,
197
+ depth_map_save_dir=None,
198
+ ):
199
+ params = [p for p in net.parameters() if p.requires_grad]
200
+ if not params:
201
+ return net
202
+
203
+ verbose = net.verbose
204
+ if verbose:
205
+ print("Global alignement - optimizing for:")
206
+ print([name for name, value in net.named_parameters() if value.requires_grad])
207
+
208
+ lr_base = lr
209
+ optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
210
+
211
+ loss = float("inf")
212
+ if verbose:
213
+ with tqdm.tqdm(total=niter) as bar:
214
+ while bar.n < bar.total:
215
+ if bar.n % 500 == 0 and depth_map_save_dir is not None:
216
+ if not os.path.exists(depth_map_save_dir):
217
+ os.makedirs(depth_map_save_dir)
218
+ # visualize the depthmaps
219
+ depth_maps = net.get_depthmaps()
220
+ for i, depth_map in enumerate(depth_maps):
221
+ depth_map_save_path = os.path.join(
222
+ depth_map_save_dir, f"depthmaps_{i}_iter_{bar.n}.png"
223
+ )
224
+ plt.imsave(
225
+ depth_map_save_path,
226
+ depth_map.detach().cpu().numpy(),
227
+ cmap="jet",
228
+ )
229
+ print(
230
+ f"Saved depthmaps at iteration {bar.n} to {depth_map_save_dir}"
231
+ )
232
+ loss, lr = global_alignment_iter(
233
+ net,
234
+ bar.n,
235
+ niter,
236
+ lr_base,
237
+ lr_min,
238
+ optimizer,
239
+ schedule,
240
+ temporal_smoothing_weight=temporal_smoothing_weight,
241
+ )
242
+ bar.set_postfix_str(f"{lr=:g} loss={loss:g}")
243
+ bar.update()
244
+ else:
245
+ for n in range(niter):
246
+ loss, _ = global_alignment_iter(
247
+ net,
248
+ n,
249
+ niter,
250
+ lr_base,
251
+ lr_min,
252
+ optimizer,
253
+ schedule,
254
+ temporal_smoothing_weight=temporal_smoothing_weight,
255
+ )
256
+ return loss
257
+
258
+
259
+ def global_alignment_iter(
260
+ net,
261
+ cur_iter,
262
+ niter,
263
+ lr_base,
264
+ lr_min,
265
+ optimizer,
266
+ schedule,
267
+ temporal_smoothing_weight=0,
268
+ ):
269
+ t = cur_iter / niter
270
+ if schedule == "cosine":
271
+ lr = cosine_schedule(t, lr_base, lr_min)
272
+ elif schedule == "linear":
273
+ lr = linear_schedule(t, lr_base, lr_min)
274
+ elif schedule.startswith("cycle"):
275
+ try:
276
+ num_cycles = int(schedule[5:])
277
+ except ValueError:
278
+ num_cycles = 2
279
+ lr = cycled_linear_schedule(t, lr_base, lr_min, num_cycles=num_cycles)
280
+ else:
281
+ raise ValueError(f"bad lr {schedule=}")
282
+
283
+ adjust_learning_rate_by_lr(optimizer, lr)
284
+ optimizer.zero_grad()
285
+
286
+ if net.empty_cache:
287
+ torch.cuda.empty_cache()
288
+
289
+ loss = net(epoch=cur_iter)
290
+
291
+ if net.empty_cache:
292
+ torch.cuda.empty_cache()
293
+
294
+ loss.backward()
295
+
296
+ if net.empty_cache:
297
+ torch.cuda.empty_cache()
298
+
299
+ optimizer.step()
300
+
301
+ return float(loss), lr
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/commons.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utility functions for global alignment
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+
11
+
12
+ def edge_str(i, j):
13
+ return f"{i}_{j}"
14
+
15
+
16
+ def i_j_ij(ij):
17
+ return edge_str(*ij), ij
18
+
19
+
20
+ def edge_conf(conf_i, conf_j, edge):
21
+ return float(conf_i[edge].mean() * conf_j[edge].mean())
22
+
23
+
24
+ def compute_edge_scores(edges, conf_i, conf_j):
25
+ return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
26
+
27
+
28
+ def NoGradParamDict(x):
29
+ assert isinstance(x, dict)
30
+ return nn.ParameterDict(x).requires_grad_(False)
31
+
32
+
33
+ def get_imshapes(edges, pred_i, pred_j):
34
+ n_imgs = max(max(e) for e in edges) + 1
35
+ imshapes = [None] * n_imgs
36
+ for e, (i, j) in enumerate(edges):
37
+ shape_i = tuple(pred_i[e].shape[0:2])
38
+ shape_j = tuple(pred_j[e].shape[0:2])
39
+ if imshapes[i]:
40
+ assert imshapes[i] == shape_i, f"incorrect shape for image {i}"
41
+ if imshapes[j]:
42
+ assert imshapes[j] == shape_j, f"incorrect shape for image {j}"
43
+ imshapes[i] = shape_i
44
+ imshapes[j] = shape_j
45
+ return imshapes
46
+
47
+
48
+ def get_conf_trf(mode):
49
+ if mode == "log":
50
+
51
+ def conf_trf(x):
52
+ return x.log()
53
+
54
+ elif mode == "sqrt":
55
+
56
+ def conf_trf(x):
57
+ return x.sqrt()
58
+
59
+ elif mode == "m1":
60
+
61
+ def conf_trf(x):
62
+ return x - 1
63
+
64
+ elif mode in ("id", "none"):
65
+
66
+ def conf_trf(x):
67
+ return x
68
+
69
+ else:
70
+ raise ValueError(f"bad mode for {mode=}")
71
+ return conf_trf
72
+
73
+
74
+ def l2_dist(a, b, weight):
75
+ return (a - b).square().sum(dim=-1) * weight
76
+
77
+
78
+ def l1_dist(a, b, weight):
79
+ return (a - b).norm(dim=-1) * weight
80
+
81
+
82
+ ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
83
+
84
+
85
+ def signed_log1p(x):
86
+ sign = torch.sign(x)
87
+ return sign * torch.log1p(torch.abs(x))
88
+
89
+
90
+ def signed_expm1(x):
91
+ sign = torch.sign(x)
92
+ return sign * torch.expm1(torch.abs(x))
93
+
94
+
95
+ def cosine_schedule(t, lr_start, lr_end):
96
+ assert 0 <= t <= 1
97
+ return lr_end + (lr_start - lr_end) * (1 + np.cos(t * np.pi)) / 2
98
+
99
+
100
+ def linear_schedule(t, lr_start, lr_end):
101
+ assert 0 <= t <= 1
102
+ return lr_start + (lr_end - lr_start) * t
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # global alignment optimization wrapper function
6
+ # --------------------------------------------------------
7
+ from enum import Enum
8
+
9
+ from .optimizer import PointCloudOptimizer
10
+
11
+
12
+ class GlobalAlignerMode(Enum):
13
+ PointCloudOptimizer = "PointCloudOptimizer"
14
+ ModularPointCloudOptimizer = "ModularPointCloudOptimizer"
15
+ PairViewer = "PairViewer"
16
+
17
+
18
+ def global_aligner(
19
+ dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw
20
+ ):
21
+ # extract all inputs
22
+ view1, view2, pred1, pred2 = [
23
+ dust3r_output[k] for k in "view1 view2 pred1 pred2".split()
24
+ ]
25
+ # build the optimizer
26
+ if mode == GlobalAlignerMode.PointCloudOptimizer:
27
+ net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
28
+ else:
29
+ raise NotImplementedError(f"Unknown mode {mode}")
30
+
31
+ return net
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/base_opt.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Base class for the global alignement procedure
6
+ # --------------------------------------------------------
7
+ from copy import deepcopy
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import roma
13
+ from copy import deepcopy
14
+ import tqdm
15
+ import cv2
16
+ from PIL import Image
17
+ from dust3r.utils.geometry import inv, geotrf
18
+ from dust3r.utils.device import to_numpy
19
+ from dust3r.utils.image import rgb
20
+ from dust3r.viz import SceneViz, segment_sky, auto_cam_size
21
+
22
+ from cloud_opt.dust3r_opt.commons import (
23
+ edge_str,
24
+ ALL_DISTS,
25
+ NoGradParamDict,
26
+ get_imshapes,
27
+ signed_expm1,
28
+ signed_log1p,
29
+ cosine_schedule,
30
+ linear_schedule,
31
+ get_conf_trf,
32
+ )
33
+ import cloud_opt.dust3r_opt.init_im_poses as init_fun
34
+ from pathlib import Path
35
+ from scipy.spatial.transform import Rotation
36
+ from evo.core.trajectory import PosePath3D, PoseTrajectory3D
37
+
38
+
39
+ def adjust_learning_rate_by_lr(optimizer, lr):
40
+ for param_group in optimizer.param_groups:
41
+ if "lr_scale" in param_group:
42
+ param_group["lr"] = lr * param_group["lr_scale"]
43
+ else:
44
+ param_group["lr"] = lr
45
+
46
+
47
+ def make_traj(args) -> PoseTrajectory3D:
48
+ if isinstance(args, tuple) or isinstance(args, list):
49
+ traj, tstamps = args
50
+ return PoseTrajectory3D(
51
+ positions_xyz=traj[:, :3],
52
+ orientations_quat_wxyz=traj[:, 3:],
53
+ timestamps=tstamps,
54
+ )
55
+ assert isinstance(args, PoseTrajectory3D), type(args)
56
+ return deepcopy(args)
57
+
58
+
59
+ def save_trajectory_tum_format(traj, filename):
60
+ traj = make_traj(traj)
61
+ tostr = lambda a: " ".join(map(str, a))
62
+ with Path(filename).open("w") as f:
63
+ for i in range(traj.num_poses):
64
+ f.write(
65
+ f"{traj.timestamps[i]} {tostr(traj.positions_xyz[i])} {tostr(traj.orientations_quat_wxyz[i][[0,1,2,3]])}\n"
66
+ )
67
+ print(f"Saved trajectory to {filename}")
68
+
69
+
70
+ def c2w_to_tumpose(c2w):
71
+ """
72
+ Convert a camera-to-world matrix to a tuple of translation and rotation
73
+
74
+ input: c2w: 4x4 matrix
75
+ output: tuple of translation and rotation (x y z qw qx qy qz)
76
+ """
77
+ # convert input to numpy
78
+ c2w = to_numpy(c2w)
79
+ xyz = c2w[:3, -1]
80
+ rot = Rotation.from_matrix(c2w[:3, :3])
81
+ qx, qy, qz, qw = rot.as_quat()
82
+ tum_pose = np.concatenate([xyz, [qw, qx, qy, qz]])
83
+ return tum_pose
84
+
85
+
86
+ class BasePCOptimizer(nn.Module):
87
+ """Optimize a global scene, given a list of pairwise observations.
88
+ Graph node: images
89
+ Graph edges: observations = (pred1, pred2)
90
+ """
91
+
92
+ def __init__(self, *args, **kwargs):
93
+ if len(args) == 1 and len(kwargs) == 0:
94
+ other = deepcopy(args[0])
95
+ attrs = """edges is_symmetrized dist n_imgs pred_i pred_j imshapes
96
+ min_conf_thr conf_thr conf_i conf_j im_conf
97
+ base_scale norm_pw_scale POSE_DIM pw_poses
98
+ pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose""".split()
99
+ self.__dict__.update({k: other[k] for k in attrs})
100
+ else:
101
+ self._init_from_views(*args, **kwargs)
102
+
103
+ def _init_from_views(
104
+ self,
105
+ view1,
106
+ view2,
107
+ pred1,
108
+ pred2,
109
+ dist="l1",
110
+ conf="log",
111
+ min_conf_thr=3,
112
+ base_scale=0.5,
113
+ allow_pw_adaptors=False,
114
+ pw_break=20,
115
+ rand_pose=torch.randn,
116
+ iterationsCount=None,
117
+ verbose=True,
118
+ ):
119
+ super().__init__()
120
+ if not isinstance(view1["idx"], list):
121
+ view1["idx"] = view1["idx"].tolist()
122
+ if not isinstance(view2["idx"], list):
123
+ view2["idx"] = view2["idx"].tolist()
124
+ self.edges = [(int(i), int(j)) for i, j in zip(view1["idx"], view2["idx"])]
125
+ self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
126
+ self.dist = ALL_DISTS[dist]
127
+ self.verbose = verbose
128
+
129
+ self.n_imgs = self._check_edges()
130
+
131
+ # input data
132
+ pred1_pts = pred1["pts3d_in_self_view"]
133
+ pred2_pts = pred2["pts3d_in_other_view"]
134
+ self.pred_i = NoGradParamDict(
135
+ {ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}
136
+ )
137
+ self.pred_j = NoGradParamDict(
138
+ {ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}
139
+ )
140
+ self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)
141
+
142
+ # work in log-scale with conf
143
+ pred1_conf = pred1["conf_self"]
144
+ pred2_conf = pred2["conf"]
145
+ self.min_conf_thr = min_conf_thr
146
+ self.conf_trf = get_conf_trf(conf)
147
+
148
+ self.conf_i = NoGradParamDict(
149
+ {ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)}
150
+ )
151
+ self.conf_j = NoGradParamDict(
152
+ {ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)}
153
+ )
154
+ self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
155
+ for i in range(len(self.im_conf)):
156
+ self.im_conf[i].requires_grad = False
157
+
158
+ # pairwise pose parameters
159
+ self.base_scale = base_scale
160
+ self.norm_pw_scale = True
161
+ self.pw_break = pw_break
162
+ self.POSE_DIM = 7
163
+ self.pw_poses = nn.Parameter(
164
+ rand_pose((self.n_edges, 1 + self.POSE_DIM))
165
+ ) # pairwise poses
166
+ self.pw_adaptors = nn.Parameter(
167
+ torch.zeros((self.n_edges, 2))
168
+ ) # slight xy/z adaptation
169
+ self.pw_adaptors.requires_grad_(allow_pw_adaptors)
170
+ self.has_im_poses = False
171
+ self.rand_pose = rand_pose
172
+
173
+ # possibly store images for show_pointcloud
174
+ self.imgs = None
175
+ if "img" in view1 and "img" in view2:
176
+ imgs = [torch.zeros((3,) + hw) for hw in self.imshapes]
177
+ for v in range(len(self.edges)):
178
+ idx = view1["idx"][v]
179
+ imgs[idx] = view1["img"][v]
180
+ idx = view2["idx"][v]
181
+ imgs[idx] = view2["img"][v]
182
+ self.imgs = rgb(imgs)
183
+
184
+ @property
185
+ def n_edges(self):
186
+ return len(self.edges)
187
+
188
+ @property
189
+ def str_edges(self):
190
+ return [edge_str(i, j) for i, j in self.edges]
191
+
192
+ @property
193
+ def imsizes(self):
194
+ return [(w, h) for h, w in self.imshapes]
195
+
196
+ @property
197
+ def device(self):
198
+ return next(iter(self.parameters())).device
199
+
200
+ def state_dict(self, trainable=True):
201
+ all_params = super().state_dict()
202
+ return {
203
+ k: v
204
+ for k, v in all_params.items()
205
+ if k.startswith(("_", "pred_i.", "pred_j.", "conf_i.", "conf_j."))
206
+ != trainable
207
+ }
208
+
209
+ def load_state_dict(self, data):
210
+ return super().load_state_dict(self.state_dict(trainable=False) | data)
211
+
212
+ def _check_edges(self):
213
+ indices = sorted({i for edge in self.edges for i in edge})
214
+ assert indices == list(range(len(indices))), "bad pair indices: missing values "
215
+ return len(indices)
216
+
217
+ @torch.no_grad()
218
+ def _compute_img_conf(self, pred1_conf, pred2_conf):
219
+ im_conf = nn.ParameterList(
220
+ [torch.zeros(hw, device=self.device) for hw in self.imshapes]
221
+ )
222
+ for e, (i, j) in enumerate(self.edges):
223
+ im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
224
+ im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
225
+ return im_conf
226
+
227
+ def get_adaptors(self):
228
+ adapt = self.pw_adaptors
229
+ adapt = torch.cat(
230
+ (adapt[:, 0:1], adapt), dim=-1
231
+ ) # (scale_xy, scale_xy, scale_z)
232
+ if self.norm_pw_scale: # normalize so that the product == 1
233
+ adapt = adapt - adapt.mean(dim=1, keepdim=True)
234
+ return (adapt / self.pw_break).exp()
235
+
236
+ def _get_poses(self, poses):
237
+ # normalize rotation
238
+ Q = poses[:, :4]
239
+ T = signed_expm1(poses[:, 4:7])
240
+ RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
241
+ return RT
242
+
243
+ def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
244
+ # all poses == cam-to-world
245
+ pose = poses[idx]
246
+ if not (pose.requires_grad or force):
247
+ return pose
248
+
249
+ if R.shape == (4, 4):
250
+ assert T is None
251
+ T = R[:3, 3]
252
+ R = R[:3, :3]
253
+
254
+ if R is not None:
255
+ pose.data[0:4] = roma.rotmat_to_unitquat(R)
256
+ if T is not None:
257
+ pose.data[4:7] = signed_log1p(
258
+ T / (scale or 1)
259
+ ) # translation is function of scale
260
+
261
+ if scale is not None:
262
+ assert poses.shape[-1] in (8, 13)
263
+ pose.data[-1] = np.log(float(scale))
264
+ return pose
265
+
266
+ def get_pw_norm_scale_factor(self):
267
+ if self.norm_pw_scale:
268
+ # normalize scales so that things cannot go south
269
+ # we want that exp(scale) ~= self.base_scale
270
+ return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
271
+ else:
272
+ return 1 # don't norm scale for known poses
273
+
274
+ def get_pw_scale(self):
275
+ scale = self.pw_poses[:, -1].exp() # (n_edges,)
276
+ scale = scale * self.get_pw_norm_scale_factor()
277
+ return scale
278
+
279
+ def get_pw_poses(self): # cam to world
280
+ RT = self._get_poses(self.pw_poses)
281
+ scaled_RT = RT.clone()
282
+ scaled_RT[:, :3] *= self.get_pw_scale().view(
283
+ -1, 1, 1
284
+ ) # scale the rotation AND translation
285
+ return scaled_RT
286
+
287
+ def get_masks(self):
288
+ return [(conf > self.min_conf_thr) for conf in self.im_conf]
289
+
290
+ def depth_to_pts3d(self):
291
+ raise NotImplementedError()
292
+
293
+ def get_pts3d(self, raw=False):
294
+ res = self.depth_to_pts3d()
295
+ if not raw:
296
+ res = [dm[: h * w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
297
+ return res
298
+
299
+ def _set_focal(self, idx, focal, force=False):
300
+ raise NotImplementedError()
301
+
302
+ def get_focals(self):
303
+ raise NotImplementedError()
304
+
305
+ def get_known_focal_mask(self):
306
+ raise NotImplementedError()
307
+
308
+ def get_principal_points(self):
309
+ raise NotImplementedError()
310
+
311
+ def get_conf(self, mode=None):
312
+ trf = self.conf_trf if mode is None else get_conf_trf(mode)
313
+ return [trf(c) for c in self.im_conf]
314
+
315
+ def get_im_poses(self):
316
+ raise NotImplementedError()
317
+
318
+ def _set_depthmap(self, idx, depth, force=False):
319
+ raise NotImplementedError()
320
+
321
+ def get_depthmaps(self, raw=False):
322
+ raise NotImplementedError()
323
+
324
+ def save_depth_maps(self, path):
325
+ depth_maps = self.get_depthmaps()
326
+ images = []
327
+
328
+ for i, depth_map in enumerate(depth_maps):
329
+ # Apply color map to depth map
330
+ depth_map_colored = cv2.applyColorMap(
331
+ (depth_map * 255).detach().cpu().numpy().astype(np.uint8),
332
+ cv2.COLORMAP_JET,
333
+ )
334
+ img_path = f"{path}/frame_{(i):04d}.png"
335
+ cv2.imwrite(img_path, depth_map_colored)
336
+ images.append(Image.open(img_path))
337
+ np.save(f"{path}/frame_{(i):04d}.npy", depth_map.detach().cpu().numpy())
338
+
339
+ images[0].save(
340
+ f"{path}/_depth_maps.gif",
341
+ save_all=True,
342
+ append_images=images[1:],
343
+ duration=100,
344
+ loop=0,
345
+ )
346
+
347
+ return depth_maps
348
+
349
+ def clean_pointcloud(self, **kw):
350
+ cams = inv(self.get_im_poses())
351
+ K = self.get_intrinsics()
352
+ depthmaps = self.get_depthmaps()
353
+ all_pts3d = self.get_pts3d()
354
+
355
+ new_im_confs = clean_pointcloud(
356
+ self.im_conf, K, cams, depthmaps, all_pts3d, **kw
357
+ )
358
+ for i, new_conf in enumerate(new_im_confs):
359
+ self.im_conf[i].data[:] = new_conf
360
+ return self
361
+
362
+ def get_tum_poses(self):
363
+ poses = self.get_im_poses()
364
+ tt = np.arange(len(poses)).astype(float)
365
+ tum_poses = [c2w_to_tumpose(p) for p in poses]
366
+ tum_poses = np.stack(tum_poses, 0)
367
+ return [tum_poses, tt]
368
+
369
+ def save_tum_poses(self, path):
370
+ traj = self.get_tum_poses()
371
+ save_trajectory_tum_format(traj, path)
372
+ return traj[0] # return the poses
373
+
374
+ def save_focals(self, path):
375
+ # convert focal to txt
376
+ focals = self.get_focals()
377
+ np.savetxt(path, focals.detach().cpu().numpy(), fmt="%.6f")
378
+ return focals
379
+
380
+ def save_intrinsics(self, path):
381
+ K_raw = self.get_intrinsics()
382
+ K = K_raw.reshape(-1, 9)
383
+ np.savetxt(path, K.detach().cpu().numpy(), fmt="%.6f")
384
+ return K_raw
385
+
386
+ def save_conf_maps(self, path):
387
+ conf = self.get_conf()
388
+ for i, c in enumerate(conf):
389
+ np.save(f"{path}/conf_{i}.npy", c.detach().cpu().numpy())
390
+ return conf
391
+
392
+ def save_init_conf_maps(self, path):
393
+ conf = self.get_init_conf()
394
+ for i, c in enumerate(conf):
395
+ np.save(f"{path}/init_conf_{i}.npy", c.detach().cpu().numpy())
396
+ return conf
397
+
398
+ def save_rgb_imgs(self, path):
399
+ imgs = self.imgs
400
+ for i, img in enumerate(imgs):
401
+ # convert from rgb to bgr
402
+ img = img[..., ::-1]
403
+ cv2.imwrite(f"{path}/frame_{i:04d}.png", img * 255)
404
+ return imgs
405
+
406
+ def save_dynamic_masks(self, path):
407
+ dynamic_masks = (
408
+ self.dynamic_masks
409
+ if getattr(self, "sam2_dynamic_masks", None) is None
410
+ else self.sam2_dynamic_masks
411
+ )
412
+ for i, dynamic_mask in enumerate(dynamic_masks):
413
+ cv2.imwrite(
414
+ f"{path}/dynamic_mask_{i}.png",
415
+ (dynamic_mask * 255).detach().cpu().numpy().astype(np.uint8),
416
+ )
417
+ return dynamic_masks
418
+
419
+ def save_depth_maps(self, path):
420
+ depth_maps = self.get_depthmaps()
421
+ images = []
422
+
423
+ for i, depth_map in enumerate(depth_maps):
424
+ # Apply color map to depth map
425
+ depth_map_colored = cv2.applyColorMap(
426
+ (depth_map * 255).detach().cpu().numpy().astype(np.uint8),
427
+ cv2.COLORMAP_JET,
428
+ )
429
+ img_path = f"{path}/frame_{(i):04d}.png"
430
+ cv2.imwrite(img_path, depth_map_colored)
431
+ images.append(Image.open(img_path))
432
+ np.save(f"{path}/frame_{(i):04d}.npy", depth_map.detach().cpu().numpy())
433
+
434
+ images[0].save(
435
+ f"{path}/_depth_maps.gif",
436
+ save_all=True,
437
+ append_images=images[1:],
438
+ duration=100,
439
+ loop=0,
440
+ )
441
+
442
+ return depth_maps
443
+
444
+ def forward(self, ret_details=False):
445
+ pw_poses = self.get_pw_poses() # cam-to-world
446
+ pw_adapt = self.get_adaptors()
447
+ proj_pts3d = self.get_pts3d()
448
+ # pre-compute pixel weights
449
+ weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
450
+ weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
451
+
452
+ loss = 0
453
+ if ret_details:
454
+ details = -torch.ones((self.n_imgs, self.n_imgs))
455
+
456
+ for e, (i, j) in enumerate(self.edges):
457
+ i_j = edge_str(i, j)
458
+ # distance in image i and j
459
+ aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
460
+ aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
461
+ li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
462
+ lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
463
+ loss = loss + li + lj
464
+
465
+ if ret_details:
466
+ details[i, j] = li + lj
467
+ loss /= self.n_edges # average over all pairs
468
+
469
+ if ret_details:
470
+ return loss, details
471
+ return loss
472
+
473
+ @torch.cuda.amp.autocast(enabled=False)
474
+ def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
475
+ if init is None:
476
+ pass
477
+ elif init == "msp" or init == "mst":
478
+ init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
479
+ elif init == "known_poses":
480
+ init_fun.init_from_known_poses(
481
+ self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP
482
+ )
483
+ else:
484
+ raise ValueError(f"bad value for {init=}")
485
+ return global_alignment_loop(self, **kw)
486
+
487
+ @torch.no_grad()
488
+ def mask_sky(self):
489
+ res = deepcopy(self)
490
+ for i in range(self.n_imgs):
491
+ sky = segment_sky(self.imgs[i])
492
+ res.im_conf[i][sky] = 0
493
+ return res
494
+
495
+ def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
496
+ viz = SceneViz()
497
+ if self.imgs is None:
498
+ colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
499
+ colors = list(map(tuple, colors.tolist()))
500
+ for n in range(self.n_imgs):
501
+ viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
502
+ else:
503
+ viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
504
+ colors = np.random.randint(256, size=(self.n_imgs, 3))
505
+
506
+ # camera poses
507
+ im_poses = to_numpy(self.get_im_poses())
508
+ if cam_size is None:
509
+ cam_size = auto_cam_size(im_poses)
510
+ viz.add_cameras(
511
+ im_poses,
512
+ self.get_focals(),
513
+ colors=colors,
514
+ images=self.imgs,
515
+ imsizes=self.imsizes,
516
+ cam_size=cam_size,
517
+ )
518
+ if show_pw_cams:
519
+ pw_poses = self.get_pw_poses()
520
+ viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)
521
+
522
+ if show_pw_pts3d:
523
+ pts = [
524
+ geotrf(pw_poses[e], self.pred_i[edge_str(i, j)])
525
+ for e, (i, j) in enumerate(self.edges)
526
+ ]
527
+ viz.add_pointcloud(pts, (128, 0, 128))
528
+
529
+ viz.show(**kw)
530
+ return viz
531
+
532
+
533
+ def global_alignment_loop(net, lr=0.01, niter=300, schedule="cosine", lr_min=1e-6):
534
+ params = [p for p in net.parameters() if p.requires_grad]
535
+ if not params:
536
+ return net
537
+
538
+ verbose = net.verbose
539
+ if verbose:
540
+ print("Global alignement - optimizing for:")
541
+ print([name for name, value in net.named_parameters() if value.requires_grad])
542
+
543
+ lr_base = lr
544
+ optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
545
+
546
+ loss = float("inf")
547
+ if verbose:
548
+ with tqdm.tqdm(total=niter) as bar:
549
+ while bar.n < bar.total:
550
+ loss, lr = global_alignment_iter(
551
+ net, bar.n, niter, lr_base, lr_min, optimizer, schedule
552
+ )
553
+ bar.set_postfix_str(f"{lr=:g} loss={loss:g}")
554
+ bar.update()
555
+ else:
556
+ for n in range(niter):
557
+ loss, _ = global_alignment_iter(
558
+ net, n, niter, lr_base, lr_min, optimizer, schedule
559
+ )
560
+ return loss
561
+
562
+
563
+ def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule):
564
+ t = cur_iter / niter
565
+ if schedule == "cosine":
566
+ lr = cosine_schedule(t, lr_base, lr_min)
567
+ elif schedule == "linear":
568
+ lr = linear_schedule(t, lr_base, lr_min)
569
+ else:
570
+ raise ValueError(f"bad lr {schedule=}")
571
+ adjust_learning_rate_by_lr(optimizer, lr)
572
+ optimizer.zero_grad()
573
+ loss = net()
574
+ loss.backward()
575
+ optimizer.step()
576
+
577
+ return float(loss), lr
578
+
579
+
580
+ @torch.no_grad()
581
+ def clean_pointcloud(
582
+ im_confs, K, cams, depthmaps, all_pts3d, tol=0.001, bad_conf=0, dbg=()
583
+ ):
584
+ """Method:
585
+ 1) express all 3d points in each camera coordinate frame
586
+ 2) if they're in front of a depthmap --> then lower their confidence
587
+ """
588
+ assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d)
589
+ assert 0 <= tol < 1
590
+ res = [c.clone() for c in im_confs]
591
+
592
+ # reshape appropriately
593
+ all_pts3d = [p.view(*c.shape, 3) for p, c in zip(all_pts3d, im_confs)]
594
+ depthmaps = [d.view(*c.shape) for d, c in zip(depthmaps, im_confs)]
595
+
596
+ for i, pts3d in enumerate(all_pts3d):
597
+ for j in range(len(all_pts3d)):
598
+ if i == j:
599
+ continue
600
+
601
+ # project 3dpts in other view
602
+ proj = geotrf(cams[j], pts3d)
603
+ proj_depth = proj[:, :, 2]
604
+ u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
605
+
606
+ # check which points are actually in the visible cone
607
+ H, W = im_confs[j].shape
608
+ msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H)
609
+ msk_j = v[msk_i], u[msk_i]
610
+
611
+ # find bad points = those in front but less confident
612
+ bad_points = (proj_depth[msk_i] < (1 - tol) * depthmaps[j][msk_j]) & (
613
+ res[i][msk_i] < res[j][msk_j]
614
+ )
615
+
616
+ bad_msk_i = msk_i.clone()
617
+ bad_msk_i[msk_i] = bad_points
618
+ res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf)
619
+
620
+ return res
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/commons.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utility functions for global alignment
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+
11
+
12
+ def edge_str(i, j):
13
+ return f"{i}_{j}"
14
+
15
+
16
+ def i_j_ij(ij):
17
+ return edge_str(*ij), ij
18
+
19
+
20
+ def edge_conf(conf_i, conf_j, edge):
21
+ return float(conf_i[edge].mean() * conf_j[edge].mean())
22
+
23
+
24
+ def compute_edge_scores(edges, conf_i, conf_j):
25
+ return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
26
+
27
+
28
+ def NoGradParamDict(x):
29
+ assert isinstance(x, dict)
30
+ return nn.ParameterDict(x).requires_grad_(False)
31
+
32
+
33
+ def get_imshapes(edges, pred_i, pred_j):
34
+ n_imgs = max(max(e) for e in edges) + 1
35
+ imshapes = [None] * n_imgs
36
+ for e, (i, j) in enumerate(edges):
37
+ shape_i = tuple(pred_i[e].shape[0:2])
38
+ shape_j = tuple(pred_j[e].shape[0:2])
39
+ if imshapes[i]:
40
+ assert imshapes[i] == shape_i, f"incorrect shape for image {i}"
41
+ if imshapes[j]:
42
+ assert imshapes[j] == shape_j, f"incorrect shape for image {j}"
43
+ imshapes[i] = shape_i
44
+ imshapes[j] = shape_j
45
+ return imshapes
46
+
47
+
48
+ def get_conf_trf(mode):
49
+ if mode == "log":
50
+
51
+ def conf_trf(x):
52
+ return x.log()
53
+
54
+ elif mode == "sqrt":
55
+
56
+ def conf_trf(x):
57
+ return x.sqrt()
58
+
59
+ elif mode == "m1":
60
+
61
+ def conf_trf(x):
62
+ return x - 1
63
+
64
+ elif mode in ("id", "none"):
65
+
66
+ def conf_trf(x):
67
+ return x
68
+
69
+ else:
70
+ raise ValueError(f"bad mode for {mode=}")
71
+ return conf_trf
72
+
73
+
74
+ def l2_dist(a, b, weight):
75
+ return (a - b).square().sum(dim=-1) * weight
76
+
77
+
78
+ def l1_dist(a, b, weight):
79
+ return (a - b).norm(dim=-1) * weight
80
+
81
+
82
+ ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
83
+
84
+
85
+ def signed_log1p(x):
86
+ sign = torch.sign(x)
87
+ return sign * torch.log1p(torch.abs(x))
88
+
89
+
90
+ def signed_expm1(x):
91
+ sign = torch.sign(x)
92
+ return sign * torch.expm1(torch.abs(x))
93
+
94
+
95
+ def cosine_schedule(t, lr_start, lr_end):
96
+ assert 0 <= t <= 1
97
+ return lr_end + (lr_start - lr_end) * (1 + np.cos(t * np.pi)) / 2
98
+
99
+
100
+ def linear_schedule(t, lr_start, lr_end):
101
+ assert 0 <= t <= 1
102
+ return lr_start + (lr_end - lr_start) * t
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/init_im_poses.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Initialization functions for global alignment
6
+ # --------------------------------------------------------
7
+ from functools import cache
8
+
9
+ import numpy as np
10
+ import scipy.sparse as sp
11
+ import torch
12
+ import cv2
13
+ import roma
14
+ from tqdm import tqdm
15
+
16
+ from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses
17
+ from dust3r.post_process import estimate_focal_knowing_depth
18
+ from dust3r.viz import to_numpy
19
+
20
+ from cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores
21
+
22
+
23
+ @torch.no_grad()
24
+ def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):
25
+ device = self.device
26
+
27
+ # indices of known poses
28
+ nkp, known_poses_msk, known_poses = get_known_poses(self)
29
+ assert nkp == self.n_imgs, "not all poses are known"
30
+
31
+ # get all focals
32
+ nkf, _, im_focals = get_known_focals(self)
33
+ assert nkf == self.n_imgs
34
+ im_pp = self.get_principal_points()
35
+
36
+ best_depthmaps = {}
37
+ # init all pairwise poses
38
+ for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)):
39
+ i_j = edge_str(i, j)
40
+
41
+ # find relative pose for this pair
42
+ P1 = torch.eye(4, device=device)
43
+ msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1)
44
+ _, P2 = fast_pnp(
45
+ self.pred_j[i_j],
46
+ float(im_focals[i].mean()),
47
+ pp=im_pp[i],
48
+ msk=msk,
49
+ device=device,
50
+ niter_PnP=niter_PnP,
51
+ )
52
+
53
+ # align the two predicted camera with the two gt cameras
54
+ s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]])
55
+ # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1
56
+ # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3])
57
+ self._set_pose(self.pw_poses, e, R, T, scale=s)
58
+
59
+ # remember if this is a good depthmap
60
+ score = float(self.conf_i[i_j].mean())
61
+ if score > best_depthmaps.get(i, (0,))[0]:
62
+ best_depthmaps[i] = score, i_j, s
63
+
64
+ # init all image poses
65
+ for n in range(self.n_imgs):
66
+ assert known_poses_msk[n]
67
+ _, i_j, scale = best_depthmaps[n]
68
+ depth = self.pred_i[i_j][:, :, 2]
69
+ self._set_depthmap(n, depth * scale)
70
+
71
+
72
+ @torch.no_grad()
73
+ def init_minimum_spanning_tree(self, **kw):
74
+ """Init all camera poses (image-wise and pairwise poses) given
75
+ an initial set of pairwise estimations.
76
+ """
77
+ device = self.device
78
+ pts3d, _, im_focals, im_poses = minimum_spanning_tree(
79
+ self.imshapes,
80
+ self.edges,
81
+ self.pred_i,
82
+ self.pred_j,
83
+ self.conf_i,
84
+ self.conf_j,
85
+ self.im_conf,
86
+ self.min_conf_thr,
87
+ device,
88
+ has_im_poses=self.has_im_poses,
89
+ verbose=self.verbose,
90
+ **kw,
91
+ )
92
+
93
+ return init_from_pts3d(self, pts3d, im_focals, im_poses)
94
+
95
+
96
+ def init_from_pts3d(self, pts3d, im_focals, im_poses):
97
+ # init poses
98
+ nkp, known_poses_msk, known_poses = get_known_poses(self)
99
+ if nkp == 1:
100
+ raise NotImplementedError(
101
+ "Would be simpler to just align everything afterwards on the single known pose"
102
+ )
103
+ elif nkp > 1:
104
+ # global rigid SE3 alignment
105
+ s, R, T = align_multiple_poses(
106
+ im_poses[known_poses_msk], known_poses[known_poses_msk]
107
+ )
108
+ trf = sRT_to_4x4(s, R, T, device=known_poses.device)
109
+
110
+ # rotate everything
111
+ im_poses = trf @ im_poses
112
+ im_poses[:, :3, :3] /= s # undo scaling on the rotation part
113
+ for img_pts3d in pts3d:
114
+ img_pts3d[:] = geotrf(trf, img_pts3d)
115
+
116
+ # set all pairwise poses
117
+ for e, (i, j) in enumerate(self.edges):
118
+ i_j = edge_str(i, j)
119
+ # compute transform that goes from cam to world
120
+ s, R, T = rigid_points_registration(
121
+ self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]
122
+ )
123
+ self._set_pose(self.pw_poses, e, R, T, scale=s)
124
+
125
+ # take into account the scale normalization
126
+ s_factor = self.get_pw_norm_scale_factor()
127
+ im_poses[:, :3, 3] *= s_factor # apply downscaling factor
128
+ for img_pts3d in pts3d:
129
+ img_pts3d *= s_factor
130
+
131
+ # init all image poses
132
+ if self.has_im_poses:
133
+ for i in range(self.n_imgs):
134
+ cam2world = im_poses[i]
135
+ depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
136
+ self._set_depthmap(i, depth)
137
+ self._set_pose(self.im_poses, i, cam2world)
138
+ if im_focals[i] is not None:
139
+ self._set_focal(i, im_focals[i])
140
+
141
+ if self.verbose:
142
+ pass
143
+ # print(' init loss =', float(self()))
144
+
145
+
146
+ def minimum_spanning_tree(
147
+ imshapes,
148
+ edges,
149
+ pred_i,
150
+ pred_j,
151
+ conf_i,
152
+ conf_j,
153
+ im_conf,
154
+ min_conf_thr,
155
+ device,
156
+ has_im_poses=True,
157
+ niter_PnP=10,
158
+ verbose=True,
159
+ ):
160
+ n_imgs = len(imshapes)
161
+ sparse_graph = -dict_to_sparse_graph(
162
+ compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j)
163
+ )
164
+ print(sparse_graph)
165
+ msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()
166
+
167
+ # temp variable to store 3d points
168
+ pts3d = [None] * len(imshapes)
169
+
170
+ todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges
171
+ im_poses = [None] * n_imgs
172
+ im_focals = [None] * n_imgs
173
+
174
+ # init with strongest edge
175
+ score, i, j = todo.pop()
176
+ if verbose:
177
+ print(f" init edge ({i}*,{j}*) {score=}")
178
+ i_j = edge_str(i, j)
179
+ pts3d[i] = pred_i[i_j].clone()
180
+ pts3d[j] = pred_j[i_j].clone()
181
+ done = {i, j}
182
+ if has_im_poses:
183
+ im_poses[i] = torch.eye(4, device=device)
184
+ im_focals[i] = estimate_focal(pred_i[i_j])
185
+
186
+ # set initial pointcloud based on pairwise graph
187
+ msp_edges = [(i, j)]
188
+ while todo:
189
+ # each time, predict the next one
190
+ score, i, j = todo.pop()
191
+
192
+ if im_focals[i] is None:
193
+ im_focals[i] = estimate_focal(pred_i[i_j])
194
+
195
+ if i in done:
196
+ if verbose:
197
+ print(f" init edge ({i},{j}*) {score=}")
198
+ assert j not in done
199
+ # align pred[i] with pts3d[i], and then set j accordingly
200
+ i_j = edge_str(i, j)
201
+ s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j])
202
+ trf = sRT_to_4x4(s, R, T, device)
203
+ pts3d[j] = geotrf(trf, pred_j[i_j])
204
+ done.add(j)
205
+ msp_edges.append((i, j))
206
+
207
+ if has_im_poses and im_poses[i] is None:
208
+ im_poses[i] = sRT_to_4x4(1, R, T, device)
209
+
210
+ elif j in done:
211
+ if verbose:
212
+ print(f" init edge ({i}*,{j}) {score=}")
213
+ assert i not in done
214
+ i_j = edge_str(i, j)
215
+ s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j])
216
+ trf = sRT_to_4x4(s, R, T, device)
217
+ pts3d[i] = geotrf(trf, pred_i[i_j])
218
+ done.add(i)
219
+ msp_edges.append((i, j))
220
+
221
+ if has_im_poses and im_poses[i] is None:
222
+ im_poses[i] = sRT_to_4x4(1, R, T, device)
223
+ else:
224
+ # let's try again later
225
+ todo.insert(0, (score, i, j))
226
+
227
+ if has_im_poses:
228
+ # complete all missing informations
229
+ pair_scores = list(
230
+ sparse_graph.values()
231
+ ) # already negative scores: less is best
232
+ edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[
233
+ np.argsort(pair_scores)
234
+ ]
235
+ for i, j in edges_from_best_to_worse.tolist():
236
+ if im_focals[i] is None:
237
+ im_focals[i] = estimate_focal(pred_i[edge_str(i, j)])
238
+
239
+ for i in range(n_imgs):
240
+ if im_poses[i] is None:
241
+ msk = im_conf[i] > min_conf_thr
242
+ res = fast_pnp(
243
+ pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP
244
+ )
245
+ if res:
246
+ im_focals[i], im_poses[i] = res
247
+ if im_poses[i] is None:
248
+ im_poses[i] = torch.eye(4, device=device)
249
+ im_poses = torch.stack(im_poses)
250
+ else:
251
+ im_poses = im_focals = None
252
+
253
+ return pts3d, msp_edges, im_focals, im_poses
254
+
255
+
256
+ def dict_to_sparse_graph(dic):
257
+ n_imgs = max(max(e) for e in dic) + 1
258
+ res = sp.dok_array((n_imgs, n_imgs))
259
+ for edge, value in dic.items():
260
+ res[edge] = value
261
+ return res
262
+
263
+
264
+ def rigid_points_registration(pts1, pts2, conf):
265
+ R, T, s = roma.rigid_points_registration(
266
+ pts1.reshape(-1, 3),
267
+ pts2.reshape(-1, 3),
268
+ weights=conf.ravel(),
269
+ compute_scaling=True,
270
+ )
271
+ return s, R, T # return un-scaled (R, T)
272
+
273
+
274
+ def sRT_to_4x4(scale, R, T, device):
275
+ trf = torch.eye(4, device=device)
276
+ trf[:3, :3] = R * scale
277
+ trf[:3, 3] = T.ravel() # doesn't need scaling
278
+ return trf
279
+
280
+
281
+ def estimate_focal(pts3d_i, pp=None):
282
+ if pp is None:
283
+ H, W, THREE = pts3d_i.shape
284
+ assert THREE == 3
285
+ pp = torch.tensor((W / 2, H / 2), device=pts3d_i.device)
286
+ focal = estimate_focal_knowing_depth(
287
+ pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode="weiszfeld"
288
+ ).ravel()
289
+ return float(focal)
290
+
291
+
292
+ @cache
293
+ def pixel_grid(H, W):
294
+ return np.mgrid[:W, :H].T.astype(np.float32)
295
+
296
+
297
+ def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
298
+ # extract camera poses and focals with RANSAC-PnP
299
+ if msk.sum() < 4:
300
+ return None # we need at least 4 points for PnP
301
+ pts3d, msk = map(to_numpy, (pts3d, msk))
302
+
303
+ H, W, THREE = pts3d.shape
304
+ assert THREE == 3
305
+ pixels = pixel_grid(H, W)
306
+
307
+ if focal is None:
308
+ S = max(W, H)
309
+ tentative_focals = np.geomspace(S / 2, S * 3, 21)
310
+ else:
311
+ tentative_focals = [focal]
312
+
313
+ if pp is None:
314
+ pp = (W / 2, H / 2)
315
+ else:
316
+ pp = to_numpy(pp)
317
+
318
+ best = (0,)
319
+ for focal in tentative_focals:
320
+ K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
321
+ try:
322
+ success, R, T, inliers = cv2.solvePnPRansac(
323
+ pts3d[msk],
324
+ pixels[msk],
325
+ K,
326
+ None,
327
+ iterationsCount=niter_PnP,
328
+ reprojectionError=5,
329
+ flags=cv2.SOLVEPNP_SQPNP,
330
+ )
331
+ if not success:
332
+ continue
333
+ except:
334
+ continue
335
+
336
+ score = len(inliers)
337
+ if success and score > best[0]:
338
+ best = score, R, T, focal
339
+
340
+ if not best[0]:
341
+ return None
342
+
343
+ _, R, T, best_focal = best
344
+ R = cv2.Rodrigues(R)[0] # world to cam
345
+ R, T = map(torch.from_numpy, (R, T))
346
+ return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
347
+
348
+
349
+ def get_known_poses(self):
350
+ if self.has_im_poses:
351
+ known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses])
352
+ known_poses = self.get_im_poses()
353
+ return known_poses_msk.sum(), known_poses_msk, known_poses
354
+ else:
355
+ return 0, None, None
356
+
357
+
358
+ def get_known_focals(self):
359
+ if self.has_im_poses:
360
+ known_focal_msk = self.get_known_focal_mask()
361
+ known_focals = self.get_focals()
362
+ return known_focal_msk.sum(), known_focal_msk, known_focals
363
+ else:
364
+ return 0, None, None
365
+
366
+
367
+ def align_multiple_poses(src_poses, target_poses):
368
+ N = len(src_poses)
369
+ assert src_poses.shape == target_poses.shape == (N, 4, 4)
370
+
371
+ def center_and_z(poses):
372
+ eps = get_med_dist_between_poses(poses) / 100
373
+ return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps * poses[:, :3, 2]))
374
+
375
+ R, T, s = roma.rigid_points_registration(
376
+ center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True
377
+ )
378
+ return s, R, T
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/dust3r_opt/optimizer.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Main class for the implementation of the global alignment
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from cloud_opt.dust3r_opt.base_opt import BasePCOptimizer
12
+ from dust3r.utils.geometry import xy_grid, geotrf
13
+ from dust3r.utils.device import to_cpu, to_numpy
14
+
15
+
16
+ class PointCloudOptimizer(BasePCOptimizer):
17
+ """Optimize a global scene, given a list of pairwise observations.
18
+ Graph node: images
19
+ Graph edges: observations = (pred1, pred2)
20
+ """
21
+
22
+ def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+
25
+ self.has_im_poses = True # by definition of this class
26
+ self.focal_break = focal_break
27
+
28
+ # adding thing to optimize
29
+ self.im_depthmaps = nn.ParameterList(
30
+ torch.randn(H, W) / 10 - 3 for H, W in self.imshapes
31
+ ) # log(depth)
32
+ self.im_poses = nn.ParameterList(
33
+ self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)
34
+ ) # camera poses
35
+ self.im_focals = nn.ParameterList(
36
+ torch.FloatTensor([self.focal_break * np.log(max(H, W))])
37
+ for H, W in self.imshapes
38
+ ) # camera intrinsics
39
+ self.im_pp = nn.ParameterList(
40
+ torch.zeros((2,)) for _ in range(self.n_imgs)
41
+ ) # camera intrinsics
42
+ self.im_pp.requires_grad_(optimize_pp)
43
+
44
+ self.imshape = self.imshapes[0]
45
+ im_areas = [h * w for h, w in self.imshapes]
46
+ self.max_area = max(im_areas)
47
+
48
+ # adding thing to optimize
49
+ self.im_depthmaps = ParameterStack(
50
+ self.im_depthmaps, is_param=True, fill=self.max_area
51
+ )
52
+ self.im_poses = ParameterStack(self.im_poses, is_param=True)
53
+ self.im_focals = ParameterStack(self.im_focals, is_param=True)
54
+ self.im_pp = ParameterStack(self.im_pp, is_param=True)
55
+ self.register_buffer(
56
+ "_pp", torch.tensor([(w / 2, h / 2) for h, w in self.imshapes])
57
+ )
58
+ self.register_buffer(
59
+ "_grid",
60
+ ParameterStack(
61
+ [xy_grid(W, H, device=self.device) for H, W in self.imshapes],
62
+ fill=self.max_area,
63
+ ),
64
+ )
65
+
66
+ # pre-compute pixel weights
67
+ self.register_buffer(
68
+ "_weight_i",
69
+ ParameterStack(
70
+ [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges],
71
+ fill=self.max_area,
72
+ ),
73
+ )
74
+ self.register_buffer(
75
+ "_weight_j",
76
+ ParameterStack(
77
+ [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges],
78
+ fill=self.max_area,
79
+ ),
80
+ )
81
+
82
+ # precompute aa
83
+ self.register_buffer(
84
+ "_stacked_pred_i",
85
+ ParameterStack(self.pred_i, self.str_edges, fill=self.max_area),
86
+ )
87
+ self.register_buffer(
88
+ "_stacked_pred_j",
89
+ ParameterStack(self.pred_j, self.str_edges, fill=self.max_area),
90
+ )
91
+ self.register_buffer("_ei", torch.tensor([i for i, j in self.edges]))
92
+ self.register_buffer("_ej", torch.tensor([j for i, j in self.edges]))
93
+ self.total_area_i = sum([im_areas[i] for i, j in self.edges])
94
+ self.total_area_j = sum([im_areas[j] for i, j in self.edges])
95
+
96
+ def _check_all_imgs_are_selected(self, msk):
97
+ assert np.all(
98
+ self._get_msk_indices(msk) == np.arange(self.n_imgs)
99
+ ), "incomplete mask!"
100
+
101
+ def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
102
+ self._check_all_imgs_are_selected(pose_msk)
103
+
104
+ if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
105
+ known_poses = [known_poses]
106
+ for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
107
+ if self.verbose:
108
+ print(f" (setting pose #{idx} = {pose[:3,3]})")
109
+ self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))
110
+
111
+ # normalize scale if there's less than 1 known pose
112
+ n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
113
+ self.norm_pw_scale = n_known_poses <= 1
114
+
115
+ self.im_poses.requires_grad_(False)
116
+ self.norm_pw_scale = False
117
+
118
+ def preset_focal(self, known_focals, msk=None):
119
+ self._check_all_imgs_are_selected(msk)
120
+
121
+ for idx, focal in zip(self._get_msk_indices(msk), known_focals):
122
+ if self.verbose:
123
+ print(f" (setting focal #{idx} = {focal})")
124
+ self._no_grad(self._set_focal(idx, focal))
125
+
126
+ self.im_focals.requires_grad_(False)
127
+
128
+ def preset_principal_point(self, known_pp, msk=None):
129
+ self._check_all_imgs_are_selected(msk)
130
+
131
+ for idx, pp in zip(self._get_msk_indices(msk), known_pp):
132
+ if self.verbose:
133
+ print(f" (setting principal point #{idx} = {pp})")
134
+ self._no_grad(self._set_principal_point(idx, pp))
135
+
136
+ self.im_pp.requires_grad_(False)
137
+
138
+ def _get_msk_indices(self, msk):
139
+ if msk is None:
140
+ return range(self.n_imgs)
141
+ elif isinstance(msk, int):
142
+ return [msk]
143
+ elif isinstance(msk, (tuple, list)):
144
+ return self._get_msk_indices(np.array(msk))
145
+ elif msk.dtype in (bool, torch.bool, np.bool_):
146
+ assert len(msk) == self.n_imgs
147
+ return np.where(msk)[0]
148
+ elif np.issubdtype(msk.dtype, np.integer):
149
+ return msk
150
+ else:
151
+ raise ValueError(f"bad {msk=}")
152
+
153
+ def _no_grad(self, tensor):
154
+ assert (
155
+ tensor.requires_grad
156
+ ), "it must be True at this point, otherwise no modification occurs"
157
+
158
+ def _set_focal(self, idx, focal, force=False):
159
+ param = self.im_focals[idx]
160
+ if (
161
+ param.requires_grad or force
162
+ ): # can only init a parameter not already initialized
163
+ param.data[:] = self.focal_break * np.log(focal)
164
+ return param
165
+
166
+ def get_focals(self):
167
+ log_focals = torch.stack(list(self.im_focals), dim=0)
168
+ return (log_focals / self.focal_break).exp()
169
+
170
+ def get_known_focal_mask(self):
171
+ return torch.tensor([not (p.requires_grad) for p in self.im_focals])
172
+
173
+ def _set_principal_point(self, idx, pp, force=False):
174
+ param = self.im_pp[idx]
175
+ H, W = self.imshapes[idx]
176
+ if (
177
+ param.requires_grad or force
178
+ ): # can only init a parameter not already initialized
179
+ param.data[:] = to_cpu(to_numpy(pp) - (W / 2, H / 2)) / 10
180
+ return param
181
+
182
+ def get_principal_points(self):
183
+ return self._pp + 10 * self.im_pp
184
+
185
+ def get_intrinsics(self):
186
+ K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
187
+ focals = self.get_focals().flatten()
188
+ K[:, 0, 0] = K[:, 1, 1] = focals
189
+ K[:, :2, 2] = self.get_principal_points()
190
+ K[:, 2, 2] = 1
191
+ return K
192
+
193
+ def get_im_poses(self): # cam to world
194
+ cam2world = self._get_poses(self.im_poses)
195
+ return cam2world
196
+
197
+ def _set_depthmap(self, idx, depth, force=False):
198
+ depth = _ravel_hw(depth, self.max_area)
199
+
200
+ param = self.im_depthmaps[idx]
201
+ if (
202
+ param.requires_grad or force
203
+ ): # can only init a parameter not already initialized
204
+ param.data[:] = depth.log().nan_to_num(neginf=0)
205
+ return param
206
+
207
+ def get_depthmaps(self, raw=False):
208
+ res = self.im_depthmaps.exp()
209
+ if not raw:
210
+ res = [dm[: h * w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
211
+ return res
212
+
213
+ def depth_to_pts3d(self):
214
+ # Get depths and projection params if not provided
215
+ focals = self.get_focals()
216
+ pp = self.get_principal_points()
217
+ im_poses = self.get_im_poses()
218
+ depth = self.get_depthmaps(raw=True)
219
+
220
+ # get pointmaps in camera frame
221
+ rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp)
222
+ # project to world frame
223
+ return geotrf(im_poses, rel_ptmaps)
224
+
225
+ def get_pts3d(self, raw=False):
226
+ res = self.depth_to_pts3d()
227
+ if not raw:
228
+ res = [dm[: h * w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
229
+ return res
230
+
231
+ def forward(self):
232
+ pw_poses = self.get_pw_poses() # cam-to-world
233
+ pw_adapt = self.get_adaptors().unsqueeze(1)
234
+ proj_pts3d = self.get_pts3d(raw=True)
235
+
236
+ # rotate pairwise prediction according to pw_poses
237
+ aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i)
238
+ aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)
239
+
240
+ # compute the less
241
+ li = (
242
+ self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum()
243
+ / self.total_area_i
244
+ )
245
+ lj = (
246
+ self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum()
247
+ / self.total_area_j
248
+ )
249
+
250
+ return li + lj
251
+
252
+
253
+ def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
254
+ pp = pp.unsqueeze(1)
255
+ focal = focal.unsqueeze(1)
256
+ assert focal.shape == (len(depth), 1, 1)
257
+ assert pp.shape == (len(depth), 1, 2)
258
+ assert pixel_grid.shape == depth.shape + (2,)
259
+ depth = depth.unsqueeze(-1)
260
+ return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1)
261
+
262
+
263
+ def ParameterStack(params, keys=None, is_param=None, fill=0):
264
+ if keys is not None:
265
+ params = [params[k] for k in keys]
266
+
267
+ if fill > 0:
268
+ params = [_ravel_hw(p, fill) for p in params]
269
+
270
+ requires_grad = params[0].requires_grad
271
+ assert all(p.requires_grad == requires_grad for p in params)
272
+
273
+ params = torch.stack(list(params)).float().detach()
274
+ if is_param or requires_grad:
275
+ params = nn.Parameter(params)
276
+ params.requires_grad_(requires_grad)
277
+ return params
278
+
279
+
280
+ def _ravel_hw(tensor, fill=0):
281
+ # ravel H,W
282
+ tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
283
+
284
+ if len(tensor) < fill:
285
+ tensor = torch.cat(
286
+ (tensor, tensor.new_zeros((fill - len(tensor),) + tensor.shape[1:]))
287
+ )
288
+ return tensor
289
+
290
+
291
+ def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
292
+ focal_base = max(H, W) / (
293
+ 2 * np.tan(np.deg2rad(60) / 2)
294
+ ) # size / 1.1547005383792515
295
+ return minf * focal_base, maxf * focal_base
296
+
297
+
298
+ def apply_mask(img, msk):
299
+ img = img.copy()
300
+ img[msk] = 0
301
+ return img
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/init_all.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cache
2
+ import numpy as np
3
+ import scipy.sparse as sp
4
+ import torch
5
+ import cv2
6
+ import roma
7
+ from tqdm import tqdm
8
+
9
+ from cloud_opt.utils import *
10
+
11
+
12
+ def compute_edge_scores(edges, edge2conf_i, edge2conf_j):
13
+ """
14
+ edges: 'i_j', (i,j)
15
+ """
16
+ score_dict = {
17
+ (i, j): edge_conf(edge2conf_i[e], edge2conf_j[e]) for e, (i, j) in edges
18
+ }
19
+ return score_dict
20
+
21
+
22
+ def dict_to_sparse_graph(dic):
23
+ n_imgs = max(max(e) for e in dic) + 1
24
+ res = sp.dok_array((n_imgs, n_imgs))
25
+ for edge, value in dic.items():
26
+ res[edge] = value
27
+ return res
28
+
29
+
30
+ @torch.no_grad()
31
+ def init_minimum_spanning_tree(self, **kw):
32
+ """Init all camera poses (image-wise and pairwise poses) given
33
+ an initial set of pairwise estimations.
34
+ """
35
+ device = self.device
36
+ pts3d, _, im_focals, im_poses = minimum_spanning_tree(
37
+ self.imshapes,
38
+ self.edges,
39
+ self.edge2pts_i,
40
+ self.edge2pts_j,
41
+ self.edge2conf_i,
42
+ self.edge2conf_j,
43
+ self.im_conf,
44
+ self.min_conf_thr,
45
+ device,
46
+ has_im_poses=self.has_im_poses,
47
+ verbose=self.verbose,
48
+ **kw,
49
+ )
50
+
51
+ return init_from_pts3d(self, pts3d, im_focals, im_poses)
52
+
53
+
54
+ def minimum_spanning_tree(
55
+ imshapes,
56
+ edges,
57
+ edge2pred_i,
58
+ edge2pred_j,
59
+ edge2conf_i,
60
+ edge2conf_j,
61
+ im_conf,
62
+ min_conf_thr,
63
+ device,
64
+ has_im_poses=True,
65
+ niter_PnP=10,
66
+ verbose=True,
67
+ save_score_path=None,
68
+ ):
69
+ n_imgs = len(imshapes)
70
+ eadge_and_scores = compute_edge_scores(map(i_j_ij, edges), edge2conf_i, edge2conf_j)
71
+ sparse_graph = -dict_to_sparse_graph(eadge_and_scores)
72
+ msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()
73
+
74
+ # temp variable to store 3d points
75
+ pts3d = [None] * len(imshapes)
76
+
77
+ todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges
78
+ im_poses = [None] * n_imgs
79
+ im_focals = [None] * n_imgs
80
+
81
+ # init with strongest edge
82
+ score, i, j = todo.pop()
83
+ if verbose:
84
+ print(f" init edge ({i}*,{j}*) {score=}")
85
+ i_j = edge_str(i, j)
86
+
87
+ pts3d[i] = edge2pred_i[i_j].clone()
88
+ pts3d[j] = edge2pred_j[i_j].clone()
89
+ done = {i, j}
90
+ if has_im_poses:
91
+ im_poses[i] = torch.eye(4, device=device)
92
+ im_focals[i] = estimate_focal(edge2pred_i[i_j])
93
+
94
+ # set initial pointcloud based on pairwise graph
95
+ msp_edges = [(i, j)]
96
+ while todo:
97
+ # each time, predict the next one
98
+ score, i, j = todo.pop()
99
+
100
+ if im_focals[i] is None:
101
+ im_focals[i] = estimate_focal(edge2pred_i[i_j])
102
+
103
+ if i in done:
104
+ if verbose:
105
+ print(f" init edge ({i},{j}*) {score=}")
106
+ assert j not in done
107
+ # align pred[i] with pts3d[i], and then set j accordingly
108
+ i_j = edge_str(i, j)
109
+ s, R, T = rigid_points_registration(
110
+ edge2pred_i[i_j], pts3d[i], conf=edge2conf_i[i_j]
111
+ )
112
+ trf = sRT_to_4x4(s, R, T, device)
113
+ pts3d[j] = geotrf(trf, edge2pred_j[i_j])
114
+ done.add(j)
115
+ msp_edges.append((i, j))
116
+
117
+ if has_im_poses and im_poses[i] is None:
118
+ im_poses[i] = sRT_to_4x4(1, R, T, device)
119
+
120
+ elif j in done:
121
+ if verbose:
122
+ print(f" init edge ({i}*,{j}) {score=}")
123
+ assert i not in done
124
+ i_j = edge_str(i, j)
125
+ s, R, T = rigid_points_registration(
126
+ edge2pred_j[i_j], pts3d[j], conf=edge2conf_j[i_j]
127
+ )
128
+ trf = sRT_to_4x4(s, R, T, device)
129
+ pts3d[i] = geotrf(trf, edge2pred_i[i_j])
130
+ done.add(i)
131
+ msp_edges.append((i, j))
132
+
133
+ if has_im_poses and im_poses[i] is None:
134
+ im_poses[i] = sRT_to_4x4(1, R, T, device)
135
+ else:
136
+ # let's try again later
137
+ todo.insert(0, (score, i, j))
138
+
139
+ if has_im_poses:
140
+ # complete all missing informations
141
+ pair_scores = list(
142
+ sparse_graph.values()
143
+ ) # already negative scores: less is best
144
+ edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[
145
+ np.argsort(pair_scores)
146
+ ]
147
+ for i, j in edges_from_best_to_worse.tolist():
148
+ if im_focals[i] is None:
149
+ im_focals[i] = estimate_focal(edge2pred_i[edge_str(i, j)])
150
+
151
+ for i in range(n_imgs):
152
+ if im_poses[i] is None:
153
+ msk = im_conf[i] > min_conf_thr
154
+ res = fast_pnp(
155
+ pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP
156
+ )
157
+ if res:
158
+ im_focals[i], im_poses[i] = res
159
+ if im_poses[i] is None:
160
+ im_poses[i] = torch.eye(4, device=device)
161
+ im_poses = torch.stack(im_poses)
162
+ else:
163
+ im_poses = im_focals = None
164
+
165
+ return pts3d, msp_edges, im_focals, im_poses
166
+
167
+
168
+ def init_from_pts3d(self, pts3d, im_focals, im_poses):
169
+ # init poses
170
+ nkp, known_poses_msk, known_poses = self.get_known_poses()
171
+ if nkp == 1:
172
+ raise NotImplementedError(
173
+ "Would be simpler to just align everything afterwards on the single known pose"
174
+ )
175
+ elif nkp > 1:
176
+ # global rigid SE3 alignment
177
+ s, R, T = align_multiple_poses(
178
+ im_poses[known_poses_msk], known_poses[known_poses_msk]
179
+ )
180
+ trf = sRT_to_4x4(s, R, T, device=known_poses.device)
181
+
182
+ # rotate everything
183
+ im_poses = trf @ im_poses
184
+ im_poses[:, :3, :3] /= s # undo scaling on the rotation part
185
+ for img_pts3d in pts3d:
186
+ img_pts3d[:] = geotrf(trf, img_pts3d)
187
+ else:
188
+ pass # no known poses
189
+
190
+ # set all pairwise poses
191
+ for e, (i, j) in enumerate(self.edges):
192
+ i_j = edge_str(i, j)
193
+ # compute transform that goes from cam to world
194
+ s, R, T = rigid_points_registration(
195
+ self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]
196
+ )
197
+ self._set_pose(self.pw_poses, e, R, T, scale=s)
198
+
199
+ # take into account the scale normalization
200
+ s_factor = self.get_pw_norm_scale_factor()
201
+ im_poses[:, :3, 3] *= s_factor # apply downscaling factor
202
+ for img_pts3d in pts3d:
203
+ img_pts3d *= s_factor
204
+
205
+ # init all image poses
206
+ if self.has_im_poses:
207
+ for i in range(self.n_imgs):
208
+ cam2world = im_poses[i]
209
+ depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
210
+ self._set_depthmap(i, depth)
211
+ self._set_pose(self.im_poses, i, cam2world)
212
+ if im_focals[i] is not None:
213
+ if not self.shared_focal:
214
+ self._set_focal(i, im_focals[i])
215
+ if self.shared_focal:
216
+ self._set_focal(0, sum(im_focals) / self.n_imgs)
217
+ if self.n_imgs > 2:
218
+ self._set_init_depthmap()
219
+
220
+ if self.verbose:
221
+ with torch.no_grad():
222
+ print(" init loss =", float(self()))
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/cloud_opt/utils.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import roma
4
+ import numpy as np
5
+ import cv2
6
+ from functools import cache
7
+
8
+
9
+ def todevice(batch, device, callback=None, non_blocking=False):
10
+ """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
11
+
12
+ batch: list, tuple, dict of tensors or other things
13
+ device: pytorch device or 'numpy'
14
+ callback: function that would be called on every sub-elements.
15
+ """
16
+ if callback:
17
+ batch = callback(batch)
18
+
19
+ if isinstance(batch, dict):
20
+ return {k: todevice(v, device) for k, v in batch.items()}
21
+
22
+ if isinstance(batch, (tuple, list)):
23
+ return type(batch)(todevice(x, device) for x in batch)
24
+
25
+ x = batch
26
+ if device == "numpy":
27
+ if isinstance(x, torch.Tensor):
28
+ x = x.detach().cpu().numpy()
29
+ elif x is not None:
30
+ if isinstance(x, np.ndarray):
31
+ x = torch.from_numpy(x)
32
+ if torch.is_tensor(x):
33
+ x = x.to(device, non_blocking=non_blocking)
34
+ return x
35
+
36
+
37
+ to_device = todevice # alias
38
+
39
+
40
+ def to_numpy(x):
41
+ return todevice(x, "numpy")
42
+
43
+
44
+ def to_cpu(x):
45
+ return todevice(x, "cpu")
46
+
47
+
48
+ def to_cuda(x):
49
+ return todevice(x, "cuda")
50
+
51
+
52
+ def signed_log1p(x):
53
+ sign = torch.sign(x)
54
+ return sign * torch.log1p(torch.abs(x))
55
+
56
+
57
+ def l2_dist(a, b, weight):
58
+ return (a - b).square().sum(dim=-1) * weight
59
+
60
+
61
+ def l1_dist(a, b, weight):
62
+ return (a - b).norm(dim=-1) * weight
63
+
64
+
65
+ ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
66
+
67
+
68
+ def _check_edges(edges):
69
+ indices = sorted({i for edge in edges for i in edge})
70
+ assert indices == list(range(len(indices))), "bad pair indices: missing values "
71
+ return len(indices)
72
+
73
+
74
+ def NoGradParamDict(x):
75
+ assert isinstance(x, dict)
76
+ return nn.ParameterDict(x).requires_grad_(False)
77
+
78
+
79
+ def edge_str(i, j):
80
+ return f"{i}_{j}"
81
+
82
+
83
+ def i_j_ij(ij):
84
+ # inputs are (i, j)
85
+ return edge_str(*ij), ij
86
+
87
+
88
+ def edge_conf(conf_i, conf_j):
89
+ score = float(conf_i.mean() * conf_j.mean())
90
+ return score
91
+
92
+
93
+ def get_imshapes(edges, pred_i, pred_j):
94
+ n_imgs = max(max(e) for e in edges) + 1
95
+ imshapes = [None] * n_imgs
96
+ for e, (i, j) in enumerate(edges):
97
+ shape_i = tuple(pred_i[e]["pts3d_is_self_view"].shape[0:2])
98
+ shape_j = tuple(pred_j[e]["pts3d_in_other_view"].shape[0:2])
99
+ if imshapes[i]:
100
+ assert imshapes[i] == shape_i, f"incorrect shape for image {i}"
101
+ if imshapes[j]:
102
+ assert imshapes[j] == shape_j, f"incorrect shape for image {j}"
103
+ imshapes[i] = shape_i
104
+ imshapes[j] = shape_j
105
+ return imshapes
106
+
107
+
108
+ def get_conf_trf(mode):
109
+ if mode == "log":
110
+
111
+ def conf_trf(x):
112
+ return x.log()
113
+
114
+ elif mode == "sqrt":
115
+
116
+ def conf_trf(x):
117
+ return x.sqrt()
118
+
119
+ elif mode == "m1":
120
+
121
+ def conf_trf(x):
122
+ return x - 1
123
+
124
+ elif mode in ("id", "none"):
125
+
126
+ def conf_trf(x):
127
+ return x
128
+
129
+ else:
130
+ raise ValueError(f"bad mode for {mode=}")
131
+ return conf_trf
132
+
133
+
134
+ @torch.no_grad()
135
+ def _compute_img_conf(imshapes, device, edges, edge2conf_i, edge2conf_j):
136
+ im_conf = nn.ParameterList([torch.zeros(hw, device=device) for hw in imshapes])
137
+ for e, (i, j) in enumerate(edges):
138
+ im_conf[i] = torch.maximum(im_conf[i], edge2conf_i[edge_str(i, j)])
139
+ im_conf[j] = torch.maximum(im_conf[j], edge2conf_j[edge_str(i, j)])
140
+ return im_conf
141
+
142
+
143
+ def xy_grid(
144
+ W,
145
+ H,
146
+ device=None,
147
+ origin=(0, 0),
148
+ unsqueeze=None,
149
+ cat_dim=-1,
150
+ homogeneous=False,
151
+ **arange_kw,
152
+ ):
153
+ """Output a (H,W,2) array of int32
154
+ with output[j,i,0] = i + origin[0]
155
+ output[j,i,1] = j + origin[1]
156
+ """
157
+ if device is None:
158
+ # numpy
159
+ arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
160
+ else:
161
+ # torch
162
+ arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
163
+ meshgrid, stack = torch.meshgrid, torch.stack
164
+ ones = lambda *a: torch.ones(*a, device=device)
165
+
166
+ tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
167
+ grid = meshgrid(tw, th, indexing="xy")
168
+ if homogeneous:
169
+ grid = grid + (ones((H, W)),)
170
+ if unsqueeze is not None:
171
+ grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
172
+ if cat_dim is not None:
173
+ grid = stack(grid, cat_dim)
174
+ return grid
175
+
176
+
177
+ def estimate_focal_knowing_depth(
178
+ pts3d, pp, focal_mode="median", min_focal=0.0, max_focal=np.inf
179
+ ):
180
+ """Reprojection method, for when the absolute depth is known:
181
+ 1) estimate the camera focal using a robust estimator
182
+ 2) reproject points onto true rays, minimizing a certain error
183
+ """
184
+ B, H, W, THREE = pts3d.shape
185
+ assert THREE == 3
186
+
187
+ # centered pixel grid
188
+ pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(
189
+ -1, 1, 2
190
+ ) # B,HW,2
191
+ pts3d = pts3d.flatten(1, 2) # (B, HW, 3)
192
+
193
+ if focal_mode == "median":
194
+ with torch.no_grad():
195
+ # direct estimation of focal
196
+ u, v = pixels.unbind(dim=-1)
197
+ x, y, z = pts3d.unbind(dim=-1)
198
+ fx_votes = (u * z) / x
199
+ fy_votes = (v * z) / y
200
+
201
+ # assume square pixels, hence same focal for X and Y
202
+ f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
203
+ focal = torch.nanmedian(f_votes, dim=-1).values
204
+
205
+ elif focal_mode == "weiszfeld":
206
+ # init focal with l2 closed form
207
+ # we try to find focal = argmin Sum | pixel - focal * (x,y)/z|
208
+ xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(
209
+ posinf=0, neginf=0
210
+ ) # homogeneous (x,y,1)
211
+
212
+ dot_xy_px = (xy_over_z * pixels).sum(dim=-1)
213
+ dot_xy_xy = xy_over_z.square().sum(dim=-1)
214
+
215
+ focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)
216
+
217
+ # iterative re-weighted least-squares
218
+ for iter in range(10):
219
+ # re-weighting by inverse of distance
220
+ dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1)
221
+ # print(dis.nanmean(-1))
222
+ w = dis.clip(min=1e-8).reciprocal()
223
+ # update the scaling with the new weights
224
+ focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)
225
+ else:
226
+ raise ValueError(f"bad {focal_mode=}")
227
+
228
+ focal_base = max(H, W) / (
229
+ 2 * np.tan(np.deg2rad(60) / 2)
230
+ ) # size / 1.1547005383792515
231
+ focal = focal.clip(min=min_focal * focal_base, max=max_focal * focal_base)
232
+ # print(focal)
233
+ return focal
234
+
235
+
236
+ def estimate_focal(pts3d_i, pp=None):
237
+ if pp is None:
238
+ H, W, THREE = pts3d_i.shape
239
+ assert THREE == 3
240
+ pp = torch.tensor((W / 2, H / 2), device=pts3d_i.device)
241
+ focal = estimate_focal_knowing_depth(
242
+ pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode="weiszfeld"
243
+ ).ravel()
244
+ return float(focal)
245
+
246
+
247
+ def rigid_points_registration(pts1, pts2, conf):
248
+ R, T, s = roma.rigid_points_registration(
249
+ pts1.reshape(-1, 3),
250
+ pts2.reshape(-1, 3),
251
+ weights=conf.ravel(),
252
+ compute_scaling=True,
253
+ )
254
+ return s, R, T # return un-scaled (R, T)
255
+
256
+
257
+ def sRT_to_4x4(scale, R, T, device):
258
+ trf = torch.eye(4, device=device)
259
+ trf[:3, :3] = R * scale
260
+ trf[:3, 3] = T.ravel() # doesn't need scaling
261
+ return trf
262
+
263
+
264
+ def geotrf(Trf, pts, ncol=None, norm=False):
265
+ """Apply a geometric transformation to a list of 3-D points.
266
+
267
+ H: 3x3 or 4x4 projection matrix (typically a Homography)
268
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
269
+
270
+ ncol: int. number of columns of the result (2 or 3)
271
+ norm: float. if != 0, the resut is projected on the z=norm plane.
272
+
273
+ Returns an array of projected 2d points.
274
+ """
275
+ assert Trf.ndim >= 2
276
+ if isinstance(Trf, np.ndarray):
277
+ pts = np.asarray(pts)
278
+ elif isinstance(Trf, torch.Tensor):
279
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
280
+
281
+ # adapt shape if necessary
282
+ output_reshape = pts.shape[:-1]
283
+ ncol = ncol or pts.shape[-1]
284
+
285
+ # optimized code
286
+ if (
287
+ isinstance(Trf, torch.Tensor)
288
+ and isinstance(pts, torch.Tensor)
289
+ and Trf.ndim == 3
290
+ and pts.ndim == 4
291
+ ):
292
+ d = pts.shape[3]
293
+ if Trf.shape[-1] == d:
294
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
295
+ elif Trf.shape[-1] == d + 1:
296
+ pts = (
297
+ torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
298
+ + Trf[:, None, None, :d, d]
299
+ )
300
+ else:
301
+ raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
302
+ else:
303
+ if Trf.ndim >= 3:
304
+ n = Trf.ndim - 2
305
+ assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
306
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
307
+
308
+ if pts.ndim > Trf.ndim:
309
+ # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
310
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
311
+ elif pts.ndim == 2:
312
+ # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
313
+ pts = pts[:, None, :]
314
+
315
+ if pts.shape[-1] + 1 == Trf.shape[-1]:
316
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
317
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
318
+ elif pts.shape[-1] == Trf.shape[-1]:
319
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
320
+ pts = pts @ Trf
321
+ else:
322
+ pts = Trf @ pts.T
323
+ if pts.ndim >= 2:
324
+ pts = pts.swapaxes(-1, -2)
325
+
326
+ if norm:
327
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
328
+ if norm != 1:
329
+ pts *= norm
330
+
331
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
332
+ return res
333
+
334
+
335
+ def inv(mat):
336
+ """Invert a torch or numpy matrix"""
337
+ if isinstance(mat, torch.Tensor):
338
+ return torch.linalg.inv(mat)
339
+ if isinstance(mat, np.ndarray):
340
+ return np.linalg.inv(mat)
341
+ raise ValueError(f"bad matrix type = {type(mat)}")
342
+
343
+
344
+ @cache
345
+ def pixel_grid(H, W):
346
+ return np.mgrid[:W, :H].T.astype(np.float32)
347
+
348
+
349
+ def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
350
+ # extract camera poses and focals with RANSAC-PnP
351
+ if msk.sum() < 4:
352
+ return None # we need at least 4 points for PnP
353
+ pts3d, msk = map(to_numpy, (pts3d, msk))
354
+
355
+ H, W, THREE = pts3d.shape
356
+ assert THREE == 3
357
+ pixels = pixel_grid(H, W)
358
+
359
+ if focal is None:
360
+ S = max(W, H)
361
+ tentative_focals = np.geomspace(S / 2, S * 3, 21)
362
+ else:
363
+ tentative_focals = [focal]
364
+
365
+ if pp is None:
366
+ pp = (W / 2, H / 2)
367
+ else:
368
+ pp = to_numpy(pp)
369
+
370
+ best = (0,)
371
+ for focal in tentative_focals:
372
+ K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
373
+
374
+ success, R, T, inliers = cv2.solvePnPRansac(
375
+ pts3d[msk],
376
+ pixels[msk],
377
+ K,
378
+ None,
379
+ iterationsCount=niter_PnP,
380
+ reprojectionError=5,
381
+ flags=cv2.SOLVEPNP_SQPNP,
382
+ )
383
+ if not success:
384
+ continue
385
+
386
+ score = len(inliers)
387
+ if success and score > best[0]:
388
+ best = score, R, T, focal
389
+
390
+ if not best[0]:
391
+ return None
392
+
393
+ _, R, T, best_focal = best
394
+ R = cv2.Rodrigues(R)[0] # world to cam
395
+ R, T = map(torch.from_numpy, (R, T))
396
+ return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
397
+
398
+
399
+ def get_med_dist_between_poses(poses):
400
+ from scipy.spatial.distance import pdist
401
+
402
+ return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
403
+
404
+
405
+ def align_multiple_poses(src_poses, target_poses):
406
+ N = len(src_poses)
407
+ assert src_poses.shape == target_poses.shape == (N, 4, 4)
408
+
409
+ def center_and_z(poses):
410
+ eps = get_med_dist_between_poses(poses) / 100
411
+ return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps * poses[:, :3, 2]))
412
+
413
+ R, T, s = roma.rigid_points_registration(
414
+ center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True
415
+ )
416
+ return s, R, T
417
+
418
+
419
+ def cosine_schedule(t, lr_start, lr_end):
420
+ assert 0 <= t <= 1
421
+ return lr_end + (lr_start - lr_end) * (1 + np.cos(t * np.pi)) / 2
422
+
423
+
424
+ def linear_schedule(t, lr_start, lr_end):
425
+ assert 0 <= t <= 1
426
+ return lr_start + (lr_end - lr_start) * t
427
+
428
+
429
+ def cycled_linear_schedule(t, lr_start, lr_end, num_cycles=2):
430
+ assert 0 <= t <= 1
431
+ cycle_t = t * num_cycles
432
+ cycle_t = cycle_t - int(cycle_t)
433
+ if t == 1:
434
+ cycle_t = 1
435
+ return linear_schedule(cycle_t, lr_start, lr_end)
436
+
437
+
438
+ def adjust_learning_rate_by_lr(optimizer, lr):
439
+ for param_group in optimizer.param_groups:
440
+ if "lr_scale" in param_group:
441
+ param_group["lr"] = lr * param_group["lr_scale"]
442
+ else:
443
+ param_group["lr"] = lr
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/deepspeed_zero3_bf16.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "gradient_accumulation_steps": 1,
6
+ "gradient_clipping": 1.0,
7
+ "train_micro_batch_size_per_gpu": 1,
8
+ "steps_per_print": 2000,
9
+ "wall_clock_breakdown": false,
10
+ "zero_optimization": {
11
+ "stage": 3,
12
+ "overlap_comm": true,
13
+ "contiguous_gradients": true,
14
+ "reduce_bucket_size": 50000000,
15
+ "stage3_prefetch_bucket_size": 5000000,
16
+ "stage3_param_persistence_threshold": 100000,
17
+ "gather_16bit_weights_on_model_save": true
18
+ }
19
+ }
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accum_iter: 1
2
+ allow_repeat: false
3
+ amp: 1
4
+ batch_size: 1
5
+ benchmark: false
6
+ custom_lr_scale: 1.0
7
+ dataset_arkit: ARKitScenes_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="/var/scratch/qzhang2/SLAM-Former/data/train/processed_arkitscenes/",
8
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
9
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
10
+ dataset_mvs_synth: MVS_Synth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="/var/scratch/qzhang2/SLAM-Former/data/train/processed_mvs_synth",
11
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
12
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
13
+ dataset_scannetpp: ScanNetpp_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="/var/scratch/qzhang2/SLAM-Former/data/train/processed_scannetpp/",
14
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
15
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
16
+ desc_dim: 128
17
+ detach_frontend_tokens: false
18
+ dist_backend: nccl
19
+ dist_url: env://
20
+ distributed: false
21
+ enable_dynamic_boundary: false
22
+ enable_loop: true
23
+ enable_submap: true
24
+ enable_temporal: false
25
+ epochs: 10
26
+ eval_freq: 1
27
+ exp_name: submap_joint_softall_64_24_v1
28
+ fixed_length: true
29
+ freeze_encoder: true
30
+ gpu: 0
31
+ gradient_checkpointing: true
32
+ gumbel_tau: 5.0
33
+ loop_mask_mode: soft_all
34
+ retain_history_grad: true
35
+ submap_train_mode: full_token
36
+ submap_retrieval_topk: 0
37
+ submap_fetch_source: frontend
38
+ submap_descriptor_source: frontend
39
+ train_submap_modules_only: false
40
+ gumbel_tau_end: 0.1
41
+ gumbel_tau_start: 5.0
42
+ hydra:
43
+ run:
44
+ dir: ${save_dir}/${exp_name}
45
+ verbose: true
46
+ keep_freq: 1
47
+ load_only_encoder: false
48
+ local-rank: -1
49
+ logdir: ${save_dir}/${exp_name}/logs
50
+ long_context: false
51
+ lr: 1e-5
52
+ max_checkpoints: 10
53
+ max_recursive_submaps: 5
54
+ min_lr: 1e-8
55
+ n_corres_test: 0
56
+ n_corres_train: 0
57
+ num_imgs_vis: 4
58
+ num_test_views: 4
59
+ num_views: 24
60
+ num_workers: 4
61
+ output_dir: ${save_dir}/${exp_name}/
62
+ pretrained: /var/scratch/qzhang2/SLAM-Former/ckpt/checkpoint-10.pth.model
63
+ print_freq: 10
64
+ print_img_freq: 50000000
65
+ rank: 0
66
+ resume: null
67
+ retention_ratio: 0.5
68
+ pseudo_gt:
69
+ enable: false
70
+ cache_path: null
71
+ use_soft_targets: true
72
+ min_confidence: 0.65
73
+ min_support_pairs: 1
74
+ topk_pairs: 4
75
+ loss_type: hybrid
76
+ loss_weight_gate: 0.1
77
+ loss_weight_desc: 0.1
78
+ geometric_support_scale: 0.25
79
+ ranking_margin: 0.1
80
+ use_l2m: false
81
+ l2m_min_certainty: 0.0
82
+ l2m_min_inlier_ratio: 0.0
83
+ save_dir: /var/scratch/qzhang2/SLAM-Former/checkpoints
84
+ save_freq: 0.1
85
+ seed: 42
86
+ soft_mask_bias: 0.2
87
+ soft_mask_temperature: 0.25
88
+ start_epoch: 0
89
+ start_step: 0
90
+ submap_size: 6
91
+ task: SLAMFormer_Submap_Finetune
92
+ tbptt_window: 0
93
+ teacher: null
94
+ temporal_embed_mode: learned
95
+ test_criterion: DistillLoss()
96
+ test_dataset: 500 @ ARKitScenes_Multi(split='test', ROOT="/var/scratch/qzhang2/SLAM-Former/data/train/processed_arkitscenes/",
97
+ resolution=(518, 392), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test})
98
+ train_criterion: DistillLoss()
99
+ train_dataset: 2250 @ ${dataset_scannetpp} + 450 @ ${dataset_mvs_synth} + 4500 @ ${dataset_arkit}
100
+ warmup_epochs: 0.5
101
+ weight_decay: 0.05
102
+ world_size: 1
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune_paper_h20.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accum_iter: 1
2
+ allow_repeat: false
3
+ amp: 1
4
+ batch_size: 1
5
+ benchmark: false
6
+ custom_lr_scale: 1.0
7
+ data_root: /var/scratch/qzhang2/SLAM-Former/data/train
8
+ root_arkit: ${data_root}/processed_arkitscenes
9
+ root_scannetpp: ${data_root}/processed_scannetpp
10
+ root_scannet: ${data_root}/processed_scannet
11
+ root_hypersim: ${data_root}/hypersim
12
+ root_blendedmvs: ${data_root}/processed_blendedmvs
13
+ root_megadepth: ${data_root}/processed_megadepth
14
+ root_mvs_synth: ${data_root}/processed_mvs_synth
15
+ dataset_arkit: ARKitScenes_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_arkit}",
16
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
17
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_arkit}, n_corres=${n_corres_train})
18
+ dataset_scannetpp: ScanNetpp_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannetpp}",
19
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
20
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_scannetpp}, n_corres=${n_corres_train})
21
+ dataset_scannet: ScanNet_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannet}",
22
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
23
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_scannet}, n_corres=${n_corres_train})
24
+ dataset_hypersim: HyperSim_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_hypersim}",
25
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
26
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_hypersim}, n_corres=${n_corres_train})
27
+ dataset_blendedmvs: BlendedMVS_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_blendedmvs}",
28
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
29
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_blendedmvs}, n_corres=${n_corres_train})
30
+ dataset_megadepth: MegaDepth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_megadepth}",
31
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
32
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_megadepth}, n_corres=${n_corres_train})
33
+ dataset_mvs_synth: MVS_Synth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_mvs_synth}",
34
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
35
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_mvs_synth}, n_corres=${n_corres_train})
36
+ desc_dim: 128
37
+ detach_frontend_tokens: false
38
+ dist_backend: nccl
39
+ dist_url: env://
40
+ distributed: false
41
+ enable_dynamic_boundary: false
42
+ enable_loop: true
43
+ enable_submap: true
44
+ enable_temporal: false
45
+ epochs: 10
46
+ eval_freq: 1
47
+ exp_name: paper_all_h20_joint
48
+ fixed_length: true
49
+ freeze_encoder: true
50
+ gpu: 0
51
+ gradient_checkpointing: true
52
+ gumbel_tau: 5.0
53
+ loop_mask_mode: soft_all
54
+ retain_history_grad: true
55
+ submap_train_mode: full_token
56
+ submap_retrieval_topk: 0
57
+ submap_fetch_source: frontend
58
+ submap_descriptor_source: frontend
59
+ train_submap_modules_only: false
60
+ gumbel_tau_end: 0.1
61
+ gumbel_tau_start: 5.0
62
+ hydra:
63
+ run:
64
+ dir: ${save_dir}/${exp_name}
65
+ verbose: true
66
+ keep_freq: 1
67
+ load_only_encoder: false
68
+ local-rank: -1
69
+ logdir: ${save_dir}/${exp_name}/logs
70
+ long_context: false
71
+ lr: 1e-5
72
+ max_checkpoints: 10
73
+ max_recursive_submaps: 5
74
+ min_lr: 1e-8
75
+ n_corres_test: 0
76
+ n_corres_train: 0
77
+ num_imgs_vis: 4
78
+ num_test_views: 4
79
+ num_views: 24
80
+ num_views_arkit: 24
81
+ num_views_scannetpp: 24
82
+ num_views_scannet: 24
83
+ num_views_hypersim: 24
84
+ num_views_blendedmvs: 24
85
+ num_views_megadepth: 24
86
+ num_views_mvs_synth: 24
87
+ num_workers: 4
88
+ output_dir: ${save_dir}/${exp_name}/
89
+ pretrained: /var/scratch/qzhang2/SLAM-Former/ckpt/checkpoint-10.pth.model
90
+ print_freq: 10
91
+ print_img_freq: 50000000
92
+ rank: 0
93
+ resume: null
94
+ retention_ratio: 0.5
95
+ pseudo_gt:
96
+ enable: false
97
+ cache_path: null
98
+ use_soft_targets: true
99
+ min_confidence: 0.65
100
+ min_support_pairs: 1
101
+ topk_pairs: 4
102
+ loss_type: hybrid
103
+ loss_weight_gate: 0.1
104
+ loss_weight_desc: 0.1
105
+ geometric_support_scale: 0.25
106
+ ranking_margin: 0.1
107
+ use_l2m: false
108
+ l2m_min_certainty: 0.0
109
+ l2m_min_inlier_ratio: 0.0
110
+ save_dir: /var/scratch/qzhang2/SLAM-Former/checkpoints
111
+ save_freq: 0.1
112
+ seed: 42
113
+ soft_mask_bias: 0.2
114
+ soft_mask_temperature: 0.25
115
+ start_epoch: 0
116
+ start_step: 0
117
+ submap_size: 6
118
+ task: SLAMFormer_Submap_Finetune
119
+ tbptt_window: 0
120
+ teacher: null
121
+ temporal_embed_mode: learned
122
+ test_criterion: DistillLoss()
123
+ test_dataset: 500 @ ARKitScenes_Multi(split='test', ROOT="${root_arkit}",
124
+ resolution=(518, 392), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test})
125
+ train_criterion: DistillLoss()
126
+ train_dataset: 4500 @ ${dataset_arkit} + 2250 @ ${dataset_scannetpp} + 4500 @ ${dataset_scannet} + 1200 @ ${dataset_hypersim} + 2250 @ ${dataset_blendedmvs} + 2250 @ ${dataset_megadepth} + 450 @ ${dataset_mvs_synth}
127
+ warmup_epochs: 0.5
128
+ weight_decay: 0.05
129
+ world_size: 1
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune_pseudo_gt_high_recall.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accum_iter: 1
2
+ allow_repeat: false
3
+ amp: 1
4
+ batch_size: 1
5
+ benchmark: false
6
+ custom_lr_scale: 1.0
7
+ data_root: /var/scratch/qzhang2/SLAM-Former/data/train
8
+ root_arkit: ${data_root}/processed_arkitscenes
9
+ root_scannetpp: ${data_root}/processed_scannetpp
10
+ root_scannet: ${data_root}/processed_scannetv2
11
+ root_hypersim: ${data_root}/hypersim
12
+ root_blendedmvs: ${data_root}/processed_blendedmvs
13
+ root_megadepth: ${data_root}/processed_megadepth
14
+ root_mvs_synth: ${data_root}/processed_mvs_synth
15
+ dataset_arkit: ARKitScenes_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_arkit}",
16
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
17
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_arkit}, n_corres=${n_corres_train})
18
+ dataset_scannetpp: ScanNetpp_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannetpp}",
19
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
20
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_scannetpp}, n_corres=${n_corres_train})
21
+ dataset_scannet: ScanNet_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannet}",
22
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
23
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_scannet}, n_corres=${n_corres_train})
24
+ dataset_hypersim: HyperSim_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_hypersim}",
25
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
26
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_hypersim}, n_corres=${n_corres_train})
27
+ dataset_blendedmvs: BlendedMVS_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_blendedmvs}",
28
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
29
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_blendedmvs}, n_corres=${n_corres_train})
30
+ dataset_megadepth: MegaDepth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_megadepth}",
31
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
32
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_megadepth}, n_corres=${n_corres_train})
33
+ dataset_mvs_synth: MVS_Synth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_mvs_synth}",
34
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
35
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views_mvs_synth}, n_corres=${n_corres_train})
36
+ desc_dim: 128
37
+ detach_frontend_tokens: true
38
+ dist_backend: nccl
39
+ dist_url: env://
40
+ distributed: false
41
+ enable_dynamic_boundary: false
42
+ enable_loop: true
43
+ enable_submap: true
44
+ enable_temporal: false
45
+ epochs: 10
46
+ eval_freq: 1
47
+ exp_name: paper_all_h20_pseudo_gt_high_recall
48
+ fixed_length: true
49
+ freeze_encoder: true
50
+ gpu: 0
51
+ gradient_checkpointing: true
52
+ gumbel_tau: 5.0
53
+ loop_mask_mode: soft_all
54
+ retain_history_grad: true
55
+ submap_train_mode: full_token
56
+ submap_retrieval_topk: 0
57
+ submap_fetch_source: frontend
58
+ submap_descriptor_source: frontend
59
+ train_submap_modules_only: true
60
+ gumbel_tau_end: 0.1
61
+ gumbel_tau_start: 5.0
62
+ hydra:
63
+ run:
64
+ dir: ${save_dir}/${exp_name}
65
+ verbose: true
66
+ keep_freq: 1
67
+ load_only_encoder: false
68
+ local-rank: -1
69
+ logdir: ${save_dir}/${exp_name}/logs
70
+ long_context: false
71
+ lr: 1e-5
72
+ max_checkpoints: 10
73
+ max_recursive_submaps: 5
74
+ min_lr: 1e-8
75
+ n_corres_test: 0
76
+ n_corres_train: 0
77
+ num_imgs_vis: 4
78
+ num_test_views: 4
79
+ num_views: 24
80
+ num_views_arkit: 24
81
+ num_views_scannetpp: 24
82
+ num_views_scannet: 24
83
+ num_views_hypersim: 24
84
+ num_views_blendedmvs: 24
85
+ num_views_megadepth: 24
86
+ num_views_mvs_synth: 24
87
+ num_workers: 4
88
+ output_dir: ${save_dir}/${exp_name}/
89
+ pretrained: /var/scratch/qzhang2/SLAM-Former/ckpt/checkpoint-10.pth.model
90
+ print_freq: 10
91
+ print_img_freq: 50000000
92
+ rank: 0
93
+ resume: null
94
+ retention_ratio: 0.5
95
+ pseudo_gt:
96
+ enable: false
97
+ cache_path: null
98
+ use_soft_targets: true
99
+ min_confidence: 0.5
100
+ min_support_pairs: 1
101
+ topk_pairs: 8
102
+ loss_type: hybrid
103
+ loss_weight_gate: 0.05
104
+ loss_weight_desc: 0.15
105
+ geometric_support_scale: 0.5
106
+ ranking_margin: 0.05
107
+ use_l2m: true
108
+ l2m_min_certainty: 0.35
109
+ l2m_min_inlier_ratio: 0.2
110
+ save_dir: /var/scratch/qzhang2/SLAM-Former/checkpoints
111
+ save_freq: 0.1
112
+ seed: 42
113
+ soft_mask_bias: 0.2
114
+ soft_mask_temperature: 0.25
115
+ start_epoch: 0
116
+ start_step: 0
117
+ submap_size: 6
118
+ task: SLAMFormer_Submap_Finetune
119
+ tbptt_window: 0
120
+ teacher: null
121
+ temporal_embed_mode: learned
122
+ test_criterion: DistillLoss()
123
+ test_dataset: 500 @ ARKitScenes_Multi(split='test', ROOT="${root_arkit}",
124
+ resolution=(518, 392), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test})
125
+ train_criterion: DistillLoss()
126
+ train_dataset: 4500 @ ${dataset_arkit} + 2250 @ ${dataset_scannetpp} + 4500 @ ${dataset_scannet} + 1200 @ ${dataset_hypersim} + 2250 @ ${dataset_blendedmvs} + 2250 @ ${dataset_megadepth} + 450 @ ${dataset_mvs_synth}
127
+ warmup_epochs: 0.5
128
+ weight_decay: 0.05
129
+ world_size: 1
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/finetune_sub_only.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accum_iter: 1
2
+ allow_repeat: false
3
+ amp: 1
4
+ batch_size: 1
5
+ benchmark: false
6
+ custom_lr_scale: 1.0
7
+ data_root: /var/scratch/qzhang2/SLAM-Former/data/train
8
+ root_arkit: ${data_root}/processed_arkitscenes
9
+ root_scannetpp: ${data_root}/processed_scannetpp
10
+ root_scannet: ${data_root}/processed_scannet
11
+ root_hypersim: ${data_root}/hypersim
12
+ root_blendedmvs: ${data_root}/processed_blendedmvs
13
+ root_megadepth: ${data_root}/processed_megadepth
14
+ root_mvs_synth: ${data_root}/processed_mvs_synth
15
+ dataset_arkit: ARKitScenes_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_arkit}",
16
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
17
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
18
+ dataset_scannet: ScanNet_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannet}",
19
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
20
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
21
+ dataset_hypersim: HyperSim_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_hypersim}",
22
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
23
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
24
+ dataset_blendedmvs: BlendedMVS_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_blendedmvs}",
25
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
26
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
27
+ dataset_megadepth: MegaDepth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_megadepth}",
28
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
29
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
30
+ dataset_mvs_synth: MVS_Synth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_mvs_synth}",
31
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
32
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
33
+ dataset_scannetpp: ScanNetpp_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_scannetpp}",
34
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210),
35
+ (518, 154)], transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
36
+ desc_dim: 128
37
+ detach_frontend_tokens: true
38
+ dist_backend: nccl
39
+ dist_url: env://
40
+ distributed: false
41
+ enable_dynamic_boundary: false
42
+ enable_loop: true
43
+ enable_submap: true
44
+ enable_temporal: false
45
+ epochs: 10
46
+ eval_freq: 1
47
+ exp_name: submap_sub_only_softall_64_24_v1
48
+ fixed_length: true
49
+ freeze_encoder: true
50
+ gpu: 0
51
+ gradient_checkpointing: true
52
+ gumbel_tau: 5.0
53
+ loop_mask_mode: soft_all
54
+ retain_history_grad: true
55
+ submap_train_mode: full_token
56
+ submap_retrieval_topk: 0
57
+ submap_fetch_source: frontend
58
+ submap_descriptor_source: frontend
59
+ train_submap_modules_only: true
60
+ gumbel_tau_end: 0.1
61
+ gumbel_tau_start: 5.0
62
+ hydra:
63
+ run:
64
+ dir: ${save_dir}/${exp_name}
65
+ verbose: true
66
+ keep_freq: 1
67
+ load_only_encoder: false
68
+ local-rank: -1
69
+ logdir: ${save_dir}/${exp_name}/logs
70
+ long_context: false
71
+ lr: 1e-5
72
+ max_checkpoints: 10
73
+ max_recursive_submaps: 5
74
+ min_lr: 1e-8
75
+ n_corres_test: 0
76
+ n_corres_train: 0
77
+ num_imgs_vis: 4
78
+ num_test_views: 4
79
+ num_views: 24
80
+ num_views_arkit: 24
81
+ num_views_scannetpp: 24
82
+ num_views_scannet: 24
83
+ num_views_hypersim: 24
84
+ num_views_blendedmvs: 24
85
+ num_views_megadepth: 24
86
+ num_views_mvs_synth: 24
87
+ num_workers: 4
88
+ output_dir: ${save_dir}/${exp_name}/
89
+ pretrained: /var/scratch/qzhang2/SLAM-Former/ckpt/checkpoint-10.pth.model
90
+ print_freq: 10
91
+ print_img_freq: 50000000
92
+ rank: 0
93
+ resume: null
94
+ retention_ratio: 0.5
95
+ pseudo_gt:
96
+ enable: false
97
+ cache_path: null
98
+ use_soft_targets: true
99
+ min_confidence: 0.65
100
+ min_support_pairs: 1
101
+ topk_pairs: 4
102
+ loss_type: hybrid
103
+ loss_weight_gate: 0.1
104
+ loss_weight_desc: 0.1
105
+ geometric_support_scale: 0.25
106
+ ranking_margin: 0.1
107
+ use_l2m: false
108
+ l2m_min_certainty: 0.0
109
+ l2m_min_inlier_ratio: 0.0
110
+ save_dir: /var/scratch/qzhang2/SLAM-Former/checkpoints
111
+ save_freq: 0.1
112
+ seed: 42
113
+ soft_mask_bias: 0.2
114
+ soft_mask_temperature: 0.25
115
+ start_epoch: 0
116
+ start_step: 0
117
+ submap_size: 6
118
+ task: SLAMFormer_Submap_Finetune
119
+ tbptt_window: 0
120
+ teacher: null
121
+ temporal_embed_mode: learned
122
+ test_criterion: DistillLoss()
123
+ test_dataset: 500 @ ARKitScenes_Multi(split='test', ROOT="${root_arkit}",
124
+ resolution=(518, 392), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test})
125
+ train_criterion: DistillLoss()
126
+ train_dataset: 2250 @ ${dataset_scannetpp} + 450 @ ${dataset_mvs_synth} + 4500 @ ${dataset_arkit}
127
+ warmup_epochs: 0.5
128
+ weight_decay: 0.05
129
+ world_size: 1
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/config/mytrain.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ teacher: ../ckpt/model.pt # no use
2
+ pretrained: ../ckpt/pi3.pth
3
+
4
+ load_only_encoder: False
5
+ long_context: False
6
+ fixed_length: True
7
+ resume: Null
8
+ benchmark: False
9
+ num_views : 12
10
+ num_test_views : 4
11
+ n_corres_train: 0
12
+ n_corres_test: 0
13
+
14
+ train_criterion: DistillLoss()
15
+ test_criterion: DistillLoss()
16
+ allow_repeat: False
17
+
18
+ dataset3: ARKitScenes_Multi(allow_repeat=${allow_repeat}, split='train', ROOT='../data/train/processed_arkitscenes/',
19
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
20
+ transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
21
+ dataset5: ScanNetpp_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="../data/train/processed_scannetpp/",
22
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
23
+ transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
24
+ dataset6: ScanNet_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="../data/train/processed_scannet/",
25
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
26
+ transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
27
+ dataset7: HyperSim_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="../data/train/hypersim",
28
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
29
+ transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
30
+ dataset8: BlendedMVS_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="../data/train/processed_blendedmvs/",
31
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
32
+ transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
33
+ dataset9: MegaDepth_Multi(allow_repeat=${allow_repeat}, split="train", ROOT="../data/train/processed_megadepth",
34
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
35
+ transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
36
+
37
+ dataset14: MVS_Synth_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="../data/train/processed_mvs_synth",
38
+ aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)],
39
+ transform=SeqColorJitter, num_views=${num_views}, n_corres=${n_corres_train})
40
+
41
+
42
+ train_dataset: 4500 @ ${dataset3} + 2250 @ ${dataset5} + 4500 @ ${dataset6} + 1200 @ ${dataset7} + 2250 @ ${dataset8} + 2250 @ ${dataset9} + 450 @ ${dataset14}
43
+
44
+ test_dataset: 1000 @ ARKitScenes_Multi(split='test', ROOT='../data/train/processed_arkitscenes/', resolution=(518, 392), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test})
45
+
46
+
47
+ seed: 0
48
+ batch_size: 1
49
+ accum_iter: 1
50
+ gradient_checkpointing: False
51
+ epochs: 10
52
+ start_epoch: 0
53
+ start_step: 0
54
+ weight_decay: 0.05
55
+
56
+ lr: 1e-5
57
+ min_lr: 1e-8
58
+
59
+
60
+ warmup_epochs: 0.5
61
+ amp: 1
62
+
63
+ num_workers: 4 # 12
64
+ world_size: 1
65
+ local-rank: -1
66
+ dist_url: 'env://'
67
+ rank: 0
68
+ gpu: 0
69
+ distributed: False
70
+ dist_backend: 'nccl'
71
+
72
+ eval_freq: 1
73
+ save_freq: 0.1
74
+ max_checkpoints: 10
75
+ keep_freq: 1
76
+ print_freq: 10
77
+ print_img_freq: 50000000
78
+ num_imgs_vis: 4
79
+ save_dir: '../checkpoints'
80
+
81
+ exp_name: 'SLAMFormer_v1'
82
+
83
+
84
+
85
+
86
+ task: 'StreamVGGT'
87
+ logdir: ${save_dir}/${exp_name}/logs
88
+ output_dir: ${save_dir}/${exp_name}/
89
+ hydra:
90
+ verbose: True
91
+ run:
92
+ dir: ${save_dir}/${exp_name}
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/environment.yml ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: SLAM-Former
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - bzip2=1.0.8=h5eee18b_6
8
+ - ca-certificates=2025.12.2=h06a4308_0
9
+ - expat=2.7.4=h7354ed3_0
10
+ - ld_impl_linux-64=2.44=h153f514_2
11
+ - libexpat=2.7.4=h7354ed3_0
12
+ - libffi=3.4.4=h6a678d5_1
13
+ - libgcc=15.2.0=h69a1729_7
14
+ - libgcc-ng=15.2.0=h166f726_7
15
+ - libgomp=15.2.0=h4751f2c_7
16
+ - libnsl=2.0.0=h5eee18b_0
17
+ - libstdcxx=15.2.0=h39759b7_7
18
+ - libstdcxx-ng=15.2.0=hc03a8fd_7
19
+ - libuuid=1.41.5=h5eee18b_0
20
+ - libxcb=1.17.0=h9b100fa_0
21
+ - libzlib=1.3.1=hb25bd0a_0
22
+ - ncurses=6.5=h7934f7d_0
23
+ - openssl=3.5.5=h1b28b03_0
24
+ - packaging=25.0=py311h06a4308_1
25
+ - pip=26.0.1=pyhc872135_0
26
+ - pthread-stubs=0.3=h0ce48e5_1
27
+ - python=3.11.14=h6fa692b_0
28
+ - readline=8.3=hc2a1206_0
29
+ - setuptools=80.10.2=py311h06a4308_0
30
+ - sqlite=3.51.1=h3e8d24a_1
31
+ - tk=8.6.15=h54e0aa7_0
32
+ - tzdata=2026a=he532380_0
33
+ - wheel=0.46.3=py311h06a4308_0
34
+ - xorg-libx11=1.8.12=h9b100fa_1
35
+ - xorg-libxau=1.0.12=h9b100fa_0
36
+ - xorg-libxdmcp=1.1.5=h9b100fa_0
37
+ - xorg-xorgproto=2024.1=h5eee18b_1
38
+ - xz=5.8.2=h448239c_0
39
+ - zlib=1.3.1=hb25bd0a_0
40
+ - pip:
41
+ - absl-py==2.4.0
42
+ - accelerate==1.13.0
43
+ - addict==2.4.0
44
+ - aiofiles==24.1.0
45
+ - annotated-doc==0.0.4
46
+ - annotated-types==0.7.0
47
+ - antlr4-python3-runtime==4.9.3
48
+ - anyio==4.12.1
49
+ - argcomplete==3.6.3
50
+ - asttokens==3.0.1
51
+ - attrs==25.4.0
52
+ - beartype==0.22.9
53
+ - blinker==1.9.0
54
+ - brotli==1.2.0
55
+ - ccimport==0.4.4
56
+ - certifi==2026.2.25
57
+ - charset-normalizer==3.4.5
58
+ - click==8.3.1
59
+ - colorama==0.4.6
60
+ - colorlog==6.10.1
61
+ - comm==0.2.3
62
+ - configargparse==1.7.3
63
+ - contourpy==1.3.3
64
+ - cumm-cu124==0.7.11
65
+ - cycler==0.12.1
66
+ - dacite==1.9.2
67
+ - dash==4.0.0
68
+ - decorator==4.4.2
69
+ - einops==0.8.2
70
+ - embreex==2.17.7.post7
71
+ - evo==1.34.3
72
+ - executing==2.2.1
73
+ - fast-pytorch-kmeans==0.2.2
74
+ - fastapi==0.135.1
75
+ - fastjsonschema==2.21.2
76
+ - ffmpy==1.0.0
77
+ - filelock==3.25.0
78
+ - fire==0.7.1
79
+ - flask==3.1.3
80
+ - fonttools==4.61.1
81
+ - fsspec==2026.2.0
82
+ - gradio==6.9.0
83
+ - gradio-client==2.3.0
84
+ - groovy==0.1.2
85
+ - grpcio==1.78.0
86
+ - gsplat==1.5.3
87
+ - h11==0.16.0
88
+ - h5py==3.16.0
89
+ - hf-xet==1.3.2
90
+ - httpcore==1.0.9
91
+ - httpx==0.28.1
92
+ - huggingface-hub==1.6.0
93
+ - hydra-core==1.3.2
94
+ - idna==3.11
95
+ - imageio==2.37.2
96
+ - imageio-ffmpeg==0.6.0
97
+ - importlib-metadata==8.7.1
98
+ - ipython==9.10.0
99
+ - ipython-pygments-lexers==1.1.1
100
+ - ipywidgets==8.1.8
101
+ - itsdangerous==2.2.0
102
+ - jaxtyping==0.3.9
103
+ - jedi==0.19.2
104
+ - jinja2==3.1.6
105
+ - joblib==1.5.3
106
+ - jsonschema==4.26.0
107
+ - jsonschema-specifications==2025.9.1
108
+ - jupyter-core==5.9.1
109
+ - jupyterlab-widgets==3.0.16
110
+ - kiwisolver==1.4.9
111
+ - kornia==0.8.2
112
+ - kornia-rs==0.1.10
113
+ - lark==1.3.1
114
+ - lpips==0.1.4
115
+ - lxml==6.0.2
116
+ - lz4==4.4.5
117
+ - manifold3d==3.4.0
118
+ - mapbox-earcut==2.0.0
119
+ - markdown==3.10.2
120
+ - markdown-it-py==4.0.0
121
+ - markupsafe==3.0.3
122
+ - matplotlib==3.10.8
123
+ - matplotlib-inline==0.2.1
124
+ - mdurl==0.1.2
125
+ - moviepy==1.0.3
126
+ - mpmath==1.3.0
127
+ - msgspec==0.20.0
128
+ - narwhals==2.17.0
129
+ - natsort==8.4.0
130
+ - nbformat==5.10.4
131
+ - nest-asyncio==1.6.0
132
+ - networkx==3.6.1
133
+ - ninja==1.13.0
134
+ - numexpr==2.14.1
135
+ - numpy==2.2.6
136
+ - nvidia-cublas-cu12==12.8.4.1
137
+ - nvidia-cuda-cupti-cu12==12.8.90
138
+ - nvidia-cuda-nvrtc-cu12==12.8.93
139
+ - nvidia-cuda-runtime-cu12==12.8.90
140
+ - nvidia-cudnn-cu12==9.10.2.21
141
+ - nvidia-cufft-cu12==11.3.3.83
142
+ - nvidia-cufile-cu12==1.13.1.3
143
+ - nvidia-curand-cu12==10.3.9.90
144
+ - nvidia-cusolver-cu12==11.7.3.90
145
+ - nvidia-cusparse-cu12==12.5.8.93
146
+ - nvidia-cusparselt-cu12==0.7.1
147
+ - nvidia-nccl-cu12==2.27.5
148
+ - nvidia-nvjitlink-cu12==12.8.93
149
+ - nvidia-nvshmem-cu12==3.3.20
150
+ - nvidia-nvtx-cu12==12.8.90
151
+ - omegaconf==2.3.0
152
+ - open3d==0.18.0
153
+ - opencv-python==4.13.0.92
154
+ - orjson==3.11.7
155
+ - pandas==3.0.1
156
+ - parso==0.8.6
157
+ - pccm==0.4.16
158
+ - pexpect==4.9.0
159
+ - pillow==12.1.1
160
+ - platformdirs==4.9.4
161
+ - plotly==6.6.0
162
+ - plyfile==1.1.3
163
+ - portalocker==3.2.0
164
+ - proglog==0.1.12
165
+ - prompt-toolkit==3.0.52
166
+ - protobuf==7.34.0
167
+ - psutil==7.2.2
168
+ - ptyprocess==0.7.0
169
+ - pure-eval==0.2.3
170
+ - pyarrow==23.0.1
171
+ - pybind11==3.0.3
172
+ - pycollada==0.9.3
173
+ - pydantic==2.12.5
174
+ - pydantic-core==2.41.5
175
+ - pydub==0.25.1
176
+ - pyglet==1.5.31
177
+ - pygments==2.19.2
178
+ - pyparsing==3.3.2
179
+ - pyquaternion==0.9.9
180
+ - python-dateutil==2.9.0.post0
181
+ - python-multipart==0.0.22
182
+ - pytz==2026.1.post1
183
+ - pyyaml==6.0.3
184
+ - referencing==0.37.0
185
+ - regex==2026.2.28
186
+ - requests==2.32.5
187
+ - rerun-sdk==0.30.1
188
+ - retrying==1.4.2
189
+ - rich==14.3.3
190
+ - roma==1.5.6
191
+ - rosbags==0.11.0
192
+ - rpds-py==0.30.0
193
+ - rtree==1.4.1
194
+ - ruamel-yaml==0.19.1
195
+ - safehttpx==0.1.7
196
+ - safetensors==0.7.0
197
+ - scikit-learn==1.8.0
198
+ - scipy==1.17.1
199
+ - seaborn==0.13.2
200
+ - semantic-version==2.10.0
201
+ - shapely==2.1.2
202
+ - shellingham==1.5.4
203
+ - six==1.17.0
204
+ - slamformer==0.1.0
205
+ - spconv-cu124==2.3.8
206
+ - stack-data==0.6.3
207
+ - starlette==0.52.1
208
+ - svg-path==7.0
209
+ - sympy==1.14.0
210
+ - tensorboard==2.20.0
211
+ - tensorboard-data-server==0.7.2
212
+ - termcolor==3.3.0
213
+ - threadpoolctl==3.6.0
214
+ - timm==1.0.26
215
+ - tokenizers==0.22.2
216
+ - tomlkit==0.13.3
217
+ - torch==2.9.1
218
+ - torch-cluster==1.6.3+pt25cu124
219
+ - torch-scatter==2.1.2+pt25cu124
220
+ - torch-sparse==0.6.18+pt25cu124
221
+ - torch-spline-conv==1.2.2+pt25cu124
222
+ - torchvision==0.24.1
223
+ - tqdm==4.67.3
224
+ - traitlets==5.14.3
225
+ - transformers==5.3.0
226
+ - trimesh==4.11.3
227
+ - triton==3.5.1
228
+ - typer==0.24.1
229
+ - typing-extensions==4.15.0
230
+ - typing-inspection==0.4.2
231
+ - urllib3==2.6.3
232
+ - uvicorn==0.41.0
233
+ - vhacdx==0.0.10
234
+ - viser==1.0.24
235
+ - wadler-lindig==0.1.7
236
+ - wcwidth==0.6.0
237
+ - websockets==15.0.1
238
+ - werkzeug==3.1.6
239
+ - widgetsnbextension==4.0.15
240
+ - xxhash==3.6.0
241
+ - yapf==0.43.0
242
+ - yourdfpy==0.0.60
243
+ - zipp==3.23.0
244
+ - zstandard==0.25.0
245
+ prefix: /var/scratch/qzhang2/miniconda3/envs/SLAM-Former
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/eval_ate_scaled.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import numpy as np
4
+
5
+ try:
6
+ import evo
7
+ from evo.core import metrics
8
+ import evo.main_ape as main_ape
9
+ from evo.core.metrics import PoseRelation
10
+ from evo.core.trajectory import PosePath3D
11
+ from evo.tools import file_interface
12
+ import evo.core.sync as sync
13
+ HAS_EVO = True
14
+ except ImportError:
15
+ HAS_EVO = False
16
+ print("EVO not found. Please install evo: pip install evo")
17
+ exit(1)
18
+
19
+ TUM_DIR = "/var/scratch/qzhang2/Feature-SLAM/datasets/tum"
20
+ RESULTS_DIR = os.environ.get("RESULTS_DIR", "./tum_results")
21
+
22
+ sequences = [d for d in os.listdir(RESULTS_DIR) if os.path.isdir(os.path.join(RESULTS_DIR, d))]
23
+
24
+ print(f"{'Sequence':<40} | {'ATE (m) [Unscaled]':<20} | {'ATE (m) [Scale Aligned]':<20}")
25
+ print("-" * 85)
26
+
27
+ def get_ate(est_file, gt_file, align_scale=False):
28
+ try:
29
+ traj_ref = file_interface.read_tum_trajectory_file(gt_file)
30
+ traj_est = file_interface.read_tum_trajectory_file(est_file)
31
+ traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est)
32
+
33
+ # Sim3 alignment (rotation, translation, and SCALE)
34
+ traj_est.align(traj_ref, correct_scale=align_scale)
35
+
36
+ result = main_ape.ape(traj_ref, traj_est, pose_relation=PoseRelation.translation_part, align=True, correct_scale=align_scale)
37
+ return result.stats["rmse"]
38
+ except Exception as e:
39
+ return f"Error: {e}"
40
+
41
+ for seq in sorted(sequences):
42
+ est_file = os.path.join(RESULTS_DIR, seq, f"final_traj.txt")
43
+ gt_file = os.path.join(TUM_DIR, seq, "groundtruth.txt")
44
+
45
+ if os.path.exists(est_file) and os.path.exists(gt_file):
46
+ ate_unscaled = get_ate(est_file, gt_file, align_scale=False)
47
+ ate_scaled = get_ate(est_file, gt_file, align_scale=True)
48
+
49
+ unscaled_str = f"{ate_unscaled:.4f}" if isinstance(ate_unscaled, float) else str(ate_unscaled)
50
+ scaled_str = f"{ate_scaled:.4f}" if isinstance(ate_scaled, float) else str(ate_scaled)
51
+
52
+ print(f"{seq:<40} | {unscaled_str:<20} | {scaled_str:<20}")
53
+ else:
54
+ print(f"{seq:<40} | Missing files or still running")
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/get_ate.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import numpy as np
4
+
5
+ try:
6
+ import evo
7
+ from evo.core import metrics
8
+ import evo.main_ape as main_ape
9
+ from evo.core.metrics import PoseRelation
10
+ from evo.core.trajectory import PosePath3D
11
+ from evo.tools import file_interface
12
+ import evo.core.sync as sync
13
+ HAS_EVO = True
14
+ except ImportError:
15
+ HAS_EVO = False
16
+ print("EVO not found, using simple ATE calculation")
17
+
18
+ TUM_DIR = "/var/scratch/qzhang2/Feature-SLAM/datasets/tum"
19
+ RESULTS_DIR = os.environ.get("RESULTS_DIR", "./tum_results")
20
+
21
+ sequences = [d for d in os.listdir(RESULTS_DIR) if os.path.isdir(os.path.join(RESULTS_DIR, d))]
22
+
23
+ print(f"{'Sequence':<40} | {'ATE (m)':<15}")
24
+ print("-" * 58)
25
+
26
+ def get_ate(est_file, gt_file):
27
+ if HAS_EVO:
28
+ try:
29
+ traj_ref = file_interface.read_tum_trajectory_file(gt_file)
30
+ traj_est = file_interface.read_tum_trajectory_file(est_file)
31
+ traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est)
32
+ traj_est.align(traj_ref, correct_scale=False)
33
+
34
+ result = main_ape.ape(traj_ref, traj_est, pose_relation=PoseRelation.translation_part, align=True, correct_scale=False)
35
+ return result.stats["rmse"]
36
+ except Exception as e:
37
+ return f"Error: {e}"
38
+ else:
39
+ # Fallback to simple ATE if evo is not available
40
+ try:
41
+ est_data = np.loadtxt(est_file)
42
+ gt_data = np.loadtxt(gt_file)
43
+
44
+ # Simple timestamp matching
45
+ ate_sum = 0
46
+ count = 0
47
+
48
+ for est in est_data:
49
+ ts = est[0]
50
+ # Find closest gt timestamp
51
+ idx = np.argmin(np.abs(gt_data[:, 0] - ts))
52
+ if np.abs(gt_data[idx, 0] - ts) < 0.1: # 100ms threshold
53
+ diff = est[1:4] - gt_data[idx, 1:4]
54
+ ate_sum += np.sum(diff**2)
55
+ count += 1
56
+
57
+ if count > 0:
58
+ return np.sqrt(ate_sum / count)
59
+ return "No matches"
60
+ except Exception as e:
61
+ return f"Error"
62
+
63
+ for seq in sorted(sequences):
64
+ est_file = os.path.join(RESULTS_DIR, seq, f"final_traj.txt")
65
+ gt_file = os.path.join(TUM_DIR, seq, "groundtruth.txt")
66
+
67
+ if os.path.exists(est_file) and os.path.exists(gt_file):
68
+ ate = get_ate(est_file, gt_file)
69
+ if isinstance(ate, float):
70
+ print(f"{seq:<40} | {ate:.4f}")
71
+ else:
72
+ print(f"{seq:<40} | {ate}")
73
+ else:
74
+ print(f"{seq:<40} | Missing files or running")
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/publish_submap.sh ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ usage() {
5
+ cat <<'EOF'
6
+ Usage:
7
+ GITHUB_TOKEN=... ./publish_submap.sh [commit message]
8
+
9
+ Optional env vars:
10
+ SOURCE_ROOT Source repository root (default: current git top-level)
11
+ PUBLISH_DIR Export directory (default: /var/scratch/qzhang2/e2e-semantic-SLAM-publish)
12
+ TARGET_REPO GitHub repo in owner/name form (default: SlamMate/e2e-semantic-SLAM)
13
+ BRANCH Git branch to push (default: submap)
14
+ GH_TOKEN Alternative to GITHUB_TOKEN
15
+ GIT_USER_NAME Commit author name override
16
+ GIT_USER_EMAIL Commit author email override
17
+ EOF
18
+ }
19
+
20
+ if [[ "${1:-}" == "-h" || "${1:-}" == "--help" ]]; then
21
+ usage
22
+ exit 0
23
+ fi
24
+
25
+ if [[ $# -gt 0 ]]; then
26
+ COMMIT_MSG="$*"
27
+ else
28
+ COMMIT_MSG="${COMMIT_MSG:-Auto export: $(date +%F_%H%M%S)}"
29
+ fi
30
+
31
+ SOURCE_ROOT="${SOURCE_ROOT:-$(git rev-parse --show-toplevel 2>/dev/null || pwd)}"
32
+ PUBLISH_DIR="${PUBLISH_DIR:-/var/scratch/qzhang2/e2e-semantic-SLAM-publish}"
33
+ TARGET_REPO="${TARGET_REPO:-SlamMate/e2e-semantic-SLAM}"
34
+ BRANCH="${BRANCH:-submap}"
35
+ TOKEN="${GITHUB_TOKEN:-${GH_TOKEN:-}}"
36
+
37
+ if [[ -z "${TOKEN}" ]]; then
38
+ echo "Missing GITHUB_TOKEN (or GH_TOKEN)." >&2
39
+ exit 1
40
+ fi
41
+
42
+ case "${PUBLISH_DIR}" in
43
+ /tmp/*|/var/tmp/*|/var/scratch/*) ;;
44
+ *)
45
+ echo "Refusing to use unsafe PUBLISH_DIR: ${PUBLISH_DIR}" >&2
46
+ exit 1
47
+ ;;
48
+ esac
49
+
50
+ if [[ ! -d "${SOURCE_ROOT}" ]]; then
51
+ echo "SOURCE_ROOT does not exist: ${SOURCE_ROOT}" >&2
52
+ exit 1
53
+ fi
54
+
55
+ rm -rf "${PUBLISH_DIR}"
56
+ mkdir -p "${PUBLISH_DIR}"
57
+
58
+ copy_file() {
59
+ local rel="$1"
60
+ if [[ -f "${SOURCE_ROOT}/${rel}" ]]; then
61
+ rsync -a "${SOURCE_ROOT}/${rel}" "${PUBLISH_DIR}/"
62
+ fi
63
+ }
64
+
65
+ copy_dir() {
66
+ local rel="$1"
67
+ if [[ -d "${SOURCE_ROOT}/${rel}" ]]; then
68
+ mkdir -p "${PUBLISH_DIR}/${rel}"
69
+ rsync -a \
70
+ --exclude '__pycache__/' \
71
+ --exclude '*.pyc' \
72
+ --exclude '*.pyo' \
73
+ "${SOURCE_ROOT}/${rel}/" "${PUBLISH_DIR}/${rel}/"
74
+ fi
75
+ }
76
+
77
+ for file in \
78
+ .gitignore \
79
+ README.md \
80
+ README_submap.md \
81
+ README_cluster_migration.md \
82
+ requirements.txt \
83
+ setup.py \
84
+ setup_env.sh \
85
+ run_tum.sh \
86
+ run_tum_top5.sh \
87
+ eval_ate_scaled.py \
88
+ get_ate.py \
89
+ submap_handoff.md \
90
+ publish_submap.sh
91
+ do
92
+ copy_file "${file}"
93
+ done
94
+
95
+ for dir in cloud_opt config slam src; do
96
+ copy_dir "${dir}"
97
+ done
98
+
99
+ if [[ ! -f "${PUBLISH_DIR}/README.md" ]]; then
100
+ echo "Export did not produce README.md; aborting." >&2
101
+ exit 1
102
+ fi
103
+
104
+ git -C "${PUBLISH_DIR}" init >/dev/null
105
+ git -C "${PUBLISH_DIR}" checkout -b "${BRANCH}" >/dev/null
106
+
107
+ GIT_USER_NAME="${GIT_USER_NAME:-$(git config --global user.name 2>/dev/null || echo Cascade)}"
108
+ GIT_USER_EMAIL="${GIT_USER_EMAIL:-$(git config --global user.email 2>/dev/null || echo cascade@example.com)}"
109
+ git -C "${PUBLISH_DIR}" config user.name "${GIT_USER_NAME}"
110
+ git -C "${PUBLISH_DIR}" config user.email "${GIT_USER_EMAIL}"
111
+
112
+ git -C "${PUBLISH_DIR}" add .
113
+ if git -C "${PUBLISH_DIR}" diff --cached --quiet; then
114
+ echo "No changes to commit in ${PUBLISH_DIR}."
115
+ else
116
+ git -C "${PUBLISH_DIR}" commit -m "${COMMIT_MSG}"
117
+ fi
118
+
119
+ git -C "${PUBLISH_DIR}" remote remove origin >/dev/null 2>&1 || true
120
+ git -C "${PUBLISH_DIR}" remote add origin "https://github.com/${TARGET_REPO}.git"
121
+
122
+ ASKPASS_HELPER="$(mktemp)"
123
+ cat >"${ASKPASS_HELPER}" <<'EOF'
124
+ #!/usr/bin/env bash
125
+ case "$1" in
126
+ *Username*) printf '%s\n' 'x-access-token' ;;
127
+ *Password*) printf '%s\n' "${GIT_TOKEN}" ;;
128
+ *) printf '%s\n' '' ;;
129
+ esac
130
+ EOF
131
+ chmod 700 "${ASKPASS_HELPER}"
132
+ trap 'rm -f "${ASKPASS_HELPER}"' EXIT
133
+
134
+ GIT_TOKEN="${TOKEN}" GIT_ASKPASS="${ASKPASS_HELPER}" GIT_TERMINAL_PROMPT=0 \
135
+ git -C "${PUBLISH_DIR}" push --force-with-lease -u origin "${BRANCH}"
136
+
137
+ echo "Published ${SOURCE_ROOT} -> ${TARGET_REPO} [${BRANCH}]"
138
+ echo "Export dir: ${PUBLISH_DIR}"
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.9.1
2
+ torchvision==0.24.1
3
+ numpy==2.2.6
4
+ Pillow
5
+ huggingface_hub
6
+ safetensors
7
+ roma
8
+ gradio
9
+ matplotlib
10
+ tqdm
11
+ opencv-python
12
+ scipy
13
+ einops
14
+ trimesh
15
+ tensorboard
16
+ pyglet<2
17
+ viser
18
+ gradio
19
+ lpips
20
+ hydra-core
21
+ h5py
22
+ accelerate
23
+ transformers
24
+ scikit-learn
25
+ gsplat
26
+ evo
27
+ open3d
28
+ rerun-sdk
29
+ kornia
30
+ moviepy==1.0.3
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/run_tum.sh ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=slam_tum
3
+ #SBATCH --output=slam_tum_%j.out
4
+ #SBATCH --partition=defq
5
+ #SBATCH --nodes=1
6
+ #SBATCH --gpus=1
7
+ #SBATCH --time=12:00:00
8
+
9
+ source /var/scratch/qzhang2/miniconda3/etc/profile.d/conda.sh
10
+ conda activate SLAM-Former
11
+
12
+ module load cuda12.1/toolkit/12.1
13
+ # Try with a different cudnn or skip it if it's causing issues. Let's not load cudnn module as it might conflict with pytorch's built-in cudnn
14
+ # module load cuDNN/cuda12.1/9.1.0.70
15
+ export CC=/opt/ohpc/pub/compiler/gcc/9.4.0/bin/gcc
16
+ export CXX=/opt/ohpc/pub/compiler/gcc/9.4.0/bin/g++
17
+
18
+ TUM_DIR="/var/scratch/qzhang2/Feature-SLAM/datasets/tum"
19
+ CKPT_PATH="${CKPT_PATH:-/var/scratch/qzhang2/SLAM-Former/ckpt/checkpoint-10.pth.model}"
20
+ RESULT_ROOT="${RESULT_ROOT:-./tum_results_aligned}"
21
+ SUBMAP_INFERENCE_MODE="${SUBMAP_INFERENCE_MODE:-full}"
22
+ LOOP_MASK_MODE="${LOOP_MASK_MODE:-soft_all}"
23
+ SUBMAP_TRAIN_MODE="${SUBMAP_TRAIN_MODE:-top5_dual_queue}"
24
+ SUBMAP_RETRIEVAL_TOPK="${SUBMAP_RETRIEVAL_TOPK:-5}"
25
+ SUBMAP_FETCH_SOURCE="${SUBMAP_FETCH_SOURCE:-frontend}"
26
+ SUBMAP_DESCRIPTOR_SOURCE="${SUBMAP_DESCRIPTOR_SOURCE:-frontend}"
27
+ MAX_RECURSIVE_SUBMAPS="${MAX_RECURSIVE_SUBMAPS:-5}"
28
+ CKPT_NAME="$(basename "$CKPT_PATH")"
29
+ CKPT_NAME="${CKPT_NAME%.pth.model}"
30
+ CKPT_NAME="${CKPT_NAME%.pth}"
31
+ CKPT_PARENT="$(basename "$(dirname "$CKPT_PATH")")"
32
+ CKPT_HASH="$(printf '%s' "$CKPT_PATH" | sha1sum | cut -c1-8)"
33
+ RUN_TAG="${RUN_TAG:-${CKPT_PARENT}__${CKPT_NAME}__${CKPT_HASH}}"
34
+ OUT_DIR="${OUT_DIR:-${RESULT_ROOT}/${RUN_TAG}}"
35
+ mkdir -p "$OUT_DIR"
36
+
37
+ case "$SUBMAP_INFERENCE_MODE" in
38
+ full)
39
+ DEMO_ARGS=()
40
+ ;;
41
+ top5)
42
+ DEMO_ARGS=(
43
+ --loop_mask_mode "$LOOP_MASK_MODE"
44
+ --submap_train_mode "$SUBMAP_TRAIN_MODE"
45
+ --submap_retrieval_topk "$SUBMAP_RETRIEVAL_TOPK"
46
+ --submap_fetch_source "$SUBMAP_FETCH_SOURCE"
47
+ --submap_descriptor_source "$SUBMAP_DESCRIPTOR_SOURCE"
48
+ --max_recursive_submaps "$MAX_RECURSIVE_SUBMAPS"
49
+ )
50
+ ;;
51
+ *)
52
+ echo "Unknown SUBMAP_INFERENCE_MODE=$SUBMAP_INFERENCE_MODE (expected full or top5)" >&2
53
+ exit 1
54
+ ;;
55
+ esac
56
+
57
+ echo "Checkpoint: $CKPT_PATH"
58
+ echo "Run tag: $RUN_TAG"
59
+ echo "Output root: $OUT_DIR"
60
+ echo "Submap inference mode: $SUBMAP_INFERENCE_MODE"
61
+
62
+ for seq in "$TUM_DIR"/rgbd_dataset_freiburg1_*; do
63
+ if [ -d "$seq/rgb" ]; then
64
+ seq_name=$(basename "$seq")
65
+ echo "======================================"
66
+ echo "Running on $seq_name..."
67
+ echo "======================================"
68
+
69
+ # The demo expects the image folder which contains images. For TUM it is the 'rgb' folder.
70
+ python slam/demo_submap.py \
71
+ --ckpt_path "$CKPT_PATH" \
72
+ --image_folder "$seq/rgb" \
73
+ --output_dir "$OUT_DIR/$seq_name" \
74
+ --target_size 518 \
75
+ "${DEMO_ARGS[@]}"
76
+ fi
77
+ done
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/run_tum_top5.sh ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=slam_tum_top5
3
+ #SBATCH --output=slam_tum_top5_%j.out
4
+ #SBATCH --partition=defq
5
+ #SBATCH --nodes=1
6
+ #SBATCH --gpus=1
7
+ #SBATCH --time=12:00:00
8
+
9
+ source /var/scratch/qzhang2/miniconda3/etc/profile.d/conda.sh
10
+ conda activate SLAM-Former
11
+
12
+ module load cuda12.1/toolkit/12.1
13
+ # Try with a different cudnn or skip it if it's causing issues. Let's not load cudnn module as it might conflict with pytorch's built-in cudnn
14
+ # module load cuDNN/cuda12.1/9.1.0.70
15
+ export CC=/opt/ohpc/pub/compiler/gcc/9.4.0/bin/gcc
16
+ export CXX=/opt/ohpc/pub/compiler/gcc/9.4.0/bin/g++
17
+
18
+ TUM_DIR="/var/scratch/qzhang2/Feature-SLAM/datasets/tum"
19
+ CKPT_PATH="${CKPT_PATH:-/var/scratch/qzhang2/SLAM-Former/checkpoints/local_cluster_nv24_sub6/submap_only_pseudo_gt_high_recall_smoke/paper_local_submap_only_pseudo_gt_high_recall_smoke_nv24_sub6/checkpoint-last.pth}"
20
+ RESULT_ROOT="${RESULT_ROOT:-./tum_results_aligned_top5}"
21
+ SUBMAP_INFERENCE_MODE="${SUBMAP_INFERENCE_MODE:-top5}"
22
+ LOOP_MASK_MODE="${LOOP_MASK_MODE:-soft_all}"
23
+ SUBMAP_TRAIN_MODE="${SUBMAP_TRAIN_MODE:-top5_dual_queue}"
24
+ SUBMAP_RETRIEVAL_TOPK="${SUBMAP_RETRIEVAL_TOPK:-5}"
25
+ SUBMAP_FETCH_SOURCE="${SUBMAP_FETCH_SOURCE:-frontend}"
26
+ SUBMAP_DESCRIPTOR_SOURCE="${SUBMAP_DESCRIPTOR_SOURCE:-frontend}"
27
+ MAX_RECURSIVE_SUBMAPS="${MAX_RECURSIVE_SUBMAPS:-5}"
28
+ CKPT_NAME="$(basename "$CKPT_PATH")"
29
+ CKPT_NAME="${CKPT_NAME%.pth.model}"
30
+ CKPT_NAME="${CKPT_NAME%.pth}"
31
+ CKPT_PARENT="$(basename "$(dirname "$CKPT_PATH")")"
32
+ CKPT_HASH="$(printf '%s' "$CKPT_PATH" | sha1sum | cut -c1-8)"
33
+ RUN_TAG="${RUN_TAG:-${CKPT_PARENT}__${CKPT_NAME}__${SUBMAP_INFERENCE_MODE}__${CKPT_HASH}}"
34
+ OUT_DIR="${OUT_DIR:-${RESULT_ROOT}/${RUN_TAG}}"
35
+ mkdir -p "$OUT_DIR"
36
+
37
+ DEMO_ARGS=(
38
+ --loop_mask_mode "$LOOP_MASK_MODE"
39
+ --submap_train_mode "$SUBMAP_TRAIN_MODE"
40
+ --submap_retrieval_topk "$SUBMAP_RETRIEVAL_TOPK"
41
+ --submap_fetch_source "$SUBMAP_FETCH_SOURCE"
42
+ --submap_descriptor_source "$SUBMAP_DESCRIPTOR_SOURCE"
43
+ --max_recursive_submaps "$MAX_RECURSIVE_SUBMAPS"
44
+ )
45
+
46
+ echo "Checkpoint: $CKPT_PATH"
47
+ echo "Run tag: $RUN_TAG"
48
+ echo "Output root: $OUT_DIR"
49
+ echo "Submap inference mode: $SUBMAP_INFERENCE_MODE"
50
+ echo "Comparative baseline: run_tum.sh with SUBMAP_INFERENCE_MODE=full and the same CKPT_PATH"
51
+
52
+ for seq in "$TUM_DIR"/rgbd_dataset_freiburg1_*; do
53
+ if [ -d "$seq/rgb" ]; then
54
+ seq_name=$(basename "$seq")
55
+ echo "======================================"
56
+ echo "Running on $seq_name..."
57
+ echo "======================================"
58
+
59
+ python slam/demo_submap.py \
60
+ --ckpt_path "$CKPT_PATH" \
61
+ --image_folder "$seq/rgb" \
62
+ --output_dir "$OUT_DIR/$seq_name" \
63
+ --target_size 518 \
64
+ "${DEMO_ARGS[@]}"
65
+ fi
66
+ done
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/setup.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='slamformer',
5
+ version='0.1.0',
6
+ description='SLAM-Former: Putting SLAM into One Transformer.',
7
+ packages=find_packages(include=['evals', 'evals.*', 'src/slamformer', 'src/slamformer.*']),
8
+ )
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/setup_env.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=setup_slam
3
+ #SBATCH --output=setup_slam.out
4
+ #SBATCH --partition=defq
5
+ #SBATCH --nodes=1
6
+ #SBATCH --gpus=1
7
+ #SBATCH --time=02:00:00
8
+
9
+ source /var/scratch/qzhang2/miniconda3/etc/profile.d/conda.sh
10
+ conda activate SLAM-Former
11
+
12
+ module load cuda12.1/toolkit/12.1
13
+ module load cuDNN/cuda12.1/9.1.0.70
14
+ export CC=/opt/ohpc/pub/compiler/gcc/9.4.0/bin/gcc
15
+ export CXX=/opt/ohpc/pub/compiler/gcc/9.4.0/bin/g++
16
+
17
+ pip install -r requirements.txt
18
+ pip install -e .
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/__init__.py ADDED
File without changes
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/audit_dataset_num_views.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import csv
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Callable
8
+
9
+ import numpy as np
10
+
11
+ try:
12
+ import h5py
13
+ except Exception:
14
+ h5py = None
15
+
16
+
17
+ README_STATUS = {
18
+ "ARKitScenes": "released",
19
+ "ScanNet": "released",
20
+ "ScanNet++": "coming_soon_in_readme",
21
+ "HyperSim": "released_separate_hf",
22
+ "BlendedMVS": "coming_soon_in_readme",
23
+ "MegaDepth": "coming_soon_in_readme",
24
+ "MVS-Synth": "released",
25
+ }
26
+
27
+ THRESHOLDS = [24, 32, 48, 64]
28
+
29
+
30
+ def retention_str(count: int | None, total: int | None) -> str:
31
+ if count is None or total in (None, 0):
32
+ return "N/A"
33
+ return f"{count}/{total} ({100.0 * count / total:.1f}%)"
34
+
35
+
36
+ def summarize_caps(caps: list[int], thresholds: list[int]) -> dict:
37
+ if not caps:
38
+ return {
39
+ "unit_count": 0,
40
+ "strict_cap_no_skip": None,
41
+ "median_cap": None,
42
+ "max_cap": None,
43
+ "threshold_counts": {t: None for t in thresholds},
44
+ }
45
+ arr = np.asarray(caps, dtype=np.int64)
46
+ return {
47
+ "unit_count": int(arr.size),
48
+ "strict_cap_no_skip": int(arr.min()),
49
+ "median_cap": int(np.median(arr)),
50
+ "max_cap": int(arr.max()),
51
+ "threshold_counts": {t: int((arr >= t).sum()) for t in thresholds},
52
+ }
53
+
54
+
55
+ def audit_arkit(root: Path, thresholds: list[int]) -> dict:
56
+ meta_root = root / "Training"
57
+ all_meta = meta_root / "all_metadata.npz"
58
+ if not all_meta.is_file():
59
+ return {"available": False, "notes": "missing Training/all_metadata.npz"}
60
+ scene_caps = []
61
+ with np.load(all_meta) as data:
62
+ scenes = data["scenes"]
63
+ for scene in scenes:
64
+ scene_dir = meta_root / str(scene)
65
+ meta_path = scene_dir / "new_scene_metadata.npz"
66
+ if not scene_dir.is_dir() or not meta_path.is_file():
67
+ continue
68
+ with np.load(meta_path, allow_pickle=True) as data:
69
+ scene_len = len(data["images"])
70
+ best_group_len = max(
71
+ (len(group) + 1 for group in data["image_collection"].item().values()),
72
+ default=0,
73
+ )
74
+ scene_caps.append(min(scene_len, best_group_len))
75
+ stats = summarize_caps(scene_caps, thresholds)
76
+ stats.update(
77
+ {
78
+ "available": True,
79
+ "unit_name": "scene",
80
+ "constraint": "min(scene_len, best_image_collection_len_per_scene)",
81
+ "notes": "strict cap is dominated by a few very short scenes; higher num_views still work with scene skipping",
82
+ }
83
+ )
84
+ return stats
85
+
86
+
87
+ def audit_scannetpp(root: Path, thresholds: list[int]) -> dict:
88
+ all_meta = root / "all_metadata.npz"
89
+ if not all_meta.is_file():
90
+ return {"available": False, "notes": "missing all_metadata.npz"}
91
+ with np.load(all_meta) as data:
92
+ scenes = data["scenes"]
93
+ scene_caps = []
94
+ for scene in scenes:
95
+ scene_dir = root / str(scene)
96
+ meta_path = scene_dir / "new_scene_metadata.npz"
97
+ images_dir = scene_dir / "images"
98
+ if not scene_dir.is_dir() or not meta_path.is_file() or not images_dir.is_dir():
99
+ continue
100
+ with np.load(meta_path, allow_pickle=True) as data:
101
+ images = data["images"]
102
+ imgs_on_disk = {name[:-4] for name in os.listdir(images_dir)}
103
+ dslr_ids = [
104
+ i
105
+ for i in range(len(images))
106
+ if images[i].startswith("DSC") and images[i] in imgs_on_disk
107
+ ]
108
+ iphone_ids = [
109
+ i
110
+ for i in range(len(images))
111
+ if images[i].startswith("frame") and images[i] in imgs_on_disk
112
+ ]
113
+ best_cap = 0
114
+ for ref_id, group in data["image_collection"].item().items():
115
+ group_len = len(group) + 1
116
+ video_len = len(dslr_ids) if images[ref_id].startswith("frame") else len(iphone_ids)
117
+ best_cap = max(best_cap, min(group_len, video_len))
118
+ if best_cap > 0:
119
+ scene_caps.append(best_cap)
120
+ stats = summarize_caps(scene_caps, thresholds)
121
+ stats.update(
122
+ {
123
+ "available": True,
124
+ "unit_name": "scene",
125
+ "constraint": "best min(group_len, paired_video_len) per scene",
126
+ "notes": "README still marks this split as coming soon, but local processed data is already present",
127
+ }
128
+ )
129
+ return stats
130
+
131
+
132
+ def audit_scannet(root: Path, thresholds: list[int]) -> dict:
133
+ scans_train = root / "scans_train"
134
+ if not scans_train.is_dir():
135
+ return {"available": False, "notes": "missing scans_train"}
136
+ scene_caps = []
137
+ for scene in sorted(os.listdir(scans_train)):
138
+ if not scene.startswith("scene"):
139
+ continue
140
+ meta_path = scans_train / scene / "new_scene_metadata.npz"
141
+ if not meta_path.is_file():
142
+ continue
143
+ with np.load(meta_path, allow_pickle=True) as data:
144
+ scene_caps.append(len(data["images"]))
145
+ stats = summarize_caps(scene_caps, thresholds)
146
+ stats.update(
147
+ {
148
+ "available": True,
149
+ "unit_name": "scene",
150
+ "constraint": "scene_len",
151
+ "notes": "",
152
+ }
153
+ )
154
+ return stats
155
+
156
+
157
+ def audit_hypersim(root: Path, thresholds: list[int]) -> dict:
158
+ if not root.is_dir():
159
+ return {"available": False, "notes": "missing hypersim root"}
160
+ scene_caps = []
161
+ for scene in sorted(os.listdir(root)):
162
+ scene_dir = root / scene
163
+ if not scene_dir.is_dir():
164
+ continue
165
+ for subscene in sorted(os.listdir(scene_dir)):
166
+ subscene_dir = scene_dir / subscene
167
+ if not subscene_dir.is_dir():
168
+ continue
169
+ rgb_paths = [name for name in os.listdir(subscene_dir) if name.endswith(".png")]
170
+ if rgb_paths:
171
+ scene_caps.append(len(rgb_paths))
172
+ stats = summarize_caps(scene_caps, thresholds)
173
+ stats.update(
174
+ {
175
+ "available": True,
176
+ "unit_name": "subscene",
177
+ "constraint": "scene_len",
178
+ "notes": "",
179
+ }
180
+ )
181
+ return stats
182
+
183
+
184
+ def audit_mvs_synth(root: Path, thresholds: list[int]) -> dict:
185
+ if not root.is_dir():
186
+ return {"available": False, "notes": "missing processed_mvs_synth root"}
187
+ scene_caps = []
188
+ for scene in sorted(os.listdir(root)):
189
+ rgb_dir = root / scene / "rgb"
190
+ if not rgb_dir.is_dir():
191
+ continue
192
+ scene_len = len([name for name in os.listdir(rgb_dir) if name.endswith(".jpg")])
193
+ if scene_len > 0:
194
+ scene_caps.append(scene_len)
195
+ stats = summarize_caps(scene_caps, thresholds)
196
+ stats.update(
197
+ {
198
+ "available": True,
199
+ "unit_name": "scene",
200
+ "constraint": "scene_len",
201
+ "notes": "",
202
+ }
203
+ )
204
+ return stats
205
+
206
+
207
+ def audit_megadepth(root: Path, thresholds: list[int]) -> dict:
208
+ sets_path = root / "megadepth_sets_64.npz"
209
+ if not sets_path.is_file():
210
+ return {"available": False, "notes": "missing megadepth_sets_64.npz; code hard cap is 64"}
211
+ with np.load(sets_path, allow_pickle=True) as data:
212
+ scenes = data["scenes"]
213
+ sets = data["sets"]
214
+ valid_scene = np.array([not str(scene).startswith(("0015", "0022")) for scene in scenes])
215
+ valid_scene_ids = np.nonzero(valid_scene)[0]
216
+ train_mask = np.in1d(sets[:, 0], valid_scene_ids)
217
+ caps = [64] * int(train_mask.sum())
218
+ stats = summarize_caps(caps, thresholds)
219
+ stats.update(
220
+ {
221
+ "available": True,
222
+ "unit_name": "set",
223
+ "constraint": "fixed_64_image_set",
224
+ "notes": "hard code cap is 64 because the loader slices image_idxs = sets[idx][1:65]",
225
+ }
226
+ )
227
+ return stats
228
+
229
+
230
+ def build_adjacency_list(score_matrix: np.ndarray, thresh: float = 0.2) -> list[list[int]]:
231
+ score_matrix = score_matrix - thresh
232
+ score_matrix[score_matrix < 0] = 0
233
+ rows, cols = np.nonzero(score_matrix)
234
+ adjacency = [[] for _ in range(len(score_matrix))]
235
+ for row, col in zip(rows, cols):
236
+ adjacency[row].append(int(col))
237
+ return adjacency
238
+
239
+
240
+ def reachable_count(adjacency: list[list[int]], start_index: int) -> int:
241
+ stack = [start_index]
242
+ visited = set()
243
+ while stack:
244
+ node = stack.pop()
245
+ if node in visited:
246
+ continue
247
+ visited.add(node)
248
+ stack.extend(neighbor for neighbor in adjacency[node] if neighbor not in visited)
249
+ return len(visited)
250
+
251
+
252
+ def audit_blendedmvs(root: Path, thresholds: list[int]) -> dict:
253
+ overlap_path = root / "new_overlap.h5"
254
+ if not overlap_path.is_file():
255
+ return {"available": False, "notes": "missing new_overlap.h5"}
256
+ if h5py is None:
257
+ return {"available": False, "notes": "h5py is unavailable"}
258
+ ref_caps = []
259
+ with h5py.File(overlap_path, "r") as handle:
260
+ for scene_dir in handle.keys():
261
+ group = handle[scene_dir]
262
+ indices = group["indices"][:]
263
+ values = group["values"][:]
264
+ shape = group.attrs["shape"]
265
+ score_matrix = np.zeros(shape, dtype=np.float32)
266
+ score_matrix[indices[0], indices[1]] = values
267
+ adjacency = build_adjacency_list(score_matrix)
268
+ ref_caps.extend(reachable_count(adjacency, idx) for idx in range(len(adjacency)))
269
+ stats = summarize_caps(ref_caps, thresholds)
270
+ stats.update(
271
+ {
272
+ "available": True,
273
+ "unit_name": "reference",
274
+ "constraint": "reachable_unique_images_per_reference",
275
+ "notes": "for allow_repeat=false this is the reachable unique-image cap from the overlap graph",
276
+ }
277
+ )
278
+ return stats
279
+
280
+
281
+ def build_dataset_specs(data_root: Path) -> list[tuple[str, Path, Callable[[Path, list[int]], dict]]]:
282
+ return [
283
+ ("ARKitScenes", data_root / "processed_arkitscenes", audit_arkit),
284
+ ("ScanNet", data_root / "processed_scannet", audit_scannet),
285
+ ("ScanNet++", data_root / "processed_scannetpp", audit_scannetpp),
286
+ ("HyperSim", data_root / "hypersim", audit_hypersim),
287
+ ("BlendedMVS", data_root / "processed_blendedmvs", audit_blendedmvs),
288
+ ("MegaDepth", data_root / "processed_megadepth", audit_megadepth),
289
+ ("MVS-Synth", data_root / "processed_mvs_synth", audit_mvs_synth),
290
+ ]
291
+
292
+
293
+ def render_markdown(rows: list[dict], output_md: Path, thresholds: list[int]) -> None:
294
+ output_md.parent.mkdir(parents=True, exist_ok=True)
295
+ header = [
296
+ "# Paper Training Dataset num_views Audit",
297
+ "",
298
+ "This table is aligned with the paper's training dataset list in `refer/arXiv-2509.16909v1/sec/4_exp.tex` and the release status in `README.md`.",
299
+ "",
300
+ "`StrictCapNoSkip` means the largest `num_views` that keeps every current local scene/reference usable under `allow_repeat=false`.",
301
+ "Using a larger `num_views` can still work, but shorter scenes will be skipped by the dataset loader.",
302
+ "",
303
+ "| Dataset | READMEStatus | LocalAvailable | LocalPath | Unit | UnitCount | StrictCapNoSkip | MedianCap | MaxCap | >=24 | >=32 | >=48 | >=64 | Constraint | Notes |",
304
+ "|---|---|---|---|---:|---:|---:|---:|---:|---|---|---|---|---|---|",
305
+ ]
306
+ lines = []
307
+ for row in rows:
308
+ lines.append(
309
+ "| {dataset} | {readme_status} | {local_available} | `{local_path}` | {unit_name} | {unit_count} | {strict_cap_no_skip} | {median_cap} | {max_cap} | {ge24} | {ge32} | {ge48} | {ge64} | {constraint} | {notes} |".format(
310
+ dataset=row["dataset"],
311
+ readme_status=row["readme_status"],
312
+ local_available=row["local_available"],
313
+ local_path=row["local_path"],
314
+ unit_name=row["unit_name"],
315
+ unit_count=row["unit_count"],
316
+ strict_cap_no_skip=row["strict_cap_no_skip"],
317
+ median_cap=row["median_cap"],
318
+ max_cap=row["max_cap"],
319
+ ge24=row[f"ge_{thresholds[0]}"],
320
+ ge32=row[f"ge_{thresholds[1]}"],
321
+ ge48=row[f"ge_{thresholds[2]}"],
322
+ ge64=row[f"ge_{thresholds[3]}"],
323
+ constraint=row["constraint"],
324
+ notes=row["notes"],
325
+ )
326
+ )
327
+ output_md.write_text("\n".join(header + lines) + "\n")
328
+
329
+
330
+ def render_csv(rows: list[dict], output_csv: Path, thresholds: list[int]) -> None:
331
+ output_csv.parent.mkdir(parents=True, exist_ok=True)
332
+ fieldnames = [
333
+ "dataset",
334
+ "readme_status",
335
+ "local_available",
336
+ "local_path",
337
+ "unit_name",
338
+ "unit_count",
339
+ "strict_cap_no_skip",
340
+ "median_cap",
341
+ "max_cap",
342
+ f"ge_{thresholds[0]}_count",
343
+ f"ge_{thresholds[1]}_count",
344
+ f"ge_{thresholds[2]}_count",
345
+ f"ge_{thresholds[3]}_count",
346
+ "constraint",
347
+ "notes",
348
+ ]
349
+ with output_csv.open("w", newline="") as handle:
350
+ writer = csv.DictWriter(handle, fieldnames=fieldnames)
351
+ writer.writeheader()
352
+ for row in rows:
353
+ writer.writerow(
354
+ {
355
+ "dataset": row["dataset"],
356
+ "readme_status": row["readme_status"],
357
+ "local_available": row["local_available"],
358
+ "local_path": row["local_path"],
359
+ "unit_name": row["unit_name"],
360
+ "unit_count": row["unit_count"],
361
+ "strict_cap_no_skip": row["strict_cap_no_skip"],
362
+ "median_cap": row["median_cap"],
363
+ "max_cap": row["max_cap"],
364
+ f"ge_{thresholds[0]}_count": row[f"ge_{thresholds[0]}_count"],
365
+ f"ge_{thresholds[1]}_count": row[f"ge_{thresholds[1]}_count"],
366
+ f"ge_{thresholds[2]}_count": row[f"ge_{thresholds[2]}_count"],
367
+ f"ge_{thresholds[3]}_count": row[f"ge_{thresholds[3]}_count"],
368
+ "constraint": row["constraint"],
369
+ "notes": row["notes"],
370
+ }
371
+ )
372
+
373
+
374
+ def main() -> None:
375
+ repo_root = Path(__file__).resolve().parents[1]
376
+ parser = argparse.ArgumentParser()
377
+ parser.add_argument("--data-root", type=Path, default=repo_root / "data" / "train")
378
+ parser.add_argument("--output-md", type=Path, default=repo_root / "reports" / "paper_dataset_num_views.md")
379
+ parser.add_argument("--output-csv", type=Path, default=repo_root / "reports" / "paper_dataset_num_views.csv")
380
+ args = parser.parse_args()
381
+
382
+ rows = []
383
+ for dataset_name, local_path, audit_fn in build_dataset_specs(args.data_root):
384
+ stats = audit_fn(local_path, THRESHOLDS)
385
+ row = {
386
+ "dataset": dataset_name,
387
+ "readme_status": README_STATUS[dataset_name],
388
+ "local_available": "yes" if stats.get("available", False) else "no",
389
+ "local_path": str(local_path),
390
+ "unit_name": stats.get("unit_name", "N/A"),
391
+ "unit_count": stats.get("unit_count", "N/A"),
392
+ "strict_cap_no_skip": stats.get("strict_cap_no_skip", "N/A"),
393
+ "median_cap": stats.get("median_cap", "N/A"),
394
+ "max_cap": stats.get("max_cap", "N/A"),
395
+ "constraint": stats.get("constraint", "N/A"),
396
+ "notes": stats.get("notes", ""),
397
+ }
398
+ threshold_counts = stats.get("threshold_counts", {})
399
+ for threshold in THRESHOLDS:
400
+ count = threshold_counts.get(threshold)
401
+ row[f"ge_{threshold}"] = retention_str(count, stats.get("unit_count"))
402
+ row[f"ge_{threshold}_count"] = count if count is not None else "N/A"
403
+ rows.append(row)
404
+
405
+ render_markdown(rows, args.output_md, THRESHOLDS)
406
+ render_csv(rows, args.output_csv, THRESHOLDS)
407
+ print(args.output_md)
408
+ print(args.output_csv)
409
+
410
+
411
+ if __name__ == "__main__":
412
+ main()
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/batched_dynamic_router.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BatchedDynamicSubmapRouter: Differentiable dynamic submap boundary predictor.
3
+
4
+ Key design decisions:
5
+ - Fully vectorized (Fix #4): NO Python-level conditionals on boundary_flag.
6
+ Every GPU always executes the full compute graph (descriptor + retrieval +
7
+ backendT). The boundary decision is applied as a differentiable gating
8
+ multiplier so DDP AllReduce never deadlocks.
9
+ - Masked pooling (Fix #3): valid_mask [B, max_K] prevents zero-padding from
10
+ injecting dead gradients / NaN into descriptors.
11
+ - Gumbel-Softmax temperature annealing (Fix #5): τ decays from tau_start
12
+ to tau_end over training via cosine schedule.
13
+
14
+ No original source files are modified.
15
+ """
16
+
17
+ import math
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from typing import Optional
22
+
23
+
24
+ class BatchedDynamicSubmapRouter(nn.Module):
25
+ """Differentiable boundary predictor with vectorized soft gating.
26
+
27
+ At every frame step the router:
28
+ 1. Appends the new token to a padded accumulation buffer.
29
+ 2. Predicts a boundary probability via an MLP + Gumbel-Softmax STE.
30
+ 3. **Always** computes descriptors and runs backend for the full batch
31
+ (no Python branching).
32
+ 4. Soft-gates the result: ``final = flag * backend + (1-flag) * frontend``.
33
+ 5. Soft-resets the accumulation buffer using the gate.
34
+
35
+ Args:
36
+ token_dim: 2C — full token feature dimension (default 2048).
37
+ boundary_hidden_dim: MLP hidden size (default 512).
38
+ tau_start: initial Gumbel-Softmax temperature.
39
+ tau_end: final Gumbel-Softmax temperature.
40
+ max_K: maximum frames per submap (buffer size).
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ token_dim: int = 2048,
46
+ boundary_hidden_dim: int = 512,
47
+ tau_start: float = 5.0,
48
+ tau_end: float = 0.1,
49
+ max_K: int = 20,
50
+ ):
51
+ super().__init__()
52
+ self.token_dim = token_dim
53
+ self.tau_start = tau_start
54
+ self.tau_end = tau_end
55
+ self.max_K = max_K
56
+
57
+ # Boundary predictor: takes [prev_token || curr_token] → logit
58
+ self.boundary_predictor = nn.Sequential(
59
+ nn.Linear(2 * token_dim, boundary_hidden_dim),
60
+ nn.ReLU(inplace=True),
61
+ nn.Linear(boundary_hidden_dim, 1),
62
+ )
63
+
64
+ # ── temperature annealing (Fix #5) ───────────────────
65
+ def get_tau(self, progress: float) -> float:
66
+ """Cosine-annealed Gumbel-Softmax temperature.
67
+
68
+ Args:
69
+ progress: training progress in [0, 1].
70
+ Returns:
71
+ current τ value.
72
+ """
73
+ progress = min(max(progress, 0.0), 1.0)
74
+ return self.tau_end + 0.5 * (self.tau_start - self.tau_end) * (
75
+ 1.0 + math.cos(math.pi * progress)
76
+ )
77
+
78
+ # ── boundary prediction (vectorized) ─────────────────
79
+ def predict_boundary(
80
+ self,
81
+ prev_token: torch.Tensor,
82
+ curr_token: torch.Tensor,
83
+ tau: float = 1.0,
84
+ ) -> torch.Tensor:
85
+ """Predict boundary flag for each element in the batch.
86
+
87
+ Uses Straight-Through Estimator (STE) via Gumbel-Softmax:
88
+ forward = hard 0/1, backward = soft gradients.
89
+
90
+ Args:
91
+ prev_token: [B, D] — pooled previous-frame token.
92
+ curr_token: [B, D] — pooled current-frame token.
93
+ tau: Gumbel-Softmax temperature.
94
+
95
+ Returns:
96
+ boundary_flag: [B, 1] — hard 0.0 or 1.0 (differentiable via STE).
97
+ """
98
+ combined = torch.cat([prev_token, curr_token], dim=-1) # [B, 2D]
99
+ logit = self.boundary_predictor(combined) # [B, 1]
100
+
101
+ # Stack [logit, -logit] → [B, 2] (cut / no-cut)
102
+ logits_2 = torch.cat([logit, -logit], dim=-1) # [B, 2]
103
+ one_hot = F.gumbel_softmax(logits_2, tau=tau, hard=True, dim=-1)
104
+ # one_hot[:, 0] = 1 means "cut", one_hot[:, 1] = 1 means "no cut"
105
+ boundary_flag = one_hot[:, :1] # [B, 1]
106
+ return boundary_flag
107
+
108
+ # ── masked pooling (Fix #3) ──────────────────────────
109
+ @staticmethod
110
+ def masked_pool(
111
+ accum_tokens: torch.Tensor,
112
+ valid_mask: torch.Tensor,
113
+ ) -> torch.Tensor:
114
+ """Safe masked average pooling over the accumulation buffer.
115
+
116
+ Args:
117
+ accum_tokens: [B, max_K, P, C] — padded token buffer.
118
+ valid_mask: [B, max_K] — True where frames are valid.
119
+
120
+ Returns:
121
+ pooled: [B, C] — mean-pooled descriptor input.
122
+ """
123
+ # [B, max_K, 1, 1]
124
+ mask_f = valid_mask.float().unsqueeze(-1).unsqueeze(-1)
125
+ # Sum over frames and patches, divide by valid count
126
+ # First pool patches within each frame: [B, max_K, C]
127
+ frame_pooled = accum_tokens.mean(dim=2) # [B, max_K, C]
128
+ mask_f_1d = valid_mask.float().unsqueeze(-1) # [B, max_K, 1]
129
+ # Weighted sum over frames
130
+ numerator = (frame_pooled * mask_f_1d).sum(dim=1) # [B, C]
131
+ denominator = mask_f_1d.sum(dim=1).clamp(min=1.0) # [B, 1]
132
+ return numerator / denominator # [B, C]
133
+
134
+ # ── vectorized soft reset (Fix #4) ───────────────────
135
+ @staticmethod
136
+ def soft_reset(
137
+ accum_tokens: torch.Tensor,
138
+ valid_mask: torch.Tensor,
139
+ accum_len: torch.Tensor,
140
+ boundary_flag: torch.Tensor,
141
+ ):
142
+ """Vectorized buffer reset using the boundary gate.
143
+
144
+ Elements where boundary_flag == 1 get zeroed out;
145
+ elements where boundary_flag == 0 are kept unchanged.
146
+ No Python-level branching.
147
+
148
+ Args:
149
+ accum_tokens: [B, max_K, P, C].
150
+ valid_mask: [B, max_K].
151
+ accum_len: [B] long — current length per sequence.
152
+ boundary_flag: [B, 1] — 0.0 or 1.0.
153
+
154
+ Returns:
155
+ (accum_tokens, valid_mask, accum_len) — updated in-place style.
156
+ """
157
+ keep = 1.0 - boundary_flag # [B, 1]
158
+
159
+ # Tokens: multiply by keep (broadcasts over max_K, P, C)
160
+ accum_tokens = accum_tokens * keep.unsqueeze(-1).unsqueeze(-1)
161
+
162
+ # Valid mask: zero out where boundary triggered
163
+ valid_mask = valid_mask & keep.squeeze(-1).bool()
164
+
165
+ # Length: reset to 0 where boundary triggered
166
+ accum_len = accum_len * keep.squeeze(-1).long()
167
+
168
+ return accum_tokens, valid_mask, accum_len
169
+
170
+ # ── full vectorized step ─────────────────────────────
171
+ def step(
172
+ self,
173
+ new_token: torch.Tensor,
174
+ prev_token: torch.Tensor,
175
+ accum_tokens: torch.Tensor,
176
+ valid_mask: torch.Tensor,
177
+ accum_len: torch.Tensor,
178
+ tau: float = 1.0,
179
+ ):
180
+ """Execute one frame step for the full batch (vectorized).
181
+
182
+ Args:
183
+ new_token: [B, P, C] — current frame token.
184
+ prev_token: [B, D] — pooled previous-frame representation.
185
+ accum_tokens: [B, max_K, P, C] — accumulation buffer.
186
+ valid_mask: [B, max_K] — bool mask.
187
+ accum_len: [B] long — frames accumulated so far.
188
+ tau: current Gumbel temperature.
189
+
190
+ Returns:
191
+ boundary_flag: [B, 1] — hard gate.
192
+ curr_desc: [B, C] — pooled descriptor (for loop retrieval).
193
+ accum_tokens: updated buffer.
194
+ valid_mask: updated mask.
195
+ accum_len: updated lengths.
196
+ """
197
+ B = new_token.shape[0]
198
+ device = new_token.device
199
+
200
+ # 1. Append token to buffer (scatter, no branching)
201
+ # Clamp index to avoid out-of-bounds when buffer is full
202
+ write_idx = accum_len.clamp(max=self.max_K - 1) # [B]
203
+ for b in range(B):
204
+ # NOTE: this loop is over a *fixed* batch size (same on all GPUs),
205
+ # so it does NOT cause DDP divergence.
206
+ idx = write_idx[b].item()
207
+ accum_tokens[b, idx] = new_token[b]
208
+ valid_mask[b, idx] = True
209
+ accum_len = (accum_len + 1).clamp(max=self.max_K)
210
+
211
+ # 2. Boundary prediction (always for ALL B)
212
+ curr_pooled = new_token.mean(dim=1) # [B, C]
213
+ boundary_flag = self.predict_boundary(prev_token, curr_pooled, tau=tau)
214
+
215
+ # 3. Compute descriptor (always for ALL B, masked pooling)
216
+ curr_desc = self.masked_pool(accum_tokens, valid_mask) # [B, C]
217
+
218
+ # 4. Soft reset (vectorized, Fix #4)
219
+ accum_tokens, valid_mask, accum_len = self.soft_reset(
220
+ accum_tokens, valid_mask, accum_len, boundary_flag
221
+ )
222
+
223
+ return boundary_flag, curr_desc, accum_tokens, valid_mask, accum_len
224
+
225
+ # ── buffer initialization helper ─────────────────────
226
+ def init_buffers(
227
+ self,
228
+ batch_size: int,
229
+ P: int,
230
+ C: int,
231
+ device: torch.device,
232
+ ):
233
+ """Create fresh accumulation buffers for a batch.
234
+
235
+ Returns:
236
+ accum_tokens: [B, max_K, P, C] zeros.
237
+ valid_mask: [B, max_K] False.
238
+ accum_len: [B] zeros (long).
239
+ """
240
+ accum_tokens = torch.zeros(batch_size, self.max_K, P, C, device=device)
241
+ valid_mask = torch.zeros(batch_size, self.max_K, dtype=torch.bool, device=device)
242
+ accum_len = torch.zeros(batch_size, dtype=torch.long, device=device)
243
+ return accum_tokens, valid_mask, accum_len
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/demo.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+ from collections import OrderedDict
3
+ import torch
4
+ import numpy as np
5
+ import re
6
+ import cv2
7
+ import glob
8
+ import argparse
9
+
10
+ import time
11
+ import open3d as o3d
12
+ from rich import print
13
+ import matplotlib.pyplot as plt
14
+ from scipy.spatial.transform import Rotation as R
15
+
16
+ import rerun as rr
17
+ import rerun.blueprint as rrb
18
+
19
+ sys.path.append('src')
20
+ from slamformer.models.slamformer import SLAMFormer
21
+
22
+ current_directory = os.path.dirname(os.path.abspath(__file__))
23
+
24
+ sys.path.append(current_directory+'/../')
25
+
26
+ import slam.utils as utils
27
+ from slam.rerun_helper import log_camera, log_window
28
+
29
+ def strip_module(state_dict):
30
+ """
31
+ Removes the 'module.' prefix from the keys of a state_dict.
32
+ Args:
33
+ state_dict (dict): The original state_dict with possible 'module.' prefixes.
34
+ Returns:
35
+ OrderedDict: A new state_dict with 'module.' prefixes removed.
36
+ """
37
+ new_state_dict = OrderedDict()
38
+ for k, v in state_dict.items():
39
+ name = k[7:] if k.startswith("module.") else k
40
+ new_state_dict[name] = v
41
+ return new_state_dict
42
+
43
+
44
+ class SLAM:
45
+ def __init__(
46
+ self,
47
+ outdir='output/tmp',
48
+ kf_th=0.1,
49
+ bn_every=10,
50
+ vis=False,
51
+ save_gmem=True,
52
+ ckpt_path='path/to/ckpt.pth',
53
+ target_size=518,
54
+ retention_ratio=0.5
55
+ ):
56
+
57
+ self.outdir = outdir
58
+ self.kf_th=kf_th
59
+ self.save_gmem = save_gmem
60
+ self.bn_every=bn_every
61
+ self.vis = vis
62
+ self.ckpt_path = ckpt_path
63
+ self.target_size = target_size
64
+
65
+
66
+ self.times = []
67
+ self.kf_time = []
68
+ self.backend_time = []
69
+
70
+ # model params
71
+ self.model = SLAMFormer(retention_ratio=retention_ratio, bn_every=bn_every)
72
+ self.model = self.model.eval()
73
+ self.load_model()
74
+ self.model.eval()
75
+ self.model.to('cuda')
76
+
77
+ # SLAM params
78
+ self.fid = -1
79
+ self.kid = -1
80
+ self.kfids = []
81
+ self.last_kfid = 0
82
+ self.kf_timestamps = []
83
+ # frontend
84
+ self.frontend_times = 0
85
+ # Token map
86
+ self.map = None
87
+ self.map_opt = None
88
+
89
+ self.signal_backend = False
90
+ self.backend_every = self.bn_every #10
91
+ #
92
+ self.extrins = []
93
+ self.intrins = []
94
+ self.frames = []
95
+ self.kf_frames = []
96
+
97
+ #
98
+ self.K = None
99
+ self.update_K = False
100
+
101
+ # vis
102
+ if self.vis:
103
+ self.entity="world"
104
+ rr.init("SLAM", spawn=True)
105
+ rr.log(self.entity, rr.ViewCoordinates.RIGHT_HAND_Z_UP)
106
+ self.Twk = np.eye(4)
107
+ self.K = np.eye(3)
108
+
109
+ def load_model(self):
110
+ ckpt_raw = torch.load(self.ckpt_path, map_location='cuda', weights_only=False)
111
+
112
+ if isinstance(ckpt_raw, dict):
113
+ if "model" in ckpt_raw:
114
+ ckpt = ckpt_raw["model"]
115
+ print("Loaded state_dict from 'model' key in checkpoint.")
116
+ else:
117
+ ckpt = ckpt_raw
118
+ else:
119
+ ckpt = ckpt_raw
120
+
121
+ ckpt = utils.strip_module(ckpt)
122
+ self.model.load_state_dict(ckpt, strict=False)
123
+ del ckpt, ckpt_raw
124
+
125
+ @property
126
+ def time(self):
127
+ torch.cuda.synchronize()
128
+ return time.perf_counter()
129
+
130
+ def kf_detect(self, image):
131
+ if self.kid == -1:
132
+ self.extrins.append(torch.eye(4))
133
+ return True
134
+
135
+ frame = utils.load_image(image, self.target_size)
136
+ _,H,W = frame.shape
137
+
138
+
139
+ st = self.time #time.perf_counter()
140
+ token = self.model.KFT(torch.stack([self.kf_frames[-1],frame.cuda()]))
141
+ if self.vis:
142
+ # scale the pose to global
143
+ res = self.model.extract(token, cam_only=True)
144
+ #z = res['local_points'][0,0,:,:,-1].cpu().numpy()
145
+ if not hasattr(self,'depth_lask_kf'):
146
+ scale=1
147
+ else:
148
+ scale=1 #np.median(self.depth_last_kf/(z+1e-6))
149
+ camera_pose = res['camera_poses']
150
+
151
+ extrinsic = torch.inverse(camera_pose)
152
+ if extrinsic.shape[1] > 1:
153
+ extrinsic_ref=extrinsic.cpu()[0,-2]
154
+ extrinsic = extrinsic.cpu()[0,-1]
155
+ Tki = torch.inverse(camera_pose[0,0])@camera_pose[0,1]
156
+ Tki = Tki.cpu().numpy()
157
+ self.Twi = self.Twk@Tki
158
+ K44 = np.eye(4)
159
+ K44[:3,:3] = self.K
160
+ log_camera("camera",self.Twi, K44, kfd=True)
161
+ # make the window follow camera
162
+ log_window(f"{self.entity}",np.linalg.inv(self.Twi), K44)
163
+
164
+
165
+ else:
166
+ res = self.model.extract(token, cam_only=True)
167
+ camera_pose = res['camera_poses']
168
+ extrinsic = torch.inverse(camera_pose)
169
+ if extrinsic.shape[1] > 1:
170
+ extrinsic_ref=extrinsic.cpu()[0,-2]
171
+ extrinsic = extrinsic.cpu()[0,-1]
172
+ self.kft_extrinsic_ref = torch.eye(4)#extrinsic_ref
173
+
174
+ dist = torch.sqrt(torch.sum((extrinsic[:3,3] - extrinsic_ref[:3,3])**2))
175
+ isKF = dist > self.kf_th
176
+
177
+ print(dist)
178
+
179
+ if isKF:
180
+ self.extrins.append(extrinsic)
181
+ return isKF
182
+
183
+ def frontend(self, image):
184
+
185
+ if self.vis:
186
+ rr.log("image", rr.Image(image[:,:,::-1]))#,static=True)
187
+
188
+ self.fid += 1
189
+ print('Frame', self.fid)
190
+ # run kf detector
191
+ st = self.time
192
+ enough_disparity = self.kf_detect(image)
193
+ self.kf_time.append(self.time-st)
194
+ if not enough_disparity:
195
+ return False
196
+
197
+
198
+ torch.cuda.empty_cache()
199
+ # run T-frontend
200
+ H_,W_,_ = image.shape
201
+ frame = utils.load_image(image, self.target_size)
202
+ self.H,self.W,_ = frame.shape
203
+ st = self.time
204
+ self.last_kf = frame.cuda()
205
+ self.kf_frames.append(self.last_kf)
206
+ self.last_kfid = self.fid
207
+ self.frames.append(self.last_kf.clone())
208
+ self.kid += 1
209
+ print("[italic purple] # KEYFRAME", self.kid)
210
+ self.kf_timestamps.append(self.cur_timestamp)
211
+ frame = frame.cuda()
212
+ st = self.time
213
+
214
+ if self.nkf == 1:
215
+ pass
216
+ elif self.nkf == 2:
217
+ token = self.model.frontendT(torch.stack([self.kf_frames[0],frame]))
218
+ self.map_add(token)
219
+ else:
220
+ token = self.model.frontendT(frame)
221
+ print(self.time-st)
222
+
223
+ self.map_add(token)
224
+
225
+ self.kfids.append(self.fid)
226
+ self.times.append(self.time-st)
227
+ torch.cuda.empty_cache()
228
+
229
+ # send signal to backend
230
+ self.frontend_times += 1
231
+ if self.frontend_times % self.backend_every == 0:
232
+ self.signal_backend = True
233
+
234
+ if self.vis and self.map is not None:
235
+ st = time.time()
236
+ map_before_bn = None
237
+ if self.map_opt is None:
238
+ map_before_bn = self.map
239
+ else:
240
+ S = self.map.shape[0]
241
+ S_oldopt = self.map_opt.shape[0]
242
+
243
+ map_before_bn = torch.cat([self.map_opt, self.map[S_oldopt:]],axis=0)
244
+ if self.nkf == 2:
245
+ ps,cs,confs,poses = self.extract(self.map)
246
+
247
+ else:
248
+ ps,cs,confs,poses = self.extract(self.map[-1:])
249
+
250
+
251
+ self.vis_mem = [ps,cs,confs,poses]
252
+
253
+ conf_threshold = np.percentile(confs, 15)
254
+ msk = confs>=conf_threshold
255
+
256
+ ps = ps[msk]
257
+ cs = cs[msk]
258
+ K44 = np.eye(4)
259
+ K44[:3,:3] = self.K
260
+
261
+ if self.nkf == 2:
262
+ log_camera(f"{self.entity}/camera_kf/0",poses[0], K44)
263
+ log_camera(f"{self.entity}/camera_kf/1",poses[1], K44)
264
+
265
+ rr.log(f"{self.entity}/lines/0to1", rr.LineStrips3D([poses[:,:3,3].tolist()],colors=[0,0,255],radii=[0.005]))
266
+
267
+ self.last_kf_pose = poses[1]
268
+ else:
269
+ log_camera(f"{self.entity}/camera_kf/{self.nkf-1}",poses.reshape(4,4), K44)
270
+ rr.log(f"{self.entity}/lines/{self.nkf-2}to{self.nkf-1}", rr.LineStrips3D([np.stack([self.last_kf_pose[:3,3],poses[0,:3,3]]).tolist()],colors=[0,0,255],radii=[0.005]))
271
+
272
+ self.last_kf_pose = poses[0]
273
+
274
+
275
+ rr.log(
276
+ f"{self.entity}/pointclouds/{self.nkf}",
277
+ rr.Points3D(ps, colors=cs, radii=0.01),
278
+ )
279
+
280
+ print('log', time.time()-st)
281
+
282
+ self.Twk = poses[-1].reshape(4,4)
283
+
284
+ def backend(self, final=False):
285
+ if not self.signal_backend:
286
+ return
287
+
288
+ torch.cuda.empty_cache()
289
+
290
+ del self.model.fkv
291
+ torch.cuda.empty_cache()
292
+ print('Backending...', self.nkf, 'KFs')
293
+ st = time.perf_counter()
294
+ map_optimed = self.model.backendT(self.map.cuda())
295
+ self.backend_time.append(time.perf_counter()-st)
296
+ print('backend_take', time.perf_counter()-st)
297
+ torch.cuda.empty_cache()
298
+
299
+ if self.map_opt is not None:
300
+ del self.map_opt
301
+ torch.cuda.empty_cache()
302
+ self.map_opt = map_optimed.cpu()
303
+
304
+ self.signal_backend = False
305
+ torch.cuda.empty_cache()
306
+
307
+ if self.vis:
308
+ ps,cs,confs,poses = self.extract(self.map_opt)
309
+ self.vis_mem = [ps,cs,confs,poses]
310
+ conf_threshold = np.percentile(confs, 15)
311
+ msk = confs>=conf_threshold
312
+
313
+ ps = ps[msk]
314
+ cs = cs[msk]
315
+
316
+
317
+ for s in range(self.nkf+1):
318
+ rr.log(f"{self.entity}/pointclouds/{s}", rr.Points3D(np.array([])))
319
+
320
+ for s in range(self.nkf):
321
+ K44 = np.eye(4)
322
+ K44[:3,:3] = self.K
323
+ log_camera(f"{self.entity}/camera_kf/{s}",poses[s].reshape(4,4), K44, update=True)
324
+
325
+ for s in range(1, self.nkf):
326
+ rr.log(f"{self.entity}/lines/{s-1}to{s}", rr.LineStrips3D([poses[s-1:s+1,:3,3].tolist()],colors=[0,0,255],radii=[0.005]))
327
+
328
+ rr.log(
329
+ f"{self.entity}/pointclouds/{self.nkf}",
330
+ rr.Points3D(ps, colors=cs, radii=0.01),
331
+ )
332
+ self.last_kf_pose = poses[-1]
333
+
334
+ def step(self, timestamp, image):
335
+ if timestamp is None:
336
+ self.cur_timestamp = self.fid+1
337
+ else:
338
+ self.cur_timestamp = timestamp
339
+
340
+ self.frontend(image)
341
+
342
+ self.backend()
343
+
344
+ def map_add(self, token_kf):
345
+ if self.map is None:
346
+ self.map = token_kf.cpu() if self.save_gmem else token_kf #[tok.cpu() for tok in token_kf]
347
+ else:
348
+ if self.save_gmem:
349
+ self.map = torch.cat([self.map, token_kf.cpu()],axis=0) # S,P,C
350
+ else:
351
+ self.map = torch.cat([self.map, token_kf],axis=0) # S,P,C
352
+
353
+ @property
354
+ def nkf(self):
355
+ return self.kid+1
356
+
357
+ @property
358
+ def nf(self):
359
+ return self.fid+1
360
+
361
+ def terminate(self):
362
+ if self.nkf % self.backend_every != 0:
363
+ self.signal_backend = True
364
+ self.backend(final=True)
365
+
366
+ print(self.kf_time)
367
+ print(self.times)
368
+ print(self.backend_time)
369
+ print('frontend take', np.mean(self.times))
370
+ print('KFT')
371
+ print('total', np.sum(self.kf_time), 'FPS', float(len(self.kf_time))/np.sum(self.kf_time))
372
+ print('FT')
373
+ print('total', np.sum(self.times), 'FPS', float(len(self.times))/np.sum(self.times))
374
+ print('BT')
375
+ print('total', np.sum(self.backend_time), 'FPS', float(len(self.backend_time))/np.sum(self.backend_time))
376
+ print('Summary')
377
+ print('total', np.sum(self.kf_time)+np.sum(self.times)+np.sum(self.backend_time), 'FPS', float(len(self.kf_time))/(np.sum(self.kf_time)+np.sum(self.times)+np.sum(self.backend_time)))
378
+ self.save_result(f'{self.outdir}/final', self.map_opt)
379
+
380
+ def extract(self, map_all=None):
381
+ result = self.model.extract(map_all.cuda())
382
+
383
+ pts = result['points'].cpu().numpy() # 1,S,H,W,3
384
+ local_pts = result['local_points'].cpu().numpy() # 1,S,H,W,3
385
+ _,S,H,W,_ = pts.shape
386
+ conf = result['conf'].cpu().numpy()
387
+ point_clouds = [pts[0,s] for s in range(S)]
388
+ #conf_threshold = np.percentile(conf, 15)
389
+ #confs = [conf[0,s]>=conf_threshold for s in range(S)]
390
+ colors = torch.stack(self.frames[-S:]).permute(0,2,3,1).reshape(-1,3).cpu().numpy()[:,::-1] # S,H,W,C
391
+ confs = conf.reshape(-1)
392
+
393
+
394
+ camera_pose = result['camera_poses'].cpu().numpy()[0] # S,4,4
395
+ pts = pts.reshape(-1,3)
396
+ colors = colors.reshape(-1,3)
397
+
398
+ # set depth for the last kf
399
+ self.depth_last_kf = local_pts[0,-1,:,:,-1]
400
+
401
+ return pts, colors, confs, camera_pose
402
+
403
+ def save_result(self, output_path = 'output/tmp', map_all=None, traj=True):
404
+ '''
405
+ if map_all is None:
406
+ map_all = self.map
407
+ '''
408
+ print(self.kfids)
409
+
410
+ if map_all is None:
411
+ map_all = self.map_opt
412
+
413
+ result = self.model.extract(map_all.cuda())
414
+ pts = result['points'].cpu().numpy() # 1,S,H,W,3
415
+ _,S,H,W,_ = pts.shape
416
+ conf = result['conf'].cpu().numpy()
417
+ point_clouds = [pts[0,s] for s in range(S)]
418
+ conf_threshold = np.percentile(conf, 15)
419
+ confs = [conf[0,s]>=conf_threshold for s in range(S)]
420
+
421
+ colors = torch.stack(self.frames).permute(0,2,3,1).reshape(-1,3).cpu().numpy()[:,::-1] # S,H,W,C
422
+ msk = np.stack(confs).reshape(-1)
423
+ pcd = o3d.geometry.PointCloud()
424
+ pcd.points = o3d.utility.Vector3dVector(pts.reshape(-1,3).astype(np.float64)[msk])
425
+ pcd.colors = o3d.utility.Vector3dVector(colors.reshape(-1,3).astype(np.float64)[msk])
426
+ #downpcd = pcd.voxel_down_sample(voxel_size=0.005)
427
+ o3d.io.write_point_cloud(f"{output_path}.ply", pcd)
428
+ camera_pose = result['camera_poses'].cpu() #torch.Size([1, 14, 4,4])
429
+ poses = camera_pose[0].numpy()
430
+
431
+ self.write_poses_to_file(f"{output_path}_traj.txt", poses, self.kf_timestamps)
432
+ self.save_framewise_pointclouds(f"{output_path}_pc", point_clouds, self.kf_timestamps, confs)
433
+
434
+ return result
435
+
436
+ def write_poses_to_file(self, filename, poses, frame_ids):
437
+
438
+ with open(filename, "w") as f:
439
+ assert len(poses) == len(frame_ids), "Number of provided poses and number of frame ids do not match"
440
+ for frame_id, pose in zip(frame_ids, poses):
441
+ x, y, z = pose[0:3, 3]
442
+ rotation_matrix = pose[0:3, 0:3]
443
+ quaternion = R.from_matrix(rotation_matrix).as_quat() # x, y, z, w
444
+ output = np.array([float(frame_id), x, y, z, *quaternion])
445
+ f.write(" ".join(f"{v:.8f}" for v in output) + "\n")
446
+
447
+ def save_framewise_pointclouds(self, filename, pointclouds, frame_ids, conf_masks):
448
+ os.makedirs(filename, exist_ok=True)
449
+ for frame_id, pointcloud, conf_masks in zip(frame_ids, pointclouds, conf_masks):
450
+ # save pcd as numpy array
451
+ np.savez(f"{filename}/{frame_id}.npz", pointcloud=pointcloud, mask=conf_masks)
452
+
453
+
454
+ def get_parser():
455
+ parser = argparse.ArgumentParser(description="SLAM-Former demo")
456
+ parser.add_argument("--ckpt_path", type=str, default="path/to/checkpoint.pth.model", help="Path to the checkpoint")
457
+ parser.add_argument("--image_folder", type=str, default="path/to/image/folder", help="Path to folder containing images")
458
+ parser.add_argument("--target_size", type=int, default=518, help="the target size of image(longer side)")
459
+ parser.add_argument("--output_dir", type=str, default="outputs/tmp", help="Path to save the output")
460
+ parser.add_argument("--stride", type=int, default=1, help="Frame stride for subsampling the input sequence")
461
+ parser.add_argument("--kf_th", type=float, default=0.1, help="Keyframe selection threshold (minimum translation distance)")
462
+ parser.add_argument("--retention_ratio", type=float, default=0.5, help="KV Pruning retention ratio")
463
+ parser.add_argument("--bn_every", type=int, default=10, help="Run backend optimization every N keyframes")
464
+ parser.add_argument("--vis", action="store_true", help="Enable real-time visualization with Rerun")
465
+ parser.add_argument("--resize_rate", type=float, default=1, help="Resize rate for input images before processing")
466
+
467
+ args = parser.parse_args()
468
+ return args
469
+
470
+
471
+ if __name__ == '__main__':
472
+ args = get_parser()
473
+ image_folder = args.image_folder
474
+ outdir = args.output_dir
475
+ os.makedirs(outdir, exist_ok=True)
476
+
477
+ if 'tum' in args.image_folder:
478
+ fx = 525.0 # focal length x
479
+ fy = 525.0 # focal length y
480
+ cx = 319.5 # optical center x
481
+ cy = 239.5 # optical center y
482
+ K = np.eye(3)
483
+ K[0,0] = fx
484
+ K[1,1] = fy
485
+ K[0,2] = cx
486
+ K[1,2] = cy
487
+ elif 'Replica' in args.image_folder:
488
+ fx = 600. # focal length x
489
+ fy = 600.0 # focal length y
490
+ cx = 599.5 # optical center x
491
+ cy = 339.5 # optical center y
492
+ K = np.eye(3)
493
+ K[0,0] = fx
494
+ K[1,1] = fy
495
+ K[0,2] = cx
496
+ K[1,2] = cy
497
+ else:
498
+ K = None
499
+
500
+
501
+ # Use the provided image folder path
502
+ print(f"Loading images from {image_folder}...")
503
+ image_names = [f for f in glob.glob(os.path.join(image_folder, "*"))
504
+ if "depth" not in os.path.basename(f).lower() and "txt" not in os.path.basename(f).lower()
505
+ and "db" not in os.path.basename(f).lower()]
506
+ image_names = utils.sort_images_by_number(image_names)
507
+
508
+ frame_ids = []
509
+ for path in image_names:
510
+ filename = os.path.basename(path)
511
+ match = re.search(r'\d+(?:\.\d+)?', filename) # matches integers and decimals
512
+ if match:
513
+ frame_ids.append(float(match.group()))
514
+ else:
515
+ raise ValueError(f"No number found in image name: {filename}")
516
+
517
+ print(f"Found {len(image_names)} images")
518
+
519
+ print('resize image', args.resize_rate)
520
+
521
+ slam = SLAM(
522
+ outdir=outdir,
523
+ kf_th=args.kf_th,
524
+ bn_every=args.bn_every,
525
+ vis=args.vis,
526
+ ckpt_path=args.ckpt_path,
527
+ target_size=args.target_size,
528
+ retention_ratio=args.retention_ratio
529
+ )
530
+
531
+ slam.K = K
532
+ for frame_id, image_name in zip(frame_ids[::args.stride], image_names[::args.stride]):
533
+ img = cv2.imread(image_name)
534
+
535
+ if args.resize_rate != 1:
536
+ H,W,_ = img.shape
537
+ img = cv2.resize(img, (int(W*args.resize_rate), int(H*args.resize_rate)), cv2.INTER_CUBIC)
538
+ slam.step(frame_id, img)
539
+ result = slam.terminate()
540
+
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/demo_infinite.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Infinite SLAM-Former — Inference Demo with all CLI toggles.
4
+
5
+ Wraps the original SLAMFormer + SLAM class with the new modules:
6
+ - GraphGatedMemoryManager (submap backend + loop closure)
7
+ - TemporalEmbedWrapper (dual temporal embedding injection)
8
+ - BatchedDynamicSubmapRouter (learned submap boundaries)
9
+
10
+ When all toggles are off, behaviour is identical to slam/demo.py.
11
+
12
+ Usage:
13
+ # Vanilla (same as demo.py):
14
+ python slam/demo_infinite.py --ckpt_path ckpt/checkpoint.pth.model \
15
+ --image_folder /path/to/images --output_dir outputs/tmp
16
+
17
+ # With submap backend + loop closure + temporal embed:
18
+ python slam/demo_infinite.py --ckpt_path ckpt/checkpoint.pth.model \
19
+ --image_folder /path/to/images --output_dir outputs/tmp \
20
+ --enable_submap_backend --enable_loop_closure --enable_temporal_embed
21
+ """
22
+
23
+ import os
24
+ import sys
25
+ import re
26
+ import glob
27
+ import argparse
28
+ import time
29
+ from collections import OrderedDict
30
+
31
+ import cv2
32
+ import torch
33
+ import numpy as np
34
+ from scipy.spatial.transform import Rotation as R
35
+
36
+ # ─── Path setup ──────────────────────────────────────────
37
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
38
+ PROJECT_DIR = os.path.dirname(SCRIPT_DIR)
39
+ SRC_DIR = os.path.join(PROJECT_DIR, "src")
40
+
41
+ if SRC_DIR not in sys.path:
42
+ sys.path.insert(0, SRC_DIR)
43
+ if PROJECT_DIR not in sys.path:
44
+ sys.path.insert(0, PROJECT_DIR)
45
+
46
+ from slamformer.models.slamformer import SLAMFormer
47
+ import slam.utils as utils
48
+ from slam.graph_gated_memory import (
49
+ GraphGatedMemoryManager,
50
+ TemporalEmbedWrapper,
51
+ )
52
+
53
+
54
+ def strip_module(state_dict):
55
+ if not isinstance(state_dict, dict):
56
+ return state_dict
57
+ new_state_dict = OrderedDict()
58
+ for k, v in state_dict.items():
59
+ name = k[7:] if k.startswith("module.") else k
60
+ new_state_dict[name] = v
61
+ return new_state_dict
62
+
63
+
64
+ def strip_prefix(state_dict, prefix):
65
+ if not isinstance(state_dict, dict):
66
+ return state_dict
67
+ new_state_dict = OrderedDict()
68
+ for k, v in state_dict.items():
69
+ name = k[len(prefix):] if k.startswith(prefix) else k
70
+ new_state_dict[name] = v
71
+ return new_state_dict
72
+
73
+
74
+ def get_cfg_value(cfg, key, default=None):
75
+ if cfg is None:
76
+ return default
77
+ try:
78
+ if hasattr(cfg, "get"):
79
+ return cfg.get(key, default)
80
+ except Exception:
81
+ pass
82
+ try:
83
+ return cfg[key]
84
+ except Exception:
85
+ return getattr(cfg, key, default)
86
+
87
+
88
+ def extract_checkpoint_parts(ckpt_raw):
89
+ train_cfg = None
90
+ memory_state = None
91
+ temporal_state = None
92
+ if isinstance(ckpt_raw, dict):
93
+ train_cfg = ckpt_raw.get("args")
94
+ if "model" in ckpt_raw:
95
+ ckpt = ckpt_raw["model"]
96
+ else:
97
+ ckpt = ckpt_raw
98
+ nested = ckpt_raw.get("submap_modules", {})
99
+ if isinstance(nested, dict):
100
+ memory_state = ckpt_raw.get("memory_mgr", nested.get("memory_mgr"))
101
+ temporal_state = ckpt_raw.get("temporal_wrapper", nested.get("temporal_wrapper"))
102
+ else:
103
+ memory_state = ckpt_raw.get("memory_mgr")
104
+ temporal_state = ckpt_raw.get("temporal_wrapper")
105
+ else:
106
+ ckpt = ckpt_raw
107
+ return ckpt, train_cfg, memory_state, temporal_state
108
+
109
+
110
+ class InfiniteSLAM:
111
+ """Extended SLAM pipeline with optional submap backend and loop closure.
112
+
113
+ When all enable_* flags are False, this behaves identically to the
114
+ original SLAM class in slam/demo.py.
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ outdir: str = "output/tmp",
120
+ kf_th: float = 0.1,
121
+ bn_every: int = 10,
122
+ ckpt_path: str = "",
123
+ target_size: int = 518,
124
+ retention_ratio: float = 0.5,
125
+ # ── new toggles ──
126
+ enable_submap_backend: bool = False,
127
+ submap_size: int = 10,
128
+ max_recursive_submaps: int = 3,
129
+ enable_loop_closure: bool = False,
130
+ desc_dim: int = 128,
131
+ enable_temporal_embed: bool = False,
132
+ temporal_embed_mode: str = "learned",
133
+ ):
134
+ self.outdir = outdir
135
+ self.kf_th = kf_th
136
+ self.bn_every = bn_every
137
+ self.target_size = target_size
138
+ ckpt_raw = None
139
+ self.train_cfg = None
140
+ if ckpt_path and os.path.exists(ckpt_path):
141
+ ckpt_raw = torch.load(ckpt_path, map_location="cuda", weights_only=False)
142
+ _, self.train_cfg, _, _ = extract_checkpoint_parts(ckpt_raw)
143
+
144
+ retention_ratio = float(get_cfg_value(self.train_cfg, "retention_ratio", retention_ratio))
145
+ self.enable_submap = bool(get_cfg_value(self.train_cfg, "enable_submap", enable_submap_backend))
146
+ self.enable_loop = bool(get_cfg_value(self.train_cfg, "enable_loop", enable_loop_closure))
147
+ self.enable_temporal = bool(get_cfg_value(self.train_cfg, "enable_temporal", enable_temporal_embed))
148
+ submap_size = int(get_cfg_value(self.train_cfg, "submap_size", submap_size))
149
+ max_recursive_submaps = int(get_cfg_value(self.train_cfg, "max_recursive_submaps", max_recursive_submaps))
150
+ desc_dim = int(get_cfg_value(self.train_cfg, "desc_dim", desc_dim))
151
+ loop_mask_mode = get_cfg_value(self.train_cfg, "loop_mask_mode", "hard_top1")
152
+ soft_mask_temperature = float(get_cfg_value(self.train_cfg, "soft_mask_temperature", 0.25))
153
+ soft_mask_bias = float(get_cfg_value(self.train_cfg, "soft_mask_bias", 0.2))
154
+ submap_train_mode = get_cfg_value(self.train_cfg, "submap_train_mode", "full_token")
155
+ submap_retrieval_topk = int(get_cfg_value(self.train_cfg, "submap_retrieval_topk", 0))
156
+ submap_fetch_source = get_cfg_value(self.train_cfg, "submap_fetch_source", "frontend")
157
+ submap_descriptor_source = get_cfg_value(self.train_cfg, "submap_descriptor_source", "frontend")
158
+ temporal_embed_mode = get_cfg_value(self.train_cfg, "temporal_embed_mode", temporal_embed_mode)
159
+
160
+ # ── Model (frozen) ───────────────────────────────
161
+ self.model = SLAMFormer(
162
+ retention_ratio=retention_ratio, bn_every=bn_every
163
+ )
164
+ self.model.eval()
165
+
166
+ # ── Memory manager ───────────────────────────────
167
+ embed_dim = self.model.dec_embed_dim
168
+ self.memory_mgr = GraphGatedMemoryManager(
169
+ submap_size=submap_size,
170
+ max_recursive_submaps=max_recursive_submaps,
171
+ desc_dim=desc_dim,
172
+ embed_dim=embed_dim,
173
+ loop_mask_mode=loop_mask_mode,
174
+ soft_mask_temperature=soft_mask_temperature,
175
+ soft_mask_bias=soft_mask_bias,
176
+ retain_history_grad=False,
177
+ submap_train_mode=submap_train_mode,
178
+ submap_retrieval_topk=submap_retrieval_topk,
179
+ submap_fetch_source=submap_fetch_source,
180
+ submap_descriptor_source=submap_descriptor_source,
181
+ )
182
+
183
+ # ── Temporal wrapper ─────────────────────────────
184
+ self.temporal_wrapper = TemporalEmbedWrapper(
185
+ embed_dim=embed_dim,
186
+ max_frames=5000,
187
+ mode=temporal_embed_mode,
188
+ )
189
+
190
+ self._load_model(ckpt_raw)
191
+ self.model.to("cuda")
192
+ self.memory_mgr.to("cuda").eval()
193
+ self.temporal_wrapper.to("cuda").eval()
194
+ self.memory_mgr.reset()
195
+
196
+ # ── SLAM state ───────────────────────────────────
197
+ self.fid = -1
198
+ self.kid = -1
199
+ self.kfids = []
200
+ self.kf_timestamps = []
201
+ self.frames = []
202
+ self.kf_frames = []
203
+ self.map = None
204
+ self.map_opt = None
205
+ self.signal_backend = False
206
+ self.frontend_times = 0
207
+ self.times = []
208
+ self.kf_time = []
209
+ self.backend_time = []
210
+ self.K = None
211
+ self.cur_timestamp = 0
212
+
213
+ def _load_model(self, ckpt_raw):
214
+ if ckpt_raw is None:
215
+ return
216
+ ckpt, _, memory_state, temporal_state = extract_checkpoint_parts(ckpt_raw)
217
+ ckpt = strip_module(ckpt)
218
+ self.model.load_state_dict(ckpt, strict=False)
219
+ if memory_state is not None:
220
+ memory_state = strip_prefix(strip_module(memory_state), "memory_mgr.")
221
+ self.memory_mgr.load_state_dict(memory_state, strict=False)
222
+ if temporal_state is not None:
223
+ temporal_state = strip_prefix(strip_module(temporal_state), "temporal_wrapper.")
224
+ self.temporal_wrapper.load_state_dict(temporal_state, strict=False)
225
+ del ckpt
226
+
227
+ @property
228
+ def time(self):
229
+ torch.cuda.synchronize()
230
+ return time.perf_counter()
231
+
232
+ @property
233
+ def nkf(self):
234
+ return self.kid + 1
235
+
236
+ def kf_detect(self, image):
237
+ if self.kid == -1:
238
+ return True
239
+ frame = utils.load_image(image, self.target_size)
240
+ token = self.model.KFT(torch.stack([self.kf_frames[-1], frame.cuda()]))
241
+ res = self.model.extract(token, cam_only=True)
242
+ camera_pose = res["camera_poses"]
243
+ extrinsic = torch.inverse(camera_pose)
244
+ if extrinsic.shape[1] > 1:
245
+ extrinsic_ref = extrinsic.cpu()[0, -2]
246
+ extrinsic = extrinsic.cpu()[0, -1]
247
+ else:
248
+ return True
249
+ dist = torch.sqrt(torch.sum((extrinsic[:3, 3] - extrinsic_ref[:3, 3]) ** 2))
250
+ return dist > self.kf_th
251
+
252
+ def frontend(self, image):
253
+ self.fid += 1
254
+ print("Frame", self.fid)
255
+
256
+ st = self.time
257
+ enough_disparity = self.kf_detect(image)
258
+ self.kf_time.append(self.time - st)
259
+ if not enough_disparity:
260
+ return False
261
+
262
+ torch.cuda.empty_cache()
263
+ frame = utils.load_image(image, self.target_size)
264
+ st = self.time
265
+ self.last_kf = frame.cuda()
266
+ self.kf_frames.append(self.last_kf)
267
+ self.frames.append(self.last_kf.clone())
268
+ self.kid += 1
269
+ print(f" # KEYFRAME {self.kid}")
270
+ self.kf_timestamps.append(self.cur_timestamp)
271
+
272
+ st = self.time
273
+ if self.nkf == 1:
274
+ pass
275
+ elif self.nkf == 2:
276
+ token = self.model.frontendT(torch.stack([self.kf_frames[0], frame.cuda()]))
277
+ self._map_add(token)
278
+ else:
279
+ token = self.model.frontendT(frame.cuda())
280
+ self._map_add(token)
281
+
282
+ self.kfids.append(self.fid)
283
+ self.times.append(self.time - st)
284
+ torch.cuda.empty_cache()
285
+
286
+ # ── Submap accumulation ──────────────────────────
287
+ if self.enable_submap and self.map is not None:
288
+ last_token = self.map[-1:] if self.map.dim() == 3 else self.map
289
+ self.memory_mgr.accumulate(last_token, frame_id=self.fid)
290
+
291
+ self.frontend_times += 1
292
+ if self.frontend_times % self.bn_every == 0:
293
+ self.signal_backend = True
294
+
295
+ def backend(self, final=False):
296
+ if not self.signal_backend:
297
+ return
298
+
299
+ torch.cuda.empty_cache()
300
+ if hasattr(self.model, "fkv"):
301
+ del self.model.fkv
302
+ torch.cuda.empty_cache()
303
+
304
+ print("Backending...", self.nkf, "KFs")
305
+ st = time.perf_counter()
306
+
307
+ if self.enable_submap and self.memory_mgr.submap_complete:
308
+ # ── Submap-based backend ─────────────────────
309
+ hidden_B, loop_gate, meta = self.memory_mgr.finalize_submap(
310
+ model=self.model,
311
+ device=torch.device("cuda"),
312
+ temporal_wrapper=self.temporal_wrapper if self.enable_temporal else None,
313
+ enable_temporal_embed=self.enable_temporal,
314
+ enable_loop_closure=self.enable_loop,
315
+ )
316
+ # Use the active submap portion as map_opt
317
+ n_prev = meta['n_prev']
318
+ n_curr = meta['n_curr']
319
+ self.map_opt = hidden_B[n_prev:n_prev + n_curr].cpu()
320
+ else:
321
+ # ── Vanilla backend ──────────────────────────
322
+ map_optimed = self.model.backendT(self.map.cuda())
323
+ self.map_opt = map_optimed.cpu()
324
+
325
+ self.backend_time.append(time.perf_counter() - st)
326
+ print("backend_take", time.perf_counter() - st)
327
+ self.signal_backend = False
328
+ torch.cuda.empty_cache()
329
+
330
+ def step(self, timestamp, image):
331
+ self.cur_timestamp = timestamp if timestamp is not None else self.fid + 1
332
+ self.frontend(image)
333
+ self.backend()
334
+
335
+ def _map_add(self, token_kf):
336
+ if self.map is None:
337
+ self.map = token_kf.cpu()
338
+ else:
339
+ self.map = torch.cat([self.map, token_kf.cpu()], axis=0)
340
+
341
+ def terminate(self):
342
+ if self.nkf % self.bn_every != 0:
343
+ self.signal_backend = True
344
+ self.backend(final=True)
345
+
346
+ print("Frontend times:", self.times)
347
+ print("Backend times:", self.backend_time)
348
+ if self.times:
349
+ print("Frontend avg:", np.mean(self.times))
350
+ if self.backend_time:
351
+ print("Backend avg:", np.mean(self.backend_time))
352
+ print("Summary FPS:", float(len(self.kf_time)) / (
353
+ np.sum(self.kf_time) + np.sum(self.times) + np.sum(self.backend_time) + 1e-9
354
+ ))
355
+
356
+ self._save_result(f"{self.outdir}/final", self.map_opt)
357
+
358
+ def _save_result(self, output_path, map_all=None):
359
+ import open3d as o3d
360
+
361
+ print(self.kfids)
362
+ if map_all is None:
363
+ map_all = self.map_opt
364
+
365
+ map_gpu = map_all.cuda()
366
+ # Ensure shape_ is set correctly for extract
367
+ BN = map_gpu.shape[0]
368
+ # Use shape_ from the last frontendT call for H, W
369
+ _, _, H, W, ph, pw = self.model.shape_
370
+ self.model.shape_ = (1, BN, H, W, ph, pw)
371
+ result = self.model.extract(map_gpu)
372
+ pts = result["points"].cpu().numpy()
373
+ _, S, H, W, _ = pts.shape
374
+ conf = result["conf"].cpu().numpy()
375
+ conf_threshold = np.percentile(conf, 15)
376
+ confs = [conf[0, s] >= conf_threshold for s in range(S)]
377
+
378
+ colors = torch.stack(self.frames).permute(0, 2, 3, 1).reshape(-1, 3).cpu().numpy()[:, ::-1]
379
+ msk = np.stack(confs).reshape(-1)
380
+ pcd = o3d.geometry.PointCloud()
381
+ pcd.points = o3d.utility.Vector3dVector(pts.reshape(-1, 3).astype(np.float64)[msk])
382
+ pcd.colors = o3d.utility.Vector3dVector(colors.reshape(-1, 3).astype(np.float64)[msk])
383
+ o3d.io.write_point_cloud(f"{output_path}.ply", pcd)
384
+
385
+ camera_pose = result["camera_poses"].cpu()
386
+ poses = camera_pose[0].numpy()
387
+ self._write_poses(f"{output_path}_traj.txt", poses, self.kf_timestamps)
388
+
389
+ def _write_poses(self, filename, poses, frame_ids):
390
+ with open(filename, "w") as f:
391
+ for frame_id, pose in zip(frame_ids, poses):
392
+ x, y, z = pose[0:3, 3]
393
+ quat = R.from_matrix(pose[0:3, 0:3]).as_quat()
394
+ output = np.array([float(frame_id), x, y, z, *quat])
395
+ f.write(" ".join(f"{v:.8f}" for v in output) + "\n")
396
+
397
+
398
+ def get_parser():
399
+ parser = argparse.ArgumentParser(description="Infinite SLAM-Former Demo")
400
+ # ── Original demo.py args ────────────────────────────
401
+ parser.add_argument("--ckpt_path", type=str, default="")
402
+ parser.add_argument("--image_folder", type=str, default="")
403
+ parser.add_argument("--target_size", type=int, default=518)
404
+ parser.add_argument("--output_dir", type=str, default="outputs/tmp")
405
+ parser.add_argument("--stride", type=int, default=1)
406
+ parser.add_argument("--kf_th", type=float, default=0.1)
407
+ parser.add_argument("--retention_ratio", type=float, default=0.5)
408
+ parser.add_argument("--bn_every", type=int, default=10)
409
+ parser.add_argument("--resize_rate", type=float, default=1.0)
410
+
411
+ # ── Task 1: submap backend ───────────────────────────
412
+ parser.add_argument("--enable_submap_backend", action="store_true")
413
+ parser.add_argument("--submap_size", type=int, default=10)
414
+ parser.add_argument("--max_recursive_submaps", type=int, default=3)
415
+ parser.add_argument("--enable_loop_closure", action="store_true")
416
+ parser.add_argument("--desc_dim", type=int, default=128)
417
+
418
+ # ── Fix #7: temporal embedding ───────────────────────
419
+ parser.add_argument("--enable_temporal_embed", action="store_true")
420
+ parser.add_argument("--temporal_embed_mode", type=str, default="learned",
421
+ choices=["learned", "sinusoidal"])
422
+
423
+ return parser
424
+
425
+
426
+ if __name__ == "__main__":
427
+ args = get_parser().parse_args()
428
+ os.makedirs(args.output_dir, exist_ok=True)
429
+
430
+ # ── Camera intrinsics (same logic as original demo.py) ──
431
+ if "tum" in args.image_folder:
432
+ K = np.eye(3)
433
+ K[0, 0], K[1, 1] = 525.0, 525.0
434
+ K[0, 2], K[1, 2] = 319.5, 239.5
435
+ elif "Replica" in args.image_folder:
436
+ K = np.eye(3)
437
+ K[0, 0], K[1, 1] = 600.0, 600.0
438
+ K[0, 2], K[1, 2] = 599.5, 339.5
439
+ else:
440
+ K = None
441
+
442
+ # ── Load images ──────────────────────────────────────
443
+ print(f"Loading images from {args.image_folder}...")
444
+ image_names = [
445
+ f for f in glob.glob(os.path.join(args.image_folder, "*"))
446
+ if "depth" not in os.path.basename(f).lower()
447
+ and "txt" not in os.path.basename(f).lower()
448
+ and "db" not in os.path.basename(f).lower()
449
+ ]
450
+ image_names = utils.sort_images_by_number(image_names)
451
+
452
+ frame_ids = []
453
+ for path in image_names:
454
+ match = re.search(r"\d+(?:\.\d+)?", os.path.basename(path))
455
+ if match:
456
+ frame_ids.append(float(match.group()))
457
+ else:
458
+ raise ValueError(f"No number found in image name: {path}")
459
+
460
+ print(f"Found {len(image_names)} images")
461
+
462
+ # ── Create SLAM instance ─────────────────────────────
463
+ slam = InfiniteSLAM(
464
+ outdir=args.output_dir,
465
+ kf_th=args.kf_th,
466
+ bn_every=args.bn_every,
467
+ ckpt_path=args.ckpt_path,
468
+ target_size=args.target_size,
469
+ retention_ratio=args.retention_ratio,
470
+ enable_submap_backend=args.enable_submap_backend,
471
+ submap_size=args.submap_size,
472
+ max_recursive_submaps=args.max_recursive_submaps,
473
+ enable_loop_closure=args.enable_loop_closure,
474
+ desc_dim=args.desc_dim,
475
+ enable_temporal_embed=args.enable_temporal_embed,
476
+ temporal_embed_mode=args.temporal_embed_mode,
477
+ )
478
+ slam.K = K
479
+
480
+ # ── Run ──────────────────────────────────────────────
481
+ for frame_id, image_name in zip(
482
+ frame_ids[:: args.stride], image_names[:: args.stride]
483
+ ):
484
+ img = cv2.imread(image_name)
485
+ if args.resize_rate != 1:
486
+ H, W, _ = img.shape
487
+ img = cv2.resize(
488
+ img, (int(W * args.resize_rate), int(H * args.resize_rate)),
489
+ cv2.INTER_CUBIC,
490
+ )
491
+ slam.step(frame_id, img)
492
+
493
+ slam.terminate()
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/demo_submap.py ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+ from collections import OrderedDict
3
+ import torch
4
+ import numpy as np
5
+ import re
6
+ import cv2
7
+ import glob
8
+ import argparse
9
+
10
+ import time
11
+ import open3d as o3d
12
+ from rich import print
13
+ import matplotlib.pyplot as plt
14
+ from scipy.spatial.transform import Rotation as R
15
+
16
+ import rerun as rr
17
+ import rerun.blueprint as rrb
18
+
19
+ current_directory = os.path.dirname(os.path.abspath(__file__))
20
+
21
+ sys.path.append(current_directory+'/../')
22
+ sys.path.append('src')
23
+ from slamformer.models.slamformer import SLAMFormer
24
+ from slam.graph_gated_memory import GraphGatedMemoryManager, TemporalEmbedWrapper
25
+
26
+ import slam.utils as utils
27
+ from slam.rerun_helper import log_camera, log_window
28
+
29
+ def strip_module(state_dict):
30
+ """
31
+ Removes the 'module.' prefix from the keys of a state_dict.
32
+ Args:
33
+ state_dict (dict): The original state_dict with possible 'module.' prefixes.
34
+ Returns:
35
+ OrderedDict: A new state_dict with 'module.' prefixes removed.
36
+ """
37
+ new_state_dict = OrderedDict()
38
+ for k, v in state_dict.items():
39
+ name = k[7:] if k.startswith("module.") else k
40
+ new_state_dict[name] = v
41
+ return new_state_dict
42
+
43
+
44
+ def strip_prefix(state_dict, prefix):
45
+ if not isinstance(state_dict, dict):
46
+ return state_dict
47
+ new_state_dict = OrderedDict()
48
+ for k, v in state_dict.items():
49
+ name = k[len(prefix):] if k.startswith(prefix) else k
50
+ new_state_dict[name] = v
51
+ return new_state_dict
52
+
53
+
54
+ def get_cfg_value(cfg, key, default=None):
55
+ if cfg is None:
56
+ return default
57
+ try:
58
+ if hasattr(cfg, "get"):
59
+ return cfg.get(key, default)
60
+ except Exception:
61
+ pass
62
+ try:
63
+ return cfg[key]
64
+ except Exception:
65
+ return getattr(cfg, key, default)
66
+
67
+
68
+ def extract_checkpoint_parts(ckpt_raw):
69
+ train_cfg = None
70
+ memory_state = None
71
+ temporal_state = None
72
+ if isinstance(ckpt_raw, dict):
73
+ train_cfg = ckpt_raw.get("args")
74
+ if "model" in ckpt_raw:
75
+ ckpt = ckpt_raw["model"]
76
+ print("Loaded state_dict from 'model' key in checkpoint.")
77
+ else:
78
+ ckpt = ckpt_raw
79
+ nested = ckpt_raw.get("submap_modules", {})
80
+ if isinstance(nested, dict):
81
+ memory_state = ckpt_raw.get("memory_mgr", nested.get("memory_mgr"))
82
+ temporal_state = ckpt_raw.get("temporal_wrapper", nested.get("temporal_wrapper"))
83
+ else:
84
+ memory_state = ckpt_raw.get("memory_mgr")
85
+ temporal_state = ckpt_raw.get("temporal_wrapper")
86
+ else:
87
+ ckpt = ckpt_raw
88
+ return ckpt, train_cfg, memory_state, temporal_state
89
+
90
+
91
+ class KVSubmapManager:
92
+ def __init__(
93
+ self,
94
+ submap_size=10,
95
+ max_loop_submaps=5,
96
+ loop_similarity_threshold=0.75,
97
+ desc_dim=128,
98
+ embed_dim=1024,
99
+ ):
100
+ self.K = submap_size
101
+ self.max_loop_submaps = max_loop_submaps
102
+ self.loop_similarity_threshold = loop_similarity_threshold
103
+ self.desc_dim = desc_dim
104
+ self.embed_dim = embed_dim
105
+
106
+ self.desc_proj = torch.nn.Linear(2 * embed_dim, desc_dim)
107
+ self.submap_tokens_cpu = {}
108
+ self.submap_descriptors = {}
109
+ self.submap_frame_ids = {}
110
+ self.adjacency = {}
111
+
112
+ self._curr_tokens = []
113
+ self._curr_frame_ids = []
114
+ self._current_submap_id = 0
115
+
116
+ @property
117
+ def current_submap_id(self):
118
+ return self._current_submap_id
119
+
120
+ @property
121
+ def current_frame_ids(self):
122
+ return list(self._curr_frame_ids)
123
+
124
+ def to(self, device):
125
+ self.desc_proj = self.desc_proj.to(device)
126
+ return self
127
+
128
+ def load_memory_state_dict(self, state_dict):
129
+ desc_state = OrderedDict()
130
+ for k, v in state_dict.items():
131
+ key = k
132
+ if key.startswith("memory_mgr."):
133
+ key = key[len("memory_mgr."):]
134
+ if key.startswith("desc_proj."):
135
+ desc_state[key[len("desc_proj."):]] = v
136
+ if desc_state:
137
+ msg = self.desc_proj.load_state_dict(desc_state, strict=False)
138
+ print(f"Loaded submap descriptor head: {msg}")
139
+
140
+ def accumulate(self, frame_token, frame_id):
141
+ if frame_token.dim() == 2:
142
+ frame_token = frame_token.unsqueeze(0)
143
+ self._curr_tokens.append(frame_token.detach().cpu())
144
+ self._curr_frame_ids.append(int(frame_id))
145
+
146
+ def has_current_tokens(self):
147
+ return len(self._curr_tokens) > 0
148
+
149
+ def current_tokens(self, device):
150
+ if not self._curr_tokens:
151
+ return None
152
+ return torch.cat(self._curr_tokens, dim=0).to(device, non_blocking=True)
153
+
154
+ def compute_descriptor(self, tokens):
155
+ if self.desc_proj.weight.device != tokens.device:
156
+ self.desc_proj = self.desc_proj.to(tokens.device)
157
+ pooled = tokens.mean(dim=(0, 1), keepdim=False).unsqueeze(0).float()
158
+ desc = self.desc_proj(pooled).squeeze(0)
159
+ return torch.nn.functional.normalize(desc, dim=0)
160
+
161
+ def select_context_submaps(self, curr_desc):
162
+ prev_sid = self._current_submap_id - 1
163
+ if prev_sid not in self.submap_tokens_cpu:
164
+ prev_sid = None
165
+
166
+ hist_ids = sorted([sid for sid in self.submap_tokens_cpu.keys() if sid < self._current_submap_id])
167
+ loop_ids = []
168
+
169
+ primary_sid = None
170
+ primary_sim = None
171
+ if self.max_loop_submaps > 0:
172
+ candidates = [sid for sid in hist_ids if sid != prev_sid]
173
+ if candidates:
174
+ curr_desc_cpu = curr_desc.detach().cpu().float()
175
+ scored = []
176
+ for sid in candidates:
177
+ hist_desc = self.submap_descriptors[sid].float()
178
+ sim = torch.nn.functional.cosine_similarity(
179
+ curr_desc_cpu.unsqueeze(0),
180
+ hist_desc.unsqueeze(0),
181
+ dim=-1,
182
+ ).item()
183
+ scored.append((sim, sid))
184
+ scored.sort(reverse=True)
185
+ if scored:
186
+ primary_sim, primary_sid = scored[0]
187
+ if primary_sim >= self.loop_similarity_threshold:
188
+ loop_ids.append(primary_sid)
189
+ for nid in sorted(self.adjacency.get(primary_sid, set())):
190
+ if len(loop_ids) >= self.max_loop_submaps:
191
+ break
192
+ if nid != self._current_submap_id and nid in self.submap_tokens_cpu and nid not in loop_ids:
193
+ loop_ids.append(nid)
194
+ else:
195
+ primary_sid = None
196
+
197
+ deduped = []
198
+ seen = set()
199
+ for sid in loop_ids:
200
+ if sid not in seen:
201
+ deduped.append(sid)
202
+ seen.add(sid)
203
+ return prev_sid, deduped, primary_sid, primary_sim
204
+
205
+ def build_backend_tokens(self, device, prev_sid=None, loop_ids=None):
206
+ loop_ids = [] if loop_ids is None else list(loop_ids)
207
+ parts = []
208
+ n_prev = 0
209
+ loop_token_counts = []
210
+
211
+ if prev_sid is not None and prev_sid in self.submap_tokens_cpu:
212
+ prev_tokens = self.submap_tokens_cpu[prev_sid].to(device, non_blocking=True)
213
+ parts.append(prev_tokens)
214
+ n_prev = prev_tokens.shape[0]
215
+
216
+ curr_tokens = self.current_tokens(device)
217
+ n_curr = curr_tokens.shape[0]
218
+ parts.append(curr_tokens)
219
+
220
+ for sid in loop_ids:
221
+ loop_tokens = self.submap_tokens_cpu[sid].to(device, non_blocking=True)
222
+ parts.append(loop_tokens)
223
+ loop_token_counts.append(loop_tokens.shape[0])
224
+
225
+ combined = torch.cat(parts, dim=0)
226
+ meta = {
227
+ "n_prev": n_prev,
228
+ "n_curr": n_curr,
229
+ "curr_frame_ids": list(self._curr_frame_ids),
230
+ "loop_token_counts": loop_token_counts,
231
+ }
232
+ return combined, meta
233
+
234
+ def _store_refined_submap(self, sid, refined_tokens):
235
+ self.submap_tokens_cpu[sid] = refined_tokens.detach().cpu()
236
+ self.submap_descriptors[sid] = self.compute_descriptor(refined_tokens).detach().cpu()
237
+
238
+ def finalize_current_submap(
239
+ self,
240
+ hidden_B,
241
+ n_prev,
242
+ n_curr,
243
+ prev_sid=None,
244
+ loop_ids=None,
245
+ loop_token_counts=None,
246
+ primary_sid=None,
247
+ ):
248
+ loop_ids = [] if loop_ids is None else list(loop_ids)
249
+ loop_token_counts = [] if loop_token_counts is None else list(loop_token_counts)
250
+
251
+ offset = 0
252
+ if prev_sid is not None and n_prev > 0:
253
+ refined_prev = hidden_B[:n_prev]
254
+ self._store_refined_submap(prev_sid, refined_prev)
255
+ offset = n_prev
256
+
257
+ refined_curr = hidden_B[offset:offset + n_curr]
258
+ offset += n_curr
259
+
260
+ for sid, count in zip(loop_ids, loop_token_counts):
261
+ refined_loop = hidden_B[offset:offset + count]
262
+ self._store_refined_submap(sid, refined_loop)
263
+ offset += count
264
+
265
+ sid = self._current_submap_id
266
+ self.submap_tokens_cpu[sid] = refined_curr.detach().cpu()
267
+ self.submap_descriptors[sid] = self.compute_descriptor(refined_curr).detach().cpu()
268
+ self.submap_frame_ids[sid] = list(self._curr_frame_ids)
269
+
270
+ if primary_sid is not None and primary_sid in self.submap_tokens_cpu:
271
+ self.adjacency.setdefault(sid, set()).add(primary_sid)
272
+ self.adjacency.setdefault(primary_sid, set()).add(sid)
273
+
274
+ self._curr_tokens = []
275
+ self._curr_frame_ids = []
276
+ self._current_submap_id += 1
277
+
278
+
279
+ class SLAM:
280
+ def __init__(
281
+ self,
282
+ outdir='output/tmp',
283
+ kf_th=0.1,
284
+ bn_every=None,
285
+ vis=False,
286
+ save_gmem=True,
287
+ ckpt_path='path/to/ckpt.pth',
288
+ target_size=518,
289
+ retention_ratio=None,
290
+ loop_mask_mode=None,
291
+ submap_train_mode=None,
292
+ submap_retrieval_topk=None,
293
+ submap_fetch_source=None,
294
+ submap_descriptor_source=None,
295
+ max_recursive_submaps=None,
296
+ ):
297
+
298
+ self.outdir = outdir
299
+ self.kf_th=kf_th
300
+ self.save_gmem = save_gmem
301
+ self.bn_every=bn_every
302
+ self.vis = vis
303
+ self.ckpt_path = ckpt_path
304
+ self.target_size = target_size
305
+
306
+
307
+ self.times = []
308
+ self.kf_time = []
309
+ self.backend_time = []
310
+
311
+ ckpt_raw = torch.load(self.ckpt_path, map_location='cpu', weights_only=False)
312
+ _, self.train_cfg, _, _ = extract_checkpoint_parts(ckpt_raw)
313
+
314
+ self.bn_every = int(bn_every if bn_every is not None else get_cfg_value(self.train_cfg, 'submap_size', 10))
315
+ self.retention_ratio = float(
316
+ retention_ratio if retention_ratio is not None else get_cfg_value(self.train_cfg, 'retention_ratio', 0.5)
317
+ )
318
+ self.enable_loop = bool(get_cfg_value(self.train_cfg, 'enable_loop', False))
319
+ self.enable_temporal = bool(get_cfg_value(self.train_cfg, 'enable_temporal', False))
320
+ self.tbptt_window = int(get_cfg_value(self.train_cfg, 'tbptt_window', 0))
321
+ self.max_recursive_submaps = int(
322
+ max_recursive_submaps if max_recursive_submaps is not None else get_cfg_value(self.train_cfg, 'max_recursive_submaps', 5)
323
+ )
324
+ self.desc_dim = int(get_cfg_value(self.train_cfg, 'desc_dim', 128))
325
+ self.loop_mask_mode = loop_mask_mode if loop_mask_mode is not None else get_cfg_value(self.train_cfg, 'loop_mask_mode', 'hard_top1')
326
+ self.soft_mask_temperature = float(get_cfg_value(self.train_cfg, 'soft_mask_temperature', 0.25))
327
+ self.soft_mask_bias = float(get_cfg_value(self.train_cfg, 'soft_mask_bias', 0.2))
328
+ self.submap_train_mode = (
329
+ submap_train_mode if submap_train_mode is not None else get_cfg_value(self.train_cfg, 'submap_train_mode', 'full_token')
330
+ )
331
+ self.submap_retrieval_topk = int(
332
+ submap_retrieval_topk if submap_retrieval_topk is not None else get_cfg_value(self.train_cfg, 'submap_retrieval_topk', 0)
333
+ )
334
+ self.submap_fetch_source = (
335
+ submap_fetch_source if submap_fetch_source is not None else get_cfg_value(self.train_cfg, 'submap_fetch_source', 'frontend')
336
+ )
337
+ self.submap_descriptor_source = (
338
+ submap_descriptor_source
339
+ if submap_descriptor_source is not None
340
+ else get_cfg_value(self.train_cfg, 'submap_descriptor_source', 'frontend')
341
+ )
342
+ self.temporal_embed_mode = get_cfg_value(self.train_cfg, 'temporal_embed_mode', 'learned')
343
+
344
+ # model params
345
+ self.model = SLAMFormer(retention_ratio=self.retention_ratio, bn_every=self.bn_every)
346
+ self.model = self.model.eval()
347
+ self.memory_mgr = GraphGatedMemoryManager(
348
+ submap_size=self.bn_every,
349
+ max_recursive_submaps=self.max_recursive_submaps,
350
+ desc_dim=self.desc_dim,
351
+ embed_dim=self.model.dec_embed_dim,
352
+ loop_mask_mode=self.loop_mask_mode,
353
+ soft_mask_temperature=self.soft_mask_temperature,
354
+ soft_mask_bias=self.soft_mask_bias,
355
+ retain_history_grad=False,
356
+ submap_train_mode=self.submap_train_mode,
357
+ submap_retrieval_topk=self.submap_retrieval_topk,
358
+ submap_fetch_source=self.submap_fetch_source,
359
+ submap_descriptor_source=self.submap_descriptor_source,
360
+ )
361
+ self.temporal_wrapper = (
362
+ TemporalEmbedWrapper(
363
+ embed_dim=self.model.dec_embed_dim,
364
+ max_frames=5000,
365
+ mode=self.temporal_embed_mode,
366
+ )
367
+ if self.enable_temporal else None
368
+ )
369
+ self.load_model(ckpt_raw)
370
+ del ckpt_raw
371
+ self.model.eval()
372
+ self.model.to('cuda')
373
+ self.memory_mgr.to('cuda')
374
+ self.memory_mgr.eval()
375
+ if self.temporal_wrapper is not None:
376
+ self.temporal_wrapper.to('cuda')
377
+ self.temporal_wrapper.eval()
378
+ self.memory_mgr.reset()
379
+ print(
380
+ f"Resolved inference config: submap_size={self.bn_every}, retention_ratio={self.retention_ratio}, "
381
+ f"enable_loop={self.enable_loop}, enable_temporal={self.enable_temporal}, loop_mask_mode={self.loop_mask_mode}, "
382
+ f"submap_train_mode={self.submap_train_mode}, submap_retrieval_topk={self.submap_retrieval_topk}, "
383
+ f"submap_fetch_source={self.submap_fetch_source}, submap_descriptor_source={self.submap_descriptor_source}, "
384
+ f"max_recursive_submaps={self.max_recursive_submaps}"
385
+ )
386
+
387
+ # SLAM params
388
+ self.fid = -1
389
+ self.kid = -1
390
+ self.kfids = []
391
+ self.last_kfid = 0
392
+ self.kf_timestamps = []
393
+ # frontend
394
+ self.frontend_times = 0
395
+ # Token map
396
+ self.map = None
397
+ self.map_opt = None
398
+
399
+ self.signal_backend = False
400
+ self.backend_every = self.bn_every #10
401
+ #
402
+ self.extrins = []
403
+ self.intrins = []
404
+ self.frames = []
405
+ self.kf_frames = []
406
+
407
+ #
408
+ self.K = None
409
+ self.update_K = False
410
+
411
+ # vis
412
+ if self.vis:
413
+ self.entity="world"
414
+ rr.init("SLAM", spawn=True)
415
+ rr.log(self.entity, rr.ViewCoordinates.RIGHT_HAND_Z_UP)
416
+ self.Twk = np.eye(4)
417
+ self.K = np.eye(3)
418
+
419
+ def load_model(self, ckpt_raw):
420
+ ckpt, _, memory_state, temporal_state = extract_checkpoint_parts(ckpt_raw)
421
+ ckpt = utils.strip_module(ckpt)
422
+ self.model.load_state_dict(ckpt, strict=False)
423
+ if memory_state is not None:
424
+ memory_state = strip_prefix(strip_module(memory_state), 'memory_mgr.')
425
+ msg = self.memory_mgr.load_state_dict(memory_state, strict=False)
426
+ print(f"Loaded memory manager: {msg}")
427
+ if self.temporal_wrapper is not None and temporal_state is not None:
428
+ temporal_state = strip_prefix(strip_module(temporal_state), 'temporal_wrapper.')
429
+ msg = self.temporal_wrapper.load_state_dict(temporal_state, strict=False)
430
+ print(f"Loaded temporal wrapper: {msg}")
431
+ del ckpt
432
+
433
+ @property
434
+ def time(self):
435
+ torch.cuda.synchronize()
436
+ return time.perf_counter()
437
+
438
+ def _estimate_kf_pose(self, frame_pair, use_amp=True):
439
+ token = self.model.KFT(frame_pair, use_amp=use_amp)
440
+ res = self.model.extract(token, cam_only=True, use_amp=use_amp)
441
+ camera_pose = res['camera_poses']
442
+ extrinsic = torch.linalg.inv(camera_pose)
443
+ return token, camera_pose, extrinsic
444
+
445
+ def kf_detect(self, image):
446
+ if self.kid == -1:
447
+ self.extrins.append(torch.eye(4))
448
+ return True
449
+
450
+ frame = utils.load_image(image, self.target_size)
451
+ _,H,W = frame.shape
452
+
453
+
454
+ st = self.time #time.perf_counter()
455
+ frame_pair = torch.stack([self.kf_frames[-1], frame.cuda()])
456
+ token, camera_pose, extrinsic = self._estimate_kf_pose(frame_pair, use_amp=True)
457
+ if (not torch.isfinite(camera_pose).all()) or (not torch.isfinite(extrinsic).all()):
458
+ print("[warning] Non-finite keyframe pose under AMP; retrying keyframe detection in float32.")
459
+ token, camera_pose, extrinsic = self._estimate_kf_pose(frame_pair, use_amp=False)
460
+ if (not torch.isfinite(camera_pose).all()) or (not torch.isfinite(extrinsic).all()):
461
+ print("[warning] Non-finite keyframe pose after float32 retry; skipping this frame for keyframe detection.")
462
+ return False
463
+ if self.vis:
464
+ # scale the pose to global
465
+ #z = res['local_points'][0,0,:,:,-1].cpu().numpy()
466
+ if not hasattr(self,'depth_lask_kf'):
467
+ scale=1
468
+ else:
469
+ scale=1 #np.median(self.depth_last_kf/(z+1e-6))
470
+ if extrinsic.shape[1] > 1:
471
+ extrinsic_ref=extrinsic.cpu()[0,-2]
472
+ extrinsic = extrinsic.cpu()[0,-1]
473
+ Tki = torch.linalg.inv(camera_pose[0,0])@camera_pose[0,1]
474
+ Tki = Tki.cpu().numpy()
475
+ self.Twi = self.Twk@Tki
476
+ K44 = np.eye(4)
477
+ K44[:3,:3] = self.K
478
+ log_camera("camera",self.Twi, K44, kfd=True)
479
+ # make the window follow camera
480
+ log_window(f"{self.entity}",np.linalg.inv(self.Twi), K44)
481
+
482
+
483
+ else:
484
+ if extrinsic.shape[1] > 1:
485
+ extrinsic_ref=extrinsic.cpu()[0,-2]
486
+ extrinsic = extrinsic.cpu()[0,-1]
487
+ self.kft_extrinsic_ref = torch.eye(4)#extrinsic_ref
488
+
489
+ dist = torch.sqrt(torch.sum((extrinsic[:3,3] - extrinsic_ref[:3,3])**2))
490
+ if not torch.isfinite(dist):
491
+ print("[warning] Non-finite keyframe distance after pose recovery; skipping this frame.")
492
+ return False
493
+ isKF = dist > self.kf_th
494
+
495
+ print(dist)
496
+
497
+ if isKF:
498
+ self.extrins.append(extrinsic)
499
+ return isKF
500
+
501
+ def frontend(self, image):
502
+
503
+ if self.vis:
504
+ rr.log("image", rr.Image(image[:,:,::-1]))#,static=True)
505
+
506
+ self.fid += 1
507
+ print('Frame', self.fid)
508
+ # run kf detector
509
+ st = self.time
510
+ enough_disparity = self.kf_detect(image)
511
+ self.kf_time.append(self.time-st)
512
+ if not enough_disparity:
513
+ return False
514
+
515
+
516
+ torch.cuda.empty_cache()
517
+ # run T-frontend
518
+ H_,W_,_ = image.shape
519
+ frame = utils.load_image(image, self.target_size)
520
+ self.H,self.W,_ = frame.shape
521
+ st = self.time
522
+ self.last_kf = frame.cuda()
523
+ self.kf_frames.append(self.last_kf)
524
+ self.last_kfid = self.fid
525
+ self.frames.append(self.last_kf.clone())
526
+ self.kid += 1
527
+ print("[italic purple] # KEYFRAME", self.kid)
528
+ self.kf_timestamps.append(self.cur_timestamp)
529
+ frame = frame.cuda()
530
+ st = self.time
531
+
532
+ if self.nkf == 1:
533
+ pass
534
+ elif self.nkf == 2:
535
+ token = self.model.frontendT(torch.stack([self.kf_frames[0],frame]))
536
+ self.map_add(token)
537
+ else:
538
+ token = self.model.frontendT(frame)
539
+ print(self.time-st)
540
+
541
+ self.map_add(token)
542
+
543
+ self.kfids.append(self.fid)
544
+ self.times.append(self.time-st)
545
+ torch.cuda.empty_cache()
546
+
547
+ # send signal to backend
548
+ self.frontend_times += 1
549
+ if self.memory_mgr.submap_complete:
550
+ self.signal_backend = True
551
+
552
+ if self.vis and self.map is not None:
553
+ st = time.time()
554
+ map_before_bn = None
555
+ if self.map_opt is None:
556
+ map_before_bn = self.map
557
+ else:
558
+ S = self.map.shape[0]
559
+ S_oldopt = self.map_opt.shape[0]
560
+
561
+ map_before_bn = torch.cat([self.map_opt, self.map[S_oldopt:]],axis=0)
562
+ if self.nkf == 2:
563
+ ps,cs,confs,poses = self.extract(self.map)
564
+
565
+ else:
566
+ ps,cs,confs,poses = self.extract(self.map[-1:])
567
+
568
+
569
+ self.vis_mem = [ps,cs,confs,poses]
570
+
571
+ conf_threshold = np.percentile(confs, 15)
572
+ msk = confs>=conf_threshold
573
+
574
+ ps = ps[msk]
575
+ cs = cs[msk]
576
+ K44 = np.eye(4)
577
+ K44[:3,:3] = self.K
578
+
579
+ if self.nkf == 2:
580
+ log_camera(f"{self.entity}/camera_kf/0",poses[0], K44)
581
+ log_camera(f"{self.entity}/camera_kf/1",poses[1], K44)
582
+
583
+ rr.log(f"{self.entity}/lines/0to1", rr.LineStrips3D([poses[:,:3,3].tolist()],colors=[0,0,255],radii=[0.005]))
584
+
585
+ self.last_kf_pose = poses[1]
586
+ else:
587
+ log_camera(f"{self.entity}/camera_kf/{self.nkf-1}",poses.reshape(4,4), K44)
588
+ rr.log(f"{self.entity}/lines/{self.nkf-2}to{self.nkf-1}", rr.LineStrips3D([np.stack([self.last_kf_pose[:3,3],poses[0,:3,3]]).tolist()],colors=[0,0,255],radii=[0.005]))
589
+
590
+ self.last_kf_pose = poses[0]
591
+
592
+
593
+ rr.log(
594
+ f"{self.entity}/pointclouds/{self.nkf}",
595
+ rr.Points3D(ps, colors=cs, radii=0.01),
596
+ )
597
+
598
+ print('log', time.time()-st)
599
+
600
+ self.Twk = poses[-1].reshape(4,4)
601
+
602
+ def backend(self, final=False):
603
+ if not self.signal_backend:
604
+ return
605
+
606
+ torch.cuda.empty_cache()
607
+ if not self.memory_mgr._curr_tokens:
608
+ self.signal_backend = False
609
+ return
610
+
611
+ if hasattr(self.model, 'fkv'):
612
+ del self.model.fkv
613
+ self.model.reset_backend_cache = True
614
+ self.model._prune_idx_cache = None
615
+ self.model._prune_idx_cache_N = 0
616
+ self.model._prune_idx_cache_hw = None
617
+
618
+ print('Backending...', self.nkf, 'KFs')
619
+ st = time.perf_counter()
620
+ map_optimed, loop_gate, backend_meta = self.memory_mgr.finalize_submap(
621
+ model=self.model,
622
+ device=torch.device('cuda'),
623
+ temporal_wrapper=self.temporal_wrapper,
624
+ enable_temporal_embed=self.enable_temporal,
625
+ enable_loop_closure=self.enable_loop,
626
+ tbptt_window=self.tbptt_window,
627
+ )
628
+ backend_take = time.perf_counter()-st
629
+ self.backend_time.append(backend_take)
630
+ print(
631
+ f'Submap backend: sid={self.memory_mgr.current_submap_id - 1}, '
632
+ f'prev_tokens={backend_meta["n_prev"]}, retrieved_tokens={backend_meta["n_retrieved"]}, '
633
+ f'loop_gate={float(loop_gate.squeeze().detach().cpu())}'
634
+ )
635
+ print('backend_take', backend_take)
636
+ torch.cuda.empty_cache()
637
+
638
+ map_cpu = self.map.detach().cpu() if self.map.is_cuda else self.map
639
+ if self.map_opt is None:
640
+ self.map_opt = map_cpu.clone()
641
+ elif self.map_opt.shape[0] < map_cpu.shape[0]:
642
+ self.map_opt = torch.cat([self.map_opt, map_cpu[self.map_opt.shape[0]:]], dim=0)
643
+
644
+ for local_idx, frame_id in enumerate(backend_meta["frame_ids"]):
645
+ self.map_opt[int(frame_id)] = map_optimed[local_idx].detach().cpu()
646
+
647
+ self.signal_backend = False
648
+ torch.cuda.empty_cache()
649
+
650
+ if self.vis:
651
+ ps,cs,confs,poses = self.extract(self.map_opt)
652
+ self.vis_mem = [ps,cs,confs,poses]
653
+ conf_threshold = np.percentile(confs, 15)
654
+ msk = confs>=conf_threshold
655
+
656
+ ps = ps[msk]
657
+ cs = cs[msk]
658
+
659
+
660
+ for s in range(self.nkf+1):
661
+ rr.log(f"{self.entity}/pointclouds/{s}", rr.Points3D(np.array([])))
662
+
663
+ for s in range(self.nkf):
664
+ K44 = np.eye(4)
665
+ K44[:3,:3] = self.K
666
+ log_camera(f"{self.entity}/camera_kf/{s}",poses[s].reshape(4,4), K44, update=True)
667
+
668
+ for s in range(1, self.nkf):
669
+ rr.log(f"{self.entity}/lines/{s-1}to{s}", rr.LineStrips3D([poses[s-1:s+1,:3,3].tolist()],colors=[0,0,255],radii=[0.005]))
670
+
671
+ rr.log(
672
+ f"{self.entity}/pointclouds/{self.nkf}",
673
+ rr.Points3D(ps, colors=cs, radii=0.01),
674
+ )
675
+ self.last_kf_pose = poses[-1]
676
+
677
+ def step(self, timestamp, image):
678
+ if timestamp is None:
679
+ self.cur_timestamp = self.fid+1
680
+ else:
681
+ self.cur_timestamp = timestamp
682
+
683
+ self.frontend(image)
684
+
685
+ self.backend()
686
+
687
+ def map_add(self, token_kf):
688
+ token_kf = token_kf.detach()
689
+ start_idx = 0 if self.map is None else int(self.map.shape[0])
690
+ for i, tok in enumerate(token_kf):
691
+ self.memory_mgr.accumulate(tok.unsqueeze(0), start_idx + i)
692
+ if self.map is None:
693
+ self.map = token_kf.cpu() if self.save_gmem else token_kf #[tok.cpu() for tok in token_kf]
694
+ else:
695
+ if self.save_gmem:
696
+ self.map = torch.cat([self.map, token_kf.cpu()],axis=0) # S,P,C
697
+ else:
698
+ self.map = torch.cat([self.map, token_kf],axis=0) # S,P,C
699
+
700
+ @property
701
+ def nkf(self):
702
+ return self.kid+1
703
+
704
+ @property
705
+ def nf(self):
706
+ return self.fid+1
707
+
708
+ def terminate(self):
709
+ if self.memory_mgr._curr_tokens:
710
+ self.signal_backend = True
711
+ self.backend(final=True)
712
+
713
+ print(self.kf_time)
714
+ print(self.times)
715
+ print(self.backend_time)
716
+ print('frontend take', np.mean(self.times))
717
+ print('KFT')
718
+ print('total', np.sum(self.kf_time), 'FPS', float(len(self.kf_time))/np.sum(self.kf_time))
719
+ print('FT')
720
+ print('total', np.sum(self.times), 'FPS', float(len(self.times))/np.sum(self.times))
721
+ print('BT')
722
+ print('total', np.sum(self.backend_time), 'FPS', float(len(self.backend_time))/np.sum(self.backend_time))
723
+ print('Summary')
724
+ print('total', np.sum(self.kf_time)+np.sum(self.times)+np.sum(self.backend_time), 'FPS', float(len(self.kf_time))/(np.sum(self.kf_time)+np.sum(self.times)+np.sum(self.backend_time)))
725
+ map_to_save = self.map_opt if self.map_opt is not None else self.map
726
+ if map_to_save is None:
727
+ print(
728
+ f"[warning] No map was built for this sequence (nkf={self.nkf}). "
729
+ "Skipping result export to avoid crashing on a single-keyframe run."
730
+ )
731
+ return None
732
+ self.save_result(f'{self.outdir}/final', map_to_save)
733
+
734
+ def extract(self, map_all=None):
735
+ result = self.model.extract(map_all.cuda())
736
+
737
+ pts = result['points'].cpu().numpy() # 1,S,H,W,3
738
+ local_pts = result['local_points'].cpu().numpy() # 1,S,H,W,3
739
+ _,S,H,W,_ = pts.shape
740
+ conf = result['conf'].cpu().numpy()
741
+ point_clouds = [pts[0,s] for s in range(S)]
742
+ #conf_threshold = np.percentile(conf, 15)
743
+ #confs = [conf[0,s]>=conf_threshold for s in range(S)]
744
+ colors = torch.stack(self.frames[-S:]).permute(0,2,3,1).reshape(-1,3).cpu().numpy()[:,::-1] # S,H,W,C
745
+ confs = conf.reshape(-1)
746
+
747
+
748
+ camera_pose = result['camera_poses'].cpu().numpy()[0] # S,4,4
749
+ pts = pts.reshape(-1,3)
750
+ colors = colors.reshape(-1,3)
751
+
752
+ # set depth for the last kf
753
+ self.depth_last_kf = local_pts[0,-1,:,:,-1]
754
+
755
+ return pts, colors, confs, camera_pose
756
+
757
+ def save_result(self, output_path = 'output/tmp', map_all=None, traj=True):
758
+ '''
759
+ if map_all is None:
760
+ map_all = self.map
761
+ '''
762
+ print(self.kfids)
763
+
764
+ if map_all is None:
765
+ map_all = self.map_opt if self.map_opt is not None else self.map
766
+ if map_all is None:
767
+ print(f"[warning] save_result() called with no map data; skipping export for {output_path}.")
768
+ return None
769
+
770
+ # Chunk-process to avoid OOM on long sequences
771
+ # (our finetuning removed torch.no_grad() from model internals)
772
+ S_total = map_all.shape[0]
773
+ chunk_size = 50 # process 50 frames at a time
774
+ all_pts, all_conf, all_poses = [], [], []
775
+
776
+ for start in range(0, S_total, chunk_size):
777
+ end = min(start + chunk_size, S_total)
778
+ chunk = map_all[start:end].cuda()
779
+ torch.cuda.empty_cache()
780
+ result_chunk = self.model.extract(chunk)
781
+ all_pts.append(result_chunk['points'].cpu())
782
+ all_conf.append(result_chunk['conf'].cpu())
783
+ all_poses.append(result_chunk['camera_poses'].cpu())
784
+ del result_chunk, chunk
785
+ torch.cuda.empty_cache()
786
+
787
+ pts = torch.cat(all_pts, dim=1).numpy() # 1,S,H,W,3
788
+ conf = torch.cat(all_conf, dim=1).numpy()
789
+ camera_pose = torch.cat(all_poses, dim=1) # 1,S,4,4
790
+
791
+ _,S,H,W,_ = pts.shape
792
+ point_clouds = [pts[0,s] for s in range(S)]
793
+ conf_threshold = np.percentile(conf, 15)
794
+ confs = [conf[0,s]>=conf_threshold for s in range(S)]
795
+
796
+ colors = torch.stack(self.frames).permute(0,2,3,1).reshape(-1,3).cpu().numpy()[:,::-1] # S,H,W,C
797
+ msk = np.stack(confs).reshape(-1)
798
+ pcd = o3d.geometry.PointCloud()
799
+ pcd.points = o3d.utility.Vector3dVector(pts.reshape(-1,3).astype(np.float64)[msk])
800
+ pcd.colors = o3d.utility.Vector3dVector(colors.reshape(-1,3).astype(np.float64)[msk])
801
+ #downpcd = pcd.voxel_down_sample(voxel_size=0.005)
802
+ o3d.io.write_point_cloud(f"{output_path}.ply", pcd)
803
+ poses = camera_pose[0].numpy()
804
+
805
+ self.write_poses_to_file(f"{output_path}_traj.txt", poses, self.kf_timestamps)
806
+ self.save_framewise_pointclouds(f"{output_path}_pc", point_clouds, self.kf_timestamps, confs)
807
+
808
+ return {'points': torch.from_numpy(pts), 'conf': torch.from_numpy(conf), 'camera_poses': camera_pose}
809
+
810
+ def write_poses_to_file(self, filename, poses, frame_ids):
811
+
812
+ with open(filename, "w") as f:
813
+ assert len(poses) == len(frame_ids), "Number of provided poses and number of frame ids do not match"
814
+ for frame_id, pose in zip(frame_ids, poses):
815
+ x, y, z = pose[0:3, 3]
816
+ rotation_matrix = pose[0:3, 0:3]
817
+ quaternion = R.from_matrix(rotation_matrix).as_quat() # x, y, z, w
818
+ output = np.array([float(frame_id), x, y, z, *quaternion])
819
+ f.write(" ".join(f"{v:.8f}" for v in output) + "\n")
820
+
821
+ def save_framewise_pointclouds(self, filename, pointclouds, frame_ids, conf_masks):
822
+ os.makedirs(filename, exist_ok=True)
823
+ for frame_id, pointcloud, conf_masks in zip(frame_ids, pointclouds, conf_masks):
824
+ # save pcd as numpy array
825
+ np.savez(f"{filename}/{frame_id}.npz", pointcloud=pointcloud, mask=conf_masks)
826
+
827
+
828
+ def get_parser():
829
+ parser = argparse.ArgumentParser(description="SLAM-Former demo")
830
+ parser.add_argument("--ckpt_path", type=str, default="path/to/checkpoint.pth.model", help="Path to the checkpoint")
831
+ parser.add_argument("--image_folder", type=str, default="path/to/image/folder", help="Path to folder containing images")
832
+ parser.add_argument("--target_size", type=int, default=518, help="the target size of image(longer side)")
833
+ parser.add_argument("--output_dir", type=str, default="outputs/tmp", help="Path to save the output")
834
+ parser.add_argument("--stride", type=int, default=1, help="Frame stride for subsampling the input sequence")
835
+ parser.add_argument("--kf_th", type=float, default=0.1, help="Keyframe selection threshold (minimum translation distance)")
836
+ parser.add_argument("--retention_ratio", type=float, default=None, help="KV Pruning retention ratio")
837
+ parser.add_argument("--bn_every", type=int, default=None, help="Run backend optimization every N keyframes")
838
+ parser.add_argument("--loop_mask_mode", type=str, default=None, choices=["hard_top1", "soft_all"], help="Override loop retrieval masking mode")
839
+ parser.add_argument("--submap_train_mode", type=str, default=None, choices=["full_token", "top5_dual_queue"], help="Override submap queue mode")
840
+ parser.add_argument("--submap_retrieval_topk", type=int, default=None, help="Override number of historical submaps fetched in soft_all mode")
841
+ parser.add_argument("--submap_fetch_source", type=str, default=None, choices=["frontend", "backend"], help="Override token source used for retrieval")
842
+ parser.add_argument("--submap_descriptor_source", type=str, default=None, choices=["frontend", "backend"], help="Override descriptor source used for retrieval")
843
+ parser.add_argument("--max_recursive_submaps", type=int, default=None, help="Override recursive covisibility fetch limit")
844
+ parser.add_argument("--vis", action="store_true", help="Enable real-time visualization with Rerun")
845
+ parser.add_argument("--resize_rate", type=float, default=1, help="Resize rate for input images before processing")
846
+
847
+ args = parser.parse_args()
848
+ return args
849
+
850
+
851
+ if __name__ == '__main__':
852
+ args = get_parser()
853
+ image_folder = args.image_folder
854
+ outdir = args.output_dir
855
+ os.makedirs(outdir, exist_ok=True)
856
+
857
+ if 'tum' in args.image_folder:
858
+ fx = 525.0 # focal length x
859
+ fy = 525.0 # focal length y
860
+ cx = 319.5 # optical center x
861
+ cy = 239.5 # optical center y
862
+ K = np.eye(3)
863
+ K[0,0] = fx
864
+ K[1,1] = fy
865
+ K[0,2] = cx
866
+ K[1,2] = cy
867
+ elif 'Replica' in args.image_folder:
868
+ fx = 600. # focal length x
869
+ fy = 600.0 # focal length y
870
+ cx = 599.5 # optical center x
871
+ cy = 339.5 # optical center y
872
+ K = np.eye(3)
873
+ K[0,0] = fx
874
+ K[1,1] = fy
875
+ K[0,2] = cx
876
+ K[1,2] = cy
877
+ else:
878
+ K = None
879
+
880
+
881
+ # Use the provided image folder path
882
+ print(f"Loading images from {image_folder}...")
883
+ image_names = [f for f in glob.glob(os.path.join(image_folder, "*"))
884
+ if "depth" not in os.path.basename(f).lower() and "txt" not in os.path.basename(f).lower()
885
+ and "db" not in os.path.basename(f).lower()]
886
+ image_names = utils.sort_images_by_number(image_names)
887
+
888
+ frame_ids = []
889
+ for path in image_names:
890
+ filename = os.path.basename(path)
891
+ match = re.search(r'\d+(?:\.\d+)?', filename) # matches integers and decimals
892
+ if match:
893
+ frame_ids.append(float(match.group()))
894
+ else:
895
+ raise ValueError(f"No number found in image name: {filename}")
896
+
897
+ print(f"Found {len(image_names)} images")
898
+
899
+ print('resize image', args.resize_rate)
900
+
901
+ slam = SLAM(
902
+ outdir=outdir,
903
+ kf_th=args.kf_th,
904
+ bn_every=args.bn_every,
905
+ vis=args.vis,
906
+ ckpt_path=args.ckpt_path,
907
+ target_size=args.target_size,
908
+ retention_ratio=args.retention_ratio,
909
+ loop_mask_mode=args.loop_mask_mode,
910
+ submap_train_mode=args.submap_train_mode,
911
+ submap_retrieval_topk=args.submap_retrieval_topk,
912
+ submap_fetch_source=args.submap_fetch_source,
913
+ submap_descriptor_source=args.submap_descriptor_source,
914
+ max_recursive_submaps=args.max_recursive_submaps,
915
+ )
916
+
917
+ slam.K = K
918
+
919
+ for frame_id, image_name in zip(frame_ids[::args.stride], image_names[::args.stride]):
920
+ img = cv2.imread(image_name)
921
+
922
+ if args.resize_rate != 1:
923
+ H,W,_ = img.shape
924
+ img = cv2.resize(img, (int(W*args.resize_rate), int(H*args.resize_rate)), cv2.INTER_CUBIC)
925
+ slam.step(frame_id, img)
926
+ result = slam.terminate()
927
+
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/download_data.sh ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ──────────────────────────────────────────────────────────
3
+ # Local download checklist template for SLAM-Former fine-tuning
4
+ #
5
+ # README-backed links:
6
+ # - ARKitScenes / MVS-Synth / ScanNet -> Hugging Face SLF dataset tree
7
+ # (ScanNet is stored as a split archive: processed_scannetv2.zip.part.aa/.ab/.ac/.ad)
8
+ # - HyperSim -> Hugging Face preprocessed_Hypersim tree
9
+ # - ScanNet++ / BlendedMVS / MegaDepth -> README says "coming soon" or no direct archive command is listed here
10
+ #
11
+ # This script is intentionally split into per-dataset toggles so you can fill
12
+ # the missing datasets one by one on the local machine without overwriting
13
+ # existing extracted folders.
14
+ #
15
+ # Example:
16
+ # DOWNLOAD_ARKITSCENES=1 DOWNLOAD_SCANNET=1 bash slam/download_data.sh
17
+ # ──────────────────────────────────────────────────────────
18
+ set -euo pipefail
19
+
20
+ PROJECT_DIR="/var/scratch/qzhang2/SLAM-Former"
21
+ DATA_DIR="$PROJECT_DIR/data/train"
22
+ CKPT_DIR="$PROJECT_DIR/ckpt"
23
+ LOG_FILE="$PROJECT_DIR/download.log"
24
+
25
+ mkdir -p "$DATA_DIR" "$CKPT_DIR"
26
+
27
+ log() { echo "[$(date '+%Y-%m-%d %H:%M:%S')] $*" | tee -a "$LOG_FILE"; }
28
+
29
+ HF_BASE="https://huggingface.co/datasets/KevinConnorLee/SLF/resolve/main"
30
+ HF_HYPERSIM_BASE="https://huggingface.co/datasets/KevinConnorLee/preprocessed_Hypersim/resolve/main"
31
+ HF_CKPT="https://huggingface.co/Jarrome/SLAM-Former/resolve/main/518/checkpoint-10.pth.model"
32
+ HF_ARKITSCENES_REPO="https://huggingface.co/datasets/Pointcept/arkitscenes-compressed"
33
+ HF_SCANNETPP_REPO="https://huggingface.co/datasets/Pointcept/scannetpp-compressed"
34
+ HF_HYPERSIM_REPO="https://huggingface.co/datasets/geyongtao/hypersim"
35
+ BLENDEDMVS_LOWRES_URL="https://1drv.ms/u/s!Ag8Dbz2Aqc81gVDu7FHfbPZwqhIy?e=BHY07t"
36
+ BLENDEDMVS_HIGHRES_URL="https://1drv.ms/u/s!Ag8Dbz2Aqc81ezb9OciQ4zKwJ_w?e=afFOTi"
37
+ MEGADEPTH_V1_URL="https://www.cs.cornell.edu/projects/megadepth/dataset/Megadepth_v1/MegaDepth_v1.tar.gz"
38
+
39
+ DOWNLOAD_CHECKPOINT="${DOWNLOAD_CHECKPOINT:-1}"
40
+ KEEP_ARCHIVES="${KEEP_ARCHIVES:-0}"
41
+ DOWNLOAD_ARKITSCENES="${DOWNLOAD_ARKITSCENES:-0}"
42
+ DOWNLOAD_SCANNETPP="${DOWNLOAD_SCANNETPP:-0}"
43
+ DOWNLOAD_MVS_SYNTH="${DOWNLOAD_MVS_SYNTH:-0}"
44
+ DOWNLOAD_SCANNET="${DOWNLOAD_SCANNET:-0}"
45
+ DOWNLOAD_HYPERSIM="${DOWNLOAD_HYPERSIM:-0}"
46
+ DOWNLOAD_BLENDEDMVS="${DOWNLOAD_BLENDEDMVS:-0}"
47
+ DOWNLOAD_MEGADEPTH="${DOWNLOAD_MEGADEPTH:-0}"
48
+ HF_BLENDEDMVS_BASE="${HF_BLENDEDMVS_BASE:-}"
49
+ HF_MEGADEPTH_BASE="${HF_MEGADEPTH_BASE:-}"
50
+ RELEASED_COMPLETION_MARKER="$PROJECT_DIR/.download_complete_released_paper_datasets"
51
+ FULL_COMPLETION_MARKER="$PROJECT_DIR/.download_complete_all_requested_paper_datasets"
52
+ IN_PROGRESS_MARKER="$PROJECT_DIR/.download_in_progress"
53
+
54
+ rm -f "$PROJECT_DIR/.download_complete" "$RELEASED_COMPLETION_MARKER" "$FULL_COMPLETION_MARKER" "$IN_PROGRESS_MARKER"
55
+ touch "$IN_PROGRESS_MARKER"
56
+
57
+ download_file() {
58
+ local url="$1"
59
+ local output="$2"
60
+ if [ -f "$output" ]; then
61
+ log "Found existing file $(basename "$output"), attempting resume if incomplete."
62
+ fi
63
+ wget -c --progress=bar:force -O "$output" "$url" 2>&1 | tee -a "$LOG_FILE"
64
+ }
65
+
66
+ download_hf_repo_dataset() {
67
+ local label="$1"
68
+ local expected_dir="$2"
69
+ local repo_url="$3"
70
+ if [ -d "$expected_dir" ] && [ -n "$(find "$expected_dir" -mindepth 1 -maxdepth 1 2>/dev/null | head -n 1)" ]; then
71
+ log "$label already exists, skipping."
72
+ return
73
+ fi
74
+ rm -rf "$expected_dir"
75
+ mkdir -p "$expected_dir"
76
+ log "=== Cloning $label from Hugging Face ==="
77
+ git clone --depth 1 "$repo_url" "$expected_dir" 2>&1 | tee -a "$LOG_FILE"
78
+ if command -v git-lfs >/dev/null 2>&1; then
79
+ log "Fetching LFS files for $label..."
80
+ (cd "$expected_dir" && git lfs pull) 2>&1 | tee -a "$LOG_FILE"
81
+ else
82
+ log "WARNING: git-lfs not found; $label may contain LFS pointer files only."
83
+ fi
84
+ }
85
+
86
+ extract_archive_auto() {
87
+ local archive="$1"
88
+ local target_dir="$2"
89
+ mkdir -p "$target_dir"
90
+ if unzip -t "$archive" >/dev/null 2>&1; then
91
+ unzip -o "$archive" -d "$target_dir" 2>&1 | tee -a "$LOG_FILE"
92
+ return 0
93
+ fi
94
+ if tar -tf "$archive" >/dev/null 2>&1; then
95
+ tar -xf "$archive" -C "$target_dir" 2>&1 | tee -a "$LOG_FILE"
96
+ return 0
97
+ fi
98
+ log "ERROR: Unsupported archive format for $archive"
99
+ return 1
100
+ }
101
+
102
+ download_url_dataset() {
103
+ local label="$1"
104
+ local expected_dir="$2"
105
+ local url="$3"
106
+ local archive_name="$4"
107
+ if [ -d "$expected_dir" ] && [ -n "$(find "$expected_dir" -mindepth 1 -maxdepth 1 2>/dev/null | head -n 1)" ]; then
108
+ log "$label already exists, skipping."
109
+ return
110
+ fi
111
+ log "=== Downloading $label ==="
112
+ download_file "$url" "$DATA_DIR/$archive_name"
113
+ extract_archive_auto "$DATA_DIR/$archive_name" "$expected_dir"
114
+ if [ "$KEEP_ARCHIVES" != "1" ]; then
115
+ rm -f "$DATA_DIR/$archive_name"
116
+ fi
117
+ if [ -d "$expected_dir" ]; then
118
+ log "$label done."
119
+ else
120
+ log "WARNING: $label archive extracted, but expected path is still missing: $expected_dir"
121
+ fi
122
+ }
123
+
124
+ assemble_parts() {
125
+ local output="$1"
126
+ shift
127
+ if [ -f "$output" ]; then
128
+ log "Found existing assembled archive $(basename "$output"), skipping assembly."
129
+ return
130
+ fi
131
+ cat "$@" > "${output}.tmp"
132
+ mv "${output}.tmp" "$output"
133
+ }
134
+
135
+ rename_if_needed() {
136
+ local expected="$1"
137
+ shift
138
+ if [ -e "$expected" ]; then
139
+ return
140
+ fi
141
+ local candidate
142
+ for candidate in "$@"; do
143
+ if [ -e "$candidate" ]; then
144
+ mv "$candidate" "$expected"
145
+ return
146
+ fi
147
+ done
148
+ }
149
+
150
+ extract_zip() {
151
+ local archive="$1"
152
+ unzip -o "$archive" -d "$DATA_DIR/" 2>&1 | tee -a "$LOG_FILE"
153
+ if [ "$KEEP_ARCHIVES" != "1" ]; then
154
+ rm -f "$archive"
155
+ fi
156
+ }
157
+
158
+ cleanup_parts() {
159
+ if [ "$KEEP_ARCHIVES" = "1" ]; then
160
+ return
161
+ fi
162
+ rm -f "$@"
163
+ }
164
+
165
+ download_single_archive_dataset() {
166
+ local label="$1"
167
+ local expected_dir="$2"
168
+ local archive_name="$3"
169
+ local base_url="$4"
170
+ shift 4
171
+ if [ -d "$expected_dir" ]; then
172
+ log "$label already exists, skipping."
173
+ return
174
+ fi
175
+ cd "$DATA_DIR"
176
+ log "=== Downloading $label ==="
177
+ download_file "$base_url/$archive_name" "$DATA_DIR/$archive_name"
178
+ log "Extracting $label..."
179
+ extract_zip "$DATA_DIR/$archive_name"
180
+ rename_if_needed "$expected_dir" "$@"
181
+ if [ -d "$expected_dir" ]; then
182
+ log "$label done."
183
+ else
184
+ log "WARNING: $label archive extracted, but expected path is still missing: $expected_dir"
185
+ fi
186
+ }
187
+
188
+ download_split_archive_dataset() {
189
+ local label="$1"
190
+ local expected_dir="$2"
191
+ local assembled_archive="$3"
192
+ local base_url="$4"
193
+ local aliases_string="$5"
194
+ shift 5
195
+ local parts=("$@")
196
+ if [ -d "$expected_dir" ] && [ -n "$(find "$expected_dir" -mindepth 1 -maxdepth 1 2>/dev/null | head -n 1)" ]; then
197
+ log "$label already exists, skipping."
198
+ return
199
+ fi
200
+ cd "$DATA_DIR"
201
+ log "=== Downloading $label ==="
202
+ local downloaded_parts=()
203
+ local part
204
+ for part in "${parts[@]}"; do
205
+ download_file "$base_url/$part" "$DATA_DIR/$part"
206
+ downloaded_parts+=("$DATA_DIR/$part")
207
+ done
208
+ log "Assembling $label archive..."
209
+ assemble_parts "$DATA_DIR/$assembled_archive" "${downloaded_parts[@]}"
210
+ log "Extracting $label..."
211
+ extract_zip "$DATA_DIR/$assembled_archive"
212
+ IFS='|' read -r -a alias_candidates <<< "$aliases_string"
213
+ local alias_paths=()
214
+ local alias
215
+ for alias in "${alias_candidates[@]}"; do
216
+ alias_paths+=("$DATA_DIR/$alias")
217
+ done
218
+ rename_if_needed "$expected_dir" "${alias_paths[@]}"
219
+ cleanup_parts "${downloaded_parts[@]}"
220
+ if [ -d "$expected_dir" ]; then
221
+ log "$label done."
222
+ else
223
+ log "WARNING: $label archive extracted, but expected path is still missing: $expected_dir"
224
+ fi
225
+ }
226
+
227
+ if [ "$DOWNLOAD_CHECKPOINT" = "1" ]; then
228
+ log "=== Downloading pretrained checkpoint (3.84 GB) ==="
229
+ if [ ! -f "$CKPT_DIR/checkpoint-10.pth.model" ]; then
230
+ download_file "$HF_CKPT" "$CKPT_DIR/checkpoint-10.pth.model"
231
+ log "Checkpoint downloaded."
232
+ else
233
+ log "Checkpoint already exists, skipping."
234
+ fi
235
+ fi
236
+
237
+ if [ "$DOWNLOAD_ARKITSCENES" = "1" ]; then
238
+ download_hf_repo_dataset \
239
+ "ARKitScenes (HF repo: Pointcept/arkitscenes-compressed)" \
240
+ "$DATA_DIR/processed_arkitscenes" \
241
+ "$HF_ARKITSCENES_REPO"
242
+ fi
243
+
244
+ if [ "$DOWNLOAD_SCANNETPP" = "1" ]; then
245
+ download_hf_repo_dataset \
246
+ "ScanNet++ (HF repo: Pointcept/scannetpp-compressed)" \
247
+ "$DATA_DIR/processed_scannetpp" \
248
+ "$HF_SCANNETPP_REPO"
249
+ fi
250
+
251
+ if [ "$DOWNLOAD_MVS_SYNTH" = "1" ]; then
252
+ download_single_archive_dataset \
253
+ "MVS-Synth (README-backed HF tree)" \
254
+ "$DATA_DIR/processed_mvs_synth" \
255
+ "processed_mvs_synth.zip" \
256
+ "$HF_BASE" \
257
+ "processed_mvs_synth" \
258
+ "$DATA_DIR/processed_mvs_synth"
259
+ fi
260
+
261
+ if [ "$DOWNLOAD_SCANNET" = "1" ]; then
262
+ download_split_archive_dataset \
263
+ "ScanNet (HF split archive: KevinConnorLee/SLF)" \
264
+ "$DATA_DIR/processed_scannet" \
265
+ "processed_scannetv2.zip" \
266
+ "$HF_BASE" \
267
+ "processed_scannetv2|processed_scannet" \
268
+ "processed_scannetv2.zip.part.aa" \
269
+ "processed_scannetv2.zip.part.ab" \
270
+ "processed_scannetv2.zip.part.ac" \
271
+ "processed_scannetv2.zip.part.ad"
272
+ fi
273
+
274
+ if [ "$DOWNLOAD_HYPERSIM" = "1" ]; then
275
+ download_hf_repo_dataset \
276
+ "HyperSim (HF repo: geyongtao/hypersim)" \
277
+ "$DATA_DIR/hypersim" \
278
+ "$HF_HYPERSIM_REPO"
279
+ fi
280
+
281
+ if [ "$DOWNLOAD_BLENDEDMVS" = "1" ]; then
282
+ BLENDEDMVS_URL="${BLENDEDMVS_URL:-$BLENDEDMVS_LOWRES_URL}"
283
+ BLENDEDMVS_ARCHIVE_NAME="${BLENDEDMVS_ARCHIVE_NAME:-blendedmvs_lowres.download}"
284
+ if [ "${BLENDEDMVS_VARIANT:-lowres}" = "highres" ]; then
285
+ BLENDEDMVS_URL="$BLENDEDMVS_HIGHRES_URL"
286
+ BLENDEDMVS_ARCHIVE_NAME="${BLENDEDMVS_ARCHIVE_NAME:-blendedmvs_highres.download}"
287
+ fi
288
+ download_url_dataset \
289
+ "BlendedMVS (${BLENDEDMVS_VARIANT:-lowres})" \
290
+ "$DATA_DIR/processed_blendedmvs" \
291
+ "$BLENDEDMVS_URL" \
292
+ "$BLENDEDMVS_ARCHIVE_NAME"
293
+ fi
294
+
295
+ if [ "$DOWNLOAD_MEGADEPTH" = "1" ] && [ ! -d "$DATA_DIR/processed_megadepth" ]; then
296
+ download_url_dataset \
297
+ "MegaDepth v1 (Cornell official archive)" \
298
+ "$DATA_DIR/processed_megadepth" \
299
+ "$MEGADEPTH_V1_URL" \
300
+ "MegaDepth_v1.tar.gz"
301
+ fi
302
+
303
+ log "=== Download complete ==="
304
+ log "Disk usage:"
305
+ du -sh "$DATA_DIR"/* "$CKPT_DIR"/* 2>/dev/null | tee -a "$LOG_FILE"
306
+ log "Total:"
307
+ du -sh "$DATA_DIR" "$CKPT_DIR" | tee -a "$LOG_FILE"
308
+
309
+ missing_released=()
310
+ missing_all_requested=()
311
+
312
+ if [ "$DOWNLOAD_CHECKPOINT" = "1" ] && [ ! -f "$CKPT_DIR/checkpoint-10.pth.model" ]; then
313
+ missing_released+=("ckpt/checkpoint-10.pth.model")
314
+ fi
315
+ if [ ! -d "$DATA_DIR/processed_scannetpp" ]; then
316
+ missing_released+=("data/train/processed_scannetpp")
317
+ fi
318
+ if [ ! -d "$DATA_DIR/processed_mvs_synth" ]; then
319
+ missing_released+=("data/train/processed_mvs_synth")
320
+ fi
321
+ if [ ! -d "$DATA_DIR/processed_arkitscenes" ]; then
322
+ missing_released+=("data/train/processed_arkitscenes")
323
+ fi
324
+ if [ "$DOWNLOAD_SCANNET" = "1" ] && [ ! -d "$DATA_DIR/processed_scannet" ]; then
325
+ missing_released+=("data/train/processed_scannet")
326
+ fi
327
+ if [ "$DOWNLOAD_HYPERSIM" = "1" ] && [ ! -d "$DATA_DIR/hypersim" ]; then
328
+ missing_released+=("data/train/hypersim")
329
+ fi
330
+ if [ "$DOWNLOAD_BLENDEDMVS" = "1" ] && [ ! -d "$DATA_DIR/processed_blendedmvs" ]; then
331
+ missing_released+=("data/train/processed_blendedmvs")
332
+ fi
333
+
334
+ missing_all_requested=("${missing_released[@]}")
335
+ if [ "$DOWNLOAD_MEGADEPTH" = "1" ] && [ ! -d "$DATA_DIR/processed_megadepth" ]; then
336
+ missing_all_requested+=("data/train/processed_megadepth")
337
+ fi
338
+
339
+ rm -f "$IN_PROGRESS_MARKER"
340
+
341
+ if [ "${#missing_released[@]}" -eq 0 ]; then
342
+ touch "$RELEASED_COMPLETION_MARKER"
343
+ log "Released paper datasets are complete. Marker created at $RELEASED_COMPLETION_MARKER"
344
+ else
345
+ log "WARNING: Missing released download targets: ${missing_released[*]}"
346
+ fi
347
+
348
+ if [ "${#missing_all_requested[@]}" -eq 0 ]; then
349
+ touch "$FULL_COMPLETION_MARKER"
350
+ touch "$PROJECT_DIR/.download_complete"
351
+ log "All requested datasets are complete. Markers created at $FULL_COMPLETION_MARKER and $PROJECT_DIR/.download_complete"
352
+ else
353
+ log "WARNING: Missing requested targets: ${missing_all_requested[*]}"
354
+ fi
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/exp_joint_freeze_frontend_fsdp_8gpu.sh ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=sf_smoke_joint_freeze_frontend_fsdp_2gpu
3
+ #SBATCH --nodes=1
4
+ #SBATCH --ntasks-per-node=1
5
+ #SBATCH --cpus-per-task=12
6
+ #SBATCH --gres=gpu:2
7
+ #SBATCH --mem=24G
8
+ #SBATCH --time=24:00:00
9
+ #SBATCH --output=%x_%j.out
10
+ #SBATCH --error=%x_%j.err
11
+
12
+ set -euo pipefail
13
+
14
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
15
+ PROJECT_DIR="${PROJECT_DIR:-${SLURM_SUBMIT_DIR:-$(dirname "$SCRIPT_DIR")}}"
16
+ SRC_DIR="$PROJECT_DIR/src"
17
+
18
+ # 迁移到其他集群时优先修改:GPU/端口、conda 初始化脚本、数据目录、预训练权重和保存目录。
19
+ MASTER_PORT="${MASTER_PORT:-29662}"
20
+ NUM_GPUS="${NUM_GPUS:-8}"
21
+ CONDA_SH="${CONDA_SH:-/home/23068142r/miniconda3/etc/profile.d/conda.sh}"
22
+ CONDA_ENV_NAME="SLAM-Former"
23
+ DATA_ROOT="${DATA_ROOT:-/home/23068142r/work_dir/data}"
24
+ ROOT_ARKIT="${ROOT_ARKIT:-$DATA_ROOT/processed_arkitscenes}"
25
+ ROOT_SCANNETPP="${ROOT_SCANNETPP:-$DATA_ROOT/preprocessed_scannetpp}"
26
+ ROOT_SCANNET="${ROOT_SCANNET:-$DATA_ROOT/processed_scannet}"
27
+ ROOT_SCANNET_FALLBACK="${ROOT_SCANNET_FALLBACK:-$DATA_ROOT/processed_scannetv2}"
28
+ ROOT_HYPERSIM="${ROOT_HYPERSIM:-$DATA_ROOT/preprocessed_Hypersim}"
29
+ ROOT_BLENDEDMVS="${ROOT_BLENDEDMVS:-$DATA_ROOT/processed_blendedmvs}"
30
+ ROOT_MEGADEPTH="${ROOT_MEGADEPTH:-$DATA_ROOT/processed_megadepth}"
31
+ ROOT_MVS_SYNTH="${ROOT_MVS_SYNTH:-$DATA_ROOT/processed_mvs_synth}"
32
+ EXPERIMENT_ROOT="${EXPERIMENT_ROOT:-paper_smoke_local_8gpu}"
33
+ VARIANT_NAME="${VARIANT_NAME:-joint_freeze_frontend_fsdp_sub12}"
34
+ EXP_NAME="${EXP_NAME:-paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12}"
35
+ SAVE_DIR="${SAVE_DIR:-$PROJECT_DIR/checkpoints/$EXPERIMENT_ROOT/$VARIANT_NAME}"
36
+ PRETRAINED="${PRETRAINED:-/home/23068142r/work_dir/projects/e2e-semantic-SLAM-submap/ckpt/checkpoint-10.pth.model}"
37
+ RESUME="${RESUME:-null}"
38
+
39
+ CONFIG_NAME="${CONFIG_NAME:-finetune_paper_h20.yaml}"
40
+ DIST_STRATEGY="${DIST_STRATEGY:-fsdp}"
41
+ AUTO_DISABLE_MISSING="${AUTO_DISABLE_MISSING:-1}"
42
+ TRAIN_SUBMAP_MODULES_ONLY="${TRAIN_SUBMAP_MODULES_ONLY:-0}"
43
+ DETACH_FRONTEND_TOKENS="${DETACH_FRONTEND_TOKENS:-1}"
44
+ NUM_VIEWS_ALL="${NUM_VIEWS_ALL:-64}"
45
+ NUM_VIEWS_ARKIT="${NUM_VIEWS_ARKIT:-64}"
46
+ NUM_VIEWS_SCANNETPP="${NUM_VIEWS_SCANNETPP:-24}"
47
+ NUM_VIEWS_SCANNET="${NUM_VIEWS_SCANNET:-64}"
48
+ NUM_VIEWS_HYPERSIM="${NUM_VIEWS_HYPERSIM:-24}"
49
+ NUM_VIEWS_BLENDEDMVS="${NUM_VIEWS_BLENDEDMVS:-64}"
50
+ NUM_VIEWS_MEGADEPTH="${NUM_VIEWS_MEGADEPTH:-64}"
51
+ NUM_VIEWS_MVS_SYNTH="${NUM_VIEWS_MVS_SYNTH:-24}"
52
+ SUBMAP_SIZE="${SUBMAP_SIZE:-12}"
53
+ SUBMAP_TRAIN_MODE="${SUBMAP_TRAIN_MODE:-full_token}"
54
+ SUBMAP_RETRIEVAL_TOPK="${SUBMAP_RETRIEVAL_TOPK:-0}"
55
+ SUBMAP_FETCH_SOURCE="${SUBMAP_FETCH_SOURCE:-frontend}"
56
+ SUBMAP_DESCRIPTOR_SOURCE="${SUBMAP_DESCRIPTOR_SOURCE:-frontend}"
57
+ ENABLE_PSEUDO_GT="${ENABLE_PSEUDO_GT:-0}"
58
+ PSEUDO_GT_CACHE_PATH="${PSEUDO_GT_CACHE_PATH:-}"
59
+ SKIP_TEST="${SKIP_TEST:-1}"
60
+ EPOCHS="${EPOCHS:-2}"
61
+ SAMPLES_ARKIT="${SAMPLES_ARKIT:-0}"
62
+ SAMPLES_SCANNETPP="${SAMPLES_SCANNETPP:-16}"
63
+ SAMPLES_SCANNET="${SAMPLES_SCANNET:-0}"
64
+ SAMPLES_HYPERSIM="${SAMPLES_HYPERSIM:-16}"
65
+ SAMPLES_BLENDEDMVS="${SAMPLES_BLENDEDMVS:-0}"
66
+ SAMPLES_MEGADEPTH="${SAMPLES_MEGADEPTH:-0}"
67
+ SAMPLES_MVS_SYNTH="${SAMPLES_MVS_SYNTH:-16}"
68
+ GLOBAL_NUM_VIEWS="${GLOBAL_NUM_VIEWS:-}"
69
+
70
+ if [ "$ENABLE_PSEUDO_GT" = "1" ]; then
71
+ if [ -z "$PSEUDO_GT_CACHE_PATH" ] || [ "$PSEUDO_GT_CACHE_PATH" = "null" ]; then
72
+ echo "ERROR: ENABLE_PSEUDO_GT=1 requires PSEUDO_GT_CACHE_PATH to be set."
73
+ exit 1
74
+ fi
75
+ fi
76
+
77
+ if [ ! -f "$PRETRAINED" ]; then
78
+ echo "ERROR: Missing pretrained checkpoint: $PRETRAINED"
79
+ exit 1
80
+ fi
81
+
82
+ if [ ! -f "$CONDA_SH" ]; then
83
+ echo "ERROR: Missing conda init script: $CONDA_SH"
84
+ exit 1
85
+ fi
86
+ source "$CONDA_SH"
87
+ conda activate "$CONDA_ENV_NAME"
88
+ export PATH="$CONDA_PREFIX/bin:$PATH"
89
+ if command -v module >/dev/null 2>&1; then
90
+ module load cuda12.1/toolkit || true
91
+ fi
92
+
93
+ export PYTHONPATH="$PROJECT_DIR/src:$PROJECT_DIR:${PYTHONPATH:-}"
94
+ export OMP_NUM_THREADS="${OMP_NUM_THREADS:-4}"
95
+ export HYDRA_FULL_ERROR="${HYDRA_FULL_ERROR:-1}"
96
+ export PYTORCH_ALLOC_CONF="${PYTORCH_ALLOC_CONF:-expandable_segments:True}"
97
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
98
+ export CONDA_SH
99
+ export CONFIG_NAME
100
+ export EXP_NAME
101
+ export MASTER_PORT
102
+ export DIST_STRATEGY
103
+ export AUTO_DISABLE_MISSING
104
+ export SAVE_DIR
105
+ export PRETRAINED
106
+ export RESUME
107
+ export NUM_GPUS
108
+ export TRAIN_SUBMAP_MODULES_ONLY
109
+ export DETACH_FRONTEND_TOKENS
110
+ export NUM_VIEWS_ALL
111
+ export NUM_VIEWS_ARKIT
112
+ export NUM_VIEWS_SCANNETPP
113
+ export NUM_VIEWS_SCANNET
114
+ export NUM_VIEWS_HYPERSIM
115
+ export NUM_VIEWS_BLENDEDMVS
116
+ export NUM_VIEWS_MEGADEPTH
117
+ export NUM_VIEWS_MVS_SYNTH
118
+ export SUBMAP_TRAIN_MODE
119
+ export SUBMAP_RETRIEVAL_TOPK
120
+ export SUBMAP_FETCH_SOURCE
121
+ export SUBMAP_DESCRIPTOR_SOURCE
122
+ export ENABLE_PSEUDO_GT
123
+ export PSEUDO_GT_CACHE_PATH
124
+ export SAMPLES_ARKIT
125
+ export SAMPLES_SCANNETPP
126
+ export SAMPLES_SCANNET
127
+ export SAMPLES_HYPERSIM
128
+ export SAMPLES_BLENDEDMVS
129
+ export SAMPLES_MEGADEPTH
130
+ export SAMPLES_MVS_SYNTH
131
+ export DATA_ROOT
132
+ export ROOT_ARKIT
133
+ export ROOT_SCANNETPP
134
+ export ROOT_SCANNET
135
+ export ROOT_SCANNET_FALLBACK
136
+ export ROOT_HYPERSIM
137
+ export ROOT_BLENDEDMVS
138
+ export ROOT_MEGADEPTH
139
+ export ROOT_MVS_SYNTH
140
+ export SKIP_TEST
141
+ export GLOBAL_NUM_VIEWS
142
+
143
+ dataset_is_ready() {
144
+ local label="$1"
145
+ local path="$2"
146
+ case "$label" in
147
+ "ARKitScenes")
148
+ [ -f "$path/Training/all_metadata.npz" ]
149
+ ;;
150
+ "ScanNet++")
151
+ [ -f "$path/all_metadata.npz" ]
152
+ ;;
153
+ "ScanNet")
154
+ [ -d "$path/scans_train" ] && [ -n "$(find "$path/scans_train" -mindepth 2 -maxdepth 2 -type f -name 'new_scene_metadata.npz' -print -quit 2>/dev/null)" ]
155
+ ;;
156
+ "HyperSim")
157
+ [ -d "$path" ] && [ -n "$(find "$path" -mindepth 3 -maxdepth 3 -type f -name '*rgb.png' -print -quit 2>/dev/null)" ]
158
+ ;;
159
+ "BlendedMVS")
160
+ [ -f "$path/new_overlap.h5" ]
161
+ ;;
162
+ "MegaDepth")
163
+ [ -f "$path/megadepth_sets_64.npz" ]
164
+ ;;
165
+ "MVS-Synth")
166
+ [ -d "$path" ] && [ -n "$(find "$path" -mindepth 2 -maxdepth 2 -type d -name 'cam' -print -quit 2>/dev/null)" ]
167
+ ;;
168
+ *)
169
+ [ -e "$path" ]
170
+ ;;
171
+ esac
172
+ }
173
+
174
+ dataset_probe_hint() {
175
+ local label="$1"
176
+ case "$label" in
177
+ "ARKitScenes")
178
+ echo "Training/all_metadata.npz"
179
+ ;;
180
+ "ScanNet++")
181
+ echo "all_metadata.npz"
182
+ ;;
183
+ "ScanNet")
184
+ echo "scans_train/*/new_scene_metadata.npz"
185
+ ;;
186
+ "HyperSim")
187
+ echo "scene/subscene/*rgb.png"
188
+ ;;
189
+ "BlendedMVS")
190
+ echo "new_overlap.h5"
191
+ ;;
192
+ "MegaDepth")
193
+ echo "megadepth_sets_64.npz"
194
+ ;;
195
+ "MVS-Synth")
196
+ echo "*/cam"
197
+ ;;
198
+ *)
199
+ echo "required dataset files"
200
+ ;;
201
+ esac
202
+ }
203
+
204
+ resolve_scannet_root() {
205
+ local preferred="$1"
206
+ local fallback="$2"
207
+ if dataset_is_ready "ScanNet" "$preferred"; then
208
+ echo "$preferred"
209
+ return
210
+ fi
211
+ if [ "$fallback" != "$preferred" ] && dataset_is_ready "ScanNet" "$fallback"; then
212
+ echo "INFO: ScanNet root $preferred is incomplete; falling back to $fallback" >&2
213
+ echo "$fallback"
214
+ return
215
+ fi
216
+ echo "$preferred"
217
+ }
218
+
219
+ handle_missing_dataset() {
220
+ local label="$1"
221
+ local path="$2"
222
+ local weight_var="$3"
223
+ local weight="${!weight_var}"
224
+ if [ "$weight" -le 0 ]; then
225
+ return
226
+ fi
227
+ if ! dataset_is_ready "$label" "$path"; then
228
+ local probe_hint
229
+ probe_hint="$(dataset_probe_hint "$label")"
230
+ if [ "$AUTO_DISABLE_MISSING" = "1" ]; then
231
+ echo "WARNING: Missing or incomplete ${label} dataset root: ${path}"
232
+ echo "WARNING: Expected ${probe_hint} under ${path}"
233
+ echo "WARNING: Disabling ${label} by setting ${weight_var}=0"
234
+ printf -v "$weight_var" '0'
235
+ else
236
+ echo "ERROR: Missing or incomplete ${label} dataset root: ${path}"
237
+ echo "ERROR: Expected ${probe_hint} under ${path}"
238
+ exit 1
239
+ fi
240
+ fi
241
+ }
242
+
243
+ append_dataset() {
244
+ local weight="$1"
245
+ local token="$2"
246
+ if [ "$weight" -le 0 ]; then
247
+ return
248
+ fi
249
+ DATASET_PARTS+=("${weight} @ \${${token}}")
250
+ }
251
+
252
+ mkdir -p "$SAVE_DIR/$EXP_NAME"
253
+
254
+ ROOT_SCANNET="$(resolve_scannet_root "$ROOT_SCANNET" "$ROOT_SCANNET_FALLBACK")"
255
+
256
+ handle_missing_dataset "ARKitScenes" "$ROOT_ARKIT" SAMPLES_ARKIT
257
+ handle_missing_dataset "ScanNet++" "$ROOT_SCANNETPP" SAMPLES_SCANNETPP
258
+ handle_missing_dataset "ScanNet" "$ROOT_SCANNET" SAMPLES_SCANNET
259
+ handle_missing_dataset "HyperSim" "$ROOT_HYPERSIM" SAMPLES_HYPERSIM
260
+ handle_missing_dataset "BlendedMVS" "$ROOT_BLENDEDMVS" SAMPLES_BLENDEDMVS
261
+ handle_missing_dataset "MegaDepth" "$ROOT_MEGADEPTH" SAMPLES_MEGADEPTH
262
+ handle_missing_dataset "MVS-Synth" "$ROOT_MVS_SYNTH" SAMPLES_MVS_SYNTH
263
+
264
+ if [ ! -f "$PRETRAINED" ]; then
265
+ echo "ERROR: Missing pretrained checkpoint: $PRETRAINED"
266
+ exit 1
267
+ fi
268
+
269
+ PSEUDO_GT_OVERRIDES=()
270
+ if [ "$ENABLE_PSEUDO_GT" = "1" ]; then
271
+ if [ -z "$PSEUDO_GT_CACHE_PATH" ] || [ "$PSEUDO_GT_CACHE_PATH" = "null" ]; then
272
+ echo "ERROR: ENABLE_PSEUDO_GT=1 requires PSEUDO_GT_CACHE_PATH to be set."
273
+ exit 1
274
+ fi
275
+ PSEUDO_GT_OVERRIDES=(
276
+ "pseudo_gt.enable=true"
277
+ "pseudo_gt.cache_path=${PSEUDO_GT_CACHE_PATH}"
278
+ )
279
+ else
280
+ PSEUDO_GT_OVERRIDES=(
281
+ "pseudo_gt.enable=false"
282
+ "pseudo_gt.cache_path=null"
283
+ )
284
+ fi
285
+
286
+ if [ "$TRAIN_SUBMAP_MODULES_ONLY" = "1" ]; then
287
+ TRAIN_SUBMAP_MODULES_ONLY_HYDRA="true"
288
+ else
289
+ TRAIN_SUBMAP_MODULES_ONLY_HYDRA="false"
290
+ fi
291
+
292
+ if [ "$DETACH_FRONTEND_TOKENS" = "1" ]; then
293
+ DETACH_FRONTEND_TOKENS_HYDRA="true"
294
+ else
295
+ DETACH_FRONTEND_TOKENS_HYDRA="false"
296
+ fi
297
+
298
+ DATASET_PARTS=()
299
+ append_dataset "$SAMPLES_ARKIT" dataset_arkit
300
+ append_dataset "$SAMPLES_SCANNETPP" dataset_scannetpp
301
+ append_dataset "$SAMPLES_SCANNET" dataset_scannet
302
+ append_dataset "$SAMPLES_HYPERSIM" dataset_hypersim
303
+ append_dataset "$SAMPLES_BLENDEDMVS" dataset_blendedmvs
304
+ append_dataset "$SAMPLES_MEGADEPTH" dataset_megadepth
305
+ append_dataset "$SAMPLES_MVS_SYNTH" dataset_mvs_synth
306
+
307
+ if [ "${#DATASET_PARTS[@]}" -eq 0 ]; then
308
+ echo "ERROR: No training dataset remains after weight filtering."
309
+ exit 1
310
+ fi
311
+
312
+ if [ -z "$GLOBAL_NUM_VIEWS" ]; then
313
+ GLOBAL_NUM_VIEWS=0
314
+ if [ "$SAMPLES_ARKIT" -gt 0 ] && [ "$NUM_VIEWS_ARKIT" -gt "$GLOBAL_NUM_VIEWS" ]; then
315
+ GLOBAL_NUM_VIEWS="$NUM_VIEWS_ARKIT"
316
+ fi
317
+ if [ "$SAMPLES_SCANNETPP" -gt 0 ] && [ "$NUM_VIEWS_SCANNETPP" -gt "$GLOBAL_NUM_VIEWS" ]; then
318
+ GLOBAL_NUM_VIEWS="$NUM_VIEWS_SCANNETPP"
319
+ fi
320
+ if [ "$SAMPLES_SCANNET" -gt 0 ] && [ "$NUM_VIEWS_SCANNET" -gt "$GLOBAL_NUM_VIEWS" ]; then
321
+ GLOBAL_NUM_VIEWS="$NUM_VIEWS_SCANNET"
322
+ fi
323
+ if [ "$SAMPLES_HYPERSIM" -gt 0 ] && [ "$NUM_VIEWS_HYPERSIM" -gt "$GLOBAL_NUM_VIEWS" ]; then
324
+ GLOBAL_NUM_VIEWS="$NUM_VIEWS_HYPERSIM"
325
+ fi
326
+ if [ "$SAMPLES_BLENDEDMVS" -gt 0 ] && [ "$NUM_VIEWS_BLENDEDMVS" -gt "$GLOBAL_NUM_VIEWS" ]; then
327
+ GLOBAL_NUM_VIEWS="$NUM_VIEWS_BLENDEDMVS"
328
+ fi
329
+ if [ "$SAMPLES_MEGADEPTH" -gt 0 ] && [ "$NUM_VIEWS_MEGADEPTH" -gt "$GLOBAL_NUM_VIEWS" ]; then
330
+ GLOBAL_NUM_VIEWS="$NUM_VIEWS_MEGADEPTH"
331
+ fi
332
+ if [ "$SAMPLES_MVS_SYNTH" -gt 0 ] && [ "$NUM_VIEWS_MVS_SYNTH" -gt "$GLOBAL_NUM_VIEWS" ]; then
333
+ GLOBAL_NUM_VIEWS="$NUM_VIEWS_MVS_SYNTH"
334
+ fi
335
+ fi
336
+
337
+ TRAIN_DATASET="${DATASET_PARTS[0]}"
338
+ for part in "${DATASET_PARTS[@]:1}"; do
339
+ TRAIN_DATASET="${TRAIN_DATASET} + ${part}"
340
+ done
341
+
342
+ HYDRA_ARGS=(
343
+ "--config-name" "$CONFIG_NAME"
344
+ "exp_name=$EXP_NAME"
345
+ "save_dir=$SAVE_DIR"
346
+ "pretrained=$PRETRAINED"
347
+ "resume=$RESUME"
348
+ "data_root=$DATA_ROOT"
349
+ "root_arkit=$ROOT_ARKIT"
350
+ "root_scannetpp=$ROOT_SCANNETPP"
351
+ "root_scannet=$ROOT_SCANNET"
352
+ "root_hypersim=$ROOT_HYPERSIM"
353
+ "root_blendedmvs=$ROOT_BLENDEDMVS"
354
+ "root_megadepth=$ROOT_MEGADEPTH"
355
+ "root_mvs_synth=$ROOT_MVS_SYNTH"
356
+ "num_views=$GLOBAL_NUM_VIEWS"
357
+ "num_views_arkit=$NUM_VIEWS_ARKIT"
358
+ "num_views_scannetpp=$NUM_VIEWS_SCANNETPP"
359
+ "num_views_scannet=$NUM_VIEWS_SCANNET"
360
+ "num_views_hypersim=$NUM_VIEWS_HYPERSIM"
361
+ "num_views_blendedmvs=$NUM_VIEWS_BLENDEDMVS"
362
+ "num_views_megadepth=$NUM_VIEWS_MEGADEPTH"
363
+ "num_views_mvs_synth=$NUM_VIEWS_MVS_SYNTH"
364
+ "train_submap_modules_only=$TRAIN_SUBMAP_MODULES_ONLY_HYDRA"
365
+ "detach_frontend_tokens=$DETACH_FRONTEND_TOKENS_HYDRA"
366
+ "submap_train_mode=$SUBMAP_TRAIN_MODE"
367
+ "submap_retrieval_topk=$SUBMAP_RETRIEVAL_TOPK"
368
+ "submap_fetch_source=$SUBMAP_FETCH_SOURCE"
369
+ "submap_descriptor_source=$SUBMAP_DESCRIPTOR_SOURCE"
370
+ "${PSEUDO_GT_OVERRIDES[@]}"
371
+ "train_dataset=$TRAIN_DATASET"
372
+ "epochs=$EPOCHS"
373
+ )
374
+ if [ "$SKIP_TEST" = "1" ]; then
375
+ HYDRA_ARGS+=("test_dataset=")
376
+ fi
377
+ HYDRA_ARGS+=("$@")
378
+
379
+ echo "=== Starting 2GPU local smoke: joint backend+submap with frozen frontend tokens ==="
380
+ echo " Project dir : $PROJECT_DIR"
381
+ echo " Launch entry : inline"
382
+ echo " Config : $CONFIG_NAME"
383
+ echo " Experiment : $EXP_NAME"
384
+ echo " Save dir : $SAVE_DIR"
385
+ echo " Pretrained : $PRETRAINED"
386
+ echo " Resume : $RESUME"
387
+ echo " Num GPUs : $NUM_GPUS"
388
+ echo " Distributed strategy : $DIST_STRATEGY"
389
+ echo " Train submap only : $TRAIN_SUBMAP_MODULES_ONLY"
390
+ echo " Detach frontend tokens: $DETACH_FRONTEND_TOKENS"
391
+ echo " Num views : arkit=$NUM_VIEWS_ARKIT scannetpp=$NUM_VIEWS_SCANNETPP scannet=$NUM_VIEWS_SCANNET hypersim=$NUM_VIEWS_HYPERSIM blendedmvs=$NUM_VIEWS_BLENDEDMVS megadepth=$NUM_VIEWS_MEGADEPTH mvs_synth=$NUM_VIEWS_MVS_SYNTH"
392
+ echo " Submap size : $SUBMAP_SIZE"
393
+ echo " Samples : arkit=$SAMPLES_ARKIT scannetpp=$SAMPLES_SCANNETPP scannet=$SAMPLES_SCANNET hypersim=$SAMPLES_HYPERSIM blendedmvs=$SAMPLES_BLENDEDMVS megadepth=$SAMPLES_MEGADEPTH mvs_synth=$SAMPLES_MVS_SYNTH"
394
+ echo " Epochs override : ${EPOCHS:-<config default>}"
395
+ echo " Skip test : $SKIP_TEST"
396
+ echo " Train dataset : $TRAIN_DATASET"
397
+ echo " Global num views : $GLOBAL_NUM_VIEWS"
398
+ echo
399
+
400
+ COMMON_ARGS=(
401
+ --num_machines 1
402
+ --num_processes "$NUM_GPUS"
403
+ --main_process_port "$MASTER_PORT"
404
+ --dynamo_backend no
405
+ --mixed_precision bf16
406
+ )
407
+
408
+ if [ "$DIST_STRATEGY" = "fsdp" ]; then
409
+ accelerate launch \
410
+ "${COMMON_ARGS[@]}" \
411
+ --use_fsdp \
412
+ --fsdp_sharding_strategy FULL_SHARD \
413
+ --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP \
414
+ --fsdp_transformer_layer_cls_to_wrap BlockRope \
415
+ --fsdp_state_dict_type FULL_STATE_DICT \
416
+ --fsdp_backward_prefetch BACKWARD_PRE \
417
+ --fsdp_use_orig_params true \
418
+ --fsdp_sync_module_states true \
419
+ --fsdp_activation_checkpointing true \
420
+ "$SRC_DIR/finetune.py" \
421
+ "${HYDRA_ARGS[@]}"
422
+ elif [ "$DIST_STRATEGY" = "ddp" ]; then
423
+ if [ "$NUM_GPUS" -gt 1 ]; then
424
+ accelerate launch \
425
+ --multi_gpu \
426
+ "${COMMON_ARGS[@]}" \
427
+ "$SRC_DIR/finetune.py" \
428
+ "${HYDRA_ARGS[@]}"
429
+ else
430
+ accelerate launch \
431
+ "${COMMON_ARGS[@]}" \
432
+ "$SRC_DIR/finetune.py" \
433
+ "${HYDRA_ARGS[@]}"
434
+ fi
435
+ else
436
+ echo "ERROR: Unsupported DIST_STRATEGY=$DIST_STRATEGY"
437
+ exit 1
438
+ fi
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/graph_gated_memory.py ADDED
@@ -0,0 +1,850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GraphGatedMemoryManager: Sparse Windowed & Recursively Retrieved Submap Backend.
3
+
4
+ Key components:
5
+ - SubMapBuffer: CPU-side storage for historical submap tokens + descriptors
6
+ - GraphGatedMemoryManager: differentiable loop closure with [NO_LOOP] gating,
7
+ recursive covisibility fetching, GPU active workspace management
8
+ - TemporalEmbedWrapper: dual-injection temporal embedding (Fix #7)
9
+ - _safe_oom_retry / _build_temporal_mask: helpers
10
+
11
+ All features are toggled via CLI arguments; when disabled, behaviour is
12
+ identical to the original SLAM-Former pipeline. No original source files
13
+ are modified.
14
+ """
15
+
16
+ import math
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import numpy as np
21
+ from dataclasses import dataclass, field
22
+ from typing import Dict, List, Optional, Set, Tuple
23
+
24
+
25
+ # ═══════════════════════════════════════════════════════════
26
+ # Helpers
27
+ # ═══════════════════════════════════════════════════════════
28
+
29
+ def _safe_oom_retry(fn, *args, **kwargs):
30
+ """Run *fn*; on CUDA OOM, free cache and retry once."""
31
+ try:
32
+ return fn(*args, **kwargs)
33
+ except RuntimeError as e:
34
+ if "out of memory" in str(e):
35
+ torch.cuda.empty_cache()
36
+ return fn(*args, **kwargs)
37
+ raise
38
+
39
+
40
+ def _build_temporal_mask(frame_id_map: torch.Tensor, P: int) -> torch.Tensor:
41
+ """Build a causal attention mask from a non-contiguous frame-id tensor.
42
+
43
+ Args:
44
+ frame_id_map: [L] int tensor — true temporal frame index for every token.
45
+ P: tokens per frame (patch_h*patch_w + register_tokens).
46
+
47
+ Returns:
48
+ attn_mask: [L, L] float tensor where future-frame positions are -inf.
49
+ """
50
+ L = frame_id_map.shape[0]
51
+ fids = frame_id_map.unsqueeze(1) # [L, 1]
52
+ fids_t = frame_id_map.unsqueeze(0) # [1, L]
53
+ # A token at position i may NOT attend to a token at position j
54
+ # if j belongs to a strictly later frame than i (causal).
55
+ future = fids < fids_t # [L, L] bool
56
+ mask = torch.zeros(L, L, device=frame_id_map.device, dtype=torch.float32)
57
+ mask.masked_fill_(future, float("-inf"))
58
+ return mask
59
+
60
+
61
+ def _build_sinusoidal_pe(max_len: int, dim: int) -> torch.Tensor:
62
+ """Fixed sinusoidal position encoding [max_len, dim]."""
63
+ pe = torch.zeros(max_len, dim)
64
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
65
+ div_term = torch.exp(
66
+ torch.arange(0, dim, 2, dtype=torch.float) * (-math.log(10000.0) / dim)
67
+ )
68
+ pe[:, 0::2] = torch.sin(position * div_term)
69
+ pe[:, 1::2] = torch.cos(position * div_term)
70
+ return pe
71
+
72
+
73
+ # ═══════════════════════════════════════════════════════════
74
+ # SubMapBuffer — CPU-side historical storage
75
+ # ═══════════════════════════════════════════════════════════
76
+
77
+ @dataclass
78
+ class SubMapBuffer:
79
+ """Lightweight CPU buffer for completed submaps.
80
+
81
+ Each entry is keyed by ``submap_id`` (int, monotonically increasing).
82
+
83
+ Attributes:
84
+ cpu_frontend_token_buffer: submap_id → [K, P, 2C] frontend token tensor.
85
+ cpu_backend_token_buffer: submap_id → [K, P, 2C] backend-refined token tensor.
86
+ cpu_frontend_descriptor_buffer: submap_id → [desc_dim] frontend descriptor.
87
+ cpu_backend_descriptor_buffer: submap_id → [desc_dim] backend descriptor.
88
+ cpu_frame_ids: submap_id → list of original frame indices.
89
+ """
90
+ cpu_frontend_token_buffer: Dict[int, torch.Tensor] = field(default_factory=dict)
91
+ cpu_backend_token_buffer: Dict[int, torch.Tensor] = field(default_factory=dict)
92
+ cpu_frontend_descriptor_buffer: Dict[int, torch.Tensor] = field(default_factory=dict)
93
+ cpu_backend_descriptor_buffer: Dict[int, torch.Tensor] = field(default_factory=dict)
94
+ cpu_frame_ids: Dict[int, List[int]] = field(default_factory=dict)
95
+ store_on_cpu: bool = True
96
+ detach_stored: bool = True
97
+ default_token_source: str = "frontend"
98
+ default_descriptor_source: str = "frontend"
99
+ default_writeback_token_source: str = "frontend"
100
+ default_writeback_descriptor_source: str = "frontend"
101
+
102
+ # ── convenience ──────────────────────────────────────
103
+ @property
104
+ def cpu_token_buffer(self) -> Dict[int, torch.Tensor]:
105
+ return self._get_token_buffer(self.default_token_source)
106
+
107
+ @property
108
+ def cpu_descriptor_buffer(self) -> Dict[int, torch.Tensor]:
109
+ return self._get_descriptor_buffer(self.default_descriptor_source)
110
+
111
+ @property
112
+ def num_submaps(self) -> int:
113
+ return len(self.cpu_descriptor_buffer)
114
+
115
+ def _get_token_buffer(self, source: str) -> Dict[int, torch.Tensor]:
116
+ if source == "frontend":
117
+ return self.cpu_frontend_token_buffer
118
+ if source == "backend":
119
+ return self.cpu_backend_token_buffer
120
+ raise ValueError(f"Unsupported token source: {source}")
121
+
122
+ def _get_descriptor_buffer(self, source: str) -> Dict[int, torch.Tensor]:
123
+ if source == "frontend":
124
+ return self.cpu_frontend_descriptor_buffer
125
+ if source == "backend":
126
+ return self.cpu_backend_descriptor_buffer
127
+ raise ValueError(f"Unsupported descriptor source: {source}")
128
+
129
+ def _resolve_token_source(self, source: Optional[str], for_writeback: bool = False) -> str:
130
+ if source is not None:
131
+ return source
132
+ return self.default_writeback_token_source if for_writeback else self.default_token_source
133
+
134
+ def _resolve_descriptor_source(self, source: Optional[str], for_writeback: bool = False) -> str:
135
+ if source is not None:
136
+ return source
137
+ return self.default_writeback_descriptor_source if for_writeback else self.default_descriptor_source
138
+
139
+ def _prepare_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
140
+ if self.detach_stored:
141
+ tensor = tensor.detach()
142
+ if self.store_on_cpu:
143
+ tensor = tensor.cpu()
144
+ return tensor
145
+
146
+ def _move_to_device(self, tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
147
+ if tensor.device == device:
148
+ return tensor
149
+ return tensor.to(device, non_blocking=self.store_on_cpu)
150
+
151
+ def store(
152
+ self,
153
+ submap_id: int,
154
+ frame_ids: List[int],
155
+ frontend_tokens: Optional[torch.Tensor] = None,
156
+ frontend_descriptor: Optional[torch.Tensor] = None,
157
+ backend_tokens: Optional[torch.Tensor] = None,
158
+ backend_descriptor: Optional[torch.Tensor] = None,
159
+ ):
160
+ """Store a completed submap in the configured history banks."""
161
+ if frontend_tokens is not None:
162
+ self.cpu_frontend_token_buffer[submap_id] = self._prepare_tensor(frontend_tokens)
163
+ if frontend_descriptor is not None:
164
+ self.cpu_frontend_descriptor_buffer[submap_id] = self._prepare_tensor(frontend_descriptor)
165
+ if backend_tokens is not None:
166
+ self.cpu_backend_token_buffer[submap_id] = self._prepare_tensor(backend_tokens)
167
+ if backend_descriptor is not None:
168
+ self.cpu_backend_descriptor_buffer[submap_id] = self._prepare_tensor(backend_descriptor)
169
+ self.cpu_frame_ids[submap_id] = list(frame_ids)
170
+
171
+ def fetch_tokens(
172
+ self,
173
+ submap_id: int,
174
+ device: torch.device,
175
+ source: Optional[str] = None,
176
+ ) -> torch.Tensor:
177
+ """Move a submap's tokens to *device*."""
178
+ source = self._resolve_token_source(source)
179
+ return self._move_to_device(self._get_token_buffer(source)[submap_id], device)
180
+
181
+ def fetch_frame_ids(self, submap_id: int) -> List[int]:
182
+ return self.cpu_frame_ids[submap_id]
183
+
184
+ def get_all_descriptors(
185
+ self,
186
+ device: torch.device,
187
+ source: Optional[str] = None,
188
+ ) -> torch.Tensor:
189
+ """Return [num_submaps, desc_dim] on *device*, ordered by submap_id."""
190
+ source = self._resolve_descriptor_source(source)
191
+ descriptor_buffer = self._get_descriptor_buffer(source)
192
+ if not descriptor_buffer:
193
+ return torch.empty(0, device=device)
194
+ ids = sorted(descriptor_buffer.keys())
195
+ descs = torch.stack([descriptor_buffer[i] for i in ids])
196
+ return self._move_to_device(descs, device)
197
+
198
+ def id_at_index(self, index: int, source: Optional[str] = None) -> int:
199
+ """Map a 0-based index (into the descriptor matrix) back to submap_id."""
200
+ source = self._resolve_descriptor_source(source)
201
+ return sorted(self._get_descriptor_buffer(source).keys())[index]
202
+
203
+ def update_descriptor(
204
+ self,
205
+ submap_id: int,
206
+ descriptor: torch.Tensor,
207
+ source: Optional[str] = None,
208
+ ):
209
+ source = self._resolve_descriptor_source(source, for_writeback=True)
210
+ self._get_descriptor_buffer(source)[submap_id] = self._prepare_tensor(descriptor)
211
+
212
+ def update_tokens(
213
+ self,
214
+ submap_id: int,
215
+ tokens: torch.Tensor,
216
+ source: Optional[str] = None,
217
+ ):
218
+ """Write-back refined tokens (after backend) to CPU buffer."""
219
+ source = self._resolve_token_source(source, for_writeback=True)
220
+ self._get_token_buffer(source)[submap_id] = self._prepare_tensor(tokens)
221
+
222
+ def detach_all(self):
223
+ for buffer_dict in (
224
+ self.cpu_frontend_token_buffer,
225
+ self.cpu_backend_token_buffer,
226
+ self.cpu_frontend_descriptor_buffer,
227
+ self.cpu_backend_descriptor_buffer,
228
+ ):
229
+ for sid in list(buffer_dict.keys()):
230
+ buffer_dict[sid] = buffer_dict[sid].detach()
231
+
232
+
233
+ # ═══════════════════════════════════════════════════════════
234
+ # TemporalEmbedWrapper — Dual injection (Fix #7)
235
+ # ═══════════════════════════════════════════════════════════
236
+
237
+ class TemporalEmbedWrapper(nn.Module):
238
+ """Dual temporal embedding: input injection + output injection.
239
+
240
+ Phase 1 (input): Add t_emb to hidden_F[:,:,:C] BEFORE backendT.
241
+ → Temporal info participates in all 36 layers of attention.
242
+ Phase 2 (output): Add projected t_emb to BOTH halves of hidden_B AFTER
243
+ backendT.
244
+ → Guarantees Layer 35 (geometry) and Layer 36 (semantics) both
245
+ carry temporal information for downstream heads.
246
+
247
+ Args:
248
+ embed_dim: feature dimension C (default 1024).
249
+ max_frames: maximum temporal index supported.
250
+ mode: 'learned' | 'sinusoidal'.
251
+ """
252
+
253
+ def __init__(self, embed_dim: int = 1024, max_frames: int = 2000,
254
+ mode: str = "learned"):
255
+ super().__init__()
256
+ self.embed_dim = embed_dim
257
+ self.max_frames = max_frames
258
+ self.mode = mode
259
+
260
+ if mode == "learned":
261
+ self.temporal_embed = nn.Embedding(max_frames, embed_dim)
262
+ elif mode == "sinusoidal":
263
+ pe = _build_sinusoidal_pe(max_frames, embed_dim)
264
+ self.register_buffer("temporal_embed_fixed", pe)
265
+ else:
266
+ raise ValueError(f"Unknown temporal embed mode: {mode}")
267
+
268
+ # Separate projections for the two output halves
269
+ self.output_proj_layer35 = nn.Linear(embed_dim, embed_dim, bias=False)
270
+ self.output_proj_layer36 = nn.Linear(embed_dim, embed_dim, bias=False)
271
+
272
+ # Learnable gates — initialised at 0 (sigmoid → 0.5 at start)
273
+ self.gate_layer35 = nn.Parameter(torch.zeros(1))
274
+ self.gate_layer36 = nn.Parameter(torch.zeros(1))
275
+
276
+ # ── core ─────────────────────────────────────────────
277
+ def get_temporal_embed(self, frame_ids: torch.Tensor) -> torch.Tensor:
278
+ """Return [N, C] temporal embedding for given frame indices."""
279
+ frame_ids = frame_ids.clamp(max=self.max_frames - 1)
280
+ if self.mode == "learned":
281
+ return self.temporal_embed(frame_ids)
282
+ else:
283
+ return self.temporal_embed_fixed[frame_ids]
284
+
285
+ def inject_input(self, hidden_F: torch.Tensor,
286
+ frame_ids: torch.Tensor) -> torch.Tensor:
287
+ """Phase 1: add temporal embedding to first C dims of hidden_F.
288
+
289
+ Args:
290
+ hidden_F: [N, P, 2C] frontend token map.
291
+ frame_ids: [N] long tensor of temporal frame indices.
292
+ Returns:
293
+ hidden_F with temporal embedding added to [:, :, :C].
294
+ """
295
+ N, P, C2 = hidden_F.shape
296
+ C = C2 // 2
297
+ t_emb = self.get_temporal_embed(frame_ids) # [N, C]
298
+ hidden_F = hidden_F.clone()
299
+ hidden_F[:, :, :C] = hidden_F[:, :, :C] + t_emb.unsqueeze(1)
300
+ return hidden_F
301
+
302
+ def inject_output(self, hidden_B: torch.Tensor,
303
+ frame_ids: torch.Tensor) -> torch.Tensor:
304
+ """Phase 2: add projected temporal embedding to BOTH halves of hidden_B.
305
+
306
+ Args:
307
+ hidden_B: [N, P, 2C] backend output (Layer35 || Layer36).
308
+ frame_ids: [N] long tensor.
309
+ Returns:
310
+ hidden_B with gated temporal embedding added to both halves.
311
+ """
312
+ N, P, C2 = hidden_B.shape
313
+ C = C2 // 2
314
+ t_emb = self.get_temporal_embed(frame_ids) # [N, C]
315
+
316
+ t_emb_35 = self.output_proj_layer35(t_emb) # [N, C]
317
+ t_emb_36 = self.output_proj_layer36(t_emb) # [N, C]
318
+
319
+ g35 = torch.sigmoid(self.gate_layer35)
320
+ g36 = torch.sigmoid(self.gate_layer36)
321
+
322
+ hidden_B = hidden_B.clone()
323
+ hidden_B[:, :, :C] = hidden_B[:, :, :C] + g35 * t_emb_35.unsqueeze(1)
324
+ hidden_B[:, :, C:] = hidden_B[:, :, C:] + g36 * t_emb_36.unsqueeze(1)
325
+ return hidden_B
326
+
327
+
328
+ # ═══════════════════════════════════════════════════════════
329
+ # GraphGatedMemoryManager
330
+ # ═══════════════════════════════════════════════════════════
331
+
332
+ class GraphGatedMemoryManager(nn.Module):
333
+ """Sparse Windowed & Recursively Retrieved Submap Backend.
334
+
335
+ Manages:
336
+ * GPU active workspace (S_prev + S_curr, up to 2K frames).
337
+ * Differentiable loop closure with a [NO_LOOP] dummy descriptor.
338
+ * Recursive covisibility fetching via an adjacency graph.
339
+ * CPU ↔ GPU memory offloading.
340
+
341
+ All operations are designed to be DDP-safe (Fix #4): the full
342
+ descriptor / retrieval / backendT path is always executed for every
343
+ batch element, gated by a differentiable multiplier.
344
+
345
+ Args:
346
+ submap_size: K — number of frames per submap.
347
+ max_recursive_submaps: cap on historical submaps fetched at once.
348
+ desc_dim: global descriptor dimension (default 128).
349
+ embed_dim: token feature dimension C (default 1024).
350
+ gumbel_tau: initial Gumbel-Softmax temperature.
351
+ loop_mask_mode: "hard_top1" or "soft_all".
352
+ soft_mask_temperature: temperature for soft_all mode.
353
+ soft_mask_bias: bias for soft_all mode.
354
+ """
355
+
356
+ def __init__(
357
+ self,
358
+ submap_size: int = 10,
359
+ max_recursive_submaps: int = 5,
360
+ desc_dim: int = 128,
361
+ embed_dim: int = 1024,
362
+ gumbel_tau: float = 1.0,
363
+ loop_mask_mode: str = "hard_top1",
364
+ soft_mask_temperature: float = 0.25,
365
+ soft_mask_bias: float = 0.2,
366
+ retain_history_grad: bool = False,
367
+ submap_train_mode: str = "full_token",
368
+ submap_retrieval_topk: int = 5,
369
+ submap_fetch_source: str = "frontend",
370
+ submap_descriptor_source: str = "frontend",
371
+ ):
372
+ super().__init__()
373
+ self.K = submap_size
374
+ self.max_recursive = max_recursive_submaps
375
+ self.desc_dim = desc_dim
376
+ self.embed_dim = embed_dim
377
+ self.gumbel_tau = gumbel_tau
378
+ self.loop_mask_mode = loop_mask_mode
379
+ self.soft_mask_temperature = soft_mask_temperature
380
+ self.soft_mask_bias = soft_mask_bias
381
+ self.retain_history_grad = retain_history_grad
382
+ self.submap_train_mode = submap_train_mode
383
+ self.submap_retrieval_topk = int(submap_retrieval_topk)
384
+ self.submap_fetch_source = submap_fetch_source
385
+ self.submap_descriptor_source = submap_descriptor_source
386
+
387
+ valid_modes = {"full_token", "top5_dual_queue"}
388
+ if self.submap_train_mode not in valid_modes:
389
+ raise ValueError(
390
+ f"Unsupported submap_train_mode: {self.submap_train_mode}. "
391
+ f"Expected one of {sorted(valid_modes)}."
392
+ )
393
+ valid_sources = {"frontend", "backend"}
394
+ if self.submap_fetch_source not in valid_sources:
395
+ raise ValueError(f"Unsupported submap_fetch_source: {self.submap_fetch_source}")
396
+ if self.submap_descriptor_source not in valid_sources:
397
+ raise ValueError(f"Unsupported submap_descriptor_source: {self.submap_descriptor_source}")
398
+
399
+ self.use_dual_queue = self.submap_train_mode == "top5_dual_queue"
400
+ if self.submap_retrieval_topk <= 0:
401
+ self.submap_retrieval_topk = 5 if self.use_dual_queue else 0
402
+
403
+ # ── learnable parameters ─────────────────────────
404
+ # [NO_LOOP] dummy descriptor at index 0
405
+ self.no_loop_descriptor = nn.Parameter(
406
+ torch.randn(1, desc_dim) * 0.02
407
+ )
408
+ # Project pooled tokens → global descriptor
409
+ self.desc_proj = nn.Linear(2 * embed_dim, desc_dim)
410
+
411
+ # ── non-parameter state ──────────────────────────
412
+ self.buffer = self._build_buffer()
413
+ self.adjacency: Dict[int, Set[int]] = {}
414
+
415
+ # Frame-level accumulation for the *current* submap
416
+ self._curr_tokens: List[torch.Tensor] = [] # list of [1, P, 2C]
417
+ self._curr_frame_ids: List[int] = []
418
+ # Tokens for the *previous* submap (kept on GPU for sliding window)
419
+ self._prev_tokens: Optional[torch.Tensor] = None # [K, P, 2C]
420
+ self._prev_frame_ids: List[int] = []
421
+
422
+ self._current_submap_id: int = 0
423
+ self._global_frame_counter: int = 0
424
+
425
+ def _build_buffer(self) -> SubMapBuffer:
426
+ return SubMapBuffer(
427
+ store_on_cpu=not self.retain_history_grad,
428
+ detach_stored=not self.retain_history_grad,
429
+ default_token_source=self.submap_fetch_source,
430
+ default_descriptor_source=self.submap_descriptor_source,
431
+ default_writeback_token_source=(
432
+ "backend" if self.use_dual_queue else self.submap_fetch_source
433
+ ),
434
+ default_writeback_descriptor_source=(
435
+ "backend"
436
+ if (self.use_dual_queue or self.submap_descriptor_source == "backend")
437
+ else self.submap_descriptor_source
438
+ ),
439
+ )
440
+
441
+ # ── properties ───────────────────────────────────────
442
+ @property
443
+ def current_submap_id(self) -> int:
444
+ return self._current_submap_id
445
+
446
+ @property
447
+ def submap_complete(self) -> bool:
448
+ return len(self._curr_tokens) >= self.K
449
+
450
+ # ── accumulate ───────────────────────────────────────
451
+ def accumulate(self, frame_token: torch.Tensor, frame_id: Optional[int] = None):
452
+ """Append a single frame's token to the current submap.
453
+
454
+ Args:
455
+ frame_token: [1, P, 2C] or [P, 2C] — output of frontendT.
456
+ frame_id: original temporal frame index (auto-incremented if None).
457
+ """
458
+ if frame_token.dim() == 2:
459
+ frame_token = frame_token.unsqueeze(0)
460
+ self._curr_tokens.append(frame_token)
461
+ fid = frame_id if frame_id is not None else self._global_frame_counter
462
+ self._curr_frame_ids.append(fid)
463
+ self._global_frame_counter += 1
464
+
465
+ # ── descriptor computation ───────────────────────────
466
+ def compute_descriptor(self, tokens: torch.Tensor) -> torch.Tensor:
467
+ """Pool submap tokens → global descriptor.
468
+
469
+ Args:
470
+ tokens: [K, P, 2C] on GPU.
471
+ Returns:
472
+ descriptor: [1, desc_dim].
473
+ """
474
+ # Mean-pool over frames and patches → [1, 2C]
475
+ device_type = tokens.device.type if tokens.is_cuda else "cpu"
476
+ with torch.amp.autocast(device_type=device_type, enabled=False):
477
+ pooled = tokens.float().mean(dim=(0, 1), keepdim=False).unsqueeze(0)
478
+ pooled = torch.nan_to_num(pooled, nan=0.0, posinf=0.0, neginf=0.0)
479
+ desc = self.desc_proj(pooled.to(dtype=self.desc_proj.weight.dtype)).float()
480
+ return torch.nan_to_num(desc, nan=0.0, posinf=0.0, neginf=0.0)
481
+
482
+ # ── loop retrieval (differentiable, DDP-safe) ────────
483
+ def retrieve(
484
+ self,
485
+ curr_desc: torch.Tensor,
486
+ device: torch.device,
487
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[int], int, int, List[int], List[int], Optional[torch.Tensor]]:
488
+ """Differentiable loop closure retrieval (Fix B2: no Python branching).
489
+
490
+ ALL code paths always fetch tokens so every GPU executes the same
491
+ compute graph. The gate multiplier controls whether retrieved
492
+ tokens contribute to the output.
493
+
494
+ Returns:
495
+ gate: [1, 1] — differentiable: 0 = no loop, 1 = loop.
496
+ retrieved_tokens: [R, P, 2C] on GPU (R may be 0 if no history).
497
+ retrieved_fids: list[int] matching dim-0 of retrieved_tokens.
498
+ n_valid_retrieved: int — how many tokens are real.
499
+ primary_sid: submap id of the primary retrieved submap (-1 if none).
500
+ fetch_ids: list[int] — submap IDs that were fetched (for write-back).
501
+ fetch_token_counts: list[int] — number of tokens per fetched submap.
502
+ retrieval_weights: [R] weights for soft_all mode.
503
+ """
504
+ num_hist = self.buffer.num_submaps
505
+
506
+ if num_hist == 0:
507
+ # No history at all — return empty (no padding waste: A5 fix)
508
+ P = self._curr_tokens[0].shape[1] if self._curr_tokens else 1
509
+ C2 = self._curr_tokens[0].shape[2] if self._curr_tokens else 2 * self.embed_dim
510
+ gate = torch.zeros(1, 1, device=device)
511
+ return gate, torch.empty(0, P, C2, device=device), [], 0, -1, [], [], None
512
+
513
+ if self.loop_mask_mode == "soft_all":
514
+ hist_ids = sorted(self.buffer.cpu_descriptor_buffer.keys())
515
+ prev_sid = self._current_submap_id - 1 if self._current_submap_id > 0 else None
516
+ fetch_ids = [sid for sid in hist_ids if sid != prev_sid]
517
+ P = self._curr_tokens[0].shape[1]
518
+ C2 = self._curr_tokens[0].shape[2]
519
+ if not fetch_ids:
520
+ gate = torch.zeros(1, 1, device=device)
521
+ return gate, torch.empty(0, P, C2, device=device), [], 0, -1, [], [], None
522
+
523
+ hist_descs = torch.stack(
524
+ [self.buffer.cpu_descriptor_buffer[sid] for sid in fetch_ids]
525
+ ).to(device, non_blocking=True).float()
526
+ curr_desc_safe = torch.nan_to_num(curr_desc.float(), nan=0.0, posinf=0.0, neginf=0.0)
527
+ hist_descs = torch.nan_to_num(hist_descs, nan=0.0, posinf=0.0, neginf=0.0)
528
+ sim = F.cosine_similarity(
529
+ curr_desc_safe.unsqueeze(1),
530
+ hist_descs.unsqueeze(0),
531
+ dim=-1,
532
+ )
533
+ sim = torch.nan_to_num(sim, nan=-1.0, posinf=1.0, neginf=-1.0).clamp(min=-1.0, max=1.0)
534
+
535
+ selected_fetch_ids = list(fetch_ids)
536
+ selected_sim = sim
537
+ if self.submap_retrieval_topk > 0 and len(fetch_ids) > self.submap_retrieval_topk:
538
+ topk = min(self.submap_retrieval_topk, len(fetch_ids))
539
+ top_scores, top_indices = torch.topk(sim.squeeze(0), k=topk, dim=-1)
540
+ if topk == 1:
541
+ selected_index_list = [int(top_indices.item())]
542
+ else:
543
+ selected_index_list = [int(idx) for idx in top_indices.tolist()]
544
+ selected_fetch_ids = [fetch_ids[idx] for idx in selected_index_list]
545
+ selected_sim = top_scores.unsqueeze(0)
546
+
547
+ tau = max(float(self.soft_mask_temperature), 1e-6)
548
+ weights = torch.sigmoid((selected_sim - self.soft_mask_bias) / tau)
549
+ weights = torch.nan_to_num(weights, nan=0.0, posinf=1.0, neginf=0.0).clamp(min=0.0, max=1.0)
550
+ gate = weights.max(dim=-1, keepdim=True).values
551
+ selected_idx = int(selected_sim.argmax(dim=-1).item())
552
+ primary_submap_id = selected_fetch_ids[selected_idx]
553
+
554
+ retrieved_list: List[torch.Tensor] = []
555
+ retrieved_fids: List[int] = []
556
+ fetch_token_counts: List[int] = []
557
+ for sid in selected_fetch_ids:
558
+ t = self.buffer.fetch_tokens(sid, device)
559
+ retrieved_list.append(t)
560
+ fetch_token_counts.append(t.shape[0])
561
+ retrieved_fids.extend(self.buffer.fetch_frame_ids(sid))
562
+
563
+ retrieved = torch.cat(retrieved_list, dim=0) if retrieved_list else torch.empty(0, P, C2, device=device)
564
+ n_valid = retrieved.shape[0]
565
+ return gate, retrieved, retrieved_fids, n_valid, primary_submap_id, selected_fetch_ids, fetch_token_counts, weights.squeeze(0)
566
+
567
+ # Build descriptor bank: [NO_LOOP] + all historical
568
+ hist_descs = torch.nan_to_num(
569
+ self.buffer.get_all_descriptors(device).float(),
570
+ nan=0.0,
571
+ posinf=0.0,
572
+ neginf=0.0,
573
+ ) # [H, D]
574
+ bank = torch.cat([self.no_loop_descriptor.float(), hist_descs], dim=0) # [H+1, D]
575
+ curr_desc_safe = torch.nan_to_num(curr_desc.float(), nan=0.0, posinf=0.0, neginf=0.0)
576
+
577
+ # Cosine similarity → Gumbel-Softmax
578
+ sim = F.cosine_similarity(
579
+ curr_desc_safe.unsqueeze(1), # [1, 1, D]
580
+ bank.unsqueeze(0), # [1, H+1, D]
581
+ dim=-1,
582
+ ) # [1, H+1]
583
+ sim = torch.nan_to_num(sim, nan=-1.0, posinf=1.0, neginf=-1.0).clamp(min=-1.0, max=1.0)
584
+ selection = F.gumbel_softmax(sim, tau=max(float(self.gumbel_tau), 1e-6), hard=True, dim=-1)
585
+ selection = torch.nan_to_num(selection, nan=0.0, posinf=1.0, neginf=0.0)
586
+
587
+ selected_idx = selection.argmax(dim=-1).item() # int
588
+
589
+ # Gate: differentiable sum of non-NO_LOOP probabilities
590
+ gate = selection[:, 1:].sum(dim=-1, keepdim=True).clamp(min=0.0, max=1.0) # [1, 1]
591
+
592
+ # ── ALWAYS fetch primary + recursive neighbours (Fix B2) ──
593
+ # When NO_LOOP is selected (idx 0), we still fetch the *best*
594
+ # historical submap's tokens so the compute graph is identical
595
+ # across GPUs; the gate multiplier zeros them out.
596
+ P = self._curr_tokens[0].shape[1]
597
+ C2 = self._curr_tokens[0].shape[2]
598
+ retrieved_list: List[torch.Tensor] = []
599
+ retrieved_fids: List[int] = []
600
+ fetch_token_counts: List[int] = []
601
+
602
+ # Determine which submap to fetch (always pick one)
603
+ if selected_idx > 0:
604
+ primary_submap_id = self.buffer.id_at_index(selected_idx - 1)
605
+ else:
606
+ # NO_LOOP selected — still fetch the highest-similarity submap
607
+ # so the compute graph is the same across GPUs
608
+ non_noloop_sim = sim[0, 1:] # [H]
609
+ fallback_idx = non_noloop_sim.argmax().item()
610
+ primary_submap_id = self.buffer.id_at_index(fallback_idx)
611
+
612
+ fetch_ids = [primary_submap_id]
613
+ neighbours = self.adjacency.get(primary_submap_id, set())
614
+ for nid in sorted(neighbours):
615
+ if len(fetch_ids) >= self.max_recursive:
616
+ break
617
+ if nid != self._current_submap_id and nid in self.buffer.cpu_token_buffer:
618
+ fetch_ids.append(nid)
619
+
620
+ for sid in fetch_ids:
621
+ t = self.buffer.fetch_tokens(sid, device)
622
+ retrieved_list.append(t)
623
+ fetch_token_counts.append(t.shape[0])
624
+ retrieved_fids.extend(self.buffer.fetch_frame_ids(sid))
625
+
626
+ # No fixed-size padding (Fix A5): return actual tokens only
627
+ if retrieved_list:
628
+ retrieved = torch.cat(retrieved_list, dim=0) # [R, P, C2]
629
+ else:
630
+ retrieved = torch.empty(0, P, C2, device=device)
631
+
632
+ n_valid = retrieved.shape[0]
633
+ return gate, retrieved, retrieved_fids, n_valid, primary_submap_id, fetch_ids, fetch_token_counts, None
634
+
635
+ # ── finalize submap ──────────────────────────────────
636
+ def finalize_submap(
637
+ self,
638
+ model,
639
+ device: torch.device,
640
+ temporal_wrapper: Optional[TemporalEmbedWrapper] = None,
641
+ enable_temporal_embed: bool = False,
642
+ enable_loop_closure: bool = False,
643
+ tbptt_window: int = 10,
644
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
645
+ """Finalize the current submap: pool, retrieve, run backend, slide window.
646
+
647
+ DDP-safe: always executes the full compute graph (Fix #4).
648
+ A4 fix: tokens within ``tbptt_window`` recent submaps keep gradients;
649
+ older ones are detached to cap memory.
650
+ A5 fix: retrieved tokens are NOT padded to a fixed size; only real
651
+ tokens enter backendT. The gate multiplier zeros out retrieved
652
+ contributions when NO_LOOP is selected.
653
+
654
+ Args:
655
+ model: SLAMFormer instance.
656
+ device: GPU device.
657
+ temporal_wrapper: TemporalEmbedWrapper (or None).
658
+ enable_temporal_embed: whether to inject temporal embeddings.
659
+ enable_loop_closure: whether to attempt differentiable loop retrieval.
660
+ tbptt_window: number of recent submaps whose stored tokens
661
+ keep gradients (A4 fix). Older submaps are
662
+ detached.
663
+
664
+ Returns:
665
+ backend_out: [N_total, P, 2C] — refined tokens from backendT.
666
+ loop_gate: [1, 1] — differentiable gate (0 = no loop, 1 = loop).
667
+ meta: dict with keys 'n_prev', 'n_curr', 'n_retrieved',
668
+ 'frame_ids' (full list), 'curr_frame_ids'.
669
+ """
670
+ # ── 1. Stack current submap tokens ───────────────
671
+ curr_tokens = torch.cat(self._curr_tokens, dim=0).to(device) # [K, P, 2C]
672
+ curr_desc = self.compute_descriptor(curr_tokens) # [1, D]
673
+
674
+ # ── 2. Loop retrieval (always executed, DDP-safe: Fix B2) ─
675
+ if enable_loop_closure:
676
+ loop_gate, retrieved_tokens, retrieved_fids, n_valid_ret, primary_sid, \
677
+ fetch_ids, fetch_token_counts, retrieval_weights = self.retrieve(curr_desc, device)
678
+ else:
679
+ P, C2 = curr_tokens.shape[1], curr_tokens.shape[2]
680
+ loop_gate = torch.zeros(1, 1, device=device)
681
+ retrieved_tokens = torch.empty(0, P, C2, device=device)
682
+ retrieved_fids = []
683
+ n_valid_ret = 0
684
+ primary_sid = -1
685
+ fetch_ids = []
686
+ fetch_token_counts = []
687
+ retrieval_weights = None
688
+
689
+ # ── 3. Build combined token tensor ───────────────
690
+ parts = []
691
+ fid_parts: List[int] = []
692
+
693
+ if self._prev_tokens is not None:
694
+ parts.append(self._prev_tokens.to(device))
695
+ fid_parts.extend(self._prev_frame_ids)
696
+
697
+ parts.append(curr_tokens)
698
+ fid_parts.extend(self._curr_frame_ids)
699
+
700
+ # A5 fix: only append retrieved tokens if there are any (no zero-padding)
701
+ n_retrieved = 0
702
+ if retrieved_tokens.shape[0] > 0:
703
+ if retrieval_weights is not None and len(fetch_token_counts) == len(retrieval_weights):
704
+ gated_chunks = []
705
+ offset = 0
706
+ for weight, count in zip(retrieval_weights, fetch_token_counts):
707
+ gated_chunks.append(
708
+ retrieved_tokens[offset: offset + count] * weight.reshape(1, 1, 1)
709
+ )
710
+ offset += count
711
+ gated_retrieved = torch.cat(gated_chunks, dim=0) if gated_chunks else retrieved_tokens
712
+ else:
713
+ gated_retrieved = retrieved_tokens * loop_gate.unsqueeze(-1)
714
+ parts.append(gated_retrieved)
715
+ fid_parts.extend(retrieved_fids)
716
+ n_retrieved = gated_retrieved.shape[0]
717
+
718
+ combined = torch.cat(parts, dim=0) if parts else curr_tokens # [N_total, P, 2C]
719
+ frame_ids_tensor = torch.tensor(fid_parts, dtype=torch.long, device=device)
720
+
721
+ # ── 4. Temporal embedding — Phase 1 (input) ──────
722
+ if enable_temporal_embed and temporal_wrapper is not None:
723
+ combined = temporal_wrapper.inject_input(combined, frame_ids_tensor)
724
+
725
+ # ── 5. Run backendT (always, DDP-safe) ───────────
726
+ hidden_B = _safe_oom_retry(model.backendT, combined)
727
+
728
+ # ── 6. Temporal embedding — Phase 2 (output) ─────
729
+ if enable_temporal_embed and temporal_wrapper is not None:
730
+ hidden_B = temporal_wrapper.inject_output(hidden_B, frame_ids_tensor)
731
+
732
+ # ── 7. Update adjacency graph (always, no Python branching) ──
733
+ if enable_loop_closure and primary_sid >= 0:
734
+ cid = self._current_submap_id
735
+ self.adjacency.setdefault(cid, set()).add(primary_sid)
736
+ self.adjacency.setdefault(primary_sid, set()).add(cid)
737
+
738
+ # ── 8. Slice refined tokens for each part ────────
739
+ n_prev = self._prev_tokens.shape[0] if self._prev_tokens is not None else 0
740
+ n_curr = curr_tokens.shape[0]
741
+ completed_submap_id = self._current_submap_id
742
+ prev_sid = completed_submap_id - 1 if n_prev > 0 else None
743
+
744
+ curr_backend_tokens = hidden_B[n_prev:n_prev + n_curr]
745
+ should_store_backend_tokens = self.use_dual_queue or self.submap_fetch_source == "backend"
746
+ should_store_backend_desc = self.use_dual_queue or self.submap_descriptor_source == "backend"
747
+ curr_backend_desc = (
748
+ self.compute_descriptor(curr_backend_tokens) if should_store_backend_desc else None
749
+ )
750
+
751
+ if self.use_dual_queue or should_store_backend_tokens or should_store_backend_desc:
752
+ if prev_sid is not None:
753
+ refined_prev = hidden_B[:n_prev]
754
+ if refined_prev.shape[0] > 0:
755
+ if should_store_backend_tokens:
756
+ self.buffer.update_tokens(prev_sid, refined_prev, source="backend")
757
+ if should_store_backend_desc:
758
+ self.buffer.update_descriptor(
759
+ prev_sid,
760
+ self.compute_descriptor(refined_prev).squeeze(0),
761
+ source="backend",
762
+ )
763
+
764
+ offset = n_prev + n_curr
765
+ for sid, count in zip(fetch_ids, fetch_token_counts):
766
+ refined_ret = hidden_B[offset: offset + count]
767
+ if refined_ret.shape[0] > 0:
768
+ if should_store_backend_tokens:
769
+ self.buffer.update_tokens(sid, refined_ret, source="backend")
770
+ if should_store_backend_desc:
771
+ self.buffer.update_descriptor(
772
+ sid,
773
+ self.compute_descriptor(refined_ret).squeeze(0),
774
+ source="backend",
775
+ )
776
+ offset += count
777
+
778
+ # ── 9. Store current submap
779
+ self.buffer.store(
780
+ submap_id=completed_submap_id,
781
+ frame_ids=self._curr_frame_ids,
782
+ frontend_tokens=curr_tokens,
783
+ frontend_descriptor=curr_desc.squeeze(0),
784
+ backend_tokens=curr_backend_tokens if should_store_backend_tokens else None,
785
+ backend_descriptor=(
786
+ curr_backend_desc.squeeze(0) if curr_backend_desc is not None else None
787
+ ),
788
+ )
789
+
790
+ # ── 10. Release retrieved tensors ─────────────────
791
+ del retrieved_tokens
792
+
793
+ # ── 11. TBPTT: detach memory states periodically ──
794
+ # When submap_id crosses a tbptt_window boundary, detach ALL
795
+ # stored tokens in the buffer and _prev_tokens. This cuts the
796
+ # backward graph at memory-state level (not loss level), so
797
+ # gradients flow within the window but not across.
798
+ next_id = self._current_submap_id + 1
799
+ should_detach_history = (
800
+ tbptt_window is not None and tbptt_window > 0 and next_id % tbptt_window == 0 and next_id > 0
801
+ )
802
+ if should_detach_history:
803
+ self.buffer.detach_all()
804
+
805
+ # ── 12. Slide window ─────────────────────────────
806
+ next_prev_tokens = curr_tokens
807
+ if self.submap_fetch_source == "backend" and should_store_backend_tokens:
808
+ next_prev_tokens = curr_backend_tokens
809
+
810
+ if self.retain_history_grad and not should_detach_history:
811
+ self._prev_tokens = next_prev_tokens
812
+ elif self.retain_history_grad:
813
+ self._prev_tokens = next_prev_tokens.detach()
814
+ else:
815
+ self._prev_tokens = next_prev_tokens.detach().cpu()
816
+ self._prev_frame_ids = list(self._curr_frame_ids)
817
+ self._curr_tokens = []
818
+ self._curr_frame_ids = []
819
+ self._current_submap_id += 1
820
+
821
+ active_curr_desc = curr_desc
822
+ if self.submap_descriptor_source == "backend" and curr_backend_desc is not None:
823
+ active_curr_desc = curr_backend_desc
824
+
825
+ meta = {
826
+ 'submap_id': completed_submap_id,
827
+ 'n_prev': n_prev,
828
+ 'n_curr': n_curr,
829
+ 'n_retrieved': n_retrieved,
830
+ 'frame_ids': fid_parts,
831
+ 'curr_frame_ids': list(self._prev_frame_ids), # after slide, prev = old curr
832
+ 'curr_descriptor': active_curr_desc,
833
+ 'curr_frontend_descriptor': curr_desc,
834
+ 'curr_backend_descriptor': curr_backend_desc,
835
+ 'submap_train_mode': self.submap_train_mode,
836
+ 'submap_descriptor_source': self.submap_descriptor_source,
837
+ }
838
+ return hidden_B, loop_gate, meta
839
+
840
+ # ── reset ────────────────────────────────────────────
841
+ def reset(self):
842
+ """Clear all state (e.g. between sequences)."""
843
+ self.buffer = self._build_buffer()
844
+ self.adjacency.clear()
845
+ self._curr_tokens.clear()
846
+ self._curr_frame_ids.clear()
847
+ self._prev_tokens = None
848
+ self._prev_frame_ids.clear()
849
+ self._current_submap_id = 0
850
+ self._global_frame_counter = 0
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/mine_pseudo_gt.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+
11
+ def canonical_view_key_from_values(dataset: Optional[str], label: Optional[str]) -> Optional[str]:
12
+ if dataset is None or label is None:
13
+ return None
14
+ return f"{dataset}::{label}"
15
+
16
+ def load_payload(path: Optional[str]):
17
+ if not path:
18
+ return None
19
+ path = os.path.expanduser(path)
20
+ if path.endswith(".json"):
21
+ with open(path, "r", encoding="utf-8") as f:
22
+ return json.load(f)
23
+ if path.endswith(".jsonl"):
24
+ records = []
25
+ with open(path, "r", encoding="utf-8") as f:
26
+ for line in f:
27
+ line = line.strip()
28
+ if line:
29
+ records.append(json.loads(line))
30
+ return {"records": records}
31
+ return torch.load(path, map_location="cpu", weights_only=False)
32
+
33
+
34
+ def _scalar(value, default=None):
35
+ if isinstance(value, (list, tuple)):
36
+ if not value:
37
+ return default
38
+ return _scalar(value[0], default)
39
+ if torch.is_tensor(value):
40
+ if value.numel() == 0:
41
+ return default
42
+ return value.detach().reshape(-1)[0].cpu().item()
43
+ return value if value is not None else default
44
+
45
+
46
+ def _float(value, default=0.0):
47
+ value = _scalar(value, default)
48
+ try:
49
+ return float(value)
50
+ except (TypeError, ValueError):
51
+ return float(default)
52
+
53
+
54
+ def _bool(record: Optional[Dict[str, Any]], score_threshold: float = 0.0) -> Optional[bool]:
55
+ if record is None:
56
+ return None
57
+ for key in ("is_positive", "accepted", "match", "loop"):
58
+ if key in record:
59
+ return bool(record[key])
60
+ tag = str(record.get("tag", "")).strip().lower()
61
+ if tag in {"positive", "pos", "match", "loop", "true", "1"}:
62
+ return True
63
+ if tag in {"negative", "neg", "false", "0"}:
64
+ return False
65
+ score = _float(record.get("score"), _float(record.get("confidence"), 0.0))
66
+ if score_threshold > 0:
67
+ return score >= score_threshold
68
+ return None
69
+
70
+
71
+ def load_pair_cache(path: Optional[str]) -> Dict[Tuple[str, str], Dict[str, Any]]:
72
+ payload = load_payload(path)
73
+ if payload is None:
74
+ return {}
75
+ raw_records = payload
76
+ if isinstance(payload, dict):
77
+ raw_records = payload.get("records", payload.get("pairs", payload))
78
+ if isinstance(raw_records, dict):
79
+ iterator = [
80
+ ({**value, "pair_key": key} if isinstance(value, dict) else value)
81
+ for key, value in raw_records.items()
82
+ ]
83
+ else:
84
+ iterator = raw_records
85
+ cache = {}
86
+ for record in iterator:
87
+ if not isinstance(record, dict):
88
+ continue
89
+ key_a = record.get("key_a") or record.get("frame_key_a")
90
+ key_b = record.get("key_b") or record.get("frame_key_b")
91
+ if key_a is None:
92
+ key_a = canonical_view_key_from_values(record.get("dataset_a"), record.get("label_a"))
93
+ if key_b is None:
94
+ key_b = canonical_view_key_from_values(record.get("dataset_b"), record.get("label_b"))
95
+ pair_key = record.get("pair_key")
96
+ if (key_a is None or key_b is None) and isinstance(pair_key, str) and "||" in pair_key:
97
+ key_a, key_b = pair_key.split("||", 1)
98
+ if key_a is None or key_b is None or key_a == key_b:
99
+ continue
100
+ pair = (key_a, key_b) if key_a <= key_b else (key_b, key_a)
101
+ cache[pair] = record
102
+ return cache
103
+
104
+
105
+ def inverse_pose(pose: np.ndarray) -> np.ndarray:
106
+ pose = np.asarray(pose, dtype=np.float32)
107
+ inv_pose = np.eye(4, dtype=np.float32)
108
+ R = pose[:3, :3]
109
+ t = pose[:3, 3]
110
+ inv_pose[:3, :3] = R.T
111
+ inv_pose[:3, 3] = -R.T @ t
112
+ return inv_pose
113
+
114
+
115
+ def pose_distance(pose_a: np.ndarray, pose_b: np.ndarray) -> float:
116
+ return float(np.linalg.norm(np.asarray(pose_a[:3, 3]) - np.asarray(pose_b[:3, 3])))
117
+
118
+
119
+ def heading_angle_deg(pose_a: np.ndarray, pose_b: np.ndarray) -> float:
120
+ forward_a = np.asarray(pose_a[:3, 2], dtype=np.float32)
121
+ forward_b = np.asarray(pose_b[:3, 2], dtype=np.float32)
122
+ norm_a = np.linalg.norm(forward_a)
123
+ norm_b = np.linalg.norm(forward_b)
124
+ if norm_a <= 1e-8 or norm_b <= 1e-8:
125
+ return 180.0
126
+ dot = float(np.clip(np.dot(forward_a / norm_a, forward_b / norm_b), -1.0, 1.0))
127
+ return float(np.degrees(np.arccos(dot)))
128
+
129
+
130
+ def sample_pixels(depth: np.ndarray, sample_points: int, rng: np.random.Generator) -> Optional[np.ndarray]:
131
+ valid = np.argwhere(depth > 0)
132
+ if len(valid) == 0:
133
+ return None
134
+ if len(valid) > sample_points:
135
+ valid = valid[rng.choice(len(valid), size=sample_points, replace=False)]
136
+ uv = valid[:, [1, 0]].astype(np.float32)
137
+ uv += 0.5
138
+ return uv
139
+
140
+
141
+ def project_points(uv: np.ndarray, depth: np.ndarray, intrinsics: np.ndarray, pose_src: np.ndarray, pose_dst: np.ndarray):
142
+ fx = float(intrinsics[0, 0])
143
+ fy = float(intrinsics[1, 1])
144
+ cx = float(intrinsics[0, 2])
145
+ cy = float(intrinsics[1, 2])
146
+ u = uv[:, 0]
147
+ v = uv[:, 1]
148
+ z = depth[np.clip((v - 0.5).astype(np.int64), 0, depth.shape[0] - 1), np.clip((u - 0.5).astype(np.int64), 0, depth.shape[1] - 1)]
149
+ x = (u - cx) * z / max(fx, 1e-6)
150
+ y = (v - cy) * z / max(fy, 1e-6)
151
+ points_cam = np.stack([x, y, z], axis=-1)
152
+ world = points_cam @ pose_src[:3, :3].T + pose_src[:3, 3]
153
+ world_to_dst = inverse_pose(pose_dst)
154
+ points_dst = world @ world_to_dst[:3, :3].T + world_to_dst[:3, 3]
155
+ return points_dst
156
+
157
+
158
+ def directed_overlap(
159
+ src_depth: np.ndarray,
160
+ dst_depth: np.ndarray,
161
+ intrinsics_src: np.ndarray,
162
+ intrinsics_dst: np.ndarray,
163
+ pose_src: np.ndarray,
164
+ pose_dst: np.ndarray,
165
+ sample_points: int,
166
+ depth_tolerance_ratio: float,
167
+ rng: np.random.Generator,
168
+ ) -> Dict[str, float]:
169
+ uv = sample_pixels(src_depth, sample_points, rng)
170
+ if uv is None:
171
+ return {"frustum": 0.0, "depth": 0.0, "consistent": 0.0, "count": 0.0}
172
+ points_dst = project_points(uv, src_depth, intrinsics_src, pose_src, pose_dst)
173
+ z_dst = points_dst[:, 2]
174
+ fx = float(intrinsics_dst[0, 0])
175
+ fy = float(intrinsics_dst[1, 1])
176
+ cx = float(intrinsics_dst[0, 2])
177
+ cy = float(intrinsics_dst[1, 2])
178
+ u_dst = fx * points_dst[:, 0] / np.clip(z_dst, 1e-6, None) + cx
179
+ v_dst = fy * points_dst[:, 1] / np.clip(z_dst, 1e-6, None) + cy
180
+ inside = (
181
+ (z_dst > 1e-6)
182
+ & (u_dst >= 0.0)
183
+ & (u_dst < dst_depth.shape[1])
184
+ & (v_dst >= 0.0)
185
+ & (v_dst < dst_depth.shape[0])
186
+ )
187
+ frustum = float(inside.mean())
188
+ if not inside.any():
189
+ return {"frustum": frustum, "depth": 0.0, "consistent": 0.0, "count": float(len(uv))}
190
+ u_idx = np.clip(np.round(u_dst[inside]).astype(np.int64), 0, dst_depth.shape[1] - 1)
191
+ v_idx = np.clip(np.round(v_dst[inside]).astype(np.int64), 0, dst_depth.shape[0] - 1)
192
+ sampled_dst = dst_depth[v_idx, u_idx]
193
+ valid_dst = sampled_dst > 0
194
+ consistent = valid_dst & (np.abs(sampled_dst - z_dst[inside]) / np.clip(sampled_dst, 1e-6, None) <= depth_tolerance_ratio)
195
+ depth_overlap = float(consistent.sum() / max(1, len(uv)))
196
+ return {
197
+ "frustum": frustum,
198
+ "depth": depth_overlap,
199
+ "consistent": float(consistent.sum()),
200
+ "count": float(len(uv)),
201
+ }
202
+
203
+
204
+ def symmetric_overlap(frame_a, frame_b, depth_a, depth_b, args, rng: np.random.Generator) -> Dict[str, float]:
205
+ a_to_b = directed_overlap(
206
+ depth_a,
207
+ depth_b,
208
+ frame_a["intrinsics"],
209
+ frame_b["intrinsics"],
210
+ frame_a["pose"],
211
+ frame_b["pose"],
212
+ args.sample_points,
213
+ args.depth_tolerance_ratio,
214
+ rng,
215
+ )
216
+ b_to_a = directed_overlap(
217
+ depth_b,
218
+ depth_a,
219
+ frame_b["intrinsics"],
220
+ frame_a["intrinsics"],
221
+ frame_b["pose"],
222
+ frame_a["pose"],
223
+ args.sample_points,
224
+ args.depth_tolerance_ratio,
225
+ rng,
226
+ )
227
+ return {
228
+ "frustum_overlap": 0.5 * (a_to_b["frustum"] + b_to_a["frustum"]),
229
+ "depth_overlap": 0.5 * (a_to_b["depth"] + b_to_a["depth"]),
230
+ "geometric_support_count": int(round(a_to_b["consistent"] + b_to_a["consistent"])),
231
+ "sample_count": int(round(a_to_b["count"] + b_to_a["count"])),
232
+ }
233
+
234
+
235
+ def load_arkitscenes(root: str, split: str) -> List[Dict[str, Any]]:
236
+ split_dir = "Training" if split.lower() == "train" else "Test"
237
+ meta_root = Path(root) / split_dir
238
+ all_metadata = np.load(meta_root / "all_metadata.npz")
239
+ scenes = []
240
+ for scene_name in all_metadata["scenes"]:
241
+ scene_name = str(scene_name)
242
+ scene_dir = meta_root / scene_name
243
+ meta_path = scene_dir / "new_scene_metadata.npz"
244
+ if not meta_path.is_file():
245
+ continue
246
+ with np.load(meta_path, allow_pickle=True) as meta:
247
+ images = meta["images"]
248
+ intrinsics = meta["intrinsics"]
249
+ trajectories = meta["trajectories"]
250
+ frames = []
251
+ for basename, intri, pose in zip(images, intrinsics, trajectories):
252
+ basename = str(basename)
253
+ K = np.eye(3, dtype=np.float32)
254
+ K[0, 0] = intri[2]
255
+ K[1, 1] = intri[3]
256
+ K[0, 2] = intri[4]
257
+ K[1, 2] = intri[5]
258
+ depth_path = scene_dir / "lowres_depth" / basename
259
+ image_path = scene_dir / "vga_wide" / basename.replace(".png", ".jpg")
260
+ if not depth_path.is_file():
261
+ continue
262
+ frames.append({
263
+ "dataset": "arkitscenes",
264
+ "scene": scene_name,
265
+ "label": f"{scene_name}_{basename}",
266
+ "pose": np.asarray(pose, dtype=np.float32),
267
+ "intrinsics": K,
268
+ "depth_path": str(depth_path),
269
+ "image_path": str(image_path),
270
+ })
271
+ if frames:
272
+ scenes.append({"scene": scene_name, "frames": frames})
273
+ return scenes
274
+
275
+
276
+ def load_scannetpp(root: str) -> List[Dict[str, Any]]:
277
+ scenes = []
278
+ all_metadata = np.load(Path(root) / "all_metadata.npz")
279
+ for scene_name in all_metadata["scenes"]:
280
+ scene_name = str(scene_name)
281
+ scene_dir = Path(root) / scene_name
282
+ meta_path = scene_dir / "new_scene_metadata.npz"
283
+ if not meta_path.is_file():
284
+ continue
285
+ with np.load(meta_path, allow_pickle=True) as meta:
286
+ images = meta["images"]
287
+ intrinsics = meta["intrinsics"]
288
+ trajectories = meta["trajectories"]
289
+ frames = []
290
+ for basename, intri, pose in zip(images, intrinsics, trajectories):
291
+ basename = str(basename)
292
+ depth_path = scene_dir / "depth" / f"{basename}.png"
293
+ image_path = scene_dir / "images" / f"{basename}.jpg"
294
+ if not depth_path.is_file():
295
+ continue
296
+ frames.append({
297
+ "dataset": "ScanNet++",
298
+ "scene": scene_name,
299
+ "label": f"{scene_name}_{basename}",
300
+ "pose": np.asarray(pose, dtype=np.float32),
301
+ "intrinsics": np.asarray(intri, dtype=np.float32),
302
+ "depth_path": str(depth_path),
303
+ "image_path": str(image_path),
304
+ })
305
+ if frames:
306
+ scenes.append({"scene": scene_name, "frames": frames})
307
+ return scenes
308
+
309
+
310
+ def load_mvs_synth(root: str) -> List[Dict[str, Any]]:
311
+ scenes = []
312
+ for scene_name in sorted(os.listdir(root)):
313
+ scene_dir = Path(root) / scene_name
314
+ rgb_dir = scene_dir / "rgb"
315
+ depth_dir = scene_dir / "depth"
316
+ cam_dir = scene_dir / "cam"
317
+ if not rgb_dir.is_dir() or not depth_dir.is_dir() or not cam_dir.is_dir():
318
+ continue
319
+ basenames = sorted([path.stem for path in rgb_dir.glob("*.jpg")])
320
+ frames = []
321
+ for basename in basenames:
322
+ cam_path = cam_dir / f"{basename}.npz"
323
+ depth_path = depth_dir / f"{basename}.npy"
324
+ if not cam_path.is_file() or not depth_path.is_file():
325
+ continue
326
+ cam = np.load(cam_path)
327
+ frames.append({
328
+ "dataset": "MVS_Synth",
329
+ "scene": scene_name,
330
+ "label": f"{scene_name}_{basename}",
331
+ "pose": np.asarray(cam["pose"], dtype=np.float32),
332
+ "intrinsics": np.asarray(cam["intrinsics"], dtype=np.float32),
333
+ "depth_path": str(depth_path),
334
+ "image_path": str(rgb_dir / f"{basename}.jpg"),
335
+ })
336
+ if frames:
337
+ scenes.append({"scene": scene_name, "frames": frames})
338
+ return scenes
339
+
340
+
341
+ def load_depth(frame: Dict[str, Any]) -> np.ndarray:
342
+ depth_path = frame["depth_path"]
343
+ if depth_path.endswith(".npy"):
344
+ depth = np.load(depth_path).astype(np.float32)
345
+ valid = depth > 0
346
+ if valid.any():
347
+ threshold = np.percentile(depth[valid], 98)
348
+ depth[depth > threshold] = 0.0
349
+ depth[depth > 1000.0] = 0.0
350
+ else:
351
+ depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 1000.0
352
+ depth[~np.isfinite(depth)] = 0.0
353
+ return depth
354
+
355
+
356
+ def load_scenes(dataset: str, root: str, split: str) -> List[Dict[str, Any]]:
357
+ dataset = dataset.lower()
358
+ if dataset == "arkitscenes":
359
+ return load_arkitscenes(root, split)
360
+ if dataset == "scannetpp":
361
+ return load_scannetpp(root)
362
+ if dataset == "mvs_synth":
363
+ return load_mvs_synth(root)
364
+ raise ValueError(f"Unsupported dataset: {dataset}")
365
+
366
+
367
+ def limit_frames(frames: List[Dict[str, Any]], args) -> List[Dict[str, Any]]:
368
+ frames = frames[:: max(1, args.frame_stride)]
369
+ if args.max_frames_per_scene > 0:
370
+ frames = frames[: args.max_frames_per_scene]
371
+ return frames
372
+
373
+
374
+ def pair_key(frame_a: Dict[str, Any], frame_b: Dict[str, Any]) -> Tuple[str, str]:
375
+ key_a = canonical_view_key_from_values(frame_a["dataset"], frame_a["label"])
376
+ key_b = canonical_view_key_from_values(frame_b["dataset"], frame_b["label"])
377
+ return (key_a, key_b) if key_a <= key_b else (key_b, key_a)
378
+
379
+
380
+ def lookup_cache(cache: Dict[Tuple[str, str], Dict[str, Any]], frame_a: Dict[str, Any], frame_b: Dict[str, Any]) -> Optional[Dict[str, Any]]:
381
+ return cache.get(pair_key(frame_a, frame_b))
382
+
383
+
384
+ def l2m_positive(record: Optional[Dict[str, Any]], args) -> bool:
385
+ if record is None:
386
+ return False
387
+ match_count = int(round(_float(record.get("l2m_match_count"), _float(record.get("match_count"), 0.0))))
388
+ certainty = _float(record.get("l2m_mean_certainty"), _float(record.get("mean_certainty"), 0.0))
389
+ inlier_ratio = _float(record.get("l2m_inlier_ratio"), _float(record.get("inlier_ratio"), 0.0))
390
+ return (
391
+ match_count >= args.l2m_min_match_count
392
+ and certainty >= args.l2m_min_certainty
393
+ and inlier_ratio >= args.l2m_min_inlier_ratio
394
+ )
395
+
396
+
397
+ def mine_scene(scene: Dict[str, Any], args, sage_cache, l2m_cache, rng: np.random.Generator):
398
+ frames = limit_frames(scene["frames"], args)
399
+ if len(frames) <= args.min_frame_gap:
400
+ return []
401
+ depth_cache: Dict[str, np.ndarray] = {}
402
+ records = []
403
+ num_pairs = 0
404
+ for i in range(len(frames)):
405
+ for j in range(i + args.min_frame_gap, len(frames), max(1, args.pair_step)):
406
+ if args.max_pairs_per_scene > 0 and num_pairs >= args.max_pairs_per_scene:
407
+ return records
408
+ frame_a = frames[i]
409
+ frame_b = frames[j]
410
+ dist = pose_distance(frame_a["pose"], frame_b["pose"])
411
+ heading = heading_angle_deg(frame_a["pose"], frame_b["pose"])
412
+ positive_coarse = dist <= args.max_translation and heading <= args.max_heading_deg
413
+ negative_coarse = dist <= args.hard_negative_max_translation and heading <= args.hard_negative_max_heading_deg
414
+ if not positive_coarse and not negative_coarse:
415
+ continue
416
+ if frame_a["depth_path"] not in depth_cache:
417
+ depth_cache[frame_a["depth_path"]] = load_depth(frame_a)
418
+ if frame_b["depth_path"] not in depth_cache:
419
+ depth_cache[frame_b["depth_path"]] = load_depth(frame_b)
420
+ overlap = symmetric_overlap(frame_a, frame_b, depth_cache[frame_a["depth_path"]], depth_cache[frame_b["depth_path"]], args, rng)
421
+ frustum_overlap = overlap["frustum_overlap"]
422
+ depth_overlap = overlap["depth_overlap"]
423
+ geometric_support_count = overlap["geometric_support_count"]
424
+ num_pairs += 1
425
+ sage_record = lookup_cache(sage_cache, frame_a, frame_b)
426
+ l2m_record = lookup_cache(l2m_cache, frame_a, frame_b)
427
+ sage_pass = _bool(sage_record, args.sage_min_score)
428
+ l2m_pass = l2m_positive(l2m_record, args)
429
+ if positive_coarse and frustum_overlap >= args.min_frustum_overlap and depth_overlap >= args.min_depth_overlap:
430
+ accepted = True
431
+ verification_path = "geometry_only"
432
+ if sage_cache:
433
+ if sage_pass is False:
434
+ accepted = False
435
+ elif sage_pass is True:
436
+ verification_path = "geometry+sage"
437
+ if l2m_cache:
438
+ if l2m_pass and verification_path == "geometry+sage":
439
+ verification_path = "geometry+sage+l2m"
440
+ elif l2m_pass and verification_path != "geometry+sage":
441
+ verification_path = "geometry+l2m_rescue"
442
+ accepted = True
443
+ elif sage_cache and sage_pass is False:
444
+ accepted = False
445
+ if accepted:
446
+ l2m_match_count = int(round(_float((l2m_record or {}).get("l2m_match_count"), _float((l2m_record or {}).get("match_count"), 0.0))))
447
+ l2m_mean_certainty = _float((l2m_record or {}).get("l2m_mean_certainty"), _float((l2m_record or {}).get("mean_certainty"), 0.0))
448
+ l2m_inlier_ratio = _float((l2m_record or {}).get("l2m_inlier_ratio"), _float((l2m_record or {}).get("inlier_ratio"), 0.0))
449
+ score = max(depth_overlap, frustum_overlap)
450
+ if sage_pass is True:
451
+ score = max(score, _float((sage_record or {}).get("score"), _float((sage_record or {}).get("confidence"), score)))
452
+ if l2m_pass:
453
+ score = max(score, l2m_mean_certainty)
454
+ record = {
455
+ "scene": scene["scene"],
456
+ "dataset_a": frame_a["dataset"],
457
+ "dataset_b": frame_b["dataset"],
458
+ "label_a": frame_a["label"],
459
+ "label_b": frame_b["label"],
460
+ "key_a": canonical_view_key_from_values(frame_a["dataset"], frame_a["label"]),
461
+ "key_b": canonical_view_key_from_values(frame_b["dataset"], frame_b["label"]),
462
+ "tag": "positive",
463
+ "is_positive": True,
464
+ "score": float(score),
465
+ "soft_overlap_target": float(depth_overlap),
466
+ "overlap": float(depth_overlap),
467
+ "pair_confidence_weight": float(max(0.05, min(1.0, score))),
468
+ "weight": float(max(0.05, min(1.0, score))),
469
+ "pose_distance": float(dist),
470
+ "heading_deg": float(heading),
471
+ "frustum_overlap": float(frustum_overlap),
472
+ "depth_overlap": float(depth_overlap),
473
+ "geometric_support_count": int(geometric_support_count),
474
+ "verification_path": verification_path,
475
+ "l2m_match_count": l2m_match_count,
476
+ "l2m_mean_certainty": float(l2m_mean_certainty),
477
+ "l2m_inlier_ratio": float(l2m_inlier_ratio),
478
+ "image_path_a": frame_a["image_path"],
479
+ "image_path_b": frame_b["image_path"],
480
+ }
481
+ records.append(record)
482
+ elif negative_coarse and frustum_overlap <= args.negative_max_frustum_overlap and depth_overlap <= args.negative_max_depth_overlap:
483
+ score = max(1.0 - max(frustum_overlap, depth_overlap), 0.0)
484
+ records.append({
485
+ "scene": scene["scene"],
486
+ "dataset_a": frame_a["dataset"],
487
+ "dataset_b": frame_b["dataset"],
488
+ "label_a": frame_a["label"],
489
+ "label_b": frame_b["label"],
490
+ "key_a": canonical_view_key_from_values(frame_a["dataset"], frame_a["label"]),
491
+ "key_b": canonical_view_key_from_values(frame_b["dataset"], frame_b["label"]),
492
+ "tag": "hard_negative",
493
+ "is_positive": False,
494
+ "score": float(score),
495
+ "soft_overlap_target": 0.0,
496
+ "overlap": 0.0,
497
+ "pair_confidence_weight": float(max(0.05, min(1.0, score))),
498
+ "weight": float(max(0.05, min(1.0, score))),
499
+ "pose_distance": float(dist),
500
+ "heading_deg": float(heading),
501
+ "frustum_overlap": float(frustum_overlap),
502
+ "depth_overlap": float(depth_overlap),
503
+ "geometric_support_count": 0,
504
+ "verification_path": "geometry_negative",
505
+ "l2m_match_count": 0,
506
+ "l2m_mean_certainty": 0.0,
507
+ "l2m_inlier_ratio": 0.0,
508
+ "image_path_a": frame_a["image_path"],
509
+ "image_path_b": frame_b["image_path"],
510
+ })
511
+ return records
512
+
513
+
514
+ def build_argparser():
515
+ parser = argparse.ArgumentParser()
516
+ parser.add_argument("--dataset", required=True, choices=["arkitscenes", "scannetpp", "mvs_synth"])
517
+ parser.add_argument("--root", required=True)
518
+ parser.add_argument("--output", required=True)
519
+ parser.add_argument("--split", default="train")
520
+ parser.add_argument("--seed", type=int, default=42)
521
+ parser.add_argument("--max-scenes", type=int, default=0)
522
+ parser.add_argument("--max-frames-per-scene", type=int, default=0)
523
+ parser.add_argument("--frame-stride", type=int, default=1)
524
+ parser.add_argument("--pair-step", type=int, default=1)
525
+ parser.add_argument("--min-frame-gap", type=int, default=12)
526
+ parser.add_argument("--max-pairs-per-scene", type=int, default=0)
527
+ parser.add_argument("--sample-points", type=int, default=2048)
528
+ parser.add_argument("--depth-tolerance-ratio", type=float, default=0.05)
529
+ parser.add_argument("--max-translation", type=float, default=2.5)
530
+ parser.add_argument("--max-heading-deg", type=float, default=45.0)
531
+ parser.add_argument("--hard-negative-max-translation", type=float, default=3.5)
532
+ parser.add_argument("--hard-negative-max-heading-deg", type=float, default=75.0)
533
+ parser.add_argument("--min-frustum-overlap", type=float, default=0.2)
534
+ parser.add_argument("--min-depth-overlap", type=float, default=0.1)
535
+ parser.add_argument("--negative-max-frustum-overlap", type=float, default=0.05)
536
+ parser.add_argument("--negative-max-depth-overlap", type=float, default=0.02)
537
+ parser.add_argument("--sage-cache", default=None)
538
+ parser.add_argument("--sage-min-score", type=float, default=0.5)
539
+ parser.add_argument("--l2m-cache", default=None)
540
+ parser.add_argument("--l2m-min-match-count", type=int, default=64)
541
+ parser.add_argument("--l2m-min-certainty", type=float, default=0.5)
542
+ parser.add_argument("--l2m-min-inlier-ratio", type=float, default=0.3)
543
+ return parser
544
+
545
+
546
+ def main():
547
+ args = build_argparser().parse_args()
548
+ rng = np.random.default_rng(args.seed)
549
+ scenes = load_scenes(args.dataset, args.root, args.split)
550
+ if args.max_scenes > 0:
551
+ scenes = scenes[: args.max_scenes]
552
+ sage_cache = load_pair_cache(args.sage_cache)
553
+ l2m_cache = load_pair_cache(args.l2m_cache)
554
+ all_records = []
555
+ scene_stats = []
556
+ for index, scene in enumerate(scenes):
557
+ records = mine_scene(scene, args, sage_cache, l2m_cache, rng)
558
+ all_records.extend(records)
559
+ positives = sum(1 for record in records if record.get("is_positive", False))
560
+ negatives = sum(1 for record in records if not record.get("is_positive", False))
561
+ scene_stats.append({
562
+ "scene": scene["scene"],
563
+ "num_frames": len(scene["frames"]),
564
+ "num_records": len(records),
565
+ "num_positives": positives,
566
+ "num_negatives": negatives,
567
+ })
568
+ print(f"[{index + 1}/{len(scenes)}] {scene['scene']}: {len(records)} records ({positives} pos / {negatives} neg)")
569
+ output_path = Path(args.output)
570
+ output_path.parent.mkdir(parents=True, exist_ok=True)
571
+ metadata = {
572
+ "dataset": args.dataset,
573
+ "root": args.root,
574
+ "split": args.split,
575
+ "num_scenes": len(scenes),
576
+ "num_records": len(all_records),
577
+ "num_positive_records": sum(1 for record in all_records if record.get("is_positive", False)),
578
+ "num_negative_records": sum(1 for record in all_records if not record.get("is_positive", False)),
579
+ "args": vars(args),
580
+ "scene_stats": scene_stats,
581
+ }
582
+ with open(output_path, "w", encoding="utf-8") as f:
583
+ json.dump({"metadata": metadata, "records": all_records}, f, indent=2)
584
+ print(f"Wrote {len(all_records)} records to {output_path}")
585
+
586
+
587
+ if __name__ == "__main__":
588
+ main()
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/pseudo_gt.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ _POSITIVE_TAGS = {"positive", "pos", "loop", "match", "1", "true"}
10
+ _NEGATIVE_TAGS = {"negative", "neg", "hard_negative", "hard-neg", "0", "false"}
11
+
12
+
13
+ def _cfg_get(cfg, key, default=None):
14
+ if cfg is None:
15
+ return default
16
+ if isinstance(cfg, dict):
17
+ return cfg.get(key, default)
18
+ if hasattr(cfg, "get"):
19
+ return cfg.get(key, default)
20
+ return getattr(cfg, key, default)
21
+
22
+
23
+ def _first_scalar(value):
24
+ if isinstance(value, (list, tuple)):
25
+ if len(value) == 0:
26
+ return None
27
+ return _first_scalar(value[0])
28
+ if torch.is_tensor(value):
29
+ if value.numel() == 0:
30
+ return None
31
+ return value.detach().reshape(-1)[0].cpu().item()
32
+ return value
33
+
34
+
35
+ def _to_str(value) -> Optional[str]:
36
+ value = _first_scalar(value)
37
+ if value is None:
38
+ return None
39
+ return str(value)
40
+
41
+
42
+ def _to_float(value, default: float = 0.0) -> float:
43
+ value = _first_scalar(value)
44
+ if value is None:
45
+ return float(default)
46
+ try:
47
+ return float(value)
48
+ except (TypeError, ValueError):
49
+ return float(default)
50
+
51
+
52
+ def canonical_view_key_from_values(dataset: Optional[str], label: Optional[str]) -> Optional[str]:
53
+ if dataset is None or label is None:
54
+ return None
55
+ return f"{dataset}::{label}"
56
+
57
+
58
+ def canonical_view_key(view: Dict[str, Any]) -> Optional[str]:
59
+ if not isinstance(view, dict):
60
+ return None
61
+ return canonical_view_key_from_values(_to_str(view.get("dataset")), _to_str(view.get("label")))
62
+
63
+
64
+ class FramePairPseudoGTDatabase:
65
+ def __init__(self, records: Dict[Tuple[str, str], Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None):
66
+ self.records = records
67
+ self.metadata = metadata or {}
68
+
69
+ def __len__(self) -> int:
70
+ return len(self.records)
71
+
72
+ @classmethod
73
+ def from_file(cls, path: str) -> "FramePairPseudoGTDatabase":
74
+ path = os.path.expanduser(path)
75
+ if not os.path.isfile(path):
76
+ raise FileNotFoundError(f"Pseudo-GT cache not found: {path}")
77
+ if path.endswith(".json"):
78
+ with open(path, "r", encoding="utf-8") as f:
79
+ payload = json.load(f)
80
+ elif path.endswith(".jsonl"):
81
+ raw_records = []
82
+ with open(path, "r", encoding="utf-8") as f:
83
+ for line in f:
84
+ line = line.strip()
85
+ if line:
86
+ raw_records.append(json.loads(line))
87
+ payload = {"records": raw_records}
88
+ else:
89
+ payload = torch.load(path, map_location="cpu", weights_only=False)
90
+
91
+ metadata = payload.get("metadata", {}) if isinstance(payload, dict) else {}
92
+ raw_records = payload
93
+ if isinstance(payload, dict):
94
+ raw_records = payload.get("records", payload.get("pairs", payload))
95
+
96
+ records: Dict[Tuple[str, str], Dict[str, Any]] = {}
97
+ if isinstance(raw_records, dict):
98
+ iterator = [
99
+ ({**value, "pair_key": key} if isinstance(value, dict) else value)
100
+ for key, value in raw_records.items()
101
+ ]
102
+ else:
103
+ iterator = raw_records
104
+ for raw_record in iterator:
105
+ record = cls._normalize_record(raw_record)
106
+ if record is None:
107
+ continue
108
+ pair = record["pair"]
109
+ previous = records.get(pair)
110
+ if previous is None or record["score"] > previous["score"]:
111
+ records[pair] = record
112
+ return cls(records, metadata=metadata)
113
+
114
+ @classmethod
115
+ def _normalize_record(cls, raw_record: Any) -> Optional[Dict[str, Any]]:
116
+ if not isinstance(raw_record, dict):
117
+ return None
118
+
119
+ key_a = _to_str(raw_record.get("key_a") or raw_record.get("frame_key_a"))
120
+ key_b = _to_str(raw_record.get("key_b") or raw_record.get("frame_key_b"))
121
+ if key_a is None:
122
+ key_a = canonical_view_key_from_values(
123
+ _to_str(raw_record.get("dataset_a")),
124
+ _to_str(raw_record.get("label_a")),
125
+ )
126
+ if key_b is None:
127
+ key_b = canonical_view_key_from_values(
128
+ _to_str(raw_record.get("dataset_b")),
129
+ _to_str(raw_record.get("label_b")),
130
+ )
131
+ pair_key = _to_str(raw_record.get("pair_key"))
132
+ if (key_a is None or key_b is None) and pair_key and "||" in pair_key:
133
+ key_a, key_b = pair_key.split("||", 1)
134
+ if key_a is None or key_b is None or key_a == key_b:
135
+ return None
136
+ pair = (key_a, key_b) if key_a <= key_b else (key_b, key_a)
137
+
138
+ tag = (_to_str(raw_record.get("tag")) or "").strip().lower()
139
+ is_positive = bool(raw_record.get("is_positive", False))
140
+ if tag in _POSITIVE_TAGS:
141
+ is_positive = True
142
+ elif tag in _NEGATIVE_TAGS:
143
+ is_positive = False
144
+
145
+ score = _to_float(
146
+ raw_record.get("score"),
147
+ _to_float(raw_record.get("confidence"), _to_float(raw_record.get("pair_confidence_weight"), 0.0)),
148
+ )
149
+ overlap = _to_float(
150
+ raw_record.get("soft_overlap_target"),
151
+ _to_float(raw_record.get("overlap"), score if is_positive else 0.0),
152
+ )
153
+ weight = _to_float(raw_record.get("weight"), _to_float(raw_record.get("pair_confidence_weight"), max(score, overlap)))
154
+ geometric_support_count = int(round(_to_float(raw_record.get("geometric_support_count"), 0.0)))
155
+ l2m_match_count = int(round(_to_float(raw_record.get("l2m_match_count"), 0.0)))
156
+ l2m_mean_certainty = _to_float(raw_record.get("l2m_mean_certainty"), _to_float(raw_record.get("l2m_certainty"), 0.0))
157
+ l2m_inlier_ratio = _to_float(raw_record.get("l2m_inlier_ratio"), 0.0)
158
+
159
+ return {
160
+ "pair": pair,
161
+ "is_positive": is_positive,
162
+ "score": float(score),
163
+ "overlap": float(overlap),
164
+ "weight": float(weight),
165
+ "geometric_support_count": geometric_support_count,
166
+ "l2m_match_count": l2m_match_count,
167
+ "l2m_mean_certainty": float(l2m_mean_certainty),
168
+ "l2m_inlier_ratio": float(l2m_inlier_ratio),
169
+ }
170
+
171
+ def lookup(self, key_a: Optional[str], key_b: Optional[str]) -> Optional[Dict[str, Any]]:
172
+ if key_a is None or key_b is None or key_a == key_b:
173
+ return None
174
+ pair = (key_a, key_b) if key_a <= key_b else (key_b, key_a)
175
+ return self.records.get(pair)
176
+
177
+
178
+ class PseudoGTLoopSupervisor:
179
+ def __init__(self, frame_db: FramePairPseudoGTDatabase, pseudo_gt_cfg):
180
+ self.frame_db = frame_db
181
+ self.use_soft_targets = bool(_cfg_get(pseudo_gt_cfg, "use_soft_targets", True))
182
+ self.min_confidence = float(_cfg_get(pseudo_gt_cfg, "min_confidence", 0.65))
183
+ self.min_support_pairs = max(1, int(_cfg_get(pseudo_gt_cfg, "min_support_pairs", 1)))
184
+ self.topk_pairs = max(1, int(_cfg_get(pseudo_gt_cfg, "topk_pairs", 4)))
185
+ self.loss_weight_gate = float(_cfg_get(pseudo_gt_cfg, "loss_weight_gate", 0.1))
186
+ self.loss_weight_desc = float(_cfg_get(pseudo_gt_cfg, "loss_weight_desc", 0.1))
187
+ self.loss_type = str(_cfg_get(pseudo_gt_cfg, "loss_type", "hybrid"))
188
+ self.geometric_support_scale = float(_cfg_get(pseudo_gt_cfg, "geometric_support_scale", 0.25))
189
+ self.ranking_margin = float(_cfg_get(pseudo_gt_cfg, "ranking_margin", 0.1))
190
+ self.use_l2m = bool(_cfg_get(pseudo_gt_cfg, "use_l2m", False))
191
+ self.l2m_min_certainty = float(_cfg_get(pseudo_gt_cfg, "l2m_min_certainty", 0.0))
192
+ self.l2m_min_inlier_ratio = float(_cfg_get(pseudo_gt_cfg, "l2m_min_inlier_ratio", 0.0))
193
+
194
+ @classmethod
195
+ def from_config(cls, pseudo_gt_cfg) -> Optional["PseudoGTLoopSupervisor"]:
196
+ if pseudo_gt_cfg is None or not bool(_cfg_get(pseudo_gt_cfg, "enable", False)):
197
+ return None
198
+ cache_path = _cfg_get(pseudo_gt_cfg, "cache_path", None)
199
+ if cache_path in (None, "", "null"):
200
+ raise ValueError("`pseudo_gt.enable=true` requires `pseudo_gt.cache_path`.")
201
+ return cls(FramePairPseudoGTDatabase.from_file(cache_path), pseudo_gt_cfg)
202
+
203
+ def _frame_keys(self, batch: Sequence[Dict[str, Any]], frame_ids: Sequence[int]) -> List[str]:
204
+ keys: List[str] = []
205
+ if batch is None:
206
+ return keys
207
+ num_views = len(batch)
208
+ for frame_id in frame_ids:
209
+ if num_views <= 0:
210
+ continue
211
+ index = min(max(int(frame_id), 0), num_views - 1)
212
+ key = canonical_view_key(batch[index])
213
+ if key is not None:
214
+ keys.append(key)
215
+ return keys
216
+
217
+ def _has_geometry(self, record: Dict[str, Any]) -> bool:
218
+ support = int(record.get("geometric_support_count", 0)) > 0 or int(record.get("l2m_match_count", 0)) > 0
219
+ if not support:
220
+ return False
221
+ if self.use_l2m and record.get("l2m_match_count", 0) > 0:
222
+ if float(record.get("l2m_mean_certainty", 0.0)) < self.l2m_min_certainty:
223
+ return False
224
+ if float(record.get("l2m_inlier_ratio", 0.0)) < self.l2m_min_inlier_ratio:
225
+ return False
226
+ return True
227
+
228
+ def build_submap_targets(self, batch, current_frame_ids, history_frame_ids_by_submap):
229
+ current_keys = self._frame_keys(batch, current_frame_ids)
230
+ if not current_keys:
231
+ return []
232
+ targets = []
233
+ for submap_id, history_frame_ids in sorted(history_frame_ids_by_submap.items()):
234
+ history_keys = self._frame_keys(batch, history_frame_ids)
235
+ if not history_keys:
236
+ continue
237
+ records = []
238
+ for current_key in current_keys:
239
+ for history_key in history_keys:
240
+ record = self.frame_db.lookup(current_key, history_key)
241
+ if record is not None:
242
+ records.append(record)
243
+ if not records:
244
+ continue
245
+ positives = [record for record in records if record.get("is_positive", False) and record.get("score", 0.0) >= self.min_confidence]
246
+ negatives = [record for record in records if (not record.get("is_positive", False)) and record.get("score", 0.0) >= self.min_confidence]
247
+ if len(positives) >= self.min_support_pairs:
248
+ ranked = sorted(positives, key=lambda record: max(record.get("weight", 0.0), record.get("overlap", 0.0), record.get("score", 0.0)), reverse=True)[: self.topk_pairs]
249
+ soft_target = sum(record.get("overlap", record.get("score", 0.0)) for record in ranked) / max(1, len(ranked))
250
+ confidence = sum(max(record.get("weight", 0.0), record.get("score", 0.0), record.get("overlap", 0.0)) for record in ranked) / max(1, len(ranked))
251
+ geometry = sum(1 for record in positives if self._has_geometry(record))
252
+ targets.append({
253
+ "submap_id": int(submap_id),
254
+ "binary": 1.0,
255
+ "soft": float(soft_target if self.use_soft_targets else 1.0),
256
+ "weight": float(max(0.05, min(1.0, confidence))),
257
+ "geometry": int(geometry),
258
+ })
259
+ elif negatives and not positives:
260
+ ranked = sorted(negatives, key=lambda record: max(record.get("weight", 0.0), record.get("score", 0.0)), reverse=True)[: self.topk_pairs]
261
+ confidence = sum(max(record.get("weight", 0.0), record.get("score", 0.0)) for record in ranked) / max(1, len(ranked))
262
+ targets.append({
263
+ "submap_id": int(submap_id),
264
+ "binary": 0.0,
265
+ "soft": 0.0,
266
+ "weight": float(max(0.05, min(1.0, confidence))),
267
+ "geometry": 0,
268
+ })
269
+ return targets
270
+
271
+ def compute_loss(self, memory_mgr, batch, hidden_B, meta, loop_gate):
272
+ current_submap_id = int(meta.get("submap_id", -1))
273
+ if current_submap_id <= 0:
274
+ return None, {}
275
+ history_frame_ids_by_submap = {
276
+ int(submap_id): list(frame_ids)
277
+ for submap_id, frame_ids in memory_mgr.buffer.cpu_frame_ids.items()
278
+ if int(submap_id) < current_submap_id
279
+ }
280
+ if not history_frame_ids_by_submap:
281
+ return None, {}
282
+ targets = self.build_submap_targets(batch, meta.get("curr_frame_ids", []), history_frame_ids_by_submap)
283
+ if not targets:
284
+ return None, {}
285
+
286
+ current_desc = meta.get("curr_descriptor")
287
+ if current_desc is None:
288
+ n_prev = int(meta.get("n_prev", 0))
289
+ n_curr = int(meta.get("n_curr", 0))
290
+ current_tokens = hidden_B[n_prev:n_prev + n_curr]
291
+ if current_tokens.numel() == 0:
292
+ return None, {}
293
+ current_desc = memory_mgr.compute_descriptor(current_tokens)
294
+ current_desc = current_desc.float()
295
+
296
+ valid_targets = []
297
+ history_descs = []
298
+ for target in targets:
299
+ history_desc = memory_mgr.buffer.cpu_descriptor_buffer.get(target["submap_id"])
300
+ if history_desc is None:
301
+ continue
302
+ history_desc = history_desc.reshape(-1).to(current_desc.device, non_blocking=True).float()
303
+ history_descs.append(history_desc)
304
+ valid_targets.append(target)
305
+ if not history_descs:
306
+ return None, {}
307
+
308
+ history_descs = torch.stack(history_descs, dim=0)
309
+ predicted_cosine = F.cosine_similarity(current_desc.expand(history_descs.shape[0], -1), history_descs, dim=-1).clamp(min=-1.0, max=1.0)
310
+ predicted_similarity = 0.5 * (predicted_cosine + 1.0)
311
+ target_binary = torch.tensor([target["binary"] for target in valid_targets], device=predicted_similarity.device, dtype=predicted_similarity.dtype)
312
+ target_soft = torch.tensor([target["soft"] for target in valid_targets], device=predicted_similarity.device, dtype=predicted_similarity.dtype)
313
+ weights = torch.tensor([target["weight"] for target in valid_targets], device=predicted_similarity.device, dtype=predicted_similarity.dtype)
314
+ geometry = torch.tensor([target["geometry"] for target in valid_targets], device=predicted_similarity.device, dtype=predicted_similarity.dtype)
315
+ if self.geometric_support_scale > 0:
316
+ weights = (weights * (1.0 + self.geometric_support_scale * geometry.clamp(max=1.0))).clamp(max=1.5)
317
+
318
+ gate_loss = predicted_similarity.new_zeros(())
319
+ if self.loss_weight_gate > 0 and torch.is_tensor(loop_gate) and loop_gate.requires_grad:
320
+ gate_target = target_binary.max()
321
+ gate_weight = weights[target_binary.argmax()] if target_binary.numel() > 0 else weights.new_ones(())
322
+ gate_pred = loop_gate.reshape(-1)[0].clamp(min=1e-6, max=1.0 - 1e-6)
323
+ gate_loss = F.binary_cross_entropy(gate_pred, gate_target.clamp(min=0.0, max=1.0), reduction="none") * gate_weight
324
+ gate_loss = gate_loss.mean() * self.loss_weight_gate
325
+
326
+ descriptor_loss = predicted_similarity.new_zeros(())
327
+ if self.loss_weight_desc > 0:
328
+ regression = F.smooth_l1_loss(predicted_similarity, target_soft.clamp(min=0.0, max=1.0), reduction="none")
329
+ descriptor_loss = (regression * weights).sum() / weights.sum().clamp(min=1e-6)
330
+ if self.loss_type in {"hybrid", "ranking"}:
331
+ pos_mask = target_binary > 0.5
332
+ neg_mask = target_binary < 0.5
333
+ if pos_mask.any() and neg_mask.any():
334
+ pos_score = predicted_similarity[pos_mask].max()
335
+ neg_score = predicted_similarity[neg_mask].max()
336
+ descriptor_loss = descriptor_loss + F.relu(self.ranking_margin - pos_score + neg_score)
337
+ descriptor_loss = descriptor_loss * self.loss_weight_desc
338
+
339
+ total_loss = gate_loss + descriptor_loss
340
+ details = {
341
+ "pseudo_gt_gate_loss": float(gate_loss.detach()),
342
+ "pseudo_gt_desc_loss": float(descriptor_loss.detach()),
343
+ "pseudo_gt_total": float(total_loss.detach()),
344
+ "pseudo_gt_pairs": float(len(valid_targets)),
345
+ "pseudo_gt_positive_pairs": float((target_binary > 0.5).sum().detach()),
346
+ "pseudo_gt_negative_pairs": float((target_binary < 0.5).sum().detach()),
347
+ }
348
+ return total_loss, details
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/__init__.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import trimesh
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+ from .geometry_utils import NormalGenerator
13
+
14
+
15
+ import rerun as rr
16
+ from .visualization_utils import reverse_imagenet_normalize, colormap_image
17
+
18
+
19
+ from typing import Dict, Any
20
+
21
+ # depth prediction normals computer
22
+ #PRED_FORMAT_SIZE = [480,640]#[192, 256]
23
+ PRED_FORMAT_SIZE = [680,1200]#[192, 256]
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ compute_normals = NormalGenerator(PRED_FORMAT_SIZE[0], PRED_FORMAT_SIZE[1]).to(device)
26
+
27
+
28
+ def to_device(input_dict, key_ignores=[], device="cuda"):
29
+ """ " Moves tensors in the input dict to the gpu and ignores tensors/elements
30
+ as with keys in key_ignores.
31
+ """
32
+ for k, v in input_dict.items():
33
+ if k not in key_ignores:
34
+ input_dict[k] = v.to(device).float()
35
+ return input_dict
36
+
37
+
38
+ def log_source_data(src_entity_path: str, src_data: Dict[str, Any]) -> None:
39
+ src_images_k3hw = reverse_imagenet_normalize(
40
+ torch.tensor(src_data["image_b3hw"][0].to(device))
41
+ )
42
+ num_src_cameras = src_data["world_T_cam_b44"][0].shape[0]
43
+ for src_idx in range(num_src_cameras):
44
+ src_cam_path = f"{src_entity_path}/{src_idx}"
45
+ world_T_cam_44 = src_data["world_T_cam_b44"][0][src_idx].squeeze().cpu().numpy()
46
+ K_44 = src_data["K_s0_b44"][0][src_idx].squeeze().cpu().numpy()
47
+ log_camera(src_cam_path, world_T_cam_44, K_44)
48
+ log_image(src_cam_path, src_images_k3hw[src_idx], denormalize=False)
49
+
50
+
51
+
52
+ def log_camera(
53
+ entity_path: str, world_T_cam_44: torch.Tensor, K_44: torch.Tensor, kfd=False, update=False,
54
+ ) -> None:
55
+ assert world_T_cam_44.shape == (4, 4)
56
+ assert K_44.shape == (4, 4)
57
+ # Convert and log camera parameters
58
+ Rot, trans = world_T_cam_44[:3, :3], world_T_cam_44[:3, 3]
59
+ K_33 = K_44[:3, :3]
60
+
61
+ K_33[:2] /= 4
62
+
63
+ rr.log(entity_path, rr.Transform3D(translation=trans, mat3x3=Rot))#, axis_length=0))
64
+ if not update: # frontend
65
+ if not kfd:
66
+ rr.log(
67
+ entity_path+'/frustum',
68
+ rr.Pinhole(
69
+ #image_from_camera=K_33,
70
+ #width=PRED_FORMAT_SIZE[1]/4,
71
+ #height=PRED_FORMAT_SIZE[0]/4,
72
+ fov_y=0.7853982,
73
+ aspect_ratio=1.7777778,
74
+ #camera_xyz=rr.ViewCoordinates.RUB,
75
+ camera_xyz=None,
76
+ image_plane_distance=0.1,
77
+ color=[0, 255, 0],
78
+ line_width=0.003,
79
+ ),
80
+ )
81
+ else:
82
+ rr.log(
83
+ entity_path+'/frustum',
84
+ rr.Pinhole(
85
+ image_from_camera=K_33,
86
+ width=PRED_FORMAT_SIZE[1]/4,
87
+ height=PRED_FORMAT_SIZE[0]/4,
88
+ ),
89
+ )
90
+ else:# backend
91
+ pass
92
+
93
+
94
+
95
+ def log_window(
96
+ entity_path: str, world_T_cam_44: torch.Tensor, K_44: torch.Tensor
97
+ ) -> None:
98
+ assert world_T_cam_44.shape == (4, 4)
99
+ assert K_44.shape == (4, 4)
100
+ # Convert and log camera parameters
101
+ Rot, trans = world_T_cam_44[:3, :3], world_T_cam_44[:3, 3]
102
+ rr.log(entity_path, rr.Transform3D(translation=trans, mat3x3=Rot))#, axis_length=0))
103
+
104
+
105
+
106
+ def log_image(
107
+ entity_path: str, color_frame_b3hw: torch.Tensor, denormalize=True
108
+ ) -> None:
109
+ # Image logging
110
+ color_frame_3hw = color_frame_b3hw.squeeze(0)
111
+ if denormalize:
112
+ main_color_3hw = reverse_imagenet_normalize(color_frame_3hw)
113
+ else:
114
+ main_color_3hw = color_frame_3hw
115
+ pil_image = Image.fromarray(
116
+ np.uint8(main_color_3hw.permute(1, 2, 0).cpu().detach().numpy() * 255)
117
+ )
118
+ pil_image = pil_image.resize((PRED_FORMAT_SIZE[1], PRED_FORMAT_SIZE[0]))
119
+ rr.log(f"{entity_path}/image/rgb", rr.Image(pil_image))
120
+
121
+
122
+ def log_rerun(
123
+ entity_path: str,
124
+ cur_data: Dict[str, Any],
125
+ src_data: Dict[str, Any],
126
+ outputs: Dict[str, Any],
127
+ scene_trimesh_mesh: trimesh.Trimesh,
128
+ should_log_source_cams: bool = True,
129
+ ) -> None:
130
+ """
131
+ Logs camera intri/extri, depth, rgb, and mesh to rerun.
132
+ """
133
+ curr_entity_path = f"{entity_path}/current_cam"
134
+ src_entity_path = f"{entity_path}/source_cam"
135
+ if should_log_source_cams:
136
+ log_source_data(src_entity_path, src_data)
137
+
138
+ world_T_cam_44 = cur_data["world_T_cam_b44"].squeeze().cpu().numpy()
139
+ K_44 = cur_data["K_s0_b44"].squeeze().cpu().numpy()
140
+ log_camera(curr_entity_path, world_T_cam_44, K_44)
141
+
142
+ # Depth logging
143
+ depth_pred = outputs["depth_pred_s0_b1hw"]
144
+ our_depth_3hw = depth_pred.squeeze(0)
145
+ our_depth_hw3 = our_depth_3hw.permute(1, 2, 0)
146
+ rr.log(
147
+ f"{curr_entity_path}/image/depth",
148
+ rr.DepthImage(our_depth_hw3.numpy(force=True)),
149
+ )
150
+
151
+ # Normal logging
152
+ invK_s0_b44 = cur_data["invK_s0_b44"].to(device)
153
+ normals_b3hw = compute_normals(depth_pred, invK_s0_b44)
154
+ our_normals_3hw = 0.5 * (1 + normals_b3hw).squeeze(0)
155
+ pil_normal = Image.fromarray(
156
+ np.uint8(our_normals_3hw.permute(1, 2, 0).cpu().detach().numpy() * 255)
157
+ )
158
+ rr.log(f"{curr_entity_path}/image/normal", rr.Image(pil_normal))
159
+
160
+ # Image logging
161
+ color_frame_b3hw = (
162
+ cur_data["high_res_color_b3hw"]
163
+ if "high_res_color_b3hw" in cur_data
164
+ else cur_data["image_b3hw"]
165
+ )
166
+ color_frame_3hw = color_frame_b3hw.squeeze(0)
167
+ main_color_3hw = reverse_imagenet_normalize(color_frame_3hw)
168
+ pil_image = Image.fromarray(
169
+ np.uint8(main_color_3hw.permute(1, 2, 0).cpu().detach().numpy() * 255)
170
+ )
171
+ pil_image = pil_image.resize((PRED_FORMAT_SIZE[1], PRED_FORMAT_SIZE[0]))
172
+ rr.log(f"{curr_entity_path}/image/rgb", rr.Image(pil_image))
173
+
174
+ # lowest cost guess from the cost volume
175
+ lowest_cost_bhw = outputs["lowest_cost_bhw"]
176
+ lowest_cost_3hw = colormap_image(
177
+ lowest_cost_bhw,
178
+ vmin=0,
179
+ vmax=5,
180
+ )
181
+ pil_cost = Image.fromarray(
182
+ np.uint8(lowest_cost_3hw.permute(1, 2, 0).cpu().detach().numpy() * 255)
183
+ )
184
+ pil_cost = pil_cost.resize((PRED_FORMAT_SIZE[1], PRED_FORMAT_SIZE[0]))
185
+ rr.log("lowest_cost_volume", rr.Image(pil_cost))
186
+
187
+ # Fused mesh logging
188
+ rr.log(
189
+ f"{entity_path}/mesh",
190
+ rr.Mesh3D(
191
+ vertex_positions=scene_trimesh_mesh.vertices,
192
+ triangle_indices=scene_trimesh_mesh.faces,
193
+ vertex_colors=scene_trimesh_mesh.visual.vertex_colors,
194
+ ),
195
+ )
196
+
197
+
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/generic_utils.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import pickle
4
+ from pathlib import Path
5
+
6
+ import kornia
7
+ import torch
8
+ import torchvision.transforms.functional as TF
9
+ from PIL import Image
10
+ from torch import nn
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def copy_code_state(path):
16
+ """Copies the code directory into the path specified using rsync. It will
17
+ use a .gitignore file to exclude files in rsync. We preserve modification
18
+ times in rsync."""
19
+
20
+ # create dir
21
+ Path(os.path.join(path)).mkdir(parents=True, exist_ok=True)
22
+
23
+ if os.path.exists("./.gitignore"):
24
+ # use .gitignore to remove junk
25
+ rsync_command = (
26
+ f"rsync -art --exclude-from='./.gitignore' --exclude '.git' . {path}"
27
+ )
28
+ else:
29
+ print("WARNING: no .gitignore found so can't use that to exlcude large "
30
+ "files when making a back up of files in copy_code_state.")
31
+ rsync_command = (
32
+ f"rsync -art --exclude '.git' . {path}"
33
+ )
34
+ os.system(rsync_command)
35
+
36
+ def readlines(filepath):
37
+ """ Reads in a text file and returns lines in a list. """
38
+ with open(filepath, 'r') as f:
39
+ lines = f.read().splitlines()
40
+ return lines
41
+
42
+ def normalize_depth_single(depth_11hw, mask_11hw, robust=False):
43
+
44
+ if mask_11hw is not None:
45
+ valid_depth_vals_N = depth_11hw.masked_select(mask_11hw)
46
+ else:
47
+ valid_depth_vals_N = torch.flatten(depth_11hw)
48
+
49
+ num_valid_pix = valid_depth_vals_N.nelement()
50
+ num_percentile_pix = num_valid_pix // 10
51
+
52
+ if num_valid_pix == 0:
53
+ return depth_11hw
54
+
55
+ sorted_depth_vals_N = torch.sort(valid_depth_vals_N)[0]
56
+ depth_flat_N = sorted_depth_vals_N[num_percentile_pix:-num_percentile_pix]
57
+
58
+ if depth_flat_N.nelement() == 0:
59
+ depth_flat_N = valid_depth_vals_N
60
+
61
+ if robust:
62
+ depth_shift = depth_flat_N.median()
63
+ depth_scale = torch.mean(torch.abs(depth_flat_N - depth_shift))
64
+ else:
65
+ depth_shift = depth_flat_N.mean()
66
+ depth_scale = depth_flat_N.std()
67
+
68
+ depth_norm = (depth_11hw - depth_shift) / depth_scale
69
+
70
+ return depth_norm
71
+
72
+
73
+ def normalize_depth(depth_b1hw: torch.Tensor,
74
+ mask_b1hw: torch.Tensor = None,
75
+ robust: bool = False):
76
+
77
+ depths_11hw = torch.split(depth_b1hw, 1, 0)
78
+ masks_11hw = ([None] * len(depths_11hw) if mask_b1hw is None
79
+ else torch.split(mask_b1hw, 1, 0))
80
+
81
+ depths_norm_11hw = [normalize_depth_single(d, m, robust)
82
+ for d, m in zip(depths_11hw, masks_11hw)]
83
+
84
+ return torch.cat(depths_norm_11hw, dim=0)
85
+
86
+
87
+
88
+ def upsample(x):
89
+ """
90
+ Upsample input tensor by a factor of 2
91
+ """
92
+ return nn.functional.interpolate(
93
+ x,
94
+ scale_factor=2,
95
+ mode="bilinear",
96
+ align_corners=False,
97
+ )
98
+
99
+ def batched_trace(mat_bNN):
100
+ return mat_bNN.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
101
+
102
+ def tensor_B_to_bM(tensor_BS, batch_size, num_views):
103
+ """Unpacks a flattened tensor of tupled elements (BS) into bMS. Tuple size
104
+ is M."""
105
+ # S for wild card number of dims in the middle
106
+ # tensor_bSM = tensor_BS.unfold(0, step=num_views, size=num_views)
107
+ # tensor_bMS = tensor_bSM.movedim(-1, 1)
108
+ tensor_bMS = tensor_BS.view([batch_size, num_views] + list(tensor_BS.shape[1:]))
109
+
110
+ return tensor_bMS
111
+
112
+
113
+ def tensor_bM_to_B(tensor_bMS):
114
+ """Packs an inflated tensor of tupled elements (bMS) into BS. Tuple size
115
+ is M."""
116
+ # S for wild card number of dims in the middle
117
+ num_views = tensor_bMS.shape[1]
118
+ num_batches = tensor_bMS.shape[0]
119
+
120
+ tensor_BS = tensor_bMS.view([num_views * num_batches] + list(tensor_bMS.shape[2:]))
121
+
122
+ return tensor_BS
123
+
124
+ def combine_dims(x, dim_begin, dim_end):
125
+ """Views x with the dimensions from dim_begin to dim_end folded."""
126
+ combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
127
+ return x.view(combined_shape)
128
+
129
+
130
+ def to_gpu(input_dict, key_ignores=[]):
131
+ """" Moves tensors in the input dict to the gpu and ignores tensors/elements
132
+ as with keys in key_ignores.
133
+ """
134
+ for k, v in input_dict.items():
135
+ if k not in key_ignores:
136
+ input_dict[k] = v.cuda().float()
137
+ return input_dict
138
+
139
+ def imagenet_normalize(image):
140
+ """ Normalizes an image with ImageNet statistics. """
141
+ image = TF.normalize(tensor=image,
142
+ mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
143
+ return image
144
+
145
+ def reverse_imagenet_normalize(image):
146
+ """ Reverses ImageNet normalization in an input image. """
147
+
148
+ image = TF.normalize(tensor=image,
149
+ mean=(-2.11790393, -2.03571429, -1.80444444),
150
+ std=(4.36681223, 4.46428571, 4.44444444))
151
+ return image
152
+
153
+
154
+ def read_image_file(filepath,
155
+ height=None,
156
+ width=None,
157
+ value_scale_factor=1.0,
158
+ resampling_mode=Image.BILINEAR,
159
+ disable_warning=False,
160
+ target_aspect_ratio=None):
161
+ """" Reads an image file using PIL, then optionally resizes the image,
162
+ with selective resampling, scales values, and returns the image as a
163
+ tensor
164
+
165
+ Args:
166
+ filepath: path to the image.
167
+ height, width: resolution to resize the image to. Both must not be
168
+ None for scaling to take place.
169
+ value_scale_factor: value to scale image values with, default is 1.0
170
+ resampling_mode: resampling method when resizing using PIL. Default
171
+ is PIL.Image.BILINEAR
172
+ target_aspect_ratio: if not None, will crop the image to match this
173
+ aspect ratio. Default is None
174
+
175
+ Returns:
176
+ img: tensor with (optionally) scaled and resized image data.
177
+
178
+ """
179
+ img = Image.open(filepath)
180
+
181
+ if target_aspect_ratio:
182
+ crop_image_to_target_ratio(img, target_aspect_ratio)
183
+
184
+ # resize if both width and height are not none.
185
+ if height is not None and width is not None:
186
+ img_width, img_height = img.size
187
+ # do we really need to resize? If not, skip.
188
+ if (img_width, img_height) != (width, height):
189
+ # warn if it doesn't make sense.
190
+ if ((width > img_width or height > img_height) and
191
+ not disable_warning):
192
+ logger.warning(
193
+ f"WARNING: target size ({width}, {height}) has a "
194
+ f"dimension larger than input size ({img_width}, "
195
+ f"{img_height}).")
196
+ img = img.resize((width, height), resample=resampling_mode)
197
+
198
+ img = TF.to_tensor(img).float() * value_scale_factor
199
+
200
+ return img
201
+
202
+ def crop_image_to_target_ratio(image, target_aspect_ratio=4.0/3.0):
203
+ """ Crops an image to satisfy a target aspect ratio. """
204
+
205
+ actual_aspect_ratio = image.width/image.height
206
+
207
+ if actual_aspect_ratio > target_aspect_ratio:
208
+ # we should crop width
209
+ new_width = image.height * target_aspect_ratio
210
+
211
+ left = (image.width - new_width)/2
212
+ top = 0
213
+ right = (image.width + new_width)/2
214
+ bottom = image.height
215
+
216
+ # Crop the center of the image
217
+ image = image.crop((left, top, right, bottom))
218
+
219
+ elif actual_aspect_ratio < target_aspect_ratio:
220
+ # we should crop height
221
+ new_height = image.width/target_aspect_ratio
222
+
223
+ left = 0
224
+ top = (image.height - new_height)/2
225
+ right = image.width
226
+ bottom = (image.height + new_height)/2
227
+
228
+ # Crop the center of the image
229
+ image = image.crop((left, top, right, bottom))
230
+
231
+ return image
232
+
233
+ def cache_model_outputs(
234
+ output_path,
235
+ outputs,
236
+ cur_data,
237
+ src_data,
238
+ batch_ind,
239
+ batch_size,
240
+ ):
241
+ """ Helper function for model output during inference. """
242
+
243
+ for elem_ind in range(outputs["depth_pred_s0_b1hw"].shape[0]):
244
+ if "frame_id_string" in cur_data:
245
+ frame_id = cur_data["frame_id_string"][elem_ind]
246
+ else:
247
+ frame_id = (batch_ind * batch_size) + elem_ind
248
+ frame_id = f"{str(frame_id):6d}"
249
+
250
+ elem_filepath = os.path.join(output_path, f"{frame_id}.pickle")
251
+
252
+ elem_output_dict = {}
253
+
254
+ for key in outputs:
255
+ if outputs[key] is not None:
256
+ elem_output_dict[key] = outputs[key][elem_ind].unsqueeze(0)
257
+ else:
258
+ elem_output_dict[key] = None
259
+
260
+ # include some auxiliary information
261
+ elem_output_dict["K_full_depth_b44"] = cur_data[
262
+ "K_full_depth_b44"
263
+ ][elem_ind].unsqueeze(0)
264
+ elem_output_dict["K_s0_b44"] = cur_data[
265
+ "K_s0_b44"
266
+ ][elem_ind].unsqueeze(0)
267
+
268
+ elem_output_dict["frame_id"] = cur_data["frame_id_string"][elem_ind]
269
+ elem_output_dict["src_ids"] = []
270
+ for src_id_list in src_data["frame_id_string"]:
271
+ elem_output_dict["src_ids"].append(src_id_list[elem_ind])
272
+
273
+ with open(elem_filepath, 'wb') as handle:
274
+ pickle.dump(elem_output_dict, handle)
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/geometry_utils.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia
2
+ import numpy as np
3
+ import torch
4
+ import torch.jit as jit
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+
8
+
9
+
10
+ @torch.jit.script
11
+ def to_homogeneous(input_tensor: Tensor, dim: int = 0) -> Tensor:
12
+ """
13
+ Converts tensor to homogeneous coordinates by adding ones to the specified
14
+ dimension
15
+ """
16
+ ones = torch.ones_like(input_tensor.select(dim, 0).unsqueeze(dim))
17
+ output_bkN = torch.cat([input_tensor, ones], dim=dim)
18
+ return output_bkN
19
+
20
+
21
+ class BackprojectDepth(jit.ScriptModule):
22
+ """
23
+ Layer that projects points from 2D camera to 3D space. The 3D points are
24
+ represented in homogeneous coordinates.
25
+ """
26
+
27
+ def __init__(self, height: int, width: int):
28
+ super().__init__()
29
+
30
+ self.height = height
31
+ self.width = width
32
+
33
+ xx, yy = torch.meshgrid(
34
+ torch.arange(self.width),
35
+ torch.arange(self.height),
36
+ indexing='xy',
37
+ )
38
+ pix_coords_2hw = torch.stack((xx, yy), axis=0) + 0.5
39
+
40
+ pix_coords_13N = to_homogeneous(
41
+ pix_coords_2hw,
42
+ dim=0,
43
+ ).flatten(1).unsqueeze(0)
44
+
45
+ # make these tensors into buffers so they are put on the correct GPU
46
+ # automatically
47
+ self.register_buffer("pix_coords_13N", pix_coords_13N)
48
+
49
+ @jit.script_method
50
+ def forward(self, depth_b1hw: Tensor, invK_b44: Tensor) -> Tensor:
51
+ """
52
+ Backprojects spatial points in 2D image space to world space using
53
+ invK_b44 at the depths defined in depth_b1hw.
54
+ """
55
+ cam_points_b3N = torch.matmul(invK_b44[:, :3, :3], self.pix_coords_13N)
56
+ cam_points_b3N = depth_b1hw.flatten(start_dim=2) * cam_points_b3N
57
+ cam_points_b4N = to_homogeneous(cam_points_b3N, dim=1)
58
+ return cam_points_b4N
59
+
60
+
61
+ class Project3D(jit.ScriptModule):
62
+ """
63
+ Layer that projects 3D points into the 2D camera
64
+ """
65
+ def __init__(self, eps: float = 1e-8):
66
+ super().__init__()
67
+
68
+ self.register_buffer("eps", torch.tensor(eps).view(1, 1, 1))
69
+
70
+ @jit.script_method
71
+ def forward(self, points_b4N: Tensor,
72
+ K_b44: Tensor, cam_T_world_b44: Tensor) -> Tensor:
73
+ """
74
+ Projects spatial points in 3D world space to camera image space using
75
+ the extrinsics matrix cam_T_world_b44 and intrinsics K_b44.
76
+ """
77
+ P_b44 = K_b44 @ cam_T_world_b44
78
+
79
+ cam_points_b3N = P_b44[:, :3] @ points_b4N
80
+
81
+ # from Kornia and OpenCV, https://kornia.readthedocs.io/en/latest/_modules/kornia/geometry/conversions.html#convert_points_from_homogeneous
82
+ mask = torch.abs(cam_points_b3N[:, 2:]) > self.eps
83
+ depth_b1N = (cam_points_b3N[:, 2:] + self.eps)
84
+ scale = torch.where(mask, 1.0 / depth_b1N, torch.tensor(1.0, device=depth_b1N.device))
85
+
86
+ pix_coords_b2N = cam_points_b3N[:, :2] * scale
87
+
88
+ return torch.cat([pix_coords_b2N, depth_b1N], dim=1)
89
+
90
+
91
+ class NormalGenerator(jit.ScriptModule):
92
+ def __init__(self, height: int, width: int,
93
+ smoothing_kernel_size: int=5, smoothing_kernel_std: float=2.0):
94
+ """
95
+ Estimates normals from depth maps.
96
+ """
97
+ super().__init__()
98
+ self.height = height
99
+ self.width = width
100
+
101
+ self.backproject = BackprojectDepth(self.height, self.width)
102
+
103
+ self.kernel_size = smoothing_kernel_size
104
+ self.std = smoothing_kernel_std
105
+
106
+ def forward(self, depth_b1hw: Tensor, invK_b44: Tensor) -> Tensor:
107
+ """
108
+ First smoothes incoming depth maps with a gaussian blur, backprojects
109
+ those depth points into world space (see BackprojectDepth), estimates
110
+ the spatial gradient at those points, and finally uses normalized cross
111
+ correlation to estimate a normal vector at each location.
112
+
113
+ """
114
+ depth_smooth_b1hw = kornia.filters.gaussian_blur2d(
115
+ depth_b1hw,
116
+ (self.kernel_size, self.kernel_size),
117
+ (self.std, self.std),
118
+ )
119
+ cam_points_b4N = self.backproject(depth_smooth_b1hw, invK_b44)
120
+ cam_points_b3hw = cam_points_b4N[:, :3].view(-1, 3, self.height, self.width)
121
+
122
+ gradients_b32hw = kornia.filters.spatial_gradient(cam_points_b3hw)
123
+
124
+ return F.normalize(
125
+ torch.cross(
126
+ gradients_b32hw[:, :, 0],
127
+ gradients_b32hw[:, :, 1],
128
+ dim=1,
129
+ ),
130
+ dim=1,
131
+ )
132
+
133
+
134
+ def get_camera_rays(
135
+ world_T_cam_b44,
136
+ world_points_b3N,
137
+ in_camera_frame,
138
+ cam_T_world_b44=None,
139
+ eps=1e-4,
140
+ ):
141
+ """
142
+ Computes camera rays for given camera data and points, optionally shifts
143
+ rays to camera frame.
144
+ """
145
+
146
+ if in_camera_frame:
147
+ batch_size = world_points_b3N.shape[0]
148
+ num_points = world_points_b3N.shape[2]
149
+ world_points_b4N = torch.cat(
150
+ [
151
+ world_points_b3N,
152
+ torch.ones(batch_size, 1, num_points).to(world_points_b3N.device),
153
+ ],
154
+ 1,
155
+ )
156
+ camera_points_b3N = torch.matmul(cam_T_world_b44[:, :3, :4],
157
+ world_points_b4N)
158
+ rays_b3N = camera_points_b3N
159
+ else:
160
+ rays_b3N = world_points_b3N - world_T_cam_b44[:, 0:3, 3][:, :, None].expand(
161
+ world_points_b3N.shape
162
+ )
163
+
164
+ rays_b3N = torch.nn.functional.normalize(rays_b3N, dim=1)
165
+
166
+ return rays_b3N
167
+
168
+
169
+ def pose_distance(pose_b44):
170
+ """
171
+ DVMVS frame pose distance.
172
+ """
173
+
174
+ R = pose_b44[:, :3, :3]
175
+ t = pose_b44[:, :3, 3]
176
+ R_trace = R.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
177
+ R_measure = torch.sqrt(2 *
178
+ (1 - torch.minimum(torch.ones_like(R_trace)*3.0, R_trace) / 3))
179
+ t_measure = torch.norm(t, dim=1)
180
+ combined_measure = torch.sqrt(t_measure ** 2 + R_measure ** 2)
181
+
182
+ return combined_measure, R_measure, t_measure
183
+
184
+ def qvec2rotmat(qvec):
185
+ """
186
+ Quaternion to 3x3 rotation matrix.
187
+ """
188
+ return np.array([
189
+ [
190
+ 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
191
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
192
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
193
+ ], [
194
+ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
195
+ 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
196
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
197
+ ], [
198
+ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
199
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
200
+ 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
201
+ ]
202
+ ])
203
+
204
+ def rotx(t):
205
+ """
206
+ 3D Rotation about the x-axis.
207
+ """
208
+ c = np.cos(t)
209
+ s = np.sin(t)
210
+ return np.array([[1, 0, 0],
211
+ [0, c, -s],
212
+ [0, s, c]])
213
+
214
+ def roty(t):
215
+ """
216
+ 3D Rotation about the y-axis.
217
+ """
218
+ c = np.cos(t)
219
+ s = np.sin(t)
220
+ return np.array([[c, 0, s],
221
+ [0, 1, 0],
222
+ [-s, 0, c]])
223
+
224
+ def rotz(t):
225
+ """
226
+ 3D Rotation about the z-axis.
227
+ """
228
+ c = np.cos(t)
229
+ s = np.sin(t)
230
+ return np.array([[c, -s, 0],
231
+ [s, c, 0],
232
+ [0, 0, 1]])
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/tmp.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import rerun as rr
3
+
4
+ # 初始化 Rerun
5
+ rr.init("Multi-Camera Pose Example",spawn=True)
6
+
7
+ # 假设你有两个摄像头的位姿和内参
8
+ # 摄像头1的位姿和内参
9
+ pose1 = np.array([
10
+ [0.99, -0.10, 0.10, 1.0],
11
+ [0.10, 0.99, -0.10, 2.0],
12
+ [-0.10, 0.10, 0.99, 3.0],
13
+ [0.0, 0.0, 0.0, 1.0]
14
+ ])
15
+ intrinsic1 = np.array([
16
+ [500, 0, 320],
17
+ [0, 500, 240],
18
+ [0, 0, 1]
19
+ ])
20
+
21
+ # 摄像头2的位姿和内参
22
+ pose2 = np.array([
23
+ [0.99, 0.10, -0.10, -1.0],
24
+ [-0.10, 0.99, 0.10, -2.0],
25
+ [0.10, -0.10, 0.99, -3.0],
26
+ [0.0, 0.0, 0.0, 1.0]
27
+ ])
28
+ intrinsic2 = np.array([
29
+ [500, 0, 320],
30
+ [0, 500, 240],
31
+ [0, 0, 1]
32
+ ])
33
+
34
+ # 展示摄像头1
35
+ rr.log_camera("camera1", pose=pose1, intrinsic=intrinsic1)
36
+
37
+ # 展示摄像头2
38
+ rr.log_camera("camera2", pose=pose2, intrinsic=intrinsic2)
39
+
checkpoints/paper_smoke_local_8gpu_submap12/joint_freeze_frontend_fsdp_sub12/paper_smoke_joint_freeze_frontend_fsdp_8gpu_sub12/code/04_04-00:52:12/slam/rerun_helper/visualization_utils.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import matplotlib.pyplot as plt
4
+ import moviepy.editor as mpy
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from .generic_utils import reverse_imagenet_normalize
10
+
11
+
12
+ def colormap_image(
13
+ image_1hw,
14
+ mask_1hw=None,
15
+ invalid_color=(0.0, 0, 0.0),
16
+ flip=True,
17
+ vmin=None,
18
+ vmax=None,
19
+ return_vminvmax=False,
20
+ colormap="turbo",
21
+ ):
22
+ """
23
+ Colormaps a one channel tensor using a matplotlib colormap.
24
+
25
+ Args:
26
+ image_1hw: the tensor to colomap.
27
+ mask_1hw: an optional float mask where 1.0 donates valid pixels.
28
+ colormap: the colormap to use. Default is turbo.
29
+ invalid_color: the color to use for invalid pixels.
30
+ flip: should we flip the colormap? True by default.
31
+ vmin: if provided uses this as the minimum when normalizing the tensor.
32
+ vmax: if provided uses this as the maximum when normalizing the tensor.
33
+ When either of vmin or vmax are None, they are computed from the
34
+ tensor.
35
+ return_vminvmax: when true, returns vmin and vmax.
36
+
37
+ Returns:
38
+ image_cm_3hw: image of the colormapped tensor.
39
+ vmin, vmax: returned when return_vminvmax is true.
40
+
41
+
42
+ """
43
+ valid_vals = image_1hw if mask_1hw is None else image_1hw[mask_1hw.bool()]
44
+ if vmin is None:
45
+ vmin = valid_vals.min()
46
+ if vmax is None:
47
+ vmax = valid_vals.max()
48
+
49
+ cmap = torch.Tensor(
50
+ plt.cm.get_cmap(colormap)(
51
+ torch.linspace(0, 1, 256)
52
+ )[:, :3]
53
+ ).to(image_1hw.device)
54
+ if flip:
55
+ cmap = torch.flip(cmap, (0,))
56
+
57
+ h, w = image_1hw.shape[1:]
58
+
59
+ image_norm_1hw = (image_1hw - vmin) / (vmax - vmin)
60
+ image_int_1hw = (torch.clamp(image_norm_1hw * 255, 0, 255)).byte().long()
61
+
62
+ image_cm_3hw = cmap[image_int_1hw.flatten(start_dim=1)
63
+ ].permute([0, 2, 1]).view([-1, h, w])
64
+
65
+ if mask_1hw is not None:
66
+ invalid_color = torch.Tensor(invalid_color).view(3, 1, 1).to(image_1hw.device)
67
+ image_cm_3hw = image_cm_3hw * mask_1hw + invalid_color * (1 - mask_1hw)
68
+
69
+ if return_vminvmax:
70
+ return image_cm_3hw, vmin, vmax
71
+ else:
72
+ return image_cm_3hw
73
+
74
+ def save_viz_video_frames(frame_list, path, fps=30):
75
+ """
76
+ Saves a video file of numpy RGB frames in frame_list.
77
+ """
78
+ clip = mpy.ImageSequenceClip(frame_list, fps=fps)
79
+ clip.write_videofile(path, verbose=False, logger=None)
80
+
81
+ return
82
+
83
+
84
+ def quick_viz_export(
85
+ output_path,
86
+ outputs,
87
+ cur_data,
88
+ batch_ind,
89
+ valid_mask_b,
90
+ batch_size):
91
+ """ Helper function for quickly exporting depth maps during inference. """
92
+
93
+ if valid_mask_b.sum() == 0:
94
+ batch_vmin = 0.0
95
+ batch_vmax = 5.0
96
+ else:
97
+ batch_vmin = cur_data["full_res_depth_b1hw"][valid_mask_b].min()
98
+ batch_vmax = cur_data["full_res_depth_b1hw"][valid_mask_b].max()
99
+
100
+ if batch_vmax == batch_vmin:
101
+ batch_vmin = 0.0
102
+ batch_vmax = 5.0
103
+
104
+ for elem_ind in range(outputs["depth_pred_s0_b1hw"].shape[0]):
105
+ if "frame_id_string" in cur_data:
106
+ frame_id = cur_data["frame_id_string"][elem_ind]
107
+ else:
108
+ frame_id = (batch_ind * batch_size) + elem_ind
109
+ frame_id = f"{str(frame_id):6d}"
110
+
111
+ # check for valid depths from dataloader
112
+ if valid_mask_b[elem_ind].sum() == 0:
113
+ sample_vmin = 0.0
114
+ sample_vmax = 0.0
115
+ else:
116
+ # these will be the same when the depth map is all ones.
117
+ sample_vmin = cur_data["full_res_depth_b1hw"][elem_ind][valid_mask_b[elem_ind]].min()
118
+ sample_vmax = cur_data["full_res_depth_b1hw"][elem_ind][valid_mask_b[elem_ind]].max()
119
+
120
+ # if no meaningful gt depth in dataloader, don't viz gt and
121
+ # set vmin/max to default
122
+ if sample_vmax != sample_vmin:
123
+ full_res_depth_1hw = cur_data["full_res_depth_b1hw"][elem_ind]
124
+
125
+ full_res_depth_3hw = colormap_image(
126
+ full_res_depth_1hw,
127
+ vmin=batch_vmin, vmax=batch_vmax
128
+ )
129
+
130
+ full_res_depth_hw3 = np.uint8(
131
+ full_res_depth_3hw.permute(1,2,0
132
+ ).cpu().detach().numpy() * 255
133
+ )
134
+ Image.fromarray(full_res_depth_hw3).save(
135
+ os.path.join(output_path,
136
+ f"{frame_id}_gt_depth.png")
137
+ )
138
+
139
+ lowest_cost_3hw = colormap_image(
140
+ outputs["lowest_cost_bhw"][elem_ind].unsqueeze(0),
141
+ vmin=batch_vmin, vmax=batch_vmax
142
+ )
143
+ pil_image = Image.fromarray(
144
+ np.uint8(
145
+ lowest_cost_3hw.permute(1,2,0
146
+ ).cpu().detach().numpy() * 255)
147
+ )
148
+ pil_image.save(os.path.join(output_path,
149
+ f"{frame_id}_lowest_cost_pred.png"))
150
+
151
+ depth_3hw = colormap_image(
152
+ outputs["depth_pred_s0_b1hw"][elem_ind],
153
+ vmin=batch_vmin, vmax=batch_vmax)
154
+ pil_image = Image.fromarray(
155
+ np.uint8(depth_3hw.permute(1,2,0
156
+ ).cpu().detach().numpy() * 255)
157
+ )
158
+
159
+ pil_image.save(os.path.join(output_path, f"{frame_id}_pred_depth.png"))
160
+
161
+ main_color_3hw = cur_data["high_res_color_b3hw"][elem_ind]
162
+ main_color_3hw = reverse_imagenet_normalize(main_color_3hw)
163
+ pil_image = Image.fromarray(
164
+ np.uint8(main_color_3hw.permute(1,2,0
165
+ ).cpu().detach().numpy() * 255)
166
+ )
167
+ pil_image.save(os.path.join(output_path, f"{frame_id}_color.png"))