diff --git a/outdoor_v48_16gpu/.hydra/config.yaml b/outdoor_v48_16gpu/.hydra/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0ef8945348c850208179dcfcf641fb2587351c35 --- /dev/null +++ b/outdoor_v48_16gpu/.hydra/config.yaml @@ -0,0 +1,68 @@ +teacher: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +pretrained: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +load_only_encoder: false +long_context: false +fixed_length: true +resume: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_8gpu/checkpoint-last.pth +benchmark: false +num_views: 64 +num_test_views: 4 +n_corres_train: 0 +n_corres_test: 0 +train_criterion: DistillLoss() +test_criterion: DistillLoss() +allow_repeat: false +root_vkitti2: /scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti +root_kitti: /scratch-shared/wwei2/eval/kitti_odometry/dataset +root_kitti_velo: /gpfs/work2/0/prjs0824/semantickitti/dataset +root_kitti360: /scratch-shared/wwei2/downloads/kitti360/KITTI-360 +root_kitti360_velo: /scratch-shared/wwei2/downloads/kitti360/KITTI-360 +root_waymo: /scratch-shared/wwei2/waymo_v2 +root_waymo_lidar: /scratch-shared/wwei2/waymo_v2 +dataset_vkitti2: VirtualKITTI2_Multi(allow_repeat=${allow_repeat}, split='train', + ROOT="${root_vkitti2}", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), + (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views}, + n_corres=${n_corres_train}) +dataset_kitti360: KITTI360_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_kitti360}", + velodyne_root="${root_kitti360_velo}", aug_crop=16, resolution=[(518, 392), (518, + 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, + num_views=${num_views}, n_corres=${n_corres_train}) +dataset_waymo: Waymo_v2_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_waymo}", + lidar_root="${root_waymo_lidar}", aug_crop=16, resolution=[(518, 392), (518, 336), + (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views}, + n_corres=${n_corres_train}) +train_dataset: 6000 @ ${dataset_vkitti2} + 6000 @ ${dataset_kitti360} + 5400 @ ${dataset_waymo} +test_dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="${root_vkitti2}", resolution=(518, + 154), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test}) +seed: 0 +batch_size: 1 +accum_iter: 1 +gradient_checkpointing: false +epochs: 10 +start_epoch: 0 +start_step: 0 +weight_decay: 0.05 +lr: 1.0e-05 +min_lr: 1.0e-08 +warmup_epochs: 0.5 +amp: 1 +num_workers: 4 +world_size: 1 +local-rank: -1 +dist_url: env:// +rank: 0 +gpu: 0 +distributed: false +dist_backend: nccl +eval_freq: 1 +save_freq: 0.1 +max_checkpoints: 10 +keep_freq: 1 +print_freq: 10 +print_img_freq: 50000000 +num_imgs_vis: 4 +save_dir: /scratch-shared/wwei2/training_upstream/checkpoints +exp_name: outdoor_v48_16gpu +task: StreamVGGT +logdir: ${save_dir}/${exp_name}/logs +output_dir: ${save_dir}/${exp_name}/ diff --git a/outdoor_v48_16gpu/.hydra/hydra.yaml b/outdoor_v48_16gpu/.hydra/hydra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d8a02505cc0c2b2595949790c6877f2a77ebe2b --- /dev/null +++ b/outdoor_v48_16gpu/.hydra/hydra.yaml @@ -0,0 +1,156 @@ +hydra: + run: + dir: ${save_dir}/${exp_name} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} + launcher: + _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher + sweeper: + _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper + max_batch_size: null + params: null + help: + app_name: ${hydra.job.name} + header: '${hydra.help.app_name} is powered by Hydra. + + ' + footer: 'Powered by Hydra (https://hydra.cc) + + Use --hydra-help to view Hydra specific help + + ' + template: '${hydra.help.header} + + == Configuration groups == + + Compose your configuration from those groups (group=option) + + + $APP_CONFIG_GROUPS + + + == Config == + + Override anything in the config (foo.bar=value) + + + $CONFIG + + + ${hydra.help.footer} + + ' + hydra_help: + template: 'Hydra (${hydra.runtime.version}) + + See https://hydra.cc for more info. + + + == Flags == + + $FLAGS_HELP + + + == Configuration groups == + + Compose your configuration from those groups (For example, append hydra/job_logging=disabled + to command line) + + + $HYDRA_CONFIG_GROUPS + + + Use ''--cfg hydra'' to Show the Hydra config. + + ' + hydra_help: ??? + hydra_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][HYDRA] %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + root: + level: INFO + handlers: + - console + loggers: + logging_example: + level: DEBUG + disable_existing_loggers: false + job_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log + root: + level: INFO + handlers: + - console + - file + disable_existing_loggers: false + env: {} + mode: RUN + searchpath: [] + callbacks: {} + output_subdir: .hydra + overrides: + hydra: + - hydra.mode=RUN + task: + - exp_name=outdoor_v48_16gpu + - resume=/scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_8gpu/checkpoint-last.pth + job: + name: mytrain + chdir: null + override_dirname: exp_name=outdoor_v48_16gpu,resume=/scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_8gpu/checkpoint-last.pth + id: ??? + num: ??? + config_name: outdoor_v48 + env_set: {} + env_copy: [] + config: + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: [] + runtime: + version: 1.3.2 + version_base: '1.3' + cwd: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src + config_sources: + - path: hydra.conf + schema: pkg + provider: hydra + - path: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/config + schema: file + provider: main + - path: '' + schema: structured + provider: schema + output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_16gpu + choices: + hydra/env: default + hydra/callbacks: null + hydra/job_logging: default + hydra/hydra_logging: default + hydra/hydra_help: default + hydra/help: default + hydra/sweeper: basic + hydra/launcher: basic + hydra/output: default + verbose: true diff --git a/outdoor_v48_16gpu/.hydra/overrides.yaml b/outdoor_v48_16gpu/.hydra/overrides.yaml new file mode 100644 index 0000000000000000000000000000000000000000..08b75e3f035c7c063bb8fa22bef29320eee0c9c6 --- /dev/null +++ b/outdoor_v48_16gpu/.hydra/overrides.yaml @@ -0,0 +1,2 @@ +- exp_name=outdoor_v48_16gpu +- resume=/scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_8gpu/checkpoint-last.pth diff --git a/outdoor_v48_16gpu/mytrain.log b/outdoor_v48_16gpu/mytrain.log new file mode 100644 index 0000000000000000000000000000000000000000..b8edfae36d5d97e6a0ca24b0d391600ca948c082 --- /dev/null +++ b/outdoor_v48_16gpu/mytrain.log @@ -0,0 +1,930 @@ +[2026-05-02 09:28:25,135][__main__][INFO] - [RANK 0] output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_16gpu/ +[2026-05-02 09:28:25,901][__main__][INFO] - [RANK 0] Saving current code to /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_16gpu/code/05_02-09:28:25 +[2026-05-02 09:28:25,901][__main__][INFO] - [RANK 0] job dir: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src +[2026-05-02 09:28:25,901][__main__][INFO] - [RANK 0] Setting seed to 0 for process 0 +[2026-05-02 09:28:25,903][__main__][INFO] - [RANK 0] Building train dataset 6000 @ VirtualKITTI2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 6000 @ KITTI360_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", velodyne_root="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 5400 @ Waymo_v2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/waymo_v2", lidar_root="/scratch-shared/wwei2/waymo_v2", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) +[2026-05-02 09:28:25,903][__main__][INFO] - [RANK 0] Building Train Data loader for dataset: 6000 @ VirtualKITTI2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 6000 @ KITTI360_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", velodyne_root="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 5400 @ Waymo_v2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/waymo_v2", lidar_root="/scratch-shared/wwei2/waymo_v2", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) +[2026-05-02 09:32:13,562][__main__][INFO] - [RANK 0] Building test dataset 200 @ VirtualKITTI2_Multi(split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", resolution=(518, 154), num_views=4, seed=42, n_corres=0) +[2026-05-02 09:32:13,562][__main__][INFO] - [RANK 0] Building Test Data loader for dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", resolution=(518, 154), num_views=4, seed=42, n_corres=0) +[2026-05-02 09:32:13,641][__main__][INFO] - [RANK 0] Loading model +[2026-05-02 09:32:19,610][__main__][INFO] - [RANK 0] All model parameters: 958696732 +[2026-05-02 09:32:19,610][__main__][INFO] - [RANK 0] >> Creating train criterion = DistillLoss() +[2026-05-02 09:32:19,610][__main__][INFO] - [RANK 0] >> Creating test criterion = DistillLoss() +[2026-05-02 09:32:20,033][__main__][INFO] - [RANK 0] Freezing patch embedding and positional encoding parameters... +[2026-05-02 09:32:20,038][__main__][INFO] - [RANK 0] Frozen 304,376,832 parameters out of 958,696,732 total parameters. (31.75%) +[2026-05-02 09:32:20,038][__main__][INFO] - [RANK 0] Trainable parameters: 654,319,900 (68.25%) +[2026-05-02 09:32:20,038][__main__][INFO] - [RANK 0] Example frozen parameters: register_token, encoder.cls_token, encoder.pos_embed, encoder.register_tokens, encoder.patch_embed.proj.weight... +[2026-05-02 09:32:20,063][croco.utils.misc][INFO] - [RANK 0] Param groups = { + "no_decay": { + "weight_decay": 0.0, + "params": [ + "decoder.0.norm1.weight", + "decoder.0.norm1.bias", + "decoder.0.attn.qkv.bias", + "decoder.0.attn.proj.bias", + "decoder.0.attn.q_norm.weight", + "decoder.0.attn.q_norm.bias", + "decoder.0.attn.k_norm.weight", + "decoder.0.attn.k_norm.bias", + "decoder.0.ls1.gamma", + "decoder.0.norm2.weight", + "decoder.0.norm2.bias", + "decoder.0.mlp.fc1.bias", + "decoder.0.mlp.fc2.bias", + "decoder.0.ls2.gamma", + "decoder.1.norm1.weight", + "decoder.1.norm1.bias", + "decoder.1.attn.qkv.bias", + "decoder.1.attn.proj.bias", + "decoder.1.attn.q_norm.weight", + "decoder.1.attn.q_norm.bias", + "decoder.1.attn.k_norm.weight", + "decoder.1.attn.k_norm.bias", + "decoder.1.ls1.gamma", + "decoder.1.norm2.weight", + "decoder.1.norm2.bias", + "decoder.1.mlp.fc1.bias", + "decoder.1.mlp.fc2.bias", + "decoder.1.ls2.gamma", + "decoder.2.norm1.weight", + "decoder.2.norm1.bias", + "decoder.2.attn.qkv.bias", + "decoder.2.attn.proj.bias", + "decoder.2.attn.q_norm.weight", + "decoder.2.attn.q_norm.bias", + "decoder.2.attn.k_norm.weight", + "decoder.2.attn.k_norm.bias", + "decoder.2.ls1.gamma", + "decoder.2.norm2.weight", + "decoder.2.norm2.bias", + "decoder.2.mlp.fc1.bias", + "decoder.2.mlp.fc2.bias", + "decoder.2.ls2.gamma", + "decoder.3.norm1.weight", + "decoder.3.norm1.bias", + "decoder.3.attn.qkv.bias", + "decoder.3.attn.proj.bias", + "decoder.3.attn.q_norm.weight", + "decoder.3.attn.q_norm.bias", + "decoder.3.attn.k_norm.weight", + "decoder.3.attn.k_norm.bias", + "decoder.3.ls1.gamma", + "decoder.3.norm2.weight", + "decoder.3.norm2.bias", + "decoder.3.mlp.fc1.bias", + "decoder.3.mlp.fc2.bias", + "decoder.3.ls2.gamma", + "decoder.4.norm1.weight", + "decoder.4.norm1.bias", + "decoder.4.attn.qkv.bias", + "decoder.4.attn.proj.bias", + "decoder.4.attn.q_norm.weight", + "decoder.4.attn.q_norm.bias", + "decoder.4.attn.k_norm.weight", + "decoder.4.attn.k_norm.bias", + "decoder.4.ls1.gamma", + "decoder.4.norm2.weight", + "decoder.4.norm2.bias", + "decoder.4.mlp.fc1.bias", + "decoder.4.mlp.fc2.bias", + "decoder.4.ls2.gamma", + "decoder.5.norm1.weight", + "decoder.5.norm1.bias", + "decoder.5.attn.qkv.bias", + "decoder.5.attn.proj.bias", + "decoder.5.attn.q_norm.weight", + "decoder.5.attn.q_norm.bias", + "decoder.5.attn.k_norm.weight", + "decoder.5.attn.k_norm.bias", + "decoder.5.ls1.gamma", + "decoder.5.norm2.weight", + "decoder.5.norm2.bias", + "decoder.5.mlp.fc1.bias", + "decoder.5.mlp.fc2.bias", + "decoder.5.ls2.gamma", + "decoder.6.norm1.weight", + "decoder.6.norm1.bias", + "decoder.6.attn.qkv.bias", + "decoder.6.attn.proj.bias", + "decoder.6.attn.q_norm.weight", + "decoder.6.attn.q_norm.bias", + "decoder.6.attn.k_norm.weight", + "decoder.6.attn.k_norm.bias", + "decoder.6.ls1.gamma", + "decoder.6.norm2.weight", + "decoder.6.norm2.bias", + "decoder.6.mlp.fc1.bias", + "decoder.6.mlp.fc2.bias", + "decoder.6.ls2.gamma", + "decoder.7.norm1.weight", + "decoder.7.norm1.bias", + "decoder.7.attn.qkv.bias", + "decoder.7.attn.proj.bias", + "decoder.7.attn.q_norm.weight", + "decoder.7.attn.q_norm.bias", + "decoder.7.attn.k_norm.weight", + "decoder.7.attn.k_norm.bias", + "decoder.7.ls1.gamma", + "decoder.7.norm2.weight", + "decoder.7.norm2.bias", + "decoder.7.mlp.fc1.bias", + "decoder.7.mlp.fc2.bias", + "decoder.7.ls2.gamma", + "decoder.8.norm1.weight", + "decoder.8.norm1.bias", + "decoder.8.attn.qkv.bias", + "decoder.8.attn.proj.bias", + "decoder.8.attn.q_norm.weight", + "decoder.8.attn.q_norm.bias", + "decoder.8.attn.k_norm.weight", + "decoder.8.attn.k_norm.bias", + "decoder.8.ls1.gamma", + "decoder.8.norm2.weight", + "decoder.8.norm2.bias", + "decoder.8.mlp.fc1.bias", + "decoder.8.mlp.fc2.bias", + "decoder.8.ls2.gamma", + "decoder.9.norm1.weight", + "decoder.9.norm1.bias", + "decoder.9.attn.qkv.bias", + "decoder.9.attn.proj.bias", + "decoder.9.attn.q_norm.weight", + "decoder.9.attn.q_norm.bias", + "decoder.9.attn.k_norm.weight", + "decoder.9.attn.k_norm.bias", + "decoder.9.ls1.gamma", + "decoder.9.norm2.weight", + "decoder.9.norm2.bias", + "decoder.9.mlp.fc1.bias", + "decoder.9.mlp.fc2.bias", + "decoder.9.ls2.gamma", + "decoder.10.norm1.weight", + "decoder.10.norm1.bias", + "decoder.10.attn.qkv.bias", + "decoder.10.attn.proj.bias", + "decoder.10.attn.q_norm.weight", + "decoder.10.attn.q_norm.bias", + "decoder.10.attn.k_norm.weight", + "decoder.10.attn.k_norm.bias", + "decoder.10.ls1.gamma", + "decoder.10.norm2.weight", + "decoder.10.norm2.bias", + "decoder.10.mlp.fc1.bias", + "decoder.10.mlp.fc2.bias", + "decoder.10.ls2.gamma", + "decoder.11.norm1.weight", + "decoder.11.norm1.bias", + "decoder.11.attn.qkv.bias", + "decoder.11.attn.proj.bias", + "decoder.11.attn.q_norm.weight", + "decoder.11.attn.q_norm.bias", + "decoder.11.attn.k_norm.weight", + "decoder.11.attn.k_norm.bias", + "decoder.11.ls1.gamma", + "decoder.11.norm2.weight", + "decoder.11.norm2.bias", + "decoder.11.mlp.fc1.bias", + "decoder.11.mlp.fc2.bias", + "decoder.11.ls2.gamma", + "decoder.12.norm1.weight", + "decoder.12.norm1.bias", + "decoder.12.attn.qkv.bias", + "decoder.12.attn.proj.bias", + "decoder.12.attn.q_norm.weight", + "decoder.12.attn.q_norm.bias", + "decoder.12.attn.k_norm.weight", + "decoder.12.attn.k_norm.bias", + "decoder.12.ls1.gamma", + "decoder.12.norm2.weight", + "decoder.12.norm2.bias", + "decoder.12.mlp.fc1.bias", + "decoder.12.mlp.fc2.bias", + "decoder.12.ls2.gamma", + "decoder.13.norm1.weight", + "decoder.13.norm1.bias", + "decoder.13.attn.qkv.bias", + "decoder.13.attn.proj.bias", + "decoder.13.attn.q_norm.weight", + "decoder.13.attn.q_norm.bias", + "decoder.13.attn.k_norm.weight", + "decoder.13.attn.k_norm.bias", + "decoder.13.ls1.gamma", + "decoder.13.norm2.weight", + "decoder.13.norm2.bias", + "decoder.13.mlp.fc1.bias", + "decoder.13.mlp.fc2.bias", + "decoder.13.ls2.gamma", + "decoder.14.norm1.weight", + "decoder.14.norm1.bias", + "decoder.14.attn.qkv.bias", + "decoder.14.attn.proj.bias", + "decoder.14.attn.q_norm.weight", + "decoder.14.attn.q_norm.bias", + "decoder.14.attn.k_norm.weight", + "decoder.14.attn.k_norm.bias", + "decoder.14.ls1.gamma", + "decoder.14.norm2.weight", + "decoder.14.norm2.bias", + "decoder.14.mlp.fc1.bias", + "decoder.14.mlp.fc2.bias", + "decoder.14.ls2.gamma", + "decoder.15.norm1.weight", + "decoder.15.norm1.bias", + "decoder.15.attn.qkv.bias", + "decoder.15.attn.proj.bias", + "decoder.15.attn.q_norm.weight", + "decoder.15.attn.q_norm.bias", + "decoder.15.attn.k_norm.weight", + "decoder.15.attn.k_norm.bias", + "decoder.15.ls1.gamma", + "decoder.15.norm2.weight", + "decoder.15.norm2.bias", + "decoder.15.mlp.fc1.bias", + "decoder.15.mlp.fc2.bias", + "decoder.15.ls2.gamma", + "decoder.16.norm1.weight", + "decoder.16.norm1.bias", + "decoder.16.attn.qkv.bias", + "decoder.16.attn.proj.bias", + "decoder.16.attn.q_norm.weight", + "decoder.16.attn.q_norm.bias", + "decoder.16.attn.k_norm.weight", + "decoder.16.attn.k_norm.bias", + "decoder.16.ls1.gamma", + "decoder.16.norm2.weight", + "decoder.16.norm2.bias", + "decoder.16.mlp.fc1.bias", + "decoder.16.mlp.fc2.bias", + "decoder.16.ls2.gamma", + "decoder.17.norm1.weight", + "decoder.17.norm1.bias", + "decoder.17.attn.qkv.bias", + "decoder.17.attn.proj.bias", + "decoder.17.attn.q_norm.weight", + "decoder.17.attn.q_norm.bias", + "decoder.17.attn.k_norm.weight", + "decoder.17.attn.k_norm.bias", + "decoder.17.ls1.gamma", + "decoder.17.norm2.weight", + "decoder.17.norm2.bias", + "decoder.17.mlp.fc1.bias", + "decoder.17.mlp.fc2.bias", + "decoder.17.ls2.gamma", + "decoder.18.norm1.weight", + "decoder.18.norm1.bias", + "decoder.18.attn.qkv.bias", + "decoder.18.attn.proj.bias", + "decoder.18.attn.q_norm.weight", + "decoder.18.attn.q_norm.bias", + "decoder.18.attn.k_norm.weight", + "decoder.18.attn.k_norm.bias", + "decoder.18.ls1.gamma", + "decoder.18.norm2.weight", + "decoder.18.norm2.bias", + "decoder.18.mlp.fc1.bias", + "decoder.18.mlp.fc2.bias", + "decoder.18.ls2.gamma", + "decoder.19.norm1.weight", + "decoder.19.norm1.bias", + "decoder.19.attn.qkv.bias", + "decoder.19.attn.proj.bias", + "decoder.19.attn.q_norm.weight", + "decoder.19.attn.q_norm.bias", + "decoder.19.attn.k_norm.weight", + "decoder.19.attn.k_norm.bias", + "decoder.19.ls1.gamma", + "decoder.19.norm2.weight", + "decoder.19.norm2.bias", + "decoder.19.mlp.fc1.bias", + "decoder.19.mlp.fc2.bias", + "decoder.19.ls2.gamma", + "decoder.20.norm1.weight", + "decoder.20.norm1.bias", + "decoder.20.attn.qkv.bias", + "decoder.20.attn.proj.bias", + "decoder.20.attn.q_norm.weight", + "decoder.20.attn.q_norm.bias", + "decoder.20.attn.k_norm.weight", + "decoder.20.attn.k_norm.bias", + "decoder.20.ls1.gamma", + "decoder.20.norm2.weight", + "decoder.20.norm2.bias", + "decoder.20.mlp.fc1.bias", + "decoder.20.mlp.fc2.bias", + "decoder.20.ls2.gamma", + "decoder.21.norm1.weight", + "decoder.21.norm1.bias", + "decoder.21.attn.qkv.bias", + "decoder.21.attn.proj.bias", + "decoder.21.attn.q_norm.weight", + "decoder.21.attn.q_norm.bias", + "decoder.21.attn.k_norm.weight", + "decoder.21.attn.k_norm.bias", + "decoder.21.ls1.gamma", + "decoder.21.norm2.weight", + "decoder.21.norm2.bias", + "decoder.21.mlp.fc1.bias", + "decoder.21.mlp.fc2.bias", + "decoder.21.ls2.gamma", + "decoder.22.norm1.weight", + "decoder.22.norm1.bias", + "decoder.22.attn.qkv.bias", + "decoder.22.attn.proj.bias", + "decoder.22.attn.q_norm.weight", + "decoder.22.attn.q_norm.bias", + "decoder.22.attn.k_norm.weight", + "decoder.22.attn.k_norm.bias", + "decoder.22.ls1.gamma", + "decoder.22.norm2.weight", + "decoder.22.norm2.bias", + "decoder.22.mlp.fc1.bias", + "decoder.22.mlp.fc2.bias", + "decoder.22.ls2.gamma", + "decoder.23.norm1.weight", + "decoder.23.norm1.bias", + "decoder.23.attn.qkv.bias", + "decoder.23.attn.proj.bias", + "decoder.23.attn.q_norm.weight", + "decoder.23.attn.q_norm.bias", + "decoder.23.attn.k_norm.weight", + "decoder.23.attn.k_norm.bias", + "decoder.23.ls1.gamma", + "decoder.23.norm2.weight", + "decoder.23.norm2.bias", + "decoder.23.mlp.fc1.bias", + "decoder.23.mlp.fc2.bias", + "decoder.23.ls2.gamma", + "decoder.24.norm1.weight", + "decoder.24.norm1.bias", + "decoder.24.attn.qkv.bias", + "decoder.24.attn.proj.bias", + "decoder.24.attn.q_norm.weight", + "decoder.24.attn.q_norm.bias", + "decoder.24.attn.k_norm.weight", + "decoder.24.attn.k_norm.bias", + "decoder.24.ls1.gamma", + "decoder.24.norm2.weight", + "decoder.24.norm2.bias", + "decoder.24.mlp.fc1.bias", + "decoder.24.mlp.fc2.bias", + "decoder.24.ls2.gamma", + "decoder.25.norm1.weight", + "decoder.25.norm1.bias", + "decoder.25.attn.qkv.bias", + "decoder.25.attn.proj.bias", + "decoder.25.attn.q_norm.weight", + "decoder.25.attn.q_norm.bias", + "decoder.25.attn.k_norm.weight", + "decoder.25.attn.k_norm.bias", + "decoder.25.ls1.gamma", + "decoder.25.norm2.weight", + "decoder.25.norm2.bias", + "decoder.25.mlp.fc1.bias", + "decoder.25.mlp.fc2.bias", + "decoder.25.ls2.gamma", + "decoder.26.norm1.weight", + "decoder.26.norm1.bias", + "decoder.26.attn.qkv.bias", + "decoder.26.attn.proj.bias", + "decoder.26.attn.q_norm.weight", + "decoder.26.attn.q_norm.bias", + "decoder.26.attn.k_norm.weight", + "decoder.26.attn.k_norm.bias", + "decoder.26.ls1.gamma", + "decoder.26.norm2.weight", + "decoder.26.norm2.bias", + "decoder.26.mlp.fc1.bias", + "decoder.26.mlp.fc2.bias", + "decoder.26.ls2.gamma", + "decoder.27.norm1.weight", + "decoder.27.norm1.bias", + "decoder.27.attn.qkv.bias", + "decoder.27.attn.proj.bias", + "decoder.27.attn.q_norm.weight", + "decoder.27.attn.q_norm.bias", + "decoder.27.attn.k_norm.weight", + "decoder.27.attn.k_norm.bias", + "decoder.27.ls1.gamma", + "decoder.27.norm2.weight", + "decoder.27.norm2.bias", + "decoder.27.mlp.fc1.bias", + "decoder.27.mlp.fc2.bias", + "decoder.27.ls2.gamma", + "decoder.28.norm1.weight", + "decoder.28.norm1.bias", + "decoder.28.attn.qkv.bias", + "decoder.28.attn.proj.bias", + "decoder.28.attn.q_norm.weight", + "decoder.28.attn.q_norm.bias", + "decoder.28.attn.k_norm.weight", + "decoder.28.attn.k_norm.bias", + "decoder.28.ls1.gamma", + "decoder.28.norm2.weight", + "decoder.28.norm2.bias", + "decoder.28.mlp.fc1.bias", + "decoder.28.mlp.fc2.bias", + "decoder.28.ls2.gamma", + "decoder.29.norm1.weight", + "decoder.29.norm1.bias", + "decoder.29.attn.qkv.bias", + "decoder.29.attn.proj.bias", + "decoder.29.attn.q_norm.weight", + "decoder.29.attn.q_norm.bias", + "decoder.29.attn.k_norm.weight", + "decoder.29.attn.k_norm.bias", + "decoder.29.ls1.gamma", + "decoder.29.norm2.weight", + "decoder.29.norm2.bias", + "decoder.29.mlp.fc1.bias", + "decoder.29.mlp.fc2.bias", + "decoder.29.ls2.gamma", + "decoder.30.norm1.weight", + "decoder.30.norm1.bias", + "decoder.30.attn.qkv.bias", + "decoder.30.attn.proj.bias", + "decoder.30.attn.q_norm.weight", + "decoder.30.attn.q_norm.bias", + "decoder.30.attn.k_norm.weight", + "decoder.30.attn.k_norm.bias", + "decoder.30.ls1.gamma", + "decoder.30.norm2.weight", + "decoder.30.norm2.bias", + "decoder.30.mlp.fc1.bias", + "decoder.30.mlp.fc2.bias", + "decoder.30.ls2.gamma", + "decoder.31.norm1.weight", + "decoder.31.norm1.bias", + "decoder.31.attn.qkv.bias", + "decoder.31.attn.proj.bias", + "decoder.31.attn.q_norm.weight", + "decoder.31.attn.q_norm.bias", + "decoder.31.attn.k_norm.weight", + "decoder.31.attn.k_norm.bias", + "decoder.31.ls1.gamma", + "decoder.31.norm2.weight", + "decoder.31.norm2.bias", + "decoder.31.mlp.fc1.bias", + "decoder.31.mlp.fc2.bias", + "decoder.31.ls2.gamma", + "decoder.32.norm1.weight", + "decoder.32.norm1.bias", + "decoder.32.attn.qkv.bias", + "decoder.32.attn.proj.bias", + "decoder.32.attn.q_norm.weight", + "decoder.32.attn.q_norm.bias", + "decoder.32.attn.k_norm.weight", + "decoder.32.attn.k_norm.bias", + "decoder.32.ls1.gamma", + "decoder.32.norm2.weight", + "decoder.32.norm2.bias", + "decoder.32.mlp.fc1.bias", + "decoder.32.mlp.fc2.bias", + "decoder.32.ls2.gamma", + "decoder.33.norm1.weight", + "decoder.33.norm1.bias", + "decoder.33.attn.qkv.bias", + "decoder.33.attn.proj.bias", + "decoder.33.attn.q_norm.weight", + "decoder.33.attn.q_norm.bias", + "decoder.33.attn.k_norm.weight", + "decoder.33.attn.k_norm.bias", + "decoder.33.ls1.gamma", + "decoder.33.norm2.weight", + "decoder.33.norm2.bias", + "decoder.33.mlp.fc1.bias", + "decoder.33.mlp.fc2.bias", + "decoder.33.ls2.gamma", + "decoder.34.norm1.weight", + "decoder.34.norm1.bias", + "decoder.34.attn.qkv.bias", + "decoder.34.attn.proj.bias", + "decoder.34.attn.q_norm.weight", + "decoder.34.attn.q_norm.bias", + "decoder.34.attn.k_norm.weight", + "decoder.34.attn.k_norm.bias", + "decoder.34.ls1.gamma", + "decoder.34.norm2.weight", + "decoder.34.norm2.bias", + "decoder.34.mlp.fc1.bias", + "decoder.34.mlp.fc2.bias", + "decoder.34.ls2.gamma", + "decoder.35.norm1.weight", + "decoder.35.norm1.bias", + "decoder.35.attn.qkv.bias", + "decoder.35.attn.proj.bias", + "decoder.35.attn.q_norm.weight", + "decoder.35.attn.q_norm.bias", + "decoder.35.attn.k_norm.weight", + "decoder.35.attn.k_norm.bias", + "decoder.35.ls1.gamma", + "decoder.35.norm2.weight", + "decoder.35.norm2.bias", + "decoder.35.mlp.fc1.bias", + "decoder.35.mlp.fc2.bias", + "decoder.35.ls2.gamma", + "point_decoder.projects.bias", + "point_decoder.blocks.0.norm1.weight", + "point_decoder.blocks.0.norm1.bias", + "point_decoder.blocks.0.attn.qkv.bias", + "point_decoder.blocks.0.attn.proj.bias", + "point_decoder.blocks.0.norm2.weight", + "point_decoder.blocks.0.norm2.bias", + "point_decoder.blocks.0.mlp.fc1.bias", + "point_decoder.blocks.0.mlp.fc2.bias", + "point_decoder.blocks.1.norm1.weight", + "point_decoder.blocks.1.norm1.bias", + "point_decoder.blocks.1.attn.qkv.bias", + "point_decoder.blocks.1.attn.proj.bias", + "point_decoder.blocks.1.norm2.weight", + "point_decoder.blocks.1.norm2.bias", + "point_decoder.blocks.1.mlp.fc1.bias", + "point_decoder.blocks.1.mlp.fc2.bias", + "point_decoder.blocks.2.norm1.weight", + "point_decoder.blocks.2.norm1.bias", + "point_decoder.blocks.2.attn.qkv.bias", + "point_decoder.blocks.2.attn.proj.bias", + "point_decoder.blocks.2.norm2.weight", + "point_decoder.blocks.2.norm2.bias", + "point_decoder.blocks.2.mlp.fc1.bias", + "point_decoder.blocks.2.mlp.fc2.bias", + "point_decoder.blocks.3.norm1.weight", + "point_decoder.blocks.3.norm1.bias", + "point_decoder.blocks.3.attn.qkv.bias", + "point_decoder.blocks.3.attn.proj.bias", + "point_decoder.blocks.3.norm2.weight", + "point_decoder.blocks.3.norm2.bias", + "point_decoder.blocks.3.mlp.fc1.bias", + "point_decoder.blocks.3.mlp.fc2.bias", + "point_decoder.blocks.4.norm1.weight", + "point_decoder.blocks.4.norm1.bias", + "point_decoder.blocks.4.attn.qkv.bias", + "point_decoder.blocks.4.attn.proj.bias", + "point_decoder.blocks.4.norm2.weight", + "point_decoder.blocks.4.norm2.bias", + "point_decoder.blocks.4.mlp.fc1.bias", + "point_decoder.blocks.4.mlp.fc2.bias", + "point_decoder.linear_out.bias", + "point_head.proj.bias", + "conf_decoder.projects.bias", + "conf_decoder.blocks.0.norm1.weight", + "conf_decoder.blocks.0.norm1.bias", + "conf_decoder.blocks.0.attn.qkv.bias", + "conf_decoder.blocks.0.attn.proj.bias", + "conf_decoder.blocks.0.norm2.weight", + "conf_decoder.blocks.0.norm2.bias", + "conf_decoder.blocks.0.mlp.fc1.bias", + "conf_decoder.blocks.0.mlp.fc2.bias", + "conf_decoder.blocks.1.norm1.weight", + "conf_decoder.blocks.1.norm1.bias", + "conf_decoder.blocks.1.attn.qkv.bias", + "conf_decoder.blocks.1.attn.proj.bias", + "conf_decoder.blocks.1.norm2.weight", + "conf_decoder.blocks.1.norm2.bias", + "conf_decoder.blocks.1.mlp.fc1.bias", + "conf_decoder.blocks.1.mlp.fc2.bias", + "conf_decoder.blocks.2.norm1.weight", + "conf_decoder.blocks.2.norm1.bias", + "conf_decoder.blocks.2.attn.qkv.bias", + "conf_decoder.blocks.2.attn.proj.bias", + "conf_decoder.blocks.2.norm2.weight", + "conf_decoder.blocks.2.norm2.bias", + "conf_decoder.blocks.2.mlp.fc1.bias", + "conf_decoder.blocks.2.mlp.fc2.bias", + "conf_decoder.blocks.3.norm1.weight", + "conf_decoder.blocks.3.norm1.bias", + "conf_decoder.blocks.3.attn.qkv.bias", + "conf_decoder.blocks.3.attn.proj.bias", + "conf_decoder.blocks.3.norm2.weight", + "conf_decoder.blocks.3.norm2.bias", + "conf_decoder.blocks.3.mlp.fc1.bias", + "conf_decoder.blocks.3.mlp.fc2.bias", + "conf_decoder.blocks.4.norm1.weight", + "conf_decoder.blocks.4.norm1.bias", + "conf_decoder.blocks.4.attn.qkv.bias", + "conf_decoder.blocks.4.attn.proj.bias", + "conf_decoder.blocks.4.norm2.weight", + "conf_decoder.blocks.4.norm2.bias", + "conf_decoder.blocks.4.mlp.fc1.bias", + "conf_decoder.blocks.4.mlp.fc2.bias", + "conf_decoder.linear_out.bias", + "conf_head.proj.bias", + "camera_decoder.projects.bias", + "camera_decoder.blocks.0.norm1.weight", + "camera_decoder.blocks.0.norm1.bias", + "camera_decoder.blocks.0.attn.qkv.bias", + "camera_decoder.blocks.0.attn.proj.bias", + "camera_decoder.blocks.0.norm2.weight", + "camera_decoder.blocks.0.norm2.bias", + "camera_decoder.blocks.0.mlp.fc1.bias", + "camera_decoder.blocks.0.mlp.fc2.bias", + "camera_decoder.blocks.1.norm1.weight", + "camera_decoder.blocks.1.norm1.bias", + "camera_decoder.blocks.1.attn.qkv.bias", + "camera_decoder.blocks.1.attn.proj.bias", + "camera_decoder.blocks.1.norm2.weight", + "camera_decoder.blocks.1.norm2.bias", + "camera_decoder.blocks.1.mlp.fc1.bias", + "camera_decoder.blocks.1.mlp.fc2.bias", + "camera_decoder.blocks.2.norm1.weight", + "camera_decoder.blocks.2.norm1.bias", + "camera_decoder.blocks.2.attn.qkv.bias", + "camera_decoder.blocks.2.attn.proj.bias", + "camera_decoder.blocks.2.norm2.weight", + "camera_decoder.blocks.2.norm2.bias", + "camera_decoder.blocks.2.mlp.fc1.bias", + "camera_decoder.blocks.2.mlp.fc2.bias", + "camera_decoder.blocks.3.norm1.weight", + "camera_decoder.blocks.3.norm1.bias", + "camera_decoder.blocks.3.attn.qkv.bias", + "camera_decoder.blocks.3.attn.proj.bias", + "camera_decoder.blocks.3.norm2.weight", + "camera_decoder.blocks.3.norm2.bias", + "camera_decoder.blocks.3.mlp.fc1.bias", + "camera_decoder.blocks.3.mlp.fc2.bias", + "camera_decoder.blocks.4.norm1.weight", + "camera_decoder.blocks.4.norm1.bias", + "camera_decoder.blocks.4.attn.qkv.bias", + "camera_decoder.blocks.4.attn.proj.bias", + "camera_decoder.blocks.4.norm2.weight", + "camera_decoder.blocks.4.norm2.bias", + "camera_decoder.blocks.4.mlp.fc1.bias", + "camera_decoder.blocks.4.mlp.fc2.bias", + "camera_decoder.linear_out.bias", + "camera_head.res_conv.0.res_conv1.bias", + "camera_head.res_conv.0.res_conv2.bias", + "camera_head.res_conv.0.res_conv3.bias", + "camera_head.res_conv.1.res_conv1.bias", + "camera_head.res_conv.1.res_conv2.bias", + "camera_head.res_conv.1.res_conv3.bias", + "camera_head.more_mlps.0.bias", + "camera_head.more_mlps.2.bias", + "camera_head.fc_t.bias", + "camera_head.fc_rot.bias" + ], + "lr_scale": 1.0 + }, + "decay": { + "weight_decay": 0.05, + "params": [ + "decoder.0.attn.qkv.weight", + "decoder.0.attn.proj.weight", + "decoder.0.mlp.fc1.weight", + "decoder.0.mlp.fc2.weight", + "decoder.1.attn.qkv.weight", + "decoder.1.attn.proj.weight", + "decoder.1.mlp.fc1.weight", + "decoder.1.mlp.fc2.weight", + "decoder.2.attn.qkv.weight", + "decoder.2.attn.proj.weight", + "decoder.2.mlp.fc1.weight", + "decoder.2.mlp.fc2.weight", + "decoder.3.attn.qkv.weight", + "decoder.3.attn.proj.weight", + "decoder.3.mlp.fc1.weight", + "decoder.3.mlp.fc2.weight", + "decoder.4.attn.qkv.weight", + "decoder.4.attn.proj.weight", + "decoder.4.mlp.fc1.weight", + "decoder.4.mlp.fc2.weight", + "decoder.5.attn.qkv.weight", + "decoder.5.attn.proj.weight", + "decoder.5.mlp.fc1.weight", + "decoder.5.mlp.fc2.weight", + "decoder.6.attn.qkv.weight", + "decoder.6.attn.proj.weight", + "decoder.6.mlp.fc1.weight", + "decoder.6.mlp.fc2.weight", + "decoder.7.attn.qkv.weight", + "decoder.7.attn.proj.weight", + "decoder.7.mlp.fc1.weight", + "decoder.7.mlp.fc2.weight", + "decoder.8.attn.qkv.weight", + "decoder.8.attn.proj.weight", + "decoder.8.mlp.fc1.weight", + "decoder.8.mlp.fc2.weight", + "decoder.9.attn.qkv.weight", + "decoder.9.attn.proj.weight", + "decoder.9.mlp.fc1.weight", + "decoder.9.mlp.fc2.weight", + "decoder.10.attn.qkv.weight", + "decoder.10.attn.proj.weight", + "decoder.10.mlp.fc1.weight", + "decoder.10.mlp.fc2.weight", + "decoder.11.attn.qkv.weight", + "decoder.11.attn.proj.weight", + "decoder.11.mlp.fc1.weight", + "decoder.11.mlp.fc2.weight", + "decoder.12.attn.qkv.weight", + "decoder.12.attn.proj.weight", + "decoder.12.mlp.fc1.weight", + "decoder.12.mlp.fc2.weight", + "decoder.13.attn.qkv.weight", + "decoder.13.attn.proj.weight", + "decoder.13.mlp.fc1.weight", + "decoder.13.mlp.fc2.weight", + "decoder.14.attn.qkv.weight", + "decoder.14.attn.proj.weight", + "decoder.14.mlp.fc1.weight", + "decoder.14.mlp.fc2.weight", + "decoder.15.attn.qkv.weight", + "decoder.15.attn.proj.weight", + "decoder.15.mlp.fc1.weight", + "decoder.15.mlp.fc2.weight", + "decoder.16.attn.qkv.weight", + "decoder.16.attn.proj.weight", + "decoder.16.mlp.fc1.weight", + "decoder.16.mlp.fc2.weight", + "decoder.17.attn.qkv.weight", + "decoder.17.attn.proj.weight", + "decoder.17.mlp.fc1.weight", + "decoder.17.mlp.fc2.weight", + "decoder.18.attn.qkv.weight", + "decoder.18.attn.proj.weight", + "decoder.18.mlp.fc1.weight", + "decoder.18.mlp.fc2.weight", + "decoder.19.attn.qkv.weight", + "decoder.19.attn.proj.weight", + "decoder.19.mlp.fc1.weight", + "decoder.19.mlp.fc2.weight", + "decoder.20.attn.qkv.weight", + "decoder.20.attn.proj.weight", + "decoder.20.mlp.fc1.weight", + "decoder.20.mlp.fc2.weight", + "decoder.21.attn.qkv.weight", + "decoder.21.attn.proj.weight", + "decoder.21.mlp.fc1.weight", + "decoder.21.mlp.fc2.weight", + "decoder.22.attn.qkv.weight", + "decoder.22.attn.proj.weight", + "decoder.22.mlp.fc1.weight", + "decoder.22.mlp.fc2.weight", + "decoder.23.attn.qkv.weight", + "decoder.23.attn.proj.weight", + "decoder.23.mlp.fc1.weight", + "decoder.23.mlp.fc2.weight", + "decoder.24.attn.qkv.weight", + "decoder.24.attn.proj.weight", + "decoder.24.mlp.fc1.weight", + "decoder.24.mlp.fc2.weight", + "decoder.25.attn.qkv.weight", + "decoder.25.attn.proj.weight", + "decoder.25.mlp.fc1.weight", + "decoder.25.mlp.fc2.weight", + "decoder.26.attn.qkv.weight", + "decoder.26.attn.proj.weight", + "decoder.26.mlp.fc1.weight", + "decoder.26.mlp.fc2.weight", + "decoder.27.attn.qkv.weight", + "decoder.27.attn.proj.weight", + "decoder.27.mlp.fc1.weight", + "decoder.27.mlp.fc2.weight", + "decoder.28.attn.qkv.weight", + "decoder.28.attn.proj.weight", + "decoder.28.mlp.fc1.weight", + "decoder.28.mlp.fc2.weight", + "decoder.29.attn.qkv.weight", + "decoder.29.attn.proj.weight", + "decoder.29.mlp.fc1.weight", + "decoder.29.mlp.fc2.weight", + "decoder.30.attn.qkv.weight", + "decoder.30.attn.proj.weight", + "decoder.30.mlp.fc1.weight", + "decoder.30.mlp.fc2.weight", + "decoder.31.attn.qkv.weight", + "decoder.31.attn.proj.weight", + "decoder.31.mlp.fc1.weight", + "decoder.31.mlp.fc2.weight", + "decoder.32.attn.qkv.weight", + "decoder.32.attn.proj.weight", + "decoder.32.mlp.fc1.weight", + "decoder.32.mlp.fc2.weight", + "decoder.33.attn.qkv.weight", + "decoder.33.attn.proj.weight", + "decoder.33.mlp.fc1.weight", + "decoder.33.mlp.fc2.weight", + "decoder.34.attn.qkv.weight", + "decoder.34.attn.proj.weight", + "decoder.34.mlp.fc1.weight", + "decoder.34.mlp.fc2.weight", + "decoder.35.attn.qkv.weight", + "decoder.35.attn.proj.weight", + "decoder.35.mlp.fc1.weight", + "decoder.35.mlp.fc2.weight", + "point_decoder.projects.weight", + "point_decoder.blocks.0.attn.qkv.weight", + "point_decoder.blocks.0.attn.proj.weight", + "point_decoder.blocks.0.mlp.fc1.weight", + "point_decoder.blocks.0.mlp.fc2.weight", + "point_decoder.blocks.1.attn.qkv.weight", + "point_decoder.blocks.1.attn.proj.weight", + "point_decoder.blocks.1.mlp.fc1.weight", + "point_decoder.blocks.1.mlp.fc2.weight", + "point_decoder.blocks.2.attn.qkv.weight", + "point_decoder.blocks.2.attn.proj.weight", + "point_decoder.blocks.2.mlp.fc1.weight", + "point_decoder.blocks.2.mlp.fc2.weight", + "point_decoder.blocks.3.attn.qkv.weight", + "point_decoder.blocks.3.attn.proj.weight", + "point_decoder.blocks.3.mlp.fc1.weight", + "point_decoder.blocks.3.mlp.fc2.weight", + "point_decoder.blocks.4.attn.qkv.weight", + "point_decoder.blocks.4.attn.proj.weight", + "point_decoder.blocks.4.mlp.fc1.weight", + "point_decoder.blocks.4.mlp.fc2.weight", + "point_decoder.linear_out.weight", + "point_head.proj.weight", + "conf_decoder.projects.weight", + "conf_decoder.blocks.0.attn.qkv.weight", + "conf_decoder.blocks.0.attn.proj.weight", + "conf_decoder.blocks.0.mlp.fc1.weight", + "conf_decoder.blocks.0.mlp.fc2.weight", + "conf_decoder.blocks.1.attn.qkv.weight", + "conf_decoder.blocks.1.attn.proj.weight", + "conf_decoder.blocks.1.mlp.fc1.weight", + "conf_decoder.blocks.1.mlp.fc2.weight", + "conf_decoder.blocks.2.attn.qkv.weight", + "conf_decoder.blocks.2.attn.proj.weight", + "conf_decoder.blocks.2.mlp.fc1.weight", + "conf_decoder.blocks.2.mlp.fc2.weight", + "conf_decoder.blocks.3.attn.qkv.weight", + "conf_decoder.blocks.3.attn.proj.weight", + "conf_decoder.blocks.3.mlp.fc1.weight", + "conf_decoder.blocks.3.mlp.fc2.weight", + "conf_decoder.blocks.4.attn.qkv.weight", + "conf_decoder.blocks.4.attn.proj.weight", + "conf_decoder.blocks.4.mlp.fc1.weight", + "conf_decoder.blocks.4.mlp.fc2.weight", + "conf_decoder.linear_out.weight", + "conf_head.proj.weight", + "camera_decoder.projects.weight", + "camera_decoder.blocks.0.attn.qkv.weight", + "camera_decoder.blocks.0.attn.proj.weight", + "camera_decoder.blocks.0.mlp.fc1.weight", + "camera_decoder.blocks.0.mlp.fc2.weight", + "camera_decoder.blocks.1.attn.qkv.weight", + "camera_decoder.blocks.1.attn.proj.weight", + "camera_decoder.blocks.1.mlp.fc1.weight", + "camera_decoder.blocks.1.mlp.fc2.weight", + "camera_decoder.blocks.2.attn.qkv.weight", + "camera_decoder.blocks.2.attn.proj.weight", + "camera_decoder.blocks.2.mlp.fc1.weight", + "camera_decoder.blocks.2.mlp.fc2.weight", + "camera_decoder.blocks.3.attn.qkv.weight", + "camera_decoder.blocks.3.attn.proj.weight", + "camera_decoder.blocks.3.mlp.fc1.weight", + "camera_decoder.blocks.3.mlp.fc2.weight", + "camera_decoder.blocks.4.attn.qkv.weight", + "camera_decoder.blocks.4.attn.proj.weight", + "camera_decoder.blocks.4.mlp.fc1.weight", + "camera_decoder.blocks.4.mlp.fc2.weight", + "camera_decoder.linear_out.weight", + "camera_head.res_conv.0.res_conv1.weight", + "camera_head.res_conv.0.res_conv2.weight", + "camera_head.res_conv.0.res_conv3.weight", + "camera_head.res_conv.1.res_conv1.weight", + "camera_head.res_conv.1.res_conv2.weight", + "camera_head.res_conv.1.res_conv3.weight", + "camera_head.more_mlps.0.weight", + "camera_head.more_mlps.2.weight", + "camera_head.fc_t.weight", + "camera_head.fc_rot.weight" + ], + "lr_scale": 1.0 + } +} +[2026-05-02 09:32:22,943][croco.utils.misc][INFO] - [RANK 0] Resume checkpoint /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_8gpu/checkpoint-last.pth +[2026-05-02 09:32:22,968][croco.utils.misc][INFO] - [RANK 0] Moving optimizer state to device: cuda:0 +[2026-05-02 09:32:22,979][croco.utils.misc][INFO] - [RANK 0] & best_so_far=inf +[2026-05-02 09:32:22,980][croco.utils.misc][INFO] - [RANK 0] With optim & sched! start_epoch=0 +[2026-05-02 09:32:26,731][__main__][INFO] - [RANK 0] Start training for 10 epochs +[2026-05-02 09:32:26,735][__main__][INFO] - [RANK 0] log_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_16gpu/ +[2026-05-02 09:34:11,147][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 0/1087] eta: 1 day, 7:31:30 lr: 0.000000 epoch: 0.0000 (0.0000) step: 0.0000 (0.0000) loss: 5081.2871 (5081.2871) Lcamera_frontend: 4.1017 (4.1017) Ldepth_frontend: 3.8557 (3.8557) Lpmap_frontend: 9.5382 (9.5382) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.0905 (4.0905) Ldepth_mix: 3.8498 (3.8498) Lpmap_mix: 9.5206 (9.5206) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.1017 (4.1017) Ldepth_backend: 3.8447 (3.8447) Lpmap_backend: 9.5162 (9.5162) Ltrack_backend: 0.0000 (0.0000) total: 5081.2871 (5081.2871) time: 104.4072 data: 28.0688 max mem: 37991 +[2026-05-02 09:43:10,412][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 10/1087] eta: 17:30:11 lr: 0.000000 epoch: 0.0046 (0.0046) step: 5.0000 (5.0000) loss: 5081.2871 (4411.0437) Lcamera_frontend: 4.1017 (3.5236) Ldepth_frontend: 4.5584 (5.0590) Lpmap_frontend: 10.3262 (10.2973) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.0905 (3.5128) Ldepth_mix: 4.5539 (5.0601) Lpmap_mix: 10.3126 (10.2883) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.1017 (3.5232) Ldepth_backend: 4.5483 (5.0615) Lpmap_backend: 10.3111 (10.2872) Ltrack_backend: 0.0000 (0.0000) total: 5081.2871 (4411.0437) time: 58.5066 data: 2.6019 max mem: 78413 +[2026-05-02 09:52:24,250][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 20/1087] eta: 16:53:59 lr: 0.000000 epoch: 0.0092 (0.0092) step: 10.0000 (10.0000) loss: 3951.7617 (3824.7185) Lcamera_frontend: 3.1362 (3.0310) Ldepth_frontend: 4.9514 (5.2168) Lpmap_frontend: 10.4624 (10.5351) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.1248 (3.0233) Ldepth_mix: 4.9618 (5.2202) Lpmap_mix: 10.4505 (10.5235) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.1384 (3.0304) Ldepth_backend: 4.9725 (5.2236) Lpmap_backend: 10.4471 (10.5182) Ltrack_backend: 0.0000 (0.0000) total: 3951.7617 (3824.7185) time: 54.6495 data: 0.0480 max mem: 78608 +[2026-05-02 10:01:28,558][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 30/1087] eta: 16:29:46 lr: 0.000001 epoch: 0.0184 (0.0138) step: 20.0000 (15.0000) loss: 3951.7617 (3864.5277) Lcamera_frontend: 3.1362 (3.0669) Ldepth_frontend: 4.4016 (5.0846) Lpmap_frontend: 10.4417 (10.3998) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.1248 (3.0592) Ldepth_mix: 4.4010 (5.0863) Lpmap_mix: 10.4259 (10.3879) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.1384 (3.0663) Ldepth_backend: 4.4006 (5.0880) Lpmap_backend: 10.4266 (10.3822) Ltrack_backend: 0.0000 (0.0000) total: 3951.7617 (3864.5277) time: 54.9066 data: 0.0411 max mem: 78608 +[2026-05-02 10:10:45,804][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 40/1087] eta: 16:18:27 lr: 0.000001 epoch: 0.0276 (0.0184) step: 30.0000 (20.0000) loss: 4213.7056 (3929.6543) Lcamera_frontend: 3.3595 (3.1187) Ldepth_frontend: 4.4016 (5.1378) Lpmap_frontend: 10.8743 (10.5532) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3562 (3.1121) Ldepth_mix: 4.4010 (5.1384) Lpmap_mix: 10.8731 (10.5401) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3618 (3.1185) Ldepth_backend: 4.4006 (5.1391) Lpmap_backend: 10.8691 (10.5334) Ltrack_backend: 0.0000 (0.0000) total: 4213.7056 (3929.6543) time: 55.0775 data: 0.0382 max mem: 78608 +[2026-05-02 10:19:57,442][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 50/1087] eta: 16:06:01 lr: 0.000001 epoch: 0.0368 (0.0230) step: 40.0000 (25.0000) loss: 3829.2021 (4129.3481) Lcamera_frontend: 3.0496 (3.2863) Ldepth_frontend: 4.7301 (5.0672) Lpmap_frontend: 10.8536 (10.4917) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.0203 (3.2795) Ldepth_mix: 4.7268 (5.0677) Lpmap_mix: 10.8350 (10.4786) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.0508 (3.2863) Ldepth_backend: 4.7242 (5.0684) Lpmap_backend: 10.8184 (10.4724) Ltrack_backend: 0.0000 (0.0000) total: 3829.2021 (4129.3481) time: 55.4440 data: 0.0356 max mem: 78608 +[2026-05-02 10:29:03,164][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 60/1087] eta: 15:53:00 lr: 0.000001 epoch: 0.0460 (0.0276) step: 50.0000 (30.0000) loss: 4683.7441 (4309.0910) Lcamera_frontend: 3.7755 (3.4374) Ldepth_frontend: 4.3455 (4.9989) Lpmap_frontend: 10.2377 (10.4435) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.7603 (3.4294) Ldepth_mix: 4.3478 (4.9995) Lpmap_mix: 10.2274 (10.4308) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.7750 (3.4373) Ldepth_backend: 4.3490 (5.0004) Lpmap_backend: 10.2330 (10.4253) Ltrack_backend: 0.0000 (0.0000) total: 4683.7441 (4309.0910) time: 54.8678 data: 0.0347 max mem: 78608 +[2026-05-02 10:38:19,102][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 70/1087] eta: 15:43:31 lr: 0.000001 epoch: 0.0552 (0.0322) step: 60.0000 (35.0000) loss: 4761.5264 (4381.1845) Lcamera_frontend: 3.7872 (3.4981) Ldepth_frontend: 4.1337 (4.9323) Lpmap_frontend: 10.1581 (10.4514) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.7779 (3.4900) Ldepth_mix: 4.1218 (4.9323) Lpmap_mix: 10.1423 (10.4382) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.7848 (3.4980) Ldepth_backend: 4.1105 (4.9326) Lpmap_backend: 10.1356 (10.4330) Ltrack_backend: 0.0000 (0.0000) total: 4761.5264 (4381.1845) time: 55.0822 data: 0.0375 max mem: 78608 +[2026-05-02 10:47:36,569][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 80/1087] eta: 15:34:24 lr: 0.000001 epoch: 0.0644 (0.0368) step: 70.0000 (40.0000) loss: 4113.8643 (4172.3970) Lcamera_frontend: 3.2871 (3.3233) Ldepth_frontend: 4.7628 (5.0174) Lpmap_frontend: 10.1581 (10.4412) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.2732 (3.3154) Ldepth_mix: 4.7717 (5.0186) Lpmap_mix: 10.1423 (10.4278) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.2865 (3.3232) Ldepth_backend: 4.7719 (5.0199) Lpmap_backend: 10.1356 (10.4220) Ltrack_backend: 0.0000 (0.0000) total: 4113.8643 (4172.3970) time: 55.6685 data: 0.0385 max mem: 78608 +[2026-05-02 10:56:52,385][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 90/1087] eta: 15:24:57 lr: 0.000002 epoch: 0.0736 (0.0414) step: 80.0000 (45.0000) loss: 4400.7505 (4400.3412) Lcamera_frontend: 3.5322 (3.5147) Ldepth_frontend: 4.7628 (4.9277) Lpmap_frontend: 10.2119 (10.3956) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5118 (3.5057) Ldepth_mix: 4.7717 (4.9287) Lpmap_mix: 10.1763 (10.3818) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5319 (3.5146) Ldepth_backend: 4.7719 (4.9298) Lpmap_backend: 10.1456 (10.3764) Ltrack_backend: 0.0000 (0.0000) total: 4400.7505 (4400.3412) time: 55.6631 data: 0.0348 max mem: 78608 +[2026-05-02 11:06:10,773][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 100/1087] eta: 15:15:58 lr: 0.000002 epoch: 0.0828 (0.0460) step: 90.0000 (49.9901) loss: 4485.5415 (4313.6692) Lcamera_frontend: 3.5890 (3.4421) Ldepth_frontend: 4.7662 (4.9857) Lpmap_frontend: 10.1037 (10.3825) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5679 (3.4326) Ldepth_mix: 4.7777 (4.9867) Lpmap_mix: 10.0893 (10.3682) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5857 (3.4420) Ldepth_backend: 4.7806 (4.9878) Lpmap_backend: 10.0853 (10.3625) Ltrack_backend: 0.0000 (0.0000) total: 4485.5415 (4313.6692) time: 55.7100 data: 0.0358 max mem: 78608 +[2026-05-02 11:15:41,530][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 110/1087] eta: 15:08:43 lr: 0.000002 epoch: 0.0920 (0.0506) step: 100.0000 (54.9910) loss: 3342.0063 (4247.8372) Lcamera_frontend: 2.6406 (3.3874) Ldepth_frontend: 5.0779 (5.0109) Lpmap_frontend: 9.9820 (10.3502) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.6192 (3.3771) Ldepth_mix: 5.0791 (5.0126) Lpmap_mix: 9.9622 (10.3355) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.6308 (3.3872) Ldepth_backend: 5.0791 (5.0143) Lpmap_backend: 9.9510 (10.3298) Ltrack_backend: 0.0000 (0.0000) total: 3342.0063 (4247.8372) time: 56.4571 data: 0.0390 max mem: 78608 +[2026-05-02 11:24:58,761][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 120/1087] eta: 14:59:19 lr: 0.000002 epoch: 0.1012 (0.0552) step: 110.0000 (59.9917) loss: 3342.0063 (4146.7741) Lcamera_frontend: 2.6406 (3.3030) Ldepth_frontend: 5.5876 (5.0812) Lpmap_frontend: 9.6714 (10.3031) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.6192 (3.2928) Ldepth_mix: 5.6178 (5.0840) Lpmap_mix: 9.6466 (10.2883) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.6308 (3.3027) Ldepth_backend: 5.6452 (5.0867) Lpmap_backend: 9.6339 (10.2825) Ltrack_backend: 0.0000 (0.0000) total: 3342.0063 (4146.7741) time: 56.3993 data: 0.0383 max mem: 78608 +[2026-05-02 11:34:10,888][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 130/1087] eta: 14:49:17 lr: 0.000002 epoch: 0.1104 (0.0598) step: 120.0000 (64.9924) loss: 2631.2615 (4164.0187) Lcamera_frontend: 2.0486 (3.3171) Ldepth_frontend: 6.1372 (5.1172) Lpmap_frontend: 9.6384 (10.2905) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.0071 (3.3068) Ldepth_mix: 6.1516 (5.1200) Lpmap_mix: 9.6245 (10.2759) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.0492 (3.3169) Ldepth_backend: 6.1575 (5.1227) Lpmap_backend: 9.6327 (10.2702) Ltrack_backend: 0.0000 (0.0000) total: 2631.2615 (4164.0187) time: 55.4666 data: 0.0347 max mem: 78608 +[2026-05-02 11:43:18,641][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 140/1087] eta: 14:38:54 lr: 0.000003 epoch: 0.1196 (0.0644) step: 130.0000 (69.9929) loss: 4879.4307 (4204.6965) Lcamera_frontend: 3.9324 (3.3515) Ldepth_frontend: 4.2465 (5.0660) Lpmap_frontend: 10.0131 (10.2895) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.9262 (3.3411) Ldepth_mix: 4.2521 (5.0687) Lpmap_mix: 9.9941 (10.2750) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.9341 (3.3513) Ldepth_backend: 4.2526 (5.0711) Lpmap_backend: 9.9882 (10.2696) Ltrack_backend: 0.0000 (0.0000) total: 4879.4307 (4204.6965) time: 54.9915 data: 0.0353 max mem: 78608 +[2026-05-02 11:52:28,470][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 150/1087] eta: 14:28:53 lr: 0.000003 epoch: 0.1288 (0.0690) step: 140.0000 (74.9934) loss: 4827.8979 (4246.8539) Lcamera_frontend: 3.8852 (3.3870) Ldepth_frontend: 4.4470 (5.0662) Lpmap_frontend: 9.9330 (10.2601) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.8699 (3.3755) Ldepth_mix: 4.4516 (5.0692) Lpmap_mix: 9.9155 (10.2457) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.8851 (3.3868) Ldepth_backend: 4.4521 (5.0718) Lpmap_backend: 9.9124 (10.2407) Ltrack_backend: 0.0000 (0.0000) total: 4827.8979 (4246.8539) time: 54.8778 data: 0.0360 max mem: 78608 +[2026-05-02 12:01:38,236][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 160/1087] eta: 14:18:59 lr: 0.000003 epoch: 0.1380 (0.0736) step: 150.0000 (79.9876) loss: 4349.0220 (4246.2913) Lcamera_frontend: 3.4843 (3.3870) Ldepth_frontend: 4.4890 (5.0494) Lpmap_frontend: 9.5575 (10.2359) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.4510 (3.3753) Ldepth_mix: 4.4851 (5.0522) Lpmap_mix: 9.5402 (10.2213) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4829 (3.3868) Ldepth_backend: 4.4813 (5.0547) Lpmap_backend: 9.5382 (10.2163) Ltrack_backend: 0.0000 (0.0000) total: 4349.0220 (4246.2913) time: 54.9796 data: 0.0352 max mem: 78608 +[2026-05-02 12:10:55,561][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 170/1087] eta: 14:09:50 lr: 0.000003 epoch: 0.1472 (0.0782) step: 160.0000 (84.9883) loss: 3598.7285 (4161.5525) Lcamera_frontend: 2.8777 (3.3157) Ldepth_frontend: 4.9348 (5.0914) Lpmap_frontend: 10.0642 (10.2605) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8557 (3.3044) Ldepth_mix: 4.9476 (5.0938) Lpmap_mix: 10.0399 (10.2456) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.8747 (3.3155) Ldepth_backend: 4.9643 (5.0959) Lpmap_backend: 10.0282 (10.2402) Ltrack_backend: 0.0000 (0.0000) total: 3598.7285 (4161.5525) time: 55.3544 data: 0.0373 max mem: 78608 +[2026-05-02 12:20:10,516][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 180/1087] eta: 14:00:29 lr: 0.000003 epoch: 0.1564 (0.0828) step: 170.0000 (89.9890) loss: 2764.1545 (4105.6668) Lcamera_frontend: 2.1421 (3.2688) Ldepth_frontend: 4.9348 (5.1130) Lpmap_frontend: 10.8191 (10.2671) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.1340 (3.2577) Ldepth_mix: 4.9225 (5.1157) Lpmap_mix: 10.8135 (10.2519) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.1400 (3.2686) Ldepth_backend: 4.9108 (5.1179) Lpmap_backend: 10.8064 (10.2462) Ltrack_backend: 0.0000 (0.0000) total: 2764.1545 (4105.6668) time: 55.6139 data: 0.0548 max mem: 78608 +[2026-05-02 12:29:22,152][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 190/1087] eta: 13:50:52 lr: 0.000003 epoch: 0.1656 (0.0874) step: 180.0000 (94.9895) loss: 3678.9834 (4134.5890) Lcamera_frontend: 2.8929 (3.2924) Ldepth_frontend: 4.6002 (5.1258) Lpmap_frontend: 10.6992 (10.2954) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8951 (3.2812) Ldepth_mix: 4.5934 (5.1285) Lpmap_mix: 10.6930 (10.2804) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.8936 (3.2923) Ldepth_backend: 4.5875 (5.1308) Lpmap_backend: 10.6846 (10.2748) Ltrack_backend: 0.0000 (0.0000) total: 3678.9834 (4134.5890) time: 55.3294 data: 0.0520 max mem: 78608 +[2026-05-02 12:38:26,897][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 200/1087] eta: 13:40:47 lr: 0.000004 epoch: 0.1748 (0.0920) step: 190.0000 (99.9851) loss: 4778.2319 (4191.8865) Lcamera_frontend: 3.8469 (3.3410) Ldepth_frontend: 4.0356 (5.0816) Lpmap_frontend: 10.1077 (10.2607) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.7730 (3.3290) Ldepth_mix: 4.0307 (5.0842) Lpmap_mix: 10.0821 (10.2454) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.8457 (3.3410) Ldepth_backend: 4.0243 (5.0865) Lpmap_backend: 10.0822 (10.2402) Ltrack_backend: 0.0000 (0.0000) total: 4778.2319 (4191.8865) time: 54.8173 data: 0.0351 max mem: 78608 +[2026-05-02 12:47:46,992][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 210/1087] eta: 13:31:52 lr: 0.000004 epoch: 0.1840 (0.0966) step: 200.0000 (104.9858) loss: 4876.9844 (4207.8285) Lcamera_frontend: 3.9352 (3.3538) Ldepth_frontend: 4.4036 (5.0991) Lpmap_frontend: 10.3494 (10.2996) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.9222 (3.3418) Ldepth_mix: 4.3864 (5.1015) Lpmap_mix: 10.3330 (10.2847) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.9361 (3.3537) Ldepth_backend: 4.3715 (5.1037) Lpmap_backend: 10.3395 (10.2797) Ltrack_backend: 0.0000 (0.0000) total: 4876.9844 (4207.8285) time: 55.2393 data: 0.0366 max mem: 78608 +[2026-05-02 12:57:14,917][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 220/1087] eta: 13:23:26 lr: 0.000004 epoch: 0.1932 (0.1012) step: 210.0000 (109.9864) loss: 2947.5679 (4140.6311) Lcamera_frontend: 2.2958 (3.2974) Ldepth_frontend: 5.4840 (5.1449) Lpmap_frontend: 10.8195 (10.2870) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.2882 (3.2854) Ldepth_mix: 5.4841 (5.1478) Lpmap_mix: 10.8116 (10.2721) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.2958 (3.2973) Ldepth_backend: 5.4839 (5.1505) Lpmap_backend: 10.8040 (10.2669) Ltrack_backend: 0.0000 (0.0000) total: 2947.5679 (4140.6311) time: 56.4000 data: 0.0393 max mem: 78608 +[2026-05-02 13:06:25,838][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 230/1087] eta: 13:13:51 lr: 0.000004 epoch: 0.2024 (0.1058) step: 220.0000 (114.9870) loss: 1353.0189 (4095.3858) Lcamera_frontend: 0.9452 (3.2596) Ldepth_frontend: 6.3815 (5.1701) Lpmap_frontend: 9.9192 (10.2703) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 0.9340 (3.2473) Ldepth_mix: 6.4087 (5.1736) Lpmap_mix: 9.8808 (10.2555) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 0.9452 (3.2596) Ldepth_backend: 6.4298 (5.1769) Lpmap_backend: 9.8618 (10.2505) Ltrack_backend: 0.0000 (0.0000) total: 1353.0189 (4095.3858) time: 55.9422 data: 0.0393 max mem: 78608 +[2026-05-02 13:15:42,474][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 240/1087] eta: 13:04:38 lr: 0.000004 epoch: 0.2116 (0.1104) step: 230.0000 (119.9834) loss: 4021.2910 (4077.2525) Lcamera_frontend: 3.2231 (3.2445) Ldepth_frontend: 4.7492 (5.1738) Lpmap_frontend: 9.9525 (10.2648) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.1804 (3.2323) Ldepth_mix: 4.7549 (5.1775) Lpmap_mix: 9.9478 (10.2501) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.2225 (3.2445) Ldepth_backend: 4.7598 (5.1811) Lpmap_backend: 9.9530 (10.2452) Ltrack_backend: 0.0000 (0.0000) total: 4021.2910 (4077.2525) time: 55.3777 data: 0.0362 max mem: 78608 +[2026-05-02 13:24:54,521][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 250/1087] eta: 12:55:10 lr: 0.000005 epoch: 0.2208 (0.1150) step: 240.0000 (124.9841) loss: 4021.2910 (4053.2323) Lcamera_frontend: 3.2231 (3.2247) Ldepth_frontend: 4.7114 (5.1795) Lpmap_frontend: 9.6903 (10.2431) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.1804 (3.2119) Ldepth_mix: 4.7149 (5.1835) Lpmap_mix: 9.6684 (10.2283) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.2225 (3.2246) Ldepth_backend: 4.7115 (5.1873) Lpmap_backend: 9.6601 (10.2235) Ltrack_backend: 0.0000 (0.0000) total: 4021.2910 (4053.2323) time: 55.4340 data: 0.0339 max mem: 78608 +[2026-05-02 13:34:12,265][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 260/1087] eta: 12:46:00 lr: 0.000005 epoch: 0.2300 (0.1196) step: 250.0000 (129.9847) loss: 4096.2056 (4089.7413) Lcamera_frontend: 3.2443 (3.2553) Ldepth_frontend: 4.7496 (5.1702) Lpmap_frontend: 9.2404 (10.2283) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.2141 (3.2423) Ldepth_mix: 4.7389 (5.1742) Lpmap_mix: 9.2168 (10.2135) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.2442 (3.2553) Ldepth_backend: 4.7272 (5.1780) Lpmap_backend: 9.2088 (10.2088) Ltrack_backend: 0.0000 (0.0000) total: 4096.2056 (4089.7413) time: 55.4881 data: 0.0340 max mem: 78608 +[2026-05-02 13:43:32,442][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 270/1087] eta: 12:36:58 lr: 0.000005 epoch: 0.2392 (0.1242) step: 260.0000 (134.9852) loss: 4514.8652 (4075.8639) Lcamera_frontend: 3.6334 (3.2437) Ldepth_frontend: 4.7496 (5.1786) Lpmap_frontend: 9.8342 (10.2301) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.6047 (3.2308) Ldepth_mix: 4.7389 (5.1824) Lpmap_mix: 9.8209 (10.2152) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.6336 (3.2436) Ldepth_backend: 4.7272 (5.1862) Lpmap_backend: 9.8241 (10.2105) Ltrack_backend: 0.0000 (0.0000) total: 4514.8652 (4075.8639) time: 55.8938 data: 0.0351 max mem: 78608 +[2026-05-02 13:52:48,804][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 280/1087] eta: 12:27:43 lr: 0.000005 epoch: 0.2484 (0.1288) step: 270.0000 (139.9822) loss: 4863.7710 (4143.1015) Lcamera_frontend: 3.9223 (3.3000) Ldepth_frontend: 4.6959 (5.1673) Lpmap_frontend: 9.7367 (10.2134) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.9067 (3.2868) Ldepth_mix: 4.7100 (5.1713) Lpmap_mix: 9.7192 (10.1985) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.9224 (3.3000) Ldepth_backend: 4.7193 (5.1750) Lpmap_backend: 9.7193 (10.1939) Ltrack_backend: 0.0000 (0.0000) total: 4863.7710 (4143.1015) time: 55.8260 data: 0.0352 max mem: 78608 +[2026-05-02 14:02:02,794][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 290/1087] eta: 12:18:22 lr: 0.000005 epoch: 0.2576 (0.1334) step: 280.0000 (144.9828) loss: 4426.5898 (4129.0049) Lcamera_frontend: 3.5444 (3.2882) Ldepth_frontend: 4.6959 (5.1711) Lpmap_frontend: 9.6949 (10.2120) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5284 (3.2750) Ldepth_mix: 4.7100 (5.1751) Lpmap_mix: 9.6806 (10.1968) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5446 (3.2882) Ldepth_backend: 4.7193 (5.1789) Lpmap_backend: 9.6830 (10.1921) Ltrack_backend: 0.0000 (0.0000) total: 4426.5898 (4129.0049) time: 55.5175 data: 0.0352 max mem: 78608 diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/mytrain.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/mytrain.py new file mode 100644 index 0000000000000000000000000000000000000000..2b90093ae52624b01cdccc958cc9e2f4c50a287b --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/mytrain.py @@ -0,0 +1,601 @@ +# -------------------------------------------------------- +# training code for CUT3R +# -------------------------------------------------------- +# References: +# DUSt3R: https://github.com/naver/dust3r +# -------------------------------------------------------- +import argparse +import datetime +import json +import numpy as np +import os +import sys +import time +import math +from collections import defaultdict +from pathlib import Path +from typing import Sized +from itertools import islice + +import torch +import torch.backends.cudnn as cudnn +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter + +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + +from dust3r.model import ( + PreTrainedModel, + ARCroco3DStereo, + ARCroco3DStereoConfig, + inf, + strip_module, +) # noqa: F401, needed when loading the model +from dust3r.datasets import get_data_loader +from dust3r.losses_noteacher import * # noqa: F401, needed when loading the model +from dust3r.inference import loss_of_one_batch # noqa +from dust3r.viz import colorize +from dust3r.utils.render import get_render_results +import dust3r.utils.path_to_croco # noqa: F401 +import croco.utils.misc as misc # noqa +from croco.utils.misc import NativeScalerWithGradNormCount as NativeScaler # noqa + +import hydra +from omegaconf import OmegaConf +import logging +import pathlib +from tqdm import tqdm +import random +import builtins +import shutil + +from accelerate import Accelerator +from accelerate import DistributedDataParallelKwargs, InitProcessGroupKwargs +from accelerate.logging import get_logger +from datetime import timedelta +import torch.multiprocessing + +from slamformer.models.slamformer import SLAMFormer # upstream typo: pi3 → slamformer + + +torch.multiprocessing.set_sharing_strategy("file_system") + +printer = get_logger(__name__, log_level="DEBUG") + + +def setup_for_distributed(accelerator: Accelerator): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + force = force or (accelerator.num_processes > 8) + if accelerator.is_main_process or force: + now = datetime.datetime.now().time() + builtin_print("[{}] ".format(now), end="") # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def save_current_code(outdir): + now = datetime.datetime.now() # current date and time + date_time = now.strftime("%m_%d-%H:%M:%S") + src_dir = "." + dst_dir = os.path.join(outdir, "code", "{}".format(date_time)) + shutil.copytree( + src_dir, + dst_dir, + ignore=shutil.ignore_patterns( + ".vscode*", + "assets*", + "example*", + "checkpoints*", + "OLD*", + "logs*", + "out*", + "runs*", + "*.png", + "*.mp4", + "*__pycache__*", + "*.git*", + "*.idea*", + "*.zip", + "*.jpg", + ), + dirs_exist_ok=True, + ) + return dst_dir + + +def train(args): + + accelerator = Accelerator( + gradient_accumulation_steps=args.accum_iter, + mixed_precision="bf16", + kwargs_handlers=[ + DistributedDataParallelKwargs(find_unused_parameters=True), + InitProcessGroupKwargs(timeout=timedelta(seconds=6000)), + ], + ) + device = accelerator.device + + setup_for_distributed(accelerator) + + printer.info("output_dir: " + args.output_dir) + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + if accelerator.is_main_process: + dst_dir = save_current_code(outdir=args.output_dir) + printer.info(f"Saving current code to {dst_dir}") + + # auto resume + if not args.resume: + last_ckpt_fname = os.path.join(args.output_dir, f"checkpoint-last.pth") + #last_ckpt_fname = os.path.join(args.output_dir, f"checkpoint-7.pth") + + args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None + + printer.info("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) + + # fix the seed + seed = args.seed + accelerator.state.process_index + printer.info( + f"Setting seed to {seed} for process {accelerator.state.process_index}" + ) + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + cudnn.benchmark = args.benchmark + + # training dataset and loader + printer.info("Building train dataset %s", args.train_dataset) + # dataset and loader + data_loader_train = build_dataset( + args.train_dataset, + args.batch_size, + args.num_workers, + accelerator=accelerator, + test=False, + fixed_length=args.fixed_length + ) + printer.info("Building test dataset %s", args.test_dataset) + data_loader_test = { + dataset.split("(")[0]: build_dataset( + dataset, + args.batch_size, + args.num_workers, + accelerator=accelerator, + test=True, + fixed_length=True + ) + for dataset in args.test_dataset.split("+") + } + + # model + printer.info("Loading model") + model = SLAMFormer() + teacher = None + + # model: PreTrainedModel = eval(args.model) + printer.info(f"All model parameters: {sum(p.numel() for p in model.parameters())}") + + + printer.info(f">> Creating train criterion = {args.train_criterion}") + train_criterion = eval(args.train_criterion).to(device) + printer.info( + f">> Creating test criterion = {args.test_criterion or args.train_criterion}" + ) + test_criterion = eval(args.test_criterion or args.criterion).to(device) + + model.to(device) + + if args.gradient_checkpointing: + model.gradient_checkpointing_enable() + if args.long_context: + model.fixed_input_length = False + + freeze_keys = None + print('NOTE:', args.pretrained, args.resume) + if args.pretrained and not args.resume: + printer.info(f"Loading pretrained: {args.pretrained}") + ckpt = torch.load(args.pretrained, map_location=device) + ''' + ckpt_ = dict() + for key, v in ckpt.items(): + ckpt_[key[7:]] = v + ''' + ''' + freeze_keys = list(ckpt.keys()) + + ls = dict() + for key, v in ckpt.items(): + if 'aggregator' in key: + key_ = key.replace('aggregator', 'backend_transformer') + ls[key_] = key + for key_ in ls.keys(): + key = ls[key_] + ckpt[key_] = ckpt[key] + ''' + printer.info( + model.load_state_dict(ckpt, strict=False) + ) + del ckpt# in case it occupies memory + ''' + if freeze_keys is None: + freeze_keys = [] + + for name, param in model.named_parameters(): + if 'backend_transformer' not in name: + freeze_keys.append(name) + ''' + ''' + printer.info("Loading teacher model") + ckpt_teacher = torch.load(args.teacher, map_location=device) + teacher.load_state_dict(ckpt_teacher, strict=True) + teacher = teacher.to("cuda") + for p in teacher.parameters(): + p.requires_grad = False + teacher.eval() + del ckpt_teacher + + ''' + # freeze + printer.info("Freezing patch embedding and positional encoding parameters...") + frozen_params = 0 + total_params = 0 + + frozen_param_names = [] + + for name, param in model.named_parameters(): + total_params += param.numel() + param.requires_grad = True + + if hasattr(model, 'encoder'):# and hasattr(model.aggregator, 'patch_embed'): + for param in model.encoder.parameters():#aggregator.patch_embed.parameters(): + if param.requires_grad: + param.requires_grad = False + + if hasattr(model, 'register_token'): + model.register_token.requires_grad = False + + # YIJUN: Skip the freezekeys + ''' + for name, param in model.named_parameters(): + if 'camera_decoder' in name or 'camera_head' in name: + print(name) + param.requires_grad = False + ''' + + for name, p in model.named_parameters(): + if not p.requires_grad: + frozen_params += p.numel() + frozen_param_names.append(name) + + printer.info( + f"Frozen {frozen_params:,} parameters out of {total_params:,} total parameters. ({frozen_params / total_params:.2%})") + printer.info( + f"Trainable parameters: {total_params - frozen_params:,} ({(total_params - frozen_params) / total_params:.2%})") + if frozen_param_names: + printer.info( + f"Example frozen parameters: {', '.join(frozen_param_names[:5])}{'...' if len(frozen_param_names) > 5 else ''}") + + + + # following timm: set wd as 0 for bias and norm layers + param_groups = misc.get_parameter_groups(model, args.weight_decay) + optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) + # print(optimizer) + loss_scaler = NativeScaler(accelerator=accelerator) + + best_so_far = misc.load_model( + args=args, model_without_ddp=model, optimizer=optimizer, loss_scaler=loss_scaler + ) + if best_so_far is None: + best_so_far = float("inf") + + accelerator.even_batches = False + optimizer, model, data_loader_train = accelerator.prepare( + optimizer, model, data_loader_train + ) + + def write_log_stats(epoch, train_stats, test_stats): + if accelerator.is_main_process: + if log_writer is not None: + log_writer.flush() + + log_stats = dict( + epoch=epoch, **{f"train_{k}": v for k, v in train_stats.items()} + ) + for test_name in data_loader_test: + if test_name not in test_stats: + continue + log_stats.update( + {test_name + "_" + k: v for k, v in test_stats[test_name].items()} + ) + + with open( + os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" + ) as f: + f.write(json.dumps(log_stats) + "\n") + + def save_model(epoch, fname, best_so_far, data_iter_step): + misc.save_model( + accelerator=accelerator, + args=args, + model_without_ddp=model, + optimizer=optimizer, + loss_scaler=loss_scaler, + epoch=epoch, + step=data_iter_step, + fname=fname, + best_so_far=best_so_far, + ) + + log_writer = ( + SummaryWriter(log_dir=args.output_dir) if accelerator.is_main_process else None + ) + + printer.info(f"Start training for {args.epochs} epochs") + start_time = time.time() + train_stats = test_stats = {} + + for epoch in range(args.start_epoch, args.epochs + 1): + + # Save immediately the last checkpoint + if epoch > args.start_epoch: + if ( + args.save_freq + and np.allclose(epoch / args.save_freq, int(epoch / args.save_freq)) + or epoch == args.epochs + ): + save_model(epoch - 1, "last", best_so_far, args.start_step) + + new_best = False + + if epoch > args.start_epoch: + if args.keep_freq and epoch % args.keep_freq == 0: + save_model(epoch - 1, str(epoch), best_so_far, args.start_step) + if new_best: + save_model(epoch - 1, "best", best_so_far, args.start_step) + if epoch >= args.epochs: + break # exit after writing last test to disk + + + # Train + train_stats = train_one_epoch( + model, + teacher, + train_criterion, + data_loader_train, + optimizer, + accelerator, + epoch, + loss_scaler, + log_writer=log_writer, + args=args + ) + + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + printer.info("Training time {}".format(total_time_str)) + + save_final_model(accelerator, args, args.epochs, model, best_so_far=best_so_far) + + +def save_final_model(accelerator, args, epoch, model_without_ddp, best_so_far=None): + output_dir = Path(args.output_dir) + checkpoint_path = output_dir / "checkpoint-final.pth" + to_save = { + "args": args, + "model": ( + model_without_ddp + if isinstance(model_without_ddp, dict) + else model_without_ddp.cpu().state_dict() + ), + "epoch": epoch, + } + if best_so_far is not None: + to_save["best_so_far"] = best_so_far + printer.info(f">> Saving model to {checkpoint_path} ...") + misc.save_on_master(accelerator, to_save, checkpoint_path) + + +def build_dataset(dataset, batch_size, num_workers, accelerator, test=False, fixed_length=False): + split = ["Train", "Test"][test] + printer.info(f"Building {split} Data loader for dataset: {dataset}") + loader = get_data_loader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_mem=True, + shuffle=not (test), + drop_last=not (test), + accelerator=accelerator, + fixed_length=fixed_length + ) + return loader + + +def train_one_epoch( + model: torch.nn.Module, + teacher: torch.nn.Module, + criterion: torch.nn.Module, + data_loader: Sized, + optimizer: torch.optim.Optimizer, + accelerator: Accelerator, + epoch: int, + loss_scaler, + args, + log_writer=None, +): + assert torch.backends.cuda.matmul.allow_tf32 == True + + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) + header = "Epoch: [{}]".format(epoch) + accum_iter = args.accum_iter + + def save_model(epoch, fname, best_so_far, data_iter_step): + unwrapped_model = accelerator.unwrap_model(model) + misc.save_model( + accelerator=accelerator, + args=args, + model_without_ddp=unwrapped_model, + optimizer=optimizer, + loss_scaler=loss_scaler, + epoch=epoch, + step=data_iter_step, + fname=fname, + best_so_far=best_so_far, + ) + + if log_writer is not None: + printer.info("log_dir: {}".format(log_writer.log_dir)) + + if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"): + data_loader.dataset.set_epoch(epoch) + if ( + hasattr(data_loader, "batch_sampler") + and hasattr(data_loader.batch_sampler, "batch_sampler") + and hasattr(data_loader.batch_sampler.batch_sampler, "set_epoch") + ): + data_loader.batch_sampler.batch_sampler.set_epoch(epoch) + + + optimizer.zero_grad() + + start_step = args.start_step + + data_iter = metric_logger.log_every(data_loader, args.print_freq, accelerator, header) + + for data_iter_step, batch in enumerate(data_iter): + + with accelerator.accumulate(model): + # change the range of the image to [0, 1] + if isinstance(batch, dict) and "img" in batch: + batch["img"] = (batch["img"] + 1.0) / 2.0 + elif isinstance(batch, list) and all(isinstance(v, dict) and "img" in v for v in batch): + for view in batch: + view["img"] = (view["img"] + 1.0) / 2.0 + + epoch_f = epoch + data_iter_step / len(data_loader) + # we use a per iteration (instead of per epoch) lr scheduler + if data_iter_step % accum_iter == 0: + misc.adjust_learning_rate(optimizer, epoch_f, args) + + epoch_f = epoch + data_iter_step / len(data_loader) + step = int(epoch_f * len(data_loader)) + + result = loss_of_one_batch( + batch, + model, + criterion, + accelerator, + teacher=teacher, + inference=False, + symmetrize_batch=False, + use_amp=bool(args.amp), + ) + + loss, loss_details = result["loss"] # criterion returns two values + + loss_value = float(loss) + + if not math.isfinite(loss_value): + print( + f"Loss is {loss_value}, stopping training, loss details: {loss_details}" + ) + sys.exit(1) + if not result.get("already_backprop", False): + loss_scaler( + loss, + optimizer, + parameters=model.parameters(), + update_grad=True, + clip_grad=1.0, + ) + optimizer.zero_grad() + + is_metric = batch[0]["is_metric"] + curr_num_view = len(batch) + + del loss + + tb_vis_img = (data_iter_step + 1) % accum_iter == 0 and ( + (step + 1) % (args.print_img_freq) + ) == 0 + if not tb_vis_img: + del batch + else: + torch.cuda.empty_cache() + + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(epoch=epoch_f) + metric_logger.update(lr=lr) + metric_logger.update(step=step) + # + metric_logger.update(loss=loss_value, **loss_details) + # + if (data_iter_step + 1) % accum_iter == 0 and ( + (data_iter_step + 1) % (accum_iter * args.print_freq) + ) == 0: + loss_value_reduce = accelerator.gather( + torch.tensor(loss_value).to(accelerator.device) + ).mean() # MUST BE EXECUTED BY ALL NODES + + if log_writer is None: + continue + """ We use epoch_1000x as the x-axis in tensorboard. + This calibrates different curves when batch size changes. + """ + epoch_1000x = int(epoch_f * 1000) + log_writer.add_scalar("train_loss", loss_value_reduce, step) + log_writer.add_scalar("train_lr", lr, step) + log_writer.add_scalar("train_iter", epoch_1000x, step) + for name, val in loss_details.items(): + if isinstance(val, torch.Tensor): + if val.ndim > 0: + continue + if isinstance(val, dict): + continue + log_writer.add_scalar("train_" + name, val, step) + + if ( + data_iter_step % int(args.save_freq * len(data_loader)) == 0 + and data_iter_step != 0 + and data_iter_step != len(data_loader) - 1 + ): + print("saving at step", data_iter_step) + save_model(epoch - 1, "last", float("inf"), data_iter_step) + + # gather the stats from all processes + metric_logger.synchronize_between_processes(accelerator) + printer.info("Averaged stats: %s", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + +def batch_append(original_list, new_list): + for sublist, new_item in zip(original_list, new_list): + sublist.append(new_item) + return original_list + + +@hydra.main( + version_base=None, + config_path=str(os.path.dirname(os.path.abspath(__file__))) + "/../config", + config_name="mytrain.yaml", +) +def run(cfg: OmegaConf): + OmegaConf.resolve(cfg) + logdir = pathlib.Path(cfg.logdir) + logdir.mkdir(parents=True, exist_ok=True) + train(cfg) + + +if __name__ == "__main__": + run() diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/camera_head.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2ff373b4bdb9f5ee017fdb491fb114c8116ba7 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/camera_head.py @@ -0,0 +1,175 @@ +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from streamvggt.layers import Mlp +from streamvggt.layers.block import Block +from streamvggt.heads.head_act import activate_pose + + +class CameraHead(nn.Module): + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + + # Build the trunk using a sequence of transformer blocks. + self.trunk = nn.Sequential( + *[ + Block( + dim=dim_in, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + init_values=init_values, + ) + for _ in range(trunk_depth) + ] + ) + + # Normalizations for camera token and trunk output. + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + # Learnable empty camera pose token. + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + # Module for producing modulation parameters: shift, scale, and a gate. + self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) + + # Adaptive layer normalization without affine parameters. + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp( + in_features=dim_in, + hidden_features=dim_in // 2, + out_features=self.target_dim, + drop=0, + ) + + def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, past_key_values_camera = None, use_cache: bool = False) -> list: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + # Use tokens from the last block for camera prediction. + tokens = aggregated_tokens_list[-1] + + # Extract the camera tokens + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + if use_cache: + pred_pose_enc_list, past_key_values_camera = self.trunk_fn(pose_tokens, num_iterations, past_key_values_camera, use_cache) + return pred_pose_enc_list, past_key_values_camera + else: + pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations, past_key_values_camera=None, use_cache=use_cache) + return pred_pose_enc_list + + def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int, past_key_values_camera, use_cache: bool) -> list: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. + num_iterations (int): Number of refinement iterations. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape # S is expected to be 1. + pred_pose_enc = None + pred_pose_enc_list = [] + + for _ in range(num_iterations): + # Use a learned empty pose for the first iteration. + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + # Detach the previous prediction to avoid backprop through time. + pred_pose_enc = pred_pose_enc.detach() + module_input = self.embed_pose(pred_pose_enc) + + # Generate modulation parameters and split them into shift, scale, and gate components. + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) + + # Adaptive layer normalization and modulation. + pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + if not use_cache: + L = S * 1 + frame_ids = torch.arange(L, device=pose_tokens_modulated.device) // 1 # [0,0,...,1,1,...,S-1] + future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0) + attn_mask = future_frame.to(pose_tokens_modulated.dtype) * torch.finfo(pose_tokens_modulated.dtype).min + else: + attn_mask = None + + if use_cache: + for idx in range(self.trunk_depth): + pose_tokens_modulated, block_kv = self.trunk[idx]( + pose_tokens_modulated, + attn_mask=attn_mask, + past_key_values=past_key_values_camera[idx] if past_key_values_camera[idx] is not None else None, + use_cache=True + ) + past_key_values_camera[idx] = block_kv + else: + for idx in range(self.trunk_depth): + pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, attn_mask=attn_mask) + + # Compute the delta update for the pose encoding. + pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + # Apply final activation functions for translation, quaternion, and field-of-view. + activated_pose = activate_pose( + pred_pose_enc, + trans_act=self.trans_act, + quat_act=self.quat_act, + fl_act=self.fl_act, + ) + pred_pose_enc_list.append(activated_pose) + + if use_cache: + return pred_pose_enc_list, past_key_values_camera + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + return x * (1 + scale) + shift diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/dpt_head.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c6b15350741d637542735f17c526be3daf0296 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/dpt_head.py @@ -0,0 +1,471 @@ +import os +from typing import List, Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .head_act import activate_head +from .utils import create_uv_grid, position_grid_to_embed + + +class DPTHead(nn.Module): + """ + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [4, 11, 17, 23], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=dim_in, + out_channels=oc, + kernel_size=1, + stride=1, + padding=0, + ) + for oc in out_channels + ] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + + self.scratch = _make_scratch( + out_channels, + features, + expand=False, + ) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = None + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_chunk_size: int = 8, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, S, _, H, W = images.shape + + # If frames_chunk_size is not specified or greater than S, process all frames at once + if frames_chunk_size is None or frames_chunk_size >= S: + return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) + + # Otherwise, process frames in chunks to manage memory usage + assert frames_chunk_size > 0 + + # Process frames in batches + all_preds = [] + all_conf = [] + + for frames_start_idx in range(0, S, frames_chunk_size): + frames_end_idx = min(frames_start_idx + frames_chunk_size, S) + + # Process batch of frames + if self.feature_only: + chunk_output = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_output) + else: + chunk_preds, chunk_conf = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_preds) + all_conf.append(chunk_conf) + + # Concatenate results along the sequence dimension + if self.feature_only: + return torch.cat(all_preds, dim=1) + else: + return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + if frames_start_idx is not None and frames_end_idx is not None: + images = images[:, frames_start_idx:frames_end_idx].contiguous() + + B, S, _, H, W = images.shape + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + # Select frames if processing a chunk + if frames_start_idx is not None and frames_end_idx is not None: + x = x[:, frames_start_idx:frames_end_idx] + + x = x.reshape(B * S, -1, x.shape[-1]) + + x = self.norm(x) + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.reshape(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) + + preds = preds.reshape(B, S, *preds.shape[1:]) + conf = conf.reshape(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_1, layer_2, layer_3, layer_4 = features + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + del layer_4_rn, layer_4 + + out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) + del layer_3_rn, layer_3 + + out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) + del layer_2_rn, layer_2 + + out = self.scratch.refinenet1(out, layer_1_rn) + del layer_1_rn, layer_1 + + out = self.scratch.output_conv1(out) + return out + + +def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/head_act.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/head_act.py new file mode 100644 index 0000000000000000000000000000000000000000..14691c0d38a165cbc70a367e0142e3b13ea41db2 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/head_act.py @@ -0,0 +1,116 @@ +import torch +import torch.nn.functional as F + + +def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): + """ + Args: + pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] + trans_act: Activation type for translation component + quat_act: Activation type for quaternion component + fl_act: Activation type for focal length component + + Returns: + Activated pose parameters tensor + """ + T = pred_pose_enc[..., :3] + quat = pred_pose_enc[..., 3:7] + fl = pred_pose_enc[..., 7:] # or fov + + T = base_pose_act(T, trans_act) + quat = base_pose_act(quat, quat_act) + fl = base_pose_act(fl, fl_act) # or fov + + pred_pose_enc = torch.cat([T, quat, fl], dim=-1) + + return pred_pose_enc + + +def base_pose_act(pose_enc, act_type="linear"): + """ + Apply basic activation function to pose parameters. + + Args: + pose_enc: Tensor containing encoded pose parameters + act_type: Activation type ("linear", "inv_log", "exp", "relu") + + Returns: + Activated pose parameters + """ + if act_type == "linear": + return pose_enc + elif act_type == "inv_log": + return inverse_log_transform(pose_enc) + elif act_type == "exp": + return torch.exp(pose_enc) + elif act_type == "relu": + return F.relu(pose_enc) + else: + raise ValueError(f"Unknown act_type: {act_type}") + + +def activate_head(out, activation="norm_exp", conf_activation="expp1"): + """ + Process network output to extract 3D points and confidence values. + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + xyz = fmap[:, :, :, :-1] + conf = fmap[:, :, :, -1] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "inv_log": + pts3d = inverse_log_transform(xyz) + elif activation == "xy_inv_log": + xy, z = xyz.split([2, 1], dim=-1) + z = inverse_log_transform(z) + pts3d = torch.cat([xy * z, z], dim=-1) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +def inverse_log_transform(y): + """ + Apply inverse log transform: sign(y) * (exp(|y|) - 1) + + Args: + y: Input tensor + + Returns: + Transformed tensor + """ + return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_head.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_head.py new file mode 100644 index 0000000000000000000000000000000000000000..eca62742213561cb7cb7c5f6c29d23233e816e41 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_head.py @@ -0,0 +1,102 @@ +import torch.nn as nn +from .dpt_head import DPTHead +from .track_modules.base_track_predictor import BaseTrackerPredictor + + +class TrackHead(nn.Module): + """ + Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. + The tracking is performed iteratively, refining predictions over multiple iterations. + """ + + def __init__( + self, + dim_in, + patch_size=14, + features=128, + iters=4, + predict_conf=True, + stride=2, + corr_levels=7, + corr_radius=4, + hidden_size=384, + ): + """ + Initialize the TrackHead module. + + Args: + dim_in (int): Input dimension of tokens from the backbone. + patch_size (int): Size of image patches used in the vision transformer. + features (int): Number of feature channels in the feature extractor output. + iters (int): Number of refinement iterations for tracking predictions. + predict_conf (bool): Whether to predict confidence scores for tracked points. + stride (int): Stride value for the tracker predictor. + corr_levels (int): Number of correlation pyramid levels + corr_radius (int): Radius for correlation computation, controlling the search area. + hidden_size (int): Size of hidden layers in the tracker network. + """ + super().__init__() + + self.patch_size = patch_size + + # Feature extractor based on DPT architecture + # Processes tokens into feature maps for tracking + self.feature_extractor = DPTHead( + dim_in=dim_in, + patch_size=patch_size, + features=features, + feature_only=True, # Only output features, no activation + down_ratio=2, # Reduces spatial dimensions by factor of 2 + pos_embed=False, + ) + + # Tracker module that predicts point trajectories + # Takes feature maps and predicts coordinates and visibility + self.tracker = BaseTrackerPredictor( + latent_dim=features, # Match the output_dim of feature extractor + predict_conf=predict_conf, + stride=stride, + corr_levels=corr_levels, + corr_radius=corr_radius, + hidden_size=hidden_size, + ) + + self.iters = iters + + def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): + """ + Forward pass of the TrackHead. + + Args: + aggregated_tokens_list (list): List of aggregated tokens from the backbone. + images (torch.Tensor): Input images of shape (B, S, C, H, W) where: + B = batch size, S = sequence length. + patch_start_idx (int): Starting index for patch tokens. + query_points (torch.Tensor, optional): Initial query points to track. + If None, points are initialized by the tracker. + iters (int, optional): Number of refinement iterations. If None, uses self.iters. + + Returns: + tuple: + - coord_preds (torch.Tensor): Predicted coordinates for tracked points. + - vis_scores (torch.Tensor): Visibility scores for tracked points. + - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). + """ + B, S, _, H, W = images.shape + + # Extract features from tokens + # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 + feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) + + # Use default iterations if not specified + if iters is None: + iters = self.iters + + # Perform tracking using the extracted features + coord_preds, vis_scores, conf_scores = self.tracker( + query_points=query_points, + fmaps=feature_maps, + iters=iters, + ) + + return coord_preds, vis_scores, conf_scores diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/__init__.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/base_track_predictor.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..81bc06e873513ac290b06fda523b1862c37d7d17 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/base_track_predictor.py @@ -0,0 +1,195 @@ +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +from .blocks import EfficientUpdateFormer, CorrBlock +from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed +from .modules import Mlp + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=1, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + max_scale=518, + predict_conf=True, + ): + super(BaseTrackerPredictor, self).__init__() + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.max_scale = max_scale + self.predict_conf = predict_conf + + self.flows_emb_dim = latent_dim // 2 + + self.corr_mlp = Mlp( + in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, + hidden_features=self.hidden_size, + out_features=self.latent_dim, + ) + + self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 + + self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.fmap_norm = nn.LayerNorm(self.latent_dim) + self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) + + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + if predict_conf: + self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2, "Input points must be 2D coordinates" + + # apply a layernorm to fmaps here + fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) + fmaps = fmaps.permute(0, 1, 4, 2, 3) + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) + + coord_preds = [] + + # Iterative Refinement + for _ in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + fcorrs = fcorr_fn.corr_sample(track_feats, coords) + + corr_dim = fcorrs.shape[3] + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) + fcorrs_ = self.corr_mlp(fcorrs_) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + # 2D positional embed + pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) + sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) + + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) + + x = transformer_input + sampled_pos_emb + + # Add the query ref token to the track feats + query_ref_token = torch.cat( + [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1 + ) + x = x + query_ref_token.to(x.device).to(x.dtype) + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta, _ = self.updateformer(x) + + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ + + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + if apply_sigmoid: + vis_e = torch.sigmoid(vis_e) + + if self.predict_conf: + conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + if apply_sigmoid: + conf_e = torch.sigmoid(conf_e) + else: + conf_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat, conf_e + else: + return coord_preds, vis_e, conf_e diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/blocks.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..d918dfdc219af05b46911049d3fdfdf02ebac435 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/blocks.py @@ -0,0 +1,237 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import bilinear_sampler +from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + + # Add input LayerNorm before linear projection + self.input_norm = nn.LayerNorm(input_dim) + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + + # Add output LayerNorm before final projection + self.output_norm = nn.LayerNorm(hidden_size) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_class=nn.MultiheadAttention, + ) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_class=nn.MultiheadAttention, + ) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) + + self.apply(_basic_init) + + def forward(self, input_tensor, mask=None): + # Apply input LayerNorm + input_tensor = self.input_norm(input_tensor) + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): + space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) + + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + # Apply output LayerNorm before final projection + tokens = self.output_norm(tokens) + flow = self.flow_head(tokens) + + return flow, None + + +class CorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): + """ + Build a pyramid of feature maps from the input. + + fmaps: Tensor (B, S, C, H, W) + num_levels: number of pyramid levels (each downsampled by factor 2) + radius: search radius for sampling correlation + multiple_track_feats: if True, split the target features per pyramid level + padding_mode: passed to grid_sample / bilinear_sampler + """ + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.num_levels = num_levels + self.radius = radius + self.padding_mode = padding_mode + self.multiple_track_feats = multiple_track_feats + + # Build pyramid: each level is half the spatial resolution of the previous + self.fmaps_pyramid = [fmaps] # level 0 is full resolution + current_fmaps = fmaps + for i in range(num_levels - 1): + B, S, C, H, W = current_fmaps.shape + # Merge batch & sequence dimensions + current_fmaps = current_fmaps.reshape(B * S, C, H, W) + # Avg pool down by factor 2 + current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) + _, _, H_new, W_new = current_fmaps.shape + current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) + self.fmaps_pyramid.append(current_fmaps) + + # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. + # This grid is added to the (scaled) coordinate centroids. + r = self.radius + dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + # delta: for every (dy,dx) displacement (i.e. Δx, Δy) + self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2) + + def corr_sample(self, targets, coords): + """ + Instead of storing the entire correlation pyramid, we compute each level's correlation + volume, sample it immediately, then discard it. This saves GPU memory. + + Args: + targets: Tensor (B, S, N, C) — features for the current targets. + coords: Tensor (B, S, N, 2) — coordinates at full resolution. + + Returns: + Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) + """ + B, S, N, C = targets.shape + + # If you have multiple track features, split them per level. + if self.multiple_track_feats: + targets_split = torch.split(targets, C // self.num_levels, dim=-1) + + out_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + # Get current spatial resolution H, W for this pyramid level. + B, S, C, H, W = fmaps.shape + # Reshape feature maps for correlation computation: + # fmap2s: (B, S, C, H*W) + fmap2s = fmaps.view(B, S, C, H * W) + # Choose appropriate target features. + fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C) + + # Compute correlation directly + corrs = compute_corr_level(fmap1, fmap2s, C) + corrs = corrs.view(B, S, N, H, W) + + # Prepare sampling grid: + # Scale down the coordinates for the current level. + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) + # Make sure our precomputed delta grid is on the same device/dtype. + delta_lvl = self.delta.to(coords.device).to(coords.dtype) + # Now the grid for grid_sample is: + # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) + coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2) + + # Sample from the correlation volume using bilinear interpolation. + # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. + corrs_sampled = bilinear_sampler( + corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode + ) + # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. + corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2) + out_pyramid.append(corrs_sampled) + + # Concatenate all levels along the last dimension. + out = torch.cat(out_pyramid, dim=-1).contiguous() + return out + + +def compute_corr_level(fmap1, fmap2s, C): + # fmap1: (B, S, N, C) + # fmap2s: (B, S, C, H*W) + corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) + corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W) + return corrs / math.sqrt(C) diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/modules.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d2db65ac004d4d1f8b52d037916759b92ddc9c --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/modules.py @@ -0,0 +1,211 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=kernel_size, + padding=1, + stride=stride, + padding_mode="zeros", + ) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=kernel_size, + padding=1, + padding_mode="zeros", + ) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), + self.norm3, + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs + ): + """ + Self attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): + """ + Cross attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/utils.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..45a8edc4b885076edf49e1adb4b629889fb8803f --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/track_modules/utils.py @@ -0,0 +1,216 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Union + + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return ( + pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), + grid, + ) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + coords = coords.detach().clone() + ############################################################ + # IMPORTANT: + coords = coords.to(input.device).to(input.dtype) + ############################################################ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + scale = torch.tensor( + [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype + ) + else: + scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype) + + coords.mul_(scale) # coords = coords * scale + coords.sub_(1) # coords = coords - 1 + + return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/utils.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4c526f2daace4cc6a52daa3ba212df55123ebf79 --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/heads/utils.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn + + +def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +def create_uv_grid( + width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid diff --git a/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/utils/geometry.py b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..8ebd25dbc6cac6b0095956524c4f0628410dd5cb --- /dev/null +++ b/outdoor_v48_16gpu_v2/code/05_02-22:24:00/streamvggt/utils/geometry.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import numpy as np + + +def unproject_depth_map_to_point_map( + depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray +) -> np.ndarray: + """ + Unproject a batch of depth maps to 3D world coordinates. + + Args: + depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) + extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) + intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) + + Returns: + np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) + """ + if isinstance(depth_map, torch.Tensor): + depth_map = depth_map.cpu().numpy() + if isinstance(extrinsics_cam, torch.Tensor): + extrinsics_cam = extrinsics_cam.cpu().numpy() + if isinstance(intrinsics_cam, torch.Tensor): + intrinsics_cam = intrinsics_cam.cpu().numpy() + + world_points_list = [] + for frame_idx in range(depth_map.shape[0]): + cur_world_points, _, _ = depth_to_world_coords_points( + depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx] + ) + world_points_list.append(cur_world_points) + world_points_array = np.stack(world_points_list, axis=0) + + return world_points_array + + +def depth_to_world_coords_points( + depth_map: np.ndarray, + extrinsic: np.ndarray, + intrinsic: np.ndarray, + eps=1e-8, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Convert a depth map to world coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. + + Returns: + tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). + """ + if depth_map is None: + return None, None, None + + # Valid depth mask + point_mask = depth_map > eps + + # Convert depth map to camera coordinates + cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) + + # Multiply with the inverse of extrinsic matrix to transform to world coordinates + # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) + cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] + + R_cam_to_world = cam_to_world_extrinsic[:3, :3] + t_cam_to_world = cam_to_world_extrinsic[:3, 3] + + # Apply the rotation and translation to the camera coordinates + world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3 + # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world + + return world_coords_points, cam_coords_points, point_mask + + +def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Convert a depth map to camera coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + + Returns: + tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) + """ + H, W = depth_map.shape + assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" + assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew" + + # Intrinsic parameters + fu, fv = intrinsic[0, 0], intrinsic[1, 1] + cu, cv = intrinsic[0, 2], intrinsic[1, 2] + + # Generate grid of pixel coordinates + u, v = np.meshgrid(np.arange(W), np.arange(H)) + + # Unproject to camera coordinates + x_cam = (u - cu) * depth_map / fu + y_cam = (v - cv) * depth_map / fv + z_cam = depth_map + + # Stack to form camera coordinates + cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + return cam_coords + + +def closed_form_inverse_se3(se3, R=None, T=None): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. + + If `R` and `T` are provided, they must correspond to the rotation and translation + components of `se3`. Otherwise, they will be extracted from `se3`. + + Args: + se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + R (optional): Nx3x3 array or tensor of rotation matrices. + T (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as `se3`. + + Shapes: + se3: (N, 4, 4) + R: (N, 3, 3) + T: (N, 3, 1) + """ + # Check if se3 is a numpy array or a torch tensor + is_numpy = isinstance(se3, np.ndarray) + + # Validate shapes + if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): + raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") + + # Extract R and T if not provided + if R is None: + R = se3[:, :3, :3] # (N,3,3) + if T is None: + T = se3[:, :3, 3:] # (N,3,1) + + # Transpose R + if is_numpy: + # Compute the transpose of the rotation for NumPy + R_transposed = np.transpose(R, (0, 2, 1)) + # -R^T t for NumPy + top_right = -np.matmul(R_transposed, T) + inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) + else: + R_transposed = R.transpose(1, 2) # (N,3,3) + top_right = -torch.bmm(R_transposed, T) # (N,3,1) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) + inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) + + inverted_matrix[:, :3, :3] = R_transposed + inverted_matrix[:, :3, 3:] = top_right + + return inverted_matrix diff --git a/outdoor_v48_16gpu_v2/mytrain.log b/outdoor_v48_16gpu_v2/mytrain.log new file mode 100644 index 0000000000000000000000000000000000000000..b97e9aa40a97c77f4d8e3167da83779b155b5205 --- /dev/null +++ b/outdoor_v48_16gpu_v2/mytrain.log @@ -0,0 +1,985 @@ +[2026-05-02 22:24:00,638][__main__][INFO] - [RANK 0] output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_16gpu_v2/ +[2026-05-02 22:24:01,213][__main__][INFO] - [RANK 0] Saving current code to /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_16gpu_v2/code/05_02-22:24:00 +[2026-05-02 22:24:01,213][__main__][INFO] - [RANK 0] job dir: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src +[2026-05-02 22:24:01,213][__main__][INFO] - [RANK 0] Setting seed to 0 for process 0 +[2026-05-02 22:24:01,215][__main__][INFO] - [RANK 0] Building train dataset 6000 @ VirtualKITTI2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 6000 @ KITTI360_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", velodyne_root="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 5400 @ Waymo_v2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/waymo_v2", lidar_root="/scratch-shared/wwei2/waymo_v2", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) +[2026-05-02 22:24:01,215][__main__][INFO] - [RANK 0] Building Train Data loader for dataset: 6000 @ VirtualKITTI2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 6000 @ KITTI360_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", velodyne_root="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 5400 @ Waymo_v2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/waymo_v2", lidar_root="/scratch-shared/wwei2/waymo_v2", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) +[2026-05-02 22:27:56,308][__main__][INFO] - [RANK 0] Building test dataset 200 @ VirtualKITTI2_Multi(split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", resolution=(518, 154), num_views=4, seed=42, n_corres=0) +[2026-05-02 22:27:56,308][__main__][INFO] - [RANK 0] Building Test Data loader for dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", resolution=(518, 154), num_views=4, seed=42, n_corres=0) +[2026-05-02 22:27:56,361][__main__][INFO] - [RANK 0] Loading model +[2026-05-02 22:28:02,069][__main__][INFO] - [RANK 0] All model parameters: 958696732 +[2026-05-02 22:28:02,069][__main__][INFO] - [RANK 0] >> Creating train criterion = DistillLoss() +[2026-05-02 22:28:02,070][__main__][INFO] - [RANK 0] >> Creating test criterion = DistillLoss() +[2026-05-02 22:28:02,336][__main__][INFO] - [RANK 0] Freezing patch embedding and positional encoding parameters... +[2026-05-02 22:28:02,341][__main__][INFO] - [RANK 0] Frozen 304,376,832 parameters out of 958,696,732 total parameters. (31.75%) +[2026-05-02 22:28:02,341][__main__][INFO] - [RANK 0] Trainable parameters: 654,319,900 (68.25%) +[2026-05-02 22:28:02,341][__main__][INFO] - [RANK 0] Example frozen parameters: register_token, encoder.cls_token, encoder.pos_embed, encoder.register_tokens, encoder.patch_embed.proj.weight... +[2026-05-02 22:28:02,345][croco.utils.misc][INFO] - [RANK 0] Param groups = { + "no_decay": { + "weight_decay": 0.0, + "params": [ + "decoder.0.norm1.weight", + "decoder.0.norm1.bias", + "decoder.0.attn.qkv.bias", + "decoder.0.attn.proj.bias", + "decoder.0.attn.q_norm.weight", + "decoder.0.attn.q_norm.bias", + "decoder.0.attn.k_norm.weight", + "decoder.0.attn.k_norm.bias", + "decoder.0.ls1.gamma", + "decoder.0.norm2.weight", + "decoder.0.norm2.bias", + "decoder.0.mlp.fc1.bias", + "decoder.0.mlp.fc2.bias", + "decoder.0.ls2.gamma", + "decoder.1.norm1.weight", + "decoder.1.norm1.bias", + "decoder.1.attn.qkv.bias", + "decoder.1.attn.proj.bias", + "decoder.1.attn.q_norm.weight", + "decoder.1.attn.q_norm.bias", + "decoder.1.attn.k_norm.weight", + "decoder.1.attn.k_norm.bias", + "decoder.1.ls1.gamma", + "decoder.1.norm2.weight", + "decoder.1.norm2.bias", + "decoder.1.mlp.fc1.bias", + "decoder.1.mlp.fc2.bias", + "decoder.1.ls2.gamma", + "decoder.2.norm1.weight", + "decoder.2.norm1.bias", + "decoder.2.attn.qkv.bias", + "decoder.2.attn.proj.bias", + "decoder.2.attn.q_norm.weight", + "decoder.2.attn.q_norm.bias", + "decoder.2.attn.k_norm.weight", + "decoder.2.attn.k_norm.bias", + "decoder.2.ls1.gamma", + "decoder.2.norm2.weight", + "decoder.2.norm2.bias", + "decoder.2.mlp.fc1.bias", + "decoder.2.mlp.fc2.bias", + "decoder.2.ls2.gamma", + "decoder.3.norm1.weight", + "decoder.3.norm1.bias", + "decoder.3.attn.qkv.bias", + "decoder.3.attn.proj.bias", + "decoder.3.attn.q_norm.weight", + "decoder.3.attn.q_norm.bias", + "decoder.3.attn.k_norm.weight", + "decoder.3.attn.k_norm.bias", + "decoder.3.ls1.gamma", + "decoder.3.norm2.weight", + "decoder.3.norm2.bias", + "decoder.3.mlp.fc1.bias", + "decoder.3.mlp.fc2.bias", + "decoder.3.ls2.gamma", + "decoder.4.norm1.weight", + "decoder.4.norm1.bias", + "decoder.4.attn.qkv.bias", + "decoder.4.attn.proj.bias", + "decoder.4.attn.q_norm.weight", + "decoder.4.attn.q_norm.bias", + "decoder.4.attn.k_norm.weight", + "decoder.4.attn.k_norm.bias", + "decoder.4.ls1.gamma", + "decoder.4.norm2.weight", + "decoder.4.norm2.bias", + "decoder.4.mlp.fc1.bias", + "decoder.4.mlp.fc2.bias", + "decoder.4.ls2.gamma", + "decoder.5.norm1.weight", + "decoder.5.norm1.bias", + "decoder.5.attn.qkv.bias", + "decoder.5.attn.proj.bias", + "decoder.5.attn.q_norm.weight", + "decoder.5.attn.q_norm.bias", + "decoder.5.attn.k_norm.weight", + "decoder.5.attn.k_norm.bias", + "decoder.5.ls1.gamma", + "decoder.5.norm2.weight", + "decoder.5.norm2.bias", + "decoder.5.mlp.fc1.bias", + "decoder.5.mlp.fc2.bias", + "decoder.5.ls2.gamma", + "decoder.6.norm1.weight", + "decoder.6.norm1.bias", + "decoder.6.attn.qkv.bias", + "decoder.6.attn.proj.bias", + "decoder.6.attn.q_norm.weight", + "decoder.6.attn.q_norm.bias", + "decoder.6.attn.k_norm.weight", + "decoder.6.attn.k_norm.bias", + "decoder.6.ls1.gamma", + "decoder.6.norm2.weight", + "decoder.6.norm2.bias", + "decoder.6.mlp.fc1.bias", + "decoder.6.mlp.fc2.bias", + "decoder.6.ls2.gamma", + "decoder.7.norm1.weight", + "decoder.7.norm1.bias", + "decoder.7.attn.qkv.bias", + "decoder.7.attn.proj.bias", + "decoder.7.attn.q_norm.weight", + "decoder.7.attn.q_norm.bias", + "decoder.7.attn.k_norm.weight", + "decoder.7.attn.k_norm.bias", + "decoder.7.ls1.gamma", + "decoder.7.norm2.weight", + "decoder.7.norm2.bias", + "decoder.7.mlp.fc1.bias", + "decoder.7.mlp.fc2.bias", + "decoder.7.ls2.gamma", + "decoder.8.norm1.weight", + "decoder.8.norm1.bias", + "decoder.8.attn.qkv.bias", + "decoder.8.attn.proj.bias", + "decoder.8.attn.q_norm.weight", + "decoder.8.attn.q_norm.bias", + "decoder.8.attn.k_norm.weight", + "decoder.8.attn.k_norm.bias", + "decoder.8.ls1.gamma", + "decoder.8.norm2.weight", + "decoder.8.norm2.bias", + "decoder.8.mlp.fc1.bias", + "decoder.8.mlp.fc2.bias", + "decoder.8.ls2.gamma", + "decoder.9.norm1.weight", + "decoder.9.norm1.bias", + "decoder.9.attn.qkv.bias", + "decoder.9.attn.proj.bias", + "decoder.9.attn.q_norm.weight", + "decoder.9.attn.q_norm.bias", + "decoder.9.attn.k_norm.weight", + "decoder.9.attn.k_norm.bias", + "decoder.9.ls1.gamma", + "decoder.9.norm2.weight", + "decoder.9.norm2.bias", + "decoder.9.mlp.fc1.bias", + "decoder.9.mlp.fc2.bias", + "decoder.9.ls2.gamma", + "decoder.10.norm1.weight", + "decoder.10.norm1.bias", + "decoder.10.attn.qkv.bias", + "decoder.10.attn.proj.bias", + "decoder.10.attn.q_norm.weight", + "decoder.10.attn.q_norm.bias", + "decoder.10.attn.k_norm.weight", + "decoder.10.attn.k_norm.bias", + "decoder.10.ls1.gamma", + "decoder.10.norm2.weight", + "decoder.10.norm2.bias", + "decoder.10.mlp.fc1.bias", + "decoder.10.mlp.fc2.bias", + "decoder.10.ls2.gamma", + "decoder.11.norm1.weight", + "decoder.11.norm1.bias", + "decoder.11.attn.qkv.bias", + "decoder.11.attn.proj.bias", + "decoder.11.attn.q_norm.weight", + "decoder.11.attn.q_norm.bias", + "decoder.11.attn.k_norm.weight", + "decoder.11.attn.k_norm.bias", + "decoder.11.ls1.gamma", + "decoder.11.norm2.weight", + "decoder.11.norm2.bias", + "decoder.11.mlp.fc1.bias", + "decoder.11.mlp.fc2.bias", + "decoder.11.ls2.gamma", + "decoder.12.norm1.weight", + "decoder.12.norm1.bias", + "decoder.12.attn.qkv.bias", + "decoder.12.attn.proj.bias", + "decoder.12.attn.q_norm.weight", + "decoder.12.attn.q_norm.bias", + "decoder.12.attn.k_norm.weight", + "decoder.12.attn.k_norm.bias", + "decoder.12.ls1.gamma", + "decoder.12.norm2.weight", + "decoder.12.norm2.bias", + "decoder.12.mlp.fc1.bias", + "decoder.12.mlp.fc2.bias", + "decoder.12.ls2.gamma", + "decoder.13.norm1.weight", + "decoder.13.norm1.bias", + "decoder.13.attn.qkv.bias", + "decoder.13.attn.proj.bias", + "decoder.13.attn.q_norm.weight", + "decoder.13.attn.q_norm.bias", + "decoder.13.attn.k_norm.weight", + "decoder.13.attn.k_norm.bias", + "decoder.13.ls1.gamma", + "decoder.13.norm2.weight", + "decoder.13.norm2.bias", + "decoder.13.mlp.fc1.bias", + "decoder.13.mlp.fc2.bias", + "decoder.13.ls2.gamma", + "decoder.14.norm1.weight", + "decoder.14.norm1.bias", + "decoder.14.attn.qkv.bias", + "decoder.14.attn.proj.bias", + "decoder.14.attn.q_norm.weight", + "decoder.14.attn.q_norm.bias", + "decoder.14.attn.k_norm.weight", + "decoder.14.attn.k_norm.bias", + "decoder.14.ls1.gamma", + "decoder.14.norm2.weight", + "decoder.14.norm2.bias", + "decoder.14.mlp.fc1.bias", + "decoder.14.mlp.fc2.bias", + "decoder.14.ls2.gamma", + "decoder.15.norm1.weight", + "decoder.15.norm1.bias", + "decoder.15.attn.qkv.bias", + "decoder.15.attn.proj.bias", + "decoder.15.attn.q_norm.weight", + "decoder.15.attn.q_norm.bias", + "decoder.15.attn.k_norm.weight", + "decoder.15.attn.k_norm.bias", + "decoder.15.ls1.gamma", + "decoder.15.norm2.weight", + "decoder.15.norm2.bias", + "decoder.15.mlp.fc1.bias", + "decoder.15.mlp.fc2.bias", + "decoder.15.ls2.gamma", + "decoder.16.norm1.weight", + "decoder.16.norm1.bias", + "decoder.16.attn.qkv.bias", + "decoder.16.attn.proj.bias", + "decoder.16.attn.q_norm.weight", + "decoder.16.attn.q_norm.bias", + "decoder.16.attn.k_norm.weight", + "decoder.16.attn.k_norm.bias", + "decoder.16.ls1.gamma", + "decoder.16.norm2.weight", + "decoder.16.norm2.bias", + "decoder.16.mlp.fc1.bias", + "decoder.16.mlp.fc2.bias", + "decoder.16.ls2.gamma", + "decoder.17.norm1.weight", + "decoder.17.norm1.bias", + "decoder.17.attn.qkv.bias", + "decoder.17.attn.proj.bias", + "decoder.17.attn.q_norm.weight", + "decoder.17.attn.q_norm.bias", + "decoder.17.attn.k_norm.weight", + "decoder.17.attn.k_norm.bias", + "decoder.17.ls1.gamma", + "decoder.17.norm2.weight", + "decoder.17.norm2.bias", + "decoder.17.mlp.fc1.bias", + "decoder.17.mlp.fc2.bias", + "decoder.17.ls2.gamma", + "decoder.18.norm1.weight", + "decoder.18.norm1.bias", + "decoder.18.attn.qkv.bias", + "decoder.18.attn.proj.bias", + "decoder.18.attn.q_norm.weight", + "decoder.18.attn.q_norm.bias", + "decoder.18.attn.k_norm.weight", + "decoder.18.attn.k_norm.bias", + "decoder.18.ls1.gamma", + "decoder.18.norm2.weight", + "decoder.18.norm2.bias", + "decoder.18.mlp.fc1.bias", + "decoder.18.mlp.fc2.bias", + "decoder.18.ls2.gamma", + "decoder.19.norm1.weight", + "decoder.19.norm1.bias", + "decoder.19.attn.qkv.bias", + "decoder.19.attn.proj.bias", + "decoder.19.attn.q_norm.weight", + "decoder.19.attn.q_norm.bias", + "decoder.19.attn.k_norm.weight", + "decoder.19.attn.k_norm.bias", + "decoder.19.ls1.gamma", + "decoder.19.norm2.weight", + "decoder.19.norm2.bias", + "decoder.19.mlp.fc1.bias", + "decoder.19.mlp.fc2.bias", + "decoder.19.ls2.gamma", + "decoder.20.norm1.weight", + "decoder.20.norm1.bias", + "decoder.20.attn.qkv.bias", + "decoder.20.attn.proj.bias", + "decoder.20.attn.q_norm.weight", + "decoder.20.attn.q_norm.bias", + "decoder.20.attn.k_norm.weight", + "decoder.20.attn.k_norm.bias", + "decoder.20.ls1.gamma", + "decoder.20.norm2.weight", + "decoder.20.norm2.bias", + "decoder.20.mlp.fc1.bias", + "decoder.20.mlp.fc2.bias", + "decoder.20.ls2.gamma", + "decoder.21.norm1.weight", + "decoder.21.norm1.bias", + "decoder.21.attn.qkv.bias", + "decoder.21.attn.proj.bias", + "decoder.21.attn.q_norm.weight", + "decoder.21.attn.q_norm.bias", + "decoder.21.attn.k_norm.weight", + "decoder.21.attn.k_norm.bias", + "decoder.21.ls1.gamma", + "decoder.21.norm2.weight", + "decoder.21.norm2.bias", + "decoder.21.mlp.fc1.bias", + "decoder.21.mlp.fc2.bias", + "decoder.21.ls2.gamma", + "decoder.22.norm1.weight", + "decoder.22.norm1.bias", + "decoder.22.attn.qkv.bias", + "decoder.22.attn.proj.bias", + "decoder.22.attn.q_norm.weight", + "decoder.22.attn.q_norm.bias", + "decoder.22.attn.k_norm.weight", + "decoder.22.attn.k_norm.bias", + "decoder.22.ls1.gamma", + "decoder.22.norm2.weight", + "decoder.22.norm2.bias", + "decoder.22.mlp.fc1.bias", + "decoder.22.mlp.fc2.bias", + "decoder.22.ls2.gamma", + "decoder.23.norm1.weight", + "decoder.23.norm1.bias", + "decoder.23.attn.qkv.bias", + "decoder.23.attn.proj.bias", + "decoder.23.attn.q_norm.weight", + "decoder.23.attn.q_norm.bias", + "decoder.23.attn.k_norm.weight", + "decoder.23.attn.k_norm.bias", + "decoder.23.ls1.gamma", + "decoder.23.norm2.weight", + "decoder.23.norm2.bias", + "decoder.23.mlp.fc1.bias", + "decoder.23.mlp.fc2.bias", + "decoder.23.ls2.gamma", + "decoder.24.norm1.weight", + "decoder.24.norm1.bias", + "decoder.24.attn.qkv.bias", + "decoder.24.attn.proj.bias", + "decoder.24.attn.q_norm.weight", + "decoder.24.attn.q_norm.bias", + "decoder.24.attn.k_norm.weight", + "decoder.24.attn.k_norm.bias", + "decoder.24.ls1.gamma", + "decoder.24.norm2.weight", + "decoder.24.norm2.bias", + "decoder.24.mlp.fc1.bias", + "decoder.24.mlp.fc2.bias", + "decoder.24.ls2.gamma", + "decoder.25.norm1.weight", + "decoder.25.norm1.bias", + "decoder.25.attn.qkv.bias", + "decoder.25.attn.proj.bias", + "decoder.25.attn.q_norm.weight", + "decoder.25.attn.q_norm.bias", + "decoder.25.attn.k_norm.weight", + "decoder.25.attn.k_norm.bias", + "decoder.25.ls1.gamma", + "decoder.25.norm2.weight", + "decoder.25.norm2.bias", + "decoder.25.mlp.fc1.bias", + "decoder.25.mlp.fc2.bias", + "decoder.25.ls2.gamma", + "decoder.26.norm1.weight", + "decoder.26.norm1.bias", + "decoder.26.attn.qkv.bias", + "decoder.26.attn.proj.bias", + "decoder.26.attn.q_norm.weight", + "decoder.26.attn.q_norm.bias", + "decoder.26.attn.k_norm.weight", + "decoder.26.attn.k_norm.bias", + "decoder.26.ls1.gamma", + "decoder.26.norm2.weight", + "decoder.26.norm2.bias", + "decoder.26.mlp.fc1.bias", + "decoder.26.mlp.fc2.bias", + "decoder.26.ls2.gamma", + "decoder.27.norm1.weight", + "decoder.27.norm1.bias", + "decoder.27.attn.qkv.bias", + "decoder.27.attn.proj.bias", + "decoder.27.attn.q_norm.weight", + "decoder.27.attn.q_norm.bias", + "decoder.27.attn.k_norm.weight", + "decoder.27.attn.k_norm.bias", + "decoder.27.ls1.gamma", + "decoder.27.norm2.weight", + "decoder.27.norm2.bias", + "decoder.27.mlp.fc1.bias", + "decoder.27.mlp.fc2.bias", + "decoder.27.ls2.gamma", + "decoder.28.norm1.weight", + "decoder.28.norm1.bias", + "decoder.28.attn.qkv.bias", + "decoder.28.attn.proj.bias", + "decoder.28.attn.q_norm.weight", + "decoder.28.attn.q_norm.bias", + "decoder.28.attn.k_norm.weight", + "decoder.28.attn.k_norm.bias", + "decoder.28.ls1.gamma", + "decoder.28.norm2.weight", + "decoder.28.norm2.bias", + "decoder.28.mlp.fc1.bias", + "decoder.28.mlp.fc2.bias", + "decoder.28.ls2.gamma", + "decoder.29.norm1.weight", + "decoder.29.norm1.bias", + "decoder.29.attn.qkv.bias", + "decoder.29.attn.proj.bias", + "decoder.29.attn.q_norm.weight", + "decoder.29.attn.q_norm.bias", + "decoder.29.attn.k_norm.weight", + "decoder.29.attn.k_norm.bias", + "decoder.29.ls1.gamma", + "decoder.29.norm2.weight", + "decoder.29.norm2.bias", + "decoder.29.mlp.fc1.bias", + "decoder.29.mlp.fc2.bias", + "decoder.29.ls2.gamma", + "decoder.30.norm1.weight", + "decoder.30.norm1.bias", + "decoder.30.attn.qkv.bias", + "decoder.30.attn.proj.bias", + "decoder.30.attn.q_norm.weight", + "decoder.30.attn.q_norm.bias", + "decoder.30.attn.k_norm.weight", + "decoder.30.attn.k_norm.bias", + "decoder.30.ls1.gamma", + "decoder.30.norm2.weight", + "decoder.30.norm2.bias", + "decoder.30.mlp.fc1.bias", + "decoder.30.mlp.fc2.bias", + "decoder.30.ls2.gamma", + "decoder.31.norm1.weight", + "decoder.31.norm1.bias", + "decoder.31.attn.qkv.bias", + "decoder.31.attn.proj.bias", + "decoder.31.attn.q_norm.weight", + "decoder.31.attn.q_norm.bias", + "decoder.31.attn.k_norm.weight", + "decoder.31.attn.k_norm.bias", + "decoder.31.ls1.gamma", + "decoder.31.norm2.weight", + "decoder.31.norm2.bias", + "decoder.31.mlp.fc1.bias", + "decoder.31.mlp.fc2.bias", + "decoder.31.ls2.gamma", + "decoder.32.norm1.weight", + "decoder.32.norm1.bias", + "decoder.32.attn.qkv.bias", + "decoder.32.attn.proj.bias", + "decoder.32.attn.q_norm.weight", + "decoder.32.attn.q_norm.bias", + "decoder.32.attn.k_norm.weight", + "decoder.32.attn.k_norm.bias", + "decoder.32.ls1.gamma", + "decoder.32.norm2.weight", + "decoder.32.norm2.bias", + "decoder.32.mlp.fc1.bias", + "decoder.32.mlp.fc2.bias", + "decoder.32.ls2.gamma", + "decoder.33.norm1.weight", + "decoder.33.norm1.bias", + "decoder.33.attn.qkv.bias", + "decoder.33.attn.proj.bias", + "decoder.33.attn.q_norm.weight", + "decoder.33.attn.q_norm.bias", + "decoder.33.attn.k_norm.weight", + "decoder.33.attn.k_norm.bias", + "decoder.33.ls1.gamma", + "decoder.33.norm2.weight", + "decoder.33.norm2.bias", + "decoder.33.mlp.fc1.bias", + "decoder.33.mlp.fc2.bias", + "decoder.33.ls2.gamma", + "decoder.34.norm1.weight", + "decoder.34.norm1.bias", + "decoder.34.attn.qkv.bias", + "decoder.34.attn.proj.bias", + "decoder.34.attn.q_norm.weight", + "decoder.34.attn.q_norm.bias", + "decoder.34.attn.k_norm.weight", + "decoder.34.attn.k_norm.bias", + "decoder.34.ls1.gamma", + "decoder.34.norm2.weight", + "decoder.34.norm2.bias", + "decoder.34.mlp.fc1.bias", + "decoder.34.mlp.fc2.bias", + "decoder.34.ls2.gamma", + "decoder.35.norm1.weight", + "decoder.35.norm1.bias", + "decoder.35.attn.qkv.bias", + "decoder.35.attn.proj.bias", + "decoder.35.attn.q_norm.weight", + "decoder.35.attn.q_norm.bias", + "decoder.35.attn.k_norm.weight", + "decoder.35.attn.k_norm.bias", + "decoder.35.ls1.gamma", + "decoder.35.norm2.weight", + "decoder.35.norm2.bias", + "decoder.35.mlp.fc1.bias", + "decoder.35.mlp.fc2.bias", + "decoder.35.ls2.gamma", + "point_decoder.projects.bias", + "point_decoder.blocks.0.norm1.weight", + "point_decoder.blocks.0.norm1.bias", + "point_decoder.blocks.0.attn.qkv.bias", + "point_decoder.blocks.0.attn.proj.bias", + "point_decoder.blocks.0.norm2.weight", + "point_decoder.blocks.0.norm2.bias", + "point_decoder.blocks.0.mlp.fc1.bias", + "point_decoder.blocks.0.mlp.fc2.bias", + "point_decoder.blocks.1.norm1.weight", + "point_decoder.blocks.1.norm1.bias", + "point_decoder.blocks.1.attn.qkv.bias", + "point_decoder.blocks.1.attn.proj.bias", + "point_decoder.blocks.1.norm2.weight", + "point_decoder.blocks.1.norm2.bias", + "point_decoder.blocks.1.mlp.fc1.bias", + "point_decoder.blocks.1.mlp.fc2.bias", + "point_decoder.blocks.2.norm1.weight", + "point_decoder.blocks.2.norm1.bias", + "point_decoder.blocks.2.attn.qkv.bias", + "point_decoder.blocks.2.attn.proj.bias", + "point_decoder.blocks.2.norm2.weight", + "point_decoder.blocks.2.norm2.bias", + "point_decoder.blocks.2.mlp.fc1.bias", + "point_decoder.blocks.2.mlp.fc2.bias", + "point_decoder.blocks.3.norm1.weight", + "point_decoder.blocks.3.norm1.bias", + "point_decoder.blocks.3.attn.qkv.bias", + "point_decoder.blocks.3.attn.proj.bias", + "point_decoder.blocks.3.norm2.weight", + "point_decoder.blocks.3.norm2.bias", + "point_decoder.blocks.3.mlp.fc1.bias", + "point_decoder.blocks.3.mlp.fc2.bias", + "point_decoder.blocks.4.norm1.weight", + "point_decoder.blocks.4.norm1.bias", + "point_decoder.blocks.4.attn.qkv.bias", + "point_decoder.blocks.4.attn.proj.bias", + "point_decoder.blocks.4.norm2.weight", + "point_decoder.blocks.4.norm2.bias", + "point_decoder.blocks.4.mlp.fc1.bias", + "point_decoder.blocks.4.mlp.fc2.bias", + "point_decoder.linear_out.bias", + "point_head.proj.bias", + "conf_decoder.projects.bias", + "conf_decoder.blocks.0.norm1.weight", + "conf_decoder.blocks.0.norm1.bias", + "conf_decoder.blocks.0.attn.qkv.bias", + "conf_decoder.blocks.0.attn.proj.bias", + "conf_decoder.blocks.0.norm2.weight", + "conf_decoder.blocks.0.norm2.bias", + "conf_decoder.blocks.0.mlp.fc1.bias", + "conf_decoder.blocks.0.mlp.fc2.bias", + "conf_decoder.blocks.1.norm1.weight", + "conf_decoder.blocks.1.norm1.bias", + "conf_decoder.blocks.1.attn.qkv.bias", + "conf_decoder.blocks.1.attn.proj.bias", + "conf_decoder.blocks.1.norm2.weight", + "conf_decoder.blocks.1.norm2.bias", + "conf_decoder.blocks.1.mlp.fc1.bias", + "conf_decoder.blocks.1.mlp.fc2.bias", + "conf_decoder.blocks.2.norm1.weight", + "conf_decoder.blocks.2.norm1.bias", + "conf_decoder.blocks.2.attn.qkv.bias", + "conf_decoder.blocks.2.attn.proj.bias", + "conf_decoder.blocks.2.norm2.weight", + "conf_decoder.blocks.2.norm2.bias", + "conf_decoder.blocks.2.mlp.fc1.bias", + "conf_decoder.blocks.2.mlp.fc2.bias", + "conf_decoder.blocks.3.norm1.weight", + "conf_decoder.blocks.3.norm1.bias", + "conf_decoder.blocks.3.attn.qkv.bias", + "conf_decoder.blocks.3.attn.proj.bias", + "conf_decoder.blocks.3.norm2.weight", + "conf_decoder.blocks.3.norm2.bias", + "conf_decoder.blocks.3.mlp.fc1.bias", + "conf_decoder.blocks.3.mlp.fc2.bias", + "conf_decoder.blocks.4.norm1.weight", + "conf_decoder.blocks.4.norm1.bias", + "conf_decoder.blocks.4.attn.qkv.bias", + "conf_decoder.blocks.4.attn.proj.bias", + "conf_decoder.blocks.4.norm2.weight", + "conf_decoder.blocks.4.norm2.bias", + "conf_decoder.blocks.4.mlp.fc1.bias", + "conf_decoder.blocks.4.mlp.fc2.bias", + "conf_decoder.linear_out.bias", + "conf_head.proj.bias", + "camera_decoder.projects.bias", + "camera_decoder.blocks.0.norm1.weight", + "camera_decoder.blocks.0.norm1.bias", + "camera_decoder.blocks.0.attn.qkv.bias", + "camera_decoder.blocks.0.attn.proj.bias", + "camera_decoder.blocks.0.norm2.weight", + "camera_decoder.blocks.0.norm2.bias", + "camera_decoder.blocks.0.mlp.fc1.bias", + "camera_decoder.blocks.0.mlp.fc2.bias", + "camera_decoder.blocks.1.norm1.weight", + "camera_decoder.blocks.1.norm1.bias", + "camera_decoder.blocks.1.attn.qkv.bias", + "camera_decoder.blocks.1.attn.proj.bias", + "camera_decoder.blocks.1.norm2.weight", + "camera_decoder.blocks.1.norm2.bias", + "camera_decoder.blocks.1.mlp.fc1.bias", + "camera_decoder.blocks.1.mlp.fc2.bias", + "camera_decoder.blocks.2.norm1.weight", + "camera_decoder.blocks.2.norm1.bias", + "camera_decoder.blocks.2.attn.qkv.bias", + "camera_decoder.blocks.2.attn.proj.bias", + "camera_decoder.blocks.2.norm2.weight", + "camera_decoder.blocks.2.norm2.bias", + "camera_decoder.blocks.2.mlp.fc1.bias", + "camera_decoder.blocks.2.mlp.fc2.bias", + "camera_decoder.blocks.3.norm1.weight", + "camera_decoder.blocks.3.norm1.bias", + "camera_decoder.blocks.3.attn.qkv.bias", + "camera_decoder.blocks.3.attn.proj.bias", + "camera_decoder.blocks.3.norm2.weight", + "camera_decoder.blocks.3.norm2.bias", + "camera_decoder.blocks.3.mlp.fc1.bias", + "camera_decoder.blocks.3.mlp.fc2.bias", + "camera_decoder.blocks.4.norm1.weight", + "camera_decoder.blocks.4.norm1.bias", + "camera_decoder.blocks.4.attn.qkv.bias", + "camera_decoder.blocks.4.attn.proj.bias", + "camera_decoder.blocks.4.norm2.weight", + "camera_decoder.blocks.4.norm2.bias", + "camera_decoder.blocks.4.mlp.fc1.bias", + "camera_decoder.blocks.4.mlp.fc2.bias", + "camera_decoder.linear_out.bias", + "camera_head.res_conv.0.res_conv1.bias", + "camera_head.res_conv.0.res_conv2.bias", + "camera_head.res_conv.0.res_conv3.bias", + "camera_head.res_conv.1.res_conv1.bias", + "camera_head.res_conv.1.res_conv2.bias", + "camera_head.res_conv.1.res_conv3.bias", + "camera_head.more_mlps.0.bias", + "camera_head.more_mlps.2.bias", + "camera_head.fc_t.bias", + "camera_head.fc_rot.bias" + ], + "lr_scale": 1.0 + }, + "decay": { + "weight_decay": 0.05, + "params": [ + "decoder.0.attn.qkv.weight", + "decoder.0.attn.proj.weight", + "decoder.0.mlp.fc1.weight", + "decoder.0.mlp.fc2.weight", + "decoder.1.attn.qkv.weight", + "decoder.1.attn.proj.weight", + "decoder.1.mlp.fc1.weight", + "decoder.1.mlp.fc2.weight", + "decoder.2.attn.qkv.weight", + "decoder.2.attn.proj.weight", + "decoder.2.mlp.fc1.weight", + "decoder.2.mlp.fc2.weight", + "decoder.3.attn.qkv.weight", + "decoder.3.attn.proj.weight", + "decoder.3.mlp.fc1.weight", + "decoder.3.mlp.fc2.weight", + "decoder.4.attn.qkv.weight", + "decoder.4.attn.proj.weight", + "decoder.4.mlp.fc1.weight", + "decoder.4.mlp.fc2.weight", + "decoder.5.attn.qkv.weight", + "decoder.5.attn.proj.weight", + "decoder.5.mlp.fc1.weight", + "decoder.5.mlp.fc2.weight", + "decoder.6.attn.qkv.weight", + "decoder.6.attn.proj.weight", + "decoder.6.mlp.fc1.weight", + "decoder.6.mlp.fc2.weight", + "decoder.7.attn.qkv.weight", + "decoder.7.attn.proj.weight", + "decoder.7.mlp.fc1.weight", + "decoder.7.mlp.fc2.weight", + "decoder.8.attn.qkv.weight", + "decoder.8.attn.proj.weight", + "decoder.8.mlp.fc1.weight", + "decoder.8.mlp.fc2.weight", + "decoder.9.attn.qkv.weight", + "decoder.9.attn.proj.weight", + "decoder.9.mlp.fc1.weight", + "decoder.9.mlp.fc2.weight", + "decoder.10.attn.qkv.weight", + "decoder.10.attn.proj.weight", + "decoder.10.mlp.fc1.weight", + "decoder.10.mlp.fc2.weight", + "decoder.11.attn.qkv.weight", + "decoder.11.attn.proj.weight", + "decoder.11.mlp.fc1.weight", + "decoder.11.mlp.fc2.weight", + "decoder.12.attn.qkv.weight", + "decoder.12.attn.proj.weight", + "decoder.12.mlp.fc1.weight", + "decoder.12.mlp.fc2.weight", + "decoder.13.attn.qkv.weight", + "decoder.13.attn.proj.weight", + "decoder.13.mlp.fc1.weight", + "decoder.13.mlp.fc2.weight", + "decoder.14.attn.qkv.weight", + "decoder.14.attn.proj.weight", + "decoder.14.mlp.fc1.weight", + "decoder.14.mlp.fc2.weight", + "decoder.15.attn.qkv.weight", + "decoder.15.attn.proj.weight", + "decoder.15.mlp.fc1.weight", + "decoder.15.mlp.fc2.weight", + "decoder.16.attn.qkv.weight", + "decoder.16.attn.proj.weight", + "decoder.16.mlp.fc1.weight", + "decoder.16.mlp.fc2.weight", + "decoder.17.attn.qkv.weight", + "decoder.17.attn.proj.weight", + "decoder.17.mlp.fc1.weight", + "decoder.17.mlp.fc2.weight", + "decoder.18.attn.qkv.weight", + "decoder.18.attn.proj.weight", + "decoder.18.mlp.fc1.weight", + "decoder.18.mlp.fc2.weight", + "decoder.19.attn.qkv.weight", + "decoder.19.attn.proj.weight", + "decoder.19.mlp.fc1.weight", + "decoder.19.mlp.fc2.weight", + "decoder.20.attn.qkv.weight", + "decoder.20.attn.proj.weight", + "decoder.20.mlp.fc1.weight", + "decoder.20.mlp.fc2.weight", + "decoder.21.attn.qkv.weight", + "decoder.21.attn.proj.weight", + "decoder.21.mlp.fc1.weight", + "decoder.21.mlp.fc2.weight", + "decoder.22.attn.qkv.weight", + "decoder.22.attn.proj.weight", + "decoder.22.mlp.fc1.weight", + "decoder.22.mlp.fc2.weight", + "decoder.23.attn.qkv.weight", + "decoder.23.attn.proj.weight", + "decoder.23.mlp.fc1.weight", + "decoder.23.mlp.fc2.weight", + "decoder.24.attn.qkv.weight", + "decoder.24.attn.proj.weight", + "decoder.24.mlp.fc1.weight", + "decoder.24.mlp.fc2.weight", + "decoder.25.attn.qkv.weight", + "decoder.25.attn.proj.weight", + "decoder.25.mlp.fc1.weight", + "decoder.25.mlp.fc2.weight", + "decoder.26.attn.qkv.weight", + "decoder.26.attn.proj.weight", + "decoder.26.mlp.fc1.weight", + "decoder.26.mlp.fc2.weight", + "decoder.27.attn.qkv.weight", + "decoder.27.attn.proj.weight", + "decoder.27.mlp.fc1.weight", + "decoder.27.mlp.fc2.weight", + "decoder.28.attn.qkv.weight", + "decoder.28.attn.proj.weight", + "decoder.28.mlp.fc1.weight", + "decoder.28.mlp.fc2.weight", + "decoder.29.attn.qkv.weight", + "decoder.29.attn.proj.weight", + "decoder.29.mlp.fc1.weight", + "decoder.29.mlp.fc2.weight", + "decoder.30.attn.qkv.weight", + "decoder.30.attn.proj.weight", + "decoder.30.mlp.fc1.weight", + "decoder.30.mlp.fc2.weight", + "decoder.31.attn.qkv.weight", + "decoder.31.attn.proj.weight", + "decoder.31.mlp.fc1.weight", + "decoder.31.mlp.fc2.weight", + "decoder.32.attn.qkv.weight", + "decoder.32.attn.proj.weight", + "decoder.32.mlp.fc1.weight", + "decoder.32.mlp.fc2.weight", + "decoder.33.attn.qkv.weight", + "decoder.33.attn.proj.weight", + "decoder.33.mlp.fc1.weight", + "decoder.33.mlp.fc2.weight", + "decoder.34.attn.qkv.weight", + "decoder.34.attn.proj.weight", + "decoder.34.mlp.fc1.weight", + "decoder.34.mlp.fc2.weight", + "decoder.35.attn.qkv.weight", + "decoder.35.attn.proj.weight", + "decoder.35.mlp.fc1.weight", + "decoder.35.mlp.fc2.weight", + "point_decoder.projects.weight", + "point_decoder.blocks.0.attn.qkv.weight", + "point_decoder.blocks.0.attn.proj.weight", + "point_decoder.blocks.0.mlp.fc1.weight", + "point_decoder.blocks.0.mlp.fc2.weight", + "point_decoder.blocks.1.attn.qkv.weight", + "point_decoder.blocks.1.attn.proj.weight", + "point_decoder.blocks.1.mlp.fc1.weight", + "point_decoder.blocks.1.mlp.fc2.weight", + "point_decoder.blocks.2.attn.qkv.weight", + "point_decoder.blocks.2.attn.proj.weight", + "point_decoder.blocks.2.mlp.fc1.weight", + "point_decoder.blocks.2.mlp.fc2.weight", + "point_decoder.blocks.3.attn.qkv.weight", + "point_decoder.blocks.3.attn.proj.weight", + "point_decoder.blocks.3.mlp.fc1.weight", + "point_decoder.blocks.3.mlp.fc2.weight", + "point_decoder.blocks.4.attn.qkv.weight", + "point_decoder.blocks.4.attn.proj.weight", + "point_decoder.blocks.4.mlp.fc1.weight", + "point_decoder.blocks.4.mlp.fc2.weight", + "point_decoder.linear_out.weight", + "point_head.proj.weight", + "conf_decoder.projects.weight", + "conf_decoder.blocks.0.attn.qkv.weight", + "conf_decoder.blocks.0.attn.proj.weight", + "conf_decoder.blocks.0.mlp.fc1.weight", + "conf_decoder.blocks.0.mlp.fc2.weight", + "conf_decoder.blocks.1.attn.qkv.weight", + "conf_decoder.blocks.1.attn.proj.weight", + "conf_decoder.blocks.1.mlp.fc1.weight", + "conf_decoder.blocks.1.mlp.fc2.weight", + "conf_decoder.blocks.2.attn.qkv.weight", + "conf_decoder.blocks.2.attn.proj.weight", + "conf_decoder.blocks.2.mlp.fc1.weight", + "conf_decoder.blocks.2.mlp.fc2.weight", + "conf_decoder.blocks.3.attn.qkv.weight", + "conf_decoder.blocks.3.attn.proj.weight", + "conf_decoder.blocks.3.mlp.fc1.weight", + "conf_decoder.blocks.3.mlp.fc2.weight", + "conf_decoder.blocks.4.attn.qkv.weight", + "conf_decoder.blocks.4.attn.proj.weight", + "conf_decoder.blocks.4.mlp.fc1.weight", + "conf_decoder.blocks.4.mlp.fc2.weight", + "conf_decoder.linear_out.weight", + "conf_head.proj.weight", + "camera_decoder.projects.weight", + "camera_decoder.blocks.0.attn.qkv.weight", + "camera_decoder.blocks.0.attn.proj.weight", + "camera_decoder.blocks.0.mlp.fc1.weight", + "camera_decoder.blocks.0.mlp.fc2.weight", + "camera_decoder.blocks.1.attn.qkv.weight", + "camera_decoder.blocks.1.attn.proj.weight", + "camera_decoder.blocks.1.mlp.fc1.weight", + "camera_decoder.blocks.1.mlp.fc2.weight", + "camera_decoder.blocks.2.attn.qkv.weight", + "camera_decoder.blocks.2.attn.proj.weight", + "camera_decoder.blocks.2.mlp.fc1.weight", + "camera_decoder.blocks.2.mlp.fc2.weight", + "camera_decoder.blocks.3.attn.qkv.weight", + "camera_decoder.blocks.3.attn.proj.weight", + "camera_decoder.blocks.3.mlp.fc1.weight", + "camera_decoder.blocks.3.mlp.fc2.weight", + "camera_decoder.blocks.4.attn.qkv.weight", + "camera_decoder.blocks.4.attn.proj.weight", + "camera_decoder.blocks.4.mlp.fc1.weight", + "camera_decoder.blocks.4.mlp.fc2.weight", + "camera_decoder.linear_out.weight", + "camera_head.res_conv.0.res_conv1.weight", + "camera_head.res_conv.0.res_conv2.weight", + "camera_head.res_conv.0.res_conv3.weight", + "camera_head.res_conv.1.res_conv1.weight", + "camera_head.res_conv.1.res_conv2.weight", + "camera_head.res_conv.1.res_conv3.weight", + "camera_head.more_mlps.0.weight", + "camera_head.more_mlps.2.weight", + "camera_head.fc_t.weight", + "camera_head.fc_rot.weight" + ], + "lr_scale": 1.0 + } +} +[2026-05-02 22:28:05,615][croco.utils.misc][INFO] - [RANK 0] Resume checkpoint /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth +[2026-05-02 22:28:05,639][croco.utils.misc][INFO] - [RANK 0] Moving optimizer state to device: cuda:0 +[2026-05-02 22:28:05,650][croco.utils.misc][INFO] - [RANK 0] & best_so_far=inf +[2026-05-02 22:28:05,650][croco.utils.misc][INFO] - [RANK 0] With optim & sched! start_epoch=0 +[2026-05-02 22:28:09,695][__main__][INFO] - [RANK 0] Start training for 10 epochs +[2026-05-02 22:28:09,699][__main__][INFO] - [RANK 0] log_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_16gpu_v2/ +[2026-05-02 22:29:54,327][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 0/1087] eta: 1 day, 7:35:24 lr: 0.000000 epoch: 0.0000 (0.0000) step: 0.0000 (0.0000) loss: 4202.3013 (4202.3013) Lcamera_frontend: 3.3617 (3.3617) Ldepth_frontend: 3.0077 (3.0077) Lpmap_frontend: 11.0651 (11.0651) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3558 (3.3558) Ldepth_mix: 3.0005 (3.0005) Lpmap_mix: 11.0574 (11.0574) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3619 (3.3619) Ldepth_backend: 2.9936 (2.9936) Lpmap_backend: 11.0585 (11.0585) Ltrack_backend: 0.0000 (0.0000) total: 4202.3013 (4202.3013) time: 104.6228 data: 26.5225 max mem: 37991 +[2026-05-02 22:38:52,688][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 10/1087] eta: 17:29:11 lr: 0.000000 epoch: 0.0046 (0.0046) step: 5.0000 (5.0000) loss: 4242.6206 (3971.1784) Lcamera_frontend: 3.3956 (3.1482) Ldepth_frontend: 3.8659 (4.7507) Lpmap_frontend: 11.4682 (11.4307) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3859 (3.1407) Ldepth_mix: 3.8547 (4.7488) Lpmap_mix: 11.4647 (11.4261) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3970 (3.1482) Ldepth_backend: 3.8452 (4.7473) Lpmap_backend: 11.4647 (11.4284) Ltrack_backend: 0.0000 (0.0000) total: 4242.6206 (3971.1784) time: 58.4508 data: 2.4434 max mem: 78413 +[2026-05-02 22:48:07,197][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 20/1087] eta: 16:54:02 lr: 0.000000 epoch: 0.0092 (0.0092) step: 10.0000 (10.0000) loss: 3202.9844 (3467.7612) Lcamera_frontend: 2.5076 (2.7240) Ldepth_frontend: 4.4754 (5.0151) Lpmap_frontend: 11.6369 (11.6323) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.4981 (2.7169) Ldepth_mix: 4.4711 (5.0133) Lpmap_mix: 11.6346 (11.6265) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.5080 (2.7240) Ldepth_backend: 4.4647 (5.0113) Lpmap_backend: 11.6393 (11.6276) Ltrack_backend: 0.0000 (0.0000) total: 3202.9844 (3467.7612) time: 54.6422 data: 0.0388 max mem: 78608 +[2026-05-02 22:57:15,910][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 30/1087] eta: 16:32:19 lr: 0.000001 epoch: 0.0184 (0.0138) step: 20.0000 (15.0000) loss: 3202.9844 (3465.4637) Lcamera_frontend: 2.5076 (2.7249) Ldepth_frontend: 4.0582 (4.8808) Lpmap_frontend: 11.6826 (11.4852) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.4981 (2.7183) Ldepth_mix: 4.0387 (4.8782) Lpmap_mix: 11.6702 (11.4789) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.5080 (2.7249) Ldepth_backend: 4.0108 (4.8754) Lpmap_backend: 11.6630 (11.4792) Ltrack_backend: 0.0000 (0.0000) total: 3202.9844 (3465.4637) time: 55.1610 data: 0.0417 max mem: 78608 +[2026-05-02 23:06:38,980][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 40/1087] eta: 16:22:50 lr: 0.000001 epoch: 0.0276 (0.0184) step: 30.0000 (20.0000) loss: 3448.9260 (3553.3549) Lcamera_frontend: 2.7252 (2.7964) Ldepth_frontend: 4.0274 (4.9665) Lpmap_frontend: 11.8865 (11.5849) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7172 (2.7901) Ldepth_mix: 4.0149 (4.9630) Lpmap_mix: 11.8827 (11.5776) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7249 (2.7962) Ldepth_backend: 4.0034 (4.9592) Lpmap_backend: 11.8838 (11.5771) Ltrack_backend: 0.0000 (0.0000) total: 3448.9260 (3553.3549) time: 55.5890 data: 0.0377 max mem: 78608 +[2026-05-02 23:15:51,240][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 50/1087] eta: 16:09:43 lr: 0.000001 epoch: 0.0368 (0.0230) step: 40.0000 (25.0000) loss: 3448.9260 (3730.1944) Lcamera_frontend: 2.7171 (2.9446) Ldepth_frontend: 4.3692 (4.8668) Lpmap_frontend: 11.9106 (11.5888) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7163 (2.9386) Ldepth_mix: 4.3693 (4.8634) Lpmap_mix: 11.9057 (11.5817) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7197 (2.9446) Ldepth_backend: 4.3642 (4.8598) Lpmap_backend: 11.9027 (11.5820) Ltrack_backend: 0.0000 (0.0000) total: 3448.9260 (3730.1944) time: 55.7664 data: 0.0367 max mem: 78608 +[2026-05-02 23:24:59,529][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 60/1087] eta: 15:56:47 lr: 0.000001 epoch: 0.0460 (0.0276) step: 50.0000 (30.0000) loss: 3853.3931 (3919.1408) Lcamera_frontend: 3.0787 (3.1032) Ldepth_frontend: 3.7769 (4.7695) Lpmap_frontend: 11.8250 (11.5821) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.0701 (3.0965) Ldepth_mix: 3.7747 (4.7661) Lpmap_mix: 11.8170 (11.5752) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.0787 (3.1031) Ldepth_backend: 3.7724 (4.7627) Lpmap_backend: 11.8129 (11.5762) Ltrack_backend: 0.0000 (0.0000) total: 3853.3931 (3919.1408) time: 55.0273 data: 0.0364 max mem: 78608 +[2026-05-02 23:34:17,190][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 70/1087] eta: 15:47:09 lr: 0.000001 epoch: 0.0552 (0.0322) step: 60.0000 (35.0000) loss: 4532.8169 (3989.7981) Lcamera_frontend: 3.6279 (3.1625) Ldepth_frontend: 3.4384 (4.6902) Lpmap_frontend: 11.7723 (11.6082) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.6180 (3.1558) Ldepth_mix: 3.4318 (4.6861) Lpmap_mix: 11.7678 (11.6008) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.6298 (3.1625) Ldepth_backend: 3.4263 (4.6823) Lpmap_backend: 11.7694 (11.6018) Ltrack_backend: 0.0000 (0.0000) total: 4532.8169 (3989.7981) time: 55.2964 data: 0.0360 max mem: 78608 +[2026-05-02 23:43:31,213][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 80/1087] eta: 15:36:51 lr: 0.000001 epoch: 0.0644 (0.0368) step: 70.0000 (40.0000) loss: 3389.3914 (3806.4857) Lcamera_frontend: 2.6853 (3.0091) Ldepth_frontend: 3.9082 (4.7790) Lpmap_frontend: 11.4536 (11.5821) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.6737 (3.0025) Ldepth_mix: 3.9067 (4.7758) Lpmap_mix: 11.4464 (11.5746) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.6854 (3.0091) Ldepth_backend: 3.9059 (4.7728) Lpmap_backend: 11.4431 (11.5753) Ltrack_backend: 0.0000 (0.0000) total: 3389.3914 (3806.4857) time: 55.5826 data: 0.0368 max mem: 78608 +[2026-05-02 23:52:50,237][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 90/1087] eta: 15:27:41 lr: 0.000002 epoch: 0.0736 (0.0414) step: 80.0000 (45.0000) loss: 3389.3914 (3999.8356) Lcamera_frontend: 2.6853 (3.1719) Ldepth_frontend: 3.8950 (4.6448) Lpmap_frontend: 11.4536 (11.5627) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.6737 (3.1649) Ldepth_mix: 3.8903 (4.6412) Lpmap_mix: 11.4464 (11.5551) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.6854 (3.1718) Ldepth_backend: 3.8884 (4.6377) Lpmap_backend: 11.4431 (11.5559) Ltrack_backend: 0.0000 (0.0000) total: 3389.3914 (3999.8356) time: 55.6517 data: 0.0346 max mem: 78608 +[2026-05-03 00:02:07,619][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 100/1087] eta: 15:18:14 lr: 0.000002 epoch: 0.0828 (0.0460) step: 90.0000 (49.9901) loss: 3579.2693 (3912.0880) Lcamera_frontend: 2.8308 (3.0985) Ldepth_frontend: 3.8950 (4.6963) Lpmap_frontend: 11.3769 (11.5519) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8226 (3.0910) Ldepth_mix: 3.8903 (4.6922) Lpmap_mix: 11.3684 (11.5436) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.8261 (3.0983) Ldepth_backend: 3.8884 (4.6884) Lpmap_backend: 11.3697 (11.5439) Ltrack_backend: 0.0000 (0.0000) total: 3579.2693 (3912.0880) time: 55.8202 data: 0.0385 max mem: 78608 +[2026-05-03 00:11:35,712][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 110/1087] eta: 15:10:23 lr: 0.000002 epoch: 0.0920 (0.0506) step: 100.0000 (54.9910) loss: 2613.1426 (3837.1782) Lcamera_frontend: 2.0297 (3.0364) Ldepth_frontend: 4.1661 (4.7066) Lpmap_frontend: 11.3587 (11.5084) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.0106 (3.0286) Ldepth_mix: 4.1581 (4.7030) Lpmap_mix: 11.3454 (11.5001) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.0296 (3.0363) Ldepth_backend: 4.1517 (4.6994) Lpmap_backend: 11.3423 (11.5001) Ltrack_backend: 0.0000 (0.0000) total: 2613.1426 (3837.1782) time: 56.2736 data: 0.0437 max mem: 78608 +[2026-05-03 00:20:58,624][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 120/1087] eta: 15:01:34 lr: 0.000002 epoch: 0.1012 (0.0552) step: 110.0000 (59.9917) loss: 2613.1426 (3752.6934) Lcamera_frontend: 2.0297 (2.9658) Ldepth_frontend: 4.3324 (4.7646) Lpmap_frontend: 11.3587 (11.4705) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.0106 (2.9580) Ldepth_mix: 4.3270 (4.7614) Lpmap_mix: 11.3454 (11.4620) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.0296 (2.9657) Ldepth_backend: 4.3216 (4.7582) Lpmap_backend: 11.3423 (11.4619) Ltrack_backend: 0.0000 (0.0000) total: 2613.1426 (3752.6934) time: 56.5501 data: 0.0401 max mem: 78608 +[2026-05-03 00:30:12,133][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 130/1087] eta: 14:51:32 lr: 0.000002 epoch: 0.1104 (0.0598) step: 120.0000 (64.9924) loss: 2144.1643 (3774.6720) Lcamera_frontend: 1.6416 (2.9835) Ldepth_frontend: 5.1765 (4.8166) Lpmap_frontend: 11.8014 (11.4739) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.6108 (2.9756) Ldepth_mix: 5.1847 (4.8138) Lpmap_mix: 11.7932 (11.4655) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.6411 (2.9834) Ldepth_backend: 5.1862 (4.8108) Lpmap_backend: 11.7894 (11.4657) Ltrack_backend: 0.0000 (0.0000) total: 2144.1643 (3774.6720) time: 55.8203 data: 0.0341 max mem: 78608 +[2026-05-03 00:39:25,201][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 140/1087] eta: 14:41:33 lr: 0.000003 epoch: 0.1196 (0.0644) step: 130.0000 (69.9929) loss: 4428.3130 (3811.5716) Lcamera_frontend: 3.5491 (3.0150) Ldepth_frontend: 3.5344 (4.7627) Lpmap_frontend: 11.5238 (11.4595) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5423 (3.0068) Ldepth_mix: 3.5241 (4.7599) Lpmap_mix: 11.5170 (11.4512) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5484 (3.0149) Ldepth_backend: 3.5188 (4.7569) Lpmap_backend: 11.5258 (11.4517) Ltrack_backend: 0.0000 (0.0000) total: 4428.3130 (3811.5716) time: 55.3273 data: 0.0338 max mem: 78608 +[2026-05-03 00:48:29,854][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 150/1087] eta: 14:30:48 lr: 0.000003 epoch: 0.1288 (0.0690) step: 140.0000 (74.9934) loss: 4079.4248 (3840.8996) Lcamera_frontend: 3.2705 (3.0399) Ldepth_frontend: 3.7790 (4.7508) Lpmap_frontend: 11.3197 (11.4240) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.2565 (3.0313) Ldepth_mix: 3.7716 (4.7482) Lpmap_mix: 11.3102 (11.4155) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.2702 (3.0398) Ldepth_backend: 3.7647 (4.7454) Lpmap_backend: 11.3096 (11.4160) Ltrack_backend: 0.0000 (0.0000) total: 4079.4248 (3840.8996) time: 54.8851 data: 0.0448 max mem: 78608 +[2026-05-03 00:57:45,911][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 160/1087] eta: 14:21:22 lr: 0.000003 epoch: 0.1380 (0.0736) step: 150.0000 (79.9876) loss: 3706.1357 (3838.3399) Lcamera_frontend: 2.9012 (3.0382) Ldepth_frontend: 3.9068 (4.7365) Lpmap_frontend: 11.0957 (11.3986) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8827 (3.0295) Ldepth_mix: 3.9063 (4.7341) Lpmap_mix: 11.0950 (11.3902) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.9018 (3.0381) Ldepth_backend: 3.9018 (4.7313) Lpmap_backend: 11.0915 (11.3909) Ltrack_backend: 0.0000 (0.0000) total: 3706.1357 (3838.3399) time: 55.0354 data: 0.0449 max mem: 78608 +[2026-05-03 01:07:11,802][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 170/1087] eta: 14:12:49 lr: 0.000003 epoch: 0.1472 (0.0782) step: 160.0000 (84.9883) loss: 2894.3047 (3769.6565) Lcamera_frontend: 2.2870 (2.9803) Ldepth_frontend: 4.2580 (4.7964) Lpmap_frontend: 11.3292 (11.4032) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.2753 (2.9717) Ldepth_mix: 4.2547 (4.7936) Lpmap_mix: 11.3332 (11.3945) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.2859 (2.9802) Ldepth_backend: 4.2554 (4.7906) Lpmap_backend: 11.3373 (11.3947) Ltrack_backend: 0.0000 (0.0000) total: 2894.3047 (3769.6565) time: 56.0973 data: 0.0360 max mem: 78608 +[2026-05-03 01:16:32,573][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 180/1087] eta: 14:03:45 lr: 0.000003 epoch: 0.1564 (0.0828) step: 170.0000 (89.9890) loss: 2615.0669 (3713.9764) Lcamera_frontend: 2.0345 (2.9336) Ldepth_frontend: 4.8844 (4.8284) Lpmap_frontend: 11.5031 (11.4055) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.0182 (2.9249) Ldepth_mix: 4.8753 (4.8257) Lpmap_mix: 11.4892 (11.3968) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.0357 (2.9335) Ldepth_backend: 4.8663 (4.8227) Lpmap_backend: 11.4933 (11.3970) Ltrack_backend: 0.0000 (0.0000) total: 2615.0669 (3713.9764) time: 56.3330 data: 0.0386 max mem: 78608 +[2026-05-03 01:25:41,765][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 190/1087] eta: 13:53:45 lr: 0.000003 epoch: 0.1656 (0.0874) step: 180.0000 (94.9895) loss: 3651.0952 (3756.0374) Lcamera_frontend: 2.8870 (2.9682) Ldepth_frontend: 4.3808 (4.8599) Lpmap_frontend: 11.4447 (11.4136) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8607 (2.9592) Ldepth_mix: 4.3745 (4.8576) Lpmap_mix: 11.4449 (11.4050) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.8858 (2.9682) Ldepth_backend: 4.3690 (4.8550) Lpmap_backend: 11.4457 (11.4054) Ltrack_backend: 0.0000 (0.0000) total: 3651.0952 (3756.0374) time: 55.4980 data: 0.0370 max mem: 78608 +[2026-05-03 01:34:53,678][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 200/1087] eta: 13:44:01 lr: 0.000004 epoch: 0.1748 (0.0920) step: 190.0000 (99.9851) loss: 4310.5767 (3799.8757) Lcamera_frontend: 3.4209 (3.0058) Ldepth_frontend: 3.5405 (4.7962) Lpmap_frontend: 11.0155 (11.3732) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.4082 (2.9966) Ldepth_mix: 3.5322 (4.7936) Lpmap_mix: 11.0036 (11.3645) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4207 (3.0057) Ldepth_backend: 3.5222 (4.7908) Lpmap_backend: 11.0105 (11.3651) Ltrack_backend: 0.0000 (0.0000) total: 4310.5767 (3799.8757) time: 55.0541 data: 0.0353 max mem: 78608 +[2026-05-03 01:44:10,757][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 210/1087] eta: 13:34:43 lr: 0.000004 epoch: 0.1840 (0.0966) step: 200.0000 (104.9858) loss: 4329.0322 (3826.5264) Lcamera_frontend: 3.4749 (3.0278) Ldepth_frontend: 4.2057 (4.8206) Lpmap_frontend: 11.3225 (11.3813) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.4592 (3.0184) Ldepth_mix: 4.1959 (4.8180) Lpmap_mix: 11.3116 (11.3726) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4761 (3.0276) Ldepth_backend: 4.1878 (4.8154) Lpmap_backend: 11.3086 (11.3733) Ltrack_backend: 0.0000 (0.0000) total: 4329.0322 (3826.5264) time: 55.4476 data: 0.0379 max mem: 78608 +[2026-05-03 01:53:38,674][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 220/1087] eta: 13:26:06 lr: 0.000004 epoch: 0.1932 (0.1012) step: 210.0000 (109.9864) loss: 2949.4153 (3767.0713) Lcamera_frontend: 2.2959 (2.9779) Ldepth_frontend: 5.0282 (4.8638) Lpmap_frontend: 11.3225 (11.3650) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.2833 (2.9686) Ldepth_mix: 5.0255 (4.8616) Lpmap_mix: 11.3116 (11.3561) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.2960 (2.9778) Ldepth_backend: 5.0224 (4.8594) Lpmap_backend: 11.3086 (11.3567) Ltrack_backend: 0.0000 (0.0000) total: 2949.4153 (3767.0713) time: 56.2488 data: 0.0409 max mem: 78608 +[2026-05-03 02:02:50,649][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 230/1087] eta: 13:16:27 lr: 0.000004 epoch: 0.2024 (0.1058) step: 220.0000 (114.9870) loss: 1227.3260 (3724.5507) Lcamera_frontend: 0.8765 (2.9426) Ldepth_frontend: 5.5676 (4.8908) Lpmap_frontend: 11.0429 (11.3324) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 0.8432 (2.9331) Ldepth_mix: 5.5706 (4.8890) Lpmap_mix: 11.0351 (11.3234) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 0.8766 (2.9424) Ldepth_backend: 5.5732 (4.8873) Lpmap_backend: 11.0334 (11.3240) Ltrack_backend: 0.0000 (0.0000) total: 1227.3260 (3724.5507) time: 55.9945 data: 0.0396 max mem: 78608 +[2026-05-03 02:12:12,330][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 240/1087] eta: 13:07:23 lr: 0.000004 epoch: 0.2116 (0.1104) step: 230.0000 (119.9834) loss: 3145.9524 (3707.0761) Lcamera_frontend: 2.4927 (2.9282) Ldepth_frontend: 4.0071 (4.8953) Lpmap_frontend: 10.8839 (11.3124) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.4697 (2.9186) Ldepth_mix: 3.9991 (4.8937) Lpmap_mix: 10.8632 (11.3032) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.4925 (2.9280) Ldepth_backend: 3.9916 (4.8922) Lpmap_backend: 10.8596 (11.3039) Ltrack_backend: 0.0000 (0.0000) total: 3145.9524 (3707.0761) time: 55.6827 data: 0.0365 max mem: 78608 +[2026-05-03 02:21:22,385][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 250/1087] eta: 12:57:40 lr: 0.000005 epoch: 0.2208 (0.1150) step: 240.0000 (124.9841) loss: 3360.2373 (3679.5271) Lcamera_frontend: 2.6443 (2.9057) Ldepth_frontend: 4.0071 (4.8888) Lpmap_frontend: 10.6091 (11.2775) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.6377 (2.8956) Ldepth_mix: 3.9991 (4.8873) Lpmap_mix: 10.5963 (11.2681) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.6462 (2.9055) Ldepth_backend: 3.9916 (4.8858) Lpmap_backend: 10.5966 (11.2687) Ltrack_backend: 0.0000 (0.0000) total: 3360.2373 (3679.5271) time: 55.5867 data: 0.0344 max mem: 78608 +[2026-05-03 02:30:44,932][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 260/1087] eta: 12:48:38 lr: 0.000005 epoch: 0.2300 (0.1196) step: 250.0000 (129.9847) loss: 3391.5752 (3716.7938) Lcamera_frontend: 2.6879 (2.9371) Ldepth_frontend: 4.1649 (4.8764) Lpmap_frontend: 10.5907 (11.2571) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.6734 (2.9268) Ldepth_mix: 4.1562 (4.8750) Lpmap_mix: 10.5793 (11.2476) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.6880 (2.9369) Ldepth_backend: 4.1492 (4.8736) Lpmap_backend: 10.5772 (11.2484) Ltrack_backend: 0.0000 (0.0000) total: 3391.5752 (3716.7938) time: 55.6291 data: 0.0358 max mem: 78608 +[2026-05-03 02:40:01,223][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 270/1087] eta: 12:39:17 lr: 0.000005 epoch: 0.2392 (0.1242) step: 260.0000 (134.9852) loss: 3615.5190 (3707.4362) Lcamera_frontend: 2.8830 (2.9292) Ldepth_frontend: 4.0986 (4.8942) Lpmap_frontend: 11.0972 (11.2496) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8654 (2.9189) Ldepth_mix: 4.0911 (4.8929) Lpmap_mix: 11.0977 (11.2401) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.8833 (2.9290) Ldepth_backend: 4.0850 (4.8917) Lpmap_backend: 11.0993 (11.2409) Ltrack_backend: 0.0000 (0.0000) total: 3615.5190 (3707.4362) time: 55.9403 data: 0.0371 max mem: 78608 +[2026-05-03 02:49:14,507][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 280/1087] eta: 12:29:46 lr: 0.000005 epoch: 0.2484 (0.1288) step: 270.0000 (139.9822) loss: 4492.9077 (3771.1266) Lcamera_frontend: 3.6113 (2.9824) Ldepth_frontend: 4.0747 (4.8840) Lpmap_frontend: 11.4186 (11.2470) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.6015 (2.9717) Ldepth_mix: 4.0737 (4.8827) Lpmap_mix: 11.4035 (11.2374) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.6116 (2.9823) Ldepth_backend: 4.0687 (4.8815) Lpmap_backend: 11.4096 (11.2384) Ltrack_backend: 0.0000 (0.0000) total: 4492.9077 (3771.1266) time: 55.4780 data: 0.0360 max mem: 78608 +[2026-05-03 02:58:21,937][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 290/1087] eta: 12:20:02 lr: 0.000005 epoch: 0.2576 (0.1334) step: 280.0000 (144.9828) loss: 4368.5010 (3761.0733) Lcamera_frontend: 3.4531 (2.9740) Ldepth_frontend: 4.2899 (4.9012) Lpmap_frontend: 11.3882 (11.2331) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.4088 (2.9629) Ldepth_mix: 4.2923 (4.9000) Lpmap_mix: 11.3673 (11.2232) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4621 (2.9739) Ldepth_backend: 4.2874 (4.8989) Lpmap_backend: 11.3639 (11.2242) Ltrack_backend: 0.0000 (0.0000) total: 4368.5010 (3761.0733) time: 55.0356 data: 0.0364 max mem: 78608 +[2026-05-03 03:07:37,235][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 300/1087] eta: 12:10:40 lr: 0.000006 epoch: 0.2668 (0.1380) step: 290.0000 (149.9834) loss: 3446.0640 (3727.2401) Lcamera_frontend: 2.7435 (2.9460) Ldepth_frontend: 4.3725 (4.9154) Lpmap_frontend: 10.6490 (11.2016) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7125 (2.9346) Ldepth_mix: 4.3465 (4.9148) Lpmap_mix: 10.6345 (11.1915) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7433 (2.9459) Ldepth_backend: 4.3297 (4.9141) Lpmap_backend: 10.6376 (11.1924) Ltrack_backend: 0.0000 (0.0000) total: 3446.0640 (3727.2401) time: 55.1362 data: 0.0364 max mem: 78608 +[2026-05-03 03:16:36,477][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 310/1087] eta: 12:00:38 lr: 0.000006 epoch: 0.2760 (0.1426) step: 300.0000 (154.9839) loss: 1748.7461 (3688.8335) Lcamera_frontend: 1.2842 (2.9140) Ldepth_frontend: 5.3224 (4.9373) Lpmap_frontend: 10.1626 (11.1831) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.2635 (2.9026) Ldepth_mix: 5.3156 (4.9370) Lpmap_mix: 10.1451 (11.1727) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.2813 (2.9138) Ldepth_backend: 5.3142 (4.9367) Lpmap_backend: 10.1451 (11.1734) Ltrack_backend: 0.0000 (0.0000) total: 1748.7461 (3688.8335) time: 54.7269 data: 0.0373 max mem: 78608 +[2026-05-03 03:25:45,852][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 320/1087] eta: 11:51:05 lr: 0.000006 epoch: 0.2852 (0.1472) step: 310.0000 (159.9813) loss: 1933.6400 (3667.4677) Lcamera_frontend: 1.4398 (2.8960) Ldepth_frontend: 5.7850 (4.9677) Lpmap_frontend: 11.0741 (11.1793) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.4133 (2.8841) Ldepth_mix: 5.7826 (4.9677) Lpmap_mix: 11.0659 (11.1688) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.4415 (2.8958) Ldepth_backend: 5.7786 (4.9679) Lpmap_backend: 11.0489 (11.1693) Ltrack_backend: 0.0000 (0.0000) total: 1933.6400 (3667.4677) time: 54.4297 data: 0.0381 max mem: 78608 +[2026-05-03 03:35:00,256][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 330/1087] eta: 11:41:44 lr: 0.000006 epoch: 0.2944 (0.1518) step: 320.0000 (164.9819) loss: 1478.3679 (3603.7438) Lcamera_frontend: 1.0652 (2.8427) Ldepth_frontend: 6.5051 (5.0058) Lpmap_frontend: 11.0741 (11.1597) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.0470 (2.8310) Ldepth_mix: 6.5272 (5.0065) Lpmap_mix: 11.0659 (11.1489) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.0661 (2.8425) Ldepth_backend: 6.5682 (5.0076) Lpmap_backend: 11.0489 (11.1493) Ltrack_backend: 0.0000 (0.0000) total: 1478.3679 (3603.7438) time: 55.1857 data: 0.0413 max mem: 78608 +[2026-05-03 03:44:16,370][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 340/1087] eta: 11:32:28 lr: 0.000006 epoch: 0.3036 (0.1564) step: 330.0000 (169.9824) loss: 660.3922 (3575.3925) Lcamera_frontend: 0.3779 (2.8193) Ldepth_frontend: 5.5826 (5.0136) Lpmap_frontend: 10.2357 (11.1348) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 0.3272 (2.8070) Ldepth_mix: 5.6053 (5.0146) Lpmap_mix: 10.2233 (11.1237) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 0.3828 (2.8191) Ldepth_backend: 5.6269 (5.0157) Lpmap_backend: 10.2307 (11.1241) Ltrack_backend: 0.0000 (0.0000) total: 660.3922 (3575.3925) time: 55.5237 data: 0.0430 max mem: 78608 +[2026-05-03 03:53:29,748][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 350/1087] eta: 11:23:05 lr: 0.000006 epoch: 0.3128 (0.1610) step: 340.0000 (174.9829) loss: 2127.1729 (3558.2403) Lcamera_frontend: 1.6259 (2.8049) Ldepth_frontend: 5.5155 (5.0409) Lpmap_frontend: 10.2984 (11.1208) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.5844 (2.7919) Ldepth_mix: 5.5179 (5.0421) Lpmap_mix: 10.2822 (11.1095) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.6279 (2.8047) Ldepth_backend: 5.5202 (5.0435) Lpmap_backend: 10.2864 (11.1101) Ltrack_backend: 0.0000 (0.0000) total: 2127.1729 (3558.2403) time: 55.4745 data: 0.0385 max mem: 78608 +[2026-05-03 04:02:45,484][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 360/1087] eta: 11:13:49 lr: 0.000007 epoch: 0.3220 (0.1656) step: 350.0000 (179.9806) loss: 3043.0459 (3554.0205) Lcamera_frontend: 2.3864 (2.8014) Ldepth_frontend: 5.2218 (5.0499) Lpmap_frontend: 10.9820 (11.1108) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.2985 (2.7875) Ldepth_mix: 5.2499 (5.0512) Lpmap_mix: 10.9726 (11.0990) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.3887 (2.8013) Ldepth_backend: 5.2633 (5.0527) Lpmap_backend: 10.9631 (11.0997) Ltrack_backend: 0.0000 (0.0000) total: 3043.0459 (3554.0205) time: 55.4556 data: 0.0354 max mem: 78608 +[2026-05-03 04:11:56,777][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 370/1087] eta: 11:04:23 lr: 0.000007 epoch: 0.3312 (0.1702) step: 360.0000 (184.9811) loss: 3459.4954 (3548.7099) Lcamera_frontend: 2.7502 (2.7974) Ldepth_frontend: 4.5297 (5.0422) Lpmap_frontend: 10.3133 (11.0893) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.6680 (2.7823) Ldepth_mix: 4.5317 (5.0437) Lpmap_mix: 10.2952 (11.0773) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7496 (2.7973) Ldepth_backend: 4.5328 (5.0452) Lpmap_backend: 10.3024 (11.0782) Ltrack_backend: 0.0000 (0.0000) total: 3459.4954 (3548.7099) time: 55.3513 data: 0.0384 max mem: 78608 +[2026-05-03 04:21:04,067][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 380/1087] eta: 10:54:51 lr: 0.000007 epoch: 0.3404 (0.1748) step: 370.0000 (189.9816) loss: 3702.8796 (3581.1575) Lcamera_frontend: 2.9399 (2.8248) Ldepth_frontend: 4.1629 (5.0211) Lpmap_frontend: 10.4718 (11.0842) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8749 (2.8083) Ldepth_mix: 4.1504 (5.0225) Lpmap_mix: 10.4454 (11.0719) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.9407 (2.8247) Ldepth_backend: 4.1369 (5.0237) Lpmap_backend: 10.4560 (11.0730) Ltrack_backend: 0.0000 (0.0000) total: 3702.8796 (3581.1575) time: 54.9282 data: 0.0367 max mem: 78608 +[2026-05-03 04:30:12,784][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 390/1087] eta: 10:45:23 lr: 0.000007 epoch: 0.3496 (0.1794) step: 380.0000 (194.9821) loss: 3815.2349 (3621.6174) Lcamera_frontend: 3.0557 (2.8583) Ldepth_frontend: 4.1017 (5.0338) Lpmap_frontend: 11.1255 (11.0971) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.9788 (2.8408) Ldepth_mix: 4.0889 (5.0351) Lpmap_mix: 11.1028 (11.0845) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.0442 (2.8582) Ldepth_backend: 4.0802 (5.0362) Lpmap_backend: 11.1291 (11.0859) Ltrack_backend: 0.0000 (0.0000) total: 3815.2349 (3621.6174) time: 54.7987 data: 0.0372 max mem: 78608 +[2026-05-03 04:39:29,897][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 400/1087] eta: 10:36:10 lr: 0.000007 epoch: 0.3588 (0.1840) step: 390.0000 (199.9800) loss: 3792.3845 (3632.4914) Lcamera_frontend: 2.9839 (2.8680) Ldepth_frontend: 4.1017 (5.0110) Lpmap_frontend: 10.8711 (11.0800) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7402 (2.8488) Ldepth_mix: 4.0889 (5.0122) Lpmap_mix: 10.8333 (11.0670) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.9997 (2.8678) Ldepth_backend: 4.0802 (5.0131) Lpmap_backend: 10.8577 (11.0691) Ltrack_backend: 0.0000 (0.0000) total: 3792.3845 (3632.4914) time: 55.2907 data: 0.0400 max mem: 78608 +[2026-05-03 04:48:42,673][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 410/1087] eta: 10:26:50 lr: 0.000008 epoch: 0.3680 (0.1886) step: 400.0000 (204.9805) loss: 3652.0527 (3621.1971) Lcamera_frontend: 2.9191 (2.8586) Ldepth_frontend: 4.2692 (5.0128) Lpmap_frontend: 10.6573 (11.0750) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8159 (2.8384) Ldepth_mix: 4.2618 (5.0139) Lpmap_mix: 10.5971 (11.0612) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.9032 (2.8586) Ldepth_backend: 4.2422 (5.0146) Lpmap_backend: 10.6582 (11.0634) Ltrack_backend: 0.0000 (0.0000) total: 3652.0527 (3621.1971) time: 55.4943 data: 0.0404 max mem: 78608 +[2026-05-03 04:57:58,689][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 420/1087] eta: 10:17:35 lr: 0.000008 epoch: 0.3772 (0.1932) step: 410.0000 (209.9810) loss: 3599.6470 (3658.2743) Lcamera_frontend: 2.8657 (2.8906) Ldepth_frontend: 4.1086 (4.9857) Lpmap_frontend: 11.1647 (11.0617) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8159 (2.8664) Ldepth_mix: 4.1172 (4.9868) Lpmap_mix: 11.0853 (11.0471) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.8670 (2.8901) Ldepth_backend: 4.1287 (4.9875) Lpmap_backend: 11.1357 (11.0503) Ltrack_backend: 0.0000 (0.0000) total: 3599.6470 (3658.2743) time: 55.4395 data: 0.0508 max mem: 78608 +[2026-05-03 05:07:12,491][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 430/1087] eta: 10:08:17 lr: 0.000008 epoch: 0.3864 (0.1978) step: 420.0000 (214.9814) loss: 4169.3975 (3674.5172) Lcamera_frontend: 3.3739 (2.9049) Ldepth_frontend: 4.1068 (4.9837) Lpmap_frontend: 10.4016 (11.0568) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.2131 (2.8778) Ldepth_mix: 4.0997 (4.9850) Lpmap_mix: 10.2944 (11.0410) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3480 (2.9039) Ldepth_backend: 4.0674 (4.9860) Lpmap_backend: 10.3913 (11.0453) Ltrack_backend: 0.0000 (0.0000) total: 4169.3975 (3674.5172) time: 55.4907 data: 0.0462 max mem: 78608 +[2026-05-03 05:16:38,261][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 440/1087] eta: 9:59:16 lr: 0.000008 epoch: 0.3956 (0.2024) step: 430.0000 (219.9796) loss: 3696.4048 (3699.8234) Lcamera_frontend: 2.9970 (2.9271) Ldepth_frontend: 4.1289 (4.9705) Lpmap_frontend: 11.1855 (11.0608) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7336 (2.8954) Ldepth_mix: 4.1181 (4.9718) Lpmap_mix: 11.1081 (11.0446) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.9742 (2.9253) Ldepth_backend: 4.1189 (4.9727) Lpmap_backend: 11.1666 (11.0495) Ltrack_backend: 0.0000 (0.0000) total: 3696.4048 (3699.8234) time: 55.9778 data: 0.0363 max mem: 78608 +[2026-05-03 05:25:56,064][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 450/1087] eta: 9:50:03 lr: 0.000008 epoch: 0.4048 (0.2070) step: 440.0000 (224.9800) loss: 3557.5820 (3685.2850) Lcamera_frontend: 2.9234 (2.9174) Ldepth_frontend: 4.1712 (4.9652) Lpmap_frontend: 10.5566 (11.0486) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.5110 (2.8790) Ldepth_mix: 4.1717 (4.9663) Lpmap_mix: 10.5271 (11.0313) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.8505 (2.9136) Ldepth_backend: 4.1736 (4.9670) Lpmap_backend: 10.5695 (11.0372) Ltrack_backend: 0.0000 (0.0000) total: 3557.5820 (3685.2850) time: 56.1767 data: 0.0383 max mem: 78608 +[2026-05-03 05:35:13,979][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 460/1087] eta: 9:40:50 lr: 0.000008 epoch: 0.4140 (0.2116) step: 450.0000 (229.9805) loss: 3480.1802 (3715.1574) Lcamera_frontend: 2.9304 (2.9444) Ldepth_frontend: 4.2331 (4.9430) Lpmap_frontend: 10.6253 (11.0481) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.5110 (2.8997) Ldepth_mix: 4.2412 (4.9438) Lpmap_mix: 10.5307 (11.0303) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7599 (2.9390) Ldepth_backend: 4.2408 (4.9443) Lpmap_backend: 10.6318 (11.0370) Ltrack_backend: 0.0000 (0.0000) total: 3480.1802 (3715.1574) time: 55.7847 data: 0.0405 max mem: 78608 +[2026-05-03 05:44:35,874][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 470/1087] eta: 9:31:42 lr: 0.000009 epoch: 0.4232 (0.2162) step: 460.0000 (234.9809) loss: 3422.6602 (3696.1356) Lcamera_frontend: 2.8272 (2.9314) Ldepth_frontend: 4.5441 (4.9593) Lpmap_frontend: 10.7407 (11.0491) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.4745 (2.8820) Ldepth_mix: 4.5294 (4.9603) Lpmap_mix: 10.7328 (11.0303) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.6872 (2.9228) Ldepth_backend: 4.4900 (4.9611) Lpmap_backend: 10.7370 (11.0375) Ltrack_backend: 0.0000 (0.0000) total: 3422.6602 (3696.1356) time: 55.9904 data: 0.0372 max mem: 78608 +[2026-05-03 05:53:57,443][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 480/1087] eta: 9:22:33 lr: 0.000009 epoch: 0.4324 (0.2208) step: 470.0000 (239.9792) loss: 1953.1704 (3685.9747) Lcamera_frontend: 1.6280 (2.9252) Ldepth_frontend: 4.7953 (4.9523) Lpmap_frontend: 11.3316 (11.0568) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.2892 (2.8721) Ldepth_mix: 4.8133 (4.9534) Lpmap_mix: 11.2683 (11.0365) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.4712 (2.9143) Ldepth_backend: 4.8394 (4.9541) Lpmap_backend: 11.3175 (11.0445) Ltrack_backend: 0.0000 (0.0000) total: 1953.1704 (3685.9747) time: 56.1731 data: 0.0351 max mem: 78608 +[2026-05-03 06:03:23,669][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 490/1087] eta: 9:13:30 lr: 0.000009 epoch: 0.4416 (0.2254) step: 480.0000 (244.9796) loss: 3448.4893 (3715.0480) Lcamera_frontend: 2.8162 (2.9481) Ldepth_frontend: 4.0748 (4.9510) Lpmap_frontend: 11.5225 (11.0709) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.4138 (2.8971) Ldepth_mix: 4.0684 (4.9520) Lpmap_mix: 11.3160 (11.0501) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7307 (2.9384) Ldepth_backend: 4.0582 (4.9524) Lpmap_backend: 11.3917 (11.0582) Ltrack_backend: 0.0000 (0.0000) total: 3448.4893 (3715.0480) time: 56.3896 data: 0.0380 max mem: 78608 +[2026-05-03 06:12:44,873][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 500/1087] eta: 9:04:19 lr: 0.000009 epoch: 0.4508 (0.2300) step: 490.0000 (249.9800) loss: 3181.1260 (3704.2903) Lcamera_frontend: 2.7117 (2.9441) Ldepth_frontend: 4.3194 (4.9549) Lpmap_frontend: 11.1899 (11.0778) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.3901 (2.8867) Ldepth_mix: 4.2974 (4.9559) Lpmap_mix: 11.1532 (11.0555) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.4224 (2.9290) Ldepth_backend: 4.2664 (4.9563) Lpmap_backend: 11.1807 (11.0642) Ltrack_backend: 0.0000 (0.0000) total: 3181.1260 (3704.2903) time: 56.3714 data: 0.0392 max mem: 78608 +[2026-05-03 06:22:10,560][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 510/1087] eta: 8:55:13 lr: 0.000009 epoch: 0.4600 (0.2346) step: 500.0000 (254.9804) loss: 2312.1357 (3664.4238) Lcamera_frontend: 1.6983 (2.9067) Ldepth_frontend: 5.5187 (4.9843) Lpmap_frontend: 11.2489 (11.0845) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.6342 (2.8571) Ldepth_mix: 5.5271 (4.9860) Lpmap_mix: 11.1145 (11.0610) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.8280 (2.8953) Ldepth_backend: 5.5356 (4.9872) Lpmap_backend: 11.2275 (11.0709) Ltrack_backend: 0.0000 (0.0000) total: 2312.1357 (3664.4238) time: 56.3433 data: 0.0392 max mem: 78608 +[2026-05-03 06:31:23,143][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 520/1087] eta: 8:45:52 lr: 0.000010 epoch: 0.4692 (0.2392) step: 510.0000 (259.9808) loss: 2312.1357 (3680.6249) Lcamera_frontend: 1.7494 (2.9227) Ldepth_frontend: 5.5370 (4.9785) Lpmap_frontend: 11.5558 (11.0989) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.6342 (2.8691) Ldepth_mix: 5.5576 (4.9800) Lpmap_mix: 11.4531 (11.0744) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.8280 (2.9087) Ldepth_backend: 5.5728 (4.9808) Lpmap_backend: 11.4899 (11.0847) Ltrack_backend: 0.0000 (0.0000) total: 2312.1357 (3680.6249) time: 55.9113 data: 0.0349 max mem: 78608 +[2026-05-03 06:40:36,931][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 530/1087] eta: 8:36:33 lr: 0.000010 epoch: 0.4784 (0.2438) step: 520.0000 (264.9812) loss: 3759.7468 (3690.1346) Lcamera_frontend: 3.2916 (2.9316) Ldepth_frontend: 3.9535 (4.9874) Lpmap_frontend: 11.9431 (11.1189) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8017 (2.8779) Ldepth_mix: 3.9413 (4.9890) Lpmap_mix: 11.8807 (11.0935) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.9341 (2.9161) Ldepth_backend: 3.9288 (4.9899) Lpmap_backend: 11.8630 (11.1037) Ltrack_backend: 0.0000 (0.0000) total: 3759.7468 (3690.1346) time: 55.3174 data: 0.0345 max mem: 78608 +[2026-05-03 06:49:38,804][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 540/1087] eta: 8:27:02 lr: 0.000010 epoch: 0.4876 (0.2484) step: 530.0000 (269.9815) loss: 3041.6226 (3678.3606) Lcamera_frontend: 2.7298 (2.9239) Ldepth_frontend: 4.7545 (4.9937) Lpmap_frontend: 11.9619 (11.1341) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.3337 (2.8681) Ldepth_mix: 4.7512 (4.9953) Lpmap_mix: 11.8882 (11.1076) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.3449 (2.9058) Ldepth_backend: 4.7452 (4.9960) Lpmap_backend: 11.8855 (11.1177) Ltrack_backend: 0.0000 (0.0000) total: 3041.6226 (3678.3606) time: 54.7829 data: 0.0372 max mem: 78608 +[2026-05-03 06:58:46,010][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 550/1087] eta: 8:17:37 lr: 0.000010 epoch: 0.4968 (0.2530) step: 540.0000 (274.9819) loss: 3157.6709 (3695.2321) Lcamera_frontend: 2.8859 (2.9411) Ldepth_frontend: 4.2761 (4.9836) Lpmap_frontend: 11.9345 (11.1517) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.3596 (2.8817) Ldepth_mix: 4.2809 (4.9853) Lpmap_mix: 11.8205 (11.1240) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.4089 (2.9195) Ldepth_backend: 4.2791 (4.9860) Lpmap_backend: 11.8111 (11.1343) Ltrack_backend: 0.0000 (0.0000) total: 3157.6709 (3695.2321) time: 54.4538 data: 0.0362 max mem: 78608 +[2026-05-03 07:07:59,114][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 560/1087] eta: 8:08:18 lr: 0.000010 epoch: 0.5060 (0.2576) step: 550.0000 (279.9804) loss: 3299.5527 (3685.7435) Lcamera_frontend: 2.8859 (2.9342) Ldepth_frontend: 4.2133 (4.9895) Lpmap_frontend: 12.1331 (11.1706) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.5620 (2.8742) Ldepth_mix: 4.2019 (4.9912) Lpmap_mix: 12.0707 (11.1417) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.5655 (2.9112) Ldepth_backend: 4.2065 (4.9920) Lpmap_backend: 12.0808 (11.1521) Ltrack_backend: 0.0000 (0.0000) total: 3299.5527 (3685.7435) time: 55.0154 data: 0.0348 max mem: 78608 +[2026-05-03 07:17:13,870][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 570/1087] eta: 7:59:01 lr: 0.000010 epoch: 0.5152 (0.2622) step: 560.0000 (284.9807) loss: 2793.5979 (3674.3886) Lcamera_frontend: 2.5606 (2.9278) Ldepth_frontend: 4.4531 (4.9938) Lpmap_frontend: 12.1361 (11.1882) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.9259 (2.8630) Ldepth_mix: 4.4475 (4.9955) Lpmap_mix: 12.0357 (11.1578) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.1464 (2.9013) Ldepth_backend: 4.4535 (4.9964) Lpmap_backend: 12.0743 (11.1686) Ltrack_backend: 0.0000 (0.0000) total: 2793.5979 (3674.3886) time: 55.3924 data: 0.0518 max mem: 78608 +[2026-05-03 07:26:34,637][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 580/1087] eta: 7:49:50 lr: 0.000010 epoch: 0.5244 (0.2668) step: 570.0000 (289.9811) loss: 2793.5979 (3664.5630) Lcamera_frontend: 2.5606 (2.9213) Ldepth_frontend: 4.7775 (5.0027) Lpmap_frontend: 12.1265 (11.2043) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.9259 (2.8544) Ldepth_mix: 4.7610 (5.0045) Lpmap_mix: 12.0290 (11.1729) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.1464 (2.8927) Ldepth_backend: 4.7341 (5.0055) Lpmap_backend: 12.0366 (11.1838) Ltrack_backend: 0.0000 (0.0000) total: 2793.5979 (3664.5630) time: 55.7751 data: 0.0523 max mem: 78608 +[2026-05-03 07:35:47,292][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 590/1087] eta: 7:40:31 lr: 0.000010 epoch: 0.5336 (0.2714) step: 580.0000 (294.9814) loss: 3078.1255 (3672.0609) Lcamera_frontend: 2.6543 (2.9290) Ldepth_frontend: 5.1508 (5.0042) Lpmap_frontend: 12.1613 (11.2219) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.3590 (2.8611) Ldepth_mix: 5.1450 (5.0061) Lpmap_mix: 12.0374 (11.1892) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.3715 (2.8986) Ldepth_backend: 5.1472 (5.0072) Lpmap_backend: 12.0366 (11.2002) Ltrack_backend: 0.0000 (0.0000) total: 3078.1255 (3672.0609) time: 55.6706 data: 0.0362 max mem: 78608 +[2026-05-03 07:45:05,642][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 600/1087] eta: 7:31:17 lr: 0.000010 epoch: 0.5428 (0.2760) step: 590.0000 (299.9800) loss: 2420.8750 (3653.3033) Lcamera_frontend: 2.1832 (2.9175) Ldepth_frontend: 5.1508 (5.0129) Lpmap_frontend: 12.2512 (11.2361) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.6398 (2.8439) Ldepth_mix: 5.1450 (5.0148) Lpmap_mix: 12.0856 (11.2016) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.8001 (2.8824) Ldepth_backend: 5.1472 (5.0159) Lpmap_backend: 12.0868 (11.2128) Ltrack_backend: 0.0000 (0.0000) total: 2420.8750 (3653.3033) time: 55.5500 data: 0.0383 max mem: 78608 +[2026-05-03 07:54:26,088][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 610/1087] eta: 7:22:04 lr: 0.000010 epoch: 0.5520 (0.2806) step: 600.0000 (304.9804) loss: 1847.7075 (3619.7879) Lcamera_frontend: 1.7183 (2.8876) Ldepth_frontend: 6.1198 (5.0419) Lpmap_frontend: 12.1447 (11.2512) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.4303 (2.8188) Ldepth_mix: 6.1404 (5.0440) Lpmap_mix: 12.0819 (11.2154) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.4411 (2.8539) Ldepth_backend: 6.1645 (5.0453) Lpmap_backend: 12.0822 (11.2265) Ltrack_backend: 0.0000 (0.0000) total: 1847.7075 (3619.7879) time: 55.9397 data: 0.0387 max mem: 78608 +[2026-05-03 08:03:46,791][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 620/1087] eta: 7:12:52 lr: 0.000010 epoch: 0.5612 (0.2852) step: 610.0000 (309.9807) loss: 1764.5923 (3613.3664) Lcamera_frontend: 1.7420 (2.8863) Ldepth_frontend: 5.8825 (5.0355) Lpmap_frontend: 12.1447 (11.2645) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.3862 (2.8133) Ldepth_mix: 5.8885 (5.0375) Lpmap_mix: 12.1045 (11.2282) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.3519 (2.8481) Ldepth_backend: 5.8812 (5.0388) Lpmap_backend: 12.0942 (11.2392) Ltrack_backend: 0.0000 (0.0000) total: 1764.5923 (3613.3664) time: 56.0573 data: 0.0348 max mem: 78608 +[2026-05-03 08:13:07,153][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 630/1087] eta: 7:03:38 lr: 0.000010 epoch: 0.5704 (0.2898) step: 620.0000 (314.9810) loss: 3059.9080 (3603.2805) Lcamera_frontend: 2.6628 (2.8808) Ldepth_frontend: 4.4627 (5.0408) Lpmap_frontend: 12.1902 (11.2795) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.3353 (2.8049) Ldepth_mix: 4.4561 (5.0429) Lpmap_mix: 12.1202 (11.2425) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.3512 (2.8391) Ldepth_backend: 4.4338 (5.0442) Lpmap_backend: 12.1125 (11.2534) Ltrack_backend: 0.0000 (0.0000) total: 3059.9080 (3603.2805) time: 56.0523 data: 0.0357 max mem: 78608 +[2026-05-03 08:22:25,019][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 640/1087] eta: 6:54:23 lr: 0.000010 epoch: 0.5796 (0.2944) step: 630.0000 (319.9797) loss: 2445.0342 (3595.1443) Lcamera_frontend: 2.2690 (2.8786) Ldepth_frontend: 5.0677 (5.0415) Lpmap_frontend: 12.1176 (11.2916) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.8117 (2.7976) Ldepth_mix: 5.0777 (5.0436) Lpmap_mix: 12.0499 (11.2539) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.8377 (2.8318) Ldepth_backend: 5.0785 (5.0450) Lpmap_backend: 12.0557 (11.2650) Ltrack_backend: 0.0000 (0.0000) total: 2445.0342 (3595.1443) time: 55.9098 data: 0.0467 max mem: 78608 +[2026-05-03 08:31:48,897][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 650/1087] eta: 6:45:12 lr: 0.000010 epoch: 0.5888 (0.2990) step: 640.0000 (324.9800) loss: 2445.0342 (3586.5967) Lcamera_frontend: 2.3161 (2.8738) Ldepth_frontend: 4.6716 (5.0454) Lpmap_frontend: 12.1176 (11.3048) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.8117 (2.7902) Ldepth_mix: 4.6889 (5.0475) Lpmap_mix: 11.9498 (11.2663) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.8377 (2.8243) Ldepth_backend: 4.7019 (5.0488) Lpmap_backend: 11.9930 (11.2774) Ltrack_backend: 0.0000 (0.0000) total: 2445.0342 (3586.5967) time: 56.0863 data: 0.0486 max mem: 78608 +[2026-05-03 08:41:02,453][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 660/1087] eta: 6:35:54 lr: 0.000010 epoch: 0.5980 (0.3036) step: 650.0000 (329.9803) loss: 2730.5120 (3580.1872) Lcamera_frontend: 2.3161 (2.8633) Ldepth_frontend: 5.5045 (5.0635) Lpmap_frontend: 12.1464 (11.3183) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.9963 (2.7858) Ldepth_mix: 5.5256 (5.0657) Lpmap_mix: 12.0755 (11.2787) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.0671 (2.8190) Ldepth_backend: 5.5443 (5.0670) Lpmap_backend: 12.1021 (11.2898) Ltrack_backend: 0.0000 (0.0000) total: 2730.5120 (3580.1872) time: 55.8715 data: 0.0395 max mem: 78608 +[2026-05-03 08:50:18,799][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 670/1087] eta: 6:26:38 lr: 0.000010 epoch: 0.6072 (0.3082) step: 660.0000 (334.9806) loss: 1671.2928 (3558.1377) Lcamera_frontend: 1.3674 (2.8426) Ldepth_frontend: 6.3148 (5.0889) Lpmap_frontend: 12.0345 (11.3293) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.1993 (2.7676) Ldepth_mix: 6.3372 (5.0913) Lpmap_mix: 11.9413 (11.2881) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.2104 (2.8004) Ldepth_backend: 6.3731 (5.0929) Lpmap_backend: 11.9567 (11.2992) Ltrack_backend: 0.0000 (0.0000) total: 1671.2928 (3558.1377) time: 55.4949 data: 0.0354 max mem: 78608 +[2026-05-03 08:59:14,886][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 680/1087] eta: 6:17:09 lr: 0.000010 epoch: 0.6164 (0.3128) step: 670.0000 (339.9794) loss: 1392.4603 (3546.1834) Lcamera_frontend: 1.2751 (2.8348) Ldepth_frontend: 6.0868 (5.0949) Lpmap_frontend: 11.9328 (11.3400) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 0.8983 (2.7576) Ldepth_mix: 6.0872 (5.0973) Lpmap_mix: 11.8686 (11.2986) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 0.9128 (2.7900) Ldepth_backend: 6.0722 (5.0989) Lpmap_backend: 11.8731 (11.3095) Ltrack_backend: 0.0000 (0.0000) total: 1392.4603 (3546.1834) time: 54.6215 data: 0.0341 max mem: 78608 +[2026-05-03 09:08:26,305][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 690/1087] eta: 6:07:51 lr: 0.000010 epoch: 0.6256 (0.3174) step: 680.0000 (344.9797) loss: 3024.0039 (3550.7322) Lcamera_frontend: 2.7665 (2.8400) Ldepth_frontend: 4.4880 (5.0947) Lpmap_frontend: 12.1197 (11.3505) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.2599 (2.7622) Ldepth_mix: 4.4941 (5.0972) Lpmap_mix: 12.0705 (11.3088) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.2941 (2.7935) Ldepth_backend: 4.4907 (5.0987) Lpmap_backend: 12.0592 (11.3199) Ltrack_backend: 0.0000 (0.0000) total: 3024.0039 (3550.7322) time: 54.3744 data: 0.0348 max mem: 78608 +[2026-05-03 09:17:43,346][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 700/1087] eta: 5:58:35 lr: 0.000010 epoch: 0.6348 (0.3220) step: 690.0000 (349.9800) loss: 2653.3569 (3546.6571) Lcamera_frontend: 2.6315 (2.8399) Ldepth_frontend: 4.4880 (5.0993) Lpmap_frontend: 12.1347 (11.3618) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.9773 (2.7598) Ldepth_mix: 4.4941 (5.1018) Lpmap_mix: 12.0656 (11.3196) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.9968 (2.7895) Ldepth_backend: 4.4907 (5.1032) Lpmap_backend: 12.0694 (11.3308) Ltrack_backend: 0.0000 (0.0000) total: 2653.3569 (3546.6571) time: 55.4217 data: 0.0344 max mem: 78608 +[2026-05-03 09:26:59,488][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 710/1087] eta: 5:49:19 lr: 0.000010 epoch: 0.6440 (0.3266) step: 700.0000 (354.9803) loss: 2276.0320 (3538.5582) Lcamera_frontend: 2.2519 (2.8377) Ldepth_frontend: 4.5876 (5.0951) Lpmap_frontend: 11.9809 (11.3689) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.6760 (2.7523) Ldepth_mix: 4.5948 (5.0976) Lpmap_mix: 11.8900 (11.3253) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.7012 (2.7823) Ldepth_backend: 4.5958 (5.0990) Lpmap_backend: 11.8968 (11.3366) Ltrack_backend: 0.0000 (0.0000) total: 2276.0320 (3538.5582) time: 55.6586 data: 0.0364 max mem: 78608 +[2026-05-03 09:36:09,910][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 720/1087] eta: 5:40:01 lr: 0.000010 epoch: 0.6532 (0.3312) step: 710.0000 (359.9792) loss: 2183.4126 (3522.9060) Lcamera_frontend: 2.1352 (2.8250) Ldepth_frontend: 5.1236 (5.1032) Lpmap_frontend: 11.8750 (11.3756) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.6790 (2.7410) Ldepth_mix: 5.1267 (5.1058) Lpmap_mix: 11.6824 (11.3311) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.5669 (2.7689) Ldepth_backend: 5.1289 (5.1074) Lpmap_backend: 11.7142 (11.3424) Ltrack_backend: 0.0000 (0.0000) total: 2183.4126 (3522.9060) time: 55.3280 data: 0.0390 max mem: 78608 +[2026-05-03 09:45:24,115][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 730/1087] eta: 5:30:44 lr: 0.000010 epoch: 0.6624 (0.3358) step: 720.0000 (364.9795) loss: 2632.0852 (3538.6334) Lcamera_frontend: 2.3416 (2.8363) Ldepth_frontend: 4.6212 (5.0993) Lpmap_frontend: 11.9530 (11.3871) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.1370 (2.7544) Ldepth_mix: 4.5974 (5.1018) Lpmap_mix: 11.9262 (11.3426) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.9735 (2.7821) Ldepth_backend: 4.6038 (5.1033) Lpmap_backend: 11.8971 (11.3540) Ltrack_backend: 0.0000 (0.0000) total: 2632.0852 (3538.6334) time: 55.2311 data: 0.0394 max mem: 78608 +[2026-05-03 09:54:34,387][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 740/1087] eta: 5:21:25 lr: 0.000010 epoch: 0.6716 (0.3404) step: 730.0000 (369.9798) loss: 3940.0229 (3533.7108) Lcamera_frontend: 3.3570 (2.8329) Ldepth_frontend: 4.6212 (5.1089) Lpmap_frontend: 12.2415 (11.4019) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8580 (2.7488) Ldepth_mix: 4.5974 (5.1111) Lpmap_mix: 12.1585 (11.3570) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.1289 (2.7778) Ldepth_backend: 4.6038 (5.1119) Lpmap_backend: 12.1698 (11.3684) Ltrack_backend: 0.0000 (0.0000) total: 3940.0229 (3533.7108) time: 55.2237 data: 0.0406 max mem: 78608 +[2026-05-03 10:03:56,791][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 750/1087] eta: 5:12:13 lr: 0.000010 epoch: 0.6808 (0.3450) step: 740.0000 (374.9800) loss: 2456.8555 (3527.7547) Lcamera_frontend: 2.2745 (2.8312) Ldepth_frontend: 5.3840 (5.1194) Lpmap_frontend: 12.2415 (11.4125) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.9090 (2.7457) Ldepth_mix: 5.4059 (5.1217) Lpmap_mix: 11.9992 (11.3643) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.8766 (2.7721) Ldepth_backend: 5.4270 (5.1226) Lpmap_backend: 12.1463 (11.3767) Ltrack_backend: 0.0000 (0.0000) total: 2456.8555 (3527.7547) time: 55.6336 data: 0.0390 max mem: 78608 +[2026-05-03 10:13:16,348][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 760/1087] eta: 5:02:58 lr: 0.000010 epoch: 0.6900 (0.3496) step: 750.0000 (379.9790) loss: 2818.4995 (3523.2228) Lcamera_frontend: 2.4893 (2.8302) Ldepth_frontend: 5.7214 (5.1307) Lpmap_frontend: 12.1750 (11.4230) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.5757 (2.7421) Ldepth_mix: 5.6839 (5.1333) Lpmap_mix: 11.9527 (11.3724) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.0683 (2.7677) Ldepth_backend: 5.6014 (5.1345) Lpmap_backend: 12.0599 (11.3857) Ltrack_backend: 0.0000 (0.0000) total: 2818.4995 (3523.2228) time: 56.0967 data: 0.0372 max mem: 78608 +[2026-05-03 10:22:27,922][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 770/1087] eta: 4:53:41 lr: 0.000010 epoch: 0.6992 (0.3542) step: 760.0000 (384.9792) loss: 3387.1812 (3518.5164) Lcamera_frontend: 2.8698 (2.8273) Ldepth_frontend: 5.2634 (5.1386) Lpmap_frontend: 12.1750 (11.4330) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7236 (2.7379) Ldepth_mix: 5.2865 (5.1413) Lpmap_mix: 12.0448 (11.3823) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.6191 (2.7635) Ldepth_backend: 5.2979 (5.1425) Lpmap_backend: 12.0879 (11.3958) Ltrack_backend: 0.0000 (0.0000) total: 3387.1812 (3518.5164) time: 55.5538 data: 0.0388 max mem: 78608 +[2026-05-03 10:31:51,892][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 780/1087] eta: 4:44:28 lr: 0.000010 epoch: 0.7084 (0.3588) step: 770.0000 (389.9795) loss: 2987.0747 (3514.9776) Lcamera_frontend: 2.8551 (2.8275) Ldepth_frontend: 4.8573 (5.1382) Lpmap_frontend: 12.1774 (11.4421) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.2459 (2.7352) Ldepth_mix: 4.8739 (5.1410) Lpmap_mix: 11.9692 (11.3895) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.2869 (2.7601) Ldepth_backend: 4.8861 (5.1422) Lpmap_backend: 12.0147 (11.4031) Ltrack_backend: 0.0000 (0.0000) total: 2987.0747 (3514.9776) time: 55.7757 data: 0.0375 max mem: 78608 +[2026-05-03 10:40:55,736][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 790/1087] eta: 4:35:07 lr: 0.000010 epoch: 0.7176 (0.3634) step: 780.0000 (394.9798) loss: 2920.8914 (3533.7044) Lcamera_frontend: 2.8551 (2.8334) Ldepth_frontend: 4.8106 (5.1427) Lpmap_frontend: 12.1774 (11.4514) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.3363 (2.7519) Ldepth_mix: 4.8182 (5.1455) Lpmap_mix: 11.9840 (11.3997) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.2747 (2.7764) Ldepth_backend: 4.8160 (5.1468) Lpmap_backend: 12.0256 (11.4133) Ltrack_backend: 0.0000 (0.0000) total: 2920.8914 (3533.7044) time: 55.3906 data: 0.0353 max mem: 78608 +[2026-05-03 10:50:15,123][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 800/1087] eta: 4:25:53 lr: 0.000010 epoch: 0.7268 (0.3680) step: 790.0000 (399.9788) loss: 4561.9434 (3540.7668) Lcamera_frontend: 3.2215 (2.8399) Ldepth_frontend: 4.5694 (5.1381) Lpmap_frontend: 12.0237 (11.4587) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5356 (2.7579) Ldepth_mix: 4.5534 (5.1410) Lpmap_mix: 11.9905 (11.4068) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.6264 (2.7822) Ldepth_backend: 4.5277 (5.1421) Lpmap_backend: 11.9463 (11.4203) Ltrack_backend: 0.0000 (0.0000) total: 4561.9434 (3540.7668) time: 55.1615 data: 0.0344 max mem: 78608 +[2026-05-03 10:59:13,062][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 810/1087] eta: 4:16:31 lr: 0.000010 epoch: 0.7360 (0.3726) step: 800.0000 (404.9790) loss: 3519.1130 (3530.4271) Lcamera_frontend: 2.8982 (2.8309) Ldepth_frontend: 4.3878 (5.1461) Lpmap_frontend: 11.9666 (11.4668) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7501 (2.7496) Ldepth_mix: 4.3607 (5.1487) Lpmap_mix: 11.9215 (11.4149) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7690 (2.7734) Ldepth_backend: 4.3180 (5.1496) Lpmap_backend: 11.9312 (11.4281) Ltrack_backend: 0.0000 (0.0000) total: 3519.1130 (3530.4271) time: 54.8662 data: 0.0341 max mem: 78608 +[2026-05-03 11:08:28,691][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 820/1087] eta: 4:07:15 lr: 0.000010 epoch: 0.7452 (0.3772) step: 810.0000 (409.9793) loss: 2573.6677 (3529.3969) Lcamera_frontend: 2.3184 (2.8273) Ldepth_frontend: 6.1467 (5.1690) Lpmap_frontend: 11.9551 (11.4722) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.8582 (2.7502) Ldepth_mix: 6.1585 (5.1718) Lpmap_mix: 11.9255 (11.4196) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.9209 (2.7724) Ldepth_backend: 6.1818 (5.1729) Lpmap_backend: 11.8981 (11.4329) Ltrack_backend: 0.0000 (0.0000) total: 2573.6677 (3529.3969) time: 54.6779 data: 0.0351 max mem: 78608 +[2026-05-03 11:17:46,381][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 830/1087] eta: 3:58:00 lr: 0.000010 epoch: 0.7544 (0.3818) step: 820.0000 (414.9795) loss: 2573.6677 (3534.8521) Lcamera_frontend: 2.3376 (2.8334) Ldepth_frontend: 5.3513 (5.1629) Lpmap_frontend: 11.8924 (11.4769) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.8582 (2.7547) Ldepth_mix: 5.3563 (5.1656) Lpmap_mix: 11.7717 (11.4241) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.9209 (2.7768) Ldepth_backend: 5.3373 (5.1664) Lpmap_backend: 11.7964 (11.4373) Ltrack_backend: 0.0000 (0.0000) total: 2573.6677 (3534.8521) time: 55.6651 data: 0.0355 max mem: 78608 +[2026-05-03 11:26:59,109][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 840/1087] eta: 3:48:44 lr: 0.000010 epoch: 0.7636 (0.3864) step: 830.0000 (419.9786) loss: 2918.1069 (3529.2461) Lcamera_frontend: 2.9368 (2.8318) Ldepth_frontend: 4.7339 (5.1652) Lpmap_frontend: 11.7897 (11.4817) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.1487 (2.7492) Ldepth_mix: 4.7480 (5.1679) Lpmap_mix: 11.7146 (11.4282) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.2180 (2.7718) Ldepth_backend: 4.7483 (5.1687) Lpmap_backend: 11.7211 (11.4415) Ltrack_backend: 0.0000 (0.0000) total: 2918.1069 (3529.2461) time: 55.5204 data: 0.0365 max mem: 78608 diff --git a/outdoor_v48_4gpu/.hydra/config.yaml b/outdoor_v48_4gpu/.hydra/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a7024a22935a37c7434ab6dd4f998e5564642ef --- /dev/null +++ b/outdoor_v48_4gpu/.hydra/config.yaml @@ -0,0 +1,68 @@ +teacher: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +pretrained: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +load_only_encoder: false +long_context: false +fixed_length: true +resume: null +benchmark: false +num_views: 64 +num_test_views: 4 +n_corres_train: 0 +n_corres_test: 0 +train_criterion: DistillLoss() +test_criterion: DistillLoss() +allow_repeat: false +root_vkitti2: /scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti +root_kitti: /scratch-shared/wwei2/eval/kitti_odometry/dataset +root_kitti_velo: /gpfs/work2/0/prjs0824/semantickitti/dataset +root_kitti360: /scratch-shared/wwei2/downloads/kitti360/KITTI-360 +root_kitti360_velo: /scratch-shared/wwei2/downloads/kitti360/KITTI-360 +root_waymo: /scratch-shared/wwei2/waymo_v2 +root_waymo_lidar: /scratch-shared/wwei2/waymo_v2 +dataset_vkitti2: VirtualKITTI2_Multi(allow_repeat=${allow_repeat}, split='train', + ROOT="${root_vkitti2}", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), + (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views}, + n_corres=${n_corres_train}) +dataset_kitti360: KITTI360_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_kitti360}", + velodyne_root="${root_kitti360_velo}", aug_crop=16, resolution=[(518, 392), (518, + 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, + num_views=${num_views}, n_corres=${n_corres_train}) +dataset_waymo: Waymo_v2_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_waymo}", + lidar_root="${root_waymo_lidar}", aug_crop=16, resolution=[(518, 392), (518, 336), + (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views}, + n_corres=${n_corres_train}) +train_dataset: 6000 @ ${dataset_vkitti2} + 6000 @ ${dataset_kitti360} + 5400 @ ${dataset_waymo} +test_dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="${root_vkitti2}", resolution=(518, + 154), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test}) +seed: 0 +batch_size: 1 +accum_iter: 1 +gradient_checkpointing: false +epochs: 10 +start_epoch: 0 +start_step: 0 +weight_decay: 0.05 +lr: 1.0e-05 +min_lr: 1.0e-08 +warmup_epochs: 0.5 +amp: 1 +num_workers: 4 +world_size: 1 +local-rank: -1 +dist_url: env:// +rank: 0 +gpu: 0 +distributed: false +dist_backend: nccl +eval_freq: 1 +save_freq: 0.1 +max_checkpoints: 10 +keep_freq: 1 +print_freq: 10 +print_img_freq: 50000000 +num_imgs_vis: 4 +save_dir: /scratch-shared/wwei2/training_upstream/checkpoints +exp_name: outdoor_v48_4gpu +task: StreamVGGT +logdir: ${save_dir}/${exp_name}/logs +output_dir: ${save_dir}/${exp_name}/ diff --git a/outdoor_v48_4gpu/.hydra/hydra.yaml b/outdoor_v48_4gpu/.hydra/hydra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e939a008caa059395d85632d6f6d9556122e7181 --- /dev/null +++ b/outdoor_v48_4gpu/.hydra/hydra.yaml @@ -0,0 +1,155 @@ +hydra: + run: + dir: ${save_dir}/${exp_name} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} + launcher: + _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher + sweeper: + _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper + max_batch_size: null + params: null + help: + app_name: ${hydra.job.name} + header: '${hydra.help.app_name} is powered by Hydra. + + ' + footer: 'Powered by Hydra (https://hydra.cc) + + Use --hydra-help to view Hydra specific help + + ' + template: '${hydra.help.header} + + == Configuration groups == + + Compose your configuration from those groups (group=option) + + + $APP_CONFIG_GROUPS + + + == Config == + + Override anything in the config (foo.bar=value) + + + $CONFIG + + + ${hydra.help.footer} + + ' + hydra_help: + template: 'Hydra (${hydra.runtime.version}) + + See https://hydra.cc for more info. + + + == Flags == + + $FLAGS_HELP + + + == Configuration groups == + + Compose your configuration from those groups (For example, append hydra/job_logging=disabled + to command line) + + + $HYDRA_CONFIG_GROUPS + + + Use ''--cfg hydra'' to Show the Hydra config. + + ' + hydra_help: ??? + hydra_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][HYDRA] %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + root: + level: INFO + handlers: + - console + loggers: + logging_example: + level: DEBUG + disable_existing_loggers: false + job_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log + root: + level: INFO + handlers: + - console + - file + disable_existing_loggers: false + env: {} + mode: RUN + searchpath: [] + callbacks: {} + output_subdir: .hydra + overrides: + hydra: + - hydra.mode=RUN + task: + - exp_name=outdoor_v48_4gpu + job: + name: mytrain + chdir: null + override_dirname: exp_name=outdoor_v48_4gpu + id: ??? + num: ??? + config_name: outdoor_v48 + env_set: {} + env_copy: [] + config: + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: [] + runtime: + version: 1.3.2 + version_base: '1.3' + cwd: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src + config_sources: + - path: hydra.conf + schema: pkg + provider: hydra + - path: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/config + schema: file + provider: main + - path: '' + schema: structured + provider: schema + output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu + choices: + hydra/env: default + hydra/callbacks: null + hydra/job_logging: default + hydra/hydra_logging: default + hydra/hydra_help: default + hydra/help: default + hydra/sweeper: basic + hydra/launcher: basic + hydra/output: default + verbose: true diff --git a/outdoor_v48_4gpu/.hydra/overrides.yaml b/outdoor_v48_4gpu/.hydra/overrides.yaml new file mode 100644 index 0000000000000000000000000000000000000000..79e9edb015109c44c06367b945027c517d7f7a3b --- /dev/null +++ b/outdoor_v48_4gpu/.hydra/overrides.yaml @@ -0,0 +1 @@ +- exp_name=outdoor_v48_4gpu diff --git a/outdoor_v48_4gpu/mytrain.log b/outdoor_v48_4gpu/mytrain.log new file mode 100644 index 0000000000000000000000000000000000000000..ea432561f0dc79bc176aaff02959f5204e0bdff1 --- /dev/null +++ b/outdoor_v48_4gpu/mytrain.log @@ -0,0 +1,1858 @@ +[2026-05-01 23:31:17,861][__main__][INFO] - [RANK 0] output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu/ +[2026-05-01 23:31:18,027][__main__][INFO] - [RANK 0] Saving current code to /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu/code/05_01-23:31:17 +[2026-05-01 23:31:18,027][__main__][INFO] - [RANK 0] job dir: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src +[2026-05-01 23:31:18,027][__main__][INFO] - [RANK 0] Setting seed to 0 for process 0 +[2026-05-01 23:31:18,029][__main__][INFO] - [RANK 0] Building train dataset 6000 @ VirtualKITTI2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 6000 @ KITTI360_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", velodyne_root="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 5400 @ Waymo_v2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/waymo_v2", lidar_root="/scratch-shared/wwei2/waymo_v2", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) +[2026-05-01 23:31:18,029][__main__][INFO] - [RANK 0] Building Train Data loader for dataset: 6000 @ VirtualKITTI2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 6000 @ KITTI360_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", velodyne_root="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 5400 @ Waymo_v2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/waymo_v2", lidar_root="/scratch-shared/wwei2/waymo_v2", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) +[2026-05-01 23:35:09,093][__main__][INFO] - [RANK 0] Building test dataset 200 @ VirtualKITTI2_Multi(split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", resolution=(518, 154), num_views=4, seed=42, n_corres=0) +[2026-05-01 23:35:09,109][__main__][INFO] - [RANK 0] Building Test Data loader for dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", resolution=(518, 154), num_views=4, seed=42, n_corres=0) +[2026-05-01 23:35:09,148][__main__][INFO] - [RANK 0] Loading model +[2026-05-01 23:35:15,072][__main__][INFO] - [RANK 0] All model parameters: 958696732 +[2026-05-01 23:35:15,073][__main__][INFO] - [RANK 0] >> Creating train criterion = DistillLoss() +[2026-05-01 23:35:15,073][__main__][INFO] - [RANK 0] >> Creating test criterion = DistillLoss() +[2026-05-01 23:35:15,579][__main__][INFO] - [RANK 0] Loading pretrained: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +[2026-05-01 23:35:26,982][__main__][INFO] - [RANK 0] _IncompatibleKeys(missing_keys=['register_token', 'image_mean', 'image_std', 'encoder.cls_token', 'encoder.pos_embed', 'encoder.register_tokens', 'encoder.patch_embed.proj.weight', 'encoder.patch_embed.proj.bias', 'encoder.blocks.0.norm1.weight', 'encoder.blocks.0.norm1.bias', 'encoder.blocks.0.attn.qkv.weight', 'encoder.blocks.0.attn.qkv.bias', 'encoder.blocks.0.attn.proj.weight', 'encoder.blocks.0.attn.proj.bias', 'encoder.blocks.0.ls1.gamma', 'encoder.blocks.0.norm2.weight', 'encoder.blocks.0.norm2.bias', 'encoder.blocks.0.mlp.fc1.weight', 'encoder.blocks.0.mlp.fc1.bias', 'encoder.blocks.0.mlp.fc2.weight', 'encoder.blocks.0.mlp.fc2.bias', 'encoder.blocks.0.ls2.gamma', 'encoder.blocks.1.norm1.weight', 'encoder.blocks.1.norm1.bias', 'encoder.blocks.1.attn.qkv.weight', 'encoder.blocks.1.attn.qkv.bias', 'encoder.blocks.1.attn.proj.weight', 'encoder.blocks.1.attn.proj.bias', 'encoder.blocks.1.ls1.gamma', 'encoder.blocks.1.norm2.weight', 'encoder.blocks.1.norm2.bias', 'encoder.blocks.1.mlp.fc1.weight', 'encoder.blocks.1.mlp.fc1.bias', 'encoder.blocks.1.mlp.fc2.weight', 'encoder.blocks.1.mlp.fc2.bias', 'encoder.blocks.1.ls2.gamma', 'encoder.blocks.2.norm1.weight', 'encoder.blocks.2.norm1.bias', 'encoder.blocks.2.attn.qkv.weight', 'encoder.blocks.2.attn.qkv.bias', 'encoder.blocks.2.attn.proj.weight', 'encoder.blocks.2.attn.proj.bias', 'encoder.blocks.2.ls1.gamma', 'encoder.blocks.2.norm2.weight', 'encoder.blocks.2.norm2.bias', 'encoder.blocks.2.mlp.fc1.weight', 'encoder.blocks.2.mlp.fc1.bias', 'encoder.blocks.2.mlp.fc2.weight', 'encoder.blocks.2.mlp.fc2.bias', 'encoder.blocks.2.ls2.gamma', 'encoder.blocks.3.norm1.weight', 'encoder.blocks.3.norm1.bias', 'encoder.blocks.3.attn.qkv.weight', 'encoder.blocks.3.attn.qkv.bias', 'encoder.blocks.3.attn.proj.weight', 'encoder.blocks.3.attn.proj.bias', 'encoder.blocks.3.ls1.gamma', 'encoder.blocks.3.norm2.weight', 'encoder.blocks.3.norm2.bias', 'encoder.blocks.3.mlp.fc1.weight', 'encoder.blocks.3.mlp.fc1.bias', 'encoder.blocks.3.mlp.fc2.weight', 'encoder.blocks.3.mlp.fc2.bias', 'encoder.blocks.3.ls2.gamma', 'encoder.blocks.4.norm1.weight', 'encoder.blocks.4.norm1.bias', 'encoder.blocks.4.attn.qkv.weight', 'encoder.blocks.4.attn.qkv.bias', 'encoder.blocks.4.attn.proj.weight', 'encoder.blocks.4.attn.proj.bias', 'encoder.blocks.4.ls1.gamma', 'encoder.blocks.4.norm2.weight', 'encoder.blocks.4.norm2.bias', 'encoder.blocks.4.mlp.fc1.weight', 'encoder.blocks.4.mlp.fc1.bias', 'encoder.blocks.4.mlp.fc2.weight', 'encoder.blocks.4.mlp.fc2.bias', 'encoder.blocks.4.ls2.gamma', 'encoder.blocks.5.norm1.weight', 'encoder.blocks.5.norm1.bias', 'encoder.blocks.5.attn.qkv.weight', 'encoder.blocks.5.attn.qkv.bias', 'encoder.blocks.5.attn.proj.weight', 'encoder.blocks.5.attn.proj.bias', 'encoder.blocks.5.ls1.gamma', 'encoder.blocks.5.norm2.weight', 'encoder.blocks.5.norm2.bias', 'encoder.blocks.5.mlp.fc1.weight', 'encoder.blocks.5.mlp.fc1.bias', 'encoder.blocks.5.mlp.fc2.weight', 'encoder.blocks.5.mlp.fc2.bias', 'encoder.blocks.5.ls2.gamma', 'encoder.blocks.6.norm1.weight', 'encoder.blocks.6.norm1.bias', 'encoder.blocks.6.attn.qkv.weight', 'encoder.blocks.6.attn.qkv.bias', 'encoder.blocks.6.attn.proj.weight', 'encoder.blocks.6.attn.proj.bias', 'encoder.blocks.6.ls1.gamma', 'encoder.blocks.6.norm2.weight', 'encoder.blocks.6.norm2.bias', 'encoder.blocks.6.mlp.fc1.weight', 'encoder.blocks.6.mlp.fc1.bias', 'encoder.blocks.6.mlp.fc2.weight', 'encoder.blocks.6.mlp.fc2.bias', 'encoder.blocks.6.ls2.gamma', 'encoder.blocks.7.norm1.weight', 'encoder.blocks.7.norm1.bias', 'encoder.blocks.7.attn.qkv.weight', 'encoder.blocks.7.attn.qkv.bias', 'encoder.blocks.7.attn.proj.weight', 'encoder.blocks.7.attn.proj.bias', 'encoder.blocks.7.ls1.gamma', 'encoder.blocks.7.norm2.weight', 'encoder.blocks.7.norm2.bias', 'encoder.blocks.7.mlp.fc1.weight', 'encoder.blocks.7.mlp.fc1.bias', 'encoder.blocks.7.mlp.fc2.weight', 'encoder.blocks.7.mlp.fc2.bias', 'encoder.blocks.7.ls2.gamma', 'encoder.blocks.8.norm1.weight', 'encoder.blocks.8.norm1.bias', 'encoder.blocks.8.attn.qkv.weight', 'encoder.blocks.8.attn.qkv.bias', 'encoder.blocks.8.attn.proj.weight', 'encoder.blocks.8.attn.proj.bias', 'encoder.blocks.8.ls1.gamma', 'encoder.blocks.8.norm2.weight', 'encoder.blocks.8.norm2.bias', 'encoder.blocks.8.mlp.fc1.weight', 'encoder.blocks.8.mlp.fc1.bias', 'encoder.blocks.8.mlp.fc2.weight', 'encoder.blocks.8.mlp.fc2.bias', 'encoder.blocks.8.ls2.gamma', 'encoder.blocks.9.norm1.weight', 'encoder.blocks.9.norm1.bias', 'encoder.blocks.9.attn.qkv.weight', 'encoder.blocks.9.attn.qkv.bias', 'encoder.blocks.9.attn.proj.weight', 'encoder.blocks.9.attn.proj.bias', 'encoder.blocks.9.ls1.gamma', 'encoder.blocks.9.norm2.weight', 'encoder.blocks.9.norm2.bias', 'encoder.blocks.9.mlp.fc1.weight', 'encoder.blocks.9.mlp.fc1.bias', 'encoder.blocks.9.mlp.fc2.weight', 'encoder.blocks.9.mlp.fc2.bias', 'encoder.blocks.9.ls2.gamma', 'encoder.blocks.10.norm1.weight', 'encoder.blocks.10.norm1.bias', 'encoder.blocks.10.attn.qkv.weight', 'encoder.blocks.10.attn.qkv.bias', 'encoder.blocks.10.attn.proj.weight', 'encoder.blocks.10.attn.proj.bias', 'encoder.blocks.10.ls1.gamma', 'encoder.blocks.10.norm2.weight', 'encoder.blocks.10.norm2.bias', 'encoder.blocks.10.mlp.fc1.weight', 'encoder.blocks.10.mlp.fc1.bias', 'encoder.blocks.10.mlp.fc2.weight', 'encoder.blocks.10.mlp.fc2.bias', 'encoder.blocks.10.ls2.gamma', 'encoder.blocks.11.norm1.weight', 'encoder.blocks.11.norm1.bias', 'encoder.blocks.11.attn.qkv.weight', 'encoder.blocks.11.attn.qkv.bias', 'encoder.blocks.11.attn.proj.weight', 'encoder.blocks.11.attn.proj.bias', 'encoder.blocks.11.ls1.gamma', 'encoder.blocks.11.norm2.weight', 'encoder.blocks.11.norm2.bias', 'encoder.blocks.11.mlp.fc1.weight', 'encoder.blocks.11.mlp.fc1.bias', 'encoder.blocks.11.mlp.fc2.weight', 'encoder.blocks.11.mlp.fc2.bias', 'encoder.blocks.11.ls2.gamma', 'encoder.blocks.12.norm1.weight', 'encoder.blocks.12.norm1.bias', 'encoder.blocks.12.attn.qkv.weight', 'encoder.blocks.12.attn.qkv.bias', 'encoder.blocks.12.attn.proj.weight', 'encoder.blocks.12.attn.proj.bias', 'encoder.blocks.12.ls1.gamma', 'encoder.blocks.12.norm2.weight', 'encoder.blocks.12.norm2.bias', 'encoder.blocks.12.mlp.fc1.weight', 'encoder.blocks.12.mlp.fc1.bias', 'encoder.blocks.12.mlp.fc2.weight', 'encoder.blocks.12.mlp.fc2.bias', 'encoder.blocks.12.ls2.gamma', 'encoder.blocks.13.norm1.weight', 'encoder.blocks.13.norm1.bias', 'encoder.blocks.13.attn.qkv.weight', 'encoder.blocks.13.attn.qkv.bias', 'encoder.blocks.13.attn.proj.weight', 'encoder.blocks.13.attn.proj.bias', 'encoder.blocks.13.ls1.gamma', 'encoder.blocks.13.norm2.weight', 'encoder.blocks.13.norm2.bias', 'encoder.blocks.13.mlp.fc1.weight', 'encoder.blocks.13.mlp.fc1.bias', 'encoder.blocks.13.mlp.fc2.weight', 'encoder.blocks.13.mlp.fc2.bias', 'encoder.blocks.13.ls2.gamma', 'encoder.blocks.14.norm1.weight', 'encoder.blocks.14.norm1.bias', 'encoder.blocks.14.attn.qkv.weight', 'encoder.blocks.14.attn.qkv.bias', 'encoder.blocks.14.attn.proj.weight', 'encoder.blocks.14.attn.proj.bias', 'encoder.blocks.14.ls1.gamma', 'encoder.blocks.14.norm2.weight', 'encoder.blocks.14.norm2.bias', 'encoder.blocks.14.mlp.fc1.weight', 'encoder.blocks.14.mlp.fc1.bias', 'encoder.blocks.14.mlp.fc2.weight', 'encoder.blocks.14.mlp.fc2.bias', 'encoder.blocks.14.ls2.gamma', 'encoder.blocks.15.norm1.weight', 'encoder.blocks.15.norm1.bias', 'encoder.blocks.15.attn.qkv.weight', 'encoder.blocks.15.attn.qkv.bias', 'encoder.blocks.15.attn.proj.weight', 'encoder.blocks.15.attn.proj.bias', 'encoder.blocks.15.ls1.gamma', 'encoder.blocks.15.norm2.weight', 'encoder.blocks.15.norm2.bias', 'encoder.blocks.15.mlp.fc1.weight', 'encoder.blocks.15.mlp.fc1.bias', 'encoder.blocks.15.mlp.fc2.weight', 'encoder.blocks.15.mlp.fc2.bias', 'encoder.blocks.15.ls2.gamma', 'encoder.blocks.16.norm1.weight', 'encoder.blocks.16.norm1.bias', 'encoder.blocks.16.attn.qkv.weight', 'encoder.blocks.16.attn.qkv.bias', 'encoder.blocks.16.attn.proj.weight', 'encoder.blocks.16.attn.proj.bias', 'encoder.blocks.16.ls1.gamma', 'encoder.blocks.16.norm2.weight', 'encoder.blocks.16.norm2.bias', 'encoder.blocks.16.mlp.fc1.weight', 'encoder.blocks.16.mlp.fc1.bias', 'encoder.blocks.16.mlp.fc2.weight', 'encoder.blocks.16.mlp.fc2.bias', 'encoder.blocks.16.ls2.gamma', 'encoder.blocks.17.norm1.weight', 'encoder.blocks.17.norm1.bias', 'encoder.blocks.17.attn.qkv.weight', 'encoder.blocks.17.attn.qkv.bias', 'encoder.blocks.17.attn.proj.weight', 'encoder.blocks.17.attn.proj.bias', 'encoder.blocks.17.ls1.gamma', 'encoder.blocks.17.norm2.weight', 'encoder.blocks.17.norm2.bias', 'encoder.blocks.17.mlp.fc1.weight', 'encoder.blocks.17.mlp.fc1.bias', 'encoder.blocks.17.mlp.fc2.weight', 'encoder.blocks.17.mlp.fc2.bias', 'encoder.blocks.17.ls2.gamma', 'encoder.blocks.18.norm1.weight', 'encoder.blocks.18.norm1.bias', 'encoder.blocks.18.attn.qkv.weight', 'encoder.blocks.18.attn.qkv.bias', 'encoder.blocks.18.attn.proj.weight', 'encoder.blocks.18.attn.proj.bias', 'encoder.blocks.18.ls1.gamma', 'encoder.blocks.18.norm2.weight', 'encoder.blocks.18.norm2.bias', 'encoder.blocks.18.mlp.fc1.weight', 'encoder.blocks.18.mlp.fc1.bias', 'encoder.blocks.18.mlp.fc2.weight', 'encoder.blocks.18.mlp.fc2.bias', 'encoder.blocks.18.ls2.gamma', 'encoder.blocks.19.norm1.weight', 'encoder.blocks.19.norm1.bias', 'encoder.blocks.19.attn.qkv.weight', 'encoder.blocks.19.attn.qkv.bias', 'encoder.blocks.19.attn.proj.weight', 'encoder.blocks.19.attn.proj.bias', 'encoder.blocks.19.ls1.gamma', 'encoder.blocks.19.norm2.weight', 'encoder.blocks.19.norm2.bias', 'encoder.blocks.19.mlp.fc1.weight', 'encoder.blocks.19.mlp.fc1.bias', 'encoder.blocks.19.mlp.fc2.weight', 'encoder.blocks.19.mlp.fc2.bias', 'encoder.blocks.19.ls2.gamma', 'encoder.blocks.20.norm1.weight', 'encoder.blocks.20.norm1.bias', 'encoder.blocks.20.attn.qkv.weight', 'encoder.blocks.20.attn.qkv.bias', 'encoder.blocks.20.attn.proj.weight', 'encoder.blocks.20.attn.proj.bias', 'encoder.blocks.20.ls1.gamma', 'encoder.blocks.20.norm2.weight', 'encoder.blocks.20.norm2.bias', 'encoder.blocks.20.mlp.fc1.weight', 'encoder.blocks.20.mlp.fc1.bias', 'encoder.blocks.20.mlp.fc2.weight', 'encoder.blocks.20.mlp.fc2.bias', 'encoder.blocks.20.ls2.gamma', 'encoder.blocks.21.norm1.weight', 'encoder.blocks.21.norm1.bias', 'encoder.blocks.21.attn.qkv.weight', 'encoder.blocks.21.attn.qkv.bias', 'encoder.blocks.21.attn.proj.weight', 'encoder.blocks.21.attn.proj.bias', 'encoder.blocks.21.ls1.gamma', 'encoder.blocks.21.norm2.weight', 'encoder.blocks.21.norm2.bias', 'encoder.blocks.21.mlp.fc1.weight', 'encoder.blocks.21.mlp.fc1.bias', 'encoder.blocks.21.mlp.fc2.weight', 'encoder.blocks.21.mlp.fc2.bias', 'encoder.blocks.21.ls2.gamma', 'encoder.blocks.22.norm1.weight', 'encoder.blocks.22.norm1.bias', 'encoder.blocks.22.attn.qkv.weight', 'encoder.blocks.22.attn.qkv.bias', 'encoder.blocks.22.attn.proj.weight', 'encoder.blocks.22.attn.proj.bias', 'encoder.blocks.22.ls1.gamma', 'encoder.blocks.22.norm2.weight', 'encoder.blocks.22.norm2.bias', 'encoder.blocks.22.mlp.fc1.weight', 'encoder.blocks.22.mlp.fc1.bias', 'encoder.blocks.22.mlp.fc2.weight', 'encoder.blocks.22.mlp.fc2.bias', 'encoder.blocks.22.ls2.gamma', 'encoder.blocks.23.norm1.weight', 'encoder.blocks.23.norm1.bias', 'encoder.blocks.23.attn.qkv.weight', 'encoder.blocks.23.attn.qkv.bias', 'encoder.blocks.23.attn.proj.weight', 'encoder.blocks.23.attn.proj.bias', 'encoder.blocks.23.ls1.gamma', 'encoder.blocks.23.norm2.weight', 'encoder.blocks.23.norm2.bias', 'encoder.blocks.23.mlp.fc1.weight', 'encoder.blocks.23.mlp.fc1.bias', 'encoder.blocks.23.mlp.fc2.weight', 'encoder.blocks.23.mlp.fc2.bias', 'encoder.blocks.23.ls2.gamma', 'encoder.norm.weight', 'encoder.norm.bias', 'decoder.0.norm1.weight', 'decoder.0.norm1.bias', 'decoder.0.attn.qkv.weight', 'decoder.0.attn.qkv.bias', 'decoder.0.attn.proj.weight', 'decoder.0.attn.proj.bias', 'decoder.0.attn.q_norm.weight', 'decoder.0.attn.q_norm.bias', 'decoder.0.attn.k_norm.weight', 'decoder.0.attn.k_norm.bias', 'decoder.0.ls1.gamma', 'decoder.0.norm2.weight', 'decoder.0.norm2.bias', 'decoder.0.mlp.fc1.weight', 'decoder.0.mlp.fc1.bias', 'decoder.0.mlp.fc2.weight', 'decoder.0.mlp.fc2.bias', 'decoder.0.ls2.gamma', 'decoder.1.norm1.weight', 'decoder.1.norm1.bias', 'decoder.1.attn.qkv.weight', 'decoder.1.attn.qkv.bias', 'decoder.1.attn.proj.weight', 'decoder.1.attn.proj.bias', 'decoder.1.attn.q_norm.weight', 'decoder.1.attn.q_norm.bias', 'decoder.1.attn.k_norm.weight', 'decoder.1.attn.k_norm.bias', 'decoder.1.ls1.gamma', 'decoder.1.norm2.weight', 'decoder.1.norm2.bias', 'decoder.1.mlp.fc1.weight', 'decoder.1.mlp.fc1.bias', 'decoder.1.mlp.fc2.weight', 'decoder.1.mlp.fc2.bias', 'decoder.1.ls2.gamma', 'decoder.2.norm1.weight', 'decoder.2.norm1.bias', 'decoder.2.attn.qkv.weight', 'decoder.2.attn.qkv.bias', 'decoder.2.attn.proj.weight', 'decoder.2.attn.proj.bias', 'decoder.2.attn.q_norm.weight', 'decoder.2.attn.q_norm.bias', 'decoder.2.attn.k_norm.weight', 'decoder.2.attn.k_norm.bias', 'decoder.2.ls1.gamma', 'decoder.2.norm2.weight', 'decoder.2.norm2.bias', 'decoder.2.mlp.fc1.weight', 'decoder.2.mlp.fc1.bias', 'decoder.2.mlp.fc2.weight', 'decoder.2.mlp.fc2.bias', 'decoder.2.ls2.gamma', 'decoder.3.norm1.weight', 'decoder.3.norm1.bias', 'decoder.3.attn.qkv.weight', 'decoder.3.attn.qkv.bias', 'decoder.3.attn.proj.weight', 'decoder.3.attn.proj.bias', 'decoder.3.attn.q_norm.weight', 'decoder.3.attn.q_norm.bias', 'decoder.3.attn.k_norm.weight', 'decoder.3.attn.k_norm.bias', 'decoder.3.ls1.gamma', 'decoder.3.norm2.weight', 'decoder.3.norm2.bias', 'decoder.3.mlp.fc1.weight', 'decoder.3.mlp.fc1.bias', 'decoder.3.mlp.fc2.weight', 'decoder.3.mlp.fc2.bias', 'decoder.3.ls2.gamma', 'decoder.4.norm1.weight', 'decoder.4.norm1.bias', 'decoder.4.attn.qkv.weight', 'decoder.4.attn.qkv.bias', 'decoder.4.attn.proj.weight', 'decoder.4.attn.proj.bias', 'decoder.4.attn.q_norm.weight', 'decoder.4.attn.q_norm.bias', 'decoder.4.attn.k_norm.weight', 'decoder.4.attn.k_norm.bias', 'decoder.4.ls1.gamma', 'decoder.4.norm2.weight', 'decoder.4.norm2.bias', 'decoder.4.mlp.fc1.weight', 'decoder.4.mlp.fc1.bias', 'decoder.4.mlp.fc2.weight', 'decoder.4.mlp.fc2.bias', 'decoder.4.ls2.gamma', 'decoder.5.norm1.weight', 'decoder.5.norm1.bias', 'decoder.5.attn.qkv.weight', 'decoder.5.attn.qkv.bias', 'decoder.5.attn.proj.weight', 'decoder.5.attn.proj.bias', 'decoder.5.attn.q_norm.weight', 'decoder.5.attn.q_norm.bias', 'decoder.5.attn.k_norm.weight', 'decoder.5.attn.k_norm.bias', 'decoder.5.ls1.gamma', 'decoder.5.norm2.weight', 'decoder.5.norm2.bias', 'decoder.5.mlp.fc1.weight', 'decoder.5.mlp.fc1.bias', 'decoder.5.mlp.fc2.weight', 'decoder.5.mlp.fc2.bias', 'decoder.5.ls2.gamma', 'decoder.6.norm1.weight', 'decoder.6.norm1.bias', 'decoder.6.attn.qkv.weight', 'decoder.6.attn.qkv.bias', 'decoder.6.attn.proj.weight', 'decoder.6.attn.proj.bias', 'decoder.6.attn.q_norm.weight', 'decoder.6.attn.q_norm.bias', 'decoder.6.attn.k_norm.weight', 'decoder.6.attn.k_norm.bias', 'decoder.6.ls1.gamma', 'decoder.6.norm2.weight', 'decoder.6.norm2.bias', 'decoder.6.mlp.fc1.weight', 'decoder.6.mlp.fc1.bias', 'decoder.6.mlp.fc2.weight', 'decoder.6.mlp.fc2.bias', 'decoder.6.ls2.gamma', 'decoder.7.norm1.weight', 'decoder.7.norm1.bias', 'decoder.7.attn.qkv.weight', 'decoder.7.attn.qkv.bias', 'decoder.7.attn.proj.weight', 'decoder.7.attn.proj.bias', 'decoder.7.attn.q_norm.weight', 'decoder.7.attn.q_norm.bias', 'decoder.7.attn.k_norm.weight', 'decoder.7.attn.k_norm.bias', 'decoder.7.ls1.gamma', 'decoder.7.norm2.weight', 'decoder.7.norm2.bias', 'decoder.7.mlp.fc1.weight', 'decoder.7.mlp.fc1.bias', 'decoder.7.mlp.fc2.weight', 'decoder.7.mlp.fc2.bias', 'decoder.7.ls2.gamma', 'decoder.8.norm1.weight', 'decoder.8.norm1.bias', 'decoder.8.attn.qkv.weight', 'decoder.8.attn.qkv.bias', 'decoder.8.attn.proj.weight', 'decoder.8.attn.proj.bias', 'decoder.8.attn.q_norm.weight', 'decoder.8.attn.q_norm.bias', 'decoder.8.attn.k_norm.weight', 'decoder.8.attn.k_norm.bias', 'decoder.8.ls1.gamma', 'decoder.8.norm2.weight', 'decoder.8.norm2.bias', 'decoder.8.mlp.fc1.weight', 'decoder.8.mlp.fc1.bias', 'decoder.8.mlp.fc2.weight', 'decoder.8.mlp.fc2.bias', 'decoder.8.ls2.gamma', 'decoder.9.norm1.weight', 'decoder.9.norm1.bias', 'decoder.9.attn.qkv.weight', 'decoder.9.attn.qkv.bias', 'decoder.9.attn.proj.weight', 'decoder.9.attn.proj.bias', 'decoder.9.attn.q_norm.weight', 'decoder.9.attn.q_norm.bias', 'decoder.9.attn.k_norm.weight', 'decoder.9.attn.k_norm.bias', 'decoder.9.ls1.gamma', 'decoder.9.norm2.weight', 'decoder.9.norm2.bias', 'decoder.9.mlp.fc1.weight', 'decoder.9.mlp.fc1.bias', 'decoder.9.mlp.fc2.weight', 'decoder.9.mlp.fc2.bias', 'decoder.9.ls2.gamma', 'decoder.10.norm1.weight', 'decoder.10.norm1.bias', 'decoder.10.attn.qkv.weight', 'decoder.10.attn.qkv.bias', 'decoder.10.attn.proj.weight', 'decoder.10.attn.proj.bias', 'decoder.10.attn.q_norm.weight', 'decoder.10.attn.q_norm.bias', 'decoder.10.attn.k_norm.weight', 'decoder.10.attn.k_norm.bias', 'decoder.10.ls1.gamma', 'decoder.10.norm2.weight', 'decoder.10.norm2.bias', 'decoder.10.mlp.fc1.weight', 'decoder.10.mlp.fc1.bias', 'decoder.10.mlp.fc2.weight', 'decoder.10.mlp.fc2.bias', 'decoder.10.ls2.gamma', 'decoder.11.norm1.weight', 'decoder.11.norm1.bias', 'decoder.11.attn.qkv.weight', 'decoder.11.attn.qkv.bias', 'decoder.11.attn.proj.weight', 'decoder.11.attn.proj.bias', 'decoder.11.attn.q_norm.weight', 'decoder.11.attn.q_norm.bias', 'decoder.11.attn.k_norm.weight', 'decoder.11.attn.k_norm.bias', 'decoder.11.ls1.gamma', 'decoder.11.norm2.weight', 'decoder.11.norm2.bias', 'decoder.11.mlp.fc1.weight', 'decoder.11.mlp.fc1.bias', 'decoder.11.mlp.fc2.weight', 'decoder.11.mlp.fc2.bias', 'decoder.11.ls2.gamma', 'decoder.12.norm1.weight', 'decoder.12.norm1.bias', 'decoder.12.attn.qkv.weight', 'decoder.12.attn.qkv.bias', 'decoder.12.attn.proj.weight', 'decoder.12.attn.proj.bias', 'decoder.12.attn.q_norm.weight', 'decoder.12.attn.q_norm.bias', 'decoder.12.attn.k_norm.weight', 'decoder.12.attn.k_norm.bias', 'decoder.12.ls1.gamma', 'decoder.12.norm2.weight', 'decoder.12.norm2.bias', 'decoder.12.mlp.fc1.weight', 'decoder.12.mlp.fc1.bias', 'decoder.12.mlp.fc2.weight', 'decoder.12.mlp.fc2.bias', 'decoder.12.ls2.gamma', 'decoder.13.norm1.weight', 'decoder.13.norm1.bias', 'decoder.13.attn.qkv.weight', 'decoder.13.attn.qkv.bias', 'decoder.13.attn.proj.weight', 'decoder.13.attn.proj.bias', 'decoder.13.attn.q_norm.weight', 'decoder.13.attn.q_norm.bias', 'decoder.13.attn.k_norm.weight', 'decoder.13.attn.k_norm.bias', 'decoder.13.ls1.gamma', 'decoder.13.norm2.weight', 'decoder.13.norm2.bias', 'decoder.13.mlp.fc1.weight', 'decoder.13.mlp.fc1.bias', 'decoder.13.mlp.fc2.weight', 'decoder.13.mlp.fc2.bias', 'decoder.13.ls2.gamma', 'decoder.14.norm1.weight', 'decoder.14.norm1.bias', 'decoder.14.attn.qkv.weight', 'decoder.14.attn.qkv.bias', 'decoder.14.attn.proj.weight', 'decoder.14.attn.proj.bias', 'decoder.14.attn.q_norm.weight', 'decoder.14.attn.q_norm.bias', 'decoder.14.attn.k_norm.weight', 'decoder.14.attn.k_norm.bias', 'decoder.14.ls1.gamma', 'decoder.14.norm2.weight', 'decoder.14.norm2.bias', 'decoder.14.mlp.fc1.weight', 'decoder.14.mlp.fc1.bias', 'decoder.14.mlp.fc2.weight', 'decoder.14.mlp.fc2.bias', 'decoder.14.ls2.gamma', 'decoder.15.norm1.weight', 'decoder.15.norm1.bias', 'decoder.15.attn.qkv.weight', 'decoder.15.attn.qkv.bias', 'decoder.15.attn.proj.weight', 'decoder.15.attn.proj.bias', 'decoder.15.attn.q_norm.weight', 'decoder.15.attn.q_norm.bias', 'decoder.15.attn.k_norm.weight', 'decoder.15.attn.k_norm.bias', 'decoder.15.ls1.gamma', 'decoder.15.norm2.weight', 'decoder.15.norm2.bias', 'decoder.15.mlp.fc1.weight', 'decoder.15.mlp.fc1.bias', 'decoder.15.mlp.fc2.weight', 'decoder.15.mlp.fc2.bias', 'decoder.15.ls2.gamma', 'decoder.16.norm1.weight', 'decoder.16.norm1.bias', 'decoder.16.attn.qkv.weight', 'decoder.16.attn.qkv.bias', 'decoder.16.attn.proj.weight', 'decoder.16.attn.proj.bias', 'decoder.16.attn.q_norm.weight', 'decoder.16.attn.q_norm.bias', 'decoder.16.attn.k_norm.weight', 'decoder.16.attn.k_norm.bias', 'decoder.16.ls1.gamma', 'decoder.16.norm2.weight', 'decoder.16.norm2.bias', 'decoder.16.mlp.fc1.weight', 'decoder.16.mlp.fc1.bias', 'decoder.16.mlp.fc2.weight', 'decoder.16.mlp.fc2.bias', 'decoder.16.ls2.gamma', 'decoder.17.norm1.weight', 'decoder.17.norm1.bias', 'decoder.17.attn.qkv.weight', 'decoder.17.attn.qkv.bias', 'decoder.17.attn.proj.weight', 'decoder.17.attn.proj.bias', 'decoder.17.attn.q_norm.weight', 'decoder.17.attn.q_norm.bias', 'decoder.17.attn.k_norm.weight', 'decoder.17.attn.k_norm.bias', 'decoder.17.ls1.gamma', 'decoder.17.norm2.weight', 'decoder.17.norm2.bias', 'decoder.17.mlp.fc1.weight', 'decoder.17.mlp.fc1.bias', 'decoder.17.mlp.fc2.weight', 'decoder.17.mlp.fc2.bias', 'decoder.17.ls2.gamma', 'decoder.18.norm1.weight', 'decoder.18.norm1.bias', 'decoder.18.attn.qkv.weight', 'decoder.18.attn.qkv.bias', 'decoder.18.attn.proj.weight', 'decoder.18.attn.proj.bias', 'decoder.18.attn.q_norm.weight', 'decoder.18.attn.q_norm.bias', 'decoder.18.attn.k_norm.weight', 'decoder.18.attn.k_norm.bias', 'decoder.18.ls1.gamma', 'decoder.18.norm2.weight', 'decoder.18.norm2.bias', 'decoder.18.mlp.fc1.weight', 'decoder.18.mlp.fc1.bias', 'decoder.18.mlp.fc2.weight', 'decoder.18.mlp.fc2.bias', 'decoder.18.ls2.gamma', 'decoder.19.norm1.weight', 'decoder.19.norm1.bias', 'decoder.19.attn.qkv.weight', 'decoder.19.attn.qkv.bias', 'decoder.19.attn.proj.weight', 'decoder.19.attn.proj.bias', 'decoder.19.attn.q_norm.weight', 'decoder.19.attn.q_norm.bias', 'decoder.19.attn.k_norm.weight', 'decoder.19.attn.k_norm.bias', 'decoder.19.ls1.gamma', 'decoder.19.norm2.weight', 'decoder.19.norm2.bias', 'decoder.19.mlp.fc1.weight', 'decoder.19.mlp.fc1.bias', 'decoder.19.mlp.fc2.weight', 'decoder.19.mlp.fc2.bias', 'decoder.19.ls2.gamma', 'decoder.20.norm1.weight', 'decoder.20.norm1.bias', 'decoder.20.attn.qkv.weight', 'decoder.20.attn.qkv.bias', 'decoder.20.attn.proj.weight', 'decoder.20.attn.proj.bias', 'decoder.20.attn.q_norm.weight', 'decoder.20.attn.q_norm.bias', 'decoder.20.attn.k_norm.weight', 'decoder.20.attn.k_norm.bias', 'decoder.20.ls1.gamma', 'decoder.20.norm2.weight', 'decoder.20.norm2.bias', 'decoder.20.mlp.fc1.weight', 'decoder.20.mlp.fc1.bias', 'decoder.20.mlp.fc2.weight', 'decoder.20.mlp.fc2.bias', 'decoder.20.ls2.gamma', 'decoder.21.norm1.weight', 'decoder.21.norm1.bias', 'decoder.21.attn.qkv.weight', 'decoder.21.attn.qkv.bias', 'decoder.21.attn.proj.weight', 'decoder.21.attn.proj.bias', 'decoder.21.attn.q_norm.weight', 'decoder.21.attn.q_norm.bias', 'decoder.21.attn.k_norm.weight', 'decoder.21.attn.k_norm.bias', 'decoder.21.ls1.gamma', 'decoder.21.norm2.weight', 'decoder.21.norm2.bias', 'decoder.21.mlp.fc1.weight', 'decoder.21.mlp.fc1.bias', 'decoder.21.mlp.fc2.weight', 'decoder.21.mlp.fc2.bias', 'decoder.21.ls2.gamma', 'decoder.22.norm1.weight', 'decoder.22.norm1.bias', 'decoder.22.attn.qkv.weight', 'decoder.22.attn.qkv.bias', 'decoder.22.attn.proj.weight', 'decoder.22.attn.proj.bias', 'decoder.22.attn.q_norm.weight', 'decoder.22.attn.q_norm.bias', 'decoder.22.attn.k_norm.weight', 'decoder.22.attn.k_norm.bias', 'decoder.22.ls1.gamma', 'decoder.22.norm2.weight', 'decoder.22.norm2.bias', 'decoder.22.mlp.fc1.weight', 'decoder.22.mlp.fc1.bias', 'decoder.22.mlp.fc2.weight', 'decoder.22.mlp.fc2.bias', 'decoder.22.ls2.gamma', 'decoder.23.norm1.weight', 'decoder.23.norm1.bias', 'decoder.23.attn.qkv.weight', 'decoder.23.attn.qkv.bias', 'decoder.23.attn.proj.weight', 'decoder.23.attn.proj.bias', 'decoder.23.attn.q_norm.weight', 'decoder.23.attn.q_norm.bias', 'decoder.23.attn.k_norm.weight', 'decoder.23.attn.k_norm.bias', 'decoder.23.ls1.gamma', 'decoder.23.norm2.weight', 'decoder.23.norm2.bias', 'decoder.23.mlp.fc1.weight', 'decoder.23.mlp.fc1.bias', 'decoder.23.mlp.fc2.weight', 'decoder.23.mlp.fc2.bias', 'decoder.23.ls2.gamma', 'decoder.24.norm1.weight', 'decoder.24.norm1.bias', 'decoder.24.attn.qkv.weight', 'decoder.24.attn.qkv.bias', 'decoder.24.attn.proj.weight', 'decoder.24.attn.proj.bias', 'decoder.24.attn.q_norm.weight', 'decoder.24.attn.q_norm.bias', 'decoder.24.attn.k_norm.weight', 'decoder.24.attn.k_norm.bias', 'decoder.24.ls1.gamma', 'decoder.24.norm2.weight', 'decoder.24.norm2.bias', 'decoder.24.mlp.fc1.weight', 'decoder.24.mlp.fc1.bias', 'decoder.24.mlp.fc2.weight', 'decoder.24.mlp.fc2.bias', 'decoder.24.ls2.gamma', 'decoder.25.norm1.weight', 'decoder.25.norm1.bias', 'decoder.25.attn.qkv.weight', 'decoder.25.attn.qkv.bias', 'decoder.25.attn.proj.weight', 'decoder.25.attn.proj.bias', 'decoder.25.attn.q_norm.weight', 'decoder.25.attn.q_norm.bias', 'decoder.25.attn.k_norm.weight', 'decoder.25.attn.k_norm.bias', 'decoder.25.ls1.gamma', 'decoder.25.norm2.weight', 'decoder.25.norm2.bias', 'decoder.25.mlp.fc1.weight', 'decoder.25.mlp.fc1.bias', 'decoder.25.mlp.fc2.weight', 'decoder.25.mlp.fc2.bias', 'decoder.25.ls2.gamma', 'decoder.26.norm1.weight', 'decoder.26.norm1.bias', 'decoder.26.attn.qkv.weight', 'decoder.26.attn.qkv.bias', 'decoder.26.attn.proj.weight', 'decoder.26.attn.proj.bias', 'decoder.26.attn.q_norm.weight', 'decoder.26.attn.q_norm.bias', 'decoder.26.attn.k_norm.weight', 'decoder.26.attn.k_norm.bias', 'decoder.26.ls1.gamma', 'decoder.26.norm2.weight', 'decoder.26.norm2.bias', 'decoder.26.mlp.fc1.weight', 'decoder.26.mlp.fc1.bias', 'decoder.26.mlp.fc2.weight', 'decoder.26.mlp.fc2.bias', 'decoder.26.ls2.gamma', 'decoder.27.norm1.weight', 'decoder.27.norm1.bias', 'decoder.27.attn.qkv.weight', 'decoder.27.attn.qkv.bias', 'decoder.27.attn.proj.weight', 'decoder.27.attn.proj.bias', 'decoder.27.attn.q_norm.weight', 'decoder.27.attn.q_norm.bias', 'decoder.27.attn.k_norm.weight', 'decoder.27.attn.k_norm.bias', 'decoder.27.ls1.gamma', 'decoder.27.norm2.weight', 'decoder.27.norm2.bias', 'decoder.27.mlp.fc1.weight', 'decoder.27.mlp.fc1.bias', 'decoder.27.mlp.fc2.weight', 'decoder.27.mlp.fc2.bias', 'decoder.27.ls2.gamma', 'decoder.28.norm1.weight', 'decoder.28.norm1.bias', 'decoder.28.attn.qkv.weight', 'decoder.28.attn.qkv.bias', 'decoder.28.attn.proj.weight', 'decoder.28.attn.proj.bias', 'decoder.28.attn.q_norm.weight', 'decoder.28.attn.q_norm.bias', 'decoder.28.attn.k_norm.weight', 'decoder.28.attn.k_norm.bias', 'decoder.28.ls1.gamma', 'decoder.28.norm2.weight', 'decoder.28.norm2.bias', 'decoder.28.mlp.fc1.weight', 'decoder.28.mlp.fc1.bias', 'decoder.28.mlp.fc2.weight', 'decoder.28.mlp.fc2.bias', 'decoder.28.ls2.gamma', 'decoder.29.norm1.weight', 'decoder.29.norm1.bias', 'decoder.29.attn.qkv.weight', 'decoder.29.attn.qkv.bias', 'decoder.29.attn.proj.weight', 'decoder.29.attn.proj.bias', 'decoder.29.attn.q_norm.weight', 'decoder.29.attn.q_norm.bias', 'decoder.29.attn.k_norm.weight', 'decoder.29.attn.k_norm.bias', 'decoder.29.ls1.gamma', 'decoder.29.norm2.weight', 'decoder.29.norm2.bias', 'decoder.29.mlp.fc1.weight', 'decoder.29.mlp.fc1.bias', 'decoder.29.mlp.fc2.weight', 'decoder.29.mlp.fc2.bias', 'decoder.29.ls2.gamma', 'decoder.30.norm1.weight', 'decoder.30.norm1.bias', 'decoder.30.attn.qkv.weight', 'decoder.30.attn.qkv.bias', 'decoder.30.attn.proj.weight', 'decoder.30.attn.proj.bias', 'decoder.30.attn.q_norm.weight', 'decoder.30.attn.q_norm.bias', 'decoder.30.attn.k_norm.weight', 'decoder.30.attn.k_norm.bias', 'decoder.30.ls1.gamma', 'decoder.30.norm2.weight', 'decoder.30.norm2.bias', 'decoder.30.mlp.fc1.weight', 'decoder.30.mlp.fc1.bias', 'decoder.30.mlp.fc2.weight', 'decoder.30.mlp.fc2.bias', 'decoder.30.ls2.gamma', 'decoder.31.norm1.weight', 'decoder.31.norm1.bias', 'decoder.31.attn.qkv.weight', 'decoder.31.attn.qkv.bias', 'decoder.31.attn.proj.weight', 'decoder.31.attn.proj.bias', 'decoder.31.attn.q_norm.weight', 'decoder.31.attn.q_norm.bias', 'decoder.31.attn.k_norm.weight', 'decoder.31.attn.k_norm.bias', 'decoder.31.ls1.gamma', 'decoder.31.norm2.weight', 'decoder.31.norm2.bias', 'decoder.31.mlp.fc1.weight', 'decoder.31.mlp.fc1.bias', 'decoder.31.mlp.fc2.weight', 'decoder.31.mlp.fc2.bias', 'decoder.31.ls2.gamma', 'decoder.32.norm1.weight', 'decoder.32.norm1.bias', 'decoder.32.attn.qkv.weight', 'decoder.32.attn.qkv.bias', 'decoder.32.attn.proj.weight', 'decoder.32.attn.proj.bias', 'decoder.32.attn.q_norm.weight', 'decoder.32.attn.q_norm.bias', 'decoder.32.attn.k_norm.weight', 'decoder.32.attn.k_norm.bias', 'decoder.32.ls1.gamma', 'decoder.32.norm2.weight', 'decoder.32.norm2.bias', 'decoder.32.mlp.fc1.weight', 'decoder.32.mlp.fc1.bias', 'decoder.32.mlp.fc2.weight', 'decoder.32.mlp.fc2.bias', 'decoder.32.ls2.gamma', 'decoder.33.norm1.weight', 'decoder.33.norm1.bias', 'decoder.33.attn.qkv.weight', 'decoder.33.attn.qkv.bias', 'decoder.33.attn.proj.weight', 'decoder.33.attn.proj.bias', 'decoder.33.attn.q_norm.weight', 'decoder.33.attn.q_norm.bias', 'decoder.33.attn.k_norm.weight', 'decoder.33.attn.k_norm.bias', 'decoder.33.ls1.gamma', 'decoder.33.norm2.weight', 'decoder.33.norm2.bias', 'decoder.33.mlp.fc1.weight', 'decoder.33.mlp.fc1.bias', 'decoder.33.mlp.fc2.weight', 'decoder.33.mlp.fc2.bias', 'decoder.33.ls2.gamma', 'decoder.34.norm1.weight', 'decoder.34.norm1.bias', 'decoder.34.attn.qkv.weight', 'decoder.34.attn.qkv.bias', 'decoder.34.attn.proj.weight', 'decoder.34.attn.proj.bias', 'decoder.34.attn.q_norm.weight', 'decoder.34.attn.q_norm.bias', 'decoder.34.attn.k_norm.weight', 'decoder.34.attn.k_norm.bias', 'decoder.34.ls1.gamma', 'decoder.34.norm2.weight', 'decoder.34.norm2.bias', 'decoder.34.mlp.fc1.weight', 'decoder.34.mlp.fc1.bias', 'decoder.34.mlp.fc2.weight', 'decoder.34.mlp.fc2.bias', 'decoder.34.ls2.gamma', 'decoder.35.norm1.weight', 'decoder.35.norm1.bias', 'decoder.35.attn.qkv.weight', 'decoder.35.attn.qkv.bias', 'decoder.35.attn.proj.weight', 'decoder.35.attn.proj.bias', 'decoder.35.attn.q_norm.weight', 'decoder.35.attn.q_norm.bias', 'decoder.35.attn.k_norm.weight', 'decoder.35.attn.k_norm.bias', 'decoder.35.ls1.gamma', 'decoder.35.norm2.weight', 'decoder.35.norm2.bias', 'decoder.35.mlp.fc1.weight', 'decoder.35.mlp.fc1.bias', 'decoder.35.mlp.fc2.weight', 'decoder.35.mlp.fc2.bias', 'decoder.35.ls2.gamma', 'point_decoder.projects.weight', 'point_decoder.projects.bias', 'point_decoder.blocks.0.norm1.weight', 'point_decoder.blocks.0.norm1.bias', 'point_decoder.blocks.0.attn.qkv.weight', 'point_decoder.blocks.0.attn.qkv.bias', 'point_decoder.blocks.0.attn.proj.weight', 'point_decoder.blocks.0.attn.proj.bias', 'point_decoder.blocks.0.norm2.weight', 'point_decoder.blocks.0.norm2.bias', 'point_decoder.blocks.0.mlp.fc1.weight', 'point_decoder.blocks.0.mlp.fc1.bias', 'point_decoder.blocks.0.mlp.fc2.weight', 'point_decoder.blocks.0.mlp.fc2.bias', 'point_decoder.blocks.1.norm1.weight', 'point_decoder.blocks.1.norm1.bias', 'point_decoder.blocks.1.attn.qkv.weight', 'point_decoder.blocks.1.attn.qkv.bias', 'point_decoder.blocks.1.attn.proj.weight', 'point_decoder.blocks.1.attn.proj.bias', 'point_decoder.blocks.1.norm2.weight', 'point_decoder.blocks.1.norm2.bias', 'point_decoder.blocks.1.mlp.fc1.weight', 'point_decoder.blocks.1.mlp.fc1.bias', 'point_decoder.blocks.1.mlp.fc2.weight', 'point_decoder.blocks.1.mlp.fc2.bias', 'point_decoder.blocks.2.norm1.weight', 'point_decoder.blocks.2.norm1.bias', 'point_decoder.blocks.2.attn.qkv.weight', 'point_decoder.blocks.2.attn.qkv.bias', 'point_decoder.blocks.2.attn.proj.weight', 'point_decoder.blocks.2.attn.proj.bias', 'point_decoder.blocks.2.norm2.weight', 'point_decoder.blocks.2.norm2.bias', 'point_decoder.blocks.2.mlp.fc1.weight', 'point_decoder.blocks.2.mlp.fc1.bias', 'point_decoder.blocks.2.mlp.fc2.weight', 'point_decoder.blocks.2.mlp.fc2.bias', 'point_decoder.blocks.3.norm1.weight', 'point_decoder.blocks.3.norm1.bias', 'point_decoder.blocks.3.attn.qkv.weight', 'point_decoder.blocks.3.attn.qkv.bias', 'point_decoder.blocks.3.attn.proj.weight', 'point_decoder.blocks.3.attn.proj.bias', 'point_decoder.blocks.3.norm2.weight', 'point_decoder.blocks.3.norm2.bias', 'point_decoder.blocks.3.mlp.fc1.weight', 'point_decoder.blocks.3.mlp.fc1.bias', 'point_decoder.blocks.3.mlp.fc2.weight', 'point_decoder.blocks.3.mlp.fc2.bias', 'point_decoder.blocks.4.norm1.weight', 'point_decoder.blocks.4.norm1.bias', 'point_decoder.blocks.4.attn.qkv.weight', 'point_decoder.blocks.4.attn.qkv.bias', 'point_decoder.blocks.4.attn.proj.weight', 'point_decoder.blocks.4.attn.proj.bias', 'point_decoder.blocks.4.norm2.weight', 'point_decoder.blocks.4.norm2.bias', 'point_decoder.blocks.4.mlp.fc1.weight', 'point_decoder.blocks.4.mlp.fc1.bias', 'point_decoder.blocks.4.mlp.fc2.weight', 'point_decoder.blocks.4.mlp.fc2.bias', 'point_decoder.linear_out.weight', 'point_decoder.linear_out.bias', 'point_head.proj.weight', 'point_head.proj.bias', 'conf_decoder.projects.weight', 'conf_decoder.projects.bias', 'conf_decoder.blocks.0.norm1.weight', 'conf_decoder.blocks.0.norm1.bias', 'conf_decoder.blocks.0.attn.qkv.weight', 'conf_decoder.blocks.0.attn.qkv.bias', 'conf_decoder.blocks.0.attn.proj.weight', 'conf_decoder.blocks.0.attn.proj.bias', 'conf_decoder.blocks.0.norm2.weight', 'conf_decoder.blocks.0.norm2.bias', 'conf_decoder.blocks.0.mlp.fc1.weight', 'conf_decoder.blocks.0.mlp.fc1.bias', 'conf_decoder.blocks.0.mlp.fc2.weight', 'conf_decoder.blocks.0.mlp.fc2.bias', 'conf_decoder.blocks.1.norm1.weight', 'conf_decoder.blocks.1.norm1.bias', 'conf_decoder.blocks.1.attn.qkv.weight', 'conf_decoder.blocks.1.attn.qkv.bias', 'conf_decoder.blocks.1.attn.proj.weight', 'conf_decoder.blocks.1.attn.proj.bias', 'conf_decoder.blocks.1.norm2.weight', 'conf_decoder.blocks.1.norm2.bias', 'conf_decoder.blocks.1.mlp.fc1.weight', 'conf_decoder.blocks.1.mlp.fc1.bias', 'conf_decoder.blocks.1.mlp.fc2.weight', 'conf_decoder.blocks.1.mlp.fc2.bias', 'conf_decoder.blocks.2.norm1.weight', 'conf_decoder.blocks.2.norm1.bias', 'conf_decoder.blocks.2.attn.qkv.weight', 'conf_decoder.blocks.2.attn.qkv.bias', 'conf_decoder.blocks.2.attn.proj.weight', 'conf_decoder.blocks.2.attn.proj.bias', 'conf_decoder.blocks.2.norm2.weight', 'conf_decoder.blocks.2.norm2.bias', 'conf_decoder.blocks.2.mlp.fc1.weight', 'conf_decoder.blocks.2.mlp.fc1.bias', 'conf_decoder.blocks.2.mlp.fc2.weight', 'conf_decoder.blocks.2.mlp.fc2.bias', 'conf_decoder.blocks.3.norm1.weight', 'conf_decoder.blocks.3.norm1.bias', 'conf_decoder.blocks.3.attn.qkv.weight', 'conf_decoder.blocks.3.attn.qkv.bias', 'conf_decoder.blocks.3.attn.proj.weight', 'conf_decoder.blocks.3.attn.proj.bias', 'conf_decoder.blocks.3.norm2.weight', 'conf_decoder.blocks.3.norm2.bias', 'conf_decoder.blocks.3.mlp.fc1.weight', 'conf_decoder.blocks.3.mlp.fc1.bias', 'conf_decoder.blocks.3.mlp.fc2.weight', 'conf_decoder.blocks.3.mlp.fc2.bias', 'conf_decoder.blocks.4.norm1.weight', 'conf_decoder.blocks.4.norm1.bias', 'conf_decoder.blocks.4.attn.qkv.weight', 'conf_decoder.blocks.4.attn.qkv.bias', 'conf_decoder.blocks.4.attn.proj.weight', 'conf_decoder.blocks.4.attn.proj.bias', 'conf_decoder.blocks.4.norm2.weight', 'conf_decoder.blocks.4.norm2.bias', 'conf_decoder.blocks.4.mlp.fc1.weight', 'conf_decoder.blocks.4.mlp.fc1.bias', 'conf_decoder.blocks.4.mlp.fc2.weight', 'conf_decoder.blocks.4.mlp.fc2.bias', 'conf_decoder.linear_out.weight', 'conf_decoder.linear_out.bias', 'conf_head.proj.weight', 'conf_head.proj.bias', 'camera_decoder.projects.weight', 'camera_decoder.projects.bias', 'camera_decoder.blocks.0.norm1.weight', 'camera_decoder.blocks.0.norm1.bias', 'camera_decoder.blocks.0.attn.qkv.weight', 'camera_decoder.blocks.0.attn.qkv.bias', 'camera_decoder.blocks.0.attn.proj.weight', 'camera_decoder.blocks.0.attn.proj.bias', 'camera_decoder.blocks.0.norm2.weight', 'camera_decoder.blocks.0.norm2.bias', 'camera_decoder.blocks.0.mlp.fc1.weight', 'camera_decoder.blocks.0.mlp.fc1.bias', 'camera_decoder.blocks.0.mlp.fc2.weight', 'camera_decoder.blocks.0.mlp.fc2.bias', 'camera_decoder.blocks.1.norm1.weight', 'camera_decoder.blocks.1.norm1.bias', 'camera_decoder.blocks.1.attn.qkv.weight', 'camera_decoder.blocks.1.attn.qkv.bias', 'camera_decoder.blocks.1.attn.proj.weight', 'camera_decoder.blocks.1.attn.proj.bias', 'camera_decoder.blocks.1.norm2.weight', 'camera_decoder.blocks.1.norm2.bias', 'camera_decoder.blocks.1.mlp.fc1.weight', 'camera_decoder.blocks.1.mlp.fc1.bias', 'camera_decoder.blocks.1.mlp.fc2.weight', 'camera_decoder.blocks.1.mlp.fc2.bias', 'camera_decoder.blocks.2.norm1.weight', 'camera_decoder.blocks.2.norm1.bias', 'camera_decoder.blocks.2.attn.qkv.weight', 'camera_decoder.blocks.2.attn.qkv.bias', 'camera_decoder.blocks.2.attn.proj.weight', 'camera_decoder.blocks.2.attn.proj.bias', 'camera_decoder.blocks.2.norm2.weight', 'camera_decoder.blocks.2.norm2.bias', 'camera_decoder.blocks.2.mlp.fc1.weight', 'camera_decoder.blocks.2.mlp.fc1.bias', 'camera_decoder.blocks.2.mlp.fc2.weight', 'camera_decoder.blocks.2.mlp.fc2.bias', 'camera_decoder.blocks.3.norm1.weight', 'camera_decoder.blocks.3.norm1.bias', 'camera_decoder.blocks.3.attn.qkv.weight', 'camera_decoder.blocks.3.attn.qkv.bias', 'camera_decoder.blocks.3.attn.proj.weight', 'camera_decoder.blocks.3.attn.proj.bias', 'camera_decoder.blocks.3.norm2.weight', 'camera_decoder.blocks.3.norm2.bias', 'camera_decoder.blocks.3.mlp.fc1.weight', 'camera_decoder.blocks.3.mlp.fc1.bias', 'camera_decoder.blocks.3.mlp.fc2.weight', 'camera_decoder.blocks.3.mlp.fc2.bias', 'camera_decoder.blocks.4.norm1.weight', 'camera_decoder.blocks.4.norm1.bias', 'camera_decoder.blocks.4.attn.qkv.weight', 'camera_decoder.blocks.4.attn.qkv.bias', 'camera_decoder.blocks.4.attn.proj.weight', 'camera_decoder.blocks.4.attn.proj.bias', 'camera_decoder.blocks.4.norm2.weight', 'camera_decoder.blocks.4.norm2.bias', 'camera_decoder.blocks.4.mlp.fc1.weight', 'camera_decoder.blocks.4.mlp.fc1.bias', 'camera_decoder.blocks.4.mlp.fc2.weight', 'camera_decoder.blocks.4.mlp.fc2.bias', 'camera_decoder.linear_out.weight', 'camera_decoder.linear_out.bias', 'camera_head.res_conv.0.res_conv1.weight', 'camera_head.res_conv.0.res_conv1.bias', 'camera_head.res_conv.0.res_conv2.weight', 'camera_head.res_conv.0.res_conv2.bias', 'camera_head.res_conv.0.res_conv3.weight', 'camera_head.res_conv.0.res_conv3.bias', 'camera_head.res_conv.1.res_conv1.weight', 'camera_head.res_conv.1.res_conv1.bias', 'camera_head.res_conv.1.res_conv2.weight', 'camera_head.res_conv.1.res_conv2.bias', 'camera_head.res_conv.1.res_conv3.weight', 'camera_head.res_conv.1.res_conv3.bias', 'camera_head.more_mlps.0.weight', 'camera_head.more_mlps.0.bias', 'camera_head.more_mlps.2.weight', 'camera_head.more_mlps.2.bias', 'camera_head.fc_t.weight', 'camera_head.fc_t.bias', 'camera_head.fc_rot.weight', 'camera_head.fc_rot.bias'], unexpected_keys=['module.register_token', 'module.image_mean', 'module.image_std', 'module.encoder.cls_token', 'module.encoder.pos_embed', 'module.encoder.register_tokens', 'module.encoder.patch_embed.proj.weight', 'module.encoder.patch_embed.proj.bias', 'module.encoder.blocks.0.norm1.weight', 'module.encoder.blocks.0.norm1.bias', 'module.encoder.blocks.0.attn.qkv.weight', 'module.encoder.blocks.0.attn.qkv.bias', 'module.encoder.blocks.0.attn.proj.weight', 'module.encoder.blocks.0.attn.proj.bias', 'module.encoder.blocks.0.ls1.gamma', 'module.encoder.blocks.0.norm2.weight', 'module.encoder.blocks.0.norm2.bias', 'module.encoder.blocks.0.mlp.fc1.weight', 'module.encoder.blocks.0.mlp.fc1.bias', 'module.encoder.blocks.0.mlp.fc2.weight', 'module.encoder.blocks.0.mlp.fc2.bias', 'module.encoder.blocks.0.ls2.gamma', 'module.encoder.blocks.1.norm1.weight', 'module.encoder.blocks.1.norm1.bias', 'module.encoder.blocks.1.attn.qkv.weight', 'module.encoder.blocks.1.attn.qkv.bias', 'module.encoder.blocks.1.attn.proj.weight', 'module.encoder.blocks.1.attn.proj.bias', 'module.encoder.blocks.1.ls1.gamma', 'module.encoder.blocks.1.norm2.weight', 'module.encoder.blocks.1.norm2.bias', 'module.encoder.blocks.1.mlp.fc1.weight', 'module.encoder.blocks.1.mlp.fc1.bias', 'module.encoder.blocks.1.mlp.fc2.weight', 'module.encoder.blocks.1.mlp.fc2.bias', 'module.encoder.blocks.1.ls2.gamma', 'module.encoder.blocks.2.norm1.weight', 'module.encoder.blocks.2.norm1.bias', 'module.encoder.blocks.2.attn.qkv.weight', 'module.encoder.blocks.2.attn.qkv.bias', 'module.encoder.blocks.2.attn.proj.weight', 'module.encoder.blocks.2.attn.proj.bias', 'module.encoder.blocks.2.ls1.gamma', 'module.encoder.blocks.2.norm2.weight', 'module.encoder.blocks.2.norm2.bias', 'module.encoder.blocks.2.mlp.fc1.weight', 'module.encoder.blocks.2.mlp.fc1.bias', 'module.encoder.blocks.2.mlp.fc2.weight', 'module.encoder.blocks.2.mlp.fc2.bias', 'module.encoder.blocks.2.ls2.gamma', 'module.encoder.blocks.3.norm1.weight', 'module.encoder.blocks.3.norm1.bias', 'module.encoder.blocks.3.attn.qkv.weight', 'module.encoder.blocks.3.attn.qkv.bias', 'module.encoder.blocks.3.attn.proj.weight', 'module.encoder.blocks.3.attn.proj.bias', 'module.encoder.blocks.3.ls1.gamma', 'module.encoder.blocks.3.norm2.weight', 'module.encoder.blocks.3.norm2.bias', 'module.encoder.blocks.3.mlp.fc1.weight', 'module.encoder.blocks.3.mlp.fc1.bias', 'module.encoder.blocks.3.mlp.fc2.weight', 'module.encoder.blocks.3.mlp.fc2.bias', 'module.encoder.blocks.3.ls2.gamma', 'module.encoder.blocks.4.norm1.weight', 'module.encoder.blocks.4.norm1.bias', 'module.encoder.blocks.4.attn.qkv.weight', 'module.encoder.blocks.4.attn.qkv.bias', 'module.encoder.blocks.4.attn.proj.weight', 'module.encoder.blocks.4.attn.proj.bias', 'module.encoder.blocks.4.ls1.gamma', 'module.encoder.blocks.4.norm2.weight', 'module.encoder.blocks.4.norm2.bias', 'module.encoder.blocks.4.mlp.fc1.weight', 'module.encoder.blocks.4.mlp.fc1.bias', 'module.encoder.blocks.4.mlp.fc2.weight', 'module.encoder.blocks.4.mlp.fc2.bias', 'module.encoder.blocks.4.ls2.gamma', 'module.encoder.blocks.5.norm1.weight', 'module.encoder.blocks.5.norm1.bias', 'module.encoder.blocks.5.attn.qkv.weight', 'module.encoder.blocks.5.attn.qkv.bias', 'module.encoder.blocks.5.attn.proj.weight', 'module.encoder.blocks.5.attn.proj.bias', 'module.encoder.blocks.5.ls1.gamma', 'module.encoder.blocks.5.norm2.weight', 'module.encoder.blocks.5.norm2.bias', 'module.encoder.blocks.5.mlp.fc1.weight', 'module.encoder.blocks.5.mlp.fc1.bias', 'module.encoder.blocks.5.mlp.fc2.weight', 'module.encoder.blocks.5.mlp.fc2.bias', 'module.encoder.blocks.5.ls2.gamma', 'module.encoder.blocks.6.norm1.weight', 'module.encoder.blocks.6.norm1.bias', 'module.encoder.blocks.6.attn.qkv.weight', 'module.encoder.blocks.6.attn.qkv.bias', 'module.encoder.blocks.6.attn.proj.weight', 'module.encoder.blocks.6.attn.proj.bias', 'module.encoder.blocks.6.ls1.gamma', 'module.encoder.blocks.6.norm2.weight', 'module.encoder.blocks.6.norm2.bias', 'module.encoder.blocks.6.mlp.fc1.weight', 'module.encoder.blocks.6.mlp.fc1.bias', 'module.encoder.blocks.6.mlp.fc2.weight', 'module.encoder.blocks.6.mlp.fc2.bias', 'module.encoder.blocks.6.ls2.gamma', 'module.encoder.blocks.7.norm1.weight', 'module.encoder.blocks.7.norm1.bias', 'module.encoder.blocks.7.attn.qkv.weight', 'module.encoder.blocks.7.attn.qkv.bias', 'module.encoder.blocks.7.attn.proj.weight', 'module.encoder.blocks.7.attn.proj.bias', 'module.encoder.blocks.7.ls1.gamma', 'module.encoder.blocks.7.norm2.weight', 'module.encoder.blocks.7.norm2.bias', 'module.encoder.blocks.7.mlp.fc1.weight', 'module.encoder.blocks.7.mlp.fc1.bias', 'module.encoder.blocks.7.mlp.fc2.weight', 'module.encoder.blocks.7.mlp.fc2.bias', 'module.encoder.blocks.7.ls2.gamma', 'module.encoder.blocks.8.norm1.weight', 'module.encoder.blocks.8.norm1.bias', 'module.encoder.blocks.8.attn.qkv.weight', 'module.encoder.blocks.8.attn.qkv.bias', 'module.encoder.blocks.8.attn.proj.weight', 'module.encoder.blocks.8.attn.proj.bias', 'module.encoder.blocks.8.ls1.gamma', 'module.encoder.blocks.8.norm2.weight', 'module.encoder.blocks.8.norm2.bias', 'module.encoder.blocks.8.mlp.fc1.weight', 'module.encoder.blocks.8.mlp.fc1.bias', 'module.encoder.blocks.8.mlp.fc2.weight', 'module.encoder.blocks.8.mlp.fc2.bias', 'module.encoder.blocks.8.ls2.gamma', 'module.encoder.blocks.9.norm1.weight', 'module.encoder.blocks.9.norm1.bias', 'module.encoder.blocks.9.attn.qkv.weight', 'module.encoder.blocks.9.attn.qkv.bias', 'module.encoder.blocks.9.attn.proj.weight', 'module.encoder.blocks.9.attn.proj.bias', 'module.encoder.blocks.9.ls1.gamma', 'module.encoder.blocks.9.norm2.weight', 'module.encoder.blocks.9.norm2.bias', 'module.encoder.blocks.9.mlp.fc1.weight', 'module.encoder.blocks.9.mlp.fc1.bias', 'module.encoder.blocks.9.mlp.fc2.weight', 'module.encoder.blocks.9.mlp.fc2.bias', 'module.encoder.blocks.9.ls2.gamma', 'module.encoder.blocks.10.norm1.weight', 'module.encoder.blocks.10.norm1.bias', 'module.encoder.blocks.10.attn.qkv.weight', 'module.encoder.blocks.10.attn.qkv.bias', 'module.encoder.blocks.10.attn.proj.weight', 'module.encoder.blocks.10.attn.proj.bias', 'module.encoder.blocks.10.ls1.gamma', 'module.encoder.blocks.10.norm2.weight', 'module.encoder.blocks.10.norm2.bias', 'module.encoder.blocks.10.mlp.fc1.weight', 'module.encoder.blocks.10.mlp.fc1.bias', 'module.encoder.blocks.10.mlp.fc2.weight', 'module.encoder.blocks.10.mlp.fc2.bias', 'module.encoder.blocks.10.ls2.gamma', 'module.encoder.blocks.11.norm1.weight', 'module.encoder.blocks.11.norm1.bias', 'module.encoder.blocks.11.attn.qkv.weight', 'module.encoder.blocks.11.attn.qkv.bias', 'module.encoder.blocks.11.attn.proj.weight', 'module.encoder.blocks.11.attn.proj.bias', 'module.encoder.blocks.11.ls1.gamma', 'module.encoder.blocks.11.norm2.weight', 'module.encoder.blocks.11.norm2.bias', 'module.encoder.blocks.11.mlp.fc1.weight', 'module.encoder.blocks.11.mlp.fc1.bias', 'module.encoder.blocks.11.mlp.fc2.weight', 'module.encoder.blocks.11.mlp.fc2.bias', 'module.encoder.blocks.11.ls2.gamma', 'module.encoder.blocks.12.norm1.weight', 'module.encoder.blocks.12.norm1.bias', 'module.encoder.blocks.12.attn.qkv.weight', 'module.encoder.blocks.12.attn.qkv.bias', 'module.encoder.blocks.12.attn.proj.weight', 'module.encoder.blocks.12.attn.proj.bias', 'module.encoder.blocks.12.ls1.gamma', 'module.encoder.blocks.12.norm2.weight', 'module.encoder.blocks.12.norm2.bias', 'module.encoder.blocks.12.mlp.fc1.weight', 'module.encoder.blocks.12.mlp.fc1.bias', 'module.encoder.blocks.12.mlp.fc2.weight', 'module.encoder.blocks.12.mlp.fc2.bias', 'module.encoder.blocks.12.ls2.gamma', 'module.encoder.blocks.13.norm1.weight', 'module.encoder.blocks.13.norm1.bias', 'module.encoder.blocks.13.attn.qkv.weight', 'module.encoder.blocks.13.attn.qkv.bias', 'module.encoder.blocks.13.attn.proj.weight', 'module.encoder.blocks.13.attn.proj.bias', 'module.encoder.blocks.13.ls1.gamma', 'module.encoder.blocks.13.norm2.weight', 'module.encoder.blocks.13.norm2.bias', 'module.encoder.blocks.13.mlp.fc1.weight', 'module.encoder.blocks.13.mlp.fc1.bias', 'module.encoder.blocks.13.mlp.fc2.weight', 'module.encoder.blocks.13.mlp.fc2.bias', 'module.encoder.blocks.13.ls2.gamma', 'module.encoder.blocks.14.norm1.weight', 'module.encoder.blocks.14.norm1.bias', 'module.encoder.blocks.14.attn.qkv.weight', 'module.encoder.blocks.14.attn.qkv.bias', 'module.encoder.blocks.14.attn.proj.weight', 'module.encoder.blocks.14.attn.proj.bias', 'module.encoder.blocks.14.ls1.gamma', 'module.encoder.blocks.14.norm2.weight', 'module.encoder.blocks.14.norm2.bias', 'module.encoder.blocks.14.mlp.fc1.weight', 'module.encoder.blocks.14.mlp.fc1.bias', 'module.encoder.blocks.14.mlp.fc2.weight', 'module.encoder.blocks.14.mlp.fc2.bias', 'module.encoder.blocks.14.ls2.gamma', 'module.encoder.blocks.15.norm1.weight', 'module.encoder.blocks.15.norm1.bias', 'module.encoder.blocks.15.attn.qkv.weight', 'module.encoder.blocks.15.attn.qkv.bias', 'module.encoder.blocks.15.attn.proj.weight', 'module.encoder.blocks.15.attn.proj.bias', 'module.encoder.blocks.15.ls1.gamma', 'module.encoder.blocks.15.norm2.weight', 'module.encoder.blocks.15.norm2.bias', 'module.encoder.blocks.15.mlp.fc1.weight', 'module.encoder.blocks.15.mlp.fc1.bias', 'module.encoder.blocks.15.mlp.fc2.weight', 'module.encoder.blocks.15.mlp.fc2.bias', 'module.encoder.blocks.15.ls2.gamma', 'module.encoder.blocks.16.norm1.weight', 'module.encoder.blocks.16.norm1.bias', 'module.encoder.blocks.16.attn.qkv.weight', 'module.encoder.blocks.16.attn.qkv.bias', 'module.encoder.blocks.16.attn.proj.weight', 'module.encoder.blocks.16.attn.proj.bias', 'module.encoder.blocks.16.ls1.gamma', 'module.encoder.blocks.16.norm2.weight', 'module.encoder.blocks.16.norm2.bias', 'module.encoder.blocks.16.mlp.fc1.weight', 'module.encoder.blocks.16.mlp.fc1.bias', 'module.encoder.blocks.16.mlp.fc2.weight', 'module.encoder.blocks.16.mlp.fc2.bias', 'module.encoder.blocks.16.ls2.gamma', 'module.encoder.blocks.17.norm1.weight', 'module.encoder.blocks.17.norm1.bias', 'module.encoder.blocks.17.attn.qkv.weight', 'module.encoder.blocks.17.attn.qkv.bias', 'module.encoder.blocks.17.attn.proj.weight', 'module.encoder.blocks.17.attn.proj.bias', 'module.encoder.blocks.17.ls1.gamma', 'module.encoder.blocks.17.norm2.weight', 'module.encoder.blocks.17.norm2.bias', 'module.encoder.blocks.17.mlp.fc1.weight', 'module.encoder.blocks.17.mlp.fc1.bias', 'module.encoder.blocks.17.mlp.fc2.weight', 'module.encoder.blocks.17.mlp.fc2.bias', 'module.encoder.blocks.17.ls2.gamma', 'module.encoder.blocks.18.norm1.weight', 'module.encoder.blocks.18.norm1.bias', 'module.encoder.blocks.18.attn.qkv.weight', 'module.encoder.blocks.18.attn.qkv.bias', 'module.encoder.blocks.18.attn.proj.weight', 'module.encoder.blocks.18.attn.proj.bias', 'module.encoder.blocks.18.ls1.gamma', 'module.encoder.blocks.18.norm2.weight', 'module.encoder.blocks.18.norm2.bias', 'module.encoder.blocks.18.mlp.fc1.weight', 'module.encoder.blocks.18.mlp.fc1.bias', 'module.encoder.blocks.18.mlp.fc2.weight', 'module.encoder.blocks.18.mlp.fc2.bias', 'module.encoder.blocks.18.ls2.gamma', 'module.encoder.blocks.19.norm1.weight', 'module.encoder.blocks.19.norm1.bias', 'module.encoder.blocks.19.attn.qkv.weight', 'module.encoder.blocks.19.attn.qkv.bias', 'module.encoder.blocks.19.attn.proj.weight', 'module.encoder.blocks.19.attn.proj.bias', 'module.encoder.blocks.19.ls1.gamma', 'module.encoder.blocks.19.norm2.weight', 'module.encoder.blocks.19.norm2.bias', 'module.encoder.blocks.19.mlp.fc1.weight', 'module.encoder.blocks.19.mlp.fc1.bias', 'module.encoder.blocks.19.mlp.fc2.weight', 'module.encoder.blocks.19.mlp.fc2.bias', 'module.encoder.blocks.19.ls2.gamma', 'module.encoder.blocks.20.norm1.weight', 'module.encoder.blocks.20.norm1.bias', 'module.encoder.blocks.20.attn.qkv.weight', 'module.encoder.blocks.20.attn.qkv.bias', 'module.encoder.blocks.20.attn.proj.weight', 'module.encoder.blocks.20.attn.proj.bias', 'module.encoder.blocks.20.ls1.gamma', 'module.encoder.blocks.20.norm2.weight', 'module.encoder.blocks.20.norm2.bias', 'module.encoder.blocks.20.mlp.fc1.weight', 'module.encoder.blocks.20.mlp.fc1.bias', 'module.encoder.blocks.20.mlp.fc2.weight', 'module.encoder.blocks.20.mlp.fc2.bias', 'module.encoder.blocks.20.ls2.gamma', 'module.encoder.blocks.21.norm1.weight', 'module.encoder.blocks.21.norm1.bias', 'module.encoder.blocks.21.attn.qkv.weight', 'module.encoder.blocks.21.attn.qkv.bias', 'module.encoder.blocks.21.attn.proj.weight', 'module.encoder.blocks.21.attn.proj.bias', 'module.encoder.blocks.21.ls1.gamma', 'module.encoder.blocks.21.norm2.weight', 'module.encoder.blocks.21.norm2.bias', 'module.encoder.blocks.21.mlp.fc1.weight', 'module.encoder.blocks.21.mlp.fc1.bias', 'module.encoder.blocks.21.mlp.fc2.weight', 'module.encoder.blocks.21.mlp.fc2.bias', 'module.encoder.blocks.21.ls2.gamma', 'module.encoder.blocks.22.norm1.weight', 'module.encoder.blocks.22.norm1.bias', 'module.encoder.blocks.22.attn.qkv.weight', 'module.encoder.blocks.22.attn.qkv.bias', 'module.encoder.blocks.22.attn.proj.weight', 'module.encoder.blocks.22.attn.proj.bias', 'module.encoder.blocks.22.ls1.gamma', 'module.encoder.blocks.22.norm2.weight', 'module.encoder.blocks.22.norm2.bias', 'module.encoder.blocks.22.mlp.fc1.weight', 'module.encoder.blocks.22.mlp.fc1.bias', 'module.encoder.blocks.22.mlp.fc2.weight', 'module.encoder.blocks.22.mlp.fc2.bias', 'module.encoder.blocks.22.ls2.gamma', 'module.encoder.blocks.23.norm1.weight', 'module.encoder.blocks.23.norm1.bias', 'module.encoder.blocks.23.attn.qkv.weight', 'module.encoder.blocks.23.attn.qkv.bias', 'module.encoder.blocks.23.attn.proj.weight', 'module.encoder.blocks.23.attn.proj.bias', 'module.encoder.blocks.23.ls1.gamma', 'module.encoder.blocks.23.norm2.weight', 'module.encoder.blocks.23.norm2.bias', 'module.encoder.blocks.23.mlp.fc1.weight', 'module.encoder.blocks.23.mlp.fc1.bias', 'module.encoder.blocks.23.mlp.fc2.weight', 'module.encoder.blocks.23.mlp.fc2.bias', 'module.encoder.blocks.23.ls2.gamma', 'module.encoder.norm.weight', 'module.encoder.norm.bias', 'module.decoder.0.norm1.weight', 'module.decoder.0.norm1.bias', 'module.decoder.0.attn.qkv.weight', 'module.decoder.0.attn.qkv.bias', 'module.decoder.0.attn.proj.weight', 'module.decoder.0.attn.proj.bias', 'module.decoder.0.attn.q_norm.weight', 'module.decoder.0.attn.q_norm.bias', 'module.decoder.0.attn.k_norm.weight', 'module.decoder.0.attn.k_norm.bias', 'module.decoder.0.ls1.gamma', 'module.decoder.0.norm2.weight', 'module.decoder.0.norm2.bias', 'module.decoder.0.mlp.fc1.weight', 'module.decoder.0.mlp.fc1.bias', 'module.decoder.0.mlp.fc2.weight', 'module.decoder.0.mlp.fc2.bias', 'module.decoder.0.ls2.gamma', 'module.decoder.1.norm1.weight', 'module.decoder.1.norm1.bias', 'module.decoder.1.attn.qkv.weight', 'module.decoder.1.attn.qkv.bias', 'module.decoder.1.attn.proj.weight', 'module.decoder.1.attn.proj.bias', 'module.decoder.1.attn.q_norm.weight', 'module.decoder.1.attn.q_norm.bias', 'module.decoder.1.attn.k_norm.weight', 'module.decoder.1.attn.k_norm.bias', 'module.decoder.1.ls1.gamma', 'module.decoder.1.norm2.weight', 'module.decoder.1.norm2.bias', 'module.decoder.1.mlp.fc1.weight', 'module.decoder.1.mlp.fc1.bias', 'module.decoder.1.mlp.fc2.weight', 'module.decoder.1.mlp.fc2.bias', 'module.decoder.1.ls2.gamma', 'module.decoder.2.norm1.weight', 'module.decoder.2.norm1.bias', 'module.decoder.2.attn.qkv.weight', 'module.decoder.2.attn.qkv.bias', 'module.decoder.2.attn.proj.weight', 'module.decoder.2.attn.proj.bias', 'module.decoder.2.attn.q_norm.weight', 'module.decoder.2.attn.q_norm.bias', 'module.decoder.2.attn.k_norm.weight', 'module.decoder.2.attn.k_norm.bias', 'module.decoder.2.ls1.gamma', 'module.decoder.2.norm2.weight', 'module.decoder.2.norm2.bias', 'module.decoder.2.mlp.fc1.weight', 'module.decoder.2.mlp.fc1.bias', 'module.decoder.2.mlp.fc2.weight', 'module.decoder.2.mlp.fc2.bias', 'module.decoder.2.ls2.gamma', 'module.decoder.3.norm1.weight', 'module.decoder.3.norm1.bias', 'module.decoder.3.attn.qkv.weight', 'module.decoder.3.attn.qkv.bias', 'module.decoder.3.attn.proj.weight', 'module.decoder.3.attn.proj.bias', 'module.decoder.3.attn.q_norm.weight', 'module.decoder.3.attn.q_norm.bias', 'module.decoder.3.attn.k_norm.weight', 'module.decoder.3.attn.k_norm.bias', 'module.decoder.3.ls1.gamma', 'module.decoder.3.norm2.weight', 'module.decoder.3.norm2.bias', 'module.decoder.3.mlp.fc1.weight', 'module.decoder.3.mlp.fc1.bias', 'module.decoder.3.mlp.fc2.weight', 'module.decoder.3.mlp.fc2.bias', 'module.decoder.3.ls2.gamma', 'module.decoder.4.norm1.weight', 'module.decoder.4.norm1.bias', 'module.decoder.4.attn.qkv.weight', 'module.decoder.4.attn.qkv.bias', 'module.decoder.4.attn.proj.weight', 'module.decoder.4.attn.proj.bias', 'module.decoder.4.attn.q_norm.weight', 'module.decoder.4.attn.q_norm.bias', 'module.decoder.4.attn.k_norm.weight', 'module.decoder.4.attn.k_norm.bias', 'module.decoder.4.ls1.gamma', 'module.decoder.4.norm2.weight', 'module.decoder.4.norm2.bias', 'module.decoder.4.mlp.fc1.weight', 'module.decoder.4.mlp.fc1.bias', 'module.decoder.4.mlp.fc2.weight', 'module.decoder.4.mlp.fc2.bias', 'module.decoder.4.ls2.gamma', 'module.decoder.5.norm1.weight', 'module.decoder.5.norm1.bias', 'module.decoder.5.attn.qkv.weight', 'module.decoder.5.attn.qkv.bias', 'module.decoder.5.attn.proj.weight', 'module.decoder.5.attn.proj.bias', 'module.decoder.5.attn.q_norm.weight', 'module.decoder.5.attn.q_norm.bias', 'module.decoder.5.attn.k_norm.weight', 'module.decoder.5.attn.k_norm.bias', 'module.decoder.5.ls1.gamma', 'module.decoder.5.norm2.weight', 'module.decoder.5.norm2.bias', 'module.decoder.5.mlp.fc1.weight', 'module.decoder.5.mlp.fc1.bias', 'module.decoder.5.mlp.fc2.weight', 'module.decoder.5.mlp.fc2.bias', 'module.decoder.5.ls2.gamma', 'module.decoder.6.norm1.weight', 'module.decoder.6.norm1.bias', 'module.decoder.6.attn.qkv.weight', 'module.decoder.6.attn.qkv.bias', 'module.decoder.6.attn.proj.weight', 'module.decoder.6.attn.proj.bias', 'module.decoder.6.attn.q_norm.weight', 'module.decoder.6.attn.q_norm.bias', 'module.decoder.6.attn.k_norm.weight', 'module.decoder.6.attn.k_norm.bias', 'module.decoder.6.ls1.gamma', 'module.decoder.6.norm2.weight', 'module.decoder.6.norm2.bias', 'module.decoder.6.mlp.fc1.weight', 'module.decoder.6.mlp.fc1.bias', 'module.decoder.6.mlp.fc2.weight', 'module.decoder.6.mlp.fc2.bias', 'module.decoder.6.ls2.gamma', 'module.decoder.7.norm1.weight', 'module.decoder.7.norm1.bias', 'module.decoder.7.attn.qkv.weight', 'module.decoder.7.attn.qkv.bias', 'module.decoder.7.attn.proj.weight', 'module.decoder.7.attn.proj.bias', 'module.decoder.7.attn.q_norm.weight', 'module.decoder.7.attn.q_norm.bias', 'module.decoder.7.attn.k_norm.weight', 'module.decoder.7.attn.k_norm.bias', 'module.decoder.7.ls1.gamma', 'module.decoder.7.norm2.weight', 'module.decoder.7.norm2.bias', 'module.decoder.7.mlp.fc1.weight', 'module.decoder.7.mlp.fc1.bias', 'module.decoder.7.mlp.fc2.weight', 'module.decoder.7.mlp.fc2.bias', 'module.decoder.7.ls2.gamma', 'module.decoder.8.norm1.weight', 'module.decoder.8.norm1.bias', 'module.decoder.8.attn.qkv.weight', 'module.decoder.8.attn.qkv.bias', 'module.decoder.8.attn.proj.weight', 'module.decoder.8.attn.proj.bias', 'module.decoder.8.attn.q_norm.weight', 'module.decoder.8.attn.q_norm.bias', 'module.decoder.8.attn.k_norm.weight', 'module.decoder.8.attn.k_norm.bias', 'module.decoder.8.ls1.gamma', 'module.decoder.8.norm2.weight', 'module.decoder.8.norm2.bias', 'module.decoder.8.mlp.fc1.weight', 'module.decoder.8.mlp.fc1.bias', 'module.decoder.8.mlp.fc2.weight', 'module.decoder.8.mlp.fc2.bias', 'module.decoder.8.ls2.gamma', 'module.decoder.9.norm1.weight', 'module.decoder.9.norm1.bias', 'module.decoder.9.attn.qkv.weight', 'module.decoder.9.attn.qkv.bias', 'module.decoder.9.attn.proj.weight', 'module.decoder.9.attn.proj.bias', 'module.decoder.9.attn.q_norm.weight', 'module.decoder.9.attn.q_norm.bias', 'module.decoder.9.attn.k_norm.weight', 'module.decoder.9.attn.k_norm.bias', 'module.decoder.9.ls1.gamma', 'module.decoder.9.norm2.weight', 'module.decoder.9.norm2.bias', 'module.decoder.9.mlp.fc1.weight', 'module.decoder.9.mlp.fc1.bias', 'module.decoder.9.mlp.fc2.weight', 'module.decoder.9.mlp.fc2.bias', 'module.decoder.9.ls2.gamma', 'module.decoder.10.norm1.weight', 'module.decoder.10.norm1.bias', 'module.decoder.10.attn.qkv.weight', 'module.decoder.10.attn.qkv.bias', 'module.decoder.10.attn.proj.weight', 'module.decoder.10.attn.proj.bias', 'module.decoder.10.attn.q_norm.weight', 'module.decoder.10.attn.q_norm.bias', 'module.decoder.10.attn.k_norm.weight', 'module.decoder.10.attn.k_norm.bias', 'module.decoder.10.ls1.gamma', 'module.decoder.10.norm2.weight', 'module.decoder.10.norm2.bias', 'module.decoder.10.mlp.fc1.weight', 'module.decoder.10.mlp.fc1.bias', 'module.decoder.10.mlp.fc2.weight', 'module.decoder.10.mlp.fc2.bias', 'module.decoder.10.ls2.gamma', 'module.decoder.11.norm1.weight', 'module.decoder.11.norm1.bias', 'module.decoder.11.attn.qkv.weight', 'module.decoder.11.attn.qkv.bias', 'module.decoder.11.attn.proj.weight', 'module.decoder.11.attn.proj.bias', 'module.decoder.11.attn.q_norm.weight', 'module.decoder.11.attn.q_norm.bias', 'module.decoder.11.attn.k_norm.weight', 'module.decoder.11.attn.k_norm.bias', 'module.decoder.11.ls1.gamma', 'module.decoder.11.norm2.weight', 'module.decoder.11.norm2.bias', 'module.decoder.11.mlp.fc1.weight', 'module.decoder.11.mlp.fc1.bias', 'module.decoder.11.mlp.fc2.weight', 'module.decoder.11.mlp.fc2.bias', 'module.decoder.11.ls2.gamma', 'module.decoder.12.norm1.weight', 'module.decoder.12.norm1.bias', 'module.decoder.12.attn.qkv.weight', 'module.decoder.12.attn.qkv.bias', 'module.decoder.12.attn.proj.weight', 'module.decoder.12.attn.proj.bias', 'module.decoder.12.attn.q_norm.weight', 'module.decoder.12.attn.q_norm.bias', 'module.decoder.12.attn.k_norm.weight', 'module.decoder.12.attn.k_norm.bias', 'module.decoder.12.ls1.gamma', 'module.decoder.12.norm2.weight', 'module.decoder.12.norm2.bias', 'module.decoder.12.mlp.fc1.weight', 'module.decoder.12.mlp.fc1.bias', 'module.decoder.12.mlp.fc2.weight', 'module.decoder.12.mlp.fc2.bias', 'module.decoder.12.ls2.gamma', 'module.decoder.13.norm1.weight', 'module.decoder.13.norm1.bias', 'module.decoder.13.attn.qkv.weight', 'module.decoder.13.attn.qkv.bias', 'module.decoder.13.attn.proj.weight', 'module.decoder.13.attn.proj.bias', 'module.decoder.13.attn.q_norm.weight', 'module.decoder.13.attn.q_norm.bias', 'module.decoder.13.attn.k_norm.weight', 'module.decoder.13.attn.k_norm.bias', 'module.decoder.13.ls1.gamma', 'module.decoder.13.norm2.weight', 'module.decoder.13.norm2.bias', 'module.decoder.13.mlp.fc1.weight', 'module.decoder.13.mlp.fc1.bias', 'module.decoder.13.mlp.fc2.weight', 'module.decoder.13.mlp.fc2.bias', 'module.decoder.13.ls2.gamma', 'module.decoder.14.norm1.weight', 'module.decoder.14.norm1.bias', 'module.decoder.14.attn.qkv.weight', 'module.decoder.14.attn.qkv.bias', 'module.decoder.14.attn.proj.weight', 'module.decoder.14.attn.proj.bias', 'module.decoder.14.attn.q_norm.weight', 'module.decoder.14.attn.q_norm.bias', 'module.decoder.14.attn.k_norm.weight', 'module.decoder.14.attn.k_norm.bias', 'module.decoder.14.ls1.gamma', 'module.decoder.14.norm2.weight', 'module.decoder.14.norm2.bias', 'module.decoder.14.mlp.fc1.weight', 'module.decoder.14.mlp.fc1.bias', 'module.decoder.14.mlp.fc2.weight', 'module.decoder.14.mlp.fc2.bias', 'module.decoder.14.ls2.gamma', 'module.decoder.15.norm1.weight', 'module.decoder.15.norm1.bias', 'module.decoder.15.attn.qkv.weight', 'module.decoder.15.attn.qkv.bias', 'module.decoder.15.attn.proj.weight', 'module.decoder.15.attn.proj.bias', 'module.decoder.15.attn.q_norm.weight', 'module.decoder.15.attn.q_norm.bias', 'module.decoder.15.attn.k_norm.weight', 'module.decoder.15.attn.k_norm.bias', 'module.decoder.15.ls1.gamma', 'module.decoder.15.norm2.weight', 'module.decoder.15.norm2.bias', 'module.decoder.15.mlp.fc1.weight', 'module.decoder.15.mlp.fc1.bias', 'module.decoder.15.mlp.fc2.weight', 'module.decoder.15.mlp.fc2.bias', 'module.decoder.15.ls2.gamma', 'module.decoder.16.norm1.weight', 'module.decoder.16.norm1.bias', 'module.decoder.16.attn.qkv.weight', 'module.decoder.16.attn.qkv.bias', 'module.decoder.16.attn.proj.weight', 'module.decoder.16.attn.proj.bias', 'module.decoder.16.attn.q_norm.weight', 'module.decoder.16.attn.q_norm.bias', 'module.decoder.16.attn.k_norm.weight', 'module.decoder.16.attn.k_norm.bias', 'module.decoder.16.ls1.gamma', 'module.decoder.16.norm2.weight', 'module.decoder.16.norm2.bias', 'module.decoder.16.mlp.fc1.weight', 'module.decoder.16.mlp.fc1.bias', 'module.decoder.16.mlp.fc2.weight', 'module.decoder.16.mlp.fc2.bias', 'module.decoder.16.ls2.gamma', 'module.decoder.17.norm1.weight', 'module.decoder.17.norm1.bias', 'module.decoder.17.attn.qkv.weight', 'module.decoder.17.attn.qkv.bias', 'module.decoder.17.attn.proj.weight', 'module.decoder.17.attn.proj.bias', 'module.decoder.17.attn.q_norm.weight', 'module.decoder.17.attn.q_norm.bias', 'module.decoder.17.attn.k_norm.weight', 'module.decoder.17.attn.k_norm.bias', 'module.decoder.17.ls1.gamma', 'module.decoder.17.norm2.weight', 'module.decoder.17.norm2.bias', 'module.decoder.17.mlp.fc1.weight', 'module.decoder.17.mlp.fc1.bias', 'module.decoder.17.mlp.fc2.weight', 'module.decoder.17.mlp.fc2.bias', 'module.decoder.17.ls2.gamma', 'module.decoder.18.norm1.weight', 'module.decoder.18.norm1.bias', 'module.decoder.18.attn.qkv.weight', 'module.decoder.18.attn.qkv.bias', 'module.decoder.18.attn.proj.weight', 'module.decoder.18.attn.proj.bias', 'module.decoder.18.attn.q_norm.weight', 'module.decoder.18.attn.q_norm.bias', 'module.decoder.18.attn.k_norm.weight', 'module.decoder.18.attn.k_norm.bias', 'module.decoder.18.ls1.gamma', 'module.decoder.18.norm2.weight', 'module.decoder.18.norm2.bias', 'module.decoder.18.mlp.fc1.weight', 'module.decoder.18.mlp.fc1.bias', 'module.decoder.18.mlp.fc2.weight', 'module.decoder.18.mlp.fc2.bias', 'module.decoder.18.ls2.gamma', 'module.decoder.19.norm1.weight', 'module.decoder.19.norm1.bias', 'module.decoder.19.attn.qkv.weight', 'module.decoder.19.attn.qkv.bias', 'module.decoder.19.attn.proj.weight', 'module.decoder.19.attn.proj.bias', 'module.decoder.19.attn.q_norm.weight', 'module.decoder.19.attn.q_norm.bias', 'module.decoder.19.attn.k_norm.weight', 'module.decoder.19.attn.k_norm.bias', 'module.decoder.19.ls1.gamma', 'module.decoder.19.norm2.weight', 'module.decoder.19.norm2.bias', 'module.decoder.19.mlp.fc1.weight', 'module.decoder.19.mlp.fc1.bias', 'module.decoder.19.mlp.fc2.weight', 'module.decoder.19.mlp.fc2.bias', 'module.decoder.19.ls2.gamma', 'module.decoder.20.norm1.weight', 'module.decoder.20.norm1.bias', 'module.decoder.20.attn.qkv.weight', 'module.decoder.20.attn.qkv.bias', 'module.decoder.20.attn.proj.weight', 'module.decoder.20.attn.proj.bias', 'module.decoder.20.attn.q_norm.weight', 'module.decoder.20.attn.q_norm.bias', 'module.decoder.20.attn.k_norm.weight', 'module.decoder.20.attn.k_norm.bias', 'module.decoder.20.ls1.gamma', 'module.decoder.20.norm2.weight', 'module.decoder.20.norm2.bias', 'module.decoder.20.mlp.fc1.weight', 'module.decoder.20.mlp.fc1.bias', 'module.decoder.20.mlp.fc2.weight', 'module.decoder.20.mlp.fc2.bias', 'module.decoder.20.ls2.gamma', 'module.decoder.21.norm1.weight', 'module.decoder.21.norm1.bias', 'module.decoder.21.attn.qkv.weight', 'module.decoder.21.attn.qkv.bias', 'module.decoder.21.attn.proj.weight', 'module.decoder.21.attn.proj.bias', 'module.decoder.21.attn.q_norm.weight', 'module.decoder.21.attn.q_norm.bias', 'module.decoder.21.attn.k_norm.weight', 'module.decoder.21.attn.k_norm.bias', 'module.decoder.21.ls1.gamma', 'module.decoder.21.norm2.weight', 'module.decoder.21.norm2.bias', 'module.decoder.21.mlp.fc1.weight', 'module.decoder.21.mlp.fc1.bias', 'module.decoder.21.mlp.fc2.weight', 'module.decoder.21.mlp.fc2.bias', 'module.decoder.21.ls2.gamma', 'module.decoder.22.norm1.weight', 'module.decoder.22.norm1.bias', 'module.decoder.22.attn.qkv.weight', 'module.decoder.22.attn.qkv.bias', 'module.decoder.22.attn.proj.weight', 'module.decoder.22.attn.proj.bias', 'module.decoder.22.attn.q_norm.weight', 'module.decoder.22.attn.q_norm.bias', 'module.decoder.22.attn.k_norm.weight', 'module.decoder.22.attn.k_norm.bias', 'module.decoder.22.ls1.gamma', 'module.decoder.22.norm2.weight', 'module.decoder.22.norm2.bias', 'module.decoder.22.mlp.fc1.weight', 'module.decoder.22.mlp.fc1.bias', 'module.decoder.22.mlp.fc2.weight', 'module.decoder.22.mlp.fc2.bias', 'module.decoder.22.ls2.gamma', 'module.decoder.23.norm1.weight', 'module.decoder.23.norm1.bias', 'module.decoder.23.attn.qkv.weight', 'module.decoder.23.attn.qkv.bias', 'module.decoder.23.attn.proj.weight', 'module.decoder.23.attn.proj.bias', 'module.decoder.23.attn.q_norm.weight', 'module.decoder.23.attn.q_norm.bias', 'module.decoder.23.attn.k_norm.weight', 'module.decoder.23.attn.k_norm.bias', 'module.decoder.23.ls1.gamma', 'module.decoder.23.norm2.weight', 'module.decoder.23.norm2.bias', 'module.decoder.23.mlp.fc1.weight', 'module.decoder.23.mlp.fc1.bias', 'module.decoder.23.mlp.fc2.weight', 'module.decoder.23.mlp.fc2.bias', 'module.decoder.23.ls2.gamma', 'module.decoder.24.norm1.weight', 'module.decoder.24.norm1.bias', 'module.decoder.24.attn.qkv.weight', 'module.decoder.24.attn.qkv.bias', 'module.decoder.24.attn.proj.weight', 'module.decoder.24.attn.proj.bias', 'module.decoder.24.attn.q_norm.weight', 'module.decoder.24.attn.q_norm.bias', 'module.decoder.24.attn.k_norm.weight', 'module.decoder.24.attn.k_norm.bias', 'module.decoder.24.ls1.gamma', 'module.decoder.24.norm2.weight', 'module.decoder.24.norm2.bias', 'module.decoder.24.mlp.fc1.weight', 'module.decoder.24.mlp.fc1.bias', 'module.decoder.24.mlp.fc2.weight', 'module.decoder.24.mlp.fc2.bias', 'module.decoder.24.ls2.gamma', 'module.decoder.25.norm1.weight', 'module.decoder.25.norm1.bias', 'module.decoder.25.attn.qkv.weight', 'module.decoder.25.attn.qkv.bias', 'module.decoder.25.attn.proj.weight', 'module.decoder.25.attn.proj.bias', 'module.decoder.25.attn.q_norm.weight', 'module.decoder.25.attn.q_norm.bias', 'module.decoder.25.attn.k_norm.weight', 'module.decoder.25.attn.k_norm.bias', 'module.decoder.25.ls1.gamma', 'module.decoder.25.norm2.weight', 'module.decoder.25.norm2.bias', 'module.decoder.25.mlp.fc1.weight', 'module.decoder.25.mlp.fc1.bias', 'module.decoder.25.mlp.fc2.weight', 'module.decoder.25.mlp.fc2.bias', 'module.decoder.25.ls2.gamma', 'module.decoder.26.norm1.weight', 'module.decoder.26.norm1.bias', 'module.decoder.26.attn.qkv.weight', 'module.decoder.26.attn.qkv.bias', 'module.decoder.26.attn.proj.weight', 'module.decoder.26.attn.proj.bias', 'module.decoder.26.attn.q_norm.weight', 'module.decoder.26.attn.q_norm.bias', 'module.decoder.26.attn.k_norm.weight', 'module.decoder.26.attn.k_norm.bias', 'module.decoder.26.ls1.gamma', 'module.decoder.26.norm2.weight', 'module.decoder.26.norm2.bias', 'module.decoder.26.mlp.fc1.weight', 'module.decoder.26.mlp.fc1.bias', 'module.decoder.26.mlp.fc2.weight', 'module.decoder.26.mlp.fc2.bias', 'module.decoder.26.ls2.gamma', 'module.decoder.27.norm1.weight', 'module.decoder.27.norm1.bias', 'module.decoder.27.attn.qkv.weight', 'module.decoder.27.attn.qkv.bias', 'module.decoder.27.attn.proj.weight', 'module.decoder.27.attn.proj.bias', 'module.decoder.27.attn.q_norm.weight', 'module.decoder.27.attn.q_norm.bias', 'module.decoder.27.attn.k_norm.weight', 'module.decoder.27.attn.k_norm.bias', 'module.decoder.27.ls1.gamma', 'module.decoder.27.norm2.weight', 'module.decoder.27.norm2.bias', 'module.decoder.27.mlp.fc1.weight', 'module.decoder.27.mlp.fc1.bias', 'module.decoder.27.mlp.fc2.weight', 'module.decoder.27.mlp.fc2.bias', 'module.decoder.27.ls2.gamma', 'module.decoder.28.norm1.weight', 'module.decoder.28.norm1.bias', 'module.decoder.28.attn.qkv.weight', 'module.decoder.28.attn.qkv.bias', 'module.decoder.28.attn.proj.weight', 'module.decoder.28.attn.proj.bias', 'module.decoder.28.attn.q_norm.weight', 'module.decoder.28.attn.q_norm.bias', 'module.decoder.28.attn.k_norm.weight', 'module.decoder.28.attn.k_norm.bias', 'module.decoder.28.ls1.gamma', 'module.decoder.28.norm2.weight', 'module.decoder.28.norm2.bias', 'module.decoder.28.mlp.fc1.weight', 'module.decoder.28.mlp.fc1.bias', 'module.decoder.28.mlp.fc2.weight', 'module.decoder.28.mlp.fc2.bias', 'module.decoder.28.ls2.gamma', 'module.decoder.29.norm1.weight', 'module.decoder.29.norm1.bias', 'module.decoder.29.attn.qkv.weight', 'module.decoder.29.attn.qkv.bias', 'module.decoder.29.attn.proj.weight', 'module.decoder.29.attn.proj.bias', 'module.decoder.29.attn.q_norm.weight', 'module.decoder.29.attn.q_norm.bias', 'module.decoder.29.attn.k_norm.weight', 'module.decoder.29.attn.k_norm.bias', 'module.decoder.29.ls1.gamma', 'module.decoder.29.norm2.weight', 'module.decoder.29.norm2.bias', 'module.decoder.29.mlp.fc1.weight', 'module.decoder.29.mlp.fc1.bias', 'module.decoder.29.mlp.fc2.weight', 'module.decoder.29.mlp.fc2.bias', 'module.decoder.29.ls2.gamma', 'module.decoder.30.norm1.weight', 'module.decoder.30.norm1.bias', 'module.decoder.30.attn.qkv.weight', 'module.decoder.30.attn.qkv.bias', 'module.decoder.30.attn.proj.weight', 'module.decoder.30.attn.proj.bias', 'module.decoder.30.attn.q_norm.weight', 'module.decoder.30.attn.q_norm.bias', 'module.decoder.30.attn.k_norm.weight', 'module.decoder.30.attn.k_norm.bias', 'module.decoder.30.ls1.gamma', 'module.decoder.30.norm2.weight', 'module.decoder.30.norm2.bias', 'module.decoder.30.mlp.fc1.weight', 'module.decoder.30.mlp.fc1.bias', 'module.decoder.30.mlp.fc2.weight', 'module.decoder.30.mlp.fc2.bias', 'module.decoder.30.ls2.gamma', 'module.decoder.31.norm1.weight', 'module.decoder.31.norm1.bias', 'module.decoder.31.attn.qkv.weight', 'module.decoder.31.attn.qkv.bias', 'module.decoder.31.attn.proj.weight', 'module.decoder.31.attn.proj.bias', 'module.decoder.31.attn.q_norm.weight', 'module.decoder.31.attn.q_norm.bias', 'module.decoder.31.attn.k_norm.weight', 'module.decoder.31.attn.k_norm.bias', 'module.decoder.31.ls1.gamma', 'module.decoder.31.norm2.weight', 'module.decoder.31.norm2.bias', 'module.decoder.31.mlp.fc1.weight', 'module.decoder.31.mlp.fc1.bias', 'module.decoder.31.mlp.fc2.weight', 'module.decoder.31.mlp.fc2.bias', 'module.decoder.31.ls2.gamma', 'module.decoder.32.norm1.weight', 'module.decoder.32.norm1.bias', 'module.decoder.32.attn.qkv.weight', 'module.decoder.32.attn.qkv.bias', 'module.decoder.32.attn.proj.weight', 'module.decoder.32.attn.proj.bias', 'module.decoder.32.attn.q_norm.weight', 'module.decoder.32.attn.q_norm.bias', 'module.decoder.32.attn.k_norm.weight', 'module.decoder.32.attn.k_norm.bias', 'module.decoder.32.ls1.gamma', 'module.decoder.32.norm2.weight', 'module.decoder.32.norm2.bias', 'module.decoder.32.mlp.fc1.weight', 'module.decoder.32.mlp.fc1.bias', 'module.decoder.32.mlp.fc2.weight', 'module.decoder.32.mlp.fc2.bias', 'module.decoder.32.ls2.gamma', 'module.decoder.33.norm1.weight', 'module.decoder.33.norm1.bias', 'module.decoder.33.attn.qkv.weight', 'module.decoder.33.attn.qkv.bias', 'module.decoder.33.attn.proj.weight', 'module.decoder.33.attn.proj.bias', 'module.decoder.33.attn.q_norm.weight', 'module.decoder.33.attn.q_norm.bias', 'module.decoder.33.attn.k_norm.weight', 'module.decoder.33.attn.k_norm.bias', 'module.decoder.33.ls1.gamma', 'module.decoder.33.norm2.weight', 'module.decoder.33.norm2.bias', 'module.decoder.33.mlp.fc1.weight', 'module.decoder.33.mlp.fc1.bias', 'module.decoder.33.mlp.fc2.weight', 'module.decoder.33.mlp.fc2.bias', 'module.decoder.33.ls2.gamma', 'module.decoder.34.norm1.weight', 'module.decoder.34.norm1.bias', 'module.decoder.34.attn.qkv.weight', 'module.decoder.34.attn.qkv.bias', 'module.decoder.34.attn.proj.weight', 'module.decoder.34.attn.proj.bias', 'module.decoder.34.attn.q_norm.weight', 'module.decoder.34.attn.q_norm.bias', 'module.decoder.34.attn.k_norm.weight', 'module.decoder.34.attn.k_norm.bias', 'module.decoder.34.ls1.gamma', 'module.decoder.34.norm2.weight', 'module.decoder.34.norm2.bias', 'module.decoder.34.mlp.fc1.weight', 'module.decoder.34.mlp.fc1.bias', 'module.decoder.34.mlp.fc2.weight', 'module.decoder.34.mlp.fc2.bias', 'module.decoder.34.ls2.gamma', 'module.decoder.35.norm1.weight', 'module.decoder.35.norm1.bias', 'module.decoder.35.attn.qkv.weight', 'module.decoder.35.attn.qkv.bias', 'module.decoder.35.attn.proj.weight', 'module.decoder.35.attn.proj.bias', 'module.decoder.35.attn.q_norm.weight', 'module.decoder.35.attn.q_norm.bias', 'module.decoder.35.attn.k_norm.weight', 'module.decoder.35.attn.k_norm.bias', 'module.decoder.35.ls1.gamma', 'module.decoder.35.norm2.weight', 'module.decoder.35.norm2.bias', 'module.decoder.35.mlp.fc1.weight', 'module.decoder.35.mlp.fc1.bias', 'module.decoder.35.mlp.fc2.weight', 'module.decoder.35.mlp.fc2.bias', 'module.decoder.35.ls2.gamma', 'module.point_decoder.projects.weight', 'module.point_decoder.projects.bias', 'module.point_decoder.blocks.0.norm1.weight', 'module.point_decoder.blocks.0.norm1.bias', 'module.point_decoder.blocks.0.attn.qkv.weight', 'module.point_decoder.blocks.0.attn.qkv.bias', 'module.point_decoder.blocks.0.attn.proj.weight', 'module.point_decoder.blocks.0.attn.proj.bias', 'module.point_decoder.blocks.0.norm2.weight', 'module.point_decoder.blocks.0.norm2.bias', 'module.point_decoder.blocks.0.mlp.fc1.weight', 'module.point_decoder.blocks.0.mlp.fc1.bias', 'module.point_decoder.blocks.0.mlp.fc2.weight', 'module.point_decoder.blocks.0.mlp.fc2.bias', 'module.point_decoder.blocks.1.norm1.weight', 'module.point_decoder.blocks.1.norm1.bias', 'module.point_decoder.blocks.1.attn.qkv.weight', 'module.point_decoder.blocks.1.attn.qkv.bias', 'module.point_decoder.blocks.1.attn.proj.weight', 'module.point_decoder.blocks.1.attn.proj.bias', 'module.point_decoder.blocks.1.norm2.weight', 'module.point_decoder.blocks.1.norm2.bias', 'module.point_decoder.blocks.1.mlp.fc1.weight', 'module.point_decoder.blocks.1.mlp.fc1.bias', 'module.point_decoder.blocks.1.mlp.fc2.weight', 'module.point_decoder.blocks.1.mlp.fc2.bias', 'module.point_decoder.blocks.2.norm1.weight', 'module.point_decoder.blocks.2.norm1.bias', 'module.point_decoder.blocks.2.attn.qkv.weight', 'module.point_decoder.blocks.2.attn.qkv.bias', 'module.point_decoder.blocks.2.attn.proj.weight', 'module.point_decoder.blocks.2.attn.proj.bias', 'module.point_decoder.blocks.2.norm2.weight', 'module.point_decoder.blocks.2.norm2.bias', 'module.point_decoder.blocks.2.mlp.fc1.weight', 'module.point_decoder.blocks.2.mlp.fc1.bias', 'module.point_decoder.blocks.2.mlp.fc2.weight', 'module.point_decoder.blocks.2.mlp.fc2.bias', 'module.point_decoder.blocks.3.norm1.weight', 'module.point_decoder.blocks.3.norm1.bias', 'module.point_decoder.blocks.3.attn.qkv.weight', 'module.point_decoder.blocks.3.attn.qkv.bias', 'module.point_decoder.blocks.3.attn.proj.weight', 'module.point_decoder.blocks.3.attn.proj.bias', 'module.point_decoder.blocks.3.norm2.weight', 'module.point_decoder.blocks.3.norm2.bias', 'module.point_decoder.blocks.3.mlp.fc1.weight', 'module.point_decoder.blocks.3.mlp.fc1.bias', 'module.point_decoder.blocks.3.mlp.fc2.weight', 'module.point_decoder.blocks.3.mlp.fc2.bias', 'module.point_decoder.blocks.4.norm1.weight', 'module.point_decoder.blocks.4.norm1.bias', 'module.point_decoder.blocks.4.attn.qkv.weight', 'module.point_decoder.blocks.4.attn.qkv.bias', 'module.point_decoder.blocks.4.attn.proj.weight', 'module.point_decoder.blocks.4.attn.proj.bias', 'module.point_decoder.blocks.4.norm2.weight', 'module.point_decoder.blocks.4.norm2.bias', 'module.point_decoder.blocks.4.mlp.fc1.weight', 'module.point_decoder.blocks.4.mlp.fc1.bias', 'module.point_decoder.blocks.4.mlp.fc2.weight', 'module.point_decoder.blocks.4.mlp.fc2.bias', 'module.point_decoder.linear_out.weight', 'module.point_decoder.linear_out.bias', 'module.point_head.proj.weight', 'module.point_head.proj.bias', 'module.conf_decoder.projects.weight', 'module.conf_decoder.projects.bias', 'module.conf_decoder.blocks.0.norm1.weight', 'module.conf_decoder.blocks.0.norm1.bias', 'module.conf_decoder.blocks.0.attn.qkv.weight', 'module.conf_decoder.blocks.0.attn.qkv.bias', 'module.conf_decoder.blocks.0.attn.proj.weight', 'module.conf_decoder.blocks.0.attn.proj.bias', 'module.conf_decoder.blocks.0.norm2.weight', 'module.conf_decoder.blocks.0.norm2.bias', 'module.conf_decoder.blocks.0.mlp.fc1.weight', 'module.conf_decoder.blocks.0.mlp.fc1.bias', 'module.conf_decoder.blocks.0.mlp.fc2.weight', 'module.conf_decoder.blocks.0.mlp.fc2.bias', 'module.conf_decoder.blocks.1.norm1.weight', 'module.conf_decoder.blocks.1.norm1.bias', 'module.conf_decoder.blocks.1.attn.qkv.weight', 'module.conf_decoder.blocks.1.attn.qkv.bias', 'module.conf_decoder.blocks.1.attn.proj.weight', 'module.conf_decoder.blocks.1.attn.proj.bias', 'module.conf_decoder.blocks.1.norm2.weight', 'module.conf_decoder.blocks.1.norm2.bias', 'module.conf_decoder.blocks.1.mlp.fc1.weight', 'module.conf_decoder.blocks.1.mlp.fc1.bias', 'module.conf_decoder.blocks.1.mlp.fc2.weight', 'module.conf_decoder.blocks.1.mlp.fc2.bias', 'module.conf_decoder.blocks.2.norm1.weight', 'module.conf_decoder.blocks.2.norm1.bias', 'module.conf_decoder.blocks.2.attn.qkv.weight', 'module.conf_decoder.blocks.2.attn.qkv.bias', 'module.conf_decoder.blocks.2.attn.proj.weight', 'module.conf_decoder.blocks.2.attn.proj.bias', 'module.conf_decoder.blocks.2.norm2.weight', 'module.conf_decoder.blocks.2.norm2.bias', 'module.conf_decoder.blocks.2.mlp.fc1.weight', 'module.conf_decoder.blocks.2.mlp.fc1.bias', 'module.conf_decoder.blocks.2.mlp.fc2.weight', 'module.conf_decoder.blocks.2.mlp.fc2.bias', 'module.conf_decoder.blocks.3.norm1.weight', 'module.conf_decoder.blocks.3.norm1.bias', 'module.conf_decoder.blocks.3.attn.qkv.weight', 'module.conf_decoder.blocks.3.attn.qkv.bias', 'module.conf_decoder.blocks.3.attn.proj.weight', 'module.conf_decoder.blocks.3.attn.proj.bias', 'module.conf_decoder.blocks.3.norm2.weight', 'module.conf_decoder.blocks.3.norm2.bias', 'module.conf_decoder.blocks.3.mlp.fc1.weight', 'module.conf_decoder.blocks.3.mlp.fc1.bias', 'module.conf_decoder.blocks.3.mlp.fc2.weight', 'module.conf_decoder.blocks.3.mlp.fc2.bias', 'module.conf_decoder.blocks.4.norm1.weight', 'module.conf_decoder.blocks.4.norm1.bias', 'module.conf_decoder.blocks.4.attn.qkv.weight', 'module.conf_decoder.blocks.4.attn.qkv.bias', 'module.conf_decoder.blocks.4.attn.proj.weight', 'module.conf_decoder.blocks.4.attn.proj.bias', 'module.conf_decoder.blocks.4.norm2.weight', 'module.conf_decoder.blocks.4.norm2.bias', 'module.conf_decoder.blocks.4.mlp.fc1.weight', 'module.conf_decoder.blocks.4.mlp.fc1.bias', 'module.conf_decoder.blocks.4.mlp.fc2.weight', 'module.conf_decoder.blocks.4.mlp.fc2.bias', 'module.conf_decoder.linear_out.weight', 'module.conf_decoder.linear_out.bias', 'module.conf_head.proj.weight', 'module.conf_head.proj.bias', 'module.camera_decoder.projects.weight', 'module.camera_decoder.projects.bias', 'module.camera_decoder.blocks.0.norm1.weight', 'module.camera_decoder.blocks.0.norm1.bias', 'module.camera_decoder.blocks.0.attn.qkv.weight', 'module.camera_decoder.blocks.0.attn.qkv.bias', 'module.camera_decoder.blocks.0.attn.proj.weight', 'module.camera_decoder.blocks.0.attn.proj.bias', 'module.camera_decoder.blocks.0.norm2.weight', 'module.camera_decoder.blocks.0.norm2.bias', 'module.camera_decoder.blocks.0.mlp.fc1.weight', 'module.camera_decoder.blocks.0.mlp.fc1.bias', 'module.camera_decoder.blocks.0.mlp.fc2.weight', 'module.camera_decoder.blocks.0.mlp.fc2.bias', 'module.camera_decoder.blocks.1.norm1.weight', 'module.camera_decoder.blocks.1.norm1.bias', 'module.camera_decoder.blocks.1.attn.qkv.weight', 'module.camera_decoder.blocks.1.attn.qkv.bias', 'module.camera_decoder.blocks.1.attn.proj.weight', 'module.camera_decoder.blocks.1.attn.proj.bias', 'module.camera_decoder.blocks.1.norm2.weight', 'module.camera_decoder.blocks.1.norm2.bias', 'module.camera_decoder.blocks.1.mlp.fc1.weight', 'module.camera_decoder.blocks.1.mlp.fc1.bias', 'module.camera_decoder.blocks.1.mlp.fc2.weight', 'module.camera_decoder.blocks.1.mlp.fc2.bias', 'module.camera_decoder.blocks.2.norm1.weight', 'module.camera_decoder.blocks.2.norm1.bias', 'module.camera_decoder.blocks.2.attn.qkv.weight', 'module.camera_decoder.blocks.2.attn.qkv.bias', 'module.camera_decoder.blocks.2.attn.proj.weight', 'module.camera_decoder.blocks.2.attn.proj.bias', 'module.camera_decoder.blocks.2.norm2.weight', 'module.camera_decoder.blocks.2.norm2.bias', 'module.camera_decoder.blocks.2.mlp.fc1.weight', 'module.camera_decoder.blocks.2.mlp.fc1.bias', 'module.camera_decoder.blocks.2.mlp.fc2.weight', 'module.camera_decoder.blocks.2.mlp.fc2.bias', 'module.camera_decoder.blocks.3.norm1.weight', 'module.camera_decoder.blocks.3.norm1.bias', 'module.camera_decoder.blocks.3.attn.qkv.weight', 'module.camera_decoder.blocks.3.attn.qkv.bias', 'module.camera_decoder.blocks.3.attn.proj.weight', 'module.camera_decoder.blocks.3.attn.proj.bias', 'module.camera_decoder.blocks.3.norm2.weight', 'module.camera_decoder.blocks.3.norm2.bias', 'module.camera_decoder.blocks.3.mlp.fc1.weight', 'module.camera_decoder.blocks.3.mlp.fc1.bias', 'module.camera_decoder.blocks.3.mlp.fc2.weight', 'module.camera_decoder.blocks.3.mlp.fc2.bias', 'module.camera_decoder.blocks.4.norm1.weight', 'module.camera_decoder.blocks.4.norm1.bias', 'module.camera_decoder.blocks.4.attn.qkv.weight', 'module.camera_decoder.blocks.4.attn.qkv.bias', 'module.camera_decoder.blocks.4.attn.proj.weight', 'module.camera_decoder.blocks.4.attn.proj.bias', 'module.camera_decoder.blocks.4.norm2.weight', 'module.camera_decoder.blocks.4.norm2.bias', 'module.camera_decoder.blocks.4.mlp.fc1.weight', 'module.camera_decoder.blocks.4.mlp.fc1.bias', 'module.camera_decoder.blocks.4.mlp.fc2.weight', 'module.camera_decoder.blocks.4.mlp.fc2.bias', 'module.camera_decoder.linear_out.weight', 'module.camera_decoder.linear_out.bias', 'module.camera_head.res_conv.0.res_conv1.weight', 'module.camera_head.res_conv.0.res_conv1.bias', 'module.camera_head.res_conv.0.res_conv2.weight', 'module.camera_head.res_conv.0.res_conv2.bias', 'module.camera_head.res_conv.0.res_conv3.weight', 'module.camera_head.res_conv.0.res_conv3.bias', 'module.camera_head.res_conv.1.res_conv1.weight', 'module.camera_head.res_conv.1.res_conv1.bias', 'module.camera_head.res_conv.1.res_conv2.weight', 'module.camera_head.res_conv.1.res_conv2.bias', 'module.camera_head.res_conv.1.res_conv3.weight', 'module.camera_head.res_conv.1.res_conv3.bias', 'module.camera_head.more_mlps.0.weight', 'module.camera_head.more_mlps.0.bias', 'module.camera_head.more_mlps.2.weight', 'module.camera_head.more_mlps.2.bias', 'module.camera_head.fc_t.weight', 'module.camera_head.fc_t.bias', 'module.camera_head.fc_rot.weight', 'module.camera_head.fc_rot.bias']) +[2026-05-01 23:35:27,069][__main__][INFO] - [RANK 0] Freezing patch embedding and positional encoding parameters... +[2026-05-01 23:35:27,075][__main__][INFO] - [RANK 0] Frozen 304,376,832 parameters out of 958,696,732 total parameters. (31.75%) +[2026-05-01 23:35:27,075][__main__][INFO] - [RANK 0] Trainable parameters: 654,319,900 (68.25%) +[2026-05-01 23:35:27,075][__main__][INFO] - [RANK 0] Example frozen parameters: register_token, encoder.cls_token, encoder.pos_embed, encoder.register_tokens, encoder.patch_embed.proj.weight... +[2026-05-01 23:35:27,078][croco.utils.misc][INFO] - [RANK 0] Param groups = { + "no_decay": { + "weight_decay": 0.0, + "params": [ + "decoder.0.norm1.weight", + "decoder.0.norm1.bias", + "decoder.0.attn.qkv.bias", + "decoder.0.attn.proj.bias", + "decoder.0.attn.q_norm.weight", + "decoder.0.attn.q_norm.bias", + "decoder.0.attn.k_norm.weight", + "decoder.0.attn.k_norm.bias", + "decoder.0.ls1.gamma", + "decoder.0.norm2.weight", + "decoder.0.norm2.bias", + "decoder.0.mlp.fc1.bias", + "decoder.0.mlp.fc2.bias", + "decoder.0.ls2.gamma", + "decoder.1.norm1.weight", + "decoder.1.norm1.bias", + "decoder.1.attn.qkv.bias", + "decoder.1.attn.proj.bias", + "decoder.1.attn.q_norm.weight", + "decoder.1.attn.q_norm.bias", + "decoder.1.attn.k_norm.weight", + "decoder.1.attn.k_norm.bias", + "decoder.1.ls1.gamma", + "decoder.1.norm2.weight", + "decoder.1.norm2.bias", + "decoder.1.mlp.fc1.bias", + "decoder.1.mlp.fc2.bias", + "decoder.1.ls2.gamma", + "decoder.2.norm1.weight", + "decoder.2.norm1.bias", + "decoder.2.attn.qkv.bias", + "decoder.2.attn.proj.bias", + "decoder.2.attn.q_norm.weight", + "decoder.2.attn.q_norm.bias", + "decoder.2.attn.k_norm.weight", + "decoder.2.attn.k_norm.bias", + "decoder.2.ls1.gamma", + "decoder.2.norm2.weight", + "decoder.2.norm2.bias", + "decoder.2.mlp.fc1.bias", + "decoder.2.mlp.fc2.bias", + "decoder.2.ls2.gamma", + "decoder.3.norm1.weight", + "decoder.3.norm1.bias", + "decoder.3.attn.qkv.bias", + "decoder.3.attn.proj.bias", + "decoder.3.attn.q_norm.weight", + "decoder.3.attn.q_norm.bias", + "decoder.3.attn.k_norm.weight", + "decoder.3.attn.k_norm.bias", + "decoder.3.ls1.gamma", + "decoder.3.norm2.weight", + "decoder.3.norm2.bias", + "decoder.3.mlp.fc1.bias", + "decoder.3.mlp.fc2.bias", + "decoder.3.ls2.gamma", + "decoder.4.norm1.weight", + "decoder.4.norm1.bias", + "decoder.4.attn.qkv.bias", + "decoder.4.attn.proj.bias", + "decoder.4.attn.q_norm.weight", + "decoder.4.attn.q_norm.bias", + "decoder.4.attn.k_norm.weight", + "decoder.4.attn.k_norm.bias", + "decoder.4.ls1.gamma", + "decoder.4.norm2.weight", + "decoder.4.norm2.bias", + "decoder.4.mlp.fc1.bias", + "decoder.4.mlp.fc2.bias", + "decoder.4.ls2.gamma", + "decoder.5.norm1.weight", + "decoder.5.norm1.bias", + "decoder.5.attn.qkv.bias", + "decoder.5.attn.proj.bias", + "decoder.5.attn.q_norm.weight", + "decoder.5.attn.q_norm.bias", + "decoder.5.attn.k_norm.weight", + "decoder.5.attn.k_norm.bias", + "decoder.5.ls1.gamma", + "decoder.5.norm2.weight", + "decoder.5.norm2.bias", + "decoder.5.mlp.fc1.bias", + "decoder.5.mlp.fc2.bias", + "decoder.5.ls2.gamma", + "decoder.6.norm1.weight", + "decoder.6.norm1.bias", + "decoder.6.attn.qkv.bias", + "decoder.6.attn.proj.bias", + "decoder.6.attn.q_norm.weight", + "decoder.6.attn.q_norm.bias", + "decoder.6.attn.k_norm.weight", + "decoder.6.attn.k_norm.bias", + "decoder.6.ls1.gamma", + "decoder.6.norm2.weight", + "decoder.6.norm2.bias", + "decoder.6.mlp.fc1.bias", + "decoder.6.mlp.fc2.bias", + "decoder.6.ls2.gamma", + "decoder.7.norm1.weight", + "decoder.7.norm1.bias", + "decoder.7.attn.qkv.bias", + "decoder.7.attn.proj.bias", + "decoder.7.attn.q_norm.weight", + "decoder.7.attn.q_norm.bias", + "decoder.7.attn.k_norm.weight", + "decoder.7.attn.k_norm.bias", + "decoder.7.ls1.gamma", + "decoder.7.norm2.weight", + "decoder.7.norm2.bias", + "decoder.7.mlp.fc1.bias", + "decoder.7.mlp.fc2.bias", + "decoder.7.ls2.gamma", + "decoder.8.norm1.weight", + "decoder.8.norm1.bias", + "decoder.8.attn.qkv.bias", + "decoder.8.attn.proj.bias", + "decoder.8.attn.q_norm.weight", + "decoder.8.attn.q_norm.bias", + "decoder.8.attn.k_norm.weight", + "decoder.8.attn.k_norm.bias", + "decoder.8.ls1.gamma", + "decoder.8.norm2.weight", + "decoder.8.norm2.bias", + "decoder.8.mlp.fc1.bias", + "decoder.8.mlp.fc2.bias", + "decoder.8.ls2.gamma", + "decoder.9.norm1.weight", + "decoder.9.norm1.bias", + "decoder.9.attn.qkv.bias", + "decoder.9.attn.proj.bias", + "decoder.9.attn.q_norm.weight", + "decoder.9.attn.q_norm.bias", + "decoder.9.attn.k_norm.weight", + "decoder.9.attn.k_norm.bias", + "decoder.9.ls1.gamma", + "decoder.9.norm2.weight", + "decoder.9.norm2.bias", + "decoder.9.mlp.fc1.bias", + "decoder.9.mlp.fc2.bias", + "decoder.9.ls2.gamma", + "decoder.10.norm1.weight", + "decoder.10.norm1.bias", + "decoder.10.attn.qkv.bias", + "decoder.10.attn.proj.bias", + "decoder.10.attn.q_norm.weight", + "decoder.10.attn.q_norm.bias", + "decoder.10.attn.k_norm.weight", + "decoder.10.attn.k_norm.bias", + "decoder.10.ls1.gamma", + "decoder.10.norm2.weight", + "decoder.10.norm2.bias", + "decoder.10.mlp.fc1.bias", + "decoder.10.mlp.fc2.bias", + "decoder.10.ls2.gamma", + "decoder.11.norm1.weight", + "decoder.11.norm1.bias", + "decoder.11.attn.qkv.bias", + "decoder.11.attn.proj.bias", + "decoder.11.attn.q_norm.weight", + "decoder.11.attn.q_norm.bias", + "decoder.11.attn.k_norm.weight", + "decoder.11.attn.k_norm.bias", + "decoder.11.ls1.gamma", + "decoder.11.norm2.weight", + "decoder.11.norm2.bias", + "decoder.11.mlp.fc1.bias", + "decoder.11.mlp.fc2.bias", + "decoder.11.ls2.gamma", + "decoder.12.norm1.weight", + "decoder.12.norm1.bias", + "decoder.12.attn.qkv.bias", + "decoder.12.attn.proj.bias", + "decoder.12.attn.q_norm.weight", + "decoder.12.attn.q_norm.bias", + "decoder.12.attn.k_norm.weight", + "decoder.12.attn.k_norm.bias", + "decoder.12.ls1.gamma", + "decoder.12.norm2.weight", + "decoder.12.norm2.bias", + "decoder.12.mlp.fc1.bias", + "decoder.12.mlp.fc2.bias", + "decoder.12.ls2.gamma", + "decoder.13.norm1.weight", + "decoder.13.norm1.bias", + "decoder.13.attn.qkv.bias", + "decoder.13.attn.proj.bias", + "decoder.13.attn.q_norm.weight", + "decoder.13.attn.q_norm.bias", + "decoder.13.attn.k_norm.weight", + "decoder.13.attn.k_norm.bias", + "decoder.13.ls1.gamma", + "decoder.13.norm2.weight", + "decoder.13.norm2.bias", + "decoder.13.mlp.fc1.bias", + "decoder.13.mlp.fc2.bias", + "decoder.13.ls2.gamma", + "decoder.14.norm1.weight", + "decoder.14.norm1.bias", + "decoder.14.attn.qkv.bias", + "decoder.14.attn.proj.bias", + "decoder.14.attn.q_norm.weight", + "decoder.14.attn.q_norm.bias", + "decoder.14.attn.k_norm.weight", + "decoder.14.attn.k_norm.bias", + "decoder.14.ls1.gamma", + "decoder.14.norm2.weight", + "decoder.14.norm2.bias", + "decoder.14.mlp.fc1.bias", + "decoder.14.mlp.fc2.bias", + "decoder.14.ls2.gamma", + "decoder.15.norm1.weight", + "decoder.15.norm1.bias", + "decoder.15.attn.qkv.bias", + "decoder.15.attn.proj.bias", + "decoder.15.attn.q_norm.weight", + "decoder.15.attn.q_norm.bias", + "decoder.15.attn.k_norm.weight", + "decoder.15.attn.k_norm.bias", + "decoder.15.ls1.gamma", + "decoder.15.norm2.weight", + "decoder.15.norm2.bias", + "decoder.15.mlp.fc1.bias", + "decoder.15.mlp.fc2.bias", + "decoder.15.ls2.gamma", + "decoder.16.norm1.weight", + "decoder.16.norm1.bias", + "decoder.16.attn.qkv.bias", + "decoder.16.attn.proj.bias", + "decoder.16.attn.q_norm.weight", + "decoder.16.attn.q_norm.bias", + "decoder.16.attn.k_norm.weight", + "decoder.16.attn.k_norm.bias", + "decoder.16.ls1.gamma", + "decoder.16.norm2.weight", + "decoder.16.norm2.bias", + "decoder.16.mlp.fc1.bias", + "decoder.16.mlp.fc2.bias", + "decoder.16.ls2.gamma", + "decoder.17.norm1.weight", + "decoder.17.norm1.bias", + "decoder.17.attn.qkv.bias", + "decoder.17.attn.proj.bias", + "decoder.17.attn.q_norm.weight", + "decoder.17.attn.q_norm.bias", + "decoder.17.attn.k_norm.weight", + "decoder.17.attn.k_norm.bias", + "decoder.17.ls1.gamma", + "decoder.17.norm2.weight", + "decoder.17.norm2.bias", + "decoder.17.mlp.fc1.bias", + "decoder.17.mlp.fc2.bias", + "decoder.17.ls2.gamma", + "decoder.18.norm1.weight", + "decoder.18.norm1.bias", + "decoder.18.attn.qkv.bias", + "decoder.18.attn.proj.bias", + "decoder.18.attn.q_norm.weight", + "decoder.18.attn.q_norm.bias", + "decoder.18.attn.k_norm.weight", + "decoder.18.attn.k_norm.bias", + "decoder.18.ls1.gamma", + "decoder.18.norm2.weight", + "decoder.18.norm2.bias", + "decoder.18.mlp.fc1.bias", + "decoder.18.mlp.fc2.bias", + "decoder.18.ls2.gamma", + "decoder.19.norm1.weight", + "decoder.19.norm1.bias", + "decoder.19.attn.qkv.bias", + "decoder.19.attn.proj.bias", + "decoder.19.attn.q_norm.weight", + "decoder.19.attn.q_norm.bias", + "decoder.19.attn.k_norm.weight", + "decoder.19.attn.k_norm.bias", + "decoder.19.ls1.gamma", + "decoder.19.norm2.weight", + "decoder.19.norm2.bias", + "decoder.19.mlp.fc1.bias", + "decoder.19.mlp.fc2.bias", + "decoder.19.ls2.gamma", + "decoder.20.norm1.weight", + "decoder.20.norm1.bias", + "decoder.20.attn.qkv.bias", + "decoder.20.attn.proj.bias", + "decoder.20.attn.q_norm.weight", + "decoder.20.attn.q_norm.bias", + "decoder.20.attn.k_norm.weight", + "decoder.20.attn.k_norm.bias", + "decoder.20.ls1.gamma", + "decoder.20.norm2.weight", + "decoder.20.norm2.bias", + "decoder.20.mlp.fc1.bias", + "decoder.20.mlp.fc2.bias", + "decoder.20.ls2.gamma", + "decoder.21.norm1.weight", + "decoder.21.norm1.bias", + "decoder.21.attn.qkv.bias", + "decoder.21.attn.proj.bias", + "decoder.21.attn.q_norm.weight", + "decoder.21.attn.q_norm.bias", + "decoder.21.attn.k_norm.weight", + "decoder.21.attn.k_norm.bias", + "decoder.21.ls1.gamma", + "decoder.21.norm2.weight", + "decoder.21.norm2.bias", + "decoder.21.mlp.fc1.bias", + "decoder.21.mlp.fc2.bias", + "decoder.21.ls2.gamma", + "decoder.22.norm1.weight", + "decoder.22.norm1.bias", + "decoder.22.attn.qkv.bias", + "decoder.22.attn.proj.bias", + "decoder.22.attn.q_norm.weight", + "decoder.22.attn.q_norm.bias", + "decoder.22.attn.k_norm.weight", + "decoder.22.attn.k_norm.bias", + "decoder.22.ls1.gamma", + "decoder.22.norm2.weight", + "decoder.22.norm2.bias", + "decoder.22.mlp.fc1.bias", + "decoder.22.mlp.fc2.bias", + "decoder.22.ls2.gamma", + "decoder.23.norm1.weight", + "decoder.23.norm1.bias", + "decoder.23.attn.qkv.bias", + "decoder.23.attn.proj.bias", + "decoder.23.attn.q_norm.weight", + "decoder.23.attn.q_norm.bias", + "decoder.23.attn.k_norm.weight", + "decoder.23.attn.k_norm.bias", + "decoder.23.ls1.gamma", + "decoder.23.norm2.weight", + "decoder.23.norm2.bias", + "decoder.23.mlp.fc1.bias", + "decoder.23.mlp.fc2.bias", + "decoder.23.ls2.gamma", + "decoder.24.norm1.weight", + "decoder.24.norm1.bias", + "decoder.24.attn.qkv.bias", + "decoder.24.attn.proj.bias", + "decoder.24.attn.q_norm.weight", + "decoder.24.attn.q_norm.bias", + "decoder.24.attn.k_norm.weight", + "decoder.24.attn.k_norm.bias", + "decoder.24.ls1.gamma", + "decoder.24.norm2.weight", + "decoder.24.norm2.bias", + "decoder.24.mlp.fc1.bias", + "decoder.24.mlp.fc2.bias", + "decoder.24.ls2.gamma", + "decoder.25.norm1.weight", + "decoder.25.norm1.bias", + "decoder.25.attn.qkv.bias", + "decoder.25.attn.proj.bias", + "decoder.25.attn.q_norm.weight", + "decoder.25.attn.q_norm.bias", + "decoder.25.attn.k_norm.weight", + "decoder.25.attn.k_norm.bias", + "decoder.25.ls1.gamma", + "decoder.25.norm2.weight", + "decoder.25.norm2.bias", + "decoder.25.mlp.fc1.bias", + "decoder.25.mlp.fc2.bias", + "decoder.25.ls2.gamma", + "decoder.26.norm1.weight", + "decoder.26.norm1.bias", + "decoder.26.attn.qkv.bias", + "decoder.26.attn.proj.bias", + "decoder.26.attn.q_norm.weight", + "decoder.26.attn.q_norm.bias", + "decoder.26.attn.k_norm.weight", + "decoder.26.attn.k_norm.bias", + "decoder.26.ls1.gamma", + "decoder.26.norm2.weight", + "decoder.26.norm2.bias", + "decoder.26.mlp.fc1.bias", + "decoder.26.mlp.fc2.bias", + "decoder.26.ls2.gamma", + "decoder.27.norm1.weight", + "decoder.27.norm1.bias", + "decoder.27.attn.qkv.bias", + "decoder.27.attn.proj.bias", + "decoder.27.attn.q_norm.weight", + "decoder.27.attn.q_norm.bias", + "decoder.27.attn.k_norm.weight", + "decoder.27.attn.k_norm.bias", + "decoder.27.ls1.gamma", + "decoder.27.norm2.weight", + "decoder.27.norm2.bias", + "decoder.27.mlp.fc1.bias", + "decoder.27.mlp.fc2.bias", + "decoder.27.ls2.gamma", + "decoder.28.norm1.weight", + "decoder.28.norm1.bias", + "decoder.28.attn.qkv.bias", + "decoder.28.attn.proj.bias", + "decoder.28.attn.q_norm.weight", + "decoder.28.attn.q_norm.bias", + "decoder.28.attn.k_norm.weight", + "decoder.28.attn.k_norm.bias", + "decoder.28.ls1.gamma", + "decoder.28.norm2.weight", + "decoder.28.norm2.bias", + "decoder.28.mlp.fc1.bias", + "decoder.28.mlp.fc2.bias", + "decoder.28.ls2.gamma", + "decoder.29.norm1.weight", + "decoder.29.norm1.bias", + "decoder.29.attn.qkv.bias", + "decoder.29.attn.proj.bias", + "decoder.29.attn.q_norm.weight", + "decoder.29.attn.q_norm.bias", + "decoder.29.attn.k_norm.weight", + "decoder.29.attn.k_norm.bias", + "decoder.29.ls1.gamma", + "decoder.29.norm2.weight", + "decoder.29.norm2.bias", + "decoder.29.mlp.fc1.bias", + "decoder.29.mlp.fc2.bias", + "decoder.29.ls2.gamma", + "decoder.30.norm1.weight", + "decoder.30.norm1.bias", + "decoder.30.attn.qkv.bias", + "decoder.30.attn.proj.bias", + "decoder.30.attn.q_norm.weight", + "decoder.30.attn.q_norm.bias", + "decoder.30.attn.k_norm.weight", + "decoder.30.attn.k_norm.bias", + "decoder.30.ls1.gamma", + "decoder.30.norm2.weight", + "decoder.30.norm2.bias", + "decoder.30.mlp.fc1.bias", + "decoder.30.mlp.fc2.bias", + "decoder.30.ls2.gamma", + "decoder.31.norm1.weight", + "decoder.31.norm1.bias", + "decoder.31.attn.qkv.bias", + "decoder.31.attn.proj.bias", + "decoder.31.attn.q_norm.weight", + "decoder.31.attn.q_norm.bias", + "decoder.31.attn.k_norm.weight", + "decoder.31.attn.k_norm.bias", + "decoder.31.ls1.gamma", + "decoder.31.norm2.weight", + "decoder.31.norm2.bias", + "decoder.31.mlp.fc1.bias", + "decoder.31.mlp.fc2.bias", + "decoder.31.ls2.gamma", + "decoder.32.norm1.weight", + "decoder.32.norm1.bias", + "decoder.32.attn.qkv.bias", + "decoder.32.attn.proj.bias", + "decoder.32.attn.q_norm.weight", + "decoder.32.attn.q_norm.bias", + "decoder.32.attn.k_norm.weight", + "decoder.32.attn.k_norm.bias", + "decoder.32.ls1.gamma", + "decoder.32.norm2.weight", + "decoder.32.norm2.bias", + "decoder.32.mlp.fc1.bias", + "decoder.32.mlp.fc2.bias", + "decoder.32.ls2.gamma", + "decoder.33.norm1.weight", + "decoder.33.norm1.bias", + "decoder.33.attn.qkv.bias", + "decoder.33.attn.proj.bias", + "decoder.33.attn.q_norm.weight", + "decoder.33.attn.q_norm.bias", + "decoder.33.attn.k_norm.weight", + "decoder.33.attn.k_norm.bias", + "decoder.33.ls1.gamma", + "decoder.33.norm2.weight", + "decoder.33.norm2.bias", + "decoder.33.mlp.fc1.bias", + "decoder.33.mlp.fc2.bias", + "decoder.33.ls2.gamma", + "decoder.34.norm1.weight", + "decoder.34.norm1.bias", + "decoder.34.attn.qkv.bias", + "decoder.34.attn.proj.bias", + "decoder.34.attn.q_norm.weight", + "decoder.34.attn.q_norm.bias", + "decoder.34.attn.k_norm.weight", + "decoder.34.attn.k_norm.bias", + "decoder.34.ls1.gamma", + "decoder.34.norm2.weight", + "decoder.34.norm2.bias", + "decoder.34.mlp.fc1.bias", + "decoder.34.mlp.fc2.bias", + "decoder.34.ls2.gamma", + "decoder.35.norm1.weight", + "decoder.35.norm1.bias", + "decoder.35.attn.qkv.bias", + "decoder.35.attn.proj.bias", + "decoder.35.attn.q_norm.weight", + "decoder.35.attn.q_norm.bias", + "decoder.35.attn.k_norm.weight", + "decoder.35.attn.k_norm.bias", + "decoder.35.ls1.gamma", + "decoder.35.norm2.weight", + "decoder.35.norm2.bias", + "decoder.35.mlp.fc1.bias", + "decoder.35.mlp.fc2.bias", + "decoder.35.ls2.gamma", + "point_decoder.projects.bias", + "point_decoder.blocks.0.norm1.weight", + "point_decoder.blocks.0.norm1.bias", + "point_decoder.blocks.0.attn.qkv.bias", + "point_decoder.blocks.0.attn.proj.bias", + "point_decoder.blocks.0.norm2.weight", + "point_decoder.blocks.0.norm2.bias", + "point_decoder.blocks.0.mlp.fc1.bias", + "point_decoder.blocks.0.mlp.fc2.bias", + "point_decoder.blocks.1.norm1.weight", + "point_decoder.blocks.1.norm1.bias", + "point_decoder.blocks.1.attn.qkv.bias", + "point_decoder.blocks.1.attn.proj.bias", + "point_decoder.blocks.1.norm2.weight", + "point_decoder.blocks.1.norm2.bias", + "point_decoder.blocks.1.mlp.fc1.bias", + "point_decoder.blocks.1.mlp.fc2.bias", + "point_decoder.blocks.2.norm1.weight", + "point_decoder.blocks.2.norm1.bias", + "point_decoder.blocks.2.attn.qkv.bias", + "point_decoder.blocks.2.attn.proj.bias", + "point_decoder.blocks.2.norm2.weight", + "point_decoder.blocks.2.norm2.bias", + "point_decoder.blocks.2.mlp.fc1.bias", + "point_decoder.blocks.2.mlp.fc2.bias", + "point_decoder.blocks.3.norm1.weight", + "point_decoder.blocks.3.norm1.bias", + "point_decoder.blocks.3.attn.qkv.bias", + "point_decoder.blocks.3.attn.proj.bias", + "point_decoder.blocks.3.norm2.weight", + "point_decoder.blocks.3.norm2.bias", + "point_decoder.blocks.3.mlp.fc1.bias", + "point_decoder.blocks.3.mlp.fc2.bias", + "point_decoder.blocks.4.norm1.weight", + "point_decoder.blocks.4.norm1.bias", + "point_decoder.blocks.4.attn.qkv.bias", + "point_decoder.blocks.4.attn.proj.bias", + "point_decoder.blocks.4.norm2.weight", + "point_decoder.blocks.4.norm2.bias", + "point_decoder.blocks.4.mlp.fc1.bias", + "point_decoder.blocks.4.mlp.fc2.bias", + "point_decoder.linear_out.bias", + "point_head.proj.bias", + "conf_decoder.projects.bias", + "conf_decoder.blocks.0.norm1.weight", + "conf_decoder.blocks.0.norm1.bias", + "conf_decoder.blocks.0.attn.qkv.bias", + "conf_decoder.blocks.0.attn.proj.bias", + "conf_decoder.blocks.0.norm2.weight", + "conf_decoder.blocks.0.norm2.bias", + "conf_decoder.blocks.0.mlp.fc1.bias", + "conf_decoder.blocks.0.mlp.fc2.bias", + "conf_decoder.blocks.1.norm1.weight", + "conf_decoder.blocks.1.norm1.bias", + "conf_decoder.blocks.1.attn.qkv.bias", + "conf_decoder.blocks.1.attn.proj.bias", + "conf_decoder.blocks.1.norm2.weight", + "conf_decoder.blocks.1.norm2.bias", + "conf_decoder.blocks.1.mlp.fc1.bias", + "conf_decoder.blocks.1.mlp.fc2.bias", + "conf_decoder.blocks.2.norm1.weight", + "conf_decoder.blocks.2.norm1.bias", + "conf_decoder.blocks.2.attn.qkv.bias", + "conf_decoder.blocks.2.attn.proj.bias", + "conf_decoder.blocks.2.norm2.weight", + "conf_decoder.blocks.2.norm2.bias", + "conf_decoder.blocks.2.mlp.fc1.bias", + "conf_decoder.blocks.2.mlp.fc2.bias", + "conf_decoder.blocks.3.norm1.weight", + "conf_decoder.blocks.3.norm1.bias", + "conf_decoder.blocks.3.attn.qkv.bias", + "conf_decoder.blocks.3.attn.proj.bias", + "conf_decoder.blocks.3.norm2.weight", + "conf_decoder.blocks.3.norm2.bias", + "conf_decoder.blocks.3.mlp.fc1.bias", + "conf_decoder.blocks.3.mlp.fc2.bias", + "conf_decoder.blocks.4.norm1.weight", + "conf_decoder.blocks.4.norm1.bias", + "conf_decoder.blocks.4.attn.qkv.bias", + "conf_decoder.blocks.4.attn.proj.bias", + "conf_decoder.blocks.4.norm2.weight", + "conf_decoder.blocks.4.norm2.bias", + "conf_decoder.blocks.4.mlp.fc1.bias", + "conf_decoder.blocks.4.mlp.fc2.bias", + "conf_decoder.linear_out.bias", + "conf_head.proj.bias", + "camera_decoder.projects.bias", + "camera_decoder.blocks.0.norm1.weight", + "camera_decoder.blocks.0.norm1.bias", + "camera_decoder.blocks.0.attn.qkv.bias", + "camera_decoder.blocks.0.attn.proj.bias", + "camera_decoder.blocks.0.norm2.weight", + "camera_decoder.blocks.0.norm2.bias", + "camera_decoder.blocks.0.mlp.fc1.bias", + "camera_decoder.blocks.0.mlp.fc2.bias", + "camera_decoder.blocks.1.norm1.weight", + "camera_decoder.blocks.1.norm1.bias", + "camera_decoder.blocks.1.attn.qkv.bias", + "camera_decoder.blocks.1.attn.proj.bias", + "camera_decoder.blocks.1.norm2.weight", + "camera_decoder.blocks.1.norm2.bias", + "camera_decoder.blocks.1.mlp.fc1.bias", + "camera_decoder.blocks.1.mlp.fc2.bias", + "camera_decoder.blocks.2.norm1.weight", + "camera_decoder.blocks.2.norm1.bias", + "camera_decoder.blocks.2.attn.qkv.bias", + "camera_decoder.blocks.2.attn.proj.bias", + "camera_decoder.blocks.2.norm2.weight", + "camera_decoder.blocks.2.norm2.bias", + "camera_decoder.blocks.2.mlp.fc1.bias", + "camera_decoder.blocks.2.mlp.fc2.bias", + "camera_decoder.blocks.3.norm1.weight", + "camera_decoder.blocks.3.norm1.bias", + "camera_decoder.blocks.3.attn.qkv.bias", + "camera_decoder.blocks.3.attn.proj.bias", + "camera_decoder.blocks.3.norm2.weight", + "camera_decoder.blocks.3.norm2.bias", + "camera_decoder.blocks.3.mlp.fc1.bias", + "camera_decoder.blocks.3.mlp.fc2.bias", + "camera_decoder.blocks.4.norm1.weight", + "camera_decoder.blocks.4.norm1.bias", + "camera_decoder.blocks.4.attn.qkv.bias", + "camera_decoder.blocks.4.attn.proj.bias", + "camera_decoder.blocks.4.norm2.weight", + "camera_decoder.blocks.4.norm2.bias", + "camera_decoder.blocks.4.mlp.fc1.bias", + "camera_decoder.blocks.4.mlp.fc2.bias", + "camera_decoder.linear_out.bias", + "camera_head.res_conv.0.res_conv1.bias", + "camera_head.res_conv.0.res_conv2.bias", + "camera_head.res_conv.0.res_conv3.bias", + "camera_head.res_conv.1.res_conv1.bias", + "camera_head.res_conv.1.res_conv2.bias", + "camera_head.res_conv.1.res_conv3.bias", + "camera_head.more_mlps.0.bias", + "camera_head.more_mlps.2.bias", + "camera_head.fc_t.bias", + "camera_head.fc_rot.bias" + ], + "lr_scale": 1.0 + }, + "decay": { + "weight_decay": 0.05, + "params": [ + "decoder.0.attn.qkv.weight", + "decoder.0.attn.proj.weight", + "decoder.0.mlp.fc1.weight", + "decoder.0.mlp.fc2.weight", + "decoder.1.attn.qkv.weight", + "decoder.1.attn.proj.weight", + "decoder.1.mlp.fc1.weight", + "decoder.1.mlp.fc2.weight", + "decoder.2.attn.qkv.weight", + "decoder.2.attn.proj.weight", + "decoder.2.mlp.fc1.weight", + "decoder.2.mlp.fc2.weight", + "decoder.3.attn.qkv.weight", + "decoder.3.attn.proj.weight", + "decoder.3.mlp.fc1.weight", + "decoder.3.mlp.fc2.weight", + "decoder.4.attn.qkv.weight", + "decoder.4.attn.proj.weight", + "decoder.4.mlp.fc1.weight", + "decoder.4.mlp.fc2.weight", + "decoder.5.attn.qkv.weight", + "decoder.5.attn.proj.weight", + "decoder.5.mlp.fc1.weight", + "decoder.5.mlp.fc2.weight", + "decoder.6.attn.qkv.weight", + "decoder.6.attn.proj.weight", + "decoder.6.mlp.fc1.weight", + "decoder.6.mlp.fc2.weight", + "decoder.7.attn.qkv.weight", + "decoder.7.attn.proj.weight", + "decoder.7.mlp.fc1.weight", + "decoder.7.mlp.fc2.weight", + "decoder.8.attn.qkv.weight", + "decoder.8.attn.proj.weight", + "decoder.8.mlp.fc1.weight", + "decoder.8.mlp.fc2.weight", + "decoder.9.attn.qkv.weight", + "decoder.9.attn.proj.weight", + "decoder.9.mlp.fc1.weight", + "decoder.9.mlp.fc2.weight", + "decoder.10.attn.qkv.weight", + "decoder.10.attn.proj.weight", + "decoder.10.mlp.fc1.weight", + "decoder.10.mlp.fc2.weight", + "decoder.11.attn.qkv.weight", + "decoder.11.attn.proj.weight", + "decoder.11.mlp.fc1.weight", + "decoder.11.mlp.fc2.weight", + "decoder.12.attn.qkv.weight", + "decoder.12.attn.proj.weight", + "decoder.12.mlp.fc1.weight", + "decoder.12.mlp.fc2.weight", + "decoder.13.attn.qkv.weight", + "decoder.13.attn.proj.weight", + "decoder.13.mlp.fc1.weight", + "decoder.13.mlp.fc2.weight", + "decoder.14.attn.qkv.weight", + "decoder.14.attn.proj.weight", + "decoder.14.mlp.fc1.weight", + "decoder.14.mlp.fc2.weight", + "decoder.15.attn.qkv.weight", + "decoder.15.attn.proj.weight", + "decoder.15.mlp.fc1.weight", + "decoder.15.mlp.fc2.weight", + "decoder.16.attn.qkv.weight", + "decoder.16.attn.proj.weight", + "decoder.16.mlp.fc1.weight", + "decoder.16.mlp.fc2.weight", + "decoder.17.attn.qkv.weight", + "decoder.17.attn.proj.weight", + "decoder.17.mlp.fc1.weight", + "decoder.17.mlp.fc2.weight", + "decoder.18.attn.qkv.weight", + "decoder.18.attn.proj.weight", + "decoder.18.mlp.fc1.weight", + "decoder.18.mlp.fc2.weight", + "decoder.19.attn.qkv.weight", + "decoder.19.attn.proj.weight", + "decoder.19.mlp.fc1.weight", + "decoder.19.mlp.fc2.weight", + "decoder.20.attn.qkv.weight", + "decoder.20.attn.proj.weight", + "decoder.20.mlp.fc1.weight", + "decoder.20.mlp.fc2.weight", + "decoder.21.attn.qkv.weight", + "decoder.21.attn.proj.weight", + "decoder.21.mlp.fc1.weight", + "decoder.21.mlp.fc2.weight", + "decoder.22.attn.qkv.weight", + "decoder.22.attn.proj.weight", + "decoder.22.mlp.fc1.weight", + "decoder.22.mlp.fc2.weight", + "decoder.23.attn.qkv.weight", + "decoder.23.attn.proj.weight", + "decoder.23.mlp.fc1.weight", + "decoder.23.mlp.fc2.weight", + "decoder.24.attn.qkv.weight", + "decoder.24.attn.proj.weight", + "decoder.24.mlp.fc1.weight", + "decoder.24.mlp.fc2.weight", + "decoder.25.attn.qkv.weight", + "decoder.25.attn.proj.weight", + "decoder.25.mlp.fc1.weight", + "decoder.25.mlp.fc2.weight", + "decoder.26.attn.qkv.weight", + "decoder.26.attn.proj.weight", + "decoder.26.mlp.fc1.weight", + "decoder.26.mlp.fc2.weight", + "decoder.27.attn.qkv.weight", + "decoder.27.attn.proj.weight", + "decoder.27.mlp.fc1.weight", + "decoder.27.mlp.fc2.weight", + "decoder.28.attn.qkv.weight", + "decoder.28.attn.proj.weight", + "decoder.28.mlp.fc1.weight", + "decoder.28.mlp.fc2.weight", + "decoder.29.attn.qkv.weight", + "decoder.29.attn.proj.weight", + "decoder.29.mlp.fc1.weight", + "decoder.29.mlp.fc2.weight", + "decoder.30.attn.qkv.weight", + "decoder.30.attn.proj.weight", + "decoder.30.mlp.fc1.weight", + "decoder.30.mlp.fc2.weight", + "decoder.31.attn.qkv.weight", + "decoder.31.attn.proj.weight", + "decoder.31.mlp.fc1.weight", + "decoder.31.mlp.fc2.weight", + "decoder.32.attn.qkv.weight", + "decoder.32.attn.proj.weight", + "decoder.32.mlp.fc1.weight", + "decoder.32.mlp.fc2.weight", + "decoder.33.attn.qkv.weight", + "decoder.33.attn.proj.weight", + "decoder.33.mlp.fc1.weight", + "decoder.33.mlp.fc2.weight", + "decoder.34.attn.qkv.weight", + "decoder.34.attn.proj.weight", + "decoder.34.mlp.fc1.weight", + "decoder.34.mlp.fc2.weight", + "decoder.35.attn.qkv.weight", + "decoder.35.attn.proj.weight", + "decoder.35.mlp.fc1.weight", + "decoder.35.mlp.fc2.weight", + "point_decoder.projects.weight", + "point_decoder.blocks.0.attn.qkv.weight", + "point_decoder.blocks.0.attn.proj.weight", + "point_decoder.blocks.0.mlp.fc1.weight", + "point_decoder.blocks.0.mlp.fc2.weight", + "point_decoder.blocks.1.attn.qkv.weight", + "point_decoder.blocks.1.attn.proj.weight", + "point_decoder.blocks.1.mlp.fc1.weight", + "point_decoder.blocks.1.mlp.fc2.weight", + "point_decoder.blocks.2.attn.qkv.weight", + "point_decoder.blocks.2.attn.proj.weight", + "point_decoder.blocks.2.mlp.fc1.weight", + "point_decoder.blocks.2.mlp.fc2.weight", + "point_decoder.blocks.3.attn.qkv.weight", + "point_decoder.blocks.3.attn.proj.weight", + "point_decoder.blocks.3.mlp.fc1.weight", + "point_decoder.blocks.3.mlp.fc2.weight", + "point_decoder.blocks.4.attn.qkv.weight", + "point_decoder.blocks.4.attn.proj.weight", + "point_decoder.blocks.4.mlp.fc1.weight", + "point_decoder.blocks.4.mlp.fc2.weight", + "point_decoder.linear_out.weight", + "point_head.proj.weight", + "conf_decoder.projects.weight", + "conf_decoder.blocks.0.attn.qkv.weight", + "conf_decoder.blocks.0.attn.proj.weight", + "conf_decoder.blocks.0.mlp.fc1.weight", + "conf_decoder.blocks.0.mlp.fc2.weight", + "conf_decoder.blocks.1.attn.qkv.weight", + "conf_decoder.blocks.1.attn.proj.weight", + "conf_decoder.blocks.1.mlp.fc1.weight", + "conf_decoder.blocks.1.mlp.fc2.weight", + "conf_decoder.blocks.2.attn.qkv.weight", + "conf_decoder.blocks.2.attn.proj.weight", + "conf_decoder.blocks.2.mlp.fc1.weight", + "conf_decoder.blocks.2.mlp.fc2.weight", + "conf_decoder.blocks.3.attn.qkv.weight", + "conf_decoder.blocks.3.attn.proj.weight", + "conf_decoder.blocks.3.mlp.fc1.weight", + "conf_decoder.blocks.3.mlp.fc2.weight", + "conf_decoder.blocks.4.attn.qkv.weight", + "conf_decoder.blocks.4.attn.proj.weight", + "conf_decoder.blocks.4.mlp.fc1.weight", + "conf_decoder.blocks.4.mlp.fc2.weight", + "conf_decoder.linear_out.weight", + "conf_head.proj.weight", + "camera_decoder.projects.weight", + "camera_decoder.blocks.0.attn.qkv.weight", + "camera_decoder.blocks.0.attn.proj.weight", + "camera_decoder.blocks.0.mlp.fc1.weight", + "camera_decoder.blocks.0.mlp.fc2.weight", + "camera_decoder.blocks.1.attn.qkv.weight", + "camera_decoder.blocks.1.attn.proj.weight", + "camera_decoder.blocks.1.mlp.fc1.weight", + "camera_decoder.blocks.1.mlp.fc2.weight", + "camera_decoder.blocks.2.attn.qkv.weight", + "camera_decoder.blocks.2.attn.proj.weight", + "camera_decoder.blocks.2.mlp.fc1.weight", + "camera_decoder.blocks.2.mlp.fc2.weight", + "camera_decoder.blocks.3.attn.qkv.weight", + "camera_decoder.blocks.3.attn.proj.weight", + "camera_decoder.blocks.3.mlp.fc1.weight", + "camera_decoder.blocks.3.mlp.fc2.weight", + "camera_decoder.blocks.4.attn.qkv.weight", + "camera_decoder.blocks.4.attn.proj.weight", + "camera_decoder.blocks.4.mlp.fc1.weight", + "camera_decoder.blocks.4.mlp.fc2.weight", + "camera_decoder.linear_out.weight", + "camera_head.res_conv.0.res_conv1.weight", + "camera_head.res_conv.0.res_conv2.weight", + "camera_head.res_conv.0.res_conv3.weight", + "camera_head.res_conv.1.res_conv1.weight", + "camera_head.res_conv.1.res_conv2.weight", + "camera_head.res_conv.1.res_conv3.weight", + "camera_head.more_mlps.0.weight", + "camera_head.more_mlps.2.weight", + "camera_head.fc_t.weight", + "camera_head.fc_rot.weight" + ], + "lr_scale": 1.0 + } +} +[2026-05-01 23:35:30,323][__main__][INFO] - [RANK 0] Start training for 10 epochs +[2026-05-01 23:35:30,327][__main__][INFO] - [RANK 0] log_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu/ +[2026-05-01 23:36:39,904][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 0/4350] eta: 3 days, 12:04:01 lr: 0.000000 epoch: 0.0000 (0.0000) step: 0.0000 (0.0000) loss: 5439.7471 (5439.7471) Lcamera_frontend: 4.1901 (4.1901) Ldepth_frontend: 16.2068 (16.2068) Lpmap_frontend: 18.4157 (18.4157) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.1823 (4.1823) Ldepth_mix: 16.2088 (16.2088) Lpmap_mix: 18.4142 (18.4142) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.1872 (4.1872) Ldepth_backend: 16.2021 (16.2021) Lpmap_backend: 18.4084 (18.4084) Ltrack_backend: 0.0000 (0.0000) total: 5439.7471 (5439.7471) time: 69.5728 data: 24.8740 max mem: 32998 +[2026-05-01 23:43:46,940][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 10/4350] eta: 2 days, 6:25:01 lr: 0.000000 epoch: 0.0011 (0.0011) step: 5.0000 (4.8182) loss: 7296.1621 (7751.5101) Lcamera_frontend: 5.7229 (6.1600) Ldepth_frontend: 16.2943 (16.8899) Lpmap_frontend: 18.4157 (18.3318) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 5.7624 (6.1545) Ldepth_mix: 16.2973 (16.8966) Lpmap_mix: 18.4142 (18.3372) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 5.7257 (6.0972) Ldepth_backend: 16.3080 (16.9024) Lpmap_backend: 18.4084 (18.3392) Ltrack_backend: 0.0000 (0.0000) total: 7296.1621 (7751.5101) time: 45.1386 data: 2.2946 max mem: 78413 +[2026-05-01 23:52:00,851][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 20/4350] eta: 2 days, 8:43:38 lr: 0.000000 epoch: 0.0023 (0.0023) step: 10.0000 (9.8571) loss: 4755.5371 (6484.7880) Lcamera_frontend: 3.8424 (5.0833) Ldepth_frontend: 16.2943 (16.7699) Lpmap_frontend: 18.2400 (18.2326) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.6659 (5.0774) Ldepth_mix: 16.2973 (16.7797) Lpmap_mix: 18.2386 (18.2391) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.6113 (5.0484) Ldepth_backend: 16.3080 (16.7872) Lpmap_backend: 18.2444 (18.2431) Ltrack_backend: 0.0000 (0.0000) total: 4755.5371 (6484.7880) time: 46.0431 data: 0.0400 max mem: 78608 +[2026-05-02 01:09:53,816][__main__][INFO] - [RANK 0] output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu/ +[2026-05-02 01:09:54,143][__main__][INFO] - [RANK 0] Saving current code to /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu/code/05_02-01:09:53 +[2026-05-02 01:09:54,143][__main__][INFO] - [RANK 0] job dir: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src +[2026-05-02 01:09:54,143][__main__][INFO] - [RANK 0] Setting seed to 0 for process 0 +[2026-05-02 01:09:54,145][__main__][INFO] - [RANK 0] Building train dataset 6000 @ VirtualKITTI2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 6000 @ KITTI360_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", velodyne_root="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 5400 @ Waymo_v2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/waymo_v2", lidar_root="/scratch-shared/wwei2/waymo_v2", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) +[2026-05-02 01:09:54,145][__main__][INFO] - [RANK 0] Building Train Data loader for dataset: 6000 @ VirtualKITTI2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 6000 @ KITTI360_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", velodyne_root="/scratch-shared/wwei2/downloads/kitti360/KITTI-360", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) + 5400 @ Waymo_v2_Multi(allow_repeat=False, split='train', ROOT="/scratch-shared/wwei2/waymo_v2", lidar_root="/scratch-shared/wwei2/waymo_v2", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=64, n_corres=0) +[2026-05-02 01:14:17,361][__main__][INFO] - [RANK 0] Building test dataset 200 @ VirtualKITTI2_Multi(split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", resolution=(518, 154), num_views=4, seed=42, n_corres=0) +[2026-05-02 01:14:17,365][__main__][INFO] - [RANK 0] Building Test Data loader for dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="/scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti", resolution=(518, 154), num_views=4, seed=42, n_corres=0) +[2026-05-02 01:14:17,418][__main__][INFO] - [RANK 0] Loading model +[2026-05-02 01:14:22,843][__main__][INFO] - [RANK 0] All model parameters: 958696732 +[2026-05-02 01:14:22,843][__main__][INFO] - [RANK 0] >> Creating train criterion = DistillLoss() +[2026-05-02 01:14:22,843][__main__][INFO] - [RANK 0] >> Creating test criterion = DistillLoss() +[2026-05-02 01:14:23,519][__main__][INFO] - [RANK 0] Loading pretrained: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +[2026-05-02 01:14:36,040][__main__][INFO] - [RANK 0] _IncompatibleKeys(missing_keys=['register_token', 'image_mean', 'image_std', 'encoder.cls_token', 'encoder.pos_embed', 'encoder.register_tokens', 'encoder.patch_embed.proj.weight', 'encoder.patch_embed.proj.bias', 'encoder.blocks.0.norm1.weight', 'encoder.blocks.0.norm1.bias', 'encoder.blocks.0.attn.qkv.weight', 'encoder.blocks.0.attn.qkv.bias', 'encoder.blocks.0.attn.proj.weight', 'encoder.blocks.0.attn.proj.bias', 'encoder.blocks.0.ls1.gamma', 'encoder.blocks.0.norm2.weight', 'encoder.blocks.0.norm2.bias', 'encoder.blocks.0.mlp.fc1.weight', 'encoder.blocks.0.mlp.fc1.bias', 'encoder.blocks.0.mlp.fc2.weight', 'encoder.blocks.0.mlp.fc2.bias', 'encoder.blocks.0.ls2.gamma', 'encoder.blocks.1.norm1.weight', 'encoder.blocks.1.norm1.bias', 'encoder.blocks.1.attn.qkv.weight', 'encoder.blocks.1.attn.qkv.bias', 'encoder.blocks.1.attn.proj.weight', 'encoder.blocks.1.attn.proj.bias', 'encoder.blocks.1.ls1.gamma', 'encoder.blocks.1.norm2.weight', 'encoder.blocks.1.norm2.bias', 'encoder.blocks.1.mlp.fc1.weight', 'encoder.blocks.1.mlp.fc1.bias', 'encoder.blocks.1.mlp.fc2.weight', 'encoder.blocks.1.mlp.fc2.bias', 'encoder.blocks.1.ls2.gamma', 'encoder.blocks.2.norm1.weight', 'encoder.blocks.2.norm1.bias', 'encoder.blocks.2.attn.qkv.weight', 'encoder.blocks.2.attn.qkv.bias', 'encoder.blocks.2.attn.proj.weight', 'encoder.blocks.2.attn.proj.bias', 'encoder.blocks.2.ls1.gamma', 'encoder.blocks.2.norm2.weight', 'encoder.blocks.2.norm2.bias', 'encoder.blocks.2.mlp.fc1.weight', 'encoder.blocks.2.mlp.fc1.bias', 'encoder.blocks.2.mlp.fc2.weight', 'encoder.blocks.2.mlp.fc2.bias', 'encoder.blocks.2.ls2.gamma', 'encoder.blocks.3.norm1.weight', 'encoder.blocks.3.norm1.bias', 'encoder.blocks.3.attn.qkv.weight', 'encoder.blocks.3.attn.qkv.bias', 'encoder.blocks.3.attn.proj.weight', 'encoder.blocks.3.attn.proj.bias', 'encoder.blocks.3.ls1.gamma', 'encoder.blocks.3.norm2.weight', 'encoder.blocks.3.norm2.bias', 'encoder.blocks.3.mlp.fc1.weight', 'encoder.blocks.3.mlp.fc1.bias', 'encoder.blocks.3.mlp.fc2.weight', 'encoder.blocks.3.mlp.fc2.bias', 'encoder.blocks.3.ls2.gamma', 'encoder.blocks.4.norm1.weight', 'encoder.blocks.4.norm1.bias', 'encoder.blocks.4.attn.qkv.weight', 'encoder.blocks.4.attn.qkv.bias', 'encoder.blocks.4.attn.proj.weight', 'encoder.blocks.4.attn.proj.bias', 'encoder.blocks.4.ls1.gamma', 'encoder.blocks.4.norm2.weight', 'encoder.blocks.4.norm2.bias', 'encoder.blocks.4.mlp.fc1.weight', 'encoder.blocks.4.mlp.fc1.bias', 'encoder.blocks.4.mlp.fc2.weight', 'encoder.blocks.4.mlp.fc2.bias', 'encoder.blocks.4.ls2.gamma', 'encoder.blocks.5.norm1.weight', 'encoder.blocks.5.norm1.bias', 'encoder.blocks.5.attn.qkv.weight', 'encoder.blocks.5.attn.qkv.bias', 'encoder.blocks.5.attn.proj.weight', 'encoder.blocks.5.attn.proj.bias', 'encoder.blocks.5.ls1.gamma', 'encoder.blocks.5.norm2.weight', 'encoder.blocks.5.norm2.bias', 'encoder.blocks.5.mlp.fc1.weight', 'encoder.blocks.5.mlp.fc1.bias', 'encoder.blocks.5.mlp.fc2.weight', 'encoder.blocks.5.mlp.fc2.bias', 'encoder.blocks.5.ls2.gamma', 'encoder.blocks.6.norm1.weight', 'encoder.blocks.6.norm1.bias', 'encoder.blocks.6.attn.qkv.weight', 'encoder.blocks.6.attn.qkv.bias', 'encoder.blocks.6.attn.proj.weight', 'encoder.blocks.6.attn.proj.bias', 'encoder.blocks.6.ls1.gamma', 'encoder.blocks.6.norm2.weight', 'encoder.blocks.6.norm2.bias', 'encoder.blocks.6.mlp.fc1.weight', 'encoder.blocks.6.mlp.fc1.bias', 'encoder.blocks.6.mlp.fc2.weight', 'encoder.blocks.6.mlp.fc2.bias', 'encoder.blocks.6.ls2.gamma', 'encoder.blocks.7.norm1.weight', 'encoder.blocks.7.norm1.bias', 'encoder.blocks.7.attn.qkv.weight', 'encoder.blocks.7.attn.qkv.bias', 'encoder.blocks.7.attn.proj.weight', 'encoder.blocks.7.attn.proj.bias', 'encoder.blocks.7.ls1.gamma', 'encoder.blocks.7.norm2.weight', 'encoder.blocks.7.norm2.bias', 'encoder.blocks.7.mlp.fc1.weight', 'encoder.blocks.7.mlp.fc1.bias', 'encoder.blocks.7.mlp.fc2.weight', 'encoder.blocks.7.mlp.fc2.bias', 'encoder.blocks.7.ls2.gamma', 'encoder.blocks.8.norm1.weight', 'encoder.blocks.8.norm1.bias', 'encoder.blocks.8.attn.qkv.weight', 'encoder.blocks.8.attn.qkv.bias', 'encoder.blocks.8.attn.proj.weight', 'encoder.blocks.8.attn.proj.bias', 'encoder.blocks.8.ls1.gamma', 'encoder.blocks.8.norm2.weight', 'encoder.blocks.8.norm2.bias', 'encoder.blocks.8.mlp.fc1.weight', 'encoder.blocks.8.mlp.fc1.bias', 'encoder.blocks.8.mlp.fc2.weight', 'encoder.blocks.8.mlp.fc2.bias', 'encoder.blocks.8.ls2.gamma', 'encoder.blocks.9.norm1.weight', 'encoder.blocks.9.norm1.bias', 'encoder.blocks.9.attn.qkv.weight', 'encoder.blocks.9.attn.qkv.bias', 'encoder.blocks.9.attn.proj.weight', 'encoder.blocks.9.attn.proj.bias', 'encoder.blocks.9.ls1.gamma', 'encoder.blocks.9.norm2.weight', 'encoder.blocks.9.norm2.bias', 'encoder.blocks.9.mlp.fc1.weight', 'encoder.blocks.9.mlp.fc1.bias', 'encoder.blocks.9.mlp.fc2.weight', 'encoder.blocks.9.mlp.fc2.bias', 'encoder.blocks.9.ls2.gamma', 'encoder.blocks.10.norm1.weight', 'encoder.blocks.10.norm1.bias', 'encoder.blocks.10.attn.qkv.weight', 'encoder.blocks.10.attn.qkv.bias', 'encoder.blocks.10.attn.proj.weight', 'encoder.blocks.10.attn.proj.bias', 'encoder.blocks.10.ls1.gamma', 'encoder.blocks.10.norm2.weight', 'encoder.blocks.10.norm2.bias', 'encoder.blocks.10.mlp.fc1.weight', 'encoder.blocks.10.mlp.fc1.bias', 'encoder.blocks.10.mlp.fc2.weight', 'encoder.blocks.10.mlp.fc2.bias', 'encoder.blocks.10.ls2.gamma', 'encoder.blocks.11.norm1.weight', 'encoder.blocks.11.norm1.bias', 'encoder.blocks.11.attn.qkv.weight', 'encoder.blocks.11.attn.qkv.bias', 'encoder.blocks.11.attn.proj.weight', 'encoder.blocks.11.attn.proj.bias', 'encoder.blocks.11.ls1.gamma', 'encoder.blocks.11.norm2.weight', 'encoder.blocks.11.norm2.bias', 'encoder.blocks.11.mlp.fc1.weight', 'encoder.blocks.11.mlp.fc1.bias', 'encoder.blocks.11.mlp.fc2.weight', 'encoder.blocks.11.mlp.fc2.bias', 'encoder.blocks.11.ls2.gamma', 'encoder.blocks.12.norm1.weight', 'encoder.blocks.12.norm1.bias', 'encoder.blocks.12.attn.qkv.weight', 'encoder.blocks.12.attn.qkv.bias', 'encoder.blocks.12.attn.proj.weight', 'encoder.blocks.12.attn.proj.bias', 'encoder.blocks.12.ls1.gamma', 'encoder.blocks.12.norm2.weight', 'encoder.blocks.12.norm2.bias', 'encoder.blocks.12.mlp.fc1.weight', 'encoder.blocks.12.mlp.fc1.bias', 'encoder.blocks.12.mlp.fc2.weight', 'encoder.blocks.12.mlp.fc2.bias', 'encoder.blocks.12.ls2.gamma', 'encoder.blocks.13.norm1.weight', 'encoder.blocks.13.norm1.bias', 'encoder.blocks.13.attn.qkv.weight', 'encoder.blocks.13.attn.qkv.bias', 'encoder.blocks.13.attn.proj.weight', 'encoder.blocks.13.attn.proj.bias', 'encoder.blocks.13.ls1.gamma', 'encoder.blocks.13.norm2.weight', 'encoder.blocks.13.norm2.bias', 'encoder.blocks.13.mlp.fc1.weight', 'encoder.blocks.13.mlp.fc1.bias', 'encoder.blocks.13.mlp.fc2.weight', 'encoder.blocks.13.mlp.fc2.bias', 'encoder.blocks.13.ls2.gamma', 'encoder.blocks.14.norm1.weight', 'encoder.blocks.14.norm1.bias', 'encoder.blocks.14.attn.qkv.weight', 'encoder.blocks.14.attn.qkv.bias', 'encoder.blocks.14.attn.proj.weight', 'encoder.blocks.14.attn.proj.bias', 'encoder.blocks.14.ls1.gamma', 'encoder.blocks.14.norm2.weight', 'encoder.blocks.14.norm2.bias', 'encoder.blocks.14.mlp.fc1.weight', 'encoder.blocks.14.mlp.fc1.bias', 'encoder.blocks.14.mlp.fc2.weight', 'encoder.blocks.14.mlp.fc2.bias', 'encoder.blocks.14.ls2.gamma', 'encoder.blocks.15.norm1.weight', 'encoder.blocks.15.norm1.bias', 'encoder.blocks.15.attn.qkv.weight', 'encoder.blocks.15.attn.qkv.bias', 'encoder.blocks.15.attn.proj.weight', 'encoder.blocks.15.attn.proj.bias', 'encoder.blocks.15.ls1.gamma', 'encoder.blocks.15.norm2.weight', 'encoder.blocks.15.norm2.bias', 'encoder.blocks.15.mlp.fc1.weight', 'encoder.blocks.15.mlp.fc1.bias', 'encoder.blocks.15.mlp.fc2.weight', 'encoder.blocks.15.mlp.fc2.bias', 'encoder.blocks.15.ls2.gamma', 'encoder.blocks.16.norm1.weight', 'encoder.blocks.16.norm1.bias', 'encoder.blocks.16.attn.qkv.weight', 'encoder.blocks.16.attn.qkv.bias', 'encoder.blocks.16.attn.proj.weight', 'encoder.blocks.16.attn.proj.bias', 'encoder.blocks.16.ls1.gamma', 'encoder.blocks.16.norm2.weight', 'encoder.blocks.16.norm2.bias', 'encoder.blocks.16.mlp.fc1.weight', 'encoder.blocks.16.mlp.fc1.bias', 'encoder.blocks.16.mlp.fc2.weight', 'encoder.blocks.16.mlp.fc2.bias', 'encoder.blocks.16.ls2.gamma', 'encoder.blocks.17.norm1.weight', 'encoder.blocks.17.norm1.bias', 'encoder.blocks.17.attn.qkv.weight', 'encoder.blocks.17.attn.qkv.bias', 'encoder.blocks.17.attn.proj.weight', 'encoder.blocks.17.attn.proj.bias', 'encoder.blocks.17.ls1.gamma', 'encoder.blocks.17.norm2.weight', 'encoder.blocks.17.norm2.bias', 'encoder.blocks.17.mlp.fc1.weight', 'encoder.blocks.17.mlp.fc1.bias', 'encoder.blocks.17.mlp.fc2.weight', 'encoder.blocks.17.mlp.fc2.bias', 'encoder.blocks.17.ls2.gamma', 'encoder.blocks.18.norm1.weight', 'encoder.blocks.18.norm1.bias', 'encoder.blocks.18.attn.qkv.weight', 'encoder.blocks.18.attn.qkv.bias', 'encoder.blocks.18.attn.proj.weight', 'encoder.blocks.18.attn.proj.bias', 'encoder.blocks.18.ls1.gamma', 'encoder.blocks.18.norm2.weight', 'encoder.blocks.18.norm2.bias', 'encoder.blocks.18.mlp.fc1.weight', 'encoder.blocks.18.mlp.fc1.bias', 'encoder.blocks.18.mlp.fc2.weight', 'encoder.blocks.18.mlp.fc2.bias', 'encoder.blocks.18.ls2.gamma', 'encoder.blocks.19.norm1.weight', 'encoder.blocks.19.norm1.bias', 'encoder.blocks.19.attn.qkv.weight', 'encoder.blocks.19.attn.qkv.bias', 'encoder.blocks.19.attn.proj.weight', 'encoder.blocks.19.attn.proj.bias', 'encoder.blocks.19.ls1.gamma', 'encoder.blocks.19.norm2.weight', 'encoder.blocks.19.norm2.bias', 'encoder.blocks.19.mlp.fc1.weight', 'encoder.blocks.19.mlp.fc1.bias', 'encoder.blocks.19.mlp.fc2.weight', 'encoder.blocks.19.mlp.fc2.bias', 'encoder.blocks.19.ls2.gamma', 'encoder.blocks.20.norm1.weight', 'encoder.blocks.20.norm1.bias', 'encoder.blocks.20.attn.qkv.weight', 'encoder.blocks.20.attn.qkv.bias', 'encoder.blocks.20.attn.proj.weight', 'encoder.blocks.20.attn.proj.bias', 'encoder.blocks.20.ls1.gamma', 'encoder.blocks.20.norm2.weight', 'encoder.blocks.20.norm2.bias', 'encoder.blocks.20.mlp.fc1.weight', 'encoder.blocks.20.mlp.fc1.bias', 'encoder.blocks.20.mlp.fc2.weight', 'encoder.blocks.20.mlp.fc2.bias', 'encoder.blocks.20.ls2.gamma', 'encoder.blocks.21.norm1.weight', 'encoder.blocks.21.norm1.bias', 'encoder.blocks.21.attn.qkv.weight', 'encoder.blocks.21.attn.qkv.bias', 'encoder.blocks.21.attn.proj.weight', 'encoder.blocks.21.attn.proj.bias', 'encoder.blocks.21.ls1.gamma', 'encoder.blocks.21.norm2.weight', 'encoder.blocks.21.norm2.bias', 'encoder.blocks.21.mlp.fc1.weight', 'encoder.blocks.21.mlp.fc1.bias', 'encoder.blocks.21.mlp.fc2.weight', 'encoder.blocks.21.mlp.fc2.bias', 'encoder.blocks.21.ls2.gamma', 'encoder.blocks.22.norm1.weight', 'encoder.blocks.22.norm1.bias', 'encoder.blocks.22.attn.qkv.weight', 'encoder.blocks.22.attn.qkv.bias', 'encoder.blocks.22.attn.proj.weight', 'encoder.blocks.22.attn.proj.bias', 'encoder.blocks.22.ls1.gamma', 'encoder.blocks.22.norm2.weight', 'encoder.blocks.22.norm2.bias', 'encoder.blocks.22.mlp.fc1.weight', 'encoder.blocks.22.mlp.fc1.bias', 'encoder.blocks.22.mlp.fc2.weight', 'encoder.blocks.22.mlp.fc2.bias', 'encoder.blocks.22.ls2.gamma', 'encoder.blocks.23.norm1.weight', 'encoder.blocks.23.norm1.bias', 'encoder.blocks.23.attn.qkv.weight', 'encoder.blocks.23.attn.qkv.bias', 'encoder.blocks.23.attn.proj.weight', 'encoder.blocks.23.attn.proj.bias', 'encoder.blocks.23.ls1.gamma', 'encoder.blocks.23.norm2.weight', 'encoder.blocks.23.norm2.bias', 'encoder.blocks.23.mlp.fc1.weight', 'encoder.blocks.23.mlp.fc1.bias', 'encoder.blocks.23.mlp.fc2.weight', 'encoder.blocks.23.mlp.fc2.bias', 'encoder.blocks.23.ls2.gamma', 'encoder.norm.weight', 'encoder.norm.bias', 'decoder.0.norm1.weight', 'decoder.0.norm1.bias', 'decoder.0.attn.qkv.weight', 'decoder.0.attn.qkv.bias', 'decoder.0.attn.proj.weight', 'decoder.0.attn.proj.bias', 'decoder.0.attn.q_norm.weight', 'decoder.0.attn.q_norm.bias', 'decoder.0.attn.k_norm.weight', 'decoder.0.attn.k_norm.bias', 'decoder.0.ls1.gamma', 'decoder.0.norm2.weight', 'decoder.0.norm2.bias', 'decoder.0.mlp.fc1.weight', 'decoder.0.mlp.fc1.bias', 'decoder.0.mlp.fc2.weight', 'decoder.0.mlp.fc2.bias', 'decoder.0.ls2.gamma', 'decoder.1.norm1.weight', 'decoder.1.norm1.bias', 'decoder.1.attn.qkv.weight', 'decoder.1.attn.qkv.bias', 'decoder.1.attn.proj.weight', 'decoder.1.attn.proj.bias', 'decoder.1.attn.q_norm.weight', 'decoder.1.attn.q_norm.bias', 'decoder.1.attn.k_norm.weight', 'decoder.1.attn.k_norm.bias', 'decoder.1.ls1.gamma', 'decoder.1.norm2.weight', 'decoder.1.norm2.bias', 'decoder.1.mlp.fc1.weight', 'decoder.1.mlp.fc1.bias', 'decoder.1.mlp.fc2.weight', 'decoder.1.mlp.fc2.bias', 'decoder.1.ls2.gamma', 'decoder.2.norm1.weight', 'decoder.2.norm1.bias', 'decoder.2.attn.qkv.weight', 'decoder.2.attn.qkv.bias', 'decoder.2.attn.proj.weight', 'decoder.2.attn.proj.bias', 'decoder.2.attn.q_norm.weight', 'decoder.2.attn.q_norm.bias', 'decoder.2.attn.k_norm.weight', 'decoder.2.attn.k_norm.bias', 'decoder.2.ls1.gamma', 'decoder.2.norm2.weight', 'decoder.2.norm2.bias', 'decoder.2.mlp.fc1.weight', 'decoder.2.mlp.fc1.bias', 'decoder.2.mlp.fc2.weight', 'decoder.2.mlp.fc2.bias', 'decoder.2.ls2.gamma', 'decoder.3.norm1.weight', 'decoder.3.norm1.bias', 'decoder.3.attn.qkv.weight', 'decoder.3.attn.qkv.bias', 'decoder.3.attn.proj.weight', 'decoder.3.attn.proj.bias', 'decoder.3.attn.q_norm.weight', 'decoder.3.attn.q_norm.bias', 'decoder.3.attn.k_norm.weight', 'decoder.3.attn.k_norm.bias', 'decoder.3.ls1.gamma', 'decoder.3.norm2.weight', 'decoder.3.norm2.bias', 'decoder.3.mlp.fc1.weight', 'decoder.3.mlp.fc1.bias', 'decoder.3.mlp.fc2.weight', 'decoder.3.mlp.fc2.bias', 'decoder.3.ls2.gamma', 'decoder.4.norm1.weight', 'decoder.4.norm1.bias', 'decoder.4.attn.qkv.weight', 'decoder.4.attn.qkv.bias', 'decoder.4.attn.proj.weight', 'decoder.4.attn.proj.bias', 'decoder.4.attn.q_norm.weight', 'decoder.4.attn.q_norm.bias', 'decoder.4.attn.k_norm.weight', 'decoder.4.attn.k_norm.bias', 'decoder.4.ls1.gamma', 'decoder.4.norm2.weight', 'decoder.4.norm2.bias', 'decoder.4.mlp.fc1.weight', 'decoder.4.mlp.fc1.bias', 'decoder.4.mlp.fc2.weight', 'decoder.4.mlp.fc2.bias', 'decoder.4.ls2.gamma', 'decoder.5.norm1.weight', 'decoder.5.norm1.bias', 'decoder.5.attn.qkv.weight', 'decoder.5.attn.qkv.bias', 'decoder.5.attn.proj.weight', 'decoder.5.attn.proj.bias', 'decoder.5.attn.q_norm.weight', 'decoder.5.attn.q_norm.bias', 'decoder.5.attn.k_norm.weight', 'decoder.5.attn.k_norm.bias', 'decoder.5.ls1.gamma', 'decoder.5.norm2.weight', 'decoder.5.norm2.bias', 'decoder.5.mlp.fc1.weight', 'decoder.5.mlp.fc1.bias', 'decoder.5.mlp.fc2.weight', 'decoder.5.mlp.fc2.bias', 'decoder.5.ls2.gamma', 'decoder.6.norm1.weight', 'decoder.6.norm1.bias', 'decoder.6.attn.qkv.weight', 'decoder.6.attn.qkv.bias', 'decoder.6.attn.proj.weight', 'decoder.6.attn.proj.bias', 'decoder.6.attn.q_norm.weight', 'decoder.6.attn.q_norm.bias', 'decoder.6.attn.k_norm.weight', 'decoder.6.attn.k_norm.bias', 'decoder.6.ls1.gamma', 'decoder.6.norm2.weight', 'decoder.6.norm2.bias', 'decoder.6.mlp.fc1.weight', 'decoder.6.mlp.fc1.bias', 'decoder.6.mlp.fc2.weight', 'decoder.6.mlp.fc2.bias', 'decoder.6.ls2.gamma', 'decoder.7.norm1.weight', 'decoder.7.norm1.bias', 'decoder.7.attn.qkv.weight', 'decoder.7.attn.qkv.bias', 'decoder.7.attn.proj.weight', 'decoder.7.attn.proj.bias', 'decoder.7.attn.q_norm.weight', 'decoder.7.attn.q_norm.bias', 'decoder.7.attn.k_norm.weight', 'decoder.7.attn.k_norm.bias', 'decoder.7.ls1.gamma', 'decoder.7.norm2.weight', 'decoder.7.norm2.bias', 'decoder.7.mlp.fc1.weight', 'decoder.7.mlp.fc1.bias', 'decoder.7.mlp.fc2.weight', 'decoder.7.mlp.fc2.bias', 'decoder.7.ls2.gamma', 'decoder.8.norm1.weight', 'decoder.8.norm1.bias', 'decoder.8.attn.qkv.weight', 'decoder.8.attn.qkv.bias', 'decoder.8.attn.proj.weight', 'decoder.8.attn.proj.bias', 'decoder.8.attn.q_norm.weight', 'decoder.8.attn.q_norm.bias', 'decoder.8.attn.k_norm.weight', 'decoder.8.attn.k_norm.bias', 'decoder.8.ls1.gamma', 'decoder.8.norm2.weight', 'decoder.8.norm2.bias', 'decoder.8.mlp.fc1.weight', 'decoder.8.mlp.fc1.bias', 'decoder.8.mlp.fc2.weight', 'decoder.8.mlp.fc2.bias', 'decoder.8.ls2.gamma', 'decoder.9.norm1.weight', 'decoder.9.norm1.bias', 'decoder.9.attn.qkv.weight', 'decoder.9.attn.qkv.bias', 'decoder.9.attn.proj.weight', 'decoder.9.attn.proj.bias', 'decoder.9.attn.q_norm.weight', 'decoder.9.attn.q_norm.bias', 'decoder.9.attn.k_norm.weight', 'decoder.9.attn.k_norm.bias', 'decoder.9.ls1.gamma', 'decoder.9.norm2.weight', 'decoder.9.norm2.bias', 'decoder.9.mlp.fc1.weight', 'decoder.9.mlp.fc1.bias', 'decoder.9.mlp.fc2.weight', 'decoder.9.mlp.fc2.bias', 'decoder.9.ls2.gamma', 'decoder.10.norm1.weight', 'decoder.10.norm1.bias', 'decoder.10.attn.qkv.weight', 'decoder.10.attn.qkv.bias', 'decoder.10.attn.proj.weight', 'decoder.10.attn.proj.bias', 'decoder.10.attn.q_norm.weight', 'decoder.10.attn.q_norm.bias', 'decoder.10.attn.k_norm.weight', 'decoder.10.attn.k_norm.bias', 'decoder.10.ls1.gamma', 'decoder.10.norm2.weight', 'decoder.10.norm2.bias', 'decoder.10.mlp.fc1.weight', 'decoder.10.mlp.fc1.bias', 'decoder.10.mlp.fc2.weight', 'decoder.10.mlp.fc2.bias', 'decoder.10.ls2.gamma', 'decoder.11.norm1.weight', 'decoder.11.norm1.bias', 'decoder.11.attn.qkv.weight', 'decoder.11.attn.qkv.bias', 'decoder.11.attn.proj.weight', 'decoder.11.attn.proj.bias', 'decoder.11.attn.q_norm.weight', 'decoder.11.attn.q_norm.bias', 'decoder.11.attn.k_norm.weight', 'decoder.11.attn.k_norm.bias', 'decoder.11.ls1.gamma', 'decoder.11.norm2.weight', 'decoder.11.norm2.bias', 'decoder.11.mlp.fc1.weight', 'decoder.11.mlp.fc1.bias', 'decoder.11.mlp.fc2.weight', 'decoder.11.mlp.fc2.bias', 'decoder.11.ls2.gamma', 'decoder.12.norm1.weight', 'decoder.12.norm1.bias', 'decoder.12.attn.qkv.weight', 'decoder.12.attn.qkv.bias', 'decoder.12.attn.proj.weight', 'decoder.12.attn.proj.bias', 'decoder.12.attn.q_norm.weight', 'decoder.12.attn.q_norm.bias', 'decoder.12.attn.k_norm.weight', 'decoder.12.attn.k_norm.bias', 'decoder.12.ls1.gamma', 'decoder.12.norm2.weight', 'decoder.12.norm2.bias', 'decoder.12.mlp.fc1.weight', 'decoder.12.mlp.fc1.bias', 'decoder.12.mlp.fc2.weight', 'decoder.12.mlp.fc2.bias', 'decoder.12.ls2.gamma', 'decoder.13.norm1.weight', 'decoder.13.norm1.bias', 'decoder.13.attn.qkv.weight', 'decoder.13.attn.qkv.bias', 'decoder.13.attn.proj.weight', 'decoder.13.attn.proj.bias', 'decoder.13.attn.q_norm.weight', 'decoder.13.attn.q_norm.bias', 'decoder.13.attn.k_norm.weight', 'decoder.13.attn.k_norm.bias', 'decoder.13.ls1.gamma', 'decoder.13.norm2.weight', 'decoder.13.norm2.bias', 'decoder.13.mlp.fc1.weight', 'decoder.13.mlp.fc1.bias', 'decoder.13.mlp.fc2.weight', 'decoder.13.mlp.fc2.bias', 'decoder.13.ls2.gamma', 'decoder.14.norm1.weight', 'decoder.14.norm1.bias', 'decoder.14.attn.qkv.weight', 'decoder.14.attn.qkv.bias', 'decoder.14.attn.proj.weight', 'decoder.14.attn.proj.bias', 'decoder.14.attn.q_norm.weight', 'decoder.14.attn.q_norm.bias', 'decoder.14.attn.k_norm.weight', 'decoder.14.attn.k_norm.bias', 'decoder.14.ls1.gamma', 'decoder.14.norm2.weight', 'decoder.14.norm2.bias', 'decoder.14.mlp.fc1.weight', 'decoder.14.mlp.fc1.bias', 'decoder.14.mlp.fc2.weight', 'decoder.14.mlp.fc2.bias', 'decoder.14.ls2.gamma', 'decoder.15.norm1.weight', 'decoder.15.norm1.bias', 'decoder.15.attn.qkv.weight', 'decoder.15.attn.qkv.bias', 'decoder.15.attn.proj.weight', 'decoder.15.attn.proj.bias', 'decoder.15.attn.q_norm.weight', 'decoder.15.attn.q_norm.bias', 'decoder.15.attn.k_norm.weight', 'decoder.15.attn.k_norm.bias', 'decoder.15.ls1.gamma', 'decoder.15.norm2.weight', 'decoder.15.norm2.bias', 'decoder.15.mlp.fc1.weight', 'decoder.15.mlp.fc1.bias', 'decoder.15.mlp.fc2.weight', 'decoder.15.mlp.fc2.bias', 'decoder.15.ls2.gamma', 'decoder.16.norm1.weight', 'decoder.16.norm1.bias', 'decoder.16.attn.qkv.weight', 'decoder.16.attn.qkv.bias', 'decoder.16.attn.proj.weight', 'decoder.16.attn.proj.bias', 'decoder.16.attn.q_norm.weight', 'decoder.16.attn.q_norm.bias', 'decoder.16.attn.k_norm.weight', 'decoder.16.attn.k_norm.bias', 'decoder.16.ls1.gamma', 'decoder.16.norm2.weight', 'decoder.16.norm2.bias', 'decoder.16.mlp.fc1.weight', 'decoder.16.mlp.fc1.bias', 'decoder.16.mlp.fc2.weight', 'decoder.16.mlp.fc2.bias', 'decoder.16.ls2.gamma', 'decoder.17.norm1.weight', 'decoder.17.norm1.bias', 'decoder.17.attn.qkv.weight', 'decoder.17.attn.qkv.bias', 'decoder.17.attn.proj.weight', 'decoder.17.attn.proj.bias', 'decoder.17.attn.q_norm.weight', 'decoder.17.attn.q_norm.bias', 'decoder.17.attn.k_norm.weight', 'decoder.17.attn.k_norm.bias', 'decoder.17.ls1.gamma', 'decoder.17.norm2.weight', 'decoder.17.norm2.bias', 'decoder.17.mlp.fc1.weight', 'decoder.17.mlp.fc1.bias', 'decoder.17.mlp.fc2.weight', 'decoder.17.mlp.fc2.bias', 'decoder.17.ls2.gamma', 'decoder.18.norm1.weight', 'decoder.18.norm1.bias', 'decoder.18.attn.qkv.weight', 'decoder.18.attn.qkv.bias', 'decoder.18.attn.proj.weight', 'decoder.18.attn.proj.bias', 'decoder.18.attn.q_norm.weight', 'decoder.18.attn.q_norm.bias', 'decoder.18.attn.k_norm.weight', 'decoder.18.attn.k_norm.bias', 'decoder.18.ls1.gamma', 'decoder.18.norm2.weight', 'decoder.18.norm2.bias', 'decoder.18.mlp.fc1.weight', 'decoder.18.mlp.fc1.bias', 'decoder.18.mlp.fc2.weight', 'decoder.18.mlp.fc2.bias', 'decoder.18.ls2.gamma', 'decoder.19.norm1.weight', 'decoder.19.norm1.bias', 'decoder.19.attn.qkv.weight', 'decoder.19.attn.qkv.bias', 'decoder.19.attn.proj.weight', 'decoder.19.attn.proj.bias', 'decoder.19.attn.q_norm.weight', 'decoder.19.attn.q_norm.bias', 'decoder.19.attn.k_norm.weight', 'decoder.19.attn.k_norm.bias', 'decoder.19.ls1.gamma', 'decoder.19.norm2.weight', 'decoder.19.norm2.bias', 'decoder.19.mlp.fc1.weight', 'decoder.19.mlp.fc1.bias', 'decoder.19.mlp.fc2.weight', 'decoder.19.mlp.fc2.bias', 'decoder.19.ls2.gamma', 'decoder.20.norm1.weight', 'decoder.20.norm1.bias', 'decoder.20.attn.qkv.weight', 'decoder.20.attn.qkv.bias', 'decoder.20.attn.proj.weight', 'decoder.20.attn.proj.bias', 'decoder.20.attn.q_norm.weight', 'decoder.20.attn.q_norm.bias', 'decoder.20.attn.k_norm.weight', 'decoder.20.attn.k_norm.bias', 'decoder.20.ls1.gamma', 'decoder.20.norm2.weight', 'decoder.20.norm2.bias', 'decoder.20.mlp.fc1.weight', 'decoder.20.mlp.fc1.bias', 'decoder.20.mlp.fc2.weight', 'decoder.20.mlp.fc2.bias', 'decoder.20.ls2.gamma', 'decoder.21.norm1.weight', 'decoder.21.norm1.bias', 'decoder.21.attn.qkv.weight', 'decoder.21.attn.qkv.bias', 'decoder.21.attn.proj.weight', 'decoder.21.attn.proj.bias', 'decoder.21.attn.q_norm.weight', 'decoder.21.attn.q_norm.bias', 'decoder.21.attn.k_norm.weight', 'decoder.21.attn.k_norm.bias', 'decoder.21.ls1.gamma', 'decoder.21.norm2.weight', 'decoder.21.norm2.bias', 'decoder.21.mlp.fc1.weight', 'decoder.21.mlp.fc1.bias', 'decoder.21.mlp.fc2.weight', 'decoder.21.mlp.fc2.bias', 'decoder.21.ls2.gamma', 'decoder.22.norm1.weight', 'decoder.22.norm1.bias', 'decoder.22.attn.qkv.weight', 'decoder.22.attn.qkv.bias', 'decoder.22.attn.proj.weight', 'decoder.22.attn.proj.bias', 'decoder.22.attn.q_norm.weight', 'decoder.22.attn.q_norm.bias', 'decoder.22.attn.k_norm.weight', 'decoder.22.attn.k_norm.bias', 'decoder.22.ls1.gamma', 'decoder.22.norm2.weight', 'decoder.22.norm2.bias', 'decoder.22.mlp.fc1.weight', 'decoder.22.mlp.fc1.bias', 'decoder.22.mlp.fc2.weight', 'decoder.22.mlp.fc2.bias', 'decoder.22.ls2.gamma', 'decoder.23.norm1.weight', 'decoder.23.norm1.bias', 'decoder.23.attn.qkv.weight', 'decoder.23.attn.qkv.bias', 'decoder.23.attn.proj.weight', 'decoder.23.attn.proj.bias', 'decoder.23.attn.q_norm.weight', 'decoder.23.attn.q_norm.bias', 'decoder.23.attn.k_norm.weight', 'decoder.23.attn.k_norm.bias', 'decoder.23.ls1.gamma', 'decoder.23.norm2.weight', 'decoder.23.norm2.bias', 'decoder.23.mlp.fc1.weight', 'decoder.23.mlp.fc1.bias', 'decoder.23.mlp.fc2.weight', 'decoder.23.mlp.fc2.bias', 'decoder.23.ls2.gamma', 'decoder.24.norm1.weight', 'decoder.24.norm1.bias', 'decoder.24.attn.qkv.weight', 'decoder.24.attn.qkv.bias', 'decoder.24.attn.proj.weight', 'decoder.24.attn.proj.bias', 'decoder.24.attn.q_norm.weight', 'decoder.24.attn.q_norm.bias', 'decoder.24.attn.k_norm.weight', 'decoder.24.attn.k_norm.bias', 'decoder.24.ls1.gamma', 'decoder.24.norm2.weight', 'decoder.24.norm2.bias', 'decoder.24.mlp.fc1.weight', 'decoder.24.mlp.fc1.bias', 'decoder.24.mlp.fc2.weight', 'decoder.24.mlp.fc2.bias', 'decoder.24.ls2.gamma', 'decoder.25.norm1.weight', 'decoder.25.norm1.bias', 'decoder.25.attn.qkv.weight', 'decoder.25.attn.qkv.bias', 'decoder.25.attn.proj.weight', 'decoder.25.attn.proj.bias', 'decoder.25.attn.q_norm.weight', 'decoder.25.attn.q_norm.bias', 'decoder.25.attn.k_norm.weight', 'decoder.25.attn.k_norm.bias', 'decoder.25.ls1.gamma', 'decoder.25.norm2.weight', 'decoder.25.norm2.bias', 'decoder.25.mlp.fc1.weight', 'decoder.25.mlp.fc1.bias', 'decoder.25.mlp.fc2.weight', 'decoder.25.mlp.fc2.bias', 'decoder.25.ls2.gamma', 'decoder.26.norm1.weight', 'decoder.26.norm1.bias', 'decoder.26.attn.qkv.weight', 'decoder.26.attn.qkv.bias', 'decoder.26.attn.proj.weight', 'decoder.26.attn.proj.bias', 'decoder.26.attn.q_norm.weight', 'decoder.26.attn.q_norm.bias', 'decoder.26.attn.k_norm.weight', 'decoder.26.attn.k_norm.bias', 'decoder.26.ls1.gamma', 'decoder.26.norm2.weight', 'decoder.26.norm2.bias', 'decoder.26.mlp.fc1.weight', 'decoder.26.mlp.fc1.bias', 'decoder.26.mlp.fc2.weight', 'decoder.26.mlp.fc2.bias', 'decoder.26.ls2.gamma', 'decoder.27.norm1.weight', 'decoder.27.norm1.bias', 'decoder.27.attn.qkv.weight', 'decoder.27.attn.qkv.bias', 'decoder.27.attn.proj.weight', 'decoder.27.attn.proj.bias', 'decoder.27.attn.q_norm.weight', 'decoder.27.attn.q_norm.bias', 'decoder.27.attn.k_norm.weight', 'decoder.27.attn.k_norm.bias', 'decoder.27.ls1.gamma', 'decoder.27.norm2.weight', 'decoder.27.norm2.bias', 'decoder.27.mlp.fc1.weight', 'decoder.27.mlp.fc1.bias', 'decoder.27.mlp.fc2.weight', 'decoder.27.mlp.fc2.bias', 'decoder.27.ls2.gamma', 'decoder.28.norm1.weight', 'decoder.28.norm1.bias', 'decoder.28.attn.qkv.weight', 'decoder.28.attn.qkv.bias', 'decoder.28.attn.proj.weight', 'decoder.28.attn.proj.bias', 'decoder.28.attn.q_norm.weight', 'decoder.28.attn.q_norm.bias', 'decoder.28.attn.k_norm.weight', 'decoder.28.attn.k_norm.bias', 'decoder.28.ls1.gamma', 'decoder.28.norm2.weight', 'decoder.28.norm2.bias', 'decoder.28.mlp.fc1.weight', 'decoder.28.mlp.fc1.bias', 'decoder.28.mlp.fc2.weight', 'decoder.28.mlp.fc2.bias', 'decoder.28.ls2.gamma', 'decoder.29.norm1.weight', 'decoder.29.norm1.bias', 'decoder.29.attn.qkv.weight', 'decoder.29.attn.qkv.bias', 'decoder.29.attn.proj.weight', 'decoder.29.attn.proj.bias', 'decoder.29.attn.q_norm.weight', 'decoder.29.attn.q_norm.bias', 'decoder.29.attn.k_norm.weight', 'decoder.29.attn.k_norm.bias', 'decoder.29.ls1.gamma', 'decoder.29.norm2.weight', 'decoder.29.norm2.bias', 'decoder.29.mlp.fc1.weight', 'decoder.29.mlp.fc1.bias', 'decoder.29.mlp.fc2.weight', 'decoder.29.mlp.fc2.bias', 'decoder.29.ls2.gamma', 'decoder.30.norm1.weight', 'decoder.30.norm1.bias', 'decoder.30.attn.qkv.weight', 'decoder.30.attn.qkv.bias', 'decoder.30.attn.proj.weight', 'decoder.30.attn.proj.bias', 'decoder.30.attn.q_norm.weight', 'decoder.30.attn.q_norm.bias', 'decoder.30.attn.k_norm.weight', 'decoder.30.attn.k_norm.bias', 'decoder.30.ls1.gamma', 'decoder.30.norm2.weight', 'decoder.30.norm2.bias', 'decoder.30.mlp.fc1.weight', 'decoder.30.mlp.fc1.bias', 'decoder.30.mlp.fc2.weight', 'decoder.30.mlp.fc2.bias', 'decoder.30.ls2.gamma', 'decoder.31.norm1.weight', 'decoder.31.norm1.bias', 'decoder.31.attn.qkv.weight', 'decoder.31.attn.qkv.bias', 'decoder.31.attn.proj.weight', 'decoder.31.attn.proj.bias', 'decoder.31.attn.q_norm.weight', 'decoder.31.attn.q_norm.bias', 'decoder.31.attn.k_norm.weight', 'decoder.31.attn.k_norm.bias', 'decoder.31.ls1.gamma', 'decoder.31.norm2.weight', 'decoder.31.norm2.bias', 'decoder.31.mlp.fc1.weight', 'decoder.31.mlp.fc1.bias', 'decoder.31.mlp.fc2.weight', 'decoder.31.mlp.fc2.bias', 'decoder.31.ls2.gamma', 'decoder.32.norm1.weight', 'decoder.32.norm1.bias', 'decoder.32.attn.qkv.weight', 'decoder.32.attn.qkv.bias', 'decoder.32.attn.proj.weight', 'decoder.32.attn.proj.bias', 'decoder.32.attn.q_norm.weight', 'decoder.32.attn.q_norm.bias', 'decoder.32.attn.k_norm.weight', 'decoder.32.attn.k_norm.bias', 'decoder.32.ls1.gamma', 'decoder.32.norm2.weight', 'decoder.32.norm2.bias', 'decoder.32.mlp.fc1.weight', 'decoder.32.mlp.fc1.bias', 'decoder.32.mlp.fc2.weight', 'decoder.32.mlp.fc2.bias', 'decoder.32.ls2.gamma', 'decoder.33.norm1.weight', 'decoder.33.norm1.bias', 'decoder.33.attn.qkv.weight', 'decoder.33.attn.qkv.bias', 'decoder.33.attn.proj.weight', 'decoder.33.attn.proj.bias', 'decoder.33.attn.q_norm.weight', 'decoder.33.attn.q_norm.bias', 'decoder.33.attn.k_norm.weight', 'decoder.33.attn.k_norm.bias', 'decoder.33.ls1.gamma', 'decoder.33.norm2.weight', 'decoder.33.norm2.bias', 'decoder.33.mlp.fc1.weight', 'decoder.33.mlp.fc1.bias', 'decoder.33.mlp.fc2.weight', 'decoder.33.mlp.fc2.bias', 'decoder.33.ls2.gamma', 'decoder.34.norm1.weight', 'decoder.34.norm1.bias', 'decoder.34.attn.qkv.weight', 'decoder.34.attn.qkv.bias', 'decoder.34.attn.proj.weight', 'decoder.34.attn.proj.bias', 'decoder.34.attn.q_norm.weight', 'decoder.34.attn.q_norm.bias', 'decoder.34.attn.k_norm.weight', 'decoder.34.attn.k_norm.bias', 'decoder.34.ls1.gamma', 'decoder.34.norm2.weight', 'decoder.34.norm2.bias', 'decoder.34.mlp.fc1.weight', 'decoder.34.mlp.fc1.bias', 'decoder.34.mlp.fc2.weight', 'decoder.34.mlp.fc2.bias', 'decoder.34.ls2.gamma', 'decoder.35.norm1.weight', 'decoder.35.norm1.bias', 'decoder.35.attn.qkv.weight', 'decoder.35.attn.qkv.bias', 'decoder.35.attn.proj.weight', 'decoder.35.attn.proj.bias', 'decoder.35.attn.q_norm.weight', 'decoder.35.attn.q_norm.bias', 'decoder.35.attn.k_norm.weight', 'decoder.35.attn.k_norm.bias', 'decoder.35.ls1.gamma', 'decoder.35.norm2.weight', 'decoder.35.norm2.bias', 'decoder.35.mlp.fc1.weight', 'decoder.35.mlp.fc1.bias', 'decoder.35.mlp.fc2.weight', 'decoder.35.mlp.fc2.bias', 'decoder.35.ls2.gamma', 'point_decoder.projects.weight', 'point_decoder.projects.bias', 'point_decoder.blocks.0.norm1.weight', 'point_decoder.blocks.0.norm1.bias', 'point_decoder.blocks.0.attn.qkv.weight', 'point_decoder.blocks.0.attn.qkv.bias', 'point_decoder.blocks.0.attn.proj.weight', 'point_decoder.blocks.0.attn.proj.bias', 'point_decoder.blocks.0.norm2.weight', 'point_decoder.blocks.0.norm2.bias', 'point_decoder.blocks.0.mlp.fc1.weight', 'point_decoder.blocks.0.mlp.fc1.bias', 'point_decoder.blocks.0.mlp.fc2.weight', 'point_decoder.blocks.0.mlp.fc2.bias', 'point_decoder.blocks.1.norm1.weight', 'point_decoder.blocks.1.norm1.bias', 'point_decoder.blocks.1.attn.qkv.weight', 'point_decoder.blocks.1.attn.qkv.bias', 'point_decoder.blocks.1.attn.proj.weight', 'point_decoder.blocks.1.attn.proj.bias', 'point_decoder.blocks.1.norm2.weight', 'point_decoder.blocks.1.norm2.bias', 'point_decoder.blocks.1.mlp.fc1.weight', 'point_decoder.blocks.1.mlp.fc1.bias', 'point_decoder.blocks.1.mlp.fc2.weight', 'point_decoder.blocks.1.mlp.fc2.bias', 'point_decoder.blocks.2.norm1.weight', 'point_decoder.blocks.2.norm1.bias', 'point_decoder.blocks.2.attn.qkv.weight', 'point_decoder.blocks.2.attn.qkv.bias', 'point_decoder.blocks.2.attn.proj.weight', 'point_decoder.blocks.2.attn.proj.bias', 'point_decoder.blocks.2.norm2.weight', 'point_decoder.blocks.2.norm2.bias', 'point_decoder.blocks.2.mlp.fc1.weight', 'point_decoder.blocks.2.mlp.fc1.bias', 'point_decoder.blocks.2.mlp.fc2.weight', 'point_decoder.blocks.2.mlp.fc2.bias', 'point_decoder.blocks.3.norm1.weight', 'point_decoder.blocks.3.norm1.bias', 'point_decoder.blocks.3.attn.qkv.weight', 'point_decoder.blocks.3.attn.qkv.bias', 'point_decoder.blocks.3.attn.proj.weight', 'point_decoder.blocks.3.attn.proj.bias', 'point_decoder.blocks.3.norm2.weight', 'point_decoder.blocks.3.norm2.bias', 'point_decoder.blocks.3.mlp.fc1.weight', 'point_decoder.blocks.3.mlp.fc1.bias', 'point_decoder.blocks.3.mlp.fc2.weight', 'point_decoder.blocks.3.mlp.fc2.bias', 'point_decoder.blocks.4.norm1.weight', 'point_decoder.blocks.4.norm1.bias', 'point_decoder.blocks.4.attn.qkv.weight', 'point_decoder.blocks.4.attn.qkv.bias', 'point_decoder.blocks.4.attn.proj.weight', 'point_decoder.blocks.4.attn.proj.bias', 'point_decoder.blocks.4.norm2.weight', 'point_decoder.blocks.4.norm2.bias', 'point_decoder.blocks.4.mlp.fc1.weight', 'point_decoder.blocks.4.mlp.fc1.bias', 'point_decoder.blocks.4.mlp.fc2.weight', 'point_decoder.blocks.4.mlp.fc2.bias', 'point_decoder.linear_out.weight', 'point_decoder.linear_out.bias', 'point_head.proj.weight', 'point_head.proj.bias', 'conf_decoder.projects.weight', 'conf_decoder.projects.bias', 'conf_decoder.blocks.0.norm1.weight', 'conf_decoder.blocks.0.norm1.bias', 'conf_decoder.blocks.0.attn.qkv.weight', 'conf_decoder.blocks.0.attn.qkv.bias', 'conf_decoder.blocks.0.attn.proj.weight', 'conf_decoder.blocks.0.attn.proj.bias', 'conf_decoder.blocks.0.norm2.weight', 'conf_decoder.blocks.0.norm2.bias', 'conf_decoder.blocks.0.mlp.fc1.weight', 'conf_decoder.blocks.0.mlp.fc1.bias', 'conf_decoder.blocks.0.mlp.fc2.weight', 'conf_decoder.blocks.0.mlp.fc2.bias', 'conf_decoder.blocks.1.norm1.weight', 'conf_decoder.blocks.1.norm1.bias', 'conf_decoder.blocks.1.attn.qkv.weight', 'conf_decoder.blocks.1.attn.qkv.bias', 'conf_decoder.blocks.1.attn.proj.weight', 'conf_decoder.blocks.1.attn.proj.bias', 'conf_decoder.blocks.1.norm2.weight', 'conf_decoder.blocks.1.norm2.bias', 'conf_decoder.blocks.1.mlp.fc1.weight', 'conf_decoder.blocks.1.mlp.fc1.bias', 'conf_decoder.blocks.1.mlp.fc2.weight', 'conf_decoder.blocks.1.mlp.fc2.bias', 'conf_decoder.blocks.2.norm1.weight', 'conf_decoder.blocks.2.norm1.bias', 'conf_decoder.blocks.2.attn.qkv.weight', 'conf_decoder.blocks.2.attn.qkv.bias', 'conf_decoder.blocks.2.attn.proj.weight', 'conf_decoder.blocks.2.attn.proj.bias', 'conf_decoder.blocks.2.norm2.weight', 'conf_decoder.blocks.2.norm2.bias', 'conf_decoder.blocks.2.mlp.fc1.weight', 'conf_decoder.blocks.2.mlp.fc1.bias', 'conf_decoder.blocks.2.mlp.fc2.weight', 'conf_decoder.blocks.2.mlp.fc2.bias', 'conf_decoder.blocks.3.norm1.weight', 'conf_decoder.blocks.3.norm1.bias', 'conf_decoder.blocks.3.attn.qkv.weight', 'conf_decoder.blocks.3.attn.qkv.bias', 'conf_decoder.blocks.3.attn.proj.weight', 'conf_decoder.blocks.3.attn.proj.bias', 'conf_decoder.blocks.3.norm2.weight', 'conf_decoder.blocks.3.norm2.bias', 'conf_decoder.blocks.3.mlp.fc1.weight', 'conf_decoder.blocks.3.mlp.fc1.bias', 'conf_decoder.blocks.3.mlp.fc2.weight', 'conf_decoder.blocks.3.mlp.fc2.bias', 'conf_decoder.blocks.4.norm1.weight', 'conf_decoder.blocks.4.norm1.bias', 'conf_decoder.blocks.4.attn.qkv.weight', 'conf_decoder.blocks.4.attn.qkv.bias', 'conf_decoder.blocks.4.attn.proj.weight', 'conf_decoder.blocks.4.attn.proj.bias', 'conf_decoder.blocks.4.norm2.weight', 'conf_decoder.blocks.4.norm2.bias', 'conf_decoder.blocks.4.mlp.fc1.weight', 'conf_decoder.blocks.4.mlp.fc1.bias', 'conf_decoder.blocks.4.mlp.fc2.weight', 'conf_decoder.blocks.4.mlp.fc2.bias', 'conf_decoder.linear_out.weight', 'conf_decoder.linear_out.bias', 'conf_head.proj.weight', 'conf_head.proj.bias', 'camera_decoder.projects.weight', 'camera_decoder.projects.bias', 'camera_decoder.blocks.0.norm1.weight', 'camera_decoder.blocks.0.norm1.bias', 'camera_decoder.blocks.0.attn.qkv.weight', 'camera_decoder.blocks.0.attn.qkv.bias', 'camera_decoder.blocks.0.attn.proj.weight', 'camera_decoder.blocks.0.attn.proj.bias', 'camera_decoder.blocks.0.norm2.weight', 'camera_decoder.blocks.0.norm2.bias', 'camera_decoder.blocks.0.mlp.fc1.weight', 'camera_decoder.blocks.0.mlp.fc1.bias', 'camera_decoder.blocks.0.mlp.fc2.weight', 'camera_decoder.blocks.0.mlp.fc2.bias', 'camera_decoder.blocks.1.norm1.weight', 'camera_decoder.blocks.1.norm1.bias', 'camera_decoder.blocks.1.attn.qkv.weight', 'camera_decoder.blocks.1.attn.qkv.bias', 'camera_decoder.blocks.1.attn.proj.weight', 'camera_decoder.blocks.1.attn.proj.bias', 'camera_decoder.blocks.1.norm2.weight', 'camera_decoder.blocks.1.norm2.bias', 'camera_decoder.blocks.1.mlp.fc1.weight', 'camera_decoder.blocks.1.mlp.fc1.bias', 'camera_decoder.blocks.1.mlp.fc2.weight', 'camera_decoder.blocks.1.mlp.fc2.bias', 'camera_decoder.blocks.2.norm1.weight', 'camera_decoder.blocks.2.norm1.bias', 'camera_decoder.blocks.2.attn.qkv.weight', 'camera_decoder.blocks.2.attn.qkv.bias', 'camera_decoder.blocks.2.attn.proj.weight', 'camera_decoder.blocks.2.attn.proj.bias', 'camera_decoder.blocks.2.norm2.weight', 'camera_decoder.blocks.2.norm2.bias', 'camera_decoder.blocks.2.mlp.fc1.weight', 'camera_decoder.blocks.2.mlp.fc1.bias', 'camera_decoder.blocks.2.mlp.fc2.weight', 'camera_decoder.blocks.2.mlp.fc2.bias', 'camera_decoder.blocks.3.norm1.weight', 'camera_decoder.blocks.3.norm1.bias', 'camera_decoder.blocks.3.attn.qkv.weight', 'camera_decoder.blocks.3.attn.qkv.bias', 'camera_decoder.blocks.3.attn.proj.weight', 'camera_decoder.blocks.3.attn.proj.bias', 'camera_decoder.blocks.3.norm2.weight', 'camera_decoder.blocks.3.norm2.bias', 'camera_decoder.blocks.3.mlp.fc1.weight', 'camera_decoder.blocks.3.mlp.fc1.bias', 'camera_decoder.blocks.3.mlp.fc2.weight', 'camera_decoder.blocks.3.mlp.fc2.bias', 'camera_decoder.blocks.4.norm1.weight', 'camera_decoder.blocks.4.norm1.bias', 'camera_decoder.blocks.4.attn.qkv.weight', 'camera_decoder.blocks.4.attn.qkv.bias', 'camera_decoder.blocks.4.attn.proj.weight', 'camera_decoder.blocks.4.attn.proj.bias', 'camera_decoder.blocks.4.norm2.weight', 'camera_decoder.blocks.4.norm2.bias', 'camera_decoder.blocks.4.mlp.fc1.weight', 'camera_decoder.blocks.4.mlp.fc1.bias', 'camera_decoder.blocks.4.mlp.fc2.weight', 'camera_decoder.blocks.4.mlp.fc2.bias', 'camera_decoder.linear_out.weight', 'camera_decoder.linear_out.bias', 'camera_head.res_conv.0.res_conv1.weight', 'camera_head.res_conv.0.res_conv1.bias', 'camera_head.res_conv.0.res_conv2.weight', 'camera_head.res_conv.0.res_conv2.bias', 'camera_head.res_conv.0.res_conv3.weight', 'camera_head.res_conv.0.res_conv3.bias', 'camera_head.res_conv.1.res_conv1.weight', 'camera_head.res_conv.1.res_conv1.bias', 'camera_head.res_conv.1.res_conv2.weight', 'camera_head.res_conv.1.res_conv2.bias', 'camera_head.res_conv.1.res_conv3.weight', 'camera_head.res_conv.1.res_conv3.bias', 'camera_head.more_mlps.0.weight', 'camera_head.more_mlps.0.bias', 'camera_head.more_mlps.2.weight', 'camera_head.more_mlps.2.bias', 'camera_head.fc_t.weight', 'camera_head.fc_t.bias', 'camera_head.fc_rot.weight', 'camera_head.fc_rot.bias'], unexpected_keys=['module.register_token', 'module.image_mean', 'module.image_std', 'module.encoder.cls_token', 'module.encoder.pos_embed', 'module.encoder.register_tokens', 'module.encoder.patch_embed.proj.weight', 'module.encoder.patch_embed.proj.bias', 'module.encoder.blocks.0.norm1.weight', 'module.encoder.blocks.0.norm1.bias', 'module.encoder.blocks.0.attn.qkv.weight', 'module.encoder.blocks.0.attn.qkv.bias', 'module.encoder.blocks.0.attn.proj.weight', 'module.encoder.blocks.0.attn.proj.bias', 'module.encoder.blocks.0.ls1.gamma', 'module.encoder.blocks.0.norm2.weight', 'module.encoder.blocks.0.norm2.bias', 'module.encoder.blocks.0.mlp.fc1.weight', 'module.encoder.blocks.0.mlp.fc1.bias', 'module.encoder.blocks.0.mlp.fc2.weight', 'module.encoder.blocks.0.mlp.fc2.bias', 'module.encoder.blocks.0.ls2.gamma', 'module.encoder.blocks.1.norm1.weight', 'module.encoder.blocks.1.norm1.bias', 'module.encoder.blocks.1.attn.qkv.weight', 'module.encoder.blocks.1.attn.qkv.bias', 'module.encoder.blocks.1.attn.proj.weight', 'module.encoder.blocks.1.attn.proj.bias', 'module.encoder.blocks.1.ls1.gamma', 'module.encoder.blocks.1.norm2.weight', 'module.encoder.blocks.1.norm2.bias', 'module.encoder.blocks.1.mlp.fc1.weight', 'module.encoder.blocks.1.mlp.fc1.bias', 'module.encoder.blocks.1.mlp.fc2.weight', 'module.encoder.blocks.1.mlp.fc2.bias', 'module.encoder.blocks.1.ls2.gamma', 'module.encoder.blocks.2.norm1.weight', 'module.encoder.blocks.2.norm1.bias', 'module.encoder.blocks.2.attn.qkv.weight', 'module.encoder.blocks.2.attn.qkv.bias', 'module.encoder.blocks.2.attn.proj.weight', 'module.encoder.blocks.2.attn.proj.bias', 'module.encoder.blocks.2.ls1.gamma', 'module.encoder.blocks.2.norm2.weight', 'module.encoder.blocks.2.norm2.bias', 'module.encoder.blocks.2.mlp.fc1.weight', 'module.encoder.blocks.2.mlp.fc1.bias', 'module.encoder.blocks.2.mlp.fc2.weight', 'module.encoder.blocks.2.mlp.fc2.bias', 'module.encoder.blocks.2.ls2.gamma', 'module.encoder.blocks.3.norm1.weight', 'module.encoder.blocks.3.norm1.bias', 'module.encoder.blocks.3.attn.qkv.weight', 'module.encoder.blocks.3.attn.qkv.bias', 'module.encoder.blocks.3.attn.proj.weight', 'module.encoder.blocks.3.attn.proj.bias', 'module.encoder.blocks.3.ls1.gamma', 'module.encoder.blocks.3.norm2.weight', 'module.encoder.blocks.3.norm2.bias', 'module.encoder.blocks.3.mlp.fc1.weight', 'module.encoder.blocks.3.mlp.fc1.bias', 'module.encoder.blocks.3.mlp.fc2.weight', 'module.encoder.blocks.3.mlp.fc2.bias', 'module.encoder.blocks.3.ls2.gamma', 'module.encoder.blocks.4.norm1.weight', 'module.encoder.blocks.4.norm1.bias', 'module.encoder.blocks.4.attn.qkv.weight', 'module.encoder.blocks.4.attn.qkv.bias', 'module.encoder.blocks.4.attn.proj.weight', 'module.encoder.blocks.4.attn.proj.bias', 'module.encoder.blocks.4.ls1.gamma', 'module.encoder.blocks.4.norm2.weight', 'module.encoder.blocks.4.norm2.bias', 'module.encoder.blocks.4.mlp.fc1.weight', 'module.encoder.blocks.4.mlp.fc1.bias', 'module.encoder.blocks.4.mlp.fc2.weight', 'module.encoder.blocks.4.mlp.fc2.bias', 'module.encoder.blocks.4.ls2.gamma', 'module.encoder.blocks.5.norm1.weight', 'module.encoder.blocks.5.norm1.bias', 'module.encoder.blocks.5.attn.qkv.weight', 'module.encoder.blocks.5.attn.qkv.bias', 'module.encoder.blocks.5.attn.proj.weight', 'module.encoder.blocks.5.attn.proj.bias', 'module.encoder.blocks.5.ls1.gamma', 'module.encoder.blocks.5.norm2.weight', 'module.encoder.blocks.5.norm2.bias', 'module.encoder.blocks.5.mlp.fc1.weight', 'module.encoder.blocks.5.mlp.fc1.bias', 'module.encoder.blocks.5.mlp.fc2.weight', 'module.encoder.blocks.5.mlp.fc2.bias', 'module.encoder.blocks.5.ls2.gamma', 'module.encoder.blocks.6.norm1.weight', 'module.encoder.blocks.6.norm1.bias', 'module.encoder.blocks.6.attn.qkv.weight', 'module.encoder.blocks.6.attn.qkv.bias', 'module.encoder.blocks.6.attn.proj.weight', 'module.encoder.blocks.6.attn.proj.bias', 'module.encoder.blocks.6.ls1.gamma', 'module.encoder.blocks.6.norm2.weight', 'module.encoder.blocks.6.norm2.bias', 'module.encoder.blocks.6.mlp.fc1.weight', 'module.encoder.blocks.6.mlp.fc1.bias', 'module.encoder.blocks.6.mlp.fc2.weight', 'module.encoder.blocks.6.mlp.fc2.bias', 'module.encoder.blocks.6.ls2.gamma', 'module.encoder.blocks.7.norm1.weight', 'module.encoder.blocks.7.norm1.bias', 'module.encoder.blocks.7.attn.qkv.weight', 'module.encoder.blocks.7.attn.qkv.bias', 'module.encoder.blocks.7.attn.proj.weight', 'module.encoder.blocks.7.attn.proj.bias', 'module.encoder.blocks.7.ls1.gamma', 'module.encoder.blocks.7.norm2.weight', 'module.encoder.blocks.7.norm2.bias', 'module.encoder.blocks.7.mlp.fc1.weight', 'module.encoder.blocks.7.mlp.fc1.bias', 'module.encoder.blocks.7.mlp.fc2.weight', 'module.encoder.blocks.7.mlp.fc2.bias', 'module.encoder.blocks.7.ls2.gamma', 'module.encoder.blocks.8.norm1.weight', 'module.encoder.blocks.8.norm1.bias', 'module.encoder.blocks.8.attn.qkv.weight', 'module.encoder.blocks.8.attn.qkv.bias', 'module.encoder.blocks.8.attn.proj.weight', 'module.encoder.blocks.8.attn.proj.bias', 'module.encoder.blocks.8.ls1.gamma', 'module.encoder.blocks.8.norm2.weight', 'module.encoder.blocks.8.norm2.bias', 'module.encoder.blocks.8.mlp.fc1.weight', 'module.encoder.blocks.8.mlp.fc1.bias', 'module.encoder.blocks.8.mlp.fc2.weight', 'module.encoder.blocks.8.mlp.fc2.bias', 'module.encoder.blocks.8.ls2.gamma', 'module.encoder.blocks.9.norm1.weight', 'module.encoder.blocks.9.norm1.bias', 'module.encoder.blocks.9.attn.qkv.weight', 'module.encoder.blocks.9.attn.qkv.bias', 'module.encoder.blocks.9.attn.proj.weight', 'module.encoder.blocks.9.attn.proj.bias', 'module.encoder.blocks.9.ls1.gamma', 'module.encoder.blocks.9.norm2.weight', 'module.encoder.blocks.9.norm2.bias', 'module.encoder.blocks.9.mlp.fc1.weight', 'module.encoder.blocks.9.mlp.fc1.bias', 'module.encoder.blocks.9.mlp.fc2.weight', 'module.encoder.blocks.9.mlp.fc2.bias', 'module.encoder.blocks.9.ls2.gamma', 'module.encoder.blocks.10.norm1.weight', 'module.encoder.blocks.10.norm1.bias', 'module.encoder.blocks.10.attn.qkv.weight', 'module.encoder.blocks.10.attn.qkv.bias', 'module.encoder.blocks.10.attn.proj.weight', 'module.encoder.blocks.10.attn.proj.bias', 'module.encoder.blocks.10.ls1.gamma', 'module.encoder.blocks.10.norm2.weight', 'module.encoder.blocks.10.norm2.bias', 'module.encoder.blocks.10.mlp.fc1.weight', 'module.encoder.blocks.10.mlp.fc1.bias', 'module.encoder.blocks.10.mlp.fc2.weight', 'module.encoder.blocks.10.mlp.fc2.bias', 'module.encoder.blocks.10.ls2.gamma', 'module.encoder.blocks.11.norm1.weight', 'module.encoder.blocks.11.norm1.bias', 'module.encoder.blocks.11.attn.qkv.weight', 'module.encoder.blocks.11.attn.qkv.bias', 'module.encoder.blocks.11.attn.proj.weight', 'module.encoder.blocks.11.attn.proj.bias', 'module.encoder.blocks.11.ls1.gamma', 'module.encoder.blocks.11.norm2.weight', 'module.encoder.blocks.11.norm2.bias', 'module.encoder.blocks.11.mlp.fc1.weight', 'module.encoder.blocks.11.mlp.fc1.bias', 'module.encoder.blocks.11.mlp.fc2.weight', 'module.encoder.blocks.11.mlp.fc2.bias', 'module.encoder.blocks.11.ls2.gamma', 'module.encoder.blocks.12.norm1.weight', 'module.encoder.blocks.12.norm1.bias', 'module.encoder.blocks.12.attn.qkv.weight', 'module.encoder.blocks.12.attn.qkv.bias', 'module.encoder.blocks.12.attn.proj.weight', 'module.encoder.blocks.12.attn.proj.bias', 'module.encoder.blocks.12.ls1.gamma', 'module.encoder.blocks.12.norm2.weight', 'module.encoder.blocks.12.norm2.bias', 'module.encoder.blocks.12.mlp.fc1.weight', 'module.encoder.blocks.12.mlp.fc1.bias', 'module.encoder.blocks.12.mlp.fc2.weight', 'module.encoder.blocks.12.mlp.fc2.bias', 'module.encoder.blocks.12.ls2.gamma', 'module.encoder.blocks.13.norm1.weight', 'module.encoder.blocks.13.norm1.bias', 'module.encoder.blocks.13.attn.qkv.weight', 'module.encoder.blocks.13.attn.qkv.bias', 'module.encoder.blocks.13.attn.proj.weight', 'module.encoder.blocks.13.attn.proj.bias', 'module.encoder.blocks.13.ls1.gamma', 'module.encoder.blocks.13.norm2.weight', 'module.encoder.blocks.13.norm2.bias', 'module.encoder.blocks.13.mlp.fc1.weight', 'module.encoder.blocks.13.mlp.fc1.bias', 'module.encoder.blocks.13.mlp.fc2.weight', 'module.encoder.blocks.13.mlp.fc2.bias', 'module.encoder.blocks.13.ls2.gamma', 'module.encoder.blocks.14.norm1.weight', 'module.encoder.blocks.14.norm1.bias', 'module.encoder.blocks.14.attn.qkv.weight', 'module.encoder.blocks.14.attn.qkv.bias', 'module.encoder.blocks.14.attn.proj.weight', 'module.encoder.blocks.14.attn.proj.bias', 'module.encoder.blocks.14.ls1.gamma', 'module.encoder.blocks.14.norm2.weight', 'module.encoder.blocks.14.norm2.bias', 'module.encoder.blocks.14.mlp.fc1.weight', 'module.encoder.blocks.14.mlp.fc1.bias', 'module.encoder.blocks.14.mlp.fc2.weight', 'module.encoder.blocks.14.mlp.fc2.bias', 'module.encoder.blocks.14.ls2.gamma', 'module.encoder.blocks.15.norm1.weight', 'module.encoder.blocks.15.norm1.bias', 'module.encoder.blocks.15.attn.qkv.weight', 'module.encoder.blocks.15.attn.qkv.bias', 'module.encoder.blocks.15.attn.proj.weight', 'module.encoder.blocks.15.attn.proj.bias', 'module.encoder.blocks.15.ls1.gamma', 'module.encoder.blocks.15.norm2.weight', 'module.encoder.blocks.15.norm2.bias', 'module.encoder.blocks.15.mlp.fc1.weight', 'module.encoder.blocks.15.mlp.fc1.bias', 'module.encoder.blocks.15.mlp.fc2.weight', 'module.encoder.blocks.15.mlp.fc2.bias', 'module.encoder.blocks.15.ls2.gamma', 'module.encoder.blocks.16.norm1.weight', 'module.encoder.blocks.16.norm1.bias', 'module.encoder.blocks.16.attn.qkv.weight', 'module.encoder.blocks.16.attn.qkv.bias', 'module.encoder.blocks.16.attn.proj.weight', 'module.encoder.blocks.16.attn.proj.bias', 'module.encoder.blocks.16.ls1.gamma', 'module.encoder.blocks.16.norm2.weight', 'module.encoder.blocks.16.norm2.bias', 'module.encoder.blocks.16.mlp.fc1.weight', 'module.encoder.blocks.16.mlp.fc1.bias', 'module.encoder.blocks.16.mlp.fc2.weight', 'module.encoder.blocks.16.mlp.fc2.bias', 'module.encoder.blocks.16.ls2.gamma', 'module.encoder.blocks.17.norm1.weight', 'module.encoder.blocks.17.norm1.bias', 'module.encoder.blocks.17.attn.qkv.weight', 'module.encoder.blocks.17.attn.qkv.bias', 'module.encoder.blocks.17.attn.proj.weight', 'module.encoder.blocks.17.attn.proj.bias', 'module.encoder.blocks.17.ls1.gamma', 'module.encoder.blocks.17.norm2.weight', 'module.encoder.blocks.17.norm2.bias', 'module.encoder.blocks.17.mlp.fc1.weight', 'module.encoder.blocks.17.mlp.fc1.bias', 'module.encoder.blocks.17.mlp.fc2.weight', 'module.encoder.blocks.17.mlp.fc2.bias', 'module.encoder.blocks.17.ls2.gamma', 'module.encoder.blocks.18.norm1.weight', 'module.encoder.blocks.18.norm1.bias', 'module.encoder.blocks.18.attn.qkv.weight', 'module.encoder.blocks.18.attn.qkv.bias', 'module.encoder.blocks.18.attn.proj.weight', 'module.encoder.blocks.18.attn.proj.bias', 'module.encoder.blocks.18.ls1.gamma', 'module.encoder.blocks.18.norm2.weight', 'module.encoder.blocks.18.norm2.bias', 'module.encoder.blocks.18.mlp.fc1.weight', 'module.encoder.blocks.18.mlp.fc1.bias', 'module.encoder.blocks.18.mlp.fc2.weight', 'module.encoder.blocks.18.mlp.fc2.bias', 'module.encoder.blocks.18.ls2.gamma', 'module.encoder.blocks.19.norm1.weight', 'module.encoder.blocks.19.norm1.bias', 'module.encoder.blocks.19.attn.qkv.weight', 'module.encoder.blocks.19.attn.qkv.bias', 'module.encoder.blocks.19.attn.proj.weight', 'module.encoder.blocks.19.attn.proj.bias', 'module.encoder.blocks.19.ls1.gamma', 'module.encoder.blocks.19.norm2.weight', 'module.encoder.blocks.19.norm2.bias', 'module.encoder.blocks.19.mlp.fc1.weight', 'module.encoder.blocks.19.mlp.fc1.bias', 'module.encoder.blocks.19.mlp.fc2.weight', 'module.encoder.blocks.19.mlp.fc2.bias', 'module.encoder.blocks.19.ls2.gamma', 'module.encoder.blocks.20.norm1.weight', 'module.encoder.blocks.20.norm1.bias', 'module.encoder.blocks.20.attn.qkv.weight', 'module.encoder.blocks.20.attn.qkv.bias', 'module.encoder.blocks.20.attn.proj.weight', 'module.encoder.blocks.20.attn.proj.bias', 'module.encoder.blocks.20.ls1.gamma', 'module.encoder.blocks.20.norm2.weight', 'module.encoder.blocks.20.norm2.bias', 'module.encoder.blocks.20.mlp.fc1.weight', 'module.encoder.blocks.20.mlp.fc1.bias', 'module.encoder.blocks.20.mlp.fc2.weight', 'module.encoder.blocks.20.mlp.fc2.bias', 'module.encoder.blocks.20.ls2.gamma', 'module.encoder.blocks.21.norm1.weight', 'module.encoder.blocks.21.norm1.bias', 'module.encoder.blocks.21.attn.qkv.weight', 'module.encoder.blocks.21.attn.qkv.bias', 'module.encoder.blocks.21.attn.proj.weight', 'module.encoder.blocks.21.attn.proj.bias', 'module.encoder.blocks.21.ls1.gamma', 'module.encoder.blocks.21.norm2.weight', 'module.encoder.blocks.21.norm2.bias', 'module.encoder.blocks.21.mlp.fc1.weight', 'module.encoder.blocks.21.mlp.fc1.bias', 'module.encoder.blocks.21.mlp.fc2.weight', 'module.encoder.blocks.21.mlp.fc2.bias', 'module.encoder.blocks.21.ls2.gamma', 'module.encoder.blocks.22.norm1.weight', 'module.encoder.blocks.22.norm1.bias', 'module.encoder.blocks.22.attn.qkv.weight', 'module.encoder.blocks.22.attn.qkv.bias', 'module.encoder.blocks.22.attn.proj.weight', 'module.encoder.blocks.22.attn.proj.bias', 'module.encoder.blocks.22.ls1.gamma', 'module.encoder.blocks.22.norm2.weight', 'module.encoder.blocks.22.norm2.bias', 'module.encoder.blocks.22.mlp.fc1.weight', 'module.encoder.blocks.22.mlp.fc1.bias', 'module.encoder.blocks.22.mlp.fc2.weight', 'module.encoder.blocks.22.mlp.fc2.bias', 'module.encoder.blocks.22.ls2.gamma', 'module.encoder.blocks.23.norm1.weight', 'module.encoder.blocks.23.norm1.bias', 'module.encoder.blocks.23.attn.qkv.weight', 'module.encoder.blocks.23.attn.qkv.bias', 'module.encoder.blocks.23.attn.proj.weight', 'module.encoder.blocks.23.attn.proj.bias', 'module.encoder.blocks.23.ls1.gamma', 'module.encoder.blocks.23.norm2.weight', 'module.encoder.blocks.23.norm2.bias', 'module.encoder.blocks.23.mlp.fc1.weight', 'module.encoder.blocks.23.mlp.fc1.bias', 'module.encoder.blocks.23.mlp.fc2.weight', 'module.encoder.blocks.23.mlp.fc2.bias', 'module.encoder.blocks.23.ls2.gamma', 'module.encoder.norm.weight', 'module.encoder.norm.bias', 'module.decoder.0.norm1.weight', 'module.decoder.0.norm1.bias', 'module.decoder.0.attn.qkv.weight', 'module.decoder.0.attn.qkv.bias', 'module.decoder.0.attn.proj.weight', 'module.decoder.0.attn.proj.bias', 'module.decoder.0.attn.q_norm.weight', 'module.decoder.0.attn.q_norm.bias', 'module.decoder.0.attn.k_norm.weight', 'module.decoder.0.attn.k_norm.bias', 'module.decoder.0.ls1.gamma', 'module.decoder.0.norm2.weight', 'module.decoder.0.norm2.bias', 'module.decoder.0.mlp.fc1.weight', 'module.decoder.0.mlp.fc1.bias', 'module.decoder.0.mlp.fc2.weight', 'module.decoder.0.mlp.fc2.bias', 'module.decoder.0.ls2.gamma', 'module.decoder.1.norm1.weight', 'module.decoder.1.norm1.bias', 'module.decoder.1.attn.qkv.weight', 'module.decoder.1.attn.qkv.bias', 'module.decoder.1.attn.proj.weight', 'module.decoder.1.attn.proj.bias', 'module.decoder.1.attn.q_norm.weight', 'module.decoder.1.attn.q_norm.bias', 'module.decoder.1.attn.k_norm.weight', 'module.decoder.1.attn.k_norm.bias', 'module.decoder.1.ls1.gamma', 'module.decoder.1.norm2.weight', 'module.decoder.1.norm2.bias', 'module.decoder.1.mlp.fc1.weight', 'module.decoder.1.mlp.fc1.bias', 'module.decoder.1.mlp.fc2.weight', 'module.decoder.1.mlp.fc2.bias', 'module.decoder.1.ls2.gamma', 'module.decoder.2.norm1.weight', 'module.decoder.2.norm1.bias', 'module.decoder.2.attn.qkv.weight', 'module.decoder.2.attn.qkv.bias', 'module.decoder.2.attn.proj.weight', 'module.decoder.2.attn.proj.bias', 'module.decoder.2.attn.q_norm.weight', 'module.decoder.2.attn.q_norm.bias', 'module.decoder.2.attn.k_norm.weight', 'module.decoder.2.attn.k_norm.bias', 'module.decoder.2.ls1.gamma', 'module.decoder.2.norm2.weight', 'module.decoder.2.norm2.bias', 'module.decoder.2.mlp.fc1.weight', 'module.decoder.2.mlp.fc1.bias', 'module.decoder.2.mlp.fc2.weight', 'module.decoder.2.mlp.fc2.bias', 'module.decoder.2.ls2.gamma', 'module.decoder.3.norm1.weight', 'module.decoder.3.norm1.bias', 'module.decoder.3.attn.qkv.weight', 'module.decoder.3.attn.qkv.bias', 'module.decoder.3.attn.proj.weight', 'module.decoder.3.attn.proj.bias', 'module.decoder.3.attn.q_norm.weight', 'module.decoder.3.attn.q_norm.bias', 'module.decoder.3.attn.k_norm.weight', 'module.decoder.3.attn.k_norm.bias', 'module.decoder.3.ls1.gamma', 'module.decoder.3.norm2.weight', 'module.decoder.3.norm2.bias', 'module.decoder.3.mlp.fc1.weight', 'module.decoder.3.mlp.fc1.bias', 'module.decoder.3.mlp.fc2.weight', 'module.decoder.3.mlp.fc2.bias', 'module.decoder.3.ls2.gamma', 'module.decoder.4.norm1.weight', 'module.decoder.4.norm1.bias', 'module.decoder.4.attn.qkv.weight', 'module.decoder.4.attn.qkv.bias', 'module.decoder.4.attn.proj.weight', 'module.decoder.4.attn.proj.bias', 'module.decoder.4.attn.q_norm.weight', 'module.decoder.4.attn.q_norm.bias', 'module.decoder.4.attn.k_norm.weight', 'module.decoder.4.attn.k_norm.bias', 'module.decoder.4.ls1.gamma', 'module.decoder.4.norm2.weight', 'module.decoder.4.norm2.bias', 'module.decoder.4.mlp.fc1.weight', 'module.decoder.4.mlp.fc1.bias', 'module.decoder.4.mlp.fc2.weight', 'module.decoder.4.mlp.fc2.bias', 'module.decoder.4.ls2.gamma', 'module.decoder.5.norm1.weight', 'module.decoder.5.norm1.bias', 'module.decoder.5.attn.qkv.weight', 'module.decoder.5.attn.qkv.bias', 'module.decoder.5.attn.proj.weight', 'module.decoder.5.attn.proj.bias', 'module.decoder.5.attn.q_norm.weight', 'module.decoder.5.attn.q_norm.bias', 'module.decoder.5.attn.k_norm.weight', 'module.decoder.5.attn.k_norm.bias', 'module.decoder.5.ls1.gamma', 'module.decoder.5.norm2.weight', 'module.decoder.5.norm2.bias', 'module.decoder.5.mlp.fc1.weight', 'module.decoder.5.mlp.fc1.bias', 'module.decoder.5.mlp.fc2.weight', 'module.decoder.5.mlp.fc2.bias', 'module.decoder.5.ls2.gamma', 'module.decoder.6.norm1.weight', 'module.decoder.6.norm1.bias', 'module.decoder.6.attn.qkv.weight', 'module.decoder.6.attn.qkv.bias', 'module.decoder.6.attn.proj.weight', 'module.decoder.6.attn.proj.bias', 'module.decoder.6.attn.q_norm.weight', 'module.decoder.6.attn.q_norm.bias', 'module.decoder.6.attn.k_norm.weight', 'module.decoder.6.attn.k_norm.bias', 'module.decoder.6.ls1.gamma', 'module.decoder.6.norm2.weight', 'module.decoder.6.norm2.bias', 'module.decoder.6.mlp.fc1.weight', 'module.decoder.6.mlp.fc1.bias', 'module.decoder.6.mlp.fc2.weight', 'module.decoder.6.mlp.fc2.bias', 'module.decoder.6.ls2.gamma', 'module.decoder.7.norm1.weight', 'module.decoder.7.norm1.bias', 'module.decoder.7.attn.qkv.weight', 'module.decoder.7.attn.qkv.bias', 'module.decoder.7.attn.proj.weight', 'module.decoder.7.attn.proj.bias', 'module.decoder.7.attn.q_norm.weight', 'module.decoder.7.attn.q_norm.bias', 'module.decoder.7.attn.k_norm.weight', 'module.decoder.7.attn.k_norm.bias', 'module.decoder.7.ls1.gamma', 'module.decoder.7.norm2.weight', 'module.decoder.7.norm2.bias', 'module.decoder.7.mlp.fc1.weight', 'module.decoder.7.mlp.fc1.bias', 'module.decoder.7.mlp.fc2.weight', 'module.decoder.7.mlp.fc2.bias', 'module.decoder.7.ls2.gamma', 'module.decoder.8.norm1.weight', 'module.decoder.8.norm1.bias', 'module.decoder.8.attn.qkv.weight', 'module.decoder.8.attn.qkv.bias', 'module.decoder.8.attn.proj.weight', 'module.decoder.8.attn.proj.bias', 'module.decoder.8.attn.q_norm.weight', 'module.decoder.8.attn.q_norm.bias', 'module.decoder.8.attn.k_norm.weight', 'module.decoder.8.attn.k_norm.bias', 'module.decoder.8.ls1.gamma', 'module.decoder.8.norm2.weight', 'module.decoder.8.norm2.bias', 'module.decoder.8.mlp.fc1.weight', 'module.decoder.8.mlp.fc1.bias', 'module.decoder.8.mlp.fc2.weight', 'module.decoder.8.mlp.fc2.bias', 'module.decoder.8.ls2.gamma', 'module.decoder.9.norm1.weight', 'module.decoder.9.norm1.bias', 'module.decoder.9.attn.qkv.weight', 'module.decoder.9.attn.qkv.bias', 'module.decoder.9.attn.proj.weight', 'module.decoder.9.attn.proj.bias', 'module.decoder.9.attn.q_norm.weight', 'module.decoder.9.attn.q_norm.bias', 'module.decoder.9.attn.k_norm.weight', 'module.decoder.9.attn.k_norm.bias', 'module.decoder.9.ls1.gamma', 'module.decoder.9.norm2.weight', 'module.decoder.9.norm2.bias', 'module.decoder.9.mlp.fc1.weight', 'module.decoder.9.mlp.fc1.bias', 'module.decoder.9.mlp.fc2.weight', 'module.decoder.9.mlp.fc2.bias', 'module.decoder.9.ls2.gamma', 'module.decoder.10.norm1.weight', 'module.decoder.10.norm1.bias', 'module.decoder.10.attn.qkv.weight', 'module.decoder.10.attn.qkv.bias', 'module.decoder.10.attn.proj.weight', 'module.decoder.10.attn.proj.bias', 'module.decoder.10.attn.q_norm.weight', 'module.decoder.10.attn.q_norm.bias', 'module.decoder.10.attn.k_norm.weight', 'module.decoder.10.attn.k_norm.bias', 'module.decoder.10.ls1.gamma', 'module.decoder.10.norm2.weight', 'module.decoder.10.norm2.bias', 'module.decoder.10.mlp.fc1.weight', 'module.decoder.10.mlp.fc1.bias', 'module.decoder.10.mlp.fc2.weight', 'module.decoder.10.mlp.fc2.bias', 'module.decoder.10.ls2.gamma', 'module.decoder.11.norm1.weight', 'module.decoder.11.norm1.bias', 'module.decoder.11.attn.qkv.weight', 'module.decoder.11.attn.qkv.bias', 'module.decoder.11.attn.proj.weight', 'module.decoder.11.attn.proj.bias', 'module.decoder.11.attn.q_norm.weight', 'module.decoder.11.attn.q_norm.bias', 'module.decoder.11.attn.k_norm.weight', 'module.decoder.11.attn.k_norm.bias', 'module.decoder.11.ls1.gamma', 'module.decoder.11.norm2.weight', 'module.decoder.11.norm2.bias', 'module.decoder.11.mlp.fc1.weight', 'module.decoder.11.mlp.fc1.bias', 'module.decoder.11.mlp.fc2.weight', 'module.decoder.11.mlp.fc2.bias', 'module.decoder.11.ls2.gamma', 'module.decoder.12.norm1.weight', 'module.decoder.12.norm1.bias', 'module.decoder.12.attn.qkv.weight', 'module.decoder.12.attn.qkv.bias', 'module.decoder.12.attn.proj.weight', 'module.decoder.12.attn.proj.bias', 'module.decoder.12.attn.q_norm.weight', 'module.decoder.12.attn.q_norm.bias', 'module.decoder.12.attn.k_norm.weight', 'module.decoder.12.attn.k_norm.bias', 'module.decoder.12.ls1.gamma', 'module.decoder.12.norm2.weight', 'module.decoder.12.norm2.bias', 'module.decoder.12.mlp.fc1.weight', 'module.decoder.12.mlp.fc1.bias', 'module.decoder.12.mlp.fc2.weight', 'module.decoder.12.mlp.fc2.bias', 'module.decoder.12.ls2.gamma', 'module.decoder.13.norm1.weight', 'module.decoder.13.norm1.bias', 'module.decoder.13.attn.qkv.weight', 'module.decoder.13.attn.qkv.bias', 'module.decoder.13.attn.proj.weight', 'module.decoder.13.attn.proj.bias', 'module.decoder.13.attn.q_norm.weight', 'module.decoder.13.attn.q_norm.bias', 'module.decoder.13.attn.k_norm.weight', 'module.decoder.13.attn.k_norm.bias', 'module.decoder.13.ls1.gamma', 'module.decoder.13.norm2.weight', 'module.decoder.13.norm2.bias', 'module.decoder.13.mlp.fc1.weight', 'module.decoder.13.mlp.fc1.bias', 'module.decoder.13.mlp.fc2.weight', 'module.decoder.13.mlp.fc2.bias', 'module.decoder.13.ls2.gamma', 'module.decoder.14.norm1.weight', 'module.decoder.14.norm1.bias', 'module.decoder.14.attn.qkv.weight', 'module.decoder.14.attn.qkv.bias', 'module.decoder.14.attn.proj.weight', 'module.decoder.14.attn.proj.bias', 'module.decoder.14.attn.q_norm.weight', 'module.decoder.14.attn.q_norm.bias', 'module.decoder.14.attn.k_norm.weight', 'module.decoder.14.attn.k_norm.bias', 'module.decoder.14.ls1.gamma', 'module.decoder.14.norm2.weight', 'module.decoder.14.norm2.bias', 'module.decoder.14.mlp.fc1.weight', 'module.decoder.14.mlp.fc1.bias', 'module.decoder.14.mlp.fc2.weight', 'module.decoder.14.mlp.fc2.bias', 'module.decoder.14.ls2.gamma', 'module.decoder.15.norm1.weight', 'module.decoder.15.norm1.bias', 'module.decoder.15.attn.qkv.weight', 'module.decoder.15.attn.qkv.bias', 'module.decoder.15.attn.proj.weight', 'module.decoder.15.attn.proj.bias', 'module.decoder.15.attn.q_norm.weight', 'module.decoder.15.attn.q_norm.bias', 'module.decoder.15.attn.k_norm.weight', 'module.decoder.15.attn.k_norm.bias', 'module.decoder.15.ls1.gamma', 'module.decoder.15.norm2.weight', 'module.decoder.15.norm2.bias', 'module.decoder.15.mlp.fc1.weight', 'module.decoder.15.mlp.fc1.bias', 'module.decoder.15.mlp.fc2.weight', 'module.decoder.15.mlp.fc2.bias', 'module.decoder.15.ls2.gamma', 'module.decoder.16.norm1.weight', 'module.decoder.16.norm1.bias', 'module.decoder.16.attn.qkv.weight', 'module.decoder.16.attn.qkv.bias', 'module.decoder.16.attn.proj.weight', 'module.decoder.16.attn.proj.bias', 'module.decoder.16.attn.q_norm.weight', 'module.decoder.16.attn.q_norm.bias', 'module.decoder.16.attn.k_norm.weight', 'module.decoder.16.attn.k_norm.bias', 'module.decoder.16.ls1.gamma', 'module.decoder.16.norm2.weight', 'module.decoder.16.norm2.bias', 'module.decoder.16.mlp.fc1.weight', 'module.decoder.16.mlp.fc1.bias', 'module.decoder.16.mlp.fc2.weight', 'module.decoder.16.mlp.fc2.bias', 'module.decoder.16.ls2.gamma', 'module.decoder.17.norm1.weight', 'module.decoder.17.norm1.bias', 'module.decoder.17.attn.qkv.weight', 'module.decoder.17.attn.qkv.bias', 'module.decoder.17.attn.proj.weight', 'module.decoder.17.attn.proj.bias', 'module.decoder.17.attn.q_norm.weight', 'module.decoder.17.attn.q_norm.bias', 'module.decoder.17.attn.k_norm.weight', 'module.decoder.17.attn.k_norm.bias', 'module.decoder.17.ls1.gamma', 'module.decoder.17.norm2.weight', 'module.decoder.17.norm2.bias', 'module.decoder.17.mlp.fc1.weight', 'module.decoder.17.mlp.fc1.bias', 'module.decoder.17.mlp.fc2.weight', 'module.decoder.17.mlp.fc2.bias', 'module.decoder.17.ls2.gamma', 'module.decoder.18.norm1.weight', 'module.decoder.18.norm1.bias', 'module.decoder.18.attn.qkv.weight', 'module.decoder.18.attn.qkv.bias', 'module.decoder.18.attn.proj.weight', 'module.decoder.18.attn.proj.bias', 'module.decoder.18.attn.q_norm.weight', 'module.decoder.18.attn.q_norm.bias', 'module.decoder.18.attn.k_norm.weight', 'module.decoder.18.attn.k_norm.bias', 'module.decoder.18.ls1.gamma', 'module.decoder.18.norm2.weight', 'module.decoder.18.norm2.bias', 'module.decoder.18.mlp.fc1.weight', 'module.decoder.18.mlp.fc1.bias', 'module.decoder.18.mlp.fc2.weight', 'module.decoder.18.mlp.fc2.bias', 'module.decoder.18.ls2.gamma', 'module.decoder.19.norm1.weight', 'module.decoder.19.norm1.bias', 'module.decoder.19.attn.qkv.weight', 'module.decoder.19.attn.qkv.bias', 'module.decoder.19.attn.proj.weight', 'module.decoder.19.attn.proj.bias', 'module.decoder.19.attn.q_norm.weight', 'module.decoder.19.attn.q_norm.bias', 'module.decoder.19.attn.k_norm.weight', 'module.decoder.19.attn.k_norm.bias', 'module.decoder.19.ls1.gamma', 'module.decoder.19.norm2.weight', 'module.decoder.19.norm2.bias', 'module.decoder.19.mlp.fc1.weight', 'module.decoder.19.mlp.fc1.bias', 'module.decoder.19.mlp.fc2.weight', 'module.decoder.19.mlp.fc2.bias', 'module.decoder.19.ls2.gamma', 'module.decoder.20.norm1.weight', 'module.decoder.20.norm1.bias', 'module.decoder.20.attn.qkv.weight', 'module.decoder.20.attn.qkv.bias', 'module.decoder.20.attn.proj.weight', 'module.decoder.20.attn.proj.bias', 'module.decoder.20.attn.q_norm.weight', 'module.decoder.20.attn.q_norm.bias', 'module.decoder.20.attn.k_norm.weight', 'module.decoder.20.attn.k_norm.bias', 'module.decoder.20.ls1.gamma', 'module.decoder.20.norm2.weight', 'module.decoder.20.norm2.bias', 'module.decoder.20.mlp.fc1.weight', 'module.decoder.20.mlp.fc1.bias', 'module.decoder.20.mlp.fc2.weight', 'module.decoder.20.mlp.fc2.bias', 'module.decoder.20.ls2.gamma', 'module.decoder.21.norm1.weight', 'module.decoder.21.norm1.bias', 'module.decoder.21.attn.qkv.weight', 'module.decoder.21.attn.qkv.bias', 'module.decoder.21.attn.proj.weight', 'module.decoder.21.attn.proj.bias', 'module.decoder.21.attn.q_norm.weight', 'module.decoder.21.attn.q_norm.bias', 'module.decoder.21.attn.k_norm.weight', 'module.decoder.21.attn.k_norm.bias', 'module.decoder.21.ls1.gamma', 'module.decoder.21.norm2.weight', 'module.decoder.21.norm2.bias', 'module.decoder.21.mlp.fc1.weight', 'module.decoder.21.mlp.fc1.bias', 'module.decoder.21.mlp.fc2.weight', 'module.decoder.21.mlp.fc2.bias', 'module.decoder.21.ls2.gamma', 'module.decoder.22.norm1.weight', 'module.decoder.22.norm1.bias', 'module.decoder.22.attn.qkv.weight', 'module.decoder.22.attn.qkv.bias', 'module.decoder.22.attn.proj.weight', 'module.decoder.22.attn.proj.bias', 'module.decoder.22.attn.q_norm.weight', 'module.decoder.22.attn.q_norm.bias', 'module.decoder.22.attn.k_norm.weight', 'module.decoder.22.attn.k_norm.bias', 'module.decoder.22.ls1.gamma', 'module.decoder.22.norm2.weight', 'module.decoder.22.norm2.bias', 'module.decoder.22.mlp.fc1.weight', 'module.decoder.22.mlp.fc1.bias', 'module.decoder.22.mlp.fc2.weight', 'module.decoder.22.mlp.fc2.bias', 'module.decoder.22.ls2.gamma', 'module.decoder.23.norm1.weight', 'module.decoder.23.norm1.bias', 'module.decoder.23.attn.qkv.weight', 'module.decoder.23.attn.qkv.bias', 'module.decoder.23.attn.proj.weight', 'module.decoder.23.attn.proj.bias', 'module.decoder.23.attn.q_norm.weight', 'module.decoder.23.attn.q_norm.bias', 'module.decoder.23.attn.k_norm.weight', 'module.decoder.23.attn.k_norm.bias', 'module.decoder.23.ls1.gamma', 'module.decoder.23.norm2.weight', 'module.decoder.23.norm2.bias', 'module.decoder.23.mlp.fc1.weight', 'module.decoder.23.mlp.fc1.bias', 'module.decoder.23.mlp.fc2.weight', 'module.decoder.23.mlp.fc2.bias', 'module.decoder.23.ls2.gamma', 'module.decoder.24.norm1.weight', 'module.decoder.24.norm1.bias', 'module.decoder.24.attn.qkv.weight', 'module.decoder.24.attn.qkv.bias', 'module.decoder.24.attn.proj.weight', 'module.decoder.24.attn.proj.bias', 'module.decoder.24.attn.q_norm.weight', 'module.decoder.24.attn.q_norm.bias', 'module.decoder.24.attn.k_norm.weight', 'module.decoder.24.attn.k_norm.bias', 'module.decoder.24.ls1.gamma', 'module.decoder.24.norm2.weight', 'module.decoder.24.norm2.bias', 'module.decoder.24.mlp.fc1.weight', 'module.decoder.24.mlp.fc1.bias', 'module.decoder.24.mlp.fc2.weight', 'module.decoder.24.mlp.fc2.bias', 'module.decoder.24.ls2.gamma', 'module.decoder.25.norm1.weight', 'module.decoder.25.norm1.bias', 'module.decoder.25.attn.qkv.weight', 'module.decoder.25.attn.qkv.bias', 'module.decoder.25.attn.proj.weight', 'module.decoder.25.attn.proj.bias', 'module.decoder.25.attn.q_norm.weight', 'module.decoder.25.attn.q_norm.bias', 'module.decoder.25.attn.k_norm.weight', 'module.decoder.25.attn.k_norm.bias', 'module.decoder.25.ls1.gamma', 'module.decoder.25.norm2.weight', 'module.decoder.25.norm2.bias', 'module.decoder.25.mlp.fc1.weight', 'module.decoder.25.mlp.fc1.bias', 'module.decoder.25.mlp.fc2.weight', 'module.decoder.25.mlp.fc2.bias', 'module.decoder.25.ls2.gamma', 'module.decoder.26.norm1.weight', 'module.decoder.26.norm1.bias', 'module.decoder.26.attn.qkv.weight', 'module.decoder.26.attn.qkv.bias', 'module.decoder.26.attn.proj.weight', 'module.decoder.26.attn.proj.bias', 'module.decoder.26.attn.q_norm.weight', 'module.decoder.26.attn.q_norm.bias', 'module.decoder.26.attn.k_norm.weight', 'module.decoder.26.attn.k_norm.bias', 'module.decoder.26.ls1.gamma', 'module.decoder.26.norm2.weight', 'module.decoder.26.norm2.bias', 'module.decoder.26.mlp.fc1.weight', 'module.decoder.26.mlp.fc1.bias', 'module.decoder.26.mlp.fc2.weight', 'module.decoder.26.mlp.fc2.bias', 'module.decoder.26.ls2.gamma', 'module.decoder.27.norm1.weight', 'module.decoder.27.norm1.bias', 'module.decoder.27.attn.qkv.weight', 'module.decoder.27.attn.qkv.bias', 'module.decoder.27.attn.proj.weight', 'module.decoder.27.attn.proj.bias', 'module.decoder.27.attn.q_norm.weight', 'module.decoder.27.attn.q_norm.bias', 'module.decoder.27.attn.k_norm.weight', 'module.decoder.27.attn.k_norm.bias', 'module.decoder.27.ls1.gamma', 'module.decoder.27.norm2.weight', 'module.decoder.27.norm2.bias', 'module.decoder.27.mlp.fc1.weight', 'module.decoder.27.mlp.fc1.bias', 'module.decoder.27.mlp.fc2.weight', 'module.decoder.27.mlp.fc2.bias', 'module.decoder.27.ls2.gamma', 'module.decoder.28.norm1.weight', 'module.decoder.28.norm1.bias', 'module.decoder.28.attn.qkv.weight', 'module.decoder.28.attn.qkv.bias', 'module.decoder.28.attn.proj.weight', 'module.decoder.28.attn.proj.bias', 'module.decoder.28.attn.q_norm.weight', 'module.decoder.28.attn.q_norm.bias', 'module.decoder.28.attn.k_norm.weight', 'module.decoder.28.attn.k_norm.bias', 'module.decoder.28.ls1.gamma', 'module.decoder.28.norm2.weight', 'module.decoder.28.norm2.bias', 'module.decoder.28.mlp.fc1.weight', 'module.decoder.28.mlp.fc1.bias', 'module.decoder.28.mlp.fc2.weight', 'module.decoder.28.mlp.fc2.bias', 'module.decoder.28.ls2.gamma', 'module.decoder.29.norm1.weight', 'module.decoder.29.norm1.bias', 'module.decoder.29.attn.qkv.weight', 'module.decoder.29.attn.qkv.bias', 'module.decoder.29.attn.proj.weight', 'module.decoder.29.attn.proj.bias', 'module.decoder.29.attn.q_norm.weight', 'module.decoder.29.attn.q_norm.bias', 'module.decoder.29.attn.k_norm.weight', 'module.decoder.29.attn.k_norm.bias', 'module.decoder.29.ls1.gamma', 'module.decoder.29.norm2.weight', 'module.decoder.29.norm2.bias', 'module.decoder.29.mlp.fc1.weight', 'module.decoder.29.mlp.fc1.bias', 'module.decoder.29.mlp.fc2.weight', 'module.decoder.29.mlp.fc2.bias', 'module.decoder.29.ls2.gamma', 'module.decoder.30.norm1.weight', 'module.decoder.30.norm1.bias', 'module.decoder.30.attn.qkv.weight', 'module.decoder.30.attn.qkv.bias', 'module.decoder.30.attn.proj.weight', 'module.decoder.30.attn.proj.bias', 'module.decoder.30.attn.q_norm.weight', 'module.decoder.30.attn.q_norm.bias', 'module.decoder.30.attn.k_norm.weight', 'module.decoder.30.attn.k_norm.bias', 'module.decoder.30.ls1.gamma', 'module.decoder.30.norm2.weight', 'module.decoder.30.norm2.bias', 'module.decoder.30.mlp.fc1.weight', 'module.decoder.30.mlp.fc1.bias', 'module.decoder.30.mlp.fc2.weight', 'module.decoder.30.mlp.fc2.bias', 'module.decoder.30.ls2.gamma', 'module.decoder.31.norm1.weight', 'module.decoder.31.norm1.bias', 'module.decoder.31.attn.qkv.weight', 'module.decoder.31.attn.qkv.bias', 'module.decoder.31.attn.proj.weight', 'module.decoder.31.attn.proj.bias', 'module.decoder.31.attn.q_norm.weight', 'module.decoder.31.attn.q_norm.bias', 'module.decoder.31.attn.k_norm.weight', 'module.decoder.31.attn.k_norm.bias', 'module.decoder.31.ls1.gamma', 'module.decoder.31.norm2.weight', 'module.decoder.31.norm2.bias', 'module.decoder.31.mlp.fc1.weight', 'module.decoder.31.mlp.fc1.bias', 'module.decoder.31.mlp.fc2.weight', 'module.decoder.31.mlp.fc2.bias', 'module.decoder.31.ls2.gamma', 'module.decoder.32.norm1.weight', 'module.decoder.32.norm1.bias', 'module.decoder.32.attn.qkv.weight', 'module.decoder.32.attn.qkv.bias', 'module.decoder.32.attn.proj.weight', 'module.decoder.32.attn.proj.bias', 'module.decoder.32.attn.q_norm.weight', 'module.decoder.32.attn.q_norm.bias', 'module.decoder.32.attn.k_norm.weight', 'module.decoder.32.attn.k_norm.bias', 'module.decoder.32.ls1.gamma', 'module.decoder.32.norm2.weight', 'module.decoder.32.norm2.bias', 'module.decoder.32.mlp.fc1.weight', 'module.decoder.32.mlp.fc1.bias', 'module.decoder.32.mlp.fc2.weight', 'module.decoder.32.mlp.fc2.bias', 'module.decoder.32.ls2.gamma', 'module.decoder.33.norm1.weight', 'module.decoder.33.norm1.bias', 'module.decoder.33.attn.qkv.weight', 'module.decoder.33.attn.qkv.bias', 'module.decoder.33.attn.proj.weight', 'module.decoder.33.attn.proj.bias', 'module.decoder.33.attn.q_norm.weight', 'module.decoder.33.attn.q_norm.bias', 'module.decoder.33.attn.k_norm.weight', 'module.decoder.33.attn.k_norm.bias', 'module.decoder.33.ls1.gamma', 'module.decoder.33.norm2.weight', 'module.decoder.33.norm2.bias', 'module.decoder.33.mlp.fc1.weight', 'module.decoder.33.mlp.fc1.bias', 'module.decoder.33.mlp.fc2.weight', 'module.decoder.33.mlp.fc2.bias', 'module.decoder.33.ls2.gamma', 'module.decoder.34.norm1.weight', 'module.decoder.34.norm1.bias', 'module.decoder.34.attn.qkv.weight', 'module.decoder.34.attn.qkv.bias', 'module.decoder.34.attn.proj.weight', 'module.decoder.34.attn.proj.bias', 'module.decoder.34.attn.q_norm.weight', 'module.decoder.34.attn.q_norm.bias', 'module.decoder.34.attn.k_norm.weight', 'module.decoder.34.attn.k_norm.bias', 'module.decoder.34.ls1.gamma', 'module.decoder.34.norm2.weight', 'module.decoder.34.norm2.bias', 'module.decoder.34.mlp.fc1.weight', 'module.decoder.34.mlp.fc1.bias', 'module.decoder.34.mlp.fc2.weight', 'module.decoder.34.mlp.fc2.bias', 'module.decoder.34.ls2.gamma', 'module.decoder.35.norm1.weight', 'module.decoder.35.norm1.bias', 'module.decoder.35.attn.qkv.weight', 'module.decoder.35.attn.qkv.bias', 'module.decoder.35.attn.proj.weight', 'module.decoder.35.attn.proj.bias', 'module.decoder.35.attn.q_norm.weight', 'module.decoder.35.attn.q_norm.bias', 'module.decoder.35.attn.k_norm.weight', 'module.decoder.35.attn.k_norm.bias', 'module.decoder.35.ls1.gamma', 'module.decoder.35.norm2.weight', 'module.decoder.35.norm2.bias', 'module.decoder.35.mlp.fc1.weight', 'module.decoder.35.mlp.fc1.bias', 'module.decoder.35.mlp.fc2.weight', 'module.decoder.35.mlp.fc2.bias', 'module.decoder.35.ls2.gamma', 'module.point_decoder.projects.weight', 'module.point_decoder.projects.bias', 'module.point_decoder.blocks.0.norm1.weight', 'module.point_decoder.blocks.0.norm1.bias', 'module.point_decoder.blocks.0.attn.qkv.weight', 'module.point_decoder.blocks.0.attn.qkv.bias', 'module.point_decoder.blocks.0.attn.proj.weight', 'module.point_decoder.blocks.0.attn.proj.bias', 'module.point_decoder.blocks.0.norm2.weight', 'module.point_decoder.blocks.0.norm2.bias', 'module.point_decoder.blocks.0.mlp.fc1.weight', 'module.point_decoder.blocks.0.mlp.fc1.bias', 'module.point_decoder.blocks.0.mlp.fc2.weight', 'module.point_decoder.blocks.0.mlp.fc2.bias', 'module.point_decoder.blocks.1.norm1.weight', 'module.point_decoder.blocks.1.norm1.bias', 'module.point_decoder.blocks.1.attn.qkv.weight', 'module.point_decoder.blocks.1.attn.qkv.bias', 'module.point_decoder.blocks.1.attn.proj.weight', 'module.point_decoder.blocks.1.attn.proj.bias', 'module.point_decoder.blocks.1.norm2.weight', 'module.point_decoder.blocks.1.norm2.bias', 'module.point_decoder.blocks.1.mlp.fc1.weight', 'module.point_decoder.blocks.1.mlp.fc1.bias', 'module.point_decoder.blocks.1.mlp.fc2.weight', 'module.point_decoder.blocks.1.mlp.fc2.bias', 'module.point_decoder.blocks.2.norm1.weight', 'module.point_decoder.blocks.2.norm1.bias', 'module.point_decoder.blocks.2.attn.qkv.weight', 'module.point_decoder.blocks.2.attn.qkv.bias', 'module.point_decoder.blocks.2.attn.proj.weight', 'module.point_decoder.blocks.2.attn.proj.bias', 'module.point_decoder.blocks.2.norm2.weight', 'module.point_decoder.blocks.2.norm2.bias', 'module.point_decoder.blocks.2.mlp.fc1.weight', 'module.point_decoder.blocks.2.mlp.fc1.bias', 'module.point_decoder.blocks.2.mlp.fc2.weight', 'module.point_decoder.blocks.2.mlp.fc2.bias', 'module.point_decoder.blocks.3.norm1.weight', 'module.point_decoder.blocks.3.norm1.bias', 'module.point_decoder.blocks.3.attn.qkv.weight', 'module.point_decoder.blocks.3.attn.qkv.bias', 'module.point_decoder.blocks.3.attn.proj.weight', 'module.point_decoder.blocks.3.attn.proj.bias', 'module.point_decoder.blocks.3.norm2.weight', 'module.point_decoder.blocks.3.norm2.bias', 'module.point_decoder.blocks.3.mlp.fc1.weight', 'module.point_decoder.blocks.3.mlp.fc1.bias', 'module.point_decoder.blocks.3.mlp.fc2.weight', 'module.point_decoder.blocks.3.mlp.fc2.bias', 'module.point_decoder.blocks.4.norm1.weight', 'module.point_decoder.blocks.4.norm1.bias', 'module.point_decoder.blocks.4.attn.qkv.weight', 'module.point_decoder.blocks.4.attn.qkv.bias', 'module.point_decoder.blocks.4.attn.proj.weight', 'module.point_decoder.blocks.4.attn.proj.bias', 'module.point_decoder.blocks.4.norm2.weight', 'module.point_decoder.blocks.4.norm2.bias', 'module.point_decoder.blocks.4.mlp.fc1.weight', 'module.point_decoder.blocks.4.mlp.fc1.bias', 'module.point_decoder.blocks.4.mlp.fc2.weight', 'module.point_decoder.blocks.4.mlp.fc2.bias', 'module.point_decoder.linear_out.weight', 'module.point_decoder.linear_out.bias', 'module.point_head.proj.weight', 'module.point_head.proj.bias', 'module.conf_decoder.projects.weight', 'module.conf_decoder.projects.bias', 'module.conf_decoder.blocks.0.norm1.weight', 'module.conf_decoder.blocks.0.norm1.bias', 'module.conf_decoder.blocks.0.attn.qkv.weight', 'module.conf_decoder.blocks.0.attn.qkv.bias', 'module.conf_decoder.blocks.0.attn.proj.weight', 'module.conf_decoder.blocks.0.attn.proj.bias', 'module.conf_decoder.blocks.0.norm2.weight', 'module.conf_decoder.blocks.0.norm2.bias', 'module.conf_decoder.blocks.0.mlp.fc1.weight', 'module.conf_decoder.blocks.0.mlp.fc1.bias', 'module.conf_decoder.blocks.0.mlp.fc2.weight', 'module.conf_decoder.blocks.0.mlp.fc2.bias', 'module.conf_decoder.blocks.1.norm1.weight', 'module.conf_decoder.blocks.1.norm1.bias', 'module.conf_decoder.blocks.1.attn.qkv.weight', 'module.conf_decoder.blocks.1.attn.qkv.bias', 'module.conf_decoder.blocks.1.attn.proj.weight', 'module.conf_decoder.blocks.1.attn.proj.bias', 'module.conf_decoder.blocks.1.norm2.weight', 'module.conf_decoder.blocks.1.norm2.bias', 'module.conf_decoder.blocks.1.mlp.fc1.weight', 'module.conf_decoder.blocks.1.mlp.fc1.bias', 'module.conf_decoder.blocks.1.mlp.fc2.weight', 'module.conf_decoder.blocks.1.mlp.fc2.bias', 'module.conf_decoder.blocks.2.norm1.weight', 'module.conf_decoder.blocks.2.norm1.bias', 'module.conf_decoder.blocks.2.attn.qkv.weight', 'module.conf_decoder.blocks.2.attn.qkv.bias', 'module.conf_decoder.blocks.2.attn.proj.weight', 'module.conf_decoder.blocks.2.attn.proj.bias', 'module.conf_decoder.blocks.2.norm2.weight', 'module.conf_decoder.blocks.2.norm2.bias', 'module.conf_decoder.blocks.2.mlp.fc1.weight', 'module.conf_decoder.blocks.2.mlp.fc1.bias', 'module.conf_decoder.blocks.2.mlp.fc2.weight', 'module.conf_decoder.blocks.2.mlp.fc2.bias', 'module.conf_decoder.blocks.3.norm1.weight', 'module.conf_decoder.blocks.3.norm1.bias', 'module.conf_decoder.blocks.3.attn.qkv.weight', 'module.conf_decoder.blocks.3.attn.qkv.bias', 'module.conf_decoder.blocks.3.attn.proj.weight', 'module.conf_decoder.blocks.3.attn.proj.bias', 'module.conf_decoder.blocks.3.norm2.weight', 'module.conf_decoder.blocks.3.norm2.bias', 'module.conf_decoder.blocks.3.mlp.fc1.weight', 'module.conf_decoder.blocks.3.mlp.fc1.bias', 'module.conf_decoder.blocks.3.mlp.fc2.weight', 'module.conf_decoder.blocks.3.mlp.fc2.bias', 'module.conf_decoder.blocks.4.norm1.weight', 'module.conf_decoder.blocks.4.norm1.bias', 'module.conf_decoder.blocks.4.attn.qkv.weight', 'module.conf_decoder.blocks.4.attn.qkv.bias', 'module.conf_decoder.blocks.4.attn.proj.weight', 'module.conf_decoder.blocks.4.attn.proj.bias', 'module.conf_decoder.blocks.4.norm2.weight', 'module.conf_decoder.blocks.4.norm2.bias', 'module.conf_decoder.blocks.4.mlp.fc1.weight', 'module.conf_decoder.blocks.4.mlp.fc1.bias', 'module.conf_decoder.blocks.4.mlp.fc2.weight', 'module.conf_decoder.blocks.4.mlp.fc2.bias', 'module.conf_decoder.linear_out.weight', 'module.conf_decoder.linear_out.bias', 'module.conf_head.proj.weight', 'module.conf_head.proj.bias', 'module.camera_decoder.projects.weight', 'module.camera_decoder.projects.bias', 'module.camera_decoder.blocks.0.norm1.weight', 'module.camera_decoder.blocks.0.norm1.bias', 'module.camera_decoder.blocks.0.attn.qkv.weight', 'module.camera_decoder.blocks.0.attn.qkv.bias', 'module.camera_decoder.blocks.0.attn.proj.weight', 'module.camera_decoder.blocks.0.attn.proj.bias', 'module.camera_decoder.blocks.0.norm2.weight', 'module.camera_decoder.blocks.0.norm2.bias', 'module.camera_decoder.blocks.0.mlp.fc1.weight', 'module.camera_decoder.blocks.0.mlp.fc1.bias', 'module.camera_decoder.blocks.0.mlp.fc2.weight', 'module.camera_decoder.blocks.0.mlp.fc2.bias', 'module.camera_decoder.blocks.1.norm1.weight', 'module.camera_decoder.blocks.1.norm1.bias', 'module.camera_decoder.blocks.1.attn.qkv.weight', 'module.camera_decoder.blocks.1.attn.qkv.bias', 'module.camera_decoder.blocks.1.attn.proj.weight', 'module.camera_decoder.blocks.1.attn.proj.bias', 'module.camera_decoder.blocks.1.norm2.weight', 'module.camera_decoder.blocks.1.norm2.bias', 'module.camera_decoder.blocks.1.mlp.fc1.weight', 'module.camera_decoder.blocks.1.mlp.fc1.bias', 'module.camera_decoder.blocks.1.mlp.fc2.weight', 'module.camera_decoder.blocks.1.mlp.fc2.bias', 'module.camera_decoder.blocks.2.norm1.weight', 'module.camera_decoder.blocks.2.norm1.bias', 'module.camera_decoder.blocks.2.attn.qkv.weight', 'module.camera_decoder.blocks.2.attn.qkv.bias', 'module.camera_decoder.blocks.2.attn.proj.weight', 'module.camera_decoder.blocks.2.attn.proj.bias', 'module.camera_decoder.blocks.2.norm2.weight', 'module.camera_decoder.blocks.2.norm2.bias', 'module.camera_decoder.blocks.2.mlp.fc1.weight', 'module.camera_decoder.blocks.2.mlp.fc1.bias', 'module.camera_decoder.blocks.2.mlp.fc2.weight', 'module.camera_decoder.blocks.2.mlp.fc2.bias', 'module.camera_decoder.blocks.3.norm1.weight', 'module.camera_decoder.blocks.3.norm1.bias', 'module.camera_decoder.blocks.3.attn.qkv.weight', 'module.camera_decoder.blocks.3.attn.qkv.bias', 'module.camera_decoder.blocks.3.attn.proj.weight', 'module.camera_decoder.blocks.3.attn.proj.bias', 'module.camera_decoder.blocks.3.norm2.weight', 'module.camera_decoder.blocks.3.norm2.bias', 'module.camera_decoder.blocks.3.mlp.fc1.weight', 'module.camera_decoder.blocks.3.mlp.fc1.bias', 'module.camera_decoder.blocks.3.mlp.fc2.weight', 'module.camera_decoder.blocks.3.mlp.fc2.bias', 'module.camera_decoder.blocks.4.norm1.weight', 'module.camera_decoder.blocks.4.norm1.bias', 'module.camera_decoder.blocks.4.attn.qkv.weight', 'module.camera_decoder.blocks.4.attn.qkv.bias', 'module.camera_decoder.blocks.4.attn.proj.weight', 'module.camera_decoder.blocks.4.attn.proj.bias', 'module.camera_decoder.blocks.4.norm2.weight', 'module.camera_decoder.blocks.4.norm2.bias', 'module.camera_decoder.blocks.4.mlp.fc1.weight', 'module.camera_decoder.blocks.4.mlp.fc1.bias', 'module.camera_decoder.blocks.4.mlp.fc2.weight', 'module.camera_decoder.blocks.4.mlp.fc2.bias', 'module.camera_decoder.linear_out.weight', 'module.camera_decoder.linear_out.bias', 'module.camera_head.res_conv.0.res_conv1.weight', 'module.camera_head.res_conv.0.res_conv1.bias', 'module.camera_head.res_conv.0.res_conv2.weight', 'module.camera_head.res_conv.0.res_conv2.bias', 'module.camera_head.res_conv.0.res_conv3.weight', 'module.camera_head.res_conv.0.res_conv3.bias', 'module.camera_head.res_conv.1.res_conv1.weight', 'module.camera_head.res_conv.1.res_conv1.bias', 'module.camera_head.res_conv.1.res_conv2.weight', 'module.camera_head.res_conv.1.res_conv2.bias', 'module.camera_head.res_conv.1.res_conv3.weight', 'module.camera_head.res_conv.1.res_conv3.bias', 'module.camera_head.more_mlps.0.weight', 'module.camera_head.more_mlps.0.bias', 'module.camera_head.more_mlps.2.weight', 'module.camera_head.more_mlps.2.bias', 'module.camera_head.fc_t.weight', 'module.camera_head.fc_t.bias', 'module.camera_head.fc_rot.weight', 'module.camera_head.fc_rot.bias']) +[2026-05-02 01:14:36,496][__main__][INFO] - [RANK 0] Freezing patch embedding and positional encoding parameters... +[2026-05-02 01:14:36,502][__main__][INFO] - [RANK 0] Frozen 304,376,832 parameters out of 958,696,732 total parameters. (31.75%) +[2026-05-02 01:14:36,502][__main__][INFO] - [RANK 0] Trainable parameters: 654,319,900 (68.25%) +[2026-05-02 01:14:36,502][__main__][INFO] - [RANK 0] Example frozen parameters: register_token, encoder.cls_token, encoder.pos_embed, encoder.register_tokens, encoder.patch_embed.proj.weight... +[2026-05-02 01:14:36,505][croco.utils.misc][INFO] - [RANK 0] Param groups = { + "no_decay": { + "weight_decay": 0.0, + "params": [ + "decoder.0.norm1.weight", + "decoder.0.norm1.bias", + "decoder.0.attn.qkv.bias", + "decoder.0.attn.proj.bias", + "decoder.0.attn.q_norm.weight", + "decoder.0.attn.q_norm.bias", + "decoder.0.attn.k_norm.weight", + "decoder.0.attn.k_norm.bias", + "decoder.0.ls1.gamma", + "decoder.0.norm2.weight", + "decoder.0.norm2.bias", + "decoder.0.mlp.fc1.bias", + "decoder.0.mlp.fc2.bias", + "decoder.0.ls2.gamma", + "decoder.1.norm1.weight", + "decoder.1.norm1.bias", + "decoder.1.attn.qkv.bias", + "decoder.1.attn.proj.bias", + "decoder.1.attn.q_norm.weight", + "decoder.1.attn.q_norm.bias", + "decoder.1.attn.k_norm.weight", + "decoder.1.attn.k_norm.bias", + "decoder.1.ls1.gamma", + "decoder.1.norm2.weight", + "decoder.1.norm2.bias", + "decoder.1.mlp.fc1.bias", + "decoder.1.mlp.fc2.bias", + "decoder.1.ls2.gamma", + "decoder.2.norm1.weight", + "decoder.2.norm1.bias", + "decoder.2.attn.qkv.bias", + "decoder.2.attn.proj.bias", + "decoder.2.attn.q_norm.weight", + "decoder.2.attn.q_norm.bias", + "decoder.2.attn.k_norm.weight", + "decoder.2.attn.k_norm.bias", + "decoder.2.ls1.gamma", + "decoder.2.norm2.weight", + "decoder.2.norm2.bias", + "decoder.2.mlp.fc1.bias", + "decoder.2.mlp.fc2.bias", + "decoder.2.ls2.gamma", + "decoder.3.norm1.weight", + "decoder.3.norm1.bias", + "decoder.3.attn.qkv.bias", + "decoder.3.attn.proj.bias", + "decoder.3.attn.q_norm.weight", + "decoder.3.attn.q_norm.bias", + "decoder.3.attn.k_norm.weight", + "decoder.3.attn.k_norm.bias", + "decoder.3.ls1.gamma", + "decoder.3.norm2.weight", + "decoder.3.norm2.bias", + "decoder.3.mlp.fc1.bias", + "decoder.3.mlp.fc2.bias", + "decoder.3.ls2.gamma", + "decoder.4.norm1.weight", + "decoder.4.norm1.bias", + "decoder.4.attn.qkv.bias", + "decoder.4.attn.proj.bias", + "decoder.4.attn.q_norm.weight", + "decoder.4.attn.q_norm.bias", + "decoder.4.attn.k_norm.weight", + "decoder.4.attn.k_norm.bias", + "decoder.4.ls1.gamma", + "decoder.4.norm2.weight", + "decoder.4.norm2.bias", + "decoder.4.mlp.fc1.bias", + "decoder.4.mlp.fc2.bias", + "decoder.4.ls2.gamma", + "decoder.5.norm1.weight", + "decoder.5.norm1.bias", + "decoder.5.attn.qkv.bias", + "decoder.5.attn.proj.bias", + "decoder.5.attn.q_norm.weight", + "decoder.5.attn.q_norm.bias", + "decoder.5.attn.k_norm.weight", + "decoder.5.attn.k_norm.bias", + "decoder.5.ls1.gamma", + "decoder.5.norm2.weight", + "decoder.5.norm2.bias", + "decoder.5.mlp.fc1.bias", + "decoder.5.mlp.fc2.bias", + "decoder.5.ls2.gamma", + "decoder.6.norm1.weight", + "decoder.6.norm1.bias", + "decoder.6.attn.qkv.bias", + "decoder.6.attn.proj.bias", + "decoder.6.attn.q_norm.weight", + "decoder.6.attn.q_norm.bias", + "decoder.6.attn.k_norm.weight", + "decoder.6.attn.k_norm.bias", + "decoder.6.ls1.gamma", + "decoder.6.norm2.weight", + "decoder.6.norm2.bias", + "decoder.6.mlp.fc1.bias", + "decoder.6.mlp.fc2.bias", + "decoder.6.ls2.gamma", + "decoder.7.norm1.weight", + "decoder.7.norm1.bias", + "decoder.7.attn.qkv.bias", + "decoder.7.attn.proj.bias", + "decoder.7.attn.q_norm.weight", + "decoder.7.attn.q_norm.bias", + "decoder.7.attn.k_norm.weight", + "decoder.7.attn.k_norm.bias", + "decoder.7.ls1.gamma", + "decoder.7.norm2.weight", + "decoder.7.norm2.bias", + "decoder.7.mlp.fc1.bias", + "decoder.7.mlp.fc2.bias", + "decoder.7.ls2.gamma", + "decoder.8.norm1.weight", + "decoder.8.norm1.bias", + "decoder.8.attn.qkv.bias", + "decoder.8.attn.proj.bias", + "decoder.8.attn.q_norm.weight", + "decoder.8.attn.q_norm.bias", + "decoder.8.attn.k_norm.weight", + "decoder.8.attn.k_norm.bias", + "decoder.8.ls1.gamma", + "decoder.8.norm2.weight", + "decoder.8.norm2.bias", + "decoder.8.mlp.fc1.bias", + "decoder.8.mlp.fc2.bias", + "decoder.8.ls2.gamma", + "decoder.9.norm1.weight", + "decoder.9.norm1.bias", + "decoder.9.attn.qkv.bias", + "decoder.9.attn.proj.bias", + "decoder.9.attn.q_norm.weight", + "decoder.9.attn.q_norm.bias", + "decoder.9.attn.k_norm.weight", + "decoder.9.attn.k_norm.bias", + "decoder.9.ls1.gamma", + "decoder.9.norm2.weight", + "decoder.9.norm2.bias", + "decoder.9.mlp.fc1.bias", + "decoder.9.mlp.fc2.bias", + "decoder.9.ls2.gamma", + "decoder.10.norm1.weight", + "decoder.10.norm1.bias", + "decoder.10.attn.qkv.bias", + "decoder.10.attn.proj.bias", + "decoder.10.attn.q_norm.weight", + "decoder.10.attn.q_norm.bias", + "decoder.10.attn.k_norm.weight", + "decoder.10.attn.k_norm.bias", + "decoder.10.ls1.gamma", + "decoder.10.norm2.weight", + "decoder.10.norm2.bias", + "decoder.10.mlp.fc1.bias", + "decoder.10.mlp.fc2.bias", + "decoder.10.ls2.gamma", + "decoder.11.norm1.weight", + "decoder.11.norm1.bias", + "decoder.11.attn.qkv.bias", + "decoder.11.attn.proj.bias", + "decoder.11.attn.q_norm.weight", + "decoder.11.attn.q_norm.bias", + "decoder.11.attn.k_norm.weight", + "decoder.11.attn.k_norm.bias", + "decoder.11.ls1.gamma", + "decoder.11.norm2.weight", + "decoder.11.norm2.bias", + "decoder.11.mlp.fc1.bias", + "decoder.11.mlp.fc2.bias", + "decoder.11.ls2.gamma", + "decoder.12.norm1.weight", + "decoder.12.norm1.bias", + "decoder.12.attn.qkv.bias", + "decoder.12.attn.proj.bias", + "decoder.12.attn.q_norm.weight", + "decoder.12.attn.q_norm.bias", + "decoder.12.attn.k_norm.weight", + "decoder.12.attn.k_norm.bias", + "decoder.12.ls1.gamma", + "decoder.12.norm2.weight", + "decoder.12.norm2.bias", + "decoder.12.mlp.fc1.bias", + "decoder.12.mlp.fc2.bias", + "decoder.12.ls2.gamma", + "decoder.13.norm1.weight", + "decoder.13.norm1.bias", + "decoder.13.attn.qkv.bias", + "decoder.13.attn.proj.bias", + "decoder.13.attn.q_norm.weight", + "decoder.13.attn.q_norm.bias", + "decoder.13.attn.k_norm.weight", + "decoder.13.attn.k_norm.bias", + "decoder.13.ls1.gamma", + "decoder.13.norm2.weight", + "decoder.13.norm2.bias", + "decoder.13.mlp.fc1.bias", + "decoder.13.mlp.fc2.bias", + "decoder.13.ls2.gamma", + "decoder.14.norm1.weight", + "decoder.14.norm1.bias", + "decoder.14.attn.qkv.bias", + "decoder.14.attn.proj.bias", + "decoder.14.attn.q_norm.weight", + "decoder.14.attn.q_norm.bias", + "decoder.14.attn.k_norm.weight", + "decoder.14.attn.k_norm.bias", + "decoder.14.ls1.gamma", + "decoder.14.norm2.weight", + "decoder.14.norm2.bias", + "decoder.14.mlp.fc1.bias", + "decoder.14.mlp.fc2.bias", + "decoder.14.ls2.gamma", + "decoder.15.norm1.weight", + "decoder.15.norm1.bias", + "decoder.15.attn.qkv.bias", + "decoder.15.attn.proj.bias", + "decoder.15.attn.q_norm.weight", + "decoder.15.attn.q_norm.bias", + "decoder.15.attn.k_norm.weight", + "decoder.15.attn.k_norm.bias", + "decoder.15.ls1.gamma", + "decoder.15.norm2.weight", + "decoder.15.norm2.bias", + "decoder.15.mlp.fc1.bias", + "decoder.15.mlp.fc2.bias", + "decoder.15.ls2.gamma", + "decoder.16.norm1.weight", + "decoder.16.norm1.bias", + "decoder.16.attn.qkv.bias", + "decoder.16.attn.proj.bias", + "decoder.16.attn.q_norm.weight", + "decoder.16.attn.q_norm.bias", + "decoder.16.attn.k_norm.weight", + "decoder.16.attn.k_norm.bias", + "decoder.16.ls1.gamma", + "decoder.16.norm2.weight", + "decoder.16.norm2.bias", + "decoder.16.mlp.fc1.bias", + "decoder.16.mlp.fc2.bias", + "decoder.16.ls2.gamma", + "decoder.17.norm1.weight", + "decoder.17.norm1.bias", + "decoder.17.attn.qkv.bias", + "decoder.17.attn.proj.bias", + "decoder.17.attn.q_norm.weight", + "decoder.17.attn.q_norm.bias", + "decoder.17.attn.k_norm.weight", + "decoder.17.attn.k_norm.bias", + "decoder.17.ls1.gamma", + "decoder.17.norm2.weight", + "decoder.17.norm2.bias", + "decoder.17.mlp.fc1.bias", + "decoder.17.mlp.fc2.bias", + "decoder.17.ls2.gamma", + "decoder.18.norm1.weight", + "decoder.18.norm1.bias", + "decoder.18.attn.qkv.bias", + "decoder.18.attn.proj.bias", + "decoder.18.attn.q_norm.weight", + "decoder.18.attn.q_norm.bias", + "decoder.18.attn.k_norm.weight", + "decoder.18.attn.k_norm.bias", + "decoder.18.ls1.gamma", + "decoder.18.norm2.weight", + "decoder.18.norm2.bias", + "decoder.18.mlp.fc1.bias", + "decoder.18.mlp.fc2.bias", + "decoder.18.ls2.gamma", + "decoder.19.norm1.weight", + "decoder.19.norm1.bias", + "decoder.19.attn.qkv.bias", + "decoder.19.attn.proj.bias", + "decoder.19.attn.q_norm.weight", + "decoder.19.attn.q_norm.bias", + "decoder.19.attn.k_norm.weight", + "decoder.19.attn.k_norm.bias", + "decoder.19.ls1.gamma", + "decoder.19.norm2.weight", + "decoder.19.norm2.bias", + "decoder.19.mlp.fc1.bias", + "decoder.19.mlp.fc2.bias", + "decoder.19.ls2.gamma", + "decoder.20.norm1.weight", + "decoder.20.norm1.bias", + "decoder.20.attn.qkv.bias", + "decoder.20.attn.proj.bias", + "decoder.20.attn.q_norm.weight", + "decoder.20.attn.q_norm.bias", + "decoder.20.attn.k_norm.weight", + "decoder.20.attn.k_norm.bias", + "decoder.20.ls1.gamma", + "decoder.20.norm2.weight", + "decoder.20.norm2.bias", + "decoder.20.mlp.fc1.bias", + "decoder.20.mlp.fc2.bias", + "decoder.20.ls2.gamma", + "decoder.21.norm1.weight", + "decoder.21.norm1.bias", + "decoder.21.attn.qkv.bias", + "decoder.21.attn.proj.bias", + "decoder.21.attn.q_norm.weight", + "decoder.21.attn.q_norm.bias", + "decoder.21.attn.k_norm.weight", + "decoder.21.attn.k_norm.bias", + "decoder.21.ls1.gamma", + "decoder.21.norm2.weight", + "decoder.21.norm2.bias", + "decoder.21.mlp.fc1.bias", + "decoder.21.mlp.fc2.bias", + "decoder.21.ls2.gamma", + "decoder.22.norm1.weight", + "decoder.22.norm1.bias", + "decoder.22.attn.qkv.bias", + "decoder.22.attn.proj.bias", + "decoder.22.attn.q_norm.weight", + "decoder.22.attn.q_norm.bias", + "decoder.22.attn.k_norm.weight", + "decoder.22.attn.k_norm.bias", + "decoder.22.ls1.gamma", + "decoder.22.norm2.weight", + "decoder.22.norm2.bias", + "decoder.22.mlp.fc1.bias", + "decoder.22.mlp.fc2.bias", + "decoder.22.ls2.gamma", + "decoder.23.norm1.weight", + "decoder.23.norm1.bias", + "decoder.23.attn.qkv.bias", + "decoder.23.attn.proj.bias", + "decoder.23.attn.q_norm.weight", + "decoder.23.attn.q_norm.bias", + "decoder.23.attn.k_norm.weight", + "decoder.23.attn.k_norm.bias", + "decoder.23.ls1.gamma", + "decoder.23.norm2.weight", + "decoder.23.norm2.bias", + "decoder.23.mlp.fc1.bias", + "decoder.23.mlp.fc2.bias", + "decoder.23.ls2.gamma", + "decoder.24.norm1.weight", + "decoder.24.norm1.bias", + "decoder.24.attn.qkv.bias", + "decoder.24.attn.proj.bias", + "decoder.24.attn.q_norm.weight", + "decoder.24.attn.q_norm.bias", + "decoder.24.attn.k_norm.weight", + "decoder.24.attn.k_norm.bias", + "decoder.24.ls1.gamma", + "decoder.24.norm2.weight", + "decoder.24.norm2.bias", + "decoder.24.mlp.fc1.bias", + "decoder.24.mlp.fc2.bias", + "decoder.24.ls2.gamma", + "decoder.25.norm1.weight", + "decoder.25.norm1.bias", + "decoder.25.attn.qkv.bias", + "decoder.25.attn.proj.bias", + "decoder.25.attn.q_norm.weight", + "decoder.25.attn.q_norm.bias", + "decoder.25.attn.k_norm.weight", + "decoder.25.attn.k_norm.bias", + "decoder.25.ls1.gamma", + "decoder.25.norm2.weight", + "decoder.25.norm2.bias", + "decoder.25.mlp.fc1.bias", + "decoder.25.mlp.fc2.bias", + "decoder.25.ls2.gamma", + "decoder.26.norm1.weight", + "decoder.26.norm1.bias", + "decoder.26.attn.qkv.bias", + "decoder.26.attn.proj.bias", + "decoder.26.attn.q_norm.weight", + "decoder.26.attn.q_norm.bias", + "decoder.26.attn.k_norm.weight", + "decoder.26.attn.k_norm.bias", + "decoder.26.ls1.gamma", + "decoder.26.norm2.weight", + "decoder.26.norm2.bias", + "decoder.26.mlp.fc1.bias", + "decoder.26.mlp.fc2.bias", + "decoder.26.ls2.gamma", + "decoder.27.norm1.weight", + "decoder.27.norm1.bias", + "decoder.27.attn.qkv.bias", + "decoder.27.attn.proj.bias", + "decoder.27.attn.q_norm.weight", + "decoder.27.attn.q_norm.bias", + "decoder.27.attn.k_norm.weight", + "decoder.27.attn.k_norm.bias", + "decoder.27.ls1.gamma", + "decoder.27.norm2.weight", + "decoder.27.norm2.bias", + "decoder.27.mlp.fc1.bias", + "decoder.27.mlp.fc2.bias", + "decoder.27.ls2.gamma", + "decoder.28.norm1.weight", + "decoder.28.norm1.bias", + "decoder.28.attn.qkv.bias", + "decoder.28.attn.proj.bias", + "decoder.28.attn.q_norm.weight", + "decoder.28.attn.q_norm.bias", + "decoder.28.attn.k_norm.weight", + "decoder.28.attn.k_norm.bias", + "decoder.28.ls1.gamma", + "decoder.28.norm2.weight", + "decoder.28.norm2.bias", + "decoder.28.mlp.fc1.bias", + "decoder.28.mlp.fc2.bias", + "decoder.28.ls2.gamma", + "decoder.29.norm1.weight", + "decoder.29.norm1.bias", + "decoder.29.attn.qkv.bias", + "decoder.29.attn.proj.bias", + "decoder.29.attn.q_norm.weight", + "decoder.29.attn.q_norm.bias", + "decoder.29.attn.k_norm.weight", + "decoder.29.attn.k_norm.bias", + "decoder.29.ls1.gamma", + "decoder.29.norm2.weight", + "decoder.29.norm2.bias", + "decoder.29.mlp.fc1.bias", + "decoder.29.mlp.fc2.bias", + "decoder.29.ls2.gamma", + "decoder.30.norm1.weight", + "decoder.30.norm1.bias", + "decoder.30.attn.qkv.bias", + "decoder.30.attn.proj.bias", + "decoder.30.attn.q_norm.weight", + "decoder.30.attn.q_norm.bias", + "decoder.30.attn.k_norm.weight", + "decoder.30.attn.k_norm.bias", + "decoder.30.ls1.gamma", + "decoder.30.norm2.weight", + "decoder.30.norm2.bias", + "decoder.30.mlp.fc1.bias", + "decoder.30.mlp.fc2.bias", + "decoder.30.ls2.gamma", + "decoder.31.norm1.weight", + "decoder.31.norm1.bias", + "decoder.31.attn.qkv.bias", + "decoder.31.attn.proj.bias", + "decoder.31.attn.q_norm.weight", + "decoder.31.attn.q_norm.bias", + "decoder.31.attn.k_norm.weight", + "decoder.31.attn.k_norm.bias", + "decoder.31.ls1.gamma", + "decoder.31.norm2.weight", + "decoder.31.norm2.bias", + "decoder.31.mlp.fc1.bias", + "decoder.31.mlp.fc2.bias", + "decoder.31.ls2.gamma", + "decoder.32.norm1.weight", + "decoder.32.norm1.bias", + "decoder.32.attn.qkv.bias", + "decoder.32.attn.proj.bias", + "decoder.32.attn.q_norm.weight", + "decoder.32.attn.q_norm.bias", + "decoder.32.attn.k_norm.weight", + "decoder.32.attn.k_norm.bias", + "decoder.32.ls1.gamma", + "decoder.32.norm2.weight", + "decoder.32.norm2.bias", + "decoder.32.mlp.fc1.bias", + "decoder.32.mlp.fc2.bias", + "decoder.32.ls2.gamma", + "decoder.33.norm1.weight", + "decoder.33.norm1.bias", + "decoder.33.attn.qkv.bias", + "decoder.33.attn.proj.bias", + "decoder.33.attn.q_norm.weight", + "decoder.33.attn.q_norm.bias", + "decoder.33.attn.k_norm.weight", + "decoder.33.attn.k_norm.bias", + "decoder.33.ls1.gamma", + "decoder.33.norm2.weight", + "decoder.33.norm2.bias", + "decoder.33.mlp.fc1.bias", + "decoder.33.mlp.fc2.bias", + "decoder.33.ls2.gamma", + "decoder.34.norm1.weight", + "decoder.34.norm1.bias", + "decoder.34.attn.qkv.bias", + "decoder.34.attn.proj.bias", + "decoder.34.attn.q_norm.weight", + "decoder.34.attn.q_norm.bias", + "decoder.34.attn.k_norm.weight", + "decoder.34.attn.k_norm.bias", + "decoder.34.ls1.gamma", + "decoder.34.norm2.weight", + "decoder.34.norm2.bias", + "decoder.34.mlp.fc1.bias", + "decoder.34.mlp.fc2.bias", + "decoder.34.ls2.gamma", + "decoder.35.norm1.weight", + "decoder.35.norm1.bias", + "decoder.35.attn.qkv.bias", + "decoder.35.attn.proj.bias", + "decoder.35.attn.q_norm.weight", + "decoder.35.attn.q_norm.bias", + "decoder.35.attn.k_norm.weight", + "decoder.35.attn.k_norm.bias", + "decoder.35.ls1.gamma", + "decoder.35.norm2.weight", + "decoder.35.norm2.bias", + "decoder.35.mlp.fc1.bias", + "decoder.35.mlp.fc2.bias", + "decoder.35.ls2.gamma", + "point_decoder.projects.bias", + "point_decoder.blocks.0.norm1.weight", + "point_decoder.blocks.0.norm1.bias", + "point_decoder.blocks.0.attn.qkv.bias", + "point_decoder.blocks.0.attn.proj.bias", + "point_decoder.blocks.0.norm2.weight", + "point_decoder.blocks.0.norm2.bias", + "point_decoder.blocks.0.mlp.fc1.bias", + "point_decoder.blocks.0.mlp.fc2.bias", + "point_decoder.blocks.1.norm1.weight", + "point_decoder.blocks.1.norm1.bias", + "point_decoder.blocks.1.attn.qkv.bias", + "point_decoder.blocks.1.attn.proj.bias", + "point_decoder.blocks.1.norm2.weight", + "point_decoder.blocks.1.norm2.bias", + "point_decoder.blocks.1.mlp.fc1.bias", + "point_decoder.blocks.1.mlp.fc2.bias", + "point_decoder.blocks.2.norm1.weight", + "point_decoder.blocks.2.norm1.bias", + "point_decoder.blocks.2.attn.qkv.bias", + "point_decoder.blocks.2.attn.proj.bias", + "point_decoder.blocks.2.norm2.weight", + "point_decoder.blocks.2.norm2.bias", + "point_decoder.blocks.2.mlp.fc1.bias", + "point_decoder.blocks.2.mlp.fc2.bias", + "point_decoder.blocks.3.norm1.weight", + "point_decoder.blocks.3.norm1.bias", + "point_decoder.blocks.3.attn.qkv.bias", + "point_decoder.blocks.3.attn.proj.bias", + "point_decoder.blocks.3.norm2.weight", + "point_decoder.blocks.3.norm2.bias", + "point_decoder.blocks.3.mlp.fc1.bias", + "point_decoder.blocks.3.mlp.fc2.bias", + "point_decoder.blocks.4.norm1.weight", + "point_decoder.blocks.4.norm1.bias", + "point_decoder.blocks.4.attn.qkv.bias", + "point_decoder.blocks.4.attn.proj.bias", + "point_decoder.blocks.4.norm2.weight", + "point_decoder.blocks.4.norm2.bias", + "point_decoder.blocks.4.mlp.fc1.bias", + "point_decoder.blocks.4.mlp.fc2.bias", + "point_decoder.linear_out.bias", + "point_head.proj.bias", + "conf_decoder.projects.bias", + "conf_decoder.blocks.0.norm1.weight", + "conf_decoder.blocks.0.norm1.bias", + "conf_decoder.blocks.0.attn.qkv.bias", + "conf_decoder.blocks.0.attn.proj.bias", + "conf_decoder.blocks.0.norm2.weight", + "conf_decoder.blocks.0.norm2.bias", + "conf_decoder.blocks.0.mlp.fc1.bias", + "conf_decoder.blocks.0.mlp.fc2.bias", + "conf_decoder.blocks.1.norm1.weight", + "conf_decoder.blocks.1.norm1.bias", + "conf_decoder.blocks.1.attn.qkv.bias", + "conf_decoder.blocks.1.attn.proj.bias", + "conf_decoder.blocks.1.norm2.weight", + "conf_decoder.blocks.1.norm2.bias", + "conf_decoder.blocks.1.mlp.fc1.bias", + "conf_decoder.blocks.1.mlp.fc2.bias", + "conf_decoder.blocks.2.norm1.weight", + "conf_decoder.blocks.2.norm1.bias", + "conf_decoder.blocks.2.attn.qkv.bias", + "conf_decoder.blocks.2.attn.proj.bias", + "conf_decoder.blocks.2.norm2.weight", + "conf_decoder.blocks.2.norm2.bias", + "conf_decoder.blocks.2.mlp.fc1.bias", + "conf_decoder.blocks.2.mlp.fc2.bias", + "conf_decoder.blocks.3.norm1.weight", + "conf_decoder.blocks.3.norm1.bias", + "conf_decoder.blocks.3.attn.qkv.bias", + "conf_decoder.blocks.3.attn.proj.bias", + "conf_decoder.blocks.3.norm2.weight", + "conf_decoder.blocks.3.norm2.bias", + "conf_decoder.blocks.3.mlp.fc1.bias", + "conf_decoder.blocks.3.mlp.fc2.bias", + "conf_decoder.blocks.4.norm1.weight", + "conf_decoder.blocks.4.norm1.bias", + "conf_decoder.blocks.4.attn.qkv.bias", + "conf_decoder.blocks.4.attn.proj.bias", + "conf_decoder.blocks.4.norm2.weight", + "conf_decoder.blocks.4.norm2.bias", + "conf_decoder.blocks.4.mlp.fc1.bias", + "conf_decoder.blocks.4.mlp.fc2.bias", + "conf_decoder.linear_out.bias", + "conf_head.proj.bias", + "camera_decoder.projects.bias", + "camera_decoder.blocks.0.norm1.weight", + "camera_decoder.blocks.0.norm1.bias", + "camera_decoder.blocks.0.attn.qkv.bias", + "camera_decoder.blocks.0.attn.proj.bias", + "camera_decoder.blocks.0.norm2.weight", + "camera_decoder.blocks.0.norm2.bias", + "camera_decoder.blocks.0.mlp.fc1.bias", + "camera_decoder.blocks.0.mlp.fc2.bias", + "camera_decoder.blocks.1.norm1.weight", + "camera_decoder.blocks.1.norm1.bias", + "camera_decoder.blocks.1.attn.qkv.bias", + "camera_decoder.blocks.1.attn.proj.bias", + "camera_decoder.blocks.1.norm2.weight", + "camera_decoder.blocks.1.norm2.bias", + "camera_decoder.blocks.1.mlp.fc1.bias", + "camera_decoder.blocks.1.mlp.fc2.bias", + "camera_decoder.blocks.2.norm1.weight", + "camera_decoder.blocks.2.norm1.bias", + "camera_decoder.blocks.2.attn.qkv.bias", + "camera_decoder.blocks.2.attn.proj.bias", + "camera_decoder.blocks.2.norm2.weight", + "camera_decoder.blocks.2.norm2.bias", + "camera_decoder.blocks.2.mlp.fc1.bias", + "camera_decoder.blocks.2.mlp.fc2.bias", + "camera_decoder.blocks.3.norm1.weight", + "camera_decoder.blocks.3.norm1.bias", + "camera_decoder.blocks.3.attn.qkv.bias", + "camera_decoder.blocks.3.attn.proj.bias", + "camera_decoder.blocks.3.norm2.weight", + "camera_decoder.blocks.3.norm2.bias", + "camera_decoder.blocks.3.mlp.fc1.bias", + "camera_decoder.blocks.3.mlp.fc2.bias", + "camera_decoder.blocks.4.norm1.weight", + "camera_decoder.blocks.4.norm1.bias", + "camera_decoder.blocks.4.attn.qkv.bias", + "camera_decoder.blocks.4.attn.proj.bias", + "camera_decoder.blocks.4.norm2.weight", + "camera_decoder.blocks.4.norm2.bias", + "camera_decoder.blocks.4.mlp.fc1.bias", + "camera_decoder.blocks.4.mlp.fc2.bias", + "camera_decoder.linear_out.bias", + "camera_head.res_conv.0.res_conv1.bias", + "camera_head.res_conv.0.res_conv2.bias", + "camera_head.res_conv.0.res_conv3.bias", + "camera_head.res_conv.1.res_conv1.bias", + "camera_head.res_conv.1.res_conv2.bias", + "camera_head.res_conv.1.res_conv3.bias", + "camera_head.more_mlps.0.bias", + "camera_head.more_mlps.2.bias", + "camera_head.fc_t.bias", + "camera_head.fc_rot.bias" + ], + "lr_scale": 1.0 + }, + "decay": { + "weight_decay": 0.05, + "params": [ + "decoder.0.attn.qkv.weight", + "decoder.0.attn.proj.weight", + "decoder.0.mlp.fc1.weight", + "decoder.0.mlp.fc2.weight", + "decoder.1.attn.qkv.weight", + "decoder.1.attn.proj.weight", + "decoder.1.mlp.fc1.weight", + "decoder.1.mlp.fc2.weight", + "decoder.2.attn.qkv.weight", + "decoder.2.attn.proj.weight", + "decoder.2.mlp.fc1.weight", + "decoder.2.mlp.fc2.weight", + "decoder.3.attn.qkv.weight", + "decoder.3.attn.proj.weight", + "decoder.3.mlp.fc1.weight", + "decoder.3.mlp.fc2.weight", + "decoder.4.attn.qkv.weight", + "decoder.4.attn.proj.weight", + "decoder.4.mlp.fc1.weight", + "decoder.4.mlp.fc2.weight", + "decoder.5.attn.qkv.weight", + "decoder.5.attn.proj.weight", + "decoder.5.mlp.fc1.weight", + "decoder.5.mlp.fc2.weight", + "decoder.6.attn.qkv.weight", + "decoder.6.attn.proj.weight", + "decoder.6.mlp.fc1.weight", + "decoder.6.mlp.fc2.weight", + "decoder.7.attn.qkv.weight", + "decoder.7.attn.proj.weight", + "decoder.7.mlp.fc1.weight", + "decoder.7.mlp.fc2.weight", + "decoder.8.attn.qkv.weight", + "decoder.8.attn.proj.weight", + "decoder.8.mlp.fc1.weight", + "decoder.8.mlp.fc2.weight", + "decoder.9.attn.qkv.weight", + "decoder.9.attn.proj.weight", + "decoder.9.mlp.fc1.weight", + "decoder.9.mlp.fc2.weight", + "decoder.10.attn.qkv.weight", + "decoder.10.attn.proj.weight", + "decoder.10.mlp.fc1.weight", + "decoder.10.mlp.fc2.weight", + "decoder.11.attn.qkv.weight", + "decoder.11.attn.proj.weight", + "decoder.11.mlp.fc1.weight", + "decoder.11.mlp.fc2.weight", + "decoder.12.attn.qkv.weight", + "decoder.12.attn.proj.weight", + "decoder.12.mlp.fc1.weight", + "decoder.12.mlp.fc2.weight", + "decoder.13.attn.qkv.weight", + "decoder.13.attn.proj.weight", + "decoder.13.mlp.fc1.weight", + "decoder.13.mlp.fc2.weight", + "decoder.14.attn.qkv.weight", + "decoder.14.attn.proj.weight", + "decoder.14.mlp.fc1.weight", + "decoder.14.mlp.fc2.weight", + "decoder.15.attn.qkv.weight", + "decoder.15.attn.proj.weight", + "decoder.15.mlp.fc1.weight", + "decoder.15.mlp.fc2.weight", + "decoder.16.attn.qkv.weight", + "decoder.16.attn.proj.weight", + "decoder.16.mlp.fc1.weight", + "decoder.16.mlp.fc2.weight", + "decoder.17.attn.qkv.weight", + "decoder.17.attn.proj.weight", + "decoder.17.mlp.fc1.weight", + "decoder.17.mlp.fc2.weight", + "decoder.18.attn.qkv.weight", + "decoder.18.attn.proj.weight", + "decoder.18.mlp.fc1.weight", + "decoder.18.mlp.fc2.weight", + "decoder.19.attn.qkv.weight", + "decoder.19.attn.proj.weight", + "decoder.19.mlp.fc1.weight", + "decoder.19.mlp.fc2.weight", + "decoder.20.attn.qkv.weight", + "decoder.20.attn.proj.weight", + "decoder.20.mlp.fc1.weight", + "decoder.20.mlp.fc2.weight", + "decoder.21.attn.qkv.weight", + "decoder.21.attn.proj.weight", + "decoder.21.mlp.fc1.weight", + "decoder.21.mlp.fc2.weight", + "decoder.22.attn.qkv.weight", + "decoder.22.attn.proj.weight", + "decoder.22.mlp.fc1.weight", + "decoder.22.mlp.fc2.weight", + "decoder.23.attn.qkv.weight", + "decoder.23.attn.proj.weight", + "decoder.23.mlp.fc1.weight", + "decoder.23.mlp.fc2.weight", + "decoder.24.attn.qkv.weight", + "decoder.24.attn.proj.weight", + "decoder.24.mlp.fc1.weight", + "decoder.24.mlp.fc2.weight", + "decoder.25.attn.qkv.weight", + "decoder.25.attn.proj.weight", + "decoder.25.mlp.fc1.weight", + "decoder.25.mlp.fc2.weight", + "decoder.26.attn.qkv.weight", + "decoder.26.attn.proj.weight", + "decoder.26.mlp.fc1.weight", + "decoder.26.mlp.fc2.weight", + "decoder.27.attn.qkv.weight", + "decoder.27.attn.proj.weight", + "decoder.27.mlp.fc1.weight", + "decoder.27.mlp.fc2.weight", + "decoder.28.attn.qkv.weight", + "decoder.28.attn.proj.weight", + "decoder.28.mlp.fc1.weight", + "decoder.28.mlp.fc2.weight", + "decoder.29.attn.qkv.weight", + "decoder.29.attn.proj.weight", + "decoder.29.mlp.fc1.weight", + "decoder.29.mlp.fc2.weight", + "decoder.30.attn.qkv.weight", + "decoder.30.attn.proj.weight", + "decoder.30.mlp.fc1.weight", + "decoder.30.mlp.fc2.weight", + "decoder.31.attn.qkv.weight", + "decoder.31.attn.proj.weight", + "decoder.31.mlp.fc1.weight", + "decoder.31.mlp.fc2.weight", + "decoder.32.attn.qkv.weight", + "decoder.32.attn.proj.weight", + "decoder.32.mlp.fc1.weight", + "decoder.32.mlp.fc2.weight", + "decoder.33.attn.qkv.weight", + "decoder.33.attn.proj.weight", + "decoder.33.mlp.fc1.weight", + "decoder.33.mlp.fc2.weight", + "decoder.34.attn.qkv.weight", + "decoder.34.attn.proj.weight", + "decoder.34.mlp.fc1.weight", + "decoder.34.mlp.fc2.weight", + "decoder.35.attn.qkv.weight", + "decoder.35.attn.proj.weight", + "decoder.35.mlp.fc1.weight", + "decoder.35.mlp.fc2.weight", + "point_decoder.projects.weight", + "point_decoder.blocks.0.attn.qkv.weight", + "point_decoder.blocks.0.attn.proj.weight", + "point_decoder.blocks.0.mlp.fc1.weight", + "point_decoder.blocks.0.mlp.fc2.weight", + "point_decoder.blocks.1.attn.qkv.weight", + "point_decoder.blocks.1.attn.proj.weight", + "point_decoder.blocks.1.mlp.fc1.weight", + "point_decoder.blocks.1.mlp.fc2.weight", + "point_decoder.blocks.2.attn.qkv.weight", + "point_decoder.blocks.2.attn.proj.weight", + "point_decoder.blocks.2.mlp.fc1.weight", + "point_decoder.blocks.2.mlp.fc2.weight", + "point_decoder.blocks.3.attn.qkv.weight", + "point_decoder.blocks.3.attn.proj.weight", + "point_decoder.blocks.3.mlp.fc1.weight", + "point_decoder.blocks.3.mlp.fc2.weight", + "point_decoder.blocks.4.attn.qkv.weight", + "point_decoder.blocks.4.attn.proj.weight", + "point_decoder.blocks.4.mlp.fc1.weight", + "point_decoder.blocks.4.mlp.fc2.weight", + "point_decoder.linear_out.weight", + "point_head.proj.weight", + "conf_decoder.projects.weight", + "conf_decoder.blocks.0.attn.qkv.weight", + "conf_decoder.blocks.0.attn.proj.weight", + "conf_decoder.blocks.0.mlp.fc1.weight", + "conf_decoder.blocks.0.mlp.fc2.weight", + "conf_decoder.blocks.1.attn.qkv.weight", + "conf_decoder.blocks.1.attn.proj.weight", + "conf_decoder.blocks.1.mlp.fc1.weight", + "conf_decoder.blocks.1.mlp.fc2.weight", + "conf_decoder.blocks.2.attn.qkv.weight", + "conf_decoder.blocks.2.attn.proj.weight", + "conf_decoder.blocks.2.mlp.fc1.weight", + "conf_decoder.blocks.2.mlp.fc2.weight", + "conf_decoder.blocks.3.attn.qkv.weight", + "conf_decoder.blocks.3.attn.proj.weight", + "conf_decoder.blocks.3.mlp.fc1.weight", + "conf_decoder.blocks.3.mlp.fc2.weight", + "conf_decoder.blocks.4.attn.qkv.weight", + "conf_decoder.blocks.4.attn.proj.weight", + "conf_decoder.blocks.4.mlp.fc1.weight", + "conf_decoder.blocks.4.mlp.fc2.weight", + "conf_decoder.linear_out.weight", + "conf_head.proj.weight", + "camera_decoder.projects.weight", + "camera_decoder.blocks.0.attn.qkv.weight", + "camera_decoder.blocks.0.attn.proj.weight", + "camera_decoder.blocks.0.mlp.fc1.weight", + "camera_decoder.blocks.0.mlp.fc2.weight", + "camera_decoder.blocks.1.attn.qkv.weight", + "camera_decoder.blocks.1.attn.proj.weight", + "camera_decoder.blocks.1.mlp.fc1.weight", + "camera_decoder.blocks.1.mlp.fc2.weight", + "camera_decoder.blocks.2.attn.qkv.weight", + "camera_decoder.blocks.2.attn.proj.weight", + "camera_decoder.blocks.2.mlp.fc1.weight", + "camera_decoder.blocks.2.mlp.fc2.weight", + "camera_decoder.blocks.3.attn.qkv.weight", + "camera_decoder.blocks.3.attn.proj.weight", + "camera_decoder.blocks.3.mlp.fc1.weight", + "camera_decoder.blocks.3.mlp.fc2.weight", + "camera_decoder.blocks.4.attn.qkv.weight", + "camera_decoder.blocks.4.attn.proj.weight", + "camera_decoder.blocks.4.mlp.fc1.weight", + "camera_decoder.blocks.4.mlp.fc2.weight", + "camera_decoder.linear_out.weight", + "camera_head.res_conv.0.res_conv1.weight", + "camera_head.res_conv.0.res_conv2.weight", + "camera_head.res_conv.0.res_conv3.weight", + "camera_head.res_conv.1.res_conv1.weight", + "camera_head.res_conv.1.res_conv2.weight", + "camera_head.res_conv.1.res_conv3.weight", + "camera_head.more_mlps.0.weight", + "camera_head.more_mlps.2.weight", + "camera_head.fc_t.weight", + "camera_head.fc_rot.weight" + ], + "lr_scale": 1.0 + } +} +[2026-05-02 01:14:39,625][__main__][INFO] - [RANK 0] Start training for 10 epochs +[2026-05-02 01:14:39,629][__main__][INFO] - [RANK 0] log_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu/ +[2026-05-02 01:15:48,588][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 0/4350] eta: 3 days, 11:19:13 lr: 0.000000 epoch: 0.0000 (0.0000) step: 0.0000 (0.0000) loss: 5439.7471 (5439.7471) Lcamera_frontend: 4.1901 (4.1901) Ldepth_frontend: 16.2068 (16.2068) Lpmap_frontend: 18.4157 (18.4157) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.1823 (4.1823) Ldepth_mix: 16.2088 (16.2088) Lpmap_mix: 18.4142 (18.4142) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.1872 (4.1872) Ldepth_backend: 16.2021 (16.2021) Lpmap_backend: 18.4084 (18.4084) Ltrack_backend: 0.0000 (0.0000) total: 5439.7471 (5439.7471) time: 68.9548 data: 24.5773 max mem: 32998 +[2026-05-02 01:22:55,417][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 10/4350] eta: 2 days, 6:19:56 lr: 0.000000 epoch: 0.0011 (0.0011) step: 5.0000 (4.8182) loss: 7296.2388 (7751.6382) Lcamera_frontend: 5.7233 (6.1598) Ldepth_frontend: 16.2956 (16.8900) Lpmap_frontend: 18.4157 (18.3318) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 5.7620 (6.1551) Ldepth_mix: 16.2991 (16.8970) Lpmap_mix: 18.4142 (18.3373) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 5.7258 (6.0973) Ldepth_backend: 16.3094 (16.9028) Lpmap_backend: 18.4084 (18.3396) Ltrack_backend: 0.0000 (0.0000) total: 7296.2388 (7751.6382) time: 45.0683 data: 2.2670 max mem: 78413 +[2026-05-02 01:31:09,091][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 20/4350] eta: 2 days, 8:39:49 lr: 0.000000 epoch: 0.0023 (0.0023) step: 10.0000 (9.8571) loss: 4755.7080 (6485.0480) Lcamera_frontend: 3.8349 (5.0833) Ldepth_frontend: 16.2956 (16.7722) Lpmap_frontend: 18.2402 (18.2332) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.6581 (5.0778) Ldepth_mix: 16.2991 (16.7825) Lpmap_mix: 18.2422 (18.2399) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.6115 (5.0486) Ldepth_backend: 16.3094 (16.7899) Lpmap_backend: 18.2483 (18.2439) Ltrack_backend: 0.0000 (0.0000) total: 4755.7080 (6485.0480) time: 46.0186 data: 0.0390 max mem: 78608 +[2026-05-02 01:39:23,133][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 30/4350] eta: 2 days, 9:25:00 lr: 0.000000 epoch: 0.0046 (0.0034) step: 20.0000 (14.8710) loss: 3735.7661 (5619.1712) Lcamera_frontend: 2.8120 (4.3580) Ldepth_frontend: 16.8454 (16.9577) Lpmap_frontend: 18.1869 (18.2127) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7571 (4.3618) Ldepth_mix: 16.8627 (16.9682) Lpmap_mix: 18.1918 (18.2193) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7541 (4.3248) Ldepth_backend: 16.8758 (16.9780) Lpmap_backend: 18.1962 (18.2246) Ltrack_backend: 0.0000 (0.0000) total: 3735.7661 (5619.1712) time: 49.3760 data: 0.0394 max mem: 78608 +[2026-05-02 01:46:20,590][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 40/4350] eta: 2 days, 7:30:02 lr: 0.000000 epoch: 0.0069 (0.0046) step: 30.0000 (19.9024) loss: 4175.3257 (5892.0705) Lcamera_frontend: 3.1523 (4.5835) Ldepth_frontend: 16.5594 (16.7492) Lpmap_frontend: 18.1080 (18.1761) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.1597 (4.5845) Ldepth_mix: 16.5518 (16.7583) Lpmap_mix: 18.1183 (18.1816) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.1245 (4.5559) Ldepth_backend: 16.5328 (16.7661) Lpmap_backend: 18.1399 (18.1858) Ltrack_backend: 0.0000 (0.0000) total: 4175.3257 (5892.0705) time: 45.5671 data: 0.0371 max mem: 78608 +[2026-05-02 01:54:45,060][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 50/4350] eta: 2 days, 8:19:46 lr: 0.000000 epoch: 0.0092 (0.0057) step: 40.0000 (24.8824) loss: 4175.3257 (5870.8307) Lcamera_frontend: 3.1523 (4.5652) Ldepth_frontend: 15.8311 (16.5923) Lpmap_frontend: 17.9634 (18.1143) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.1597 (4.5649) Ldepth_mix: 15.8259 (16.6015) Lpmap_mix: 17.9731 (18.1198) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.1285 (4.5411) Ldepth_backend: 15.8247 (16.6095) Lpmap_backend: 17.9753 (18.1242) Ltrack_backend: 0.0000 (0.0000) total: 4175.3257 (5870.8307) time: 46.0933 data: 0.0386 max mem: 78608 +[2026-05-02 02:02:54,303][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 60/4350] eta: 2 days, 8:32:35 lr: 0.000000 epoch: 0.0115 (0.0069) step: 50.0000 (29.8852) loss: 3993.8896 (5538.5084) Lcamera_frontend: 3.0130 (4.2841) Ldepth_frontend: 15.9295 (16.5670) Lpmap_frontend: 17.8362 (18.0703) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.9901 (4.2851) Ldepth_mix: 15.9719 (16.5775) Lpmap_mix: 17.8467 (18.0770) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.9972 (4.2656) Ldepth_backend: 16.0150 (16.5863) Lpmap_backend: 17.8505 (18.0821) Ltrack_backend: 0.0000 (0.0000) total: 3993.8896 (5538.5084) time: 49.6855 data: 0.0376 max mem: 78608 +[2026-05-02 02:10:41,238][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 70/4350] eta: 2 days, 8:17:05 lr: 0.000000 epoch: 0.0138 (0.0080) step: 60.0000 (34.9014) loss: 2019.8086 (5077.0649) Lcamera_frontend: 1.3256 (3.8990) Ldepth_frontend: 16.2305 (16.6073) Lpmap_frontend: 17.6795 (18.0000) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.3377 (3.8994) Ldepth_mix: 16.2331 (16.6158) Lpmap_mix: 17.6791 (18.0054) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.3201 (3.8817) Ldepth_backend: 16.2375 (16.6228) Lpmap_backend: 17.6731 (18.0092) Ltrack_backend: 0.0000 (0.0000) total: 2019.8086 (5077.0649) time: 47.8087 data: 0.0391 max mem: 78608 +[2026-05-02 02:18:44,614][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 80/4350] eta: 2 days, 8:17:55 lr: 0.000000 epoch: 0.0161 (0.0092) step: 70.0000 (39.9136) loss: 2104.1528 (5019.4932) Lcamera_frontend: 1.4395 (3.8555) Ldepth_frontend: 15.7767 (16.4471) Lpmap_frontend: 17.4410 (17.9200) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.4340 (3.8518) Ldepth_mix: 15.7863 (16.4542) Lpmap_mix: 17.4408 (17.9243) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.4361 (3.8361) Ldepth_backend: 15.7929 (16.4599) Lpmap_backend: 17.4133 (17.9275) Ltrack_backend: 0.0000 (0.0000) total: 2104.1528 (5019.4932) time: 47.5140 data: 0.0396 max mem: 78608 +[2026-05-02 02:25:53,868][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 90/4350] eta: 2 days, 7:34:28 lr: 0.000000 epoch: 0.0184 (0.0103) step: 80.0000 (44.9011) loss: 4923.5767 (5067.0565) Lcamera_frontend: 3.8066 (3.8988) Ldepth_frontend: 14.4957 (16.2270) Lpmap_frontend: 17.3165 (17.8568) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.7645 (3.8937) Ldepth_mix: 14.4941 (16.2327) Lpmap_mix: 17.3146 (17.8605) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.7776 (3.8786) Ldepth_backend: 14.4895 (16.2375) Lpmap_backend: 17.3208 (17.8632) Ltrack_backend: 0.0000 (0.0000) total: 4923.5767 (5067.0565) time: 45.6226 data: 0.0348 max mem: 78608 +[2026-05-02 02:33:27,363][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 100/4350] eta: 2 days, 7:15:17 lr: 0.000000 epoch: 0.0207 (0.0115) step: 90.0000 (49.8911) loss: 4923.5767 (5125.7103) Lcamera_frontend: 3.8066 (3.9501) Ldepth_frontend: 14.0182 (16.0263) Lpmap_frontend: 17.1792 (17.7799) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.7645 (3.9445) Ldepth_mix: 14.0235 (16.0316) Lpmap_mix: 17.1757 (17.7832) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.7776 (3.9304) Ldepth_backend: 14.0175 (16.0360) Lpmap_backend: 17.1734 (17.7857) Ltrack_backend: 0.0000 (0.0000) total: 4923.5767 (5125.7103) time: 44.1277 data: 0.0344 max mem: 78608 +[2026-05-02 02:41:37,317][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 110/4350] eta: 2 days, 7:21:26 lr: 0.000001 epoch: 0.0230 (0.0126) step: 100.0000 (54.8829) loss: 4877.1230 (5084.8790) Lcamera_frontend: 3.7518 (3.9181) Ldepth_frontend: 13.4376 (15.8005) Lpmap_frontend: 16.9117 (17.6830) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.7503 (3.9128) Ldepth_mix: 13.4419 (15.8053) Lpmap_mix: 16.9061 (17.6860) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.7520 (3.8998) Ldepth_backend: 13.4440 (15.8096) Lpmap_backend: 16.8987 (17.6885) Ltrack_backend: 0.0000 (0.0000) total: 4877.1230 (5084.8790) time: 47.1700 data: 0.0380 max mem: 78608 +[2026-05-02 02:48:51,964][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 120/4350] eta: 2 days, 6:52:59 lr: 0.000001 epoch: 0.0253 (0.0138) step: 110.0000 (59.8926) loss: 4647.9844 (4950.6465) Lcamera_frontend: 3.5584 (3.8076) Ldepth_frontend: 13.1211 (15.6595) Lpmap_frontend: 16.5926 (17.5811) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5538 (3.8025) Ldepth_mix: 13.1282 (15.6641) Lpmap_mix: 16.5907 (17.5839) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5589 (3.7906) Ldepth_backend: 13.1469 (15.6683) Lpmap_backend: 16.5881 (17.5863) Ltrack_backend: 0.0000 (0.0000) total: 4647.9844 (4950.6465) time: 46.2298 data: 0.0378 max mem: 78608 +[2026-05-02 02:56:33,901][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 130/4350] eta: 2 days, 6:42:26 lr: 0.000001 epoch: 0.0276 (0.0149) step: 120.0000 (64.9008) loss: 2733.9695 (4868.5874) Lcamera_frontend: 1.9781 (3.7412) Ldepth_frontend: 12.9465 (15.4629) Lpmap_frontend: 16.4048 (17.4734) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.9763 (3.7365) Ldepth_mix: 12.9361 (15.4666) Lpmap_mix: 16.4147 (17.4757) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.9768 (3.7255) Ldepth_backend: 12.9314 (15.4697) Lpmap_backend: 16.4238 (17.4775) Ltrack_backend: 0.0000 (0.0000) total: 2733.9695 (4868.5874) time: 44.8291 data: 0.0333 max mem: 78608 +[2026-05-02 03:04:23,553][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 140/4350] eta: 2 days, 6:36:08 lr: 0.000001 epoch: 0.0299 (0.0161) step: 130.0000 (69.9078) loss: 4043.9023 (4806.1195) Lcamera_frontend: 3.0814 (3.6918) Ldepth_frontend: 12.6396 (15.2376) Lpmap_frontend: 15.9093 (17.3541) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.0799 (3.6874) Ldepth_mix: 12.6302 (15.2405) Lpmap_mix: 15.9025 (17.3558) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.0790 (3.6770) Ldepth_backend: 12.6213 (15.2429) Lpmap_backend: 15.8975 (17.3572) Ltrack_backend: 0.0000 (0.0000) total: 4043.9023 (4806.1195) time: 46.5793 data: 0.0337 max mem: 78608 +[2026-05-02 03:12:20,033][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 150/4350] eta: 2 days, 6:32:45 lr: 0.000001 epoch: 0.0322 (0.0172) step: 140.0000 (74.9139) loss: 4388.3794 (4815.9643) Lcamera_frontend: 3.4041 (3.7034) Ldepth_frontend: 11.0197 (14.9571) Lpmap_frontend: 15.5645 (17.2238) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.4033 (3.6992) Ldepth_mix: 11.0160 (14.9594) Lpmap_mix: 15.5572 (17.2252) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4037 (3.6895) Ldepth_backend: 11.0117 (14.9614) Lpmap_backend: 15.5540 (17.2263) Ltrack_backend: 0.0000 (0.0000) total: 4388.3794 (4815.9643) time: 47.3033 data: 0.0364 max mem: 78608 +[2026-05-02 03:21:04,409][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 160/4350] eta: 2 days, 6:49:35 lr: 0.000001 epoch: 0.0345 (0.0184) step: 150.0000 (79.9130) loss: 4584.6353 (4828.3560) Lcamera_frontend: 3.5562 (3.7166) Ldepth_frontend: 10.6893 (14.7231) Lpmap_frontend: 15.2648 (17.0993) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5570 (3.7127) Ldepth_mix: 10.6896 (14.7245) Lpmap_mix: 15.2685 (17.1002) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5572 (3.7035) Ldepth_backend: 10.6855 (14.7257) Lpmap_backend: 15.2754 (17.1008) Ltrack_backend: 0.0000 (0.0000) total: 4584.6353 (4828.3560) time: 50.0357 data: 0.0355 max mem: 78608 +[2026-05-02 03:29:40,139][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 170/4350] eta: 2 days, 6:59:56 lr: 0.000001 epoch: 0.0368 (0.0195) step: 160.0000 (84.9064) loss: 3588.4995 (4690.3600) Lcamera_frontend: 2.7438 (3.6041) Ldepth_frontend: 10.4000 (14.5376) Lpmap_frontend: 14.9626 (16.9780) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7423 (3.6003) Ldepth_mix: 10.3896 (14.5381) Lpmap_mix: 14.9625 (16.9782) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7432 (3.5917) Ldepth_backend: 10.3826 (14.5386) Lpmap_backend: 14.9690 (16.9784) Ltrack_backend: 0.0000 (0.0000) total: 3588.4995 (4690.3600) time: 52.0014 data: 0.0363 max mem: 78608 +[2026-05-02 03:37:55,423][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 180/4350] eta: 2 days, 7:00:20 lr: 0.000001 epoch: 0.0391 (0.0207) step: 170.0000 (89.9006) loss: 4503.9561 (4779.2116) Lcamera_frontend: 3.5134 (3.6813) Ldepth_frontend: 9.8712 (14.2839) Lpmap_frontend: 14.9626 (16.8743) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5094 (3.6777) Ldepth_mix: 9.8633 (14.2836) Lpmap_mix: 14.9625 (16.8741) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5106 (3.6694) Ldepth_backend: 9.8648 (14.2832) Lpmap_backend: 14.9690 (16.8739) Ltrack_backend: 0.0000 (0.0000) total: 4503.9561 (4779.2116) time: 50.5506 data: 0.0408 max mem: 78608 +[2026-05-02 03:46:40,065][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 190/4350] eta: 2 days, 7:10:29 lr: 0.000001 epoch: 0.0414 (0.0218) step: 180.0000 (94.8953) loss: 4562.6719 (4768.5866) Lcamera_frontend: 3.5725 (3.6756) Ldepth_frontend: 9.1872 (14.0536) Lpmap_frontend: 14.7721 (16.7478) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5724 (3.6721) Ldepth_mix: 9.1771 (14.0527) Lpmap_mix: 14.7675 (16.7472) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5717 (3.6642) Ldepth_backend: 9.1655 (14.0516) Lpmap_backend: 14.7637 (16.7467) Ltrack_backend: 0.0000 (0.0000) total: 4562.6719 (4768.5866) time: 50.9961 data: 0.0451 max mem: 78608 +[2026-05-02 03:55:24,681][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 200/4350] eta: 2 days, 7:18:45 lr: 0.000001 epoch: 0.0437 (0.0230) step: 190.0000 (99.8905) loss: 4451.8721 (4785.7538) Lcamera_frontend: 3.4677 (3.6933) Ldepth_frontend: 8.8989 (13.7981) Lpmap_frontend: 14.2427 (16.6206) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.4711 (3.6899) Ldepth_mix: 8.8772 (13.7967) Lpmap_mix: 14.2358 (16.6197) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4657 (3.6825) Ldepth_backend: 8.8569 (13.7951) Lpmap_backend: 14.2313 (16.6190) Ltrack_backend: 0.0000 (0.0000) total: 4451.8721 (4785.7538) time: 52.4627 data: 0.0424 max mem: 78608 +[2026-05-02 04:03:17,038][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 210/4350] eta: 2 days, 7:08:18 lr: 0.000001 epoch: 0.0460 (0.0241) step: 200.0000 (104.8863) loss: 4841.5708 (4780.2808) Lcamera_frontend: 3.8159 (3.6921) Ldepth_frontend: 8.1239 (13.5487) Lpmap_frontend: 14.1029 (16.4993) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.8157 (3.6888) Ldepth_mix: 8.1218 (13.5469) Lpmap_mix: 14.0988 (16.4980) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.8157 (3.6817) Ldepth_backend: 8.1202 (13.5448) Lpmap_backend: 14.1007 (16.4971) Ltrack_backend: 0.0000 (0.0000) total: 4841.5708 (4780.2808) time: 49.8470 data: 0.0386 max mem: 78608 +[2026-05-02 04:11:36,536][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 220/4350] eta: 2 days, 7:06:31 lr: 0.000001 epoch: 0.0483 (0.0253) step: 210.0000 (109.8824) loss: 4116.9346 (4724.9302) Lcamera_frontend: 3.2286 (3.6486) Ldepth_frontend: 7.7186 (13.3494) Lpmap_frontend: 13.9703 (16.3888) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.2278 (3.6455) Ldepth_mix: 7.7106 (13.3470) Lpmap_mix: 13.9648 (16.3871) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.2290 (3.6387) Ldepth_backend: 7.6947 (13.3442) Lpmap_backend: 13.9617 (16.3859) Ltrack_backend: 0.0000 (0.0000) total: 4116.9346 (4724.9302) time: 48.5874 data: 0.0391 max mem: 78608 +[2026-05-02 04:19:14,374][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 230/4350] eta: 2 days, 6:51:49 lr: 0.000001 epoch: 0.0506 (0.0264) step: 220.0000 (114.8874) loss: 3592.6655 (4753.4447) Lcamera_frontend: 2.7775 (3.6758) Ldepth_frontend: 7.3422 (13.1020) Lpmap_frontend: 13.5992 (16.2612) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7765 (3.6728) Ldepth_mix: 7.3393 (13.0991) Lpmap_mix: 13.5960 (16.2592) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7781 (3.6663) Ldepth_backend: 7.3332 (13.0959) Lpmap_backend: 13.5926 (16.2578) Ltrack_backend: 0.0000 (0.0000) total: 3592.6655 (4753.4447) time: 47.8630 data: 0.0361 max mem: 78608 +[2026-05-02 04:27:23,632][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 240/4350] eta: 2 days, 6:46:38 lr: 0.000001 epoch: 0.0529 (0.0276) step: 230.0000 (119.8921) loss: 4193.9175 (4729.9065) Lcamera_frontend: 3.2969 (3.6589) Ldepth_frontend: 7.8148 (12.9080) Lpmap_frontend: 13.3712 (16.1506) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.2965 (3.6559) Ldepth_mix: 7.8154 (12.9046) Lpmap_mix: 13.3681 (16.1484) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.2980 (3.6498) Ldepth_backend: 7.8168 (12.9010) Lpmap_backend: 13.3686 (16.1467) Ltrack_backend: 0.0000 (0.0000) total: 4193.9175 (4729.9065) time: 47.3547 data: 0.0370 max mem: 78608 +[2026-05-02 04:34:53,206][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 250/4350] eta: 2 days, 6:30:24 lr: 0.000001 epoch: 0.0552 (0.0287) step: 240.0000 (124.8964) loss: 2745.3713 (4679.1104) Lcamera_frontend: 2.0429 (3.6190) Ldepth_frontend: 7.9078 (12.7415) Lpmap_frontend: 13.2886 (16.0429) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.0439 (3.6161) Ldepth_mix: 7.8983 (12.7376) Lpmap_mix: 13.2856 (16.0403) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.0335 (3.6103) Ldepth_backend: 7.8948 (12.7336) Lpmap_backend: 13.2888 (16.0384) Ltrack_backend: 0.0000 (0.0000) total: 2745.3713 (4679.1104) time: 46.9415 data: 0.0391 max mem: 78608 +[2026-05-02 04:42:30,433][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 260/4350] eta: 2 days, 6:16:51 lr: 0.000001 epoch: 0.0575 (0.0299) step: 250.0000 (129.9004) loss: 3481.0742 (4684.3168) Lcamera_frontend: 2.7115 (3.6267) Ldepth_frontend: 6.5131 (12.4980) Lpmap_frontend: 12.9252 (15.9200) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.7106 (3.6239) Ldepth_mix: 6.5117 (12.4940) Lpmap_mix: 12.9262 (15.9173) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.7102 (3.6184) Ldepth_backend: 6.5104 (12.4900) Lpmap_backend: 12.9302 (15.9155) Ltrack_backend: 0.0000 (0.0000) total: 3481.0742 (4684.3168) time: 45.3399 data: 0.0383 max mem: 78608 +[2026-05-02 04:50:35,517][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 270/4350] eta: 2 days, 6:10:41 lr: 0.000001 epoch: 0.0598 (0.0310) step: 260.0000 (134.9041) loss: 4508.8149 (4655.1027) Lcamera_frontend: 3.5818 (3.6054) Ldepth_frontend: 5.7420 (12.2852) Lpmap_frontend: 12.7981 (15.8062) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5797 (3.6027) Ldepth_mix: 5.7361 (12.2812) Lpmap_mix: 12.7930 (15.8035) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5820 (3.5973) Ldepth_backend: 5.7358 (12.2772) Lpmap_backend: 12.7943 (15.8018) Ltrack_backend: 0.0000 (0.0000) total: 4508.8149 (4655.1027) time: 47.1099 data: 0.0365 max mem: 78608 +[2026-05-02 04:57:25,163][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 280/4350] eta: 2 days, 5:46:11 lr: 0.000001 epoch: 0.0621 (0.0322) step: 270.0000 (139.9075) loss: 4191.4390 (4647.1543) Lcamera_frontend: 3.2958 (3.6021) Ldepth_frontend: 6.0389 (12.0579) Lpmap_frontend: 12.7756 (15.6775) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.2970 (3.5994) Ldepth_mix: 6.0483 (12.0538) Lpmap_mix: 12.7697 (15.6747) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.2946 (3.5943) Ldepth_backend: 6.0530 (12.0495) Lpmap_backend: 12.7654 (15.6729) Ltrack_backend: 0.0000 (0.0000) total: 4191.4390 (4647.1543) time: 44.7235 data: 0.0370 max mem: 78608 +[2026-05-02 05:05:15,640][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 290/4350] eta: 2 days, 5:37:03 lr: 0.000001 epoch: 0.0644 (0.0333) step: 280.0000 (144.9107) loss: 4670.9741 (4670.9792) Lcamera_frontend: 3.7119 (3.6246) Ldepth_frontend: 5.9312 (11.8789) Lpmap_frontend: 12.4036 (15.5709) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.7154 (3.6219) Ldepth_mix: 5.9225 (11.8747) Lpmap_mix: 12.3986 (15.5680) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.7163 (3.6171) Ldepth_backend: 5.9176 (11.8703) Lpmap_backend: 12.4043 (15.5663) Ltrack_backend: 0.0000 (0.0000) total: 4670.9741 (4670.9792) time: 43.9986 data: 0.0408 max mem: 78608 +[2026-05-02 05:13:02,642][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 300/4350] eta: 2 days, 5:27:14 lr: 0.000001 epoch: 0.0667 (0.0345) step: 290.0000 (149.9136) loss: 3980.8718 (4625.0386) Lcamera_frontend: 3.1443 (3.5890) Ldepth_frontend: 5.6045 (11.7002) Lpmap_frontend: 12.3428 (15.4632) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.1428 (3.5863) Ldepth_mix: 5.6022 (11.6960) Lpmap_mix: 12.3420 (15.4603) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.1448 (3.5817) Ldepth_backend: 5.5965 (11.6915) Lpmap_backend: 12.3449 (15.4586) Ltrack_backend: 0.0000 (0.0000) total: 3980.8718 (4625.0386) time: 46.8735 data: 0.0404 max mem: 78608 +[2026-05-02 05:21:19,142][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 310/4350] eta: 2 days, 5:23:57 lr: 0.000001 epoch: 0.0690 (0.0356) step: 300.0000 (154.9164) loss: 3719.2571 (4613.9610) Lcamera_frontend: 2.9313 (3.5826) Ldepth_frontend: 5.3743 (11.5036) Lpmap_frontend: 12.1202 (15.3530) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.9310 (3.5800) Ldepth_mix: 5.3739 (11.4993) Lpmap_mix: 12.1138 (15.3500) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.9313 (3.5756) Ldepth_backend: 5.3734 (11.4949) Lpmap_backend: 12.1080 (15.3483) Ltrack_backend: 0.0000 (0.0000) total: 3719.2571 (4613.9610) time: 48.1747 data: 0.0357 max mem: 78608 +[2026-05-02 05:29:31,705][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 320/4350] eta: 2 days, 5:19:31 lr: 0.000001 epoch: 0.0713 (0.0368) step: 310.0000 (159.9159) loss: 4960.5903 (4596.7354) Lcamera_frontend: 3.9675 (3.5705) Ldepth_frontend: 4.8698 (11.3546) Lpmap_frontend: 12.2881 (15.2610) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.9699 (3.5679) Ldepth_mix: 4.8641 (11.3503) Lpmap_mix: 12.2837 (15.2579) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.9682 (3.5636) Ldepth_backend: 4.8602 (11.3458) Lpmap_backend: 12.2865 (15.2562) Ltrack_backend: 0.0000 (0.0000) total: 4960.5903 (4596.7354) time: 49.4530 data: 0.0506 max mem: 78608 +[2026-05-02 05:37:15,018][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 330/4350] eta: 2 days, 5:08:55 lr: 0.000002 epoch: 0.0736 (0.0379) step: 320.0000 (164.9154) loss: 5038.3403 (4647.2778) Lcamera_frontend: 4.0410 (3.6152) Ldepth_frontend: 4.6207 (11.1695) Lpmap_frontend: 12.4004 (15.1692) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.0408 (3.6127) Ldepth_mix: 4.5988 (11.1652) Lpmap_mix: 12.3953 (15.1661) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.0401 (3.6086) Ldepth_backend: 4.5872 (11.1607) Lpmap_backend: 12.3927 (15.1645) Ltrack_backend: 0.0000 (0.0000) total: 5038.3403 (4647.2778) time: 47.7890 data: 0.0505 max mem: 78608 +[2026-05-02 05:45:09,610][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 340/4350] eta: 2 days, 5:00:42 lr: 0.000002 epoch: 0.0759 (0.0391) step: 330.0000 (169.9150) loss: 6248.1436 (4689.8981) Lcamera_frontend: 5.0322 (3.6532) Ldepth_frontend: 4.6842 (10.9979) Lpmap_frontend: 11.9383 (15.0774) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 5.0228 (3.6507) Ldepth_mix: 4.6807 (10.9936) Lpmap_mix: 11.9347 (15.0743) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 5.0310 (3.6467) Ldepth_backend: 4.6768 (10.9891) Lpmap_backend: 11.9321 (15.0727) Ltrack_backend: 0.0000 (0.0000) total: 6248.1436 (4689.8981) time: 46.8883 data: 0.0330 max mem: 78608 +[2026-05-02 05:52:53,173][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 350/4350] eta: 2 days, 4:50:26 lr: 0.000002 epoch: 0.0782 (0.0402) step: 340.0000 (174.9145) loss: 4489.6079 (4672.2387) Lcamera_frontend: 3.5861 (3.6412) Ldepth_frontend: 4.6842 (10.8194) Lpmap_frontend: 11.7572 (14.9707) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5837 (3.6386) Ldepth_mix: 4.6807 (10.8150) Lpmap_mix: 11.7550 (14.9675) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5867 (3.6349) Ldepth_backend: 4.6768 (10.8105) Lpmap_backend: 11.7575 (14.9660) Ltrack_backend: 0.0000 (0.0000) total: 4489.6079 (4672.2387) time: 46.9055 data: 0.0399 max mem: 78608 +[2026-05-02 06:00:42,547][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 360/4350] eta: 2 days, 4:41:22 lr: 0.000002 epoch: 0.0805 (0.0414) step: 350.0000 (179.9141) loss: 4359.2837 (4677.3988) Lcamera_frontend: 3.4775 (3.6479) Ldepth_frontend: 4.4092 (10.6521) Lpmap_frontend: 11.6426 (14.8854) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.4747 (3.6453) Ldepth_mix: 4.3984 (10.6478) Lpmap_mix: 11.6373 (14.8821) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4771 (3.6418) Ldepth_backend: 4.3781 (10.6432) Lpmap_backend: 11.6349 (14.8806) Ltrack_backend: 0.0000 (0.0000) total: 4359.2837 (4677.3988) time: 46.6467 data: 0.0455 max mem: 78608 +[2026-05-02 06:08:24,937][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 370/4350] eta: 2 days, 4:31:07 lr: 0.000002 epoch: 0.0828 (0.0425) step: 360.0000 (184.9137) loss: 4740.8579 (4667.8413) Lcamera_frontend: 3.7744 (3.6421) Ldepth_frontend: 4.5237 (10.5071) Lpmap_frontend: 12.0848 (14.8030) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.7638 (3.6394) Ldepth_mix: 4.5182 (10.5028) Lpmap_mix: 12.0942 (14.7997) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.7747 (3.6361) Ldepth_backend: 4.5145 (10.4983) Lpmap_backend: 12.0973 (14.7983) Ltrack_backend: 0.0000 (0.0000) total: 4740.8579 (4667.8413) time: 46.5881 data: 0.0395 max mem: 78608 +[2026-05-02 06:16:52,155][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 380/4350] eta: 2 days, 4:28:47 lr: 0.000002 epoch: 0.0851 (0.0437) step: 370.0000 (189.9134) loss: 4655.8022 (4642.4782) Lcamera_frontend: 3.7364 (3.6232) Ldepth_frontend: 4.7282 (10.3629) Lpmap_frontend: 11.7485 (14.7112) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.7308 (3.6206) Ldepth_mix: 4.7133 (10.3585) Lpmap_mix: 11.7448 (14.7079) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.7364 (3.6174) Ldepth_backend: 4.7069 (10.3540) Lpmap_backend: 11.7383 (14.7065) Ltrack_backend: 0.0000 (0.0000) total: 4655.8022 (4642.4782) time: 48.4802 data: 0.0391 max mem: 78608 +[2026-05-02 06:25:02,362][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 390/4350] eta: 2 days, 4:23:15 lr: 0.000002 epoch: 0.0874 (0.0448) step: 380.0000 (194.9130) loss: 4655.8022 (4656.7480) Lcamera_frontend: 3.7364 (3.6373) Ldepth_frontend: 3.9775 (10.2164) Lpmap_frontend: 11.1651 (14.6171) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.7308 (3.6347) Ldepth_mix: 3.9693 (10.2120) Lpmap_mix: 11.1612 (14.6137) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.7364 (3.6317) Ldepth_backend: 3.9635 (10.2076) Lpmap_backend: 11.1563 (14.6123) Ltrack_backend: 0.0000 (0.0000) total: 4655.8022 (4656.7480) time: 49.8683 data: 0.0426 max mem: 78608 +[2026-05-02 06:32:40,740][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 400/4350] eta: 2 days, 4:12:23 lr: 0.000002 epoch: 0.0897 (0.0460) step: 390.0000 (199.9127) loss: 4428.9600 (4618.8660) Lcamera_frontend: 3.5498 (3.6079) Ldepth_frontend: 4.1744 (10.0852) Lpmap_frontend: 10.8094 (14.5211) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5452 (3.6053) Ldepth_mix: 4.1709 (10.0809) Lpmap_mix: 10.8027 (14.5177) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5496 (3.6024) Ldepth_backend: 4.1680 (10.0765) Lpmap_backend: 10.8033 (14.5163) Ltrack_backend: 0.0000 (0.0000) total: 4428.9600 (4618.8660) time: 47.4239 data: 0.0470 max mem: 78608 +[2026-05-02 06:40:29,620][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 410/4350] eta: 2 days, 4:03:20 lr: 0.000002 epoch: 0.0920 (0.0471) step: 400.0000 (204.9124) loss: 4394.9678 (4603.6969) Lcamera_frontend: 3.5124 (3.5973) Ldepth_frontend: 4.0765 (9.9538) Lpmap_frontend: 11.1108 (14.4389) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5064 (3.5947) Ldepth_mix: 4.0689 (9.9496) Lpmap_mix: 11.1036 (14.4355) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5126 (3.5919) Ldepth_backend: 4.0628 (9.9452) Lpmap_backend: 11.1045 (14.4341) Ltrack_backend: 0.0000 (0.0000) total: 4394.9678 (4603.6969) time: 46.3576 data: 0.0470 max mem: 78608 +[2026-05-02 06:48:07,254][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 420/4350] eta: 2 days, 3:52:36 lr: 0.000002 epoch: 0.0943 (0.0483) step: 410.0000 (209.9121) loss: 4223.4790 (4568.6497) Lcamera_frontend: 3.3775 (3.5699) Ldepth_frontend: 4.1790 (9.8462) Lpmap_frontend: 11.1108 (14.3521) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3668 (3.5672) Ldepth_mix: 4.1751 (9.8419) Lpmap_mix: 11.1036 (14.3485) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.3774 (3.5647) Ldepth_backend: 4.1693 (9.8376) Lpmap_backend: 11.1045 (14.3472) Ltrack_backend: 0.0000 (0.0000) total: 4223.4790 (4568.6497) time: 46.3227 data: 0.0384 max mem: 78608 +[2026-05-02 06:56:01,669][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 430/4350] eta: 2 days, 3:44:34 lr: 0.000002 epoch: 0.0966 (0.0494) step: 420.0000 (214.9118) loss: 3679.8608 (4566.3171) Lcamera_frontend: 2.9084 (3.5698) Ldepth_frontend: 4.7060 (9.7321) Lpmap_frontend: 10.7529 (14.2725) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.9113 (3.5671) Ldepth_mix: 4.7047 (9.7278) Lpmap_mix: 10.7470 (14.2689) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.9068 (3.5647) Ldepth_backend: 4.7049 (9.7235) Lpmap_backend: 10.7480 (14.2676) Ltrack_backend: 0.0000 (0.0000) total: 3679.8608 (4566.3171) time: 46.6023 data: 0.0393 max mem: 78608 +[2026-05-02 07:03:53,559][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 440/4350] eta: 2 days, 3:36:09 lr: 0.000002 epoch: 0.0989 (0.0506) step: 430.0000 (219.9116) loss: 3679.8608 (4528.5202) Lcamera_frontend: 2.9084 (3.5395) Ldepth_frontend: 5.7144 (9.6702) Lpmap_frontend: 11.5739 (14.2122) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.9113 (3.5367) Ldepth_mix: 5.7130 (9.6660) Lpmap_mix: 11.5758 (14.2085) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.9068 (3.5345) Ldepth_backend: 5.7153 (9.6619) Lpmap_backend: 11.5766 (14.2071) Ltrack_backend: 0.0000 (0.0000) total: 3679.8608 (4528.5202) time: 47.3151 data: 0.0427 max mem: 78608 +[2026-05-02 07:12:09,673][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 450/4350] eta: 2 days, 3:31:15 lr: 0.000002 epoch: 0.1011 (0.0517) step: 440.0000 (224.9135) loss: 4102.1382 (4525.2955) Lcamera_frontend: 3.2252 (3.5385) Ldepth_frontend: 4.7831 (9.5636) Lpmap_frontend: 11.2134 (14.1356) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.2183 (3.5357) Ldepth_mix: 4.7791 (9.5595) Lpmap_mix: 11.2071 (14.1318) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.2259 (3.5336) Ldepth_backend: 4.7714 (9.5553) Lpmap_backend: 11.2107 (14.1304) Ltrack_backend: 0.0000 (0.0000) total: 4102.1382 (4525.2955) time: 48.3975 data: 0.0428 max mem: 78608 +[2026-05-02 07:20:48,452][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 460/4350] eta: 2 days, 3:29:24 lr: 0.000002 epoch: 0.1034 (0.0529) step: 450.0000 (229.9154) loss: 4274.7393 (4506.7233) Lcamera_frontend: 3.4219 (3.5246) Ldepth_frontend: 4.4790 (9.4752) Lpmap_frontend: 10.6576 (14.0635) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.4149 (3.5217) Ldepth_mix: 4.4725 (9.4711) Lpmap_mix: 10.6511 (14.0597) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4215 (3.5198) Ldepth_backend: 4.4745 (9.4671) Lpmap_backend: 10.6527 (14.0583) Ltrack_backend: 0.0000 (0.0000) total: 4274.7393 (4506.7233) time: 50.7390 data: 0.0426 max mem: 78608 +[2026-05-02 07:29:09,762][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 470/4350] eta: 2 days, 3:24:52 lr: 0.000002 epoch: 0.1057 (0.0540) step: 460.0000 (234.9172) loss: 2672.1321 (4487.0258) Lcamera_frontend: 2.0881 (3.5096) Ldepth_frontend: 4.9512 (9.3907) Lpmap_frontend: 10.9740 (13.9991) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.0739 (3.5066) Ldepth_mix: 4.9512 (9.3868) Lpmap_mix: 10.9654 (13.9953) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.0877 (3.5049) Ldepth_backend: 4.9571 (9.3828) Lpmap_backend: 10.9658 (13.9939) Ltrack_backend: 0.0000 (0.0000) total: 2672.1321 (4487.0258) time: 51.0014 data: 0.0387 max mem: 78608 +[2026-05-02 07:36:18,752][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 480/4350] eta: 2 days, 3:10:28 lr: 0.000002 epoch: 0.1080 (0.0552) step: 470.0000 (239.9189) loss: 2672.1321 (4457.7627) Lcamera_frontend: 2.0881 (3.4867) Ldepth_frontend: 4.8497 (9.3036) Lpmap_frontend: 10.8684 (13.9270) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.0739 (3.4837) Ldepth_mix: 4.8514 (9.2997) Lpmap_mix: 10.8650 (13.9231) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.0877 (3.4821) Ldepth_backend: 4.8518 (9.2958) Lpmap_backend: 10.8630 (13.9218) Ltrack_backend: 0.0000 (0.0000) total: 2672.1321 (4457.7627) time: 46.5149 data: 0.0371 max mem: 78608 +[2026-05-02 07:44:16,908][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 490/4350] eta: 2 days, 3:02:48 lr: 0.000002 epoch: 0.1103 (0.0563) step: 480.0000 (244.9206) loss: 4570.7773 (4492.1682) Lcamera_frontend: 3.6608 (3.5170) Ldepth_frontend: 4.1658 (9.2019) Lpmap_frontend: 10.6379 (13.8618) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.6551 (3.5140) Ldepth_mix: 4.1629 (9.1981) Lpmap_mix: 10.6353 (13.8579) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.6603 (3.5124) Ldepth_backend: 4.1498 (9.1941) Lpmap_backend: 10.6326 (13.8566) Ltrack_backend: 0.0000 (0.0000) total: 4570.7773 (4492.1682) time: 45.3572 data: 0.0400 max mem: 78608 +[2026-05-02 07:51:59,484][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 500/4350] eta: 2 days, 2:53:09 lr: 0.000002 epoch: 0.1126 (0.0575) step: 490.0000 (249.9222) loss: 5368.7139 (4490.8729) Lcamera_frontend: 4.3299 (3.5171) Ldepth_frontend: 3.9676 (9.1234) Lpmap_frontend: 10.9663 (13.8088) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.3251 (3.5141) Ldepth_mix: 3.9632 (9.1197) Lpmap_mix: 10.9597 (13.8048) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.3292 (3.5127) Ldepth_backend: 3.9585 (9.1159) Lpmap_backend: 10.9590 (13.8035) Ltrack_backend: 0.0000 (0.0000) total: 5368.7139 (4490.8729) time: 47.0364 data: 0.0392 max mem: 78608 +[2026-05-02 07:59:26,021][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 510/4350] eta: 2 days, 2:41:33 lr: 0.000002 epoch: 0.1149 (0.0586) step: 500.0000 (254.9237) loss: 5368.7139 (4521.7455) Lcamera_frontend: 4.3299 (3.5443) Ldepth_frontend: 4.2202 (9.0275) Lpmap_frontend: 11.0006 (13.7536) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.3251 (3.5413) Ldepth_mix: 4.2150 (9.0237) Lpmap_mix: 10.9940 (13.7495) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.3292 (3.5400) Ldepth_backend: 4.2075 (9.0198) Lpmap_backend: 10.9915 (13.7482) Ltrack_backend: 0.0000 (0.0000) total: 5368.7139 (4521.7455) time: 45.4547 data: 0.0358 max mem: 78608 +[2026-05-02 08:06:46,825][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 520/4350] eta: 2 days, 2:29:24 lr: 0.000002 epoch: 0.1172 (0.0598) step: 510.0000 (259.9251) loss: 5378.4395 (4514.5812) Lcamera_frontend: 4.3361 (3.5396) Ldepth_frontend: 4.3307 (8.9524) Lpmap_frontend: 10.8269 (13.6930) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 4.3329 (3.5365) Ldepth_mix: 4.3201 (8.9487) Lpmap_mix: 10.8119 (13.6889) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 4.3365 (3.5354) Ldepth_backend: 4.3105 (8.9448) Lpmap_backend: 10.8128 (13.6876) Ltrack_backend: 0.0000 (0.0000) total: 5378.4395 (4514.5812) time: 44.3650 data: 0.0366 max mem: 78608 +[2026-05-02 08:14:57,910][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 530/4350] eta: 2 days, 2:23:28 lr: 0.000002 epoch: 0.1195 (0.0609) step: 520.0000 (264.9266) loss: 4904.3882 (4528.3333) Lcamera_frontend: 3.9475 (3.5524) Ldepth_frontend: 3.8068 (8.8720) Lpmap_frontend: 10.8597 (13.6391) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.9312 (3.5493) Ldepth_mix: 3.7953 (8.8683) Lpmap_mix: 10.8571 (13.6348) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.9461 (3.5482) Ldepth_backend: 3.7868 (8.8644) Lpmap_backend: 10.8528 (13.6336) Ltrack_backend: 0.0000 (0.0000) total: 4904.3882 (4528.3333) time: 46.5908 data: 0.0396 max mem: 78608 +[2026-05-02 08:22:51,469][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 540/4350] eta: 2 days, 2:15:23 lr: 0.000002 epoch: 0.1218 (0.0621) step: 530.0000 (269.9279) loss: 4816.3633 (4526.7610) Lcamera_frontend: 3.8457 (3.5523) Ldepth_frontend: 3.9130 (8.7945) Lpmap_frontend: 11.0622 (13.5900) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.8394 (3.5491) Ldepth_mix: 3.9054 (8.7908) Lpmap_mix: 11.0505 (13.5857) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.8518 (3.5481) Ldepth_backend: 3.8939 (8.7869) Lpmap_backend: 11.0470 (13.5845) Ltrack_backend: 0.0000 (0.0000) total: 4816.3633 (4526.7610) time: 48.2297 data: 0.0382 max mem: 78608 +[2026-05-02 08:30:49,267][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 550/4350] eta: 2 days, 2:07:49 lr: 0.000003 epoch: 0.1241 (0.0632) step: 540.0000 (274.9292) loss: 4476.8623 (4529.0379) Lcamera_frontend: 3.5796 (3.5555) Ldepth_frontend: 3.9130 (8.7119) Lpmap_frontend: 10.7210 (13.5338) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.5713 (3.5523) Ldepth_mix: 3.9054 (8.7082) Lpmap_mix: 10.7124 (13.5294) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.5792 (3.5514) Ldepth_backend: 3.8939 (8.7043) Lpmap_backend: 10.7145 (13.5282) Ltrack_backend: 0.0000 (0.0000) total: 4476.8623 (4529.0379) time: 47.5677 data: 0.0388 max mem: 78608 +[2026-05-02 08:38:50,704][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 560/4350] eta: 2 days, 2:00:38 lr: 0.000003 epoch: 0.1264 (0.0644) step: 550.0000 (279.9305) loss: 4251.3789 (4501.8727) Lcamera_frontend: 3.4037 (3.5340) Ldepth_frontend: 3.7894 (8.6479) Lpmap_frontend: 10.5072 (13.4768) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 3.3982 (3.5307) Ldepth_mix: 3.7742 (8.6444) Lpmap_mix: 10.4941 (13.4723) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 3.4028 (3.5300) Ldepth_backend: 3.7652 (8.6406) Lpmap_backend: 10.4951 (13.4712) Ltrack_backend: 0.0000 (0.0000) total: 4251.3789 (4501.8727) time: 47.9616 data: 0.0424 max mem: 78608 +[2026-05-02 08:46:58,046][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 570/4350] eta: 2 days, 1:54:04 lr: 0.000003 epoch: 0.1287 (0.0655) step: 560.0000 (284.9317) loss: 2446.9312 (4483.3973) Lcamera_frontend: 1.8582 (3.5196) Ldepth_frontend: 5.2431 (8.5927) Lpmap_frontend: 10.4526 (13.4309) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 1.8515 (3.5163) Ldepth_mix: 5.2475 (8.5892) Lpmap_mix: 10.4325 (13.4264) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 1.8560 (3.5157) Ldepth_backend: 5.2615 (8.5855) Lpmap_backend: 10.4244 (13.4252) Ltrack_backend: 0.0000 (0.0000) total: 2446.9312 (4483.3973) time: 48.4378 data: 0.0402 max mem: 78608 +[2026-05-02 08:54:51,066][croco.utils.misc][INFO] - [RANK 0] Epoch: [0] [ 580/4350] eta: 2 days, 1:45:54 lr: 0.000003 epoch: 0.1310 (0.0667) step: 570.0000 (289.9329) loss: 3681.2949 (4484.2430) Lcamera_frontend: 2.8750 (3.5214) Ldepth_frontend: 4.5045 (8.5301) Lpmap_frontend: 10.4526 (13.3799) Ltrack_frontend: 0.0000 (0.0000) Lcamera_mix: 2.8603 (3.5180) Ldepth_mix: 4.5061 (8.5266) Lpmap_mix: 10.4325 (13.3752) Ltrack_mix: 0.0000 (0.0000) Lcamera_backend: 2.8769 (3.5175) Ldepth_backend: 4.5067 (8.5230) Lpmap_backend: 10.4244 (13.3741) Ltrack_backend: 0.0000 (0.0000) total: 3681.2949 (4484.2430) time: 48.0132 data: 0.0375 max mem: 78608 diff --git a/outdoor_v48_4gpu_v2/.hydra/config.yaml b/outdoor_v48_4gpu_v2/.hydra/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d0020abb27a135318f331690c1d3264102eb430 --- /dev/null +++ b/outdoor_v48_4gpu_v2/.hydra/config.yaml @@ -0,0 +1,68 @@ +teacher: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +pretrained: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model +load_only_encoder: false +long_context: false +fixed_length: true +resume: null +benchmark: false +num_views: 64 +num_test_views: 4 +n_corres_train: 0 +n_corres_test: 0 +train_criterion: DistillLoss() +test_criterion: DistillLoss() +allow_repeat: false +root_vkitti2: /scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti +root_kitti: /scratch-shared/wwei2/eval/kitti_odometry/dataset +root_kitti_velo: /gpfs/work2/0/prjs0824/semantickitti/dataset +root_kitti360: /scratch-shared/wwei2/downloads/kitti360/KITTI-360 +root_kitti360_velo: /scratch-shared/wwei2/downloads/kitti360/KITTI-360 +root_waymo: /scratch-shared/wwei2/waymo_v2 +root_waymo_lidar: /scratch-shared/wwei2/waymo_v2 +dataset_vkitti2: VirtualKITTI2_Multi(allow_repeat=${allow_repeat}, split='train', + ROOT="${root_vkitti2}", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294), + (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views}, + n_corres=${n_corres_train}) +dataset_kitti360: KITTI360_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_kitti360}", + velodyne_root="${root_kitti360_velo}", aug_crop=16, resolution=[(518, 392), (518, + 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, + num_views=${num_views}, n_corres=${n_corres_train}) +dataset_waymo: Waymo_v2_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_waymo}", + lidar_root="${root_waymo_lidar}", aug_crop=16, resolution=[(518, 392), (518, 336), + (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views}, + n_corres=${n_corres_train}) +train_dataset: 6000 @ ${dataset_vkitti2} + 6000 @ ${dataset_kitti360} + 5400 @ ${dataset_waymo} +test_dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="${root_vkitti2}", resolution=(518, + 154), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test}) +seed: 0 +batch_size: 1 +accum_iter: 1 +gradient_checkpointing: false +epochs: 10 +start_epoch: 0 +start_step: 0 +weight_decay: 0.05 +lr: 1.0e-05 +min_lr: 1.0e-08 +warmup_epochs: 0.5 +amp: 1 +num_workers: 4 +world_size: 1 +local-rank: -1 +dist_url: env:// +rank: 0 +gpu: 0 +distributed: false +dist_backend: nccl +eval_freq: 1 +save_freq: 0.1 +max_checkpoints: 10 +keep_freq: 1 +print_freq: 10 +print_img_freq: 50000000 +num_imgs_vis: 4 +save_dir: /scratch-shared/wwei2/training_upstream/checkpoints +exp_name: outdoor_v48_4gpu_v2 +task: StreamVGGT +logdir: ${save_dir}/${exp_name}/logs +output_dir: ${save_dir}/${exp_name}/ diff --git a/outdoor_v48_4gpu_v2/.hydra/hydra.yaml b/outdoor_v48_4gpu_v2/.hydra/hydra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e36a39209469db6ef288f0a76e8726b9cbb7c11b --- /dev/null +++ b/outdoor_v48_4gpu_v2/.hydra/hydra.yaml @@ -0,0 +1,155 @@ +hydra: + run: + dir: ${save_dir}/${exp_name} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} + launcher: + _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher + sweeper: + _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper + max_batch_size: null + params: null + help: + app_name: ${hydra.job.name} + header: '${hydra.help.app_name} is powered by Hydra. + + ' + footer: 'Powered by Hydra (https://hydra.cc) + + Use --hydra-help to view Hydra specific help + + ' + template: '${hydra.help.header} + + == Configuration groups == + + Compose your configuration from those groups (group=option) + + + $APP_CONFIG_GROUPS + + + == Config == + + Override anything in the config (foo.bar=value) + + + $CONFIG + + + ${hydra.help.footer} + + ' + hydra_help: + template: 'Hydra (${hydra.runtime.version}) + + See https://hydra.cc for more info. + + + == Flags == + + $FLAGS_HELP + + + == Configuration groups == + + Compose your configuration from those groups (For example, append hydra/job_logging=disabled + to command line) + + + $HYDRA_CONFIG_GROUPS + + + Use ''--cfg hydra'' to Show the Hydra config. + + ' + hydra_help: ??? + hydra_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][HYDRA] %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + root: + level: INFO + handlers: + - console + loggers: + logging_example: + level: DEBUG + disable_existing_loggers: false + job_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log + root: + level: INFO + handlers: + - console + - file + disable_existing_loggers: false + env: {} + mode: RUN + searchpath: [] + callbacks: {} + output_subdir: .hydra + overrides: + hydra: + - hydra.mode=RUN + task: + - exp_name=outdoor_v48_4gpu_v2 + job: + name: mytrain + chdir: null + override_dirname: exp_name=outdoor_v48_4gpu_v2 + id: ??? + num: ??? + config_name: outdoor_v48 + env_set: {} + env_copy: [] + config: + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: [] + runtime: + version: 1.3.2 + version_base: '1.3' + cwd: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src + config_sources: + - path: hydra.conf + schema: pkg + provider: hydra + - path: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/config + schema: file + provider: main + - path: '' + schema: structured + provider: schema + output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2 + choices: + hydra/env: default + hydra/callbacks: null + hydra/job_logging: default + hydra/hydra_logging: default + hydra/hydra_help: default + hydra/help: default + hydra/sweeper: basic + hydra/launcher: basic + hydra/output: default + verbose: true diff --git a/outdoor_v48_4gpu_v2/.hydra/overrides.yaml b/outdoor_v48_4gpu_v2/.hydra/overrides.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f5672c5a1c47b238c798b03683844c953dc9c6fc --- /dev/null +++ b/outdoor_v48_4gpu_v2/.hydra/overrides.yaml @@ -0,0 +1 @@ +- exp_name=outdoor_v48_4gpu_v2 diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/__init__.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7efad5b728988b807497c6a70c1b5de61904cc7d --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/__init__.py @@ -0,0 +1,91 @@ +from .utils.transforms import * +from .base.batched_sampler import BatchedRandomSampler # noqa +from .arkitscenes import ARKitScenes_Multi # noqa +from .arkitscenes_highres import ARKitScenesHighRes_Multi +from .bedlam import BEDLAM_Multi +from .blendedmvs import BlendedMVS_Multi # noqa +from .co3d import Co3d_Multi # noqa +from .cop3d import Cop3D_Multi +from .dl3dv import DL3DV_Multi +from .dynamic_replica import DynamicReplica +from .eden import EDEN_Multi +from .hypersim import HyperSim_Multi +from .hoi4d import HOI4D_Multi +from .irs import IRS +from .mapfree import MapFree_Multi +from .megadepth import MegaDepth_Multi # noqa +from .mp3d import MP3D_Multi +from .mvimgnet import MVImgNet_Multi +from .mvs_synth import MVS_Synth_Multi +from .omniobject3d import OmniObject3D_Multi +from .pointodyssey import PointOdyssey_Multi +from .realestate10k import RE10K_Multi +from .scannet import ScanNet_Multi +from .scannetpp import ScanNetpp_Multi # noqa +from .smartportraits import SmartPortraits_Multi +from .spring import Spring +from .synscapes import SynScapes +from .tartanair import TartanAir_Multi +from .threedkb import ThreeDKenBurns +from .uasol import UASOL_Multi +from .urbansyn import UrbanSyn +from .unreal4k import UnReal4K_Multi +from .vkitti2 import VirtualKITTI2_Multi # noqa +from .waymo import Waymo_Multi # noqa (legacy h5 format) +from .waymo_v2 import Waymo_v2_Multi # noqa (parquet v2.0.1, with TOP-lidar) +from .kitti import KITTI_Multi # noqa (KITTI odometry + Velodyne) +from .kitti360 import KITTI360_Multi # noqa (KITTI-360 + Velodyne) +from .wildrgbd import WildRGBD_Multi # noqa + +from .habitat_hm3d import HabitatHM3D_Multi + + +from accelerate import Accelerator + + +def get_data_loader( + dataset, + batch_size, + num_workers=8, + shuffle=True, + drop_last=True, + pin_mem=True, + accelerator: Accelerator = None, + fixed_length=False, +): + import torch + + # pytorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + + try: + sampler = dataset.make_sampler( + batch_size, + shuffle=shuffle, + drop_last=drop_last, + world_size=accelerator.num_processes, + fixed_length=fixed_length, + ) + shuffle = False + + data_loader = torch.utils.data.DataLoader( + dataset, + batch_sampler=sampler, + num_workers=num_workers, + pin_memory=pin_mem, + ) + + except (AttributeError, NotImplementedError): + sampler = None + + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=drop_last, + ) + + return data_loader diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/arkitscenes.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/arkitscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..49d69d414a3b452d5619f5c8cbc55a89ec158a5b --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/arkitscenes.py @@ -0,0 +1,246 @@ +import os.path as osp +import os +import sys +import itertools + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2, imread_pil + + +def stratified_sampling(indices, num_samples, rng=None): + if num_samples > len(indices): + raise ValueError("num_samples cannot exceed the number of available indices.") + elif num_samples == len(indices): + return indices + + sorted_indices = sorted(indices) + stride = len(sorted_indices) / num_samples + sampled_indices = [] + if rng is None: + rng = np.random.default_rng() + + for i in range(num_samples): + start = int(i * stride) + end = int((i + 1) * stride) + # Ensure end does not exceed the list + end = min(end, len(sorted_indices)) + if start < end: + # Randomly select within the current stratum + rand_idx = rng.integers(start, end) + sampled_indices.append(sorted_indices[rand_idx]) + else: + # In case of any rounding issues, select the last index + sampled_indices.append(sorted_indices[-1]) + + return rng.permutation(sampled_indices) + + +class ARKitScenes_Multi(BaseMultiViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 8 + super().__init__(*args, **kwargs) + if split == "train": + self.split = "Training" + elif split == "test": + self.split = "Test" + else: + raise ValueError("") + + self.loaded_data = self._load_data(self.split) + print('DATA: arkit', len(self)) + + def _load_data(self, split): + with np.load(osp.join(self.ROOT, split, "all_metadata.npz")) as data: + self.scenes: np.ndarray = data["scenes"] + ''' + high_res_list = np.array( + [ + d + for d in os.listdir( + os.path.join( + self.ROOT.rstrip("/"),# + "_highres", + split if split == "Training" else "Test",#"Validation", + ) + ) + if os.path.join(self.ROOT, split, d) + #if os.path.join(self.ROOT + "_highres", split, d) + ] + ) + self.scenes = np.setdiff1d(self.scenes, high_res_list) + ''' + offset = 0 + counts = [] + scenes = [] + sceneids = [] + images = [] + intrinsics = [] + trajectories = [] + groups = [] + id_ranges = [] + j = 0 + for scene_idx, scene in enumerate(self.scenes): + scene_dir = osp.join(self.ROOT, self.split, scene) + with np.load( + osp.join(scene_dir, "new_scene_metadata.npz"), allow_pickle=True + ) as data: + imgs = data["images"] + intrins = data["intrinsics"] + traj = data["trajectories"] + min_seq_len = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + if len(imgs) < min_seq_len: + print(f"Skipping {scene}") + continue + + collections = {} + assert "image_collection" in data, "Image collection not found" + collections["image"] = data["image_collection"] + + num_imgs = imgs.shape[0] + img_groups = [] + min_group_len = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + for ref_id, group in collections["image"].item().items(): + if len(group) + 1 < min_group_len: + continue + + # groups are (idx, score)s + group.insert(0, (ref_id, 1.0)) + group = [int(x[0] + offset) for x in group] + img_groups.append(sorted(group)) + + if len(img_groups) == 0: + print(f"Skipping {scene}") + continue + + scenes.append(scene) + sceneids.extend([j] * num_imgs) + id_ranges.extend([(offset, offset + num_imgs) for _ in range(num_imgs)]) + images.extend(imgs) + K = np.expand_dims(np.eye(3), 0).repeat(num_imgs, 0) + + K[:, 0, 0] = [fx for _, _, fx, _, _, _ in intrins] + K[:, 1, 1] = [fy for _, _, _, fy, _, _ in intrins] + K[:, 0, 2] = [cx for _, _, _, _, cx, _ in intrins] + K[:, 1, 2] = [cy for _, _, _, _, _, cy in intrins] + intrinsics.extend(list(K)) + trajectories.extend(list(traj)) + + # offset groups + groups.extend(img_groups) + counts.append(offset) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.id_ranges = id_ranges + self.images = images + self.intrinsics = intrinsics + self.trajectories = trajectories + self.groups = groups + + def __len__(self): + return len(self.groups) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + + if rng.choice([True, False]): + image_idxs = np.arange(self.id_ranges[idx][0], self.id_ranges[idx][1]) + cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3) + start_image_idxs = image_idxs[: len(image_idxs) - cut_off + 1] + start_id = rng.choice(start_image_idxs) + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + image_idxs.tolist(), + rng, + max_interval=self.max_interval, + video_prob=0.8, + fix_interval_prob=0.5, + block_shuffle=16, + ) + image_idxs = np.array(image_idxs)[pos] + else: + ordered_video = False + image_idxs = self.groups[idx] + image_idxs = rng.permutation(image_idxs) + if len(image_idxs) > num_views: + image_idxs = image_idxs[:num_views] + else: + if rng.random() < 0.8: + image_idxs = rng.choice(image_idxs, size=num_views, replace=True) + else: + repeat_num = num_views // len(image_idxs) + 1 + image_idxs = np.tile(image_idxs, repeat_num)[:num_views] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id]) + + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + assert ( + basename[:8] == self.scenes[scene_id] + ), f"{basename}, {self.scenes[scene_id]}" + # print(scene_dir, basename) + # Load RGB image + rgb_image = imread_pil( + osp.join(scene_dir, "vga_wide", basename.replace(".png", ".jpg")) + ) + # Load depthmap + depthmap = imread_cv2( + osp.join(scene_dir, "lowres_depth", basename), cv2.IMREAD_UNCHANGED + ) + depthmap = depthmap.astype(np.float32) / 1000.0 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="arkitscenes", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/arkitscenes_highres.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/arkitscenes_highres.py new file mode 100644 index 0000000000000000000000000000000000000000..92826e1c46a067ed93ffc30d0470685085377bf6 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/arkitscenes_highres.py @@ -0,0 +1,175 @@ +import os.path as osp +import os +import sys +import itertools + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np +import h5py +import math +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class ARKitScenesHighRes_Multi(BaseMultiViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.max_interval = 8 + self.is_metric = True + super().__init__(*args, **kwargs) + if split == "train": + self.split = "Training" + elif split == "test": + self.split = "Validation" + else: + raise ValueError("") + + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + all_scenes = sorted( + [ + d + for d in os.listdir(osp.join(self.ROOT, split)) + if osp.isdir(osp.join(self.ROOT, split, d)) + ] + ) + offset = 0 + scenes = [] + sceneids = [] + images = [] + start_img_ids = [] + scene_img_list = [] + timestamps = [] + intrinsics = [] + trajectories = [] + scene_id = 0 + for scene in all_scenes: + scene_dir = osp.join(self.ROOT, self.split, scene) + with np.load(osp.join(scene_dir, "scene_metadata.npz")) as data: + imgs_with_indices = sorted( + enumerate(data["images"]), key=lambda x: x[1] + ) + imgs = [x[1] for x in imgs_with_indices] + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + if len(imgs) < cut_off: + print(f"Skipping {scene}") + continue + indices = [x[0] for x in imgs_with_indices] + tsps = np.array( + [float(img_name.split("_")[1][:-4]) for img_name in imgs] + ) + assert [img[:8] == scene for img in imgs], f"{scene}, {imgs}" + num_imgs = data["images"].shape[0] + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + scenes.append(scene) + scene_img_list.append(img_ids) + sceneids.extend([scene_id] * num_imgs) + images.extend(imgs) + start_img_ids.extend(start_img_ids_) + timestamps.extend(tsps) + + K = np.expand_dims(np.eye(3), 0).repeat(num_imgs, 0) + intrins = data["intrinsics"][indices] + K[:, 0, 0] = [fx for _, _, fx, _, _, _ in intrins] + K[:, 1, 1] = [fy for _, _, _, fy, _, _ in intrins] + K[:, 0, 2] = [cx for _, _, _, _, cx, _ in intrins] + K[:, 1, 2] = [cy for _, _, _, _, _, cy in intrins] + intrinsics.extend(list(K)) + trajectories.extend(list(data["trajectories"][indices])) + + # offset groups + offset += num_imgs + scene_id += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.scene_img_list = scene_img_list + self.intrinsics = intrinsics + self.trajectories = trajectories + self.start_img_ids = start_img_ids + assert len(self.images) == len(self.intrinsics) == len(self.trajectories) + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + block_shuffle=16, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id]) + + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + assert ( + basename[:8] == self.scenes[scene_id] + ), f"{basename}, {self.scenes[scene_id]}" + # print(scene_dir, basename) + # Load RGB image + rgb_image = imread_cv2( + osp.join(scene_dir, "vga_wide", basename.replace(".png", ".jpg")) + ) + # Load depthmap + depthmap = imread_cv2( + osp.join(scene_dir, "highres_depth", basename), cv2.IMREAD_UNCHANGED + ) + depthmap = depthmap.astype(np.float32) / 1000.0 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.7, 0.25, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="arkitscenes_highres", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.99, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/base/__init__.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/base/base_multiview_dataset.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/base/base_multiview_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..43571a69609444fc7d11dbdf6643c130dab6f127 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/base/base_multiview_dataset.py @@ -0,0 +1,576 @@ +import PIL +import numpy as np +import torch +import random +import itertools +from dust3r.datasets.base.easy_dataset import EasyDataset +from dust3r.datasets.utils.transforms import ImgNorm, SeqColorJitter +from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates +import dust3r.datasets.utils.cropping as cropping +from dust3r.datasets.utils.corr import extract_correspondences_from_pts3d + +from vggt.train_utils.augmentation import get_image_augmentation + + + +def get_ray_map(c2w1, c2w2, intrinsics, h, w): + c2w = np.linalg.inv(c2w1) @ c2w2 + i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy") + grid = np.stack([i, j, np.ones_like(i)], axis=-1) + ro = c2w[:3, 3] + rd = np.linalg.inv(intrinsics) @ grid.reshape(-1, 3).T + rd = (c2w @ np.vstack([rd, np.ones_like(rd[0])])).T[:, :3].reshape(h, w, 3) + rd = rd / np.linalg.norm(rd, axis=-1, keepdims=True) + ro = np.broadcast_to(ro, (h, w, 3)) + ray_map = np.concatenate([ro, rd], axis=-1) + return ray_map + + +class BaseMultiViewDataset(EasyDataset): + """Define all basic options. + + Usage: + class MyDataset (BaseMultiViewDataset): + def _get_views(self, idx, rng): + # overload here + views = [] + views.append(dict(img=, ...)) + return views + """ + + def __init__( + self, + *, # only keyword arguments + num_views=None, + split=None, + resolution=None, # square_size or (width, height) or list of [(width,height), ...] + transform=ImgNorm, + aug_crop=False, + n_corres=0, + nneg=0, + seed=None, + allow_repeat=False, + seq_aug_crop=False, + ): + assert num_views is not None, "undefined num_views" + self.num_views = num_views + self.split = split + self._set_resolutions(resolution) + + self.n_corres = n_corres + self.nneg = nneg + assert ( + self.n_corres == "all" + or isinstance(self.n_corres, int) + or ( + isinstance(self.n_corres, list) and len(self.n_corres) == self.num_views + ) + ), f"Error, n_corres should either be 'all', a single integer or a list of length {self.num_views}" + assert ( + self.nneg == 0 or self.n_corres != "all" + ), "nneg should be 0 if n_corres is all" + + self.is_seq_color_jitter = False + if isinstance(transform, str): + transform = eval(transform) + if transform == SeqColorJitter: + transform = SeqColorJitter() + self.is_seq_color_jitter = True + self.transform = transform + + self.image_aug = get_image_augmentation( + color_jitter={ 'brightness': 0.5, + 'contrast': 0.5, + 'saturation': 0.5, + 'hue': 0.1, + 'p': 0.9}, +#common_config.augs.color_jitter, + gray_scale=True,#common_config.augs.gray_scale, + gau_blur=False, #common_config.augs.gau_blur, + ) + + + self.aug_crop = aug_crop + self.seed = seed + self.allow_repeat = allow_repeat + self.seq_aug_crop = seq_aug_crop + + def __len__(self): + return len(self.scenes) + + @staticmethod + def efficient_random_intervals( + start, + num_elements, + interval_range, + fixed_interval_prob=0.8, + weights=None, + seed=42, + ): + if random.random() < fixed_interval_prob: + intervals = random.choices(interval_range, weights=weights) * ( + num_elements - 1 + ) + else: + intervals = [ + random.choices(interval_range, weights=weights)[0] + for _ in range(num_elements - 1) + ] + return list(itertools.accumulate([start] + intervals)) + + def sample_based_on_timestamps(self, i, timestamps, num_views, interval=1): + time_diffs = np.abs(timestamps - timestamps[i]) + ids_candidate = np.where(time_diffs < interval)[0] + ids_candidate = np.sort(ids_candidate) + if (self.allow_repeat and len(ids_candidate) < num_views // 3) or ( + len(ids_candidate) < num_views + ): + return [] + ids_sel_list = [] + ids_candidate_left = ids_candidate.copy() + while len(ids_candidate_left) >= num_views: + ids_sel = np.random.choice(ids_candidate_left, num_views, replace=False) + ids_sel_list.append(sorted(ids_sel)) + ids_candidate_left = np.setdiff1d(ids_candidate_left, ids_sel) + + if len(ids_candidate_left) > 0 and len(ids_candidate) >= num_views: + ids_sel = np.concatenate( + [ + ids_candidate_left, + np.random.choice( + np.setdiff1d(ids_candidate, ids_candidate_left), + num_views - len(ids_candidate_left), + replace=False, + ), + ] + ) + ids_sel_list.append(sorted(ids_sel)) + + if self.allow_repeat: + ids_sel_list.append( + sorted(np.random.choice(ids_candidate, num_views, replace=True)) + ) + + # add sequences with fixed intervals (all possible intervals) + pos_i = np.where(ids_candidate == i)[0][0] + curr_interval = 1 + stop = len(ids_candidate) < num_views + while not stop: + pos_sel = [pos_i] + count = 0 + while len(pos_sel) < num_views: + if count % 2 == 0: + curr_pos_i = pos_sel[-1] + curr_interval + if curr_pos_i >= len(ids_candidate): + stop = True + break + pos_sel.append(curr_pos_i) + else: + curr_pos_i = pos_sel[0] - curr_interval + if curr_pos_i < 0: + stop = True + break + pos_sel.insert(0, curr_pos_i) + count += 1 + if not stop and len(pos_sel) == num_views: + ids_sel = sorted([ids_candidate[pos] for pos in pos_sel]) + if ids_sel not in ids_sel_list: + ids_sel_list.append(ids_sel) + curr_interval += 1 + return ids_sel_list + + @staticmethod + def blockwise_shuffle(x, rng, block_shuffle): + if block_shuffle is None: + return rng.permutation(x).tolist() + else: + assert block_shuffle > 0 + blocks = [x[i : i + block_shuffle] for i in range(0, len(x), block_shuffle)] + shuffled_blocks = [rng.permutation(block).tolist() for block in blocks] + shuffled_list = [item for block in shuffled_blocks for item in block] + return shuffled_list + + def get_seq_from_start_id( + self, + num_views, + id_ref, + ids_all, + rng, + min_interval=1, + max_interval=25, + video_prob=0.5, + fix_interval_prob=0.5, + block_shuffle=None, + ): + """ + args: + num_views: number of views to return + id_ref: the reference id (first id) + ids_all: all the ids + rng: random number generator + max_interval: maximum interval between two views + returns: + pos: list of positions of the views in ids_all, i.e., index for ids_all + is_video: True if the views are consecutive + """ + assert min_interval > 0, f"min_interval should be > 0, got {min_interval}" + assert ( + min_interval <= max_interval + ), f"min_interval should be <= max_interval, got {min_interval} and {max_interval}" + assert id_ref in ids_all + pos_ref = ids_all.index(id_ref) + all_possible_pos = np.arange(pos_ref, len(ids_all)) + + remaining_sum = len(ids_all) - 1 - pos_ref + + if remaining_sum >= num_views - 1: + if remaining_sum == num_views - 1: + assert ids_all[-num_views] == id_ref + return [pos_ref + i for i in range(num_views)], True + max_interval = min(max_interval, 2 * remaining_sum // (num_views - 1)) + intervals = [ + rng.choice(range(min_interval, max_interval + 1)) + for _ in range(num_views - 1) + ] + + # if video or collection + if rng.random() < video_prob: + # if fixed interval or random + if rng.random() < fix_interval_prob: + # regular interval + fixed_interval = rng.choice( + range( + 1, + min(remaining_sum // (num_views - 1) + 1, max_interval + 1), + ) + ) + intervals = [fixed_interval for _ in range(num_views - 1)] + is_video = True + else: + is_video = False + + pos = list(itertools.accumulate([pos_ref] + intervals)) + pos = [p for p in pos if p < len(ids_all)] + pos_candidates = [p for p in all_possible_pos if p not in pos] + pos = ( + pos + + rng.choice( + pos_candidates, num_views - len(pos), replace=False + ).tolist() + ) + + pos = ( + sorted(pos) + if is_video + else self.blockwise_shuffle(pos, rng, block_shuffle) + ) + #elif remaining_sum>1: + else: + # assert self.allow_repeat + uniq_num = remaining_sum + new_pos_ref = rng.choice(np.arange(pos_ref + 1)) + new_remaining_sum = len(ids_all) - 1 - new_pos_ref + new_max_interval = min(max_interval, new_remaining_sum // (uniq_num - 1)) + new_intervals = [ + rng.choice(range(1, new_max_interval + 1)) for _ in range(uniq_num - 1) + ] + + revisit_random = rng.random() + video_random = rng.random() + + if rng.random() < fix_interval_prob and video_random < video_prob: + # regular interval + fixed_interval = rng.choice(range(1, new_max_interval + 1)) + new_intervals = [fixed_interval for _ in range(uniq_num - 1)] + pos = list(itertools.accumulate([new_pos_ref] + new_intervals)) + + is_video = False + if revisit_random < 0.5 or video_prob == 1.0: # revisit, video / collection + is_video = video_random < video_prob + pos = ( + self.blockwise_shuffle(pos, rng, block_shuffle) + if not is_video + else pos + ) + num_full_repeat = num_views // uniq_num + pos = ( + pos * num_full_repeat + + pos[: num_views - len(pos) * num_full_repeat] + ) + elif revisit_random < 0.9: # random + pos = rng.choice(pos, num_views, replace=True) + else: # ordered + pos = sorted(rng.choice(pos, num_views, replace=True)) + assert len(pos) == num_views + return pos, is_video + + def get_img_and_ray_masks(self, is_metric, v, rng, p=[0.8, 0.15, 0.05]): + # generate img mask and raymap mask + if v == 0 or (not is_metric): + img_mask = True + raymap_mask = False + else: + rand_val = rng.random() + if rand_val < p[0]: + img_mask = True + raymap_mask = False + elif rand_val < p[0] + p[1]: + img_mask = False + raymap_mask = True + else: + img_mask = True + raymap_mask = True + return img_mask, raymap_mask + + def get_stats(self): + return f"{len(self)} groups of views" + + def __repr__(self): + resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]" + return ( + f"""{type(self).__name__}({self.get_stats()}, + {self.num_views=}, + {self.split=}, + {self.seed=}, + resolutions={resolutions_str}, + {self.transform=})""".replace( + "self.", "" + ) + .replace("\n", "") + .replace(" ", "") + ) + + def _get_views(self, idx, resolution, rng, num_views): + raise NotImplementedError() + + def __getitem__(self, idx): + # print("Receiving:" , idx) + if isinstance(idx, (tuple, list, np.ndarray)): + # the idx is specifying the aspect-ratio + idx, ar_idx, nview = idx + else: + assert len(self._resolutions) == 1 + ar_idx = 0 + nview = self.num_views + + assert nview >= 1 and nview <= self.num_views + # set-up the rng + if self.seed: # reseed for each __getitem__ + self._rng = np.random.default_rng(seed=self.seed + idx) + elif not hasattr(self, "_rng"): + seed = torch.randint(0, 2**32, (1,)).item() + self._rng = np.random.default_rng(seed=seed) + + if self.aug_crop > 1 and self.seq_aug_crop: + self.delta_target_resolution = self._rng.integers(0, self.aug_crop) + + # over-loaded code + resolution = self._resolutions[ + ar_idx + ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) + views = self._get_views(idx, resolution, self._rng, nview) + assert len(views) == nview + + if "camera_pose" not in views[0]: + views[0]["camera_pose"] = np.ones((4, 4), dtype=np.float32) + first_view_camera_pose = views[0]["camera_pose"] + transform = SeqColorJitter() if self.is_seq_color_jitter else self.transform + + for v, view in enumerate(views): + assert ( + "pts3d" not in view + ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" + view["idx"] = (idx, ar_idx, v) + + # encode the image + width, height = view["img"].size + + view["true_shape"] = np.int32((height, width)) + view["img"] = transform(view["img"]) + view["sky_mask"] = view["depthmap"] < 0 + + assert "camera_intrinsics" in view + if "camera_pose" not in view: + view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32) + else: + assert np.isfinite( + view["camera_pose"] + ).all(), f"NaN in camera pose for view {view_name(view)}" + + ray_map = get_ray_map( + first_view_camera_pose, + view["camera_pose"], + view["camera_intrinsics"], + height, + width, + ) + view["ray_map"] = ray_map.astype(np.float32) + + assert "pts3d" not in view + assert "valid_mask" not in view + assert np.isfinite( + view["depthmap"] + ).all(), f"NaN in depthmap for view {view_name(view)}" + pts3d, pts3d_local, valid_mask = depthmap_to_absolute_camera_coordinates(**view) + + + + + view["pts3d"] = pts3d + view["pts3d_local"] = pts3d_local + view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1) + + # check all datatypes + for key, val in view.items(): + res, err_msg = is_good_type(key, val) + assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" + K = view["camera_intrinsics"] + if False: + if random.random() > 0.3:#self.cojitter_ratio: + images = torch.stack([view['img'] for view in views],axis=0) + images = self.image_aug(images) + for v, view in enumerate(views): + view['img'] = images[v] + + else: + for view in views: + view['img'] = self.image_aug(view['img'][None])[0] + + if self.n_corres > 0: + ref_view = views[0] + for view in views: + corres1, corres2, valid = extract_correspondences_from_pts3d( + ref_view, view, self.n_corres, self._rng, nneg=self.nneg + ) + view["corres"] = (corres1, corres2) + view["valid_corres"] = valid + + # last thing done! + for view in views: + view["rng"] = int.from_bytes(self._rng.bytes(4), "big") + return views + + def _set_resolutions(self, resolutions): + assert resolutions is not None, "undefined resolution" + + if not isinstance(resolutions, list): + resolutions = [resolutions] + + self._resolutions = [] + for resolution in resolutions: + if isinstance(resolution, int): + width = height = resolution + else: + width, height = resolution + assert isinstance( + width, int + ), f"Bad type for {width=} {type(width)=}, should be int" + assert isinstance( + height, int + ), f"Bad type for {height=} {type(height)=}, should be int" + self._resolutions.append((width, height)) + + def _crop_resize_if_necessary( + self, image, depthmap, intrinsics, resolution, rng=None, info=None + ): + """This function: + - first downsizes the image with LANCZOS inteprolation, + which is better than bilinear interpolation in + """ + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # downscale with lanczos interpolation so that image.size == resolution + # cropping centered on the principal point + W, H = image.size + cx, cy = intrinsics[:2, 2].round().astype(int) + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + assert min_margin_x > W / 5, f"Bad principal point in view={info}" + assert min_margin_y > H / 5, f"Bad principal point in view={info}" + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = cx - min_margin_x, cy - min_margin_y + r, b = cx + min_margin_x, cy + min_margin_y + crop_bbox = (l, t, r, b) + image, depthmap, intrinsics = cropping.crop_image_depthmap( + image, depthmap, intrinsics, crop_bbox + ) + + # transpose the resolution if necessary + W, H = image.size # new size + + # high-quality Lanczos down-scaling + target_resolution = np.array(resolution) + if self.aug_crop > 1: + target_resolution += ( + rng.integers(0, self.aug_crop) + if not self.seq_aug_crop + else self.delta_target_resolution + ) + image, depthmap, intrinsics = cropping.rescale_image_depthmap( + image, depthmap, intrinsics, target_resolution + ) + + # actual cropping (if necessary) with bilinear interpolation + intrinsics2 = cropping.camera_matrix_of_crop( + intrinsics, image.size, resolution, offset_factor=0.5 + ) + crop_bbox = cropping.bbox_from_intrinsics_in_out( + intrinsics, intrinsics2, resolution + ) + image, depthmap, intrinsics2 = cropping.crop_image_depthmap( + image, depthmap, intrinsics, crop_bbox + ) + + return image, depthmap, intrinsics2 + + +def is_good_type(key, v): + """returns (is_good, err_msg)""" + if isinstance(v, (str, int, tuple)): + return True, None + if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): + return False, f"bad {v.dtype=}" + return True, None + + +def view_name(view, batch_index=None): + def sel(x): + return x[batch_index] if batch_index not in (None, slice(None)) else x + + db = sel(view["dataset"]) + label = sel(view["label"]) + instance = sel(view["instance"]) + return f"{db}/{label}/{instance}" + + +def transpose_to_landscape(view): + height, width = view["true_shape"] + + if width < height: + # rectify portrait to landscape + assert view["img"].shape == (3, height, width) + view["img"] = view["img"].swapaxes(1, 2) + + assert view["valid_mask"].shape == (height, width) + view["valid_mask"] = view["valid_mask"].swapaxes(0, 1) + + assert view["depthmap"].shape == (height, width) + view["depthmap"] = view["depthmap"].swapaxes(0, 1) + + assert view["pts3d"].shape == (height, width, 3) + view["pts3d"] = view["pts3d"].swapaxes(0, 1) + + # transpose x and y pixels + view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]] + + assert view["ray_map"].shape == (height, width, 6) + view["ray_map"] = view["ray_map"].swapaxes(0, 1) + + assert view["sky_mask"].shape == (height, width) + view["sky_mask"] = view["sky_mask"].swapaxes(0, 1) + + if "corres" in view: + # transpose correspondences x and y + view["corres"][0] = view["corres"][0][:, [1, 0]] + view["corres"][1] = view["corres"][1][:, [1, 0]] diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/base/batched_sampler.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/base/batched_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..b556e913c55791eea3323057402e9637abc9888a --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/base/batched_sampler.py @@ -0,0 +1,93 @@ +import numpy as np +import torch +from accelerate import Accelerator +import torch.utils +from torch.utils.data import BatchSampler, Sampler +import torch.utils.data + + +class CustomRandomSampler(Sampler): + """Random sampling under a constraint: each sample in the batch has the same feature, + which is chosen randomly from a known pool of 'features' for each batch. + + For instance, the 'feature' could be the image aspect-ratio. + + The index returned is a tuple (sample_idx, feat_idx). + This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. + """ + + def __init__( + self, + dataset, + batch_size, + pool_size, + min_view_size, + max_view_size, + world_size, + warmup=1, + drop_last=True, + ): + self.batch_size = batch_size + self.pool_size = pool_size + self.min_view_size = min_view_size + self.max_view_size = max_view_size + self.drop_last = drop_last + self.len_dataset = N = len(dataset) + self.total_size = N + self.epoch = None + self.epochf = 0.0 + + def __len__(self): + return self.total_size + + def set_epoch(self, epoch): + self.epoch = epoch + + def __iter__(self): + if self.epoch is None: + raise ValueError( + "Epoch number not set. Please call 'set_epoch(epoch)' before iterating." + ) + + seed = self.epoch + 788 + rng = np.random.default_rng(seed=seed) + # random indices (will restart from 0 if not drop_last) + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + # random feat_idxs (same across each batch) + n_batches = (self.total_size + self.batch_size - 1) // self.batch_size + if self.pool_size > 1: + p = np.ones(self.pool_size) + p[: self.pool_size // 2] *= 2 + p = p / p.sum() + _feat_idxs = rng.choice(self.pool_size, size=n_batches, p=p) + else: + _feat_idxs = rng.integers(self.pool_size, size=n_batches) + _feat_idxs = np.broadcast_to(_feat_idxs[:, None], (n_batches, self.batch_size)) + _feat_idxs = _feat_idxs.ravel()[: self.total_size] + _view_idxs = rng.integers( + self.min_view_size, self.max_view_size + 1, size=n_batches + ) + _view_idxs = np.broadcast_to(_view_idxs[:, None], (n_batches, self.batch_size)) + _view_idxs = _view_idxs.ravel()[: self.total_size] + + idxs = np.c_[sample_idxs, _feat_idxs, _view_idxs] + yield from (tuple(idx) for idx in idxs) + + +class BatchedRandomSampler(BatchSampler): + """Batch sampler that groups indices from RandomSampler into batches.""" + + def __init__(self, sampler: CustomRandomSampler, batch_size, drop_last=True): + self.sampler = sampler # An instance of RandomSampler + self.batch_size = batch_size + self.drop_last = drop_last + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + +def round_by(total, multiple, up=False): + if up: + total = total + multiple - 1 + return (total // multiple) * multiple diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/base/easy_dataset.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/base/easy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..604048ba0d055d9e59713b87dbab0c2fb7db6d3c --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/base/easy_dataset.py @@ -0,0 +1,212 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import numpy as np +from dust3r.datasets.base.batched_sampler import ( + BatchedRandomSampler, + CustomRandomSampler, +) +import torch + + +class EasyDataset: + """a dataset that you can easily resize and combine. + Examples: + --------- + 2 * dataset ==> duplicate each element 2x + + 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) + + dataset1 + dataset2 ==> concatenate datasets + """ + + def __add__(self, other): + return CatDataset([self, other]) + + def __rmul__(self, factor): + return MulDataset(factor, self) + + def __rmatmul__(self, factor): + return ResizedDataset(factor, self) + + def set_epoch(self, epoch): + pass # nothing to do by default + + def make_sampler( + self, batch_size, shuffle=True, drop_last=True, world_size=1, rank=0, fixed_length=False + ): + if not (shuffle): + raise NotImplementedError() # cannot deal yet + num_of_aspect_ratios = len(self._resolutions) + num_of_views = self.num_views + sampler = CustomRandomSampler( + self, + batch_size, + num_of_aspect_ratios, + 4 if not fixed_length else num_of_views, + num_of_views, + world_size, + warmup=1, + drop_last=drop_last, + ) + return BatchedRandomSampler(sampler, batch_size, drop_last) + + +class MulDataset(EasyDataset): + """Artifically augmenting the size of a dataset.""" + + multiplicator: int + + def __init__(self, multiplicator, dataset): + assert isinstance(multiplicator, int) and multiplicator > 0 + self.multiplicator = multiplicator + self.dataset = dataset + + def __len__(self): + return self.multiplicator * len(self.dataset) + + def __repr__(self): + return f"{self.multiplicator}*{repr(self.dataset)}" + + def __getitem__(self, idx): + if isinstance(idx, tuple): + idx, other, another = idx + return self.dataset[idx // self.multiplicator, other, another] + else: + return self.dataset[idx // self.multiplicator] + + @property + def _resolutions(self): + return self.dataset._resolutions + + @property + def num_views(self): + return self.dataset.num_views + + +class ResizedDataset(EasyDataset): + """Artifically changing the size of a dataset.""" + + new_size: int + + def __init__(self, new_size, dataset): + assert isinstance(new_size, int) and new_size > 0 + self.new_size = new_size + self.dataset = dataset + + def __len__(self): + return self.new_size + + def __repr__(self): + size_str = str(self.new_size) + for i in range((len(size_str) - 1) // 3): + sep = -4 * i - 3 + size_str = size_str[:sep] + "_" + size_str[sep:] + return f"{size_str} @ {repr(self.dataset)}" + + def set_epoch(self, epoch): + # this random shuffle only depends on the epoch + rng = np.random.default_rng(seed=epoch + 777) + + # shuffle all indices + perm = rng.permutation(len(self.dataset)) + + # rotary extension until target size is met + shuffled_idxs = np.concatenate( + [perm] * (1 + (len(self) - 1) // len(self.dataset)) + ) + self._idxs_mapping = shuffled_idxs[: self.new_size] + + assert len(self._idxs_mapping) == self.new_size + + def __getitem__(self, idx): + assert hasattr( + self, "_idxs_mapping" + ), "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()" + if isinstance(idx, tuple): + idx, other, another = idx + return self.dataset[self._idxs_mapping[idx], other, another] + else: + return self.dataset[self._idxs_mapping[idx]] + + @property + def _resolutions(self): + return self.dataset._resolutions + + @property + def num_views(self): + return self.dataset.num_views + + +class CatDataset(EasyDataset): + """Concatenation of several datasets""" + + def __init__(self, datasets): + for dataset in datasets: + assert isinstance(dataset, EasyDataset) + self.datasets = datasets + self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) + + def __len__(self): + return self._cum_sizes[-1] + + def __repr__(self): + # remove uselessly long transform + return " + ".join( + repr(dataset).replace( + ",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))", + "", + ) + for dataset in self.datasets + ) + + def set_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_epoch(epoch) + + def __getitem__(self, idx): + other = None + if isinstance(idx, tuple): + idx, other, another = idx + + cause_error = False + while True: + + if not (0 <= idx < len(self)): + raise IndexError() + + db_idx = np.searchsorted(self._cum_sizes, idx, "right") + dataset = self.datasets[db_idx] + new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) + + if other is not None and another is not None: + new_idx = (new_idx, other, another) + + try: + res_data = dataset[new_idx] + except Exception as e: + print(e) + print("DATA ERROR", new_idx) + idx += 1 + idx = idx % len(self) + continue + + break + return res_data + + @property + def _resolutions(self): + resolutions = self.datasets[0]._resolutions + for dataset in self.datasets[1:]: + assert tuple(dataset._resolutions) == tuple(resolutions) + return resolutions + + @property + def num_views(self): + num_views = self.datasets[0].num_views + for dataset in self.datasets[1:]: + assert dataset.num_views == num_views + return num_views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/blendedmvs.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/blendedmvs.py new file mode 100644 index 0000000000000000000000000000000000000000..f43874276e5f52ff3745cca1c452361770ff57c2 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/blendedmvs.py @@ -0,0 +1,348 @@ +import os.path as osp +import numpy as np +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2, imread_pil +import h5py +from tqdm import tqdm + + +class BlendedMVS_Multi(BaseMultiViewDataset): + """Dataset of outdoor street scenes, 5 images each time""" + + def __init__(self, *args, ROOT, split=None, **kwargs): + self.ROOT = ROOT + self.video = False + self.is_metric = False + super().__init__(*args, **kwargs) + # assert split is None + self._load_data() + + def _load_data(self): + self.data_dict = self.read_h5_file(os.path.join(self.ROOT, "new_overlap.h5")) + self.num_imgs = sum( + [len(self.data_dict[s]["basenames"]) for s in self.data_dict.keys()] + ) + self.num_scenes = len(self.data_dict.keys()) + self.invalid_scenes = [] + self.is_reachable_cache = {scene: {} for scene in self.data_dict.keys()} + + def read_h5_file(self, h5_file_path): + data_dict = {} + self.all_ref_imgs = [] + with h5py.File(h5_file_path, "r") as f: + for scene_dir in tqdm(f.keys()): + group = f[scene_dir] + basenames = group["basenames"][:] + indices = group["indices"][:] + values = group["values"][:] + shape = group.attrs["shape"] + # Reconstruct the sparse matrix + score_matrix = np.zeros(shape, dtype=np.float32) + score_matrix[indices[0], indices[1]] = values + data_dict[scene_dir] = { + "basenames": basenames, + "score_matrix": self.build_adjacency_list(score_matrix), + } + self.all_ref_imgs.extend( + [(scene_dir, b) for b in range(len(basenames))] + ) + return data_dict + + @staticmethod + def build_adjacency_list(S, thresh=0.2): + adjacency_list = [[] for _ in range(len(S))] + S = S - thresh + S[S < 0] = 0 + rows, cols = np.nonzero(S) + for i, j in zip(rows, cols): + adjacency_list[i].append((j, S[i][j])) + return adjacency_list + + @staticmethod + def is_reachable(adjacency_list, start_index, k): + visited = set() + stack = [start_index] + while stack and len(visited) < k: + node = stack.pop() + if node not in visited: + visited.add(node) + for neighbor in adjacency_list[node]: + if neighbor[0] not in visited: + stack.append(neighbor[0]) + return len(visited) >= k + + @staticmethod + def random_sequence_no_revisit_with_backtracking( + adjacency_list, k, start_index, rng: np.random.Generator + ): + path = [start_index] + visited = set([start_index]) + + neighbor_iterators = [] + # Initialize the iterator for the start index + neighbors = adjacency_list[start_index] + neighbor_idxs = [n[0] for n in neighbors] + neighbor_weights = [n[1] for n in neighbors] + neighbor_idxs = rng.choice( + neighbor_idxs, + size=len(neighbor_idxs), + replace=False, + p=np.array(neighbor_weights) / np.sum(neighbor_weights), + ).tolist() + neighbor_iterators.append(iter(neighbor_idxs)) + + while len(path) < k: + if not neighbor_iterators: + # No possible sequence + return None + current_iterator = neighbor_iterators[-1] + try: + next_index = next(current_iterator) + if next_index not in visited: + path.append(next_index) + visited.add(next_index) + + # Prepare iterator for the next node + neighbors = adjacency_list[next_index] + neighbor_idxs = [n[0] for n in neighbors] + neighbor_weights = [n[1] for n in neighbors] + neighbor_idxs = rng.choice( + neighbor_idxs, + size=len(neighbor_idxs), + replace=False, + p=np.array(neighbor_weights) / np.sum(neighbor_weights), + ).tolist() + neighbor_iterators.append(iter(neighbor_idxs)) + except StopIteration: + # No more neighbors to try at this node, backtrack + neighbor_iterators.pop() + visited.remove(path.pop()) + return path + + @staticmethod + def random_sequence_with_optional_repeats( + adjacency_list, + k, + start_index, + rng: np.random.Generator, + max_k=None, + max_attempts=100, + ): + if max_k is None: + max_k = k + path = [start_index] + visited = set([start_index]) + current_index = start_index + attempts = 0 + + while len(path) < max_k and attempts < max_attempts: + attempts += 1 + neighbors = adjacency_list[current_index] + neighbor_idxs = [n[0] for n in neighbors] + neighbor_weights = [n[1] for n in neighbors] + + if not neighbor_idxs: + # No neighbors, cannot proceed further + break + + # Try to find unvisited neighbors + unvisited_neighbors = [ + (idx, wgt) + for idx, wgt in zip(neighbor_idxs, neighbor_weights) + if idx not in visited + ] + if unvisited_neighbors: + # Select among unvisited neighbors + unvisited_idxs = [idx for idx, _ in unvisited_neighbors] + unvisited_weights = [wgt for _, wgt in unvisited_neighbors] + probabilities = np.array(unvisited_weights) / np.sum(unvisited_weights) + next_index = rng.choice(unvisited_idxs, p=probabilities) + visited.add(next_index) + else: + # All neighbors visited, but we need to reach length max_k + # So we can revisit nodes + probabilities = np.array(neighbor_weights) / np.sum(neighbor_weights) + next_index = rng.choice(neighbor_idxs, p=probabilities) + + path.append(next_index) + current_index = next_index + + if len(set(path)) >= k: + # If path is shorter than max_k, extend it by repeating existing elements + while len(path) < max_k: + # Randomly select nodes from the existing path to repeat + next_index = rng.choice(path) + path.append(next_index) + return path + else: + # Could not reach k unique nodes + return None + + def __len__(self): + return len(self.all_ref_imgs) + + def get_image_num(self): + return self.num_imgs + + def get_stats(self): + return f"{len(self)} imgs from {self.num_scenes} scenes" + + def generate_sequence( + self, scene, adj_list, num_views, start_index, rng, allow_repeat=False + ): + cutoff = num_views if not allow_repeat else max(num_views // 5, 3) + if start_index in self.is_reachable_cache[scene]: + if not self.is_reachable_cache[scene][start_index]: + print( + f"Cannot reach {num_views} unique elements from index {start_index}." + ) + return None + else: + self.is_reachable_cache[scene][start_index] = self.is_reachable( + adj_list, start_index, cutoff + ) + if not self.is_reachable_cache[scene][start_index]: + print( + f"Cannot reach {num_views} unique elements from index {start_index}." + ) + return None + if not allow_repeat: + sequence = self.random_sequence_no_revisit_with_backtracking( + adj_list, cutoff, start_index, rng + ) + else: + sequence = self.random_sequence_with_optional_repeats( + adj_list, cutoff, start_index, rng, max_k=num_views + ) + if not sequence: + self.is_reachable_cache[scene][start_index] = False + print("Failed to generate a sequence without revisiting.") + return sequence + + def _get_views(self, idx, resolution, rng: np.random.Generator, num_views): + MAX_RETRIES = 100 # Maximum attempts to find a valid sequence + MAX_SCENE_RETRIES = 50 # Maximum attempts to find a valid scene + + scene_info, ref_img_idx = self.all_ref_imgs[idx] + invalid_seq = True + ordered_video = False + + outer_retry_count = 0 + + while invalid_seq and outer_retry_count < MAX_RETRIES: + outer_retry_count += 1 + + basenames = self.data_dict[scene_info]["basenames"] + if ( + sum( + [ + (1 - int(x)) + for x in list(self.is_reachable_cache[scene_info].values()) + ] + ) + > len(basenames) - self.num_views + ): + self.invalid_scenes.append(scene_info) + + inner_retry_count = 0 + while scene_info in self.invalid_scenes and inner_retry_count < MAX_SCENE_RETRIES: + inner_retry_count += 1 + idx = rng.integers(low=0, high=len(self.all_ref_imgs)) + scene_info, ref_img_idx = self.all_ref_imgs[idx] + basenames = self.data_dict[scene_info]["basenames"] + + # If we exhausted inner retries, skip to next sample + if inner_retry_count >= MAX_SCENE_RETRIES: + import warnings + warnings.warn( + f"BlendedMVS: Could not find valid scene after {MAX_SCENE_RETRIES} attempts. " + f"Skipping sample idx={idx}. This might indicate data quality issues." + ) + # Try with a completely random sample + idx = rng.integers(low=0, high=len(self.all_ref_imgs)) + scene_info, ref_img_idx = self.all_ref_imgs[idx] + basenames = self.data_dict[scene_info]["basenames"] + + score_matrix = self.data_dict[scene_info]["score_matrix"] + imgs_idxs = self.generate_sequence( + scene_info, score_matrix, num_views, ref_img_idx, rng, self.allow_repeat + ) + + if imgs_idxs is None: + random_direction = 2 * rng.choice(2) - 1 + for offset in range(1, len(basenames)): + tentative_im_idx = ( + ref_img_idx + (random_direction * offset) + ) % len(basenames) + if ( + tentative_im_idx not in self.is_reachable_cache[scene_info] + or self.is_reachable_cache[scene_info][tentative_im_idx] + ): + ref_img_idx = tentative_im_idx + break + else: + invalid_seq = False + + # If we exhausted all retries, raise an error instead of hanging + if outer_retry_count >= MAX_RETRIES: + import warnings + warnings.warn( + f"BlendedMVS: Failed to generate valid sequence after {MAX_RETRIES} attempts. " + f"Skipping sample idx={idx}. This might indicate severe data quality issues." + ) + # As a last resort, try one more time with a completely random sample + idx = rng.integers(low=0, high=len(self.all_ref_imgs)) + scene_info, ref_img_idx = self.all_ref_imgs[idx] + basenames = self.data_dict[scene_info]["basenames"] + score_matrix = self.data_dict[scene_info]["score_matrix"] + imgs_idxs = self.generate_sequence( + scene_info, score_matrix, num_views, ref_img_idx, rng, self.allow_repeat + ) + # If still None, use sequential indices as fallback + if imgs_idxs is None: + imgs_idxs = list(range(min(num_views, len(basenames)))) + + views = [] + for view_idx in imgs_idxs: + scene_dir = osp.join(self.ROOT, scene_info) + impath = basenames[view_idx].decode("utf-8") + image = imread_pil(osp.join(scene_dir, impath + ".jpg")) + depthmap = imread_cv2(osp.join(scene_dir, impath + ".exr")) + camera_params = np.load(osp.join(scene_dir, impath + ".npz")) + + intrinsics = np.float32(camera_params["intrinsics"]) + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = camera_params["R_cam2world"] + camera_pose[:3, 3] = camera_params["t_cam2world"] + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(scene_dir, impath) + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="BlendedMVS", + label=osp.relpath(scene_dir, self.ROOT), + is_metric=self.is_metric, + is_video=ordered_video, + instance=osp.join(scene_dir, impath + ".jpg"), + quantile=np.array(0.97, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views + diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/co3d.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/co3d.py new file mode 100644 index 0000000000000000000000000000000000000000..98dcc820fcd70fd496396ef000c22aeb2adee35a --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/co3d.py @@ -0,0 +1,190 @@ +import os.path as osp +import json +import itertools +from collections import deque +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np +import time + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class Co3d_Multi(BaseMultiViewDataset): + def __init__(self, mask_bg="rand", *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + assert mask_bg in (True, False, "rand") + self.mask_bg = mask_bg + self.is_metric = False + self.dataset_label = "Co3d_v2" + + # load all scenes + with open(osp.join(self.ROOT, f"selected_seqs_{self.split}.json"), "r") as f: + self.scenes = json.load(f) + self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0} + self.scenes = { + (k, k2): v2 for k, v in self.scenes.items() for k2, v2 in v.items() + } + self.scene_list = list(self.scenes.keys()) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + self.cut_off = cut_off + self.all_ref_imgs = [ + (key, value) + for key, values in self.scenes.items() + for value in values[: len(values) - cut_off + 1] + ] + self.invalidate = {scene: {} for scene in self.scene_list} + self.invalid_scenes = {scene: False for scene in self.scene_list} + + def __len__(self): + return len(self.all_ref_imgs) + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.npz") + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.jpg") + + def _get_depthpath(self, obj, instance, view_idx): + return osp.join( + self.ROOT, obj, instance, "depths", f"frame{view_idx:06n}.jpg.geometric.png" + ) + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "masks", f"frame{view_idx:06n}.png") + + def _read_depthmap(self, depthpath, input_metadata): + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num( + input_metadata["maximum_depth"] + ) + return depthmap + + def _get_views(self, idx, resolution, rng, num_views): + invalid_seq = True + scene_info, ref_img_idx = self.all_ref_imgs[idx] + + while invalid_seq: + while self.invalid_scenes[scene_info]: + idx = rng.integers(low=0, high=len(self.all_ref_imgs)) + scene_info, ref_img_idx = self.all_ref_imgs[idx] + + obj, instance = scene_info + + image_pool = self.scenes[obj, instance] + if len(image_pool) < self.cut_off: + print("Invalid scene!") + self.invalid_scenes[scene_info] = True + continue + + imgs_idxs, ordered_video = self.get_seq_from_start_id( + num_views, ref_img_idx, image_pool, rng + ) + + if resolution not in self.invalidate[obj, instance]: # flag invalid images + self.invalidate[obj, instance][resolution] = [ + False for _ in range(len(image_pool)) + ] + # decide now if we mask the bg + mask_bg = (self.mask_bg == True) or ( + self.mask_bg == "rand" and rng.choice(2, p=[0.9, 0.1]) + ) + views = [] + + imgs_idxs = deque(imgs_idxs) + + while len(imgs_idxs) > 0: # some images (few) have zero depth + if ( + len(image_pool) - sum(self.invalidate[obj, instance][resolution]) + < self.cut_off + ): + print("Invalid scene!") + invalid_seq = True + self.invalid_scenes[scene_info] = True + break + + im_idx = imgs_idxs.pop() + if self.invalidate[obj, instance][resolution][im_idx]: + # search for a valid image + ordered_video = False + random_direction = 2 * rng.choice(2) - 1 + for offset in range(1, len(image_pool)): + tentative_im_idx = (im_idx + (random_direction * offset)) % len( + image_pool + ) + if not self.invalidate[obj, instance][resolution][ + tentative_im_idx + ]: + im_idx = tentative_im_idx + break + view_idx = image_pool[im_idx] + impath = self._get_impath(obj, instance, view_idx) + depthpath = self._get_depthpath(obj, instance, view_idx) + + # load camera params + metadata_path = self._get_metadatapath(obj, instance, view_idx) + input_metadata = np.load(metadata_path) + camera_pose = input_metadata["camera_pose"].astype(np.float32) + intrinsics = input_metadata["camera_intrinsics"].astype(np.float32) + + # load image and depth + rgb_image = imread_cv2(impath) + depthmap = self._read_depthmap(depthpath, input_metadata) + + if mask_bg: + # load object mask + maskpath = self._get_maskpath(obj, instance, view_idx) + maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype( + np.float32 + ) + maskmap = (maskmap / 255.0) > 0.1 + + # update the depthmap with mask + depthmap *= maskmap + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath + ) + num_valid = (depthmap > 0.0).sum() + if num_valid == 0: + # problem, invalidate image and retry + self.invalidate[obj, instance][resolution][im_idx] = True + imgs_idxs.append(im_idx) + continue + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, len(views), rng + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset=self.dataset_label, + label=osp.join(obj, instance), + instance=osp.split(impath)[1], + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.9, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + + if len(views) == num_views and not all( + [view["instance"] == views[0]["instance"] for view in views] + ): + invalid_seq = False + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/cop3d.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/cop3d.py new file mode 100644 index 0000000000000000000000000000000000000000..aa93c7d109f80d70869250b8a44daf59cf202e0f --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/cop3d.py @@ -0,0 +1,110 @@ +import os.path as osp +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np + +from dust3r.datasets.co3d import Co3d_Multi +from dust3r.utils.image import imread_cv2 + + +class Cop3D_Multi(Co3d_Multi): + def __init__(self, mask_bg="rand", *args, ROOT, **kwargs): + super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs) + self.dataset_label = "Cop3D" + self.is_metric = False + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.npz") + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.jpg") + + def _get_depthpath(self, obj, instance, view_idx): + # no depth, pseduo path just for getting the right resolution + return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.jpg") + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "masks", f"frame{view_idx:06n}.png") + + def _read_depthmap(self, impath, input_metadata): + # no depth, set to all ones + img = imread_cv2(impath, cv2.IMREAD_UNCHANGED) + depthmap = np.ones_like(img[..., 0], dtype=np.float32) + return depthmap + + def _get_views(self, idx, resolution, rng, num_views): + invalid_seq = True + scene_info, ref_img_idx = self.all_ref_imgs[idx] + + while invalid_seq: + while self.invalid_scenes[scene_info]: + idx = rng.integers(low=0, high=len(self.all_ref_imgs)) + scene_info, ref_img_idx = self.all_ref_imgs[idx] + + obj, instance = scene_info + + image_pool = self.scenes[obj, instance] + if len(image_pool) < self.num_views: + print("Invalid scene!") + self.invalid_scenes[scene_info] = True + continue + + imgs_idxs, ordered_video = self.get_seq_from_start_id( + num_views, + ref_img_idx, + image_pool, + rng, + max_interval=5, + video_prob=1.0, + fix_interval_prob=0.9, + ) + + views = [] + + for im_idx in imgs_idxs: + view_idx = image_pool[im_idx] + impath = self._get_impath(obj, instance, view_idx) + depthpath = self._get_depthpath(obj, instance, view_idx) + + # load camera params + metadata_path = self._get_metadatapath(obj, instance, view_idx) + input_metadata = np.load(metadata_path) + camera_pose = input_metadata["camera_pose"].astype(np.float32) + intrinsics = input_metadata["camera_intrinsics"].astype(np.float32) + + # load image and depth + rgb_image = imread_cv2(impath) + depthmap = self._read_depthmap(depthpath, input_metadata) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset=self.dataset_label, + label=osp.join(obj, instance), + instance=osp.split(impath)[1], + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.96, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=True, + depth_only=False, + single_view=False, + reset=False, + ) + ) + + if len(views) == num_views and not all( + [view["instance"] == views[0]["instance"] for view in views] + ): + invalid_seq = False + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/dl3dv.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/dl3dv.py new file mode 100644 index 0000000000000000000000000000000000000000..2650d573123b86f10c99bb663ec399372808fe37 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/dl3dv.py @@ -0,0 +1,166 @@ +import os.path as osp +import os +import sys +import itertools + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class DL3DV_Multi(BaseMultiViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.max_interval = 20 + self.is_metric = False + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data() + + def _load_data(self): + self.all_scenes = sorted( + [f for f in os.listdir(self.ROOT) if os.path.isdir(osp.join(self.ROOT, f))] + ) + subscenes = [] + for scene in self.all_scenes: + # not empty + subscenes.extend( + [ + osp.join(scene, f) + for f in os.listdir(osp.join(self.ROOT, scene)) + if os.path.isdir(osp.join(self.ROOT, scene, f)) + and len(os.listdir(osp.join(self.ROOT, scene, f))) > 0 + ] + ) + + offset = 0 + scenes = [] + sceneids = [] + images = [] + scene_img_list = [] + start_img_ids = [] + j = 0 + + for scene_idx, scene in enumerate(subscenes): + scene_dir = osp.join(self.ROOT, scene, "dense") + rgb_paths = sorted( + [ + f + for f in os.listdir(os.path.join(scene_dir, "rgb")) + if f.endswith(".png") + ] + ) + assert len(rgb_paths) > 0, f"{scene_dir} is empty." + num_imgs = len(rgb_paths) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + scenes.append(scene) + scene_img_list.append(img_ids) + sceneids.extend([j] * num_imgs) + images.extend(rgb_paths) + start_img_ids.extend(start_img_ids_) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + scene_id = self.sceneids[start_id] + all_image_ids = self.scene_img_list[scene_id] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + block_shuffle=25, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for view_idx in image_idxs: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id], "dense") + + rgb_path = self.images[view_idx] + basename = rgb_path[:-4] + + rgb_image = imread_cv2( + osp.join(scene_dir, "rgb", rgb_path), cv2.IMREAD_COLOR + ) + depthmap = np.load(osp.join(scene_dir, "depth", basename + ".npy")).astype( + np.float32 + ) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + cam_file = np.load(osp.join(scene_dir, "cam", basename + ".npz")) + sky_mask = ( + cv2.imread( + osp.join(scene_dir, "sky_mask", rgb_path), cv2.IMREAD_UNCHANGED + ) + >= 127 + ) + outlier_mask = cv2.imread( + osp.join(scene_dir, "outlier_mask", rgb_path), cv2.IMREAD_UNCHANGED + ) + depthmap[sky_mask] = -1.0 + depthmap[outlier_mask >= 127] = 0.0 + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + threshold = ( + np.percentile(depthmap[depthmap > 0], 98) + if depthmap[depthmap > 0].size > 0 + else 0 + ) + depthmap[depthmap > threshold] = 0.0 + + intrinsics = cam_file["intrinsic"].astype(np.float32) + camera_pose = cam_file["pose"].astype(np.float32) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="dl3dv", + label=self.scenes[scene_id] + "_" + rgb_path, + instance=osp.join(scene_dir, "rgb", rgb_path), + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.9, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/dynamic_replica.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/dynamic_replica.py new file mode 100644 index 0000000000000000000000000000000000000000..1d816e58be6518e1274fa84fa8c6a7cae73741ca --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/dynamic_replica.py @@ -0,0 +1,137 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class DynamicReplica(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 16 + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + self.scenes = os.listdir(os.path.join(self.ROOT, split)) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, self.split, scene, "left") + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")], + key=lambda x: float(x), + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=1.0, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id], "left") + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.85, 0.10, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="dynamic_replica", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/habitat_hm3d.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/habitat_hm3d.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3d3422ccc4b19753630d09a39beee191bae8fe --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/habitat_hm3d.py @@ -0,0 +1,174 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class HabitatHM3D_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = False + self.max_interval = 8 + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + basenames = sorted( + [f[:-4] for f in os.listdir(scene_dir) if f.endswith(".npz")], + key=lambda x: int(x), + ) + + num_imgs = len(basenames) + # TODO: because current minghui's training data is backward moving, now use seq from -1 to 0 + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend([(scene, id) for id in start_img_ids_]) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + self.invalid_scenes = {scene: False for scene in self.scenes} + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + invalid_seq = True + scene, start_id = self.start_img_ids[idx] # 获取指定索引idx对应的场景名scene和起始图像id + + # 添加最大重试次数,防止无限循环导致分布式训练卡住 + max_retries = 100 + retry_count = 0 + + while invalid_seq: + retry_count += 1 + + # 超过重试次数限制,抛出异常 + if retry_count > max_retries: + raise RuntimeError( + f"[HabitatHM3D] Failed to get valid views after {max_retries} retries. " + f"idx={idx}, scene={scene}, num_views={num_views}. " + f"This may indicate insufficient valid frames in the dataset." + ) + + # 超过50次时打印警告 + if retry_count == 50: + print(f"[HabitatHM3D WARNING] Already retried {retry_count} times for idx={idx}, scene={scene}") + + # 如果当前场景被标记为invalid则随机选择一个新的场景和起始图像id + scene_retry = 0 + while self.invalid_scenes[scene]: + scene_retry += 1 + if scene_retry > len(self.start_img_ids): + raise RuntimeError( + f"[HabitatHM3D] All scenes are invalid! Cannot find valid scene after {scene_retry} attempts." + ) + idx = rng.integers(low=0, high=len(self.start_img_ids)) + scene, start_id = self.start_img_ids[idx] + + all_image_ids = self.scene_img_list[self.sceneids[start_id]] # 获取当前场景的所有图像id列表 + pos, ordered_video = self.get_seq_from_start_id( + num_views, start_id, all_image_ids, rng, max_interval=self.max_interval + ) # 根据起始图像id和其他参数生成图像序列的索引pos 并返回有序视频 + image_idxs = np.array(all_image_ids)[pos] # 从all_image_ids提取图像序列 + + views = [] + load_failed = False + for view_idx in image_idxs: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + + basename = self.images[view_idx] + + try: + # Load RGB image + rgb_image = imread_cv2(osp.join(scene_dir, "image_" + basename + ".png")) + # Load depthmap + depthmap = imread_cv2( + osp.join(scene_dir, "depth_" + basename + ".png"), cv2.IMREAD_UNCHANGED + ) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + camera_params = np.load(osp.join(scene_dir, basename + ".npz")) + intrinsics = np.float32(camera_params["intrinsics"]) + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = camera_params["R_cam2world"] + camera_pose[:3, 3] = camera_params["t_cam2world"] + except Exception as e: + print(f"[HabitatHM3D] Error loading {scene} {basename}: {e}, skipping scene") + self.invalid_scenes[scene] = True + load_failed = True + break + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="habitatHM3D", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=True, + depth_only=False, + single_view=False, + reset=False, + ) + ) + + # 只有成功加载所有视图才退出循环 + if not load_failed and len(views) == num_views: + invalid_seq = False + + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/hoi4d.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/hoi4d.py new file mode 100644 index 0000000000000000000000000000000000000000..b602df5d4dd1493d02377039379fd2ffb3b08ba2 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/hoi4d.py @@ -0,0 +1,84 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys +sys.path.append(osp.join(osp.dirname(__file__), '..','..')) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class HOI4D_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + scenes = os.listdir(self.ROOT) + img_names = [] + for scene in scenes: + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, 'rgb') + basenames = sorted([f[:-4] for f in os.listdir(rgb_dir) if f.endswith('.png')]) + img_names.extend([(scene, basename) for basename in basenames]) + + self.img_names = img_names + + def __len__(self): + return len(self.img_names) + + def get_image_num(self): + return len(self.img_names) + + def _get_views(self, idx, resolution, rng, num_views): + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + invalid_seq = True + while invalid_seq: + img_names = new_rng.choice(self.img_names, num_views, replace=False) + + views = [] + for v, img_name in enumerate(img_names): + # Load RGB image + scene, img_name = img_name + try: + rgb_image = imread_cv2(osp.join(self.ROOT, scene, "rgb", f"{img_name}.png")) + depthmap = np.load(osp.join(self.ROOT, scene, "depth", f"{img_name}.npy")) + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + + intrinsics = np.load(osp.join(self.ROOT, scene, "cam", f"{img_name}.npz"))["intrinsics"] + except: + print(f"Error loading {scene} {img_name}, skipping") + break + # camera pose is not provided, placeholder + camera_pose = np.eye(4) + + rgb_image, depthmap, intrinsics= self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name) + + views.append(dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset='HOI4D', + label=img_name, + instance=osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"), + is_metric=self.is_metric, + is_video=False, + quantile=np.array(0.99, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + )) + if len(views) == num_views: + invalid_seq = False + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/hypersim.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/hypersim.py new file mode 100644 index 0000000000000000000000000000000000000000..141c1c95b49923923a87d6b4baf8fe32b00f98e6 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/hypersim.py @@ -0,0 +1,142 @@ +import os.path as osp +import os +import sys +import itertools + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_pil + + +class HyperSim_Multi(BaseMultiViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 4 + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data() + print('DATA: hypersim', len(self)) + + def _load_data(self): + self.all_scenes = sorted( + [f for f in os.listdir(self.ROOT) if os.path.isdir(osp.join(self.ROOT, f))] + ) + subscenes = [] + for scene in self.all_scenes: + # not empty + subscenes.extend( + [ + osp.join(scene, f) + for f in os.listdir(osp.join(self.ROOT, scene)) + if os.path.isdir(osp.join(self.ROOT, scene, f)) + and len(os.listdir(osp.join(self.ROOT, scene, f))) > 0 + ] + ) + + offset = 0 + scenes = [] + sceneids = [] + images = [] + start_img_ids = [] + scene_img_list = [] + j = 0 + for scene_idx, scene in enumerate(subscenes): + scene_dir = osp.join(self.ROOT, scene) + rgb_paths = sorted([f for f in os.listdir(scene_dir) if f.endswith(".png")]) + assert len(rgb_paths) > 0, f"{scene_dir} is empty." + num_imgs = len(rgb_paths) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + scenes.append(scene) + scene_img_list.append(img_ids) + sceneids.extend([j] * num_imgs) + images.extend(rgb_paths) + start_img_ids.extend(start_img_ids_) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.scene_img_list = scene_img_list + self.start_img_ids = start_img_ids + + def __len__(self): + return len(self.start_img_ids) * 10 + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + idx = idx // 10 + start_id = self.start_img_ids[idx] + scene_id = self.sceneids[start_id] + all_image_ids = self.scene_img_list[scene_id] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + block_shuffle=16, + ) + image_idxs = np.array(all_image_ids)[pos] + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + + rgb_path = self.images[view_idx] + depth_path = rgb_path.replace("rgb.png", "depth.npy") + cam_path = rgb_path.replace("rgb.png", "cam.npz") + + rgb_image = imread_pil(osp.join(scene_dir, rgb_path)) + depthmap = np.load(osp.join(scene_dir, depth_path)).astype(np.float32) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + cam_file = np.load(osp.join(scene_dir, cam_path)) + intrinsics = cam_file["intrinsics"].astype(np.float32) + camera_pose = cam_file["pose"].astype(np.float32) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="hypersim", + label=self.scenes[scene_id] + "_" + rgb_path, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/kitti360.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/kitti360.py new file mode 100644 index 0000000000000000000000000000000000000000..3113f845ef0be1ff1b1085b60dc92a21bffc79e5 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/kitti360.py @@ -0,0 +1,354 @@ +"""KITTI-360 training dataset loader (real outdoor). + +cam_00 perspective images (rectified) + poses.txt + perspective.txt P_rect_00 +intrinsics. Optional Velodyne HDL64 sparse depth supervision when data_3d_raw +is on disk (loaded from `velodyne_root`, defaults to ROOT). + +Layout expected: + ROOT/data_2d_raw//image_00/data_rect/{NNNNNNNNNN}.png + ROOT/data_poses//poses.txt # frame_idx + 12 floats c2w + ROOT/calibration/perspective.txt # P_rect_00 + R_rect_00 + ROOT/calibration/calib_cam_to_velo.txt # cam0→velo (3×4) + velodyne_root/data_3d_raw//velodyne_points/data/{NNNNNNNNNN}.bin # optional + +Train/test split (cvlibs convention): + train: 0000, 0002, 0003, 0004, 0005, 0006, 0009 + test: 0007, 0010 +""" +import os +import os.path as osp +import sys + +import cv2 +import numpy as np + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + +TRAIN_SEQS = [ + "2013_05_28_drive_0000_sync", + "2013_05_28_drive_0002_sync", + "2013_05_28_drive_0003_sync", + "2013_05_28_drive_0004_sync", + "2013_05_28_drive_0005_sync", + "2013_05_28_drive_0006_sync", + "2013_05_28_drive_0009_sync", +] +TEST_SEQS = [ + "2013_05_28_drive_0007_sync", + "2013_05_28_drive_0010_sync", +] + + +def _parse_perspective_intrinsics(path: str): + """Parse calibration/perspective.txt → (P_rect_00 (3,4), R_rect_00 (3,3), S_rect_00 (W, H)).""" + P_rect = None + R_rect = None + S_rect = None + with open(path) as fh: + for line in fh: + line = line.strip() + if line.startswith("P_rect_00:"): + vals = list(map(float, line.split()[1:])) + P_rect = np.array(vals, dtype=np.float64).reshape(3, 4) + elif line.startswith("R_rect_00:"): + vals = list(map(float, line.split()[1:])) + R_rect = np.array(vals, dtype=np.float64).reshape(3, 3) + elif line.startswith("S_rect_00:"): + vals = list(map(float, line.split()[1:])) + S_rect = (int(vals[0]), int(vals[1])) # (W, H) + if P_rect is None: + raise RuntimeError(f"P_rect_00 missing in {path}") + if R_rect is None: + R_rect = np.eye(3, dtype=np.float64) + return P_rect, R_rect, S_rect + + +def _parse_cam_to_velo(path: str): + """Parse calibration/calib_cam_to_velo.txt → T_cam0_to_velo (4×4 homogeneous). + + File contains a single 3×4 row-major rigid transform (cam0 origin in velo frame). + """ + with open(path) as fh: + line = fh.readline().strip() + vals = list(map(float, line.split())) + if len(vals) != 12: + raise RuntimeError(f"Expected 12 floats in {path}, got {len(vals)}") + T = np.eye(4, dtype=np.float64) + T[:3, :] = np.array(vals, dtype=np.float64).reshape(3, 4) + return T + + +def _load_velodyne_bin(bin_path: str) -> np.ndarray: + """Load (N,4) float32 [x,y,z,reflectance] from KITTI-360 .bin file.""" + return np.fromfile(bin_path, dtype=np.float32).reshape(-1, 4) + + +def _project_velo_to_depth_kitti360(velo_pts, P_rect_00, T_velo_to_cam_rect, H, W, + min_depth=0.5, max_depth=80.0): + """Project KITTI-360 velodyne scan onto image_00 (rectified) → sparse depthmap. + + pixel_h = P_rect_00 @ T_velo_to_cam_rect @ velo_h + where T_velo_to_cam_rect = R_rect_00 @ inv(T_cam0_to_velo) (4×4 incorporating rectification). + Closest-z wins on duplicate pixels. + """ + pts_h = np.concatenate( + [velo_pts[:, :3].astype(np.float64), np.ones((velo_pts.shape[0], 1))], + axis=1, + ) + cam = pts_h @ T_velo_to_cam_rect.T # (N,4) in rectified cam0 frame + in_front = cam[:, 2] > min_depth + cam = cam[in_front] + if cam.shape[0] == 0: + return np.full((H, W), -1.0, dtype=np.float32) + uv_h = cam @ P_rect_00.T # (M,3) + z = uv_h[:, 2] + valid = z > min_depth + z = z[valid] + u = uv_h[valid, 0] / z + v = uv_h[valid, 1] / z + in_img = (u >= 0) & (u < W) & (v >= 0) & (v < H) & (z < max_depth) + u = u[in_img].astype(np.int32) + v = v[in_img].astype(np.int32) + z = z[in_img] + depthmap = np.full((H, W), -1.0, dtype=np.float32) + if z.size == 0: + return depthmap + order = np.argsort(-z) # closest-z written last → wins + depthmap[v[order], u[order]] = z[order].astype(np.float32) + return depthmap + + +def _load_kitti360_poses(path: str): + """Read cam0_to_world.txt → dict[frame_idx] = (4,4) c2w matrix. + + KITTI-360 ships TWO pose files per sequence: + - poses.txt : IMU/system pose (NOT camera pose) + - cam0_to_world.txt : actual camera-to-world for cam_00 + The cam0 file has full 4x4 rows (16 floats); poses.txt is 3x4 (12 floats). + Using poses.txt makes pmap loss inconsistent with depth (~1m offset). + Note: not every frame has a pose (gaps where SLAM failed); skip missing. + """ + raw = np.loadtxt(path) + out = {} + for row in raw: + fid = int(row[0]) + if row.shape[0] >= 17: # cam0_to_world.txt: 1 + 16 + T = row[1:17].reshape(4, 4).astype(np.float32) + else: # poses.txt fallback: 1 + 12 + T = np.eye(4, dtype=np.float32) + T[:3, :] = row[1:13].reshape(3, 4).astype(np.float32) + out[fid] = T + return out + + +class KITTI360_Multi(BaseMultiViewDataset): + """KITTI-360 perspective cam_00. + + Camera-only by default; depth supervision activates per-frame when a Velodyne + .bin scan is present at velodyne_root/data_3d_raw//velodyne_points/data/.bin. + """ + + def __init__(self, ROOT, *args, velodyne_root=None, **kwargs): + self.ROOT = ROOT + # Velodyne root for data_3d_raw/. If None, look under ROOT (in-place download). + self.velodyne_root = velodyne_root if velodyne_root else ROOT + self.video = True + self.is_metric = True + self.max_interval = 4 + super().__init__(*args, **kwargs) + self._load_data(self.split) + + def _load_data(self, split=None): + # Intrinsics + rectification (shared across all KITTI-360 sequences) + calib_dir = osp.join(self.ROOT, "calibration") + P_rect, R_rect, _ = _parse_perspective_intrinsics( + osp.join(calib_dir, "perspective.txt") + ) + self.P_rect_00 = P_rect.copy() + self.K = P_rect[:, :3].copy().astype(np.float32) + + # T_cam0→velo from calib_cam_to_velo.txt; lidar projection needs the inverse, + # composed with R_rect_00 to land in rectified cam0 frame. + cam_to_velo_path = osp.join(calib_dir, "calib_cam_to_velo.txt") + if osp.isfile(cam_to_velo_path): + T_cam_to_velo = _parse_cam_to_velo(cam_to_velo_path) # (4,4) + T_velo_to_cam = np.linalg.inv(T_cam_to_velo) + R_rect_h = np.eye(4, dtype=np.float64) + R_rect_h[:3, :3] = R_rect + self.T_velo_to_cam_rect = R_rect_h @ T_velo_to_cam # (4,4) + else: + self.T_velo_to_cam_rect = None # lidar disabled + + seq_ids = TRAIN_SEQS if split == "train" else TEST_SEQS + scenes = [] + seq_poses = [] # list of (M_i, 4, 4) per scene + seq_frame_ids = [] # list of [frame_idx, ...] (only those with poses + image) + seq_velo_dir = [] # absolute velodyne dir per seq, or None + scene_img_list = [] + sceneids = [] + start_img_ids = [] + offset = 0 + j = 0 + + for sid in seq_ids: + img_dir = osp.join(self.ROOT, "data_2d_raw", sid, "image_00", "data_rect") + pose_path = osp.join(self.ROOT, "data_poses", sid, "cam0_to_world.txt") + if not osp.isdir(img_dir) or not osp.isfile(pose_path): + continue + + poses_dict = _load_kitti360_poses(pose_path) + # Walk image_00/data_rect for available frame_idx files. Skip zero-byte + # placeholders left over from partial / aborted downloads (would crash + # imread_cv2 at sample time). + avail = [] + for fname in os.listdir(img_dir): + if not fname.endswith(".png"): + continue + try: + fid = int(osp.splitext(fname)[0]) + except ValueError: + continue + if fid not in poses_dict: + continue + fpath = osp.join(img_dir, fname) + try: + if osp.getsize(fpath) <= 0: + continue + except OSError: + continue + avail.append(fid) + avail.sort() + if not avail: + continue + + poses = np.stack([poses_dict[f] for f in avail], axis=0) + n_imgs = len(avail) + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + if n_imgs < cut_off: + continue + + # Velodyne dir (per-seq); set to None if absent → frame falls back to camera-only. + velo_dir = osp.join( + self.velodyne_root, "data_3d_raw", sid, "velodyne_points", "data" + ) + velo_dir = velo_dir if osp.isdir(velo_dir) else None + + img_ids = list(np.arange(n_imgs) + offset) + start_img_ids_ = img_ids[: n_imgs - cut_off + 1] + + scenes.append(sid) + seq_poses.append(poses) + seq_frame_ids.append(np.asarray(avail, dtype=np.int64)) + seq_velo_dir.append(velo_dir) + scene_img_list.append(img_ids) + sceneids.extend([j] * n_imgs) + start_img_ids.extend(start_img_ids_) + offset += n_imgs + j += 1 + + self.scenes = scenes + self.seq_poses = seq_poses + self.seq_frame_ids = seq_frame_ids + self.seq_velo_dir = seq_velo_dir + self.scene_img_list = scene_img_list + self.sceneids = sceneids + self.start_img_ids = start_img_ids + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return sum(len(p) for p in self.seq_poses) + + def get_stats(self): + return f"{len(self)} groups across {len(self.scenes)} KITTI-360 sequences" + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + scene_id = self.sceneids[start_id] + all_image_ids = self.scene_img_list[scene_id] + n_frames = len(all_image_ids) + sid = self.scenes[scene_id] + img_dir = osp.join(self.ROOT, "data_2d_raw", sid, "image_00", "data_rect") + frame_ids = self.seq_frame_ids[scene_id] + poses = self.seq_poses[scene_id] + velo_dir = self.seq_velo_dir[scene_id] + K = self.K + has_lidar = velo_dir is not None and self.T_velo_to_cam_rect is not None + + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=0.9, + ) + local_idxs = np.asarray(pos, dtype=int) + + views = [] + for v, lid in enumerate(local_idxs): + lid = int(lid) + fid = int(frame_ids[lid]) + img_path = osp.join(img_dir, f"{fid:010d}.png") + image = imread_cv2(img_path) + H, W = image.shape[:2] + + # If velodyne available, project lidar to image_00 → sparse depthmap. + if has_lidar: + bin_path = osp.join(velo_dir, f"{fid:010d}.bin") + if osp.isfile(bin_path): + velo_pts = _load_velodyne_bin(bin_path) + depthmap = _project_velo_to_depth_kitti360( + velo_pts, self.P_rect_00, self.T_velo_to_cam_rect, H, W + ) + frame_has_lidar = bool((depthmap > 0).any()) + else: + depthmap = np.full((H, W), -1.0, dtype=np.float32) + frame_has_lidar = False + else: + depthmap = np.full((H, W), -1.0, dtype=np.float32) + frame_has_lidar = False + + intrinsics = K.copy() + camera_pose = poses[lid].astype(np.float32) + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(img_dir, img_path) + ) + + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.85, 0.1, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset="KITTI360", + label=img_dir, + is_metric=self.is_metric, + instance=f"{sid}/image_00/{fid:010d}.png", + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=not frame_has_lidar, + depth_only=False, + single_view=False, + reset=False, + scene_tag=f"kitti360/{sid}", + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/mapfree.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/mapfree.py new file mode 100644 index 0000000000000000000000000000000000000000..58eef2f61642deeca4e7accb84429f3d471a5bd9 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/mapfree.py @@ -0,0 +1,282 @@ +import os.path as osp +import numpy as np +import cv2 +import numpy as np +import itertools +import os +import sys +import pickle +import h5py +from tqdm import tqdm + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class MapFree_Multi(BaseMultiViewDataset): + + def __init__(self, ROOT, *args, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 30 + super().__init__(*args, **kwargs) + + self._load_data() + + def imgid2path(self, img_id, scene): + first_seq_id, first_frame_id = img_id + return os.path.join( + self.ROOT, + scene, + f"dense{first_seq_id}", + "rgb", + f"frame_{first_frame_id:05d}.jpg", + ) + + def path2imgid(self, subscene, filename): + first_seq_id = int(subscene[5:]) + first_frame_id = int(filename[6:-4]) + return [first_seq_id, first_frame_id] + + def _load_data(self): + cache_file = f"{self.ROOT}/cached_metadata_50_col_only.h5" + if os.path.exists(cache_file): + print(f"Loading cached metadata from {cache_file}") + with h5py.File(cache_file, "r") as hf: + self.scenes = list(map(lambda x: x.decode("utf-8"), hf["scenes"][:])) + self.sceneids = hf["sceneids"][:] + self.scope = hf["scope"][:] + self.video_flags = hf["video_flags"][:] + self.groups = hf["groups"][:] + self.id_ranges = hf["id_ranges"][:] + self.images = hf["images"][:] + else: + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + scenes = [] + sceneids = [] + groups = [] + scope = [] + images = [] + id_ranges = [] + is_video = [] + start = 0 + j = 0 + offset = 0 + + for scene in tqdm(scene_dirs): + scenes.append(scene) + # video sequences + subscenes = sorted( + [ + d + for d in os.listdir(os.path.join(self.ROOT, scene)) + if d.startswith("dense") + ] + ) + id_range_subscenes = [] + for subscene in subscenes: + rgb_paths = sorted( + [ + d + for d in os.listdir( + os.path.join(self.ROOT, scene, subscene, "rgb") + ) + if d.endswith(".jpg") + ] + ) + assert ( + len(rgb_paths) > 0 + ), f"{os.path.join(self.ROOT, scene, subscene)} is empty." + num_imgs = len(rgb_paths) + images.extend( + [self.path2imgid(subscene, rgb_path) for rgb_path in rgb_paths] + ) + id_range_subscenes.append((offset, offset + num_imgs)) + offset += num_imgs + + # image collections + metadata = pickle.load( + open(os.path.join(self.ROOT, scene, "metadata.pkl"), "rb") + ) + ref_imgs = list(metadata.keys()) + img_groups = [] + for ref_img in ref_imgs: + other_imgs = metadata[ref_img] + if len(other_imgs) + 1 < self.num_views: + continue + group = [(*other_img[0], other_img[1]) for other_img in other_imgs] + group.insert(0, (*ref_img, 1)) + img_groups.append(np.array(group)) + id_ranges.append(id_range_subscenes[ref_img[0]]) + scope.append(start) + start = start + len(group) + + num_groups = len(img_groups) + sceneids.extend([j] * num_groups) + groups.extend(img_groups) + is_video.extend([False] * num_groups) + j += 1 + + self.scenes = np.array(scenes) + self.sceneids = np.array(sceneids) + self.scope = np.array(scope) + self.video_flags = np.array(is_video) + self.groups = np.concatenate(groups, 0) + self.id_ranges = np.array(id_ranges) + self.images = np.array(images) + + data = dict( + scenes=self.scenes, + sceneids=self.sceneids, + scope=self.scope, + video_flags=self.video_flags, + groups=self.groups, + id_ranges=self.id_ranges, + images=self.images, + ) + + with h5py.File(cache_file, "w") as h5f: + h5f.create_dataset( + "scenes", + data=data["scenes"].astype(object), + dtype=h5py.string_dtype(encoding="utf-8"), + compression="lzf", + chunks=True, + ) + h5f.create_dataset( + "sceneids", data=data["sceneids"], compression="lzf", chunks=True + ) + h5f.create_dataset( + "scope", data=data["scope"], compression="lzf", chunks=True + ) + h5f.create_dataset( + "video_flags", + data=data["video_flags"], + compression="lzf", + chunks=True, + ) + h5f.create_dataset( + "groups", data=data["groups"], compression="lzf", chunks=True + ) + h5f.create_dataset( + "id_ranges", data=data["id_ranges"], compression="lzf", chunks=True + ) + h5f.create_dataset( + "images", data=data["images"], compression="lzf", chunks=True + ) + + def __len__(self): + return len(self.scope) + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + scene = self.scenes[self.sceneids[idx]] + if rng.random() < 0.6: + ids = np.arange(self.id_ranges[idx][0], self.id_ranges[idx][1]) + cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3) + start_ids = ids[: len(ids) - cut_off + 1] + start_id = rng.choice(start_ids) + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + ids.tolist(), + rng, + max_interval=self.max_interval, + video_prob=0.8, + fix_interval_prob=0.5, + block_shuffle=16, + ) + ids = np.array(ids)[pos] + image_idxs = self.images[ids] + else: + ordered_video = False + seq_start_index = self.scope[idx] + seq_end_index = self.scope[idx + 1] if idx < len(self.scope) - 1 else None + image_idxs = ( + self.groups[seq_start_index:seq_end_index] + if seq_end_index is not None + else self.groups[seq_start_index:] + ) + image_idxs, overlap_scores = image_idxs[:, :2], image_idxs[:, 2] + replace = ( + True + if self.allow_repeat + or len(overlap_scores[overlap_scores > 0]) < num_views + else False + ) + image_idxs = rng.choice( + image_idxs, + num_views, + replace=replace, + p=overlap_scores / np.sum(overlap_scores), + ) + image_idxs = image_idxs.astype(np.int64) + + views = [] + for v, view_idx in enumerate(image_idxs): + img_path = self.imgid2path(view_idx, scene) + depth_path = img_path.replace("rgb", "depth").replace(".jpg", ".npy") + cam_path = img_path.replace("rgb", "cam").replace(".jpg", ".npz") + sky_mask_path = img_path.replace("rgb", "sky_mask") + image = imread_cv2(img_path) + depthmap = np.load(depth_path) + camera_params = np.load(cam_path) + sky_mask = cv2.imread(sky_mask_path, cv2.IMREAD_UNCHANGED) >= 127 + + intrinsics = camera_params["intrinsic"].astype(np.float32) + camera_pose = camera_params["pose"].astype(np.float32) + + depthmap[sky_mask] = -1.0 + depthmap[depthmap > 400.0] = 0.0 + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + threshold = ( + np.percentile(depthmap[depthmap > 0], 98) + if depthmap[depthmap > 0].size > 0 + else 0 + ) + depthmap[depthmap > threshold] = 0.0 + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(img_path) + ) + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="MapFree", + label=img_path, + is_metric=self.is_metric, + instance=img_path, + is_video=ordered_video, + quantile=np.array(0.96, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/megadepth.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..75f9747e7da5a32998882ab22f44dfb9a515688f --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/megadepth.py @@ -0,0 +1,100 @@ +import os.path as osp +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2, imread_pil + + +class MegaDepth_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self._load_data(self.split) + self.is_metric = False + if self.split is None: + pass + elif self.split == "train": + self.select_scene(("0015", "0022"), opposite=True) + elif self.split == "val": + self.select_scene(("0015", "0022")) + else: + raise ValueError(f"bad {self.split=}") + + print('DATA: megadepth', len(self)) + + def _load_data(self, split): + with np.load( + osp.join(self.ROOT, "megadepth_sets_64.npz"), allow_pickle=True + ) as data: + self.all_scenes = data["scenes"] + self.all_images = data["images"] + self.sets = data["sets"] + + def __len__(self): + return len(self.sets) + + def get_image_num(self): + return len(self.all_images) + + def get_stats(self): + return f"{len(self)} groups from {len(self.all_scenes)} scenes" + + def select_scene(self, scene, *instances, opposite=False): + scenes = (scene,) if isinstance(scene, str) else tuple(scene) + scene_id = [s.startswith(scenes) for s in self.all_scenes] + assert any(scene_id), "no scene found" + valid = np.in1d(self.sets[:, 0], np.nonzero(scene_id)[0]) + if instances: + raise NotImplementedError("selecting instances not implemented") + if opposite: + valid = ~valid + assert valid.any() + self.sets = self.sets[valid] + + def _get_views(self, idx, resolution, rng, num_views): + scene_id = self.sets[idx][0] + image_idxs = self.sets[idx][1:65] + replace = False if not self.allow_repeat else True + image_idxs = rng.choice(image_idxs, num_views, replace=replace) + scene, subscene = self.all_scenes[scene_id].split() + seq_path = osp.join(self.ROOT, scene, subscene) + views = [] + for im_id in image_idxs: + img = self.all_images[im_id] + try: + image = imread_pil(osp.join(seq_path, img + ".jpg")) + depthmap = imread_cv2(osp.join(seq_path, img + ".exr")) + camera_params = np.load(osp.join(seq_path, img + ".npz")) + except Exception as e: + raise OSError(f"cannot load {img}, got exception {e}") + intrinsics = np.float32(camera_params["intrinsics"]) + camera_pose = np.float32(camera_params["cam2world"]) + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(seq_path, img) + ) + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="MegaDepth", + label=osp.relpath(seq_path, self.ROOT), + is_metric=self.is_metric, + instance=img, + is_video=False, + quantile=np.array(0.96, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/mvs_synth.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/mvs_synth.py new file mode 100644 index 0000000000000000000000000000000000000000..09f1b1a85364a8de08813396d76762a2f8f2c966 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/mvs_synth.py @@ -0,0 +1,144 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_pil + + +class MVS_Synth_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = False + self.max_interval = 4 + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + print('DATA: mvs_synth', len(self)) + + def _load_data(self): + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".jpg")] + ) + num_imgs = len(basenames) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=1.0, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_pil(osp.join(rgb_dir, basename + ".jpg")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + threshold = ( + np.percentile(depthmap[depthmap > 0], 98) + if depthmap[depthmap > 0].size > 0 + else 0 + ) + depthmap[depthmap > threshold] = 0.0 + depthmap[depthmap > 1000] = 0.0 + + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.8, 0.15, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="MVS_Synth", + label=self.scenes[scene_id] + "_" + basename, + instance=osp.join(rgb_dir, basename + ".jpg"), + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/omniobject3d.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/omniobject3d.py new file mode 100644 index 0000000000000000000000000000000000000000..1d8e1019c94e30c70dd1d9dd2d50ff9dee46b924 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/omniobject3d.py @@ -0,0 +1,146 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys +import json + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 +import re + + +def extract_number(filename): + match = re.search(r"\d+", filename) + if match: + return int(match.group()) + return 0 + + +class OmniObject3D_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = False + self.is_metric = False # True + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data() + + def _load_data(self): + self.scenes = [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) and not d.startswith('.') + ] + with open(os.path.join(self.ROOT, "scale.json"), "r") as f: + self.scales = json.load(f) + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")], + key=extract_number, + ) + + num_imgs = len(basenames) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend([(scene, id) for id in start_img_ids_]) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + scene, start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, start_id, all_image_ids, rng, max_interval=100, video_prob=0.0 + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + scale = self.scales[self.scenes[scene_id]] + depthmap = depthmap / scale / 1000.0 + camera_pose[:3, 3] = camera_pose[:3, 3] / scale / 1000.0 + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.8, 0.15, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="OmniObject3D", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/pointodyssey.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/pointodyssey.py new file mode 100644 index 0000000000000000000000000000000000000000..9ced302f1bdaed09fc2294fd6c3a7dd8e248f964 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/pointodyssey.py @@ -0,0 +1,178 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class PointOdyssey_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 4 + super().__init__(*args, **kwargs) + assert self.split in ["train", "test", "val"] + self.scenes_to_use = [ + # 'cab_h_bench_3rd', 'cab_h_bench_ego1', 'cab_h_bench_ego2', + "cnb_dlab_0215_3rd", + "cnb_dlab_0215_ego1", + "cnb_dlab_0225_3rd", + "cnb_dlab_0225_ego1", + "dancing", + "dancingroom0_3rd", + "footlab_3rd", + "footlab_ego1", + "footlab_ego2", + "girl", + "girl_egocentric", + "human_egocentric", + "human_in_scene", + "human_in_scene1", + "kg", + "kg_ego1", + "kg_ego2", + "kitchen_gfloor", + "kitchen_gfloor_ego1", + "kitchen_gfloor_ego2", + "scene_carb_h_tables", + "scene_carb_h_tables_ego1", + "scene_carb_h_tables_ego2", + "scene_j716_3rd", + "scene_j716_ego1", + "scene_j716_ego2", + "scene_recording_20210910_S05_S06_0_3rd", + "scene_recording_20210910_S05_S06_0_ego2", + "scene1_0129", + "scene1_0129_ego", + "seminar_h52_3rd", + "seminar_h52_ego1", + "seminar_h52_ego2", + ] + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + root = os.path.join(self.ROOT, split) + self.scenes = [] + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(os.listdir(root)): + if scene not in self.scenes_to_use: + continue + scene_dir = osp.join(root, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".jpg")] + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + # start_img_ids_ = img_ids[:-self.num_views+1] + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=1.0, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".jpg")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + depthmap[depthmap > 1000] = 0.0 + + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.9, 0.05, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="PointOdyssey", + label=self.scenes[scene_id] + "_" + basename, + instance=osp.join(rgb_dir, basename + ".jpg"), + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/realestate10k.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/realestate10k.py new file mode 100644 index 0000000000000000000000000000000000000000..34526946529905640be4ee49d0530b950bafdb04 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/realestate10k.py @@ -0,0 +1,139 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class RE10K_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = False + self.max_interval = 128 + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")], + key=lambda x: int(x), + ) + + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend([(scene, id) for id in start_img_ids_]) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + self.invalid_scenes = {scene: False for scene in self.scenes} + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + invalid_seq = True + scene, start_id = self.start_img_ids[idx] + + while invalid_seq: + while self.invalid_scenes[scene]: + idx = rng.integers(low=0, high=len(self.start_img_ids)) + scene, start_id = self.start_img_ids[idx] + + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, start_id, all_image_ids, rng, max_interval=self.max_interval + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for view_idx in image_idxs: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + try: + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + # Load depthmap, no depth, set to all ones + depthmap = np.ones_like(rgb_image[..., 0], dtype=np.float32) + cam = np.load(osp.join(cam_dir, basename + ".npz")) + intrinsics = cam["intrinsics"] + camera_pose = cam["pose"] + except: + print(f"Error loading {scene} {basename}, skipping") + self.invalid_scenes[scene] = True + break + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="realestate10k", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=True, + depth_only=False, + single_view=False, + reset=False, + ) + ) + if len(views) == num_views: + invalid_seq = False + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/scannet.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6644615d2e9761a2a3cec8178a22be5f316afa --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/scannet.py @@ -0,0 +1,149 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2, imread_pil + + +class ScanNet_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 30 + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data(self.split) + print('DATA: scannet', len(self)) + + def _load_data(self, split): + self.scene_root = osp.join( + self.ROOT, "scans_train" if split == "train" else "scans_test" + ) + self.scenes = [ + scene for scene in os.listdir(self.scene_root) if scene.startswith("scene") + ] + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.scene_root, scene) + with np.load( + osp.join(scene_dir, "new_scene_metadata.npz"), allow_pickle=True + ) as data: + basenames = data["images"] + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=0.6, + fix_interval_prob=0.6, + block_shuffle=16, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.scene_root, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "color") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_pil(osp.join(rgb_dir, basename + ".jpg")) + # Load depthmap + depthmap = imread_cv2( + osp.join(depth_dir, basename + ".png"), cv2.IMREAD_UNCHANGED + ) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="ScanNet", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/scannetpp.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/scannetpp.py new file mode 100644 index 0000000000000000000000000000000000000000..4cca5c0ade3ccf79f97f31e5f30a823e032152c6 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/scannetpp.py @@ -0,0 +1,211 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2, imread_pil + + +class ScanNetpp_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 3 + super().__init__(*args, **kwargs) + assert self.split == "train" + self.loaded_data = self._load_data() + + def _load_data(self): + with np.load(osp.join(self.ROOT, "all_metadata.npz")) as data: + self.scenes = data["scenes"] + offset = 0 + scenes = [] + sceneids = [] + images = [] + intrinsics = [] + trajectories = [] + groups = [] + id_ranges = [] + j = 0 + self.image_num = 0 + for scene in self.scenes: + scene_dir = osp.join(self.ROOT, scene) + with np.load( + osp.join(scene_dir, "new_scene_metadata.npz"), allow_pickle=True + ) as data: + imgs = data["images"] + self.image_num += len(imgs) + img_ids = np.arange(len(imgs)).tolist() + intrins = data["intrinsics"] + traj = data["trajectories"] + imgs_on_disk = sorted(os.listdir(osp.join(scene_dir, "images"))) + imgs_on_disk = list(map(lambda x: x[:-4], imgs_on_disk)) + + dslr_ids = [ + i + offset + for i in img_ids + if imgs[i].startswith("DSC") and imgs[i] in imgs_on_disk + ] + iphone_ids = [ + i + offset + for i in img_ids + if imgs[i].startswith("frame") and imgs[i] in imgs_on_disk + ] + + num_imgs = len(imgs) + assert max(dslr_ids) < min(iphone_ids) + assert "image_collection" in data + + img_groups = [] + img_id_ranges = [] + + # 使用与其他数据集一致的 cut_off 逻辑 + min_group_len = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + + for ref_id, group in data["image_collection"].item().items(): + if len(group) + 1 < min_group_len: + continue + group.insert(0, (ref_id, 1.0)) + sorted_group = sorted(group, key=lambda x: x[1], reverse=True) + group = [int(x[0] + offset) for x in sorted_group] + + # 确定对应的视频帧列表 + if imgs[ref_id].startswith("frame"): + video_ids = dslr_ids + else: + video_ids = iphone_ids + + # 只有当视频帧列表足够长时才添加 + if len(video_ids) >= min_group_len: + img_groups.append(sorted(group)) + img_id_ranges.append(video_ids) + + if len(img_groups) == 0: + print(f"Skipping {scene}") + continue + scenes.append(scene) + sceneids.extend([j] * num_imgs) + images.extend(imgs) + intrinsics.append(intrins) + trajectories.append(traj) + + # offset groups + groups.extend(img_groups) + id_ranges.extend(img_id_ranges) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.intrinsics = np.concatenate(intrinsics, axis=0) + self.trajectories = np.concatenate(trajectories, axis=0) + self.id_ranges = id_ranges + self.groups = groups + + def __len__(self): + return len(self.groups) * 10 + + def get_image_num(self): + return self.image_num + + def _get_views(self, idx, resolution, rng, num_views): + idx = idx // 10 + image_idxs = self.groups[idx] + rand_val = rng.random() + + image_idxs_video = self.id_ranges[idx] + cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3) + start_image_idxs = image_idxs_video[: len(image_idxs_video) - cut_off + 1] + + if rand_val < 0.7 and len(start_image_idxs) > 0: + start_id = rng.choice(start_image_idxs) + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + image_idxs_video, + rng, + max_interval=self.max_interval, + video_prob=0.8, + fix_interval_prob=0.5, + block_shuffle=16, + ) + image_idxs = np.array(image_idxs_video)[pos] + + else: + ordered_video = True + # ordered video with varying intervals + num_candidates = len(image_idxs) + max_id = min(num_candidates, int(num_views * (2 + 2 * rng.random()))) + + # 确保有足够的候选帧 + if num_candidates < num_views: + # 如果候选帧不足,使用重复采样 + image_idxs = sorted(rng.choice(image_idxs, size=num_views, replace=True)) + else: + image_idxs = sorted(rng.permutation(image_idxs[:max_id])[:num_views]) + + if rand_val > 0.75: + ordered_video = False + image_idxs = rng.permutation(image_idxs) + + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_pil(osp.join(scene_dir, "images", basename + ".jpg")) + # Load depthmap + depthmap = imread_cv2( + osp.join(scene_dir, "depth", basename + ".png"), cv2.IMREAD_UNCHANGED + ) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="ScanNet++", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.99, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/smartportraits.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/smartportraits.py new file mode 100644 index 0000000000000000000000000000000000000000..a5955aecd651f2bf1f6a666b0869b5d97816cf5f --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/smartportraits.py @@ -0,0 +1,85 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class SmartPortraits_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + scenes = os.listdir(self.ROOT) + img_names = [] + for scene in scenes: + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")] + ) + img_names.extend([(scene, basename) for basename in basenames]) + + self.img_names = img_names + + def __len__(self): + return len(self.img_names) + + def get_image_num(self): + return len(self.img_names) + + def _get_views(self, idx, resolution, rng, num_views): + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + img_names = new_rng.choice(self.img_names, num_views, replace=False) + + views = [] + for v, img_name in enumerate(img_names): + # Load RGB image + scene, img_name = img_name + rgb_image = imread_cv2(osp.join(self.ROOT, scene, "rgb", f"{img_name}.png")) + depthmap = np.load(osp.join(self.ROOT, scene, "depth", f"{img_name}.npy")) + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + + intrinsics = np.load(osp.join(self.ROOT, scene, "cam", f"{img_name}.npz"))[ + "intrinsics" + ] + # camera pose is not provided, placeholder + camera_pose = np.eye(4) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="SmartPortraits", + label=img_name, + instance=osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"), + is_metric=self.is_metric, + is_video=False, + quantile=np.array(0.98, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/tartanair.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/tartanair.py new file mode 100644 index 0000000000000000000000000000000000000000..760d0e9d6921bb31354fbe505821b550d301f83a --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/tartanair.py @@ -0,0 +1,164 @@ +import os.path as osp +import numpy as np +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class TartanAir_Multi(BaseMultiViewDataset): + + def __init__(self, ROOT, *args, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 20 + super().__init__(*args, **kwargs) + # loading all + assert self.split is None + self._load_data() + + def _load_data(self): + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + + offset = 0 + scenes = [] + sceneids = [] + images = [] + scene_img_list = [] + start_img_ids = [] + j = 0 + + for scene in scene_dirs: + for mode in ["Easy", "Hard"]: + seq_dirs = sorted( + [ + os.path.join(self.ROOT, scene, mode, d) + for d in os.listdir(os.path.join(self.ROOT, scene, mode)) + if os.path.isdir(os.path.join(self.ROOT, scene, mode, d)) + ] + ) + for seq_dir in seq_dirs: + basenames = sorted( + [f[:-8] for f in os.listdir(seq_dir) if f.endswith(".png")] + ) + num_imgs = len(basenames) + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + scenes.append(seq_dir) + scene_img_list.append(img_ids) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + start_img_ids.extend(start_img_ids_) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + scene_id = self.sceneids[start_id] + all_image_ids = self.scene_img_list[scene_id] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=0.8, + fix_interval_prob=0.8, + block_shuffle=16, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = self.scenes[scene_id] + basename = self.images[view_idx] + + img = basename + "_rgb.png" + image = imread_cv2(osp.join(scene_dir, img)) + depthmap = np.load(osp.join(scene_dir, basename + "_depth.npy")) + camera_params = np.load(osp.join(scene_dir, basename + "_cam.npz")) + + intrinsics = camera_params["camera_intrinsics"] + camera_pose = camera_params["camera_pose"] + + sky_mask = depthmap >= 1000 + depthmap[sky_mask] = -1.0 # sky + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + threshold = ( + np.percentile(depthmap[depthmap > 0], 98) + if depthmap[depthmap > 0].size > 0 + else 0 + ) + depthmap[depthmap > threshold] = 0.0 + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(scene_dir, img) + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="TartanAir", + label=scene_dir, + is_metric=self.is_metric, + instance=scene_dir + "_" + img, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/threedkb.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/threedkb.py new file mode 100644 index 0000000000000000000000000000000000000000..face09abd00f76cd62e7654b1b673e9d1d3394b7 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/threedkb.py @@ -0,0 +1,111 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class ThreeDKenBurns(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = False + self.is_metric = False + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + images = [] + img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")] + ) + + num_imgs = len(basenames) + img_ids_ = list(np.arange(num_imgs) + offset) + + img_ids.extend(img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.img_ids = img_ids + + def __len__(self): + return len(self.img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + image_idxs = new_rng.choice(self.img_ids, num_views, replace=False) + + views = [] + for view_idx in image_idxs: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + depthmap = imread_cv2(osp.join(depth_dir, basename + ".exr")) + depthmap[depthmap > 20000] = 0.0 + depthmap = depthmap / 1000.0 + cam = np.load(osp.join(cam_dir, basename + ".npz")) + intrinsics = cam["intrinsics"] + camera_pose = np.eye(4) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="3DKenBurns", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=False, + quantile=np.array(1.0, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/uasol.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/uasol.py new file mode 100644 index 0000000000000000000000000000000000000000..b91b43bdd6a27691ac5016b22c183ac300d219a9 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/uasol.py @@ -0,0 +1,148 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + +import re + + +def extract_number(filename): + match = re.search(r"\d+", filename) + if match: + return int(match.group()) + return 0 + + +class UASOL_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 40 + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")], + key=extract_number, + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + # start_img_ids_ = img_ids[:-self.num_views+1] + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=0.75, + fix_interval_prob=0.75, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + depthmap[depthmap >= 20] = 0 # invalid + + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="UASOL", + label=self.scenes[scene_id] + "_" + basename, + instance=osp.join(rgb_dir, basename + ".png"), + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.9, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/unreal4k.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/unreal4k.py new file mode 100644 index 0000000000000000000000000000000000000000..4d9092928daacf527c99e1958bbee85ef9110035 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/unreal4k.py @@ -0,0 +1,159 @@ +import os.path as osp +import numpy as np +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + +R_conv = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).astype( + np.float32 +) + + +class UnReal4K_Multi(BaseMultiViewDataset): + + def __init__(self, ROOT, *args, **kwargs): + self.ROOT = ROOT + self.max_interval = 2 + self.is_metric = True + super().__init__(*args, **kwargs) + # loading all + assert self.split is None + self._load_data() + + def _load_data(self): + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + + offset = 0 + scenes = [] + sceneids = [] + images = [] + start_img_ids = [] + scene_img_list = [] + j = 0 + + seq_dirs = sorted( + [ + os.path.join(self.ROOT, scene, mode) + for scene in scene_dirs + for mode in ["0", "1"] + ] + ) + for seq_dir in seq_dirs: + basenames = sorted( + [f[:-8] for f in os.listdir(seq_dir) if f.endswith(".png")] + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + # start_img_ids_ = img_ids[:-self.num_views+1] + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + if num_imgs < cut_off: + print(f"Skipping {seq_dir}") + continue + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(seq_dir) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) * 10 + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)//10} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + idx = idx // 10 + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, start_id, all_image_ids, rng, max_interval=self.max_interval + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = self.scenes[scene_id] + basename = self.images[view_idx] + + img = basename + "_rgb.png" + image = imread_cv2(osp.join(scene_dir, img)) + depthmap = np.load(osp.join(scene_dir, basename + "_depth.npy")) + camera_params = np.load(osp.join(scene_dir, basename + ".npz")) + + intrinsics = camera_params["intrinsics"].astype(np.float32) + camera_pose = camera_params["cam2world"].astype(np.float32) + + camera_pose = R_conv @ camera_pose + + sky_mask = depthmap >= 1000 + depthmap[sky_mask] = -1.0 # sky + threshold = ( + np.percentile(depthmap[depthmap > 0], 98) + if depthmap[depthmap > 0].size > 0 + else 0 + ) + depthmap[depthmap > threshold] = 0.0 + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(scene_dir, img) + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="UnReal4K", + label=scene_dir, + is_metric=self.is_metric, + instance=scene_dir + "_" + img, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/utils/__init__.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/utils/corr.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/utils/corr.py new file mode 100644 index 0000000000000000000000000000000000000000..a0413d4cc035f21acd9b02fb2bccebe36ab57736 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/utils/corr.py @@ -0,0 +1,129 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import numpy as np +from dust3r.utils.device import to_numpy +from dust3r.utils.geometry import inv, geotrf + + +def reproject_view(pts3d, view2): + shape = view2["pts3d"].shape[:2] + return reproject( + pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape + ) + + +def reproject(pts3d, K, world2cam, shape): + H, W, THREE = pts3d.shape + assert THREE == 3 + + # reproject in camera2 space + with np.errstate(divide="ignore", invalid="ignore"): + pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2) + + # quantize to pixel positions + return (H, W), ravel_xy(pos, shape) + + +def ravel_xy(pos, shape): + H, W = shape + with np.errstate(invalid="ignore"): + qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T + quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip( + min=0, max=H - 1, out=qy + ) + return quantized_pos + + +def unravel_xy(pos, shape): + # convert (x+W*y) back to 2d (x,y) coordinates + return np.unravel_index(pos, shape)[0].base[:, ::-1].copy() + + +def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False): + is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2)) + pos1 = is_reciprocal1.nonzero()[0] + pos2 = corres_1_to_2[pos1] + if ret_recip: + return is_reciprocal1, pos1, pos2 + return pos1, pos2 + + +def extract_correspondences_from_pts3d( + view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0 +): + view1, view2 = to_numpy((view1, view2)) + # project pixels from image1 --> 3d points --> image2 pixels + shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2) + shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1) + + # compute reciprocal correspondences: + # pos1 == valid pixels (correspondences) in image1 + is_reciprocal1, pos1, pos2 = reciprocal_1d( + corres1_to_2, corres2_to_1, ret_recip=True + ) + is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1)) + + if target_n_corres is None: + if ret_xy: + pos1 = unravel_xy(pos1, shape1) + pos2 = unravel_xy(pos2, shape2) + return pos1, pos2 + + available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum()) + target_n_positives = int(target_n_corres * (1 - nneg)) + n_positives = min(len(pos1), target_n_positives) + n_negatives = min(target_n_corres - n_positives, available_negatives) + + if n_negatives + n_positives != target_n_corres: + # should be really rare => when there are not enough negatives + # in that case, break nneg and add a few more positives ? + n_positives = target_n_corres - n_negatives + assert n_positives <= len(pos1) + + assert n_positives <= len(pos1) + assert n_positives <= len(pos2) + assert n_negatives <= (~is_reciprocal1).sum() + assert n_negatives <= (~is_reciprocal2).sum() + assert n_positives + n_negatives == target_n_corres + + valid = np.ones(n_positives, dtype=bool) + if n_positives < len(pos1): + # random sub-sampling of valid correspondences + perm = rng.permutation(len(pos1))[:n_positives] + pos1 = pos1[perm] + pos2 = pos2[perm] + + if n_negatives > 0: + # add false correspondences if not enough + def norm(p): + return p / p.sum() + + pos1 = np.r_[ + pos1, + rng.choice( + shape1[0] * shape1[1], + size=n_negatives, + replace=False, + p=norm(~is_reciprocal1), + ), + ] + pos2 = np.r_[ + pos2, + rng.choice( + shape2[0] * shape2[1], + size=n_negatives, + replace=False, + p=norm(~is_reciprocal2), + ), + ] + valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)] + + # convert (x+W*y) back to 2d (x,y) coordinates + if ret_xy: + pos1 = unravel_xy(pos1, shape1) + pos2 = unravel_xy(pos2, shape2) + return pos1, pos2, valid diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/utils/cropping.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/utils/cropping.py new file mode 100644 index 0000000000000000000000000000000000000000..6074f0d93b54ef5af36189276e0f179825a525fe --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/utils/cropping.py @@ -0,0 +1,147 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# croppping utilities +# -------------------------------------------------------- +import PIL.Image +import os + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa +import numpy as np # noqa +from dust3r.utils.geometry import ( + colmap_to_opencv_intrinsics, + opencv_to_colmap_intrinsics, +) # noqa + +try: + lanczos = PIL.Image.Resampling.LANCZOS + bicubic = PIL.Image.Resampling.BICUBIC +except AttributeError: + lanczos = PIL.Image.LANCZOS + bicubic = PIL.Image.BICUBIC + + +class ImageList: + """Convenience class to aply the same operation to a whole set of images.""" + + def __init__(self, images): + if not isinstance(images, (tuple, list, set)): + images = [images] + self.images = [] + for image in images: + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + self.images.append(image) + + def __len__(self): + return len(self.images) + + def to_pil(self): + return tuple(self.images) if len(self.images) > 1 else self.images[0] + + @property + def size(self): + sizes = [im.size for im in self.images] + assert all(sizes[0] == s for s in sizes) + return sizes[0] + + def resize(self, *args, **kwargs): + return ImageList(self._dispatch("resize", *args, **kwargs)) + + def crop(self, *args, **kwargs): + return ImageList(self._dispatch("crop", *args, **kwargs)) + + def _dispatch(self, func, *args, **kwargs): + return [getattr(im, func)(*args, **kwargs) for im in self.images] + + +def rescale_image_depthmap( + image, depthmap, camera_intrinsics, output_resolution, force=True +): + """Jointly rescale a (image, depthmap) + so that (out_width, out_height) >= output_res + """ + image = ImageList(image) + input_resolution = np.array(image.size) # (W,H) + output_resolution = np.array(output_resolution) + if depthmap is not None: + # can also use this with masks instead of depthmaps + assert tuple(depthmap.shape[:2]) == image.size[::-1] + + # define output resolution + assert output_resolution.shape == (2,) + scale_final = max(output_resolution / image.size) + 1e-8 + if scale_final >= 1 and not force: # image is already smaller than what is asked + return (image.to_pil(), depthmap, camera_intrinsics) + output_resolution = np.floor(input_resolution * scale_final).astype(int) + + # first rescale the image so that it contains the crop + image = image.resize( + output_resolution, resample=lanczos if scale_final < 1 else bicubic + ) + if depthmap is not None: + depthmap = cv2.resize( + depthmap, + output_resolution, + fx=scale_final, + fy=scale_final, + interpolation=cv2.INTER_NEAREST, + ) + + # no offset here; simple rescaling + camera_intrinsics = camera_matrix_of_crop( + camera_intrinsics, input_resolution, output_resolution, scaling=scale_final + ) + + return image.to_pil(), depthmap, camera_intrinsics + + +def camera_matrix_of_crop( + input_camera_matrix, + input_resolution, + output_resolution, + scaling=1, + offset_factor=0.5, + offset=None, +): + # Margins to offset the origin + margins = np.asarray(input_resolution) * scaling - output_resolution + assert np.all(margins >= 0.0) + if offset is None: + offset = offset_factor * margins + + # Generate new camera parameters + output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) + output_camera_matrix_colmap[:2, :] *= scaling + output_camera_matrix_colmap[:2, 2] -= offset + output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) + + return output_camera_matrix + + +def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): + """ + Return a crop of the input view. + """ + image = ImageList(image) + l, t, r, b = crop_bbox + + image = image.crop((l, t, r, b)) + depthmap = depthmap[t:b, l:r] + + camera_intrinsics = camera_intrinsics.copy() + camera_intrinsics[0, 2] -= l + camera_intrinsics[1, 2] -= t + + return image.to_pil(), depthmap, camera_intrinsics + + +def bbox_from_intrinsics_in_out( + input_camera_matrix, output_camera_matrix, output_resolution +): + out_width, out_height = output_resolution + l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) + crop_bbox = (l, t, l + out_width, t + out_height) + return crop_bbox diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/utils/transforms.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..39a4450e57e3482315e307e72c0f3b19e77dea3b --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/utils/transforms.py @@ -0,0 +1,80 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUST3R default transforms +# -------------------------------------------------------- +import torchvision.transforms as tvf +from dust3r.utils.image import ImgNorm + +# define the standard image transforms +ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) + + +def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True): + if isinstance(value, (int, float)): + if value < 0: + raise ValueError(f"If is a single number, it must be non negative.") + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + value = [float(value[0]), float(value[1])] + else: + raise TypeError(f"should be a single number or a list/tuple with length 2.") + + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"values should be between {bound}, but got {value}.") + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + return None + else: + return tuple(value) + + +import torch +import torchvision.transforms.functional as F + + +def SeqColorJitter(): + """ + Return a color jitter transform with same random parameters + """ + brightness = _check_input(0.5) + contrast = _check_input(0.5) + saturation = _check_input(0.5) + hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + + fn_idx = torch.randperm(4) + brightness_factor = ( + None + if brightness is None + else float(torch.empty(1).uniform_(brightness[0], brightness[1])) + ) + contrast_factor = ( + None + if contrast is None + else float(torch.empty(1).uniform_(contrast[0], contrast[1])) + ) + saturation_factor = ( + None + if saturation is None + else float(torch.empty(1).uniform_(saturation[0], saturation[1])) + ) + hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) + + def _color_jitter(img): + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return ImgNorm(img) + + return _color_jitter diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/vkitti2.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/vkitti2.py new file mode 100644 index 0000000000000000000000000000000000000000..438e24f425fdb610b870c4d7b7f02b66ce8e3246 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/vkitti2.py @@ -0,0 +1,169 @@ +import os.path as osp +import numpy as np +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class VirtualKITTI2_Multi(BaseMultiViewDataset): + + def __init__(self, ROOT, *args, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 5 + super().__init__(*args, **kwargs) + # loading all + self._load_data(self.split) + + def _load_data(self, split=None): + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + if split == "train": + scene_dirs = scene_dirs[:-1] + elif split == "test": + scene_dirs = scene_dirs[-1:] + seq_dirs = [] + for scene in scene_dirs: + seq_dirs += sorted( + [ + os.path.join(scene, d) + for d in os.listdir(os.path.join(self.ROOT, scene)) + if os.path.isdir(os.path.join(self.ROOT, scene, d)) + ] + ) + offset = 0 + scenes = [] + sceneids = [] + images = [] + scene_img_list = [] + start_img_ids = [] + j = 0 + + for seq_idx, seq in enumerate(seq_dirs): + seq_path = osp.join(self.ROOT, seq) + for cam in ["Camera_0", "Camera_1"]: + basenames = sorted( + [ + f[:5] + for f in os.listdir(seq_path + "/" + cam) + if f.endswith(".jpg") + ] + ) + num_imgs = len(basenames) + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + scenes.append(seq + "/" + cam) + scene_img_list.append(img_ids) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + start_img_ids.extend(start_img_ids_) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + scene_id = self.sceneids[start_id] + all_image_ids = self.scene_img_list[scene_id] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=0.9, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + basename = self.images[view_idx] + + img = basename + "_rgb.jpg" + image = imread_cv2(osp.join(scene_dir, img)) + depthmap = ( + cv2.imread( + osp.join(scene_dir, basename + "_depth.png"), + cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH, + ).astype(np.float32) + / 100.0 + ) + camera_params = np.load(osp.join(scene_dir, basename + "_cam.npz")) + + intrinsics = camera_params["camera_intrinsics"] + camera_pose = camera_params["camera_pose"] + + sky_mask = depthmap >= 655 + depthmap[sky_mask] = -1.0 # sky + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(scene_dir, img) + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.85, 0.1, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="VirtualKITTI2", + label=scene_dir, + is_metric=self.is_metric, + instance=scene_dir + "_" + img, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/waymo.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/waymo.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f811f144c638b931cb99fd246702a0fa2d18e7 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/waymo.py @@ -0,0 +1,178 @@ +import os.path as osp +import os +import numpy as np +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import h5py +from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset +from dust3r.utils.image import imread_cv2 + + +class Waymo_Multi(BaseMultiViewDataset): + """Dataset of outdoor street scenes, 5 images each time""" + + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.max_interval = 8 + self.video = True + self.is_metric = True + super().__init__(*args, **kwargs) + assert self.split is None + self._load_data() + + def load_invalid_dict(self, h5_file_path): + invalid_dict = {} + with h5py.File(h5_file_path, "r") as h5f: + for scene in h5f: + data = h5f[scene]["invalid_pairs"][:] + invalid_pairs = set( + tuple(pair.decode("utf-8").split("_")) for pair in data + ) + invalid_dict[scene] = invalid_pairs + return invalid_dict + + def _load_data(self): + invalid_dict = self.load_invalid_dict( + os.path.join(self.ROOT, "invalid_files.h5") + ) + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + offset = 0 + scenes = [] + sceneids = [] + images = [] + start_img_ids = [] + scene_img_list = [] + is_video = [] + j = 0 + + for scene in scene_dirs: + scene_dir = osp.join(self.ROOT, scene) + invalid_pairs = invalid_dict.get(scene, set()) + seq2frames = {} + for f in os.listdir(scene_dir): + if not f.endswith(".jpg"): + continue + basename = f[:-4] + frame_id = basename.split("_")[0] + seq_id = basename.split("_")[1] + if seq_id == "5": + continue + if (seq_id, frame_id) in invalid_pairs: + continue # Skip invalid files + if seq_id not in seq2frames: + seq2frames[seq_id] = [] + seq2frames[seq_id].append(frame_id) + + for seq_id, frame_ids in seq2frames.items(): + frame_ids = sorted(frame_ids) + num_imgs = len(frame_ids) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + if num_imgs < cut_off: + print(f"Skipping {scene}_{seq_id}") + continue + + scenes.append((scene, seq_id)) + sceneids.extend([j] * num_imgs) + images.extend(frame_ids) + start_img_ids.extend(start_img_ids_) + scene_img_list.append(img_ids) + + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + self.is_video = is_video + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + _, seq_id = self.scenes[self.sceneids[start_id]] + max_interval = self.max_interval // 2 if seq_id == "4" else self.max_interval + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=max_interval, + video_prob=0.9, + fix_interval_prob=0.9, + block_shuffle=16, + ) + image_idxs = np.array(all_image_ids)[pos] + views = [] + ordered_video = True + + views = [] + + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir, seq_id = self.scenes[scene_id] + scene_dir = osp.join(self.ROOT, scene_dir) + frame_id = self.images[view_idx] + + impath = f"{frame_id}_{seq_id}" + image = imread_cv2(osp.join(scene_dir, impath + ".jpg")) + depthmap = imread_cv2(osp.join(scene_dir, impath + ".exr")) + camera_params = np.load(osp.join(scene_dir, impath + ".npz")) + + intrinsics = np.float32(camera_params["intrinsics"]) + camera_pose = np.float32(camera_params["cam2world"]) + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(scene_dir, impath) + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.85, 0.10, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="Waymo", + label=osp.relpath(scene_dir, self.ROOT), + is_metric=self.is_metric, + instance=osp.join(scene_dir, impath + ".jpg"), + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/wildrgbd.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/wildrgbd.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba152e19b9dae9e3ddd254d632f19d779ccffbe --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/wildrgbd.py @@ -0,0 +1,56 @@ +import os.path as osp +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np + +from dust3r.datasets.co3d import Co3d_Multi +from dust3r.utils.image import imread_cv2 + + +class WildRGBD_Multi(Co3d_Multi): + def __init__(self, mask_bg="rand", *args, ROOT, **kwargs): + super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs) + self.dataset_label = "WildRGBD" + self.is_metric = True + # load all scenes + self.scenes.pop(("box", "scenes/scene_257"), None) + self.scene_list = list(self.scenes.keys()) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + self.cut_off = cut_off + self.all_ref_imgs = [ + (key, value) + for key, values in self.scenes.items() + for value in values[: len(values) - cut_off + 1] + ] + self.invalidate = {scene: {} for scene in self.scene_list} + self.invalid_scenes = {scene: False for scene in self.scene_list} + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "metadata", f"{view_idx:0>5d}.npz") + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "rgb", f"{view_idx:0>5d}.jpg") + + def _get_depthpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "depth", f"{view_idx:0>5d}.png") + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "masks", f"{view_idx:0>5d}.png") + + def _read_depthmap(self, depthpath, input_metadata): + # We store depths in the depth scale of 1000. + # That is, when we load depth image and divide by 1000, we could get depth in meters. + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = depthmap.astype(np.float32) / 1000.0 + return depthmap + + def _get_views(self, idx, resolution, rng, num_views): + views = super()._get_views(idx, resolution, rng, num_views) + for view in views: + assert view["is_metric"] + view["quantile"] = np.array(0.96, dtype=np.float32) + return views diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/__init__.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/__init__.py @@ -0,0 +1 @@ + diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/camera.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..a76b52fcae78a004f74ae4fc1a4c187b743c5e57 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/camera.py @@ -0,0 +1,463 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from croco.models.blocks import Mlp + + +inf = float("inf") + + +class PoseDecoder(nn.Module): + def __init__( + self, + hidden_size=768, + mlp_ratio=4, + pose_encoding_type="absT_quaR", + ): + super().__init__() + + self.pose_encoding_type = pose_encoding_type + if self.pose_encoding_type == "absT_quaR": + self.target_dim = 7 + + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=int(hidden_size * mlp_ratio), + out_features=self.target_dim, + drop=0, + ) + + def forward( + self, + pose_feat, + ): + """ + pose_feat: BxC + preliminary_cameras: cameras in opencv coordinate. + """ + + pred_cameras = self.mlp(pose_feat) # Bx7, 3 for absT, 4 for quaR + return pred_cameras + + +class PoseEncoder(nn.Module): + def __init__( + self, + hidden_size=768, + mlp_ratio=4, + pose_mode=("exp", -inf, inf), + pose_encoding_type="absT_quaR", + ): + super().__init__() + self.pose_encoding_type = pose_encoding_type + self.pose_mode = pose_mode + + if self.pose_encoding_type == "absT_quaR": + self.target_dim = 7 + + self.embed_pose = PoseEmbedding( + target_dim=self.target_dim, + out_dim=hidden_size, + n_harmonic_functions=10, + append_input=True, + ) + self.pose_encoder = Mlp( + in_features=self.embed_pose.out_dim, + hidden_features=int(hidden_size * mlp_ratio), + out_features=hidden_size, + drop=0, + ) + + def forward(self, camera): + from dust3r.heads.postprocess import postprocess_pose + pose_enc = camera_to_pose_encoding( + camera, + pose_encoding_type=self.pose_encoding_type, + ).to(camera.dtype) + pose_enc = postprocess_pose(pose_enc, self.pose_mode, inverse=True) + pose_feat = self.embed_pose(pose_enc) + pose_feat = self.pose_encoder(pose_feat) + return pose_feat + + +class HarmonicEmbedding(torch.nn.Module): + def __init__( + self, + n_harmonic_functions: int = 6, + omega_0: float = 1.0, + logspace: bool = True, + append_input: bool = True, + ) -> None: + """ + The harmonic embedding layer supports the classical + Nerf positional encoding described in + `NeRF `_ + and the integrated position encoding in + `MIP-NeRF `_. + + During the inference you can provide the extra argument `diag_cov`. + + If `diag_cov is None`, it converts + rays parametrized with a `ray_bundle` to 3D points by + extending each ray according to the corresponding length. + Then it converts each feature + (i.e. vector along the last dimension) in `x` + into a series of harmonic features `embedding`, + where for each i in range(dim) the following are present + in embedding[...]:: + + [ + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i], # only present if append_input is True. + ] + + where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar + denoting the i-th frequency of the harmonic embedding. + + + If `diag_cov is not None`, it approximates + conical frustums following a ray bundle as gaussians, + defined by x, the means of the gaussians and diag_cov, + the diagonal covariances. + Then it converts each gaussian + into a series of harmonic features `embedding`, + where for each i in range(dim) the following are present + in embedding[...]:: + + [ + sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), + sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]), + ... + sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), + cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), + cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),, + ... + cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), + x[..., i], # only present if append_input is True. + ] + + where N equals `n_harmonic_functions-1`, and f_i is a scalar + denoting the i-th frequency of the harmonic embedding. + + If `logspace==True`, the frequencies `[f_1, ..., f_N]` are + powers of 2: + `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` + + If `logspace==False`, frequencies are linearly spaced between + `1.0` and `2**(n_harmonic_functions-1)`: + `f_1, ..., f_N = torch.linspace( + 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions + )` + + Note that `x` is also premultiplied by the base frequency `omega_0` + before evaluating the harmonic functions. + + Args: + n_harmonic_functions: int, number of harmonic + features + omega_0: float, base frequency + logspace: bool, Whether to space the frequencies in + logspace or linear space + append_input: bool, whether to concat the original + input to the harmonic embedding. If true the + output is of the form (embed.sin(), embed.cos(), x) + """ + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (n_harmonic_functions - 1), + n_harmonic_functions, + dtype=torch.float32, + ) + + self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) + self.register_buffer( + "_zero_half_pi", + torch.tensor([0.0, 0.5 * torch.pi]), + persistent=False, + ) + self.append_input = append_input + + def forward( + self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs + ) -> torch.Tensor: + """ + Args: + x: tensor of shape [..., dim] + diag_cov: An optional tensor of shape `(..., dim)` + representing the diagonal covariance matrices of our Gaussians, joined with x + as means of the Gaussians. + + Returns: + embedding: a harmonic embedding of `x` of shape + [..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray] + """ + + embed = x[..., None] * self._frequencies + + embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None] + + embed = embed.sin() + if diag_cov is not None: + x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2) + exp_var = torch.exp(-0.5 * x_var) + + embed = embed * exp_var[..., None, :, :] + + embed = embed.reshape(*x.shape[:-1], -1) + + if self.append_input: + return torch.cat([embed, x], dim=-1) + return embed + + @staticmethod + def get_output_dim_static( + input_dims: int, n_harmonic_functions: int, append_input: bool + ) -> int: + """ + Utility to help predict the shape of the output of `forward`. + + Args: + input_dims: length of the last dimension of the input tensor + n_harmonic_functions: number of embedding frequencies + append_input: whether or not to concat the original + input to the harmonic embedding + Returns: + int: the length of the last dimension of the output tensor + """ + return input_dims * (2 * n_harmonic_functions + int(append_input)) + + def get_output_dim(self, input_dims: int = 3) -> int: + """ + Same as above. The default for input_dims is 3 for 3D applications + which use harmonic embedding for positional encoding, + so the input might be xyz. + """ + return self.get_output_dim_static( + input_dims, len(self._frequencies), self.append_input + ) + + +class PoseEmbedding(nn.Module): + def __init__(self, target_dim, out_dim, n_harmonic_functions=10, append_input=True): + super().__init__() + + self._emb_pose = HarmonicEmbedding( + n_harmonic_functions=n_harmonic_functions, append_input=append_input + ) + + self.out_dim = self._emb_pose.get_output_dim(target_dim) + + def forward(self, pose_encoding): + e_pose_encoding = self._emb_pose(pose_encoding) + return e_pose_encoding + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + return standardize_quaternion(out) + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + quaternions = F.normalize(quaternions, p=2, dim=-1) + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def camera_to_pose_encoding( + camera, + pose_encoding_type="absT_quaR", +): + """ + Inverse to pose_encoding_to_camera + camera: opencv, cam2world + """ + if pose_encoding_type == "absT_quaR": + + quaternion_R = matrix_to_quaternion(camera[:, :3, :3]) + + pose_encoding = torch.cat([camera[:, :3, 3], quaternion_R], dim=-1) + else: + raise ValueError(f"Unknown pose encoding {pose_encoding_type}") + + return pose_encoding + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def pose_encoding_to_camera( + pose_encoding, + pose_encoding_type="absT_quaR", +): + """ + Args: + pose_encoding: A tensor of shape `BxC`, containing a batch of + `B` `C`-dimensional pose encodings. + pose_encoding_type: The type of pose encoding, + """ + + if pose_encoding_type == "absT_quaR": + + abs_T = pose_encoding[:, :3] + quaternion_R = pose_encoding[:, 3:7] + R = quaternion_to_matrix(quaternion_R) + else: + raise ValueError(f"Unknown pose encoding {pose_encoding_type}") + + c2w_mats = torch.eye(4, 4).to(R.dtype).to(R.device) + c2w_mats = c2w_mats[None].repeat(len(R), 1, 1) + c2w_mats[:, :3, :3] = R + c2w_mats[:, :3, 3] = abs_T + + return c2w_mats + + +def quaternion_conjugate(q): + """Compute the conjugate of quaternion q (w, x, y, z).""" + + q_conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1) + return q_conj + + +def quaternion_multiply(q1, q2): + """Multiply two quaternions q1 and q2.""" + w1, x1, y1, z1 = q1.unbind(dim=-1) + w2, x2, y2, z2 = q2.unbind(dim=-1) + + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + + return torch.stack((w, x, y, z), dim=-1) + + +def rotate_vector(q, v): + """Rotate vector v by quaternion q.""" + q_vec = q[..., 1:] + q_w = q[..., :1] + + t = 2.0 * torch.cross(q_vec, v, dim=-1) + v_rot = v + q_w * t + torch.cross(q_vec, t, dim=-1) + return v_rot + + +def relative_pose_absT_quatR(t1, q1, t2, q2): + """Compute the relative translation and quaternion between two poses.""" + + q1_inv = quaternion_conjugate(q1) + + q_rel = quaternion_multiply(q1_inv, q2) + + delta_t = t2 - t1 + t_rel = rotate_vector(q1_inv, delta_t) + return t_rel, q_rel diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/device.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..ad5e8a44a0e634b4590695063f028847818bf12f --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/device.py @@ -0,0 +1,88 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import numpy as np +import torch + + +def todevice(batch, device, callback=None, non_blocking=False): + """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). + + batch: list, tuple, dict of tensors or other things + device: pytorch device or 'numpy' + callback: function that would be called on every sub-elements. + """ + if callback: + batch = callback(batch) + + if isinstance(batch, dict): + return {k: todevice(v, device) for k, v in batch.items()} + + if isinstance(batch, (tuple, list)): + return type(batch)(todevice(x, device) for x in batch) + + x = batch + if device == "numpy": + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +to_device = todevice # alias + + +def to_numpy(x): + return todevice(x, "numpy") + + +def to_cpu(x): + return todevice(x, "cpu") + + +def to_cuda(x): + return todevice(x, "cuda") + + +def collate_with_cat(whatever, lists=False): + if isinstance(whatever, dict): + return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} + + elif isinstance(whatever, (tuple, list)): + if len(whatever) == 0: + return whatever + elem = whatever[0] + T = type(whatever) + + if elem is None: + return None + if isinstance(elem, (bool, float, int, str)): + return whatever + if isinstance(elem, tuple): + return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) + if isinstance(elem, dict): + return { + k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem + } + + if isinstance(elem, torch.Tensor): + return listify(whatever) if lists else torch.cat(whatever) + if isinstance(elem, np.ndarray): + return ( + listify(whatever) + if lists + else torch.cat([torch.from_numpy(x) for x in whatever]) + ) + + return sum(whatever, T()) + + +def listify(elems): + return [x for e in elems for x in e] diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/geometry.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4ab6a9338d112a5e0e27a1d249bd9be0f0c282 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/geometry.py @@ -0,0 +1,554 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import torch +import numpy as np +from scipy.spatial import cKDTree as KDTree + +from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans +from dust3r.utils.device import to_numpy + + +def xy_grid( + W, + H, + device=None, + origin=(0, 0), + unsqueeze=None, + cat_dim=-1, + homogeneous=False, + **arange_kw, +): + """Output a (H,W,2) array of int32 + with output[j,i,0] = i + origin[0] + output[j,i,1] = j + origin[1] + """ + if device is None: + + arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones + else: + + arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) + meshgrid, stack = torch.meshgrid, torch.stack + ones = lambda *a: torch.ones(*a, device=device) + + tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] + grid = meshgrid(tw, th, indexing="xy") + if homogeneous: + grid = grid + (ones((H, W)),) + if unsqueeze is not None: + grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) + if cat_dim is not None: + grid = stack(grid, cat_dim) + return grid + + +def geotrf(Trf, pts, ncol=None, norm=False): + """Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + if ( + isinstance(Trf, torch.Tensor) + and isinstance(pts, torch.Tensor) + and Trf.ndim == 3 + and pts.ndim == 4 + ): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = ( + torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + + Trf[:, None, None, :d, d] + ) + else: + raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def inv(mat): + """Invert a torch or numpy matrix""" + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f"bad matrix type = {type(mat)}") + + +def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): + """ + Args: + - depthmap (BxHxW array): + - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] + Returns: + pointmap of absolute coordinates (BxHxWx3 array) + """ + + if len(depth.shape) == 4: + B, H, W, n = depth.shape + else: + B, H, W = depth.shape + n = None + + if len(pseudo_focal.shape) == 3: # [B,H,W] + pseudo_focalx = pseudo_focaly = pseudo_focal + elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] + pseudo_focalx = pseudo_focal[:, 0] + if pseudo_focal.shape[1] == 2: + pseudo_focaly = pseudo_focal[:, 1] + else: + pseudo_focaly = pseudo_focalx + else: + raise NotImplementedError("Error, unknown input focal shape format.") + + assert pseudo_focalx.shape == depth.shape[:3] + assert pseudo_focaly.shape == depth.shape[:3] + grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] + + if pp is None: + grid_x = grid_x - (W - 1) / 2 + grid_y = grid_y - (H - 1) / 2 + else: + grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] + grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] + + if n is None: + pts3d = torch.empty((B, H, W, 3), device=depth.device) + pts3d[..., 0] = depth * grid_x / pseudo_focalx + pts3d[..., 1] = depth * grid_y / pseudo_focaly + pts3d[..., 2] = depth + else: + pts3d = torch.empty((B, H, W, 3, n), device=depth.device) + pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] + pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] + pts3d[..., 2, :] = depth + return pts3d + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + valid_mask = depthmap > 0.0 + return X_cam, valid_mask + + +def depthmap_to_absolute_camera_coordinates( + depthmap, camera_intrinsics, camera_pose, **kw +): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + + X_world = X_cam # default + if camera_pose is not None: + + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + X_world = ( + np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + ) + + return X_world, X_cam, valid_mask + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + return K + + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + return K + + +def normalize_pointcloud( + pts1, pts2, norm_mode="avg_dis", valid1=None, valid2=None, ret_factor=False +): + """renorm pointmaps pts1, pts2 with norm_mode""" + assert pts1.ndim >= 3 and pts1.shape[-1] == 3 + assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3) + norm_mode, dis_mode = norm_mode.split("_") + + if norm_mode == "avg": + + nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3) + nan_pts2, nnz2 = ( + invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0) + ) + all_pts = ( + torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + ) + + all_dis = all_pts.norm(dim=-1) + if dis_mode == "dis": + pass # do nothing + elif dis_mode == "log1p": + all_dis = torch.log1p(all_dis) + elif dis_mode == "warp-log1p": + + log_dis = torch.log1p(all_dis) + warp_factor = log_dis / all_dis.clip(min=1e-8) + H1, W1 = pts1.shape[1:-1] + pts1 = pts1 * warp_factor[:, : W1 * H1].view(-1, H1, W1, 1) + if pts2 is not None: + H2, W2 = pts2.shape[1:-1] + pts2 = pts2 * warp_factor[:, W1 * H1 :].view(-1, H2, W2, 1) + all_dis = log_dis # this is their true distance afterwards + else: + raise ValueError(f"bad {dis_mode=}") + + norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8) + else: + + nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3) + nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None + all_pts = ( + torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + ) + + all_dis = all_pts.norm(dim=-1) + + if norm_mode == "avg": + norm_factor = all_dis.nanmean(dim=1) + elif norm_mode == "median": + norm_factor = all_dis.nanmedian(dim=1).values.detach() + elif norm_mode == "sqrt": + norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2 + else: + raise ValueError(f"bad {norm_mode=}") + + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts1.ndim: + norm_factor.unsqueeze_(-1) + + res = pts1 / norm_factor + if pts2 is not None: + res = (res, pts2 / norm_factor) + if ret_factor: + res = res + (norm_factor,) + return res + + +def normalize_pointcloud_group( + pts_list, + norm_mode="avg_dis", + valid_list=None, + conf_list=None, + ret_factor=False, + ret_factor_only=False, +): + """renorm pointmaps pts1, pts2 with norm_mode""" + for pts in pts_list: + assert pts.ndim >= 3 and pts.shape[-1] == 3 + + norm_mode, dis_mode = norm_mode.split("_") + + if norm_mode == "avg": + + nan_pts_list, nnz_list = zip( + *[ + invalid_to_zeros(pts1, valid1, ndim=3) + for pts1, valid1 in zip(pts_list, valid_list) + ] + ) + all_pts = torch.cat(nan_pts_list, dim=1) + if conf_list is not None: + nan_conf_list = [ + invalid_to_zeros(conf1[..., None], valid1, ndim=3)[0] + for conf1, valid1 in zip(conf_list, valid_list) + ] + all_conf = torch.cat(nan_conf_list, dim=1)[..., 0] + else: + all_conf = torch.ones_like(all_pts[..., 0]) + + all_dis = all_pts.norm(dim=-1) + if dis_mode == "dis": + pass # do nothing + elif dis_mode == "log1p": + all_dis = torch.log1p(all_dis) + elif dis_mode == "warp-log1p": + + log_dis = torch.log1p(all_dis) + warp_factor = log_dis / all_dis.clip(min=1e-8) + H_W_list = [pts.shape[1:-1] for pts in pts_list] + pts_list = [ + pts + * warp_factor[:, sum(H_W_list[:i]) : sum(H_W_list[: i + 1])].view( + -1, H, W, 1 + ) + for i, (pts, (H, W)) in enumerate(zip(pts_list, H_W_list)) + ] + all_dis = log_dis # this is their true distance afterwards + else: + raise ValueError(f"bad {dis_mode=}") + + norm_factor = (all_conf * all_dis).sum(dim=1) / (all_conf.sum(dim=1) + 1e-8) + else: + + nan_pts_list = [ + invalid_to_nans(pts1, valid1, ndim=3) + for pts1, valid1 in zip(pts_list, valid_list) + ] + + all_pts = torch.cat(nan_pts_list, dim=1) + + all_dis = all_pts.norm(dim=-1) + + if norm_mode == "avg": + norm_factor = all_dis.nanmean(dim=1) + elif norm_mode == "median": + norm_factor = all_dis.nanmedian(dim=1).values.detach() + elif norm_mode == "sqrt": + norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2 + else: + raise ValueError(f"bad {norm_mode=}") + + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts_list[0].ndim: + norm_factor.unsqueeze_(-1) + + if ret_factor_only: + + return norm_factor + + res = [pts / norm_factor for pts in pts_list] + if ret_factor: + return res, norm_factor + return res + + +@torch.no_grad() +def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5): + + _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) + _z2 = ( + invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) + if z2 is not None + else None + ) + _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1 + + if quantile == 0.5: + shift_z = torch.nanmedian(_z, dim=-1).values + else: + shift_z = torch.nanquantile(_z, quantile, dim=-1) + return shift_z # (B,) + + +@torch.no_grad() +def get_group_pointcloud_depth(zs, valid_masks, quantile=0.5): + + _zs = [ + invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) + for z1, valid_mask1 in zip(zs, valid_masks) + ] + _z = torch.cat(_zs, dim=-1) + + if quantile == 0.5: + shift_z = torch.nanmedian(_z, dim=-1).values + else: + shift_z = torch.nanquantile(_z, quantile, dim=-1) + return shift_z # (B,) + + +@torch.no_grad() +def get_joint_pointcloud_center_scale( + pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True +): + + _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) + _pts2 = ( + invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) + if pts2 is not None + else None + ) + _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1 + + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) + if z_only: + _center[..., :2] = 0 # do not center X and Y + + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + return _center[:, None, :, :], scale[:, None, None, None] + + +@torch.no_grad() +def get_group_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True): + + _pts = [ + invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) + for pts1, valid_mask1 in zip(pts, valid_masks) + ] + _pts = torch.cat(_pts, dim=1) + + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) + if z_only: + _center[..., :2] = 0 # do not center X and Y + + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + return _center[:, None, :, :], scale[:, None, None, None] + + +def find_reciprocal_matches(P1, P2): + """ + returns 3 values: + 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match + 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1 + 3 - reciprocal_in_P2.sum(): the number of matches + """ + tree1 = KDTree(P1) + tree2 = KDTree(P2) + + _, nn1_in_P2 = tree2.query(P1, workers=8) + _, nn2_in_P1 = tree1.query(P2, workers=8) + + reciprocal_in_P1 = nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)) + reciprocal_in_P2 = nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)) + assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum() + return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum() + + +def get_med_dist_between_poses(poses): + from scipy.spatial.distance import pdist + + return np.median(pdist([to_numpy(p[:3, 3]) for p in poses])) + + +def weighted_procrustes(A, B, w, use_weights=True, eps=1e-16, return_T=False): + """ + X: torch tensor B x N x 3 + Y: torch tensor B x N x 3 + w: torch tensor B x N + """ + assert len(A) == len(B) + if use_weights: + W1 = torch.abs(w).sum(1, keepdim=True) + w_norm = (w / (W1 + eps)).unsqueeze(-1) + a_mean = (w_norm * A).sum(dim=1, keepdim=True) + b_mean = (w_norm * B).sum(dim=1, keepdim=True) + + A_c = A - a_mean + B_c = B - b_mean + + H = torch.einsum("bni,bnj->bij", A_c, w_norm * B_c) + + else: + a_mean = A.mean(axis=1, keepdim=True) + b_mean = B.mean(axis=1, keepdim=True) + + A_c = A - a_mean + B_c = B - b_mean + + H = torch.einsum("bij,bik->bjk", A_c, B_c) + + U, S, V = torch.svd(H) # U: B x 3 x 3, S: B x 3, V: B x 3 x 3 + Z = torch.eye(3).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device) + Z[:, -1, -1] = torch.sign(torch.linalg.det(U @ V.transpose(1, 2))) # B x 3 x 3 + R = V @ Z @ U.transpose(1, 2) # B x 3 x 3 + t = b_mean - torch.einsum("bij,bjk->bik", R, a_mean.transpose(-2, -1)).transpose( + -2, -1 + ) + if return_T: + T = torch.eye(4).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device) + T[:, :3, :3] = R + T[:, :3, 3] = t.squeeze() + return T + return R, t.squeeze() diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/image.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..60feb3048d342bf1e82483cbd37c57d6efd7fff3 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/image.py @@ -0,0 +1,271 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import os +import torch +import numpy as np +import PIL.Image +from PIL.ImageOps import exif_transpose +import torchvision.transforms as tvf + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa + +try: + from pillow_heif import register_heif_opener # noqa + + register_heif_opener() + heif_support_enabled = True +except ImportError: + heif_support_enabled = False + +ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + +def img_to_arr(img): + if isinstance(img, str): + img = imread_cv2(img) + return img + + +def imread_cv2(path, options=cv2.IMREAD_COLOR): + """Open an image or a depthmap with opencv-python.""" + if path.endswith((".exr", "EXR")): + options = cv2.IMREAD_ANYDEPTH + img = cv2.imread(path, options) + if img is None: + raise IOError(f"Could not load image={path} with {options=}") + if img.ndim == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def imread_pil(path): + """Open an RGB image using PIL and return as numpy array.""" + img = PIL.Image.open(path) + img = exif_transpose(img) + img = img.convert("RGB") + return np.array(img) + + +def rgb(ftensor, true_shape=None): + if isinstance(ftensor, list): + return [rgb(x, true_shape=true_shape) for x in ftensor] + if isinstance(ftensor, torch.Tensor): + ftensor = ftensor.detach().cpu().numpy() # H,W,3 + if ftensor.ndim == 3 and ftensor.shape[0] == 3: + ftensor = ftensor.transpose(1, 2, 0) + elif ftensor.ndim == 4 and ftensor.shape[1] == 3: + ftensor = ftensor.transpose(0, 2, 3, 1) + if true_shape is not None: + H, W = true_shape + ftensor = ftensor[:H, :W] + if ftensor.dtype == np.uint8: + img = np.float32(ftensor) / 255 + else: + img = (ftensor * 0.5) + 0.5 + return img.clip(min=0, max=1) + + +def _resize_pil_image(img, long_edge_size): + S = max(img.size) + if S > long_edge_size: + interp = PIL.Image.LANCZOS + elif S <= long_edge_size: + interp = PIL.Image.BICUBIC + new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size) + return img.resize(new_size, interp) + + +def load_images(folder_or_list, size, square_ok=False, verbose=True): + """open and convert all images in a list or folder to proper input format for DUSt3R""" + if isinstance(folder_or_list, str): + if verbose: + print(f">> Loading images from {folder_or_list}") + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + if verbose: + print(f">> Loading a list of {len(folder_or_list)} images") + root, folder_content = "", folder_or_list + + else: + raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})") + + supported_images_extensions = [".jpg", ".jpeg", ".png", ".bmp"] + if heif_support_enabled: + supported_images_extensions += [".heic", ".heif"] + supported_images_extensions = tuple(supported_images_extensions) + + imgs = [] + for path in folder_content: + if not path.lower().endswith(supported_images_extensions): + continue + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB") + W1, H1 = img.size + if size == 224: + + img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1))) + else: + + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W // 2, H // 2 + if size == 224: + half = min(cx, cy) + img = img.crop((cx - half, cy - half, cx + half, cy + half)) + else: + halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8 + if not (square_ok) and W == H: + halfh = 3 * halfw / 4 + img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh)) + + W2, H2 = img.size + if verbose: + print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}") + imgs.append( + dict( + img=ImgNorm(img)[None], + true_shape=np.int32([img.size[::-1]]), + idx=len(imgs), + instance=str(len(imgs)), + ) + ) + + assert imgs, "no images foud at " + root + if verbose: + print(f" (Found {len(imgs)} images)") + return imgs + + +def load_images_for_eval( + folder_or_list, size, square_ok=False, verbose=True, crop=True +): + """open and convert all images in a list or folder to proper input format for DUSt3R""" + if isinstance(folder_or_list, str): + if verbose: + print(f">> Loading images from {folder_or_list}") + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + if verbose: + print(f">> Loading a list of {len(folder_or_list)} images") + root, folder_content = "", folder_or_list + + else: + raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})") + + supported_images_extensions = [".jpg", ".jpeg", ".png"] + if heif_support_enabled: + supported_images_extensions += [".heic", ".heif"] + supported_images_extensions = tuple(supported_images_extensions) + + imgs = [] + for path in folder_content: + if not path.lower().endswith(supported_images_extensions): + continue + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB") + W1, H1 = img.size + if size == 224: + # resize short side to 224 (then crop) + img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1))) + else: + # resize long side to 512 + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W // 2, H // 2 + if size == 224: + half = min(cx, cy) + if crop: + img = img.crop((cx - half, cy - half, cx + half, cy + half)) + else: # resize + img = img.resize((2 * half, 2 * half), PIL.Image.LANCZOS) + else: + halfw, halfh = ((2 * cx) // 14) * 7, ((2 * cy) // 14) * 7 + if not (square_ok) and W == H: + halfh = 3 * halfw / 4 + if crop: + img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh)) + else: # resize + img = img.resize((2 * halfw, 2 * halfh), PIL.Image.LANCZOS) + W2, H2 = img.size + if verbose: + print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}") + imgs.append( + dict( + img=ImgNorm(img)[None], + true_shape=np.int32([img.size[::-1]]), + idx=len(imgs), + instance=str(len(imgs)), + ) + ) + + assert imgs, "no images foud at " + root + if verbose: + print(f" (Found {len(imgs)} images)") + return imgs + + +def load_images_512(folder_or_list, size, square_ok=False, verbose=True): + """open and convert all images in a list or folder to proper input format for DUSt3R""" + if isinstance(folder_or_list, str): + if verbose: + print(f">> Loading images from {folder_or_list}") + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + if verbose: + print(f">> Loading a list of {len(folder_or_list)} images") + root, folder_content = "", folder_or_list + + else: + raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})") + + supported_images_extensions = [".jpg", ".jpeg", ".png", ".bmp"] + if heif_support_enabled: + supported_images_extensions += [".heic", ".heif"] + supported_images_extensions = tuple(supported_images_extensions) + + imgs = [] + for path in folder_content: + if not path.lower().endswith(supported_images_extensions): + continue + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB") + img = img.resize((512, 384)) + W1, H1 = img.size + if size == 224: + + img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1))) + else: + + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W // 2, H // 2 + if size == 224: + half = min(cx, cy) + img = img.crop((cx - half, cy - half, cx + half, cy + half)) + else: + halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8 + if not (square_ok) and W == H: + halfh = 3 * halfw / 4 + img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh)) + + W2, H2 = img.size + if verbose: + print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}") + imgs.append( + dict( + img=ImgNorm(img)[None], + true_shape=np.int32([img.size[::-1]]), + idx=len(imgs), + instance=str(len(imgs)), + ) + ) + + assert imgs, "no images foud at " + root + if verbose: + print(f" (Found {len(imgs)} images)") + return imgs diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/misc.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb3f225ba3b0a007541eb81362cd58e1c54d916 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/misc.py @@ -0,0 +1,127 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import torch + + +def fill_default_args(kwargs, func): + import inspect # a bit hacky but it works reliably + + signature = inspect.signature(func) + + for k, v in signature.parameters.items(): + if v.default is inspect.Parameter.empty: + continue + kwargs.setdefault(k, v.default) + + return kwargs + + +def freeze_all_params(modules): + for module in modules: + try: + for n, param in module.named_parameters(): + param.requires_grad = False + except AttributeError: + + module.requires_grad = False + + +def is_symmetrized(gt1, gt2): + x = gt1["instance"] + y = gt2["instance"] + if len(x) == len(y) and len(x) == 1: + return False # special case of batchsize 1 + ok = True + for i in range(0, len(x), 2): + ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i]) + return ok + + +def flip(tensor): + """flip so that tensor[0::2] <=> tensor[1::2]""" + return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) + + +def interleave(tensor1, tensor2): + res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) + res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) + return res1, res2 + + +def transpose_to_landscape(head, activate=True): + """Predict in the correct aspect-ratio, + then transpose the result in landscape + and stack everything back together. + """ + + def wrapper_no(decout, true_shape, **kwargs): + B = len(true_shape) + assert true_shape[0:1].allclose(true_shape), "true_shape must be all identical" + H, W = true_shape[0].cpu().tolist() + res = head(decout, (H, W), **kwargs) + return res + + def wrapper_yes(decout, true_shape, **kwargs): + B = len(true_shape) + + H, W = int(true_shape.min()), int(true_shape.max()) + + height, width = true_shape.T + is_landscape = width >= height + is_portrait = ~is_landscape + + if is_landscape.all(): + return head(decout, (H, W), **kwargs) + if is_portrait.all(): + return transposed(head(decout, (W, H), **kwargs)) + + def selout(ar): + return [d[ar] for d in decout] + + if "pos" in kwargs: + kwargs_landscape = kwargs.copy() + kwargs_landscape["pos"] = kwargs["pos"][is_landscape] + kwargs_portrait = kwargs.copy() + kwargs_portrait["pos"] = kwargs["pos"][is_portrait] + l_result = head(selout(is_landscape), (H, W), **kwargs_landscape) + p_result = transposed(head(selout(is_portrait), (W, H), **kwargs_portrait)) + + result = {} + for k in l_result | p_result: + x = l_result[k].new(B, *l_result[k].shape[1:]) + x[is_landscape] = l_result[k] + x[is_portrait] = p_result[k] + result[k] = x + + return result + + return wrapper_yes if activate else wrapper_no + + +def transposed(dic): + return {k: v.swapaxes(1, 2) if v.ndim > 2 else v for k, v in dic.items()} + + +def invalid_to_nans(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = float("nan") + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr + + +def invalid_to_zeros(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = 0 + nnz = valid_mask.view(len(valid_mask), -1).sum(1) + else: + nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr, nnz diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/parallel.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..5082a85b8c66cdcddc7402c401c0c983c5f1078b --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/parallel.py @@ -0,0 +1,87 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +from tqdm import tqdm +from multiprocessing.dummy import Pool as ThreadPool +from multiprocessing import cpu_count + + +def parallel_threads( + function, + args, + workers=0, + star_args=False, + kw_args=False, + front_num=1, + Pool=ThreadPool, + **tqdm_kw +): + """tqdm but with parallel execution. + + Will essentially return + res = [ function(arg) # default + function(*arg) # if star_args is True + function(**arg) # if kw_args is True + for arg in args] + + Note: + the first elements of args will not be parallelized. + This can be useful for debugging. + """ + while workers <= 0: + workers += cpu_count() + if workers == 1: + front_num = float("inf") + + try: + n_args_parallel = len(args) - front_num + except TypeError: + n_args_parallel = None + args = iter(args) + + front = [] + while len(front) < front_num: + try: + a = next(args) + except StopIteration: + return front # end of the iterable + front.append( + function(*a) if star_args else function(**a) if kw_args else function(a) + ) + + out = [] + with Pool(workers) as pool: + + if star_args: + futures = pool.imap(starcall, [(function, a) for a in args]) + elif kw_args: + futures = pool.imap(starstarcall, [(function, a) for a in args]) + else: + futures = pool.imap(function, args) + + for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): + out.append(f) + return front + out + + +def parallel_processes(*args, **kwargs): + """Same as parallel_threads, with processes""" + import multiprocessing as mp + + kwargs["Pool"] = mp.Pool + return parallel_threads(*args, **kwargs) + + +def starcall(args): + """convenient wrapper for Process.Pool""" + function, args = args + return function(*args) + + +def starstarcall(args): + """convenient wrapper for Process.Pool""" + function, args = args + return function(**args) diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/path_to_croco.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/path_to_croco.py new file mode 100644 index 0000000000000000000000000000000000000000..108b532b440b49dd5c9f77eac86ec3562cb5c1e8 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/path_to_croco.py @@ -0,0 +1,47 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import sys +import os.path as path +import importlib + +HERE_PATH = path.normpath(path.dirname(__file__)) +CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, "../../croco")) +CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, "models") +# IMPORTANT: +# Do NOT add `.../src/croco` directly to sys.path, otherwise subfolders like +# `croco/datasets` become a top-level module named `datasets`, which will shadow +# HuggingFace `datasets` and break `accelerate` (and others). +# Instead, add `.../src` so we import as `croco.*`. +SRC_PATH = path.normpath(path.join(HERE_PATH, "../../..")) + +if path.isdir(CROCO_MODELS_PATH): + + # Prefer adding the `src` directory; this enables `import croco...` without + # polluting top-level module names. + if SRC_PATH not in sys.path: + sys.path.insert(0, SRC_PATH) + + # In case an old run already inserted CROCO_REPO_PATH, remove it to avoid + # shadowing top-level modules (e.g., `datasets`). + while CROCO_REPO_PATH in sys.path: + sys.path.remove(CROCO_REPO_PATH) + + # Backward-compat: DUSt3R code expects `models.*` to exist as a top-level package + # (historically achieved by adding CROCO_REPO_PATH to sys.path). We keep that + # import path working by aliasing `croco.models` to `models` without exposing + # other top-level names like `datasets`. + try: + _croco_models = importlib.import_module("croco.models") + sys.modules.setdefault("models", _croco_models) + except Exception: + # If croco isn't importable yet, downstream import will raise a clearer error. + pass +else: + raise ImportError( + f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n " + "Did you forget to run 'git submodule update --init --recursive' ?" + ) diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/render.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/render.py new file mode 100644 index 0000000000000000000000000000000000000000..bc61fa8993396c9cd850177c288eb2a798561333 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/utils/render.py @@ -0,0 +1,75 @@ +import torch +from gsplat import rasterization +from dust3r.utils.geometry import inv, geotrf + + +def render( + intrinsics: torch.Tensor, + pts3d: torch.Tensor, + rgbs: torch.Tensor | None = None, + scale: float = 0.002, + opacity: float = 0.95, +): + + device = pts3d.device + batch_size = len(intrinsics) + img_size = pts3d.shape[1:3] + pts3d = pts3d.reshape(batch_size, -1, 3) + num_pts = pts3d.shape[1] + quats = torch.randn((num_pts, 4), device=device) + quats = quats / quats.norm(dim=-1, keepdim=True) + scales = scale * torch.ones((num_pts, 3), device=device) + opacities = opacity * torch.ones((num_pts), device=device) + if rgbs is not None: + assert rgbs.shape[1] == 3 + rgbs = rgbs.reshape(batch_size, 3, -1).transpose(1, 2) + else: + rgbs = torch.ones_like(pts3d[:, :, :3]) + + rendered_rgbs = [] + rendered_depths = [] + accs = [] + for i in range(batch_size): + rgbd, acc, _ = rasterization( + pts3d[i], + quats, + scales, + opacities, + rgbs[i], + torch.eye(4, device=device)[None], + intrinsics[[i]], + width=img_size[1], + height=img_size[0], + packed=False, + render_mode="RGB+D", + ) + + rendered_depths.append(rgbd[..., 3]) + + rendered_depths = torch.cat(rendered_depths, dim=0) + + return rendered_rgbs, rendered_depths, accs + + +def get_render_results(gts, preds, self_view=False): + device = preds[0]["pts3d_in_other_view"].device + with torch.no_grad(): + depths = [] + gt_depths = [] + for i, (gt, pred) in enumerate(zip(gts, preds)): + if self_view: + camera = inv(gt["camera_pose"]).to(device) + intrinsics = gt["camera_intrinsics"].to(device) + pred = pred["pts3d_in_other_view"] + else: + camera = inv(gts[0]["camera_pose"]).to(device) + intrinsics = gts[0]["camera_intrinsics"].to(device) + pred = pred["pts3d_in_other_view"] + gt_img = gt["img"].to(device) + gt_pts3d = gt["pts3d"].to(device) + + _, depth, _ = render(intrinsics, pred, gt_img) + _, gt_depth, _ = render(intrinsics, geotrf(camera, gt_pts3d), gt_img) + depths.append(depth) + gt_depths.append(gt_depth) + return depths, gt_depths diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/__init__.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/hub/__init__.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/hub/backbones.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/hub/backbones.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/hub/utils.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/__init__.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/attention.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3fed573116d5c837be46a7525d8acf77422c2400 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/block.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5b8a7bb8527b74186af7c1e060e37bdb52c73d --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/block.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/dino_head.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/drop_path.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/layer_scale.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/mlp.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/patch_embed.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/swiglu_ffn.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (SwiGLU)") + else: + # warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + # warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/models/__init__.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/models/vision_transformer.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..73f15cfb082d0fe629f8aa312c9d9b27a64ad4e7 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/models/vision_transformer.py @@ -0,0 +1,404 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.nn.init import trunc_normal_ + +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block +from ...layers.attention import FlashAttention + + +# logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + # logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + # logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + # logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + attn_class=FlashAttention + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/__init__.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/cluster.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/cluster.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/config.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import dinov2.distributed as distributed +from dinov2.logging import setup_logging +from dinov2.utils import utils +from dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/dtype.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/dtype.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/param_groups.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/param_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/utils.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e8842e4145414f6f040c4ae83bf38552de8f65b2 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/dinov2/utils/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +# logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + # logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + # logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/attention.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..538702c02ca4eb6e0768f2fb261e7bf6256d0adf --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/attention.py @@ -0,0 +1,369 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn +import torch + +from torch.nn.functional import scaled_dot_product_attention +from torch.nn.attention import SDPBackend + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlashAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + + if q.dtype == torch.bfloat16: + with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = scaled_dot_product_attention(q, k, v) + else: + with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): + x = scaled_dot_product_attention(q, k, v) + + x = x.transpose(1, 2).reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +""" +Following is written by GPT-4o +""" +class CrossAttentionRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + qk_norm: bool = False, + norm_layer: nn.Module = nn.LayerNorm, + rope=None, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + # Separate projection layers for query, key, and value + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) + + self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = rope + + def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor: + """ + Args: + query: Tensor of shape (B, N, C), input query + key: Tensor of shape (B, M, C), input key + value: Tensor of shape (B, M, C), input value + attn_bias: Optional tensor for attention bias + Returns: + Tensor of shape (B, N, C), output of cross-attention + """ + B, N, C = query.shape + _, M, _ = key.shape + + # Project query, key, and value + q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + # Scale query + q = q * self.scale + + # Compute attention scores + attn = q @ k.transpose(-2, -1) # (B, num_heads, N, M) + if attn_bias is not None: + attn = attn + attn_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # Compute attention output + x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C) + + # Final projection + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffCrossAttentionRope(CrossAttentionRope): + def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor: + """ + Args: + query: Tensor of shape (B, N, C), input query + key: Tensor of shape (B, M, C), input key + value: Tensor of shape (B, M, C), input value + attn_bias: Optional tensor for attention bias + Returns: + Tensor of shape (B, N, C), output of cross-attention + """ + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(query, key, value, attn_bias) + + B, N, C = query.shape + _, M, _ = key.shape + + # Project query, key, and value + q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads) + k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads) + v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + # Compute memory-efficient attention + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape(B, N, C) + + # Final projection + x = self.proj(x) + x = self.proj_drop(x) + return x + +class AttentionRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + qk_norm: bool = False, + norm_layer: nn.Module = nn.LayerNorm, + rope=None + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + + self.rope = rope + + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttentionRope(AttentionRope): + def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + qkv = qkv.transpose(1, 3) + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1) # for frame attention matrix + # global_valid_id = torch.where(score_matrix > 0) + # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FlashAttentionRope(AttentionRope): + def forward(self, x: Tensor, attn_bias=None, xpos=None, attn_mask=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + if q.dtype == torch.bfloat16: + #with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + else: + with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): + x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + + x = x.transpose(1, 2).reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + +def get_attn_score(blk_class, x, frame_num, token_length, xpos=None): + x = blk_class.norm1(x) + + B, N, C = x.shape + qkv = blk_class.attn.qkv(x).reshape(B, N, 3, blk_class.attn.num_heads, C // blk_class.attn.num_heads) + + qkv = qkv.transpose(1, 3) + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = blk_class.attn.q_norm(q).to(v.dtype), blk_class.attn.k_norm(k).to(v.dtype) + + if blk_class.attn.rope is not None: + q = blk_class.attn.rope(q, xpos) + k = blk_class.attn.rope(k, xpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + score = (q.permute(0, 2, 1, 3) * blk_class.attn.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(B, frame_num, token_length, frame_num, token_length).mean(dim=[2, 4]).sum(-1) + + return score diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/block.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..4135a9c8622b8f8468e8324ef9e223964ede913f --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/block.py @@ -0,0 +1,434 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention, CrossAttentionRope, MemEffCrossAttentionRope, FlashAttentionRope +from ..dinov2.layers.drop_path import DropPath +from ..dinov2.layers.layer_scale import LayerScale +from ..dinov2.layers.mlp import Mlp + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError + +class BlockRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool=False, + rope=None + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + rope=rope + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, xpos=None, N=None, branch=1, global_=False, attn_mask=None, **kwargs) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + if attn_mask is not None: + # Use externally provided mask directly + _attn_mask = attn_mask + elif global_: + if branch == 1: # frontend + B,NP,C = x.shape + S = N + L = NP #S * P + P = int(NP//N) + frame_ids = torch.arange(L, device=x.device) // P # [0,0,...,1,1,...,S-1] + future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0) + future_frame *= (frame_ids.unsqueeze(0)>1) + _attn_mask = future_frame.to(x.dtype) * torch.finfo(x.dtype).min + + elif branch == 2: # backend + _attn_mask = None + elif branch == 3: # mix-mode + B,NP,C = x.shape + S = N + L = NP #S * P + P = int(NP//N) + S_half= int(S//2) + frame_ids = torch.arange(L, device=x.device) // P # [0,0,...,1,1,...,S-1] + future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0) + future_frame *= (frame_ids.unsqueeze(0)>=S_half) + _attn_mask = future_frame.to(x.dtype) * torch.finfo(x.dtype).min + else: + _attn_mask = None + return self.ls1(self.attn(self.norm1(x), xpos=xpos, attn_mask=_attn_mask)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +class CrossBlockRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + cross_attn_class: Callable[..., nn.Module] = CrossAttentionRope, + ffn_layer: Callable[..., nn.Module] = Mlp, + init_values=None, + qk_norm: bool=False, + rope=None + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + rope=rope, + qk_norm=qk_norm + ) + + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.ls_y = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.norm2 = norm_layer(dim) + self.norm_y = norm_layer(dim) + self.cross_attn = cross_attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + rope=rope, + qk_norm=qk_norm + ) + + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + bias=ffn_bias, + ) + + def forward(self, x: Tensor, y: Tensor, xpos=None, ypos=None) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x), xpos=xpos)) + + def cross_attn_residual_func(x: Tensor, y: Tensor) -> Tensor: + return self.ls_y(self.cross_attn(self.norm2(x), y, y, qpos=xpos, kpos=ypos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm3(x))) + + x = x + attn_residual_func(x) + y_ = self.norm_y(y) + x = x + cross_attn_residual_func(x, y_) + x = x + ffn_residual_func(x) + + return x diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/camera_head.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7d844f7b76851c3e523e419e18358838e9d23410 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/camera_head.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +from copy import deepcopy +import torch.nn.functional as F + +# code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172' +class ResConvBlock(nn.Module): + """ + 1x1 convolution residual block + """ + def __init__(self, in_channels, out_channels): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.head_skip = nn.Identity() if self.in_channels == self.out_channels else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) + # self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) + # self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) + # self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) + + # change 1x1 convolution to linear + self.res_conv1 = nn.Linear(self.in_channels, self.out_channels) + self.res_conv2 = nn.Linear(self.out_channels, self.out_channels) + self.res_conv3 = nn.Linear(self.out_channels, self.out_channels) + + def forward(self, res): + x = F.relu(self.res_conv1(res)) + x = F.relu(self.res_conv2(x)) + x = F.relu(self.res_conv3(x)) + res = self.head_skip(res) + x + return res + +class CameraHead(nn.Module): + def __init__(self, dim=512): + super().__init__() + output_dim = dim + self.res_conv = nn.ModuleList([deepcopy(ResConvBlock(output_dim, output_dim)) + for _ in range(2)]) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.more_mlps = nn.Sequential( + nn.Linear(output_dim,output_dim), + nn.ReLU(), + nn.Linear(output_dim,output_dim), + nn.ReLU() + ) + self.fc_t = nn.Linear(output_dim, 3) + self.fc_rot = nn.Linear(output_dim, 9) + + def forward(self, feat, patch_h, patch_w): + BN, hw, c = feat.shape + + for i in range(2): + feat = self.res_conv[i](feat) + + # feat = self.avgpool(feat) + feat = self.avgpool(feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous()) ########## + feat = feat.view(feat.size(0), -1) + + feat = self.more_mlps(feat) # [B, D_] + with torch.amp.autocast(device_type='cuda', enabled=False): + out_t = self.fc_t(feat.float()) # [B,3] + out_r = self.fc_rot(feat.float()) # [B,9] + pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device) + + return pose + + def convert_pose_to_4x4(self, B, out_r, out_t, device): + out_r = self.svd_orthogonalize(out_r) # [N,3,3] + pose = torch.zeros((B, 4, 4), device=device) + pose[:, :3, :3] = out_r + pose[:, :3, 3] = out_t + pose[:, 3, 3] = 1. + return pose + + def svd_orthogonalize(self, m): + """Convert 9D representation to SO(3) using SVD orthogonalization. + + Args: + m: [BATCH, 3, 3] 3x3 matrices. + + Returns: + [BATCH, 3, 3] SO(3) rotation matrices. + """ + if m.dim() < 3: + m = m.reshape((-1, 3, 3)) + m_transpose = torch.transpose(torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2) + u, s, v = torch.svd(m_transpose) + det = torch.det(torch.matmul(v, u.transpose(-2, -1))) + # Check orientation reflection. + r = torch.matmul( + torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2), + u.transpose(-2, -1) + ) + return r \ No newline at end of file diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/dpt_head.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e82f9e20e7fd8824888cfb3f1da81977706d40aa --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/dpt_head.py @@ -0,0 +1,415 @@ +import os +from typing import List, Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from streamvggt.heads.head_act import activate_head +from streamvggt.heads.utils import create_uv_grid, position_grid_to_embed +import pdb + +class DPTHead(nn.Module): + """ + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [0], #[4, 11, 17, 23], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=dim_in, + out_channels=oc, + kernel_size=1, + stride=1, + padding=0, + ) + for oc in out_channels + ] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + + self.scratch = _make_scratch( + out_channels, + features, + expand=False, + ) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = None + #self.scratch.refinenet1 = _make_fusion_block(features) + #self.scratch.refinenet2 = _make_fusion_block(features) + #self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + patch_start_idx: int = 0, + frames_chunk_size: int = 8, + shape_: tuple = (1,16,3, 518, 392) + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, S, _, H, W = shape_ + + # If frames_chunk_size is not specified or greater than S, process all frames at once + + return self._forward_impl(aggregated_tokens_list, patch_start_idx, shape_ = shape_) + + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + shape_: tuple = (1,16,3, 518, 392) + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + + B, S, _, H, W = shape_ + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + # Select frames if processing a chunk + + + x = x.reshape(B * S, -1, x.shape[-1]) + x = self.norm(x) + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.reshape(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) + + preds = preds.reshape(B, S, *preds.shape[1:]) + conf = conf.reshape(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_4 = features[0] + + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_4_rn.shape[2:]) + del layer_4_rn, layer_4 + out = self.scratch.output_conv1(out) + return out + + +def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer4_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/pos_embed.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3ceb86b2d3992636a28b1a4abb3f5722dd959e --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/pos_embed.py @@ -0,0 +1,606 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Modified: Added RoPE3D_ChunkAware and PositionGetter3D_ChunkAware for chunk-aware position encoding +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + + + +import numpy as np + +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if n_cls_token>0: + pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +#---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +#---------------------------------------------------------- + +try: + from models.curope import cuRoPE2D + RoPE2D = cuRoPE2D +except ImportError: + print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') + + class RoPE2D(torch.nn.Module): + + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D,seq_len,device,dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D,seq_len,device,dtype] = (cos,sin) + return self.cache[D,seq_len,device,dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim==2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:,:,0], cos, sin) + x = self.apply_rope1d(x, positions[:,:,1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens + +# patch embedding +class PositionGetter(object): + """ return positions of patches """ + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if not (h,w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone() + return pos + + +#---------------------------------------------------------- +# RoPE3D: RoPE implementation in 3D (ref_id + space) +# Fixed dimension split: ref_id=22, y=21, x=21 (total 64) +#---------------------------------------------------------- + +class RoPE3D(torch.nn.Module): + """ + RoPE implementation in 3D (ref_id, height, width) + + Fixed dimension split for head_dim=64: + - ref_id: 22 dimensions + - y: 21 dimensions + - x: 21 dimensions + + This ensures no dim divisibility issues (64 = 22 + 21 + 21) + """ + + def __init__(self, freq=100.0, F0=1.0, ref_freq=10.0): + super().__init__() + self.base = freq # for spatial dimensions (y, x) + self.ref_base = ref_freq # for ref_id dimension + self.F0 = F0 + self.cache = {} + + # Fixed dimension split + self.D_ref = 22 # ref_id gets 22 dims + self.D_y = 21 # y gets 21 dims + self.D_x = 21 # x gets 21 dims + + def get_cos_sin(self, D, seq_len, device, dtype, base): + key = (D, seq_len, device, dtype, base) + if key not in self.cache: + # For D dimensions, we need D/2 frequencies + # When D is odd, we compute ceil(D/2) frequencies and truncate + half_D = (D + 1) // 2 # ceil(D/2) + inv_freq = 1.0 / (base ** (torch.arange(0, half_D).float().to(device) * 2 / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + # Duplicate and truncate to exactly D dimensions + freqs = torch.cat((freqs, freqs), dim=-1)[:, :D] + cos = freqs.cos() # (Seq, D) + sin = freqs.sin() + self.cache[key] = (cos, sin) + return self.cache[key] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 3 (ref_id, y, x) + output: + * tokens after applying RoPE3D (batch_size x nheads x ntokens x dim) + """ + dim = tokens.size(3) + assert dim == self.D_ref + self.D_y + self.D_x, \ + f"dim {dim} != {self.D_ref} + {self.D_y} + {self.D_x}" + assert positions.ndim == 3 and positions.shape[-1] == 3 # Batch, Seq, 3 + + # Get cos/sin for ref_id dimension + cos_r, sin_r = self.get_cos_sin(self.D_ref, int(positions[:,:,0].max())+1, + tokens.device, tokens.dtype, self.ref_base) + # Get cos/sin for spatial dimensions (y and x have different D) + cos_y, sin_y = self.get_cos_sin(self.D_y, int(positions[:,:,1].max())+1, + tokens.device, tokens.dtype, self.base) + cos_x, sin_x = self.get_cos_sin(self.D_x, int(positions[:,:,2].max())+1, + tokens.device, tokens.dtype, self.base) + + # Split features using fixed dimensions + r = tokens[..., :self.D_ref] # [0:22] + y = tokens[..., self.D_ref:self.D_ref+self.D_y] # [22:43] + x = tokens[..., self.D_ref+self.D_y:] # [43:64] + + # Apply rope1d on each part with corresponding position + r = self.apply_rope1d(r, positions[:,:,0], cos_r, sin_r) # ref_id + y = self.apply_rope1d(y, positions[:,:,1], cos_y, sin_y) # height + x = self.apply_rope1d(x, positions[:,:,2], cos_x, sin_x) # width + + tokens = torch.cat((r, y, x), dim=-1) + return tokens + + +class PositionGetter3D(object): + """ return 3D positions of patches (ref_id, y, x) """ + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, n, h, w, device, chunk_size=None): + """ + Args: + b: batch size + n: number of frames + h: height in patches + w: width in patches + device: torch device + chunk_size: size of each chunk for computing ref_id + If None, uses frame_idx directly (backward compatible) + Returns: + pos: (b, n*h*w, 3) with (ref_id, y, x) + + ref_id formula: + - frame 0: ref_id = 0 + - other frames: ref_id = ((frame_idx - 1) // chunk_size) * chunk_size + """ + # Default chunk_size for backward compatibility + if chunk_size is None: + chunk_size = n // 2 + + cache_key = (n, h, w, chunk_size) + if cache_key not in self.cache_positions: + positions = [] + for frame_idx in range(n): + # Compute ref_id based on chunk_size + if frame_idx == 0: + ref_id = 0 + else: + ref_id = ((frame_idx - 1) // chunk_size) * chunk_size + + for py in range(h): + for px in range(w): + positions.append([ref_id, py, px]) + pos = torch.tensor(positions, dtype=torch.long, device=device) + self.cache_positions[cache_key] = pos + + pos = self.cache_positions[cache_key].to(device).view(1, n*h*w, 3).expand(b, -1, 3).clone() + return pos + + +#---------------------------------------------------------- +# RoPE4D_RefAware: RoPE with reference awareness (ref_id, t, y, x) +# 4D encoding solves dim divisibility issues (64 / 4 = 16) +#---------------------------------------------------------- + +class RoPE4D_ChunkAware(torch.nn.Module): + """ + Chunk-Aware 4D RoPE implementation: (chunk_idx, t, y, x) + + Key design improvements over ref_id-based encoding: + - chunk_idx: which chunk this frame belongs to (0, 1, 2, ...) - CONTINUOUS values + - t: temporal position within the chunk (0 for anchor, 1-chunk_size for others) + - y, x: spatial position in the patch grid + + Dimension split for head_dim=64 (optimized for value ranges): + - chunk_idx: 8 dims (small value range: 0, 1, 2, 3, ...) + - t: 8 dims (small value range: 0 to chunk_size) + - y: 24 dims (larger range: 0 to ~37 for 518x518 images) + - x: 24 dims (larger range: 0 to ~37) + + This gives more capacity to spatial dimensions while keeping temporal compact. + """ + + def __init__(self, freq=100.0, chunk_freq=10000, temporal_freq=1000.0): + super().__init__() + self.base = freq # for spatial dimensions (y, x) + self.chunk_base = chunk_freq # for chunk_idx dimension + self.temporal_base = temporal_freq # for temporal dimension t + self.cache = {} + + # Fixed dimension split (8 + 8 + 24 + 24 = 64) + self.D_chunk = 8 # chunk_idx gets 8 dims + self.D_t = 8 # t gets 8 dims + self.D_y = 24 # y gets 24 dims + self.D_x = 24 # x gets 24 dims + + def get_cos_sin(self, D, seq_len, device, dtype, base): + key = (D, seq_len, device, dtype, base) + if key not in self.cache: + # For D dimensions, we need D/2 frequencies + half_D = (D + 1) // 2 + inv_freq = 1.0 / (base ** (torch.arange(0, half_D).float().to(device) * 2 / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1)[:, :D] # Truncate to D + cos = freqs.cos() # (Seq, D) + sin = freqs.sin() + self.cache[key] = (cos, sin) + return self.cache[key] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 4 (chunk_idx, t, y, x) + output: + * tokens after applying RoPE (batch_size x nheads x ntokens x dim) + + Dimension layout: [chunk(8) | t(8) | y(24) | x(24)] + """ + dim = tokens.size(3) + expected_dim = self.D_chunk + self.D_t + self.D_y + self.D_x + assert dim == expected_dim, f"dim {dim} != expected {expected_dim}" + assert positions.ndim == 3 and positions.shape[-1] == 4 + + # Get cos/sin for each dimension with appropriate base frequency + cos_c, sin_c = self.get_cos_sin(self.D_chunk, int(positions[:,:,0].max())+1, + tokens.device, tokens.dtype, self.chunk_base) + cos_t, sin_t = self.get_cos_sin(self.D_t, int(positions[:,:,1].max())+1, + tokens.device, tokens.dtype, self.temporal_base) + cos_y, sin_y = self.get_cos_sin(self.D_y, int(positions[:,:,2].max())+1, + tokens.device, tokens.dtype, self.base) + cos_x, sin_x = self.get_cos_sin(self.D_x, int(positions[:,:,3].max())+1, + tokens.device, tokens.dtype, self.base) + + # Split features using fixed dimensions + c = tokens[..., :self.D_chunk] # [0:8] + t = tokens[..., self.D_chunk:self.D_chunk+self.D_t] # [8:16] + y = tokens[..., self.D_chunk+self.D_t:self.D_chunk+self.D_t+self.D_y] # [16:40] + x = tokens[..., self.D_chunk+self.D_t+self.D_y:] # [40:64] + + # Apply rope1d on each part + c = self.apply_rope1d(c, positions[:,:,0], cos_c, sin_c) # chunk_idx + t = self.apply_rope1d(t, positions[:,:,1], cos_t, sin_t) # temporal + y = self.apply_rope1d(y, positions[:,:,2], cos_y, sin_y) # height + x = self.apply_rope1d(x, positions[:,:,3], cos_x, sin_x) # width + + return torch.cat((c, t, y, x), dim=-1) + + +# Legacy alias - redirect to new class +class RoPE4D_RefAware(RoPE4D_ChunkAware): + """Backward compatibility alias for RoPE4D_ChunkAware""" + def __init__(self, freq=100.0, ref_freq=10.0, temporal_freq=10.0): + super().__init__(freq=freq, chunk_freq=ref_freq, temporal_freq=temporal_freq) + +# Aliases for backward compatibility +RoPE3D_RefAware = RoPE4D_RefAware +RoPE3D_ChunkAware = RoPE4D_RefAware + + +class PositionGetter4D_ChunkAware(object): + """ + Return 4D positions: (chunk_idx, t, y, x) + + Key improvement: uses chunk_idx (continuous: 0, 1, 2, ...) instead of ref_id (sparse: 0, 4, 8, ...) + + chunk_idx: which chunk this frame belongs to + t: temporal position within the chunk + y, x: spatial position in patch grid + + Position encoding distribution (N=12, chunk_size=4): + Frame: 0 1 2 3 4 5 6 7 8 9 10 11 + chunk_idx: 0 0 0 0 1 1 1 1 2 2 2 2 ← CONTINUOUS! + t: 0 1 2 3 0 1 2 3 0 1 2 3 ← Repeating pattern + + This gives RoPE a much better distribution to work with compared to: + ref_id: 0 0 0 0 0 4 4 4 4 8 8 8 ← SPARSE (problematic!) + """ + + def __init__(self): + self.cache = {} + + def __call__(self, b, n, h, w, device, chunk_size=None, ref_ids=None): + """ + Args: + b: batch size + n: total number of frames + h, w: patch grid size (height, width in patches) + chunk_size: size of each chunk (default: n//2 for backward compatibility) + ref_ids: Deprecated, ignored. Use chunk_size instead. + Returns: + pos: (b, n*h*w, 4) with (chunk_idx, t, y, x) + """ + if chunk_size is None: + chunk_size = n // 2 # backward compatible + + key = (n, h, w, chunk_size) + if key not in self.cache: + positions = [] + for frame_idx in range(n): + # Compute chunk_idx (continuous: 0, 1, 2, ...) + if frame_idx == 0: + chunk_idx = 0 + t = 0 # First frame (anchor of first chunk) + else: + # chunk_idx = floor((frame_idx - 1) / chunk_size) + chunk_idx = (frame_idx - 1) // chunk_size + # t = position within chunk (1 to chunk_size for non-first-frame of chunk) + t = (frame_idx - 1) % chunk_size + 1 + + for py in range(h): + for px in range(w): + positions.append([chunk_idx, t, py, px]) + pos = torch.tensor(positions, dtype=torch.long, device=device) + self.cache[key] = pos + + # Clone and expand for batch + pos = self.cache[key].to(device).view(1, n*h*w, 4).expand(b, -1, 4).clone() + return pos + + def get_position_for_single_frame(self, frame_idx, h, w, device, chunk_size): + """ + Get position encoding for a single frame (useful for inference). + + Args: + frame_idx: index of the frame + h, w: patch grid size + chunk_size: size of each chunk + Returns: + pos: (1, h*w, 4) with (chunk_idx, t, y, x) + """ + if frame_idx == 0: + chunk_idx = 0 + t = 0 + else: + chunk_idx = (frame_idx - 1) // chunk_size + t = (frame_idx - 1) % chunk_size + 1 + + positions = [] + for py in range(h): + for px in range(w): + positions.append([chunk_idx, t, py, px]) + + pos = torch.tensor(positions, dtype=torch.long, device=device) + return pos.view(1, h*w, 4) + + +# Legacy aliases for backward compatibility +class PositionGetter3D_RefAware(PositionGetter4D_ChunkAware): + """Backward compatibility alias""" + pass + +PositionGetter3D_ChunkAware = PositionGetter4D_ChunkAware + + +#---------------------------------------------------------- +# ALiBi2D_Temporal: ALiBi for temporal dimensions only +# RoPE2D handles spatial (y, x), ALiBi handles (chunk_id, t) +#---------------------------------------------------------- + +class ALiBi2D_Temporal(torch.nn.Module): + """ + ALiBi (Attention with Linear Biases) for temporal dimensions (chunk_id, t). + + This is used in HYBRID mode with RoPE2D: + - RoPE2D encodes spatial positions (y, x) via rotary embeddings + - ALiBi2D_Temporal encodes temporal positions (chunk_id, t) via attention bias + + Key advantage: Perfect extrapolation for inference + - Training: chunk_id ∈ [0, 3] + - Inference: chunk_id ∈ [0, 200+] + - ALiBi uses linear distance, so no distribution shift! + + Formula: + bias(i, j) = -m * (chunk_weight * |chunk_i - chunk_j| + temporal_weight * |t_i - t_j|) + + where m is the slope for each attention head (different per head). + """ + + def __init__(self, num_heads, chunk_weight=1.0, temporal_weight=0.5): + """ + Args: + num_heads: number of attention heads + chunk_weight: weight for chunk_id distance (default 1.0, most important) + temporal_weight: weight for temporal t distance (default 0.5) + """ + super().__init__() + self.num_heads = num_heads + # Store as Python floats to avoid tensor conversion issues + self.chunk_weight = float(chunk_weight) + self.temporal_weight = float(temporal_weight) + + # Generate slopes for each head using geometric sequence + # Paper recommendation: slopes = 2^(-8/n * i) for i in [1, 2, ..., n] + slopes = torch.pow(2.0, -torch.linspace(0, 8, num_heads)) + self.register_buffer('slopes', slopes) + + def forward(self, positions_4d): + """ + Compute ALiBi temporal bias from 4D positions. + + Args: + positions_4d: (B, L, 4) with (chunk_id, t, y, x) + We only use chunk_id and t here. + + Returns: + bias: (B, num_heads, L, L) attention bias matrix + Negative values penalize distant frames. + """ + B, L, _ = positions_4d.shape + device = positions_4d.device + + # Extract temporal dimensions only (chunk_id, t) + chunk_ids = positions_4d[:, :, 0:1].float() # (B, L, 1) + t_ids = positions_4d[:, :, 1:2].float() # (B, L, 1) + + # Compute pairwise distance matrices + # chunk_dist[b, i, j] = |chunk_i - chunk_j| + chunk_dist = torch.abs( + chunk_ids - chunk_ids.transpose(1, 2) + ) # (B, L, L) + + t_dist = torch.abs( + t_ids - t_ids.transpose(1, 2) + ) # (B, L, L) + + # Weighted combination of distances + total_dist = ( + self.chunk_weight * chunk_dist + + self.temporal_weight * t_dist + ) # (B, L, L) + + # Apply per-head slopes: bias = -slope * distance + # slopes: (num_heads,) -> (1, num_heads, 1, 1) + # total_dist: (B, L, L) -> (B, 1, L, L) + bias = -self.slopes.view(1, -1, 1, 1) * total_dist.unsqueeze(1) + + return bias # (B, num_heads, L, L) \ No newline at end of file diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/transformer_head.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/transformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8b03892d1629f995151fc06e1c5299f9f6b4a6f2 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/layers/transformer_head.py @@ -0,0 +1,81 @@ +from .attention import FlashAttentionRope +from .block import BlockRope +from ..dinov2.layers import Mlp +import torch.nn as nn +from functools import partial +from torch.utils.checkpoint import checkpoint +import torch.nn.functional as F + +class TransformerDecoder(nn.Module): + def __init__( + self, + in_dim, + out_dim, + dec_embed_dim=512, + depth=5, + dec_num_heads=8, + mlp_ratio=4, + rope=None, + need_project=True, + use_checkpoint=False, + ): + super().__init__() + + self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity() + self.use_checkpoint = use_checkpoint + + self.blocks = nn.ModuleList([ + BlockRope( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=None, + qk_norm=False, + # attn_class=MemEffAttentionRope, + attn_class=FlashAttentionRope, + rope=rope + ) for _ in range(depth)]) + + self.linear_out = nn.Linear(dec_embed_dim, out_dim) + + def forward(self, hidden, xpos=None): + hidden = self.projects(hidden) + for i, blk in enumerate(self.blocks): + if self.use_checkpoint and self.training: + hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False) + else: + hidden = blk(hidden, xpos=xpos) + out = self.linear_out(hidden) + return out + +class LinearPts3d (nn.Module): + """ + Linear head for dust3r + Each token outputs: - 16x16 3D points (+ confidence) + """ + + def __init__(self, patch_size, dec_embed_dim, output_dim=3,): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2) + + def forward(self, decout, img_shape): + H, W = img_shape + tokens = decout[-1] + B, S, D = tokens.shape + + # extract 3D points + feat = self.proj(tokens) # B,S,D + feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) + feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W + + # permute + norm depth + return feat.permute(0, 2, 3, 1) \ No newline at end of file diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/slamformer.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/slamformer.py new file mode 100644 index 0000000000000000000000000000000000000000..be601fe006038e86ca755c1493070dcfc737f709 --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/models/slamformer.py @@ -0,0 +1,304 @@ +import torch +import torch.nn as nn +from functools import partial +from copy import deepcopy + +from .dinov2.layers import Mlp +from ..utils.geometry import homogenize_points +from .layers.pos_embed import RoPE2D, PositionGetter +from .layers.block import BlockRope +from .layers.attention import FlashAttentionRope +from .layers.transformer_head import TransformerDecoder, LinearPts3d +from .layers.camera_head import CameraHead +from .dinov2.hub.backbones import dinov2_vitl14, dinov2_vitl14_reg +from huggingface_hub import PyTorchModelHubMixin +from torch.utils.checkpoint import checkpoint +from typing import Optional, Tuple, List, Any +from dataclasses import dataclass +from transformers.file_utils import ModelOutput + + +from .layers.dpt_head import DPTHead + +@dataclass +class StreamVGGTOutput(ModelOutput): + ress: Optional[List[dict]] = None + views: Optional[torch.Tensor] = None + + + + +class SLAMFormer(nn.Module, PyTorchModelHubMixin): + def __init__( + self, + pos_type='rope100', + decoder_size='large', + ): + super().__init__() + + # ---------------------- + # Encoder + # ---------------------- + self.encoder = dinov2_vitl14_reg(pretrained=False) + self.patch_size = 14 + del self.encoder.mask_token + + # ---------------------- + # Positonal Encoding + # ---------------------- + self.pos_type = pos_type if pos_type is not None else 'none' + self.rope=None + if self.pos_type.startswith('rope'): # eg rope100 + if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions") + freq = float(self.pos_type[len('rope'):]) + self.rope = RoPE2D(freq=freq) + self.position_getter = PositionGetter() + else: + raise NotImplementedError + + + # ---------------------- + # Decoder + # ---------------------- + enc_embed_dim = self.encoder.blocks[0].attn.qkv.in_features # 1024 + if decoder_size == 'small': + dec_embed_dim = 384 + dec_num_heads = 6 + mlp_ratio = 4 + dec_depth = 24 + elif decoder_size == 'base': + dec_embed_dim = 768 + dec_num_heads = 12 + mlp_ratio = 4 + dec_depth = 24 + elif decoder_size == 'large': + dec_embed_dim = 1024 + dec_num_heads = 16 + mlp_ratio = 4 + dec_depth = 36 + else: + raise NotImplementedError + self.decoder = nn.ModuleList([ + BlockRope( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=0.01, + qk_norm=True, + attn_class=FlashAttentionRope, + rope=self.rope + ) for _ in range(dec_depth)]) + self.dec_embed_dim = dec_embed_dim + + # ---------------------- + # Register_token + # ---------------------- + num_register_tokens = 5 + self.patch_start_idx = num_register_tokens + self.register_token = nn.Parameter(torch.randn(1, 1, num_register_tokens, self.dec_embed_dim)) + nn.init.normal_(self.register_token, std=1e-6) + + # ---------------------- + # Local Points Decoder + # ---------------------- + self.point_decoder = TransformerDecoder( + in_dim=2*self.dec_embed_dim, + dec_embed_dim=1024, + dec_num_heads=16, + out_dim=1024, + rope=self.rope, + use_checkpoint=True + ) + self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3) + + # ---------------------- + # Conf Decoder + # ---------------------- + self.conf_decoder = deepcopy(self.point_decoder) + self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1) + ''' + + self.point_head = DPTHead(dim_in=1024, output_dim=4, activation="inv_log", conf_activation="expp1", intermediate_layer_idx=[0]) + ''' + + + # ---------------------- + # Camera Pose Decoder + # ---------------------- + self.camera_decoder = TransformerDecoder( + in_dim=2*self.dec_embed_dim, + dec_embed_dim=1024, + dec_num_heads=16, # 8 + out_dim=512, + rope=self.rope, + use_checkpoint=True + ) + self.camera_head = CameraHead(dim=512) + #self.intrin_head = CameraHead(dim=512) + + # For ImageNet Normalize + image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + + self.register_buffer("image_mean", image_mean) + self.register_buffer("image_std", image_std) + + + def decode(self, N, H, W, hidden_I=None, hidden_F=None): + # branch 1: frontend + # branch 2: backend + # branch 3: mix-mode + + branch = 1 + if hidden_I is not None: + branch = 1 + hidden = hidden_I + + BN, hw, C = hidden.shape + B = BN // N + + hidden = hidden.reshape(B*N, hw, -1) + + register_token = self.register_token.repeat(B, N, 1, 1).reshape(B*N, *self.register_token.shape[-2:]) + + # Concatenate special tokens with patch tokens + hidden = torch.cat([register_token, hidden], dim=1) + hw = hidden.shape[1] + + if hidden_F is None: + branch = 1 + else: + branch = 3 + hidden = hidden.view(B,N,-1,C) + hidden = torch.cat([hidden_F.view(B,N,-1,C)[:,:int(N//2)], hidden[:,int(N//2):]],axis=1).reshape(B*N,-1,C) + + + elif hidden_I is None and hidden_F is not None: + branch = 2 + hidden = hidden_F + BN, hw, C = hidden.shape + B = BN // N + + + if self.pos_type.startswith('rope'): + pos = self.position_getter(B * N, H//self.patch_size, W//self.patch_size, hidden.device) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + 1 + pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(hidden.device).to(pos.dtype) + pos = torch.cat([pos_special, pos], dim=1) + + final_output = [] + for i in range(len(self.decoder)): + blk = self.decoder[i] + + if i % 2 == 0: + pos = pos.reshape(B*N, hw, -1) + hidden = hidden.reshape(B*N, hw, -1) + global_ = False + else: + pos = pos.reshape(B, N*hw, -1) + hidden = hidden.reshape(B, N*hw, -1) + global_ = True + + + hidden = checkpoint(blk, hidden, xpos=pos, N=N, branch=branch, global_=global_, use_reentrant=False) + + if i+1 in [len(self.decoder)-1, len(self.decoder)]: + final_output.append(hidden.reshape(B*N, hw, -1)) + + return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(B*N, hw, -1) + + def forward(self, views, query_points): + imgs = torch.stack( + [view["img"] for view in views], dim=0 + ).permute(1, 0, 2, 3, 4) # B S C H W + + imgs = (imgs - self.image_mean) / self.image_std + + B, N, _, H, W = imgs.shape + patch_h, patch_w = H // 14, W // 14 + + shape_ = (B,N,H,W,patch_h, patch_w) + # encode by dinov2 + imgs = imgs.reshape(B*N, _, H, W) + hidden_I = self.encoder(imgs, is_training=True) + + if isinstance(hidden_I, dict): + hidden_I = hidden_I["x_norm_patchtokens"] + + hidden_F, pos = self.decode(N, H, W, hidden_I, hidden_F=None) + res_F = self.extract(hidden_F,pos,shape_) + + + C2 = hidden_F.shape[-1] + C = int(C2//2) + ''' + hidden_B, pos = self.decode(N, H, W, hidden_I=None, hidden_F=hidden_F[:,:,:int(C2//2)]) + res_B = self.extract(hidden_B,pos,shape_) + ''' + + hidden_M, pos = self.decode(N, H, W, hidden_I,hidden_F[:,:,:int(C2//2)]) + res_M = self.extract(hidden_M,pos,shape_) + + + N_half = int(N//2) + hidden_input_B = torch.cat([hidden_F.view(B,N,-1,C2)[:,:N_half,:,:int(C2//2)], hidden_M.view(B,N,-1,C2)[:,N_half:,:,:int(C2//2)]],axis=1).view(B*N,-1,C) + hidden_B, pos = self.decode(N, H, W, hidden_I=None,hidden_F=hidden_input_B) + res_B = self.extract(hidden_B,pos,shape_) + + + + output_F=StreamVGGTOutput(ress=[res_F],views=views) + output_M=StreamVGGTOutput(ress=[res_M],views=views) + output_B=StreamVGGTOutput(ress=[res_B],views=views) + + + return output_F, output_M, output_B + + + def extract(self, hidden, pos, shape_): + B,N,H,W,patch_h, patch_w = shape_ + + point_hidden = self.point_decoder(hidden, xpos=pos) # BN, P, 1024 + conf_hidden = self.conf_decoder(hidden, xpos=pos) + camera_hidden = self.camera_decoder(hidden, xpos=pos) + + + + with torch.amp.autocast(device_type='cuda', enabled=False): + # local points + point_hidden = point_hidden.float() + ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1) + xy, z = ret.split([2, 1], dim=-1) + z = torch.exp(z) + local_points = torch.cat([xy * z, z], dim=-1) + + # confidence + conf_hidden = conf_hidden.float() + conf = self.conf_head([conf_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1) + + # camera + camera_hidden = camera_hidden.float() + camera_poses = self.camera_head(camera_hidden[:, self.patch_start_idx:], patch_h, patch_w).reshape(B, N, 4, 4) + + + # unproject local points using camera poses + points = torch.einsum('bnij, bnhwj -> bnhwi', camera_poses, homogenize_points(local_points))[..., :3] + + output = dict(points=points, + local_points=local_points, + conf=conf, + camera_poses=camera_poses, + ) + + return output diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/utils/basic.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/utils/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac73492409b1f2441a84e2f9de9681b3cf3ca9f --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/utils/basic.py @@ -0,0 +1,223 @@ +import os +import os.path as osp +import math +import cv2 +from PIL import Image +import torch +from torchvision import transforms +from plyfile import PlyData, PlyElement +import numpy as np + +def load_images_as_tensor(path='data/truck', interval=1, PIXEL_LIMIT=255000): + """ + Loads images from a directory or video, resizes them to a uniform size, + then converts and stacks them into a single [N, 3, H, W] PyTorch tensor. + """ + sources = [] + + # --- 1. Load image paths or video frames --- + if osp.isdir(path): + print(f"Loading images from directory: {path}") + filenames = sorted([x for x in os.listdir(path) if x.lower().endswith(('.png', '.jpg', '.jpeg'))]) + for i in range(0, len(filenames), interval): + img_path = osp.join(path, filenames[i]) + try: + sources.append(Image.open(img_path).convert('RGB')) + except Exception as e: + print(f"Could not load image {filenames[i]}: {e}") + elif path.lower().endswith('.mp4'): + print(f"Loading frames from video: {path}") + cap = cv2.VideoCapture(path) + if not cap.isOpened(): raise IOError(f"Cannot open video file: {path}") + frame_idx = 0 + while True: + ret, frame = cap.read() + if not ret: break + if frame_idx % interval == 0: + rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + sources.append(Image.fromarray(rgb_frame)) + frame_idx += 1 + cap.release() + else: + raise ValueError(f"Unsupported path. Must be a directory or a .mp4 file: {path}") + + if not sources: + print("No images found or loaded.") + return torch.empty(0) + + print(f"Found {len(sources)} images/frames. Processing...") + + # --- 2. Determine a uniform target size for all images based on the first image --- + # This is necessary to ensure all tensors have the same dimensions for stacking. + first_img = sources[0] + W_orig, H_orig = first_img.size + scale = math.sqrt(PIXEL_LIMIT / (W_orig * H_orig)) if W_orig * H_orig > 0 else 1 + W_target, H_target = W_orig * scale, H_orig * scale + k, m = round(W_target / 14), round(H_target / 14) + while (k * 14) * (m * 14) > PIXEL_LIMIT: + if k / m > W_target / H_target: k -= 1 + else: m -= 1 + TARGET_W, TARGET_H = max(1, k) * 14, max(1, m) * 14 + print(f"All images will be resized to a uniform size: ({TARGET_W}, {TARGET_H})") + + # --- 3. Resize images and convert them to tensors in the [0, 1] range --- + tensor_list = [] + # Define a transform to convert a PIL Image to a CxHxW tensor and normalize to [0,1] + to_tensor_transform = transforms.ToTensor() + + for img_pil in sources: + try: + # Resize to the uniform target size + resized_img = img_pil.resize((TARGET_W, TARGET_H), Image.Resampling.LANCZOS) + # Convert to tensor + img_tensor = to_tensor_transform(resized_img) + tensor_list.append(img_tensor) + except Exception as e: + print(f"Error processing an image: {e}") + + if not tensor_list: + print("No images were successfully processed.") + return torch.empty(0) + + # --- 4. Stack the list of tensors into a single [N, C, H, W] batch tensor --- + return torch.stack(tensor_list, dim=0) + + +def tensor_to_pil(tensor): + """ + Converts a PyTorch tensor to a PIL image. Automatically moves the channel dimension + (if it has size 3) to the last axis before converting. + + Args: + tensor (torch.Tensor): Input tensor. Expected shape can be [C, H, W], [H, W, C], or [H, W]. + + Returns: + PIL.Image: The converted PIL image. + """ + if torch.is_tensor(tensor): + array = tensor.detach().cpu().numpy() + else: + array = tensor + + return array_to_pil(array) + + +def array_to_pil(array): + """ + Converts a NumPy array to a PIL image. Automatically: + - Squeezes dimensions of size 1. + - Moves the channel dimension (if it has size 3) to the last axis. + + Args: + array (np.ndarray): Input array. Expected shape can be [C, H, W], [H, W, C], or [H, W]. + + Returns: + PIL.Image: The converted PIL image. + """ + # Remove singleton dimensions + array = np.squeeze(array) + + # Ensure the array has the channel dimension as the last axis + if array.ndim == 3 and array.shape[0] == 3: # If the channel is the first axis + array = np.transpose(array, (1, 2, 0)) # Move channel to the last axis + + # Handle single-channel grayscale images + if array.ndim == 2: # [H, W] + return Image.fromarray((array * 255).astype(np.uint8), mode="L") + elif array.ndim == 3 and array.shape[2] == 3: # [H, W, C] with 3 channels + return Image.fromarray((array * 255).astype(np.uint8), mode="RGB") + else: + raise ValueError(f"Unsupported array shape for PIL conversion: {array.shape}") + + +def rotate_target_dim_to_last_axis(x, target_dim=3): + shape = x.shape + axis_to_move = -1 + # Iterate backwards to find the first occurrence from the end + # (which corresponds to the last dimension of size 3 in the original order). + for i in range(len(shape) - 1, -1, -1): + if shape[i] == target_dim: + axis_to_move = i + break + + # 2. If the axis is found and it's not already in the last position, move it. + if axis_to_move != -1 and axis_to_move != len(shape) - 1: + # Create the new dimension order. + dims_order = list(range(len(shape))) + dims_order.pop(axis_to_move) + dims_order.append(axis_to_move) + + # Use permute to reorder the dimensions. + ret = x.transpose(*dims_order) + else: + ret = x + + return ret + + +def write_ply( + xyz, + rgb=None, + path='output.ply', +) -> None: + if torch.is_tensor(xyz): + xyz = xyz.detach().cpu().numpy() + + if torch.is_tensor(rgb): + rgb = rgb.detach().cpu().numpy() + + if rgb is not None and rgb.max() > 1: + rgb = rgb / 255. + + xyz = rotate_target_dim_to_last_axis(xyz, 3) + xyz = xyz.reshape(-1, 3) + + if rgb is not None: + rgb = rotate_target_dim_to_last_axis(rgb, 3) + rgb = rgb.reshape(-1, 3) + + if rgb is None: + min_coord = np.min(xyz, axis=0) + max_coord = np.max(xyz, axis=0) + normalized_coord = (xyz - min_coord) / (max_coord - min_coord + 1e-8) + + hue = 0.7 * normalized_coord[:,0] + 0.2 * normalized_coord[:,1] + 0.1 * normalized_coord[:,2] + hsv = np.stack([hue, 0.9*np.ones_like(hue), 0.8*np.ones_like(hue)], axis=1) + + c = hsv[:,2:] * hsv[:,1:2] + x = c * (1 - np.abs( (hsv[:,0:1]*6) % 2 - 1 )) + m = hsv[:,2:] - c + + rgb = np.zeros_like(hsv) + cond = (0 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 1) + rgb[cond] = np.hstack([c[cond], x[cond], np.zeros_like(x[cond])]) + cond = (1 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 2) + rgb[cond] = np.hstack([x[cond], c[cond], np.zeros_like(x[cond])]) + cond = (2 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 3) + rgb[cond] = np.hstack([np.zeros_like(x[cond]), c[cond], x[cond]]) + cond = (3 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 4) + rgb[cond] = np.hstack([np.zeros_like(x[cond]), x[cond], c[cond]]) + cond = (4 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 5) + rgb[cond] = np.hstack([x[cond], np.zeros_like(x[cond]), c[cond]]) + cond = (5 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 6) + rgb[cond] = np.hstack([c[cond], np.zeros_like(x[cond]), x[cond]]) + rgb = (rgb + m) + + dtype = [ + ("x", "f4"), + ("y", "f4"), + ("z", "f4"), + ("nx", "f4"), + ("ny", "f4"), + ("nz", "f4"), + ("red", "u1"), + ("green", "u1"), + ("blue", "u1"), + ] + normals = np.zeros_like(xyz) + elements = np.empty(xyz.shape[0], dtype=dtype) + attributes = np.concatenate((xyz, normals, rgb * 255), axis=1) + elements[:] = list(map(tuple, attributes)) + vertex_element = PlyElement.describe(elements, "vertex") + ply_data = PlyData([vertex_element]) + ply_data.write(path) \ No newline at end of file diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/utils/debug.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/utils/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..f3da8f3d6caece7c828cab2574bf1cbdd207779e --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/utils/debug.py @@ -0,0 +1,63 @@ +import os +import json +import debugpy +import socket +import random + +def update_vscode_launch_file(host: str, port: int): + """Update the .vscode/launch.json file with the new host and port.""" + launch_file_path = ".vscode/launch.json" + # Desired configuration + new_config = { + "version": "0.2.0", + "configurations": [ + { + "name": "bash_debug", + "type": "debugpy", + "request": "attach", + "connect": { + "host": host, + "port": port + }, + "justMyCode": False + }, + ] + } + + # Ensure the .vscode directory exists + if not os.path.exists(".vscode"): + os.makedirs(".vscode") + + # Write the updated configuration to launch.json + with open(launch_file_path, "w") as f: + json.dump(new_config, f, indent=4) + print(f"Updated {launch_file_path} with host: {host} and port: {port}") + +def is_port_in_use(host, port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex((host, port)) == 0 + +def setup_debug(is_main_process=True, max_retries=10, port_range=(10000, 20000)): + if is_main_process: + host = os.environ['SLURM_NODELIST'].split(',')[0] + + for _ in range(max_retries): + port = random.randint(*port_range) + try: + if is_port_in_use(host, port): + print(f"Port {port} is already in use, trying another...") + continue + + # 更新 launch.json + update_vscode_launch_file(host, port) + + print("master_addr = ", host) + debugpy.listen((host, port)) + print(f"Waiting for debugger attach at port {port}...", flush=True) + debugpy.wait_for_client() + print("Debugger attached", flush=True) + return + except Exception as e: + print(f"Failed to bind to port {port}: {e}") + + raise RuntimeError("Could not find a free port for debugpy after several attempts.") \ No newline at end of file diff --git a/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/utils/geometry.py b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..903b31d62f3fc9036169e2eba60cf21491a3791c --- /dev/null +++ b/outdoor_v48_4gpu_v2/code/05_02-14:21:58/slamformer/utils/geometry.py @@ -0,0 +1,414 @@ +import numpy as np +import torch +import torch.nn.functional as F + +def se3_inverse(T): + """ + Computes the inverse of a batch of SE(3) matrices. + T: Tensor of shape (B, 4, 4) + """ + if len(T.shape) == 2: + T = T[None] + unseq_flag = True + else: + unseq_flag = False + + if torch.is_tensor(T): + R = T[:, :3, :3] + t = T[:, :3, 3].unsqueeze(-1) + R_inv = R.transpose(-2, -1) + t_inv = -torch.matmul(R_inv, t) + T_inv = torch.cat([ + torch.cat([R_inv, t_inv], dim=-1), + torch.tensor([0, 0, 0, 1], device=T.device, dtype=T.dtype).repeat(T.shape[0], 1, 1) + ], dim=1) + else: + R = T[:, :3, :3] + t = T[:, :3, 3, np.newaxis] + + R_inv = np.swapaxes(R, -2, -1) + t_inv = -R_inv @ t + + bottom_row = np.zeros((T.shape[0], 1, 4), dtype=T.dtype) + bottom_row[:, :, 3] = 1 + + top_part = np.concatenate([R_inv, t_inv], axis=-1) + T_inv = np.concatenate([top_part, bottom_row], axis=1) + + if unseq_flag: + T_inv = T_inv[0] + return T_inv + +def get_pixel(H, W): + # get 2D pixels (u, v) for image_a in cam_a pixel space + u_a, v_a = np.meshgrid(np.arange(W), np.arange(H)) + # u_a = np.flip(u_a, axis=1) + # v_a = np.flip(v_a, axis=0) + pixels_a = np.stack([ + u_a.flatten() + 0.5, + v_a.flatten() + 0.5, + np.ones_like(u_a.flatten()) + ], axis=0) + + return pixels_a + +def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, z_far=0, **kw): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.""" + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + if z_far > 0: + valid_mask = valid_mask & (depthmap < z_far) + + X_world = X_cam # default + if camera_pose is not None: + # R_cam2world = np.float32(camera_params["R_cam2world"]) + # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates (invalid depth values) + X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + + return X_world, valid_mask + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + # assert camera_intrinsics[0, 1] == 0.0 + # assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + # Mask for valid coordinates + valid_mask = (depthmap > 0.0) + # Invalid any depth > 80m + valid_mask = valid_mask + return X_cam, valid_mask + +def homogenize_points( + points, +): + """Convert batched points (xyz) to (xyz1).""" + return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + + +def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None): + + if H is None: + B,H,W = depth1.shape + else: + B = depth1.shape[0] + with torch.no_grad(): + x1_n = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=depth1.device + ) + for n in (B, H, W) + ], + indexing = 'ij' + ) + x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) + mask, x2 = warp_kpts( + x1_n.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + depth_interpolation_mode = depth_interpolation_mode, + relative_depth_error_threshold = relative_depth_error_threshold, + ) + prob = mask.float().reshape(B, H, W) + x2 = x2.reshape(B, H, W, 2) + return x2, prob + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05): + """Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here + Args: + kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1) + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + ( + n, + h, + w, + ) = depth0.shape + if depth_interpolation_mode == "combined": + # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation + if smooth_mask: + raise NotImplementedError("Combined bilinear and NN warp not implemented") + valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "bilinear", + relative_depth_error_threshold = relative_depth_error_threshold) + valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "nearest-exact", + relative_depth_error_threshold = relative_depth_error_threshold) + nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) + warp = warp_bilinear.clone() + warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid] + valid = valid_bilinear | valid_nearest + return valid, warp + + + kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[ + :, 0, :, 0 + ] + kpts0 = torch.stack( + (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + # Sample depth, get calculable_mask on depth != 0 + # nonzero_mask = kpts0_depth != 0 + # Sample depth, get calculable_mask on depth > 0 + nonzero_mask = kpts0_depth > 0 + + # Unproject + kpts0_h = ( + torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) + * kpts0_depth[..., None] + ) # (N, L, 3) + kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + kpts0_cam = kpts0_n + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / ( + w_kpts0_h[:, :, [2]] + 1e-4 + ) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = ( + (w_kpts0[:, :, 0] > 0) + * (w_kpts0[:, :, 0] < w - 1) + * (w_kpts0[:, :, 1] > 0) + * (w_kpts0[:, :, 1] < h - 1) + ) + w_kpts0 = torch.stack( + (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 + ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] + # w_kpts0[~covisible_mask, :] = -5 # xd + + w_kpts0_depth = F.grid_sample( + depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False + )[:, 0, :, 0] + + relative_depth_error = ( + (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth + ).abs() + if not smooth_mask: + consistent_mask = relative_depth_error < relative_depth_error_threshold + else: + consistent_mask = (-relative_depth_error/smooth_mask).exp() + valid_mask = nonzero_mask * covisible_mask * consistent_mask + if return_relative_depth_error: + return relative_depth_error, w_kpts0 + else: + return valid_mask, w_kpts0 + + +def geotrf(Trf, pts, ncol=None, norm=False): + """ Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + # optimized code + if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and + Trf.ndim == 3 and pts.ndim == 4): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] + else: + raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def inv(mat): + """ Invert a torch or numpy matrix + """ + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f'bad matrix type = {type(mat)}') + +def opencv_camera_to_plucker(poses, K, H, W): + device = poses.device + B = poses.shape[0] + + pixel = torch.from_numpy(get_pixel(H, W).astype(np.float32)).to(device).T.reshape(H, W, 3)[None].repeat(B, 1, 1, 1) # (3, H, W) + pixel = torch.einsum('bij, bhwj -> bhwi', torch.inverse(K), pixel) + ray_directions = torch.einsum('bij, bhwj -> bhwi', poses[..., :3, :3], pixel) + + ray_origins = poses[..., :3, 3][:, None, None].repeat(1, H, W, 1) + + ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True) + plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1) + plucker_ray = torch.cat([ray_directions, plucker_normal], dim=-1) + + return plucker_ray + + +def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor: + """ + Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth. + + Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool + """ + shape = depth.shape + depth = depth.reshape(-1, 1, *shape[-2:]) + if mask is not None: + mask = mask.reshape(-1, 1, *shape[-2:]) + + if mask is None: + diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)) + else: + diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)) + + edge = torch.zeros_like(depth, dtype=torch.bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= (diff / depth).nan_to_num_() > rtol + edge = edge.reshape(*shape) + return edge + +def weighted_procrustes(A, B, w, use_weights=True, eps=1e-16, return_T=False): + """ + X: torch tensor B x N x 3 + Y: torch tensor B x N x 3 + w: torch tensor B x N + """ + assert len(A) == len(B) + if use_weights: + W1 = torch.abs(w).sum(1, keepdim=True) + w_norm = (w / (W1 + eps)).unsqueeze(-1) + a_mean = (w_norm * A).sum(dim=1, keepdim=True) + b_mean = (w_norm * B).sum(dim=1, keepdim=True) + + A_c = A - a_mean + B_c = B - b_mean + + H = torch.einsum("bni,bnj->bij", A_c, w_norm * B_c) + + else: + a_mean = A.mean(axis=1, keepdim=True) + b_mean = B.mean(axis=1, keepdim=True) + + A_c = A - a_mean + B_c = B - b_mean + + H = torch.einsum("bij,bik->bjk", A_c, B_c) + + U, S, V = torch.svd(H) # U: B x 3 x 3, S: B x 3, V: B x 3 x 3 + Z = torch.eye(3).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device) + Z[:, -1, -1] = torch.sign(torch.linalg.det(U @ V.transpose(1, 2))) # B x 3 x 3 + R = V @ Z @ U.transpose(1, 2) # B x 3 x 3 + t = b_mean - torch.einsum("bij,bjk->bik", R, a_mean.transpose(-2, -1)).transpose(-2, -1) + if return_T: + T = torch.eye(4).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device) + T[:, :3, :3] = R + T[:, :3, 3] = t.squeeze() + return T + return R, t.squeeze() \ No newline at end of file