diff --git a/config/dataset_depth/data_diode_all.yaml b/config/dataset_depth/data_diode_all.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c2284fe699778b6ef4c23270d8f49c983c57306d --- /dev/null +++ b/config/dataset_depth/data_diode_all.yaml @@ -0,0 +1,4 @@ +name: diode_depth +disp_name: diode_depth_val_all +dir: diode/diode_val.tar +filenames: data_split/diode_depth/diode_val_all_filename_list.txt diff --git a/config/dataset_depth/data_eth3d.yaml b/config/dataset_depth/data_eth3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7175e2b583b2172feb1ba4630b98256de453ade7 --- /dev/null +++ b/config/dataset_depth/data_eth3d.yaml @@ -0,0 +1,4 @@ +name: eth3d_depth +disp_name: eth3d_depth_full +dir: eth3d/eth3d.tar +filenames: data_split/eth3d_depth/eth3d_filename_list.txt diff --git a/config/dataset_depth/data_hypersim_train.yaml b/config/dataset_depth/data_hypersim_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..400e7f45d7f64303a94e5c60fcb0c6792f4ee36d --- /dev/null +++ b/config/dataset_depth/data_hypersim_train.yaml @@ -0,0 +1,4 @@ +name: hypersim_depth +disp_name: hypersim_depth_train +dir: hypersim/hypersim_processed_train.tar +filenames: data_split/hypersim_depth/filename_list_train_filtered.txt diff --git a/config/dataset_depth/data_hypersim_val.yaml b/config/dataset_depth/data_hypersim_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2edd0fbd1b107a0a5af6ac324a34286e56d2926e --- /dev/null +++ b/config/dataset_depth/data_hypersim_val.yaml @@ -0,0 +1,4 @@ +name: hypersim_depth +disp_name: hypersim_depth_val +dir: hypersim/hypersim_processed_val.tar +filenames: data_split/hypersim_depth/filename_list_val_filtered.txt diff --git a/config/dataset_depth/data_kitti_eigen_test.yaml b/config/dataset_depth/data_kitti_eigen_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a2ef3f8766dc22597188f7b8911a2d3ee54eb297 --- /dev/null +++ b/config/dataset_depth/data_kitti_eigen_test.yaml @@ -0,0 +1,6 @@ +name: kitti_depth +disp_name: kitti_depth_eigen_test_full +dir: kitti/kitti_eigen_split_test.tar +filenames: data_split/kitti_depth/eigen_test_files_with_gt.txt +kitti_bm_crop: true +valid_mask_crop: eigen diff --git a/config/dataset_depth/data_kitti_val.yaml b/config/dataset_depth/data_kitti_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..68bbc54036b142bd5867a8e5acf4acca36cd268b --- /dev/null +++ b/config/dataset_depth/data_kitti_val.yaml @@ -0,0 +1,6 @@ +name: kitti_depth +disp_name: kitti_depth_val800_from_eigen_train +dir: kitti/kitti_sampled_val_800.tar +filenames: data_split/kitti_depth/eigen_val_from_train_800.txt +kitti_bm_crop: true +valid_mask_crop: eigen diff --git a/config/dataset_depth/data_nyu_test.yaml b/config/dataset_depth/data_nyu_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..23799be1f1a4939422cc5b5ff6217e226fb604c3 --- /dev/null +++ b/config/dataset_depth/data_nyu_test.yaml @@ -0,0 +1,5 @@ +name: nyu_depth +disp_name: nyu_depth_test_full +dir: nyuv2/nyu_labeled_extracted.tar +filenames: data_split/nyu_depth/labeled/filename_list_test.txt +eigen_valid_mask: true diff --git a/config/dataset_depth/data_nyu_train.yaml b/config/dataset_depth/data_nyu_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d8d7e84a0c6e7bc3698fbe45e8c8241966fab512 --- /dev/null +++ b/config/dataset_depth/data_nyu_train.yaml @@ -0,0 +1,5 @@ +name: nyu_depth +disp_name: nyu_depth_train_full +dir: nyuv2/nyu_labeled_extracted.tar +filenames: data_split/nyu_depth/labeled/filename_list_train.txt +eigen_valid_mask: true diff --git a/config/dataset_depth/data_scannet_val.yaml b/config/dataset_depth/data_scannet_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e99b6cee6e25ac880784d4a5dff9071390180ed --- /dev/null +++ b/config/dataset_depth/data_scannet_val.yaml @@ -0,0 +1,4 @@ +name: scannet_depth +disp_name: scannet_depth_val_800_1 +dir: scannet/scannet_val_sampled_800_1.tar +filenames: data_split/scannet_depth/scannet_val_sampled_list_800_1.txt diff --git a/config/dataset_depth/data_vkitti_train.yaml b/config/dataset_depth/data_vkitti_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..11aea772903c640c815f8d70b1a2263e6e6dbca8 --- /dev/null +++ b/config/dataset_depth/data_vkitti_train.yaml @@ -0,0 +1,6 @@ +name: vkitti_depth +disp_name: vkitti_depth_train +dir: vkitti/vkitti.tar +filenames: data_split/vkitti_depth/vkitti_train.txt +kitti_bm_crop: true +valid_mask_crop: null # no valid_mask_crop for training diff --git a/config/dataset_depth/data_vkitti_val.yaml b/config/dataset_depth/data_vkitti_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1c9862b8e1d59203b51398fadcd0b7ac1db511d8 --- /dev/null +++ b/config/dataset_depth/data_vkitti_val.yaml @@ -0,0 +1,6 @@ +name: vkitti_depth +disp_name: vkitti_depth_val +dir: vkitti/vkitti.tar +filenames: data_split/vkitti_depth/vkitti_val.txt +kitti_bm_crop: true +valid_mask_crop: eigen diff --git a/config/dataset_depth/dataset_train.yaml b/config/dataset_depth/dataset_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..214b3f2febad3d80ca61461c176d1614d7cb68e2 --- /dev/null +++ b/config/dataset_depth/dataset_train.yaml @@ -0,0 +1,18 @@ +dataset: + train: + name: mixed + prob_ls: [0.9, 0.1] + dataset_list: + - name: hypersim_depth + disp_name: hypersim_depth_train + dir: hypersim/hypersim_processed_train.tar + filenames: data_split/hypersim_depth/filename_list_train_filtered.txt + resize_to_hw: + - 480 + - 640 + - name: vkitti_depth + disp_name: vkitti_depth_train + dir: vkitti/vkitti.tar + filenames: data_split/vkitti_depth/vkitti_train.txt + kitti_bm_crop: true + valid_mask_crop: null diff --git a/config/dataset_depth/dataset_val.yaml b/config/dataset_depth/dataset_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..67207c01ee2c1a4c52f387b39f9ad6428b694f09 --- /dev/null +++ b/config/dataset_depth/dataset_val.yaml @@ -0,0 +1,45 @@ +dataset: + val: + # - name: hypersim_depth + # disp_name: hypersim_depth_val + # dir: hypersim/hypersim_processed_val.tar + # filenames: data_split/hypersim_depth/filename_list_val_filtered.txt + # resize_to_hw: + # - 480 + # - 640 + + # - name: nyu_depth + # disp_name: nyu_depth_train_full + # dir: nyuv2/nyu_labeled_extracted.tar + # filenames: data_split/nyu_depth/labeled/filename_list_train.txt + # eigen_valid_mask: true + + # - name: kitti_depth + # disp_name: kitti_depth_val800_from_eigen_train + # dir: kitti/kitti_depth_sampled_val_800.tar + # filenames: data_split/kitti_depth/eigen_val_from_train_800.txt + # kitti_bm_crop: true + # valid_mask_crop: eigen + + # Smaller subsets for faster validation during training + # The first dataset is used to calculate main eval metric. + - name: hypersim_depth + disp_name: hypersim_depth_val_small_80 + dir: hypersim/hypersim_processed_val.tar + filenames: data_split/hypersim_depth/filename_list_val_filtered_small_80.txt + resize_to_hw: + - 480 + - 640 + + - name: nyu_depth + disp_name: nyu_depth_train_small_100 + dir: nyuv2/nyu_labeled_extracted.tar + filenames: data_split/nyu_depth/labeled/filename_list_train_small_100.txt + eigen_valid_mask: true + + - name: kitti_depth + disp_name: kitti_depth_val_from_train_sub_100 + dir: kitti/kitti_sampled_val_800.tar + filenames: data_split/kitti_depth/eigen_val_from_train_sub_100.txt + kitti_bm_crop: true + valid_mask_crop: eigen diff --git a/config/dataset_depth/dataset_vis.yaml b/config/dataset_depth/dataset_vis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b1dd23748fd784c5e5c2e3210e6b9a8a837eb0e --- /dev/null +++ b/config/dataset_depth/dataset_vis.yaml @@ -0,0 +1,9 @@ +dataset: + vis: + - name: hypersim_depth + disp_name: hypersim_depth_vis + dir: hypersim/hypersim_processed_val.tar + filenames: data_split/hypersim_depth/selected_vis_sample.txt + resize_to_hw: + - 480 + - 640 diff --git a/config/dataset_iid/data_appearance_interiorverse_test.yaml b/config/dataset_iid/data_appearance_interiorverse_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26e1d6cc2b8dba72dc5344dc77e4399bfdecdb4d --- /dev/null +++ b/config/dataset_iid/data_appearance_interiorverse_test.yaml @@ -0,0 +1,4 @@ +name: interiorverse_iid +disp_name: interiorverse_iid_appearance_test +dir: interiorverse/InteriorVerse.tar +filenames: data_split/interiorverse_iid/interiorverse_test_scenes_85.txt diff --git a/config/dataset_iid/data_appearance_synthetic_test.yaml b/config/dataset_iid/data_appearance_synthetic_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..413910624eb238661ee2b1be14fb4d453ce1194b --- /dev/null +++ b/config/dataset_iid/data_appearance_synthetic_test.yaml @@ -0,0 +1,4 @@ +name: interiorverse_iid +disp_name: interiorverse_iid_appearance_test +dir: synthetic +filenames: data_split/osu/osu_test_scenes_85.txt diff --git a/config/dataset_iid/data_art_test.yaml b/config/dataset_iid/data_art_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c65bdf6b15896f335f3f1234fe8a7403a85b897 --- /dev/null +++ b/config/dataset_iid/data_art_test.yaml @@ -0,0 +1,4 @@ +name: interiorverse_iid +disp_name: interiorverse_iid_appearance_test +dir: art +filenames: data_split/osu/art_test_scenes.txt \ No newline at end of file diff --git a/config/dataset_iid/data_lighting_hypersim_test.yaml b/config/dataset_iid/data_lighting_hypersim_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a6a6777ff279527a85cf70dcfc204c2073fd05a --- /dev/null +++ b/config/dataset_iid/data_lighting_hypersim_test.yaml @@ -0,0 +1,4 @@ +name: hypersim_iid +disp_name: hypersim_iid_lighting_test +dir: hypersim +filenames: data_split/hypersim_iid/hypersim_test.txt diff --git a/config/dataset_iid/dataset_appearance_train.yaml b/config/dataset_iid/dataset_appearance_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a77bf692ceab8484c09659456bfcddfaa00ff9f --- /dev/null +++ b/config/dataset_iid/dataset_appearance_train.yaml @@ -0,0 +1,9 @@ +dataset: + train: + name: mixed + prob_ls: [1.0] + dataset_list: + - name: interiorverse_iid + disp_name: interiorverse_iid_appearance_train + dir: osu_albedo_new + filenames: data_split/osu/osu_train_scenes_85.txt diff --git a/config/dataset_iid/dataset_appearance_val.yaml b/config/dataset_iid/dataset_appearance_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..75558aa513633ac5d4dc72e7532d1907712d8274 --- /dev/null +++ b/config/dataset_iid/dataset_appearance_val.yaml @@ -0,0 +1,6 @@ +dataset: + val: + - name: interiorverse_iid + disp_name: interiorverse_iid_appearance_val + dir: synthetic + filenames: data_split/MatrixCity/matrixcity_val_scenes_small.txt diff --git a/config/dataset_iid/dataset_appearance_vis.yaml b/config/dataset_iid/dataset_appearance_vis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a7956fc29be25dc6ca7701275c13dbe6c6e3857 --- /dev/null +++ b/config/dataset_iid/dataset_appearance_vis.yaml @@ -0,0 +1,6 @@ +dataset: + vis: + - name: interiorverse_iid + disp_name: interiorverse_iid_appearance_vis + dir: synthetic + filenames: data_split/MatrixCity/matrixcity_vis_scenes.txt diff --git a/config/dataset_iid/dataset_lighting_train.yaml b/config/dataset_iid/dataset_lighting_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..031c62264689d512dea92efc51870d9522a88d4b --- /dev/null +++ b/config/dataset_iid/dataset_lighting_train.yaml @@ -0,0 +1,12 @@ +dataset: + train: + name: mixed + prob_ls: [1.0] + dataset_list: + - name: hypersim_iid + disp_name: hypersim_iid_lighting_train + dir: hypersim + filenames: data_split/hypersim_iid/hypersim_train_filtered.txt + resize_to_hw: + - 480 + - 640 \ No newline at end of file diff --git a/config/dataset_iid/dataset_lighting_val.yaml b/config/dataset_iid/dataset_lighting_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58e2da3facbff138bbe639bf88287f541ddbf281 --- /dev/null +++ b/config/dataset_iid/dataset_lighting_val.yaml @@ -0,0 +1,6 @@ +dataset: + val: + - name: hypersim_iid + disp_name: hypersim_iid_lighting_val + dir: hypersim + filenames: data_split/hypersim_iid/hypersim_val.txt diff --git a/config/dataset_iid/dataset_lighting_vis.yaml b/config/dataset_iid/dataset_lighting_vis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..70608ae85e27f7c3aa5a2069edda3cbc20b33872 --- /dev/null +++ b/config/dataset_iid/dataset_lighting_vis.yaml @@ -0,0 +1,6 @@ +dataset: + vis: + - name: hypersim_iid + disp_name: hypersim_iid_lighting_vis + dir: hypersim + filenames: data_split/hypersim_iid/hypersim_vis.txt diff --git a/config/dataset_iid/osu_data_appearance_interiorverse_test.yaml b/config/dataset_iid/osu_data_appearance_interiorverse_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..413910624eb238661ee2b1be14fb4d453ce1194b --- /dev/null +++ b/config/dataset_iid/osu_data_appearance_interiorverse_test.yaml @@ -0,0 +1,4 @@ +name: interiorverse_iid +disp_name: interiorverse_iid_appearance_test +dir: synthetic +filenames: data_split/osu/osu_test_scenes_85.txt diff --git a/config/dataset_normals/data_diode_test.yaml b/config/dataset_normals/data_diode_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b8de2bbf22bd58792e12fc2d126b8402a80bec64 --- /dev/null +++ b/config/dataset_normals/data_diode_test.yaml @@ -0,0 +1,4 @@ +name: diode_normals +disp_name: diode_normals_test +dir: diode/val +filenames: data_split/diode_normals/diode_test.txt diff --git a/config/dataset_normals/data_ibims_test.yaml b/config/dataset_normals/data_ibims_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..41e43ec4f5b31c7d1aaedeca2e90986018121c69 --- /dev/null +++ b/config/dataset_normals/data_ibims_test.yaml @@ -0,0 +1,4 @@ +name: ibims_normals +disp_name: ibims_normals_test +dir: ibims/ibims +filenames: data_split/ibims_normals/ibims_test.txt diff --git a/config/dataset_normals/data_nyu_test.yaml b/config/dataset_normals/data_nyu_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..13263dd0c6a06f89cf69ad94643261477db114ff --- /dev/null +++ b/config/dataset_normals/data_nyu_test.yaml @@ -0,0 +1,4 @@ +name: nyu_normals +disp_name: nyu_normals_test +dir: nyuv2/test +filenames: data_split/nyu_normals/nyuv2_test.txt diff --git a/config/dataset_normals/data_oasis_test.yaml b/config/dataset_normals/data_oasis_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d00eb6d6ad37c4cf5e573fa8b45f0fae3ae73e98 --- /dev/null +++ b/config/dataset_normals/data_oasis_test.yaml @@ -0,0 +1,4 @@ +name: oasis_normals +disp_name: oasis_normals_test +dir: oasis/val +filenames: data_split/oasis_normals/oasis_test.txt diff --git a/config/dataset_normals/data_scannet_test.yaml b/config/dataset_normals/data_scannet_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..28d6e006795610c0772ff1704af664de63afb6b1 --- /dev/null +++ b/config/dataset_normals/data_scannet_test.yaml @@ -0,0 +1,4 @@ +name: scannet_normals +disp_name: scannet_normals_test +dir: scannet +filenames: data_split/scannet_normals/scannet_test.txt diff --git a/config/dataset_normals/dataset_train.yaml b/config/dataset_normals/dataset_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..053f25228e503f447da502da8e8915f9e22f174e --- /dev/null +++ b/config/dataset_normals/dataset_train.yaml @@ -0,0 +1,25 @@ +dataset: + train: + name: mixed + prob_ls: [0.5, 0.49, 0.01] + dataset_list: + - name: hypersim_normals + disp_name: hypersim_normals_train + dir: hypersim + filenames: data_split/hypersim_normals/hypersim_filtered_all.txt + resize_to_hw: + - 480 + - 640 + - name: interiorverse_normals + disp_name: interiorverse_normals_train + dir: interiorverse/scenes_85 + filenames: data_split/interiorverse_normals/interiorverse_filtered_all.txt + resize_to_hw: null + - name: sintel_normals + disp_name: sintel_normals_train + dir: sintel + filenames: data_split/sintel_normals/sintel_filtered.txt + resize_to_hw: + - 480 + - 640 + center_crop: true diff --git a/config/dataset_normals/dataset_val.yaml b/config/dataset_normals/dataset_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b4a23d65b489af5e6afb3d6c5b325778bb85e8d7 --- /dev/null +++ b/config/dataset_normals/dataset_val.yaml @@ -0,0 +1,7 @@ +dataset: + val: + - name: hypersim_normals + disp_name: hypersim_normals_val_small_100 + dir: hypersim + filenames: data_split/hypersim_normals/hypersim_filtered_val_100.txt + resize_to_hw: null diff --git a/config/dataset_normals/dataset_vis.yaml b/config/dataset_normals/dataset_vis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ff805087bcd423b2d8c66d0d96c250394d61cc3 --- /dev/null +++ b/config/dataset_normals/dataset_vis.yaml @@ -0,0 +1,7 @@ +dataset: + vis: + - name: hypersim_normals + disp_name: hypersim_normals_vis + dir: hypersim + filenames: data_split/hypersim_normals/hypersim_filtered_vis_20.txt + resize_to_hw: null diff --git a/config/logging.yaml b/config/logging.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8cecbaeca91f2ae38677336b4dc7dcc5ad020f60 --- /dev/null +++ b/config/logging.yaml @@ -0,0 +1,5 @@ +logging: + filename: logging.log + format: ' %(asctime)s - %(levelname)s -%(filename)s - %(funcName)s >> %(message)s' + console_level: 20 + file_level: 10 diff --git a/config/model_sdv2.yaml b/config/model_sdv2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4fb702126ae11eb8b74d57f0d5569bd790269af6 --- /dev/null +++ b/config/model_sdv2.yaml @@ -0,0 +1,4 @@ +model: + name: marigold_pipeline + pretrained_path: stable-diffusion-2 + latent_scale_factor: 0.18215 diff --git a/config/train_debug_depth.yaml b/config/train_debug_depth.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ba4ddfada72fedd93acf7bc0d4a4f0e04129fabf --- /dev/null +++ b/config/train_debug_depth.yaml @@ -0,0 +1,10 @@ +base_config: + - config/train_marigold_depth.yaml + +trainer: + save_period: 5 + backup_period: 10 + validation_period: 5 + visualization_period: 5 + +max_iter: 50 diff --git a/config/train_debug_iid.yaml b/config/train_debug_iid.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b8b95c038eb22c39f3b23a6329126a381e3ba46b --- /dev/null +++ b/config/train_debug_iid.yaml @@ -0,0 +1,11 @@ +base_config: + # - config/train_marigold_iid_lighting.yaml + - config/train_marigold_iid_appearance.yaml + +trainer: + save_period: 10 + backup_period: 10 + validation_period: 5 + visualization_period: 5 + +max_iter: 50 diff --git a/config/train_debug_normals.yaml b/config/train_debug_normals.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dcfb4e146c63a3a662fd05e65f06f56b39611a67 --- /dev/null +++ b/config/train_debug_normals.yaml @@ -0,0 +1,10 @@ +base_config: + - config/train_marigold_normals.yaml + +trainer: + save_period: 5 + backup_period: 10 + validation_period: 5 + visualization_period: 5 + +max_iter: 50 diff --git a/config/train_marigold_depth.yaml b/config/train_marigold_depth.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f35828b7a51c4c39fb4ded843a4e7f35308b2bd2 --- /dev/null +++ b/config/train_marigold_depth.yaml @@ -0,0 +1,94 @@ +base_config: +- config/logging.yaml +- config/wandb.yaml +- config/dataset_depth/dataset_train.yaml +- config/dataset_depth/dataset_val.yaml +- config/dataset_depth/dataset_vis.yaml +- config/model_sdv2.yaml + +pipeline: + name: MarigoldDepthPipeline + kwargs: + scale_invariant: true + shift_invariant: true + default_denoising_steps: 4 + default_processing_resolution: 768 + +depth_normalization: + type: scale_shift_depth + clip: true + norm_min: -1.0 + norm_max: 1.0 + min_max_quantile: 0.02 + +augmentation: + lr_flip_p: 0.5 + +dataloader: + num_workers: 2 + effective_batch_size: 32 + max_train_batch_size: 2 + seed: 2024 # to ensure continuity when resuming from checkpoint + +trainer: + name: MarigoldDepthTrainer + training_noise_scheduler: + pretrained_path: stable-diffusion-2 + init_seed: 2024 # use null to train w/o seeding + save_period: 50 + backup_period: 2000 + validation_period: 500 + visualization_period: 1000 + +multi_res_noise: + strength: 0.9 + annealed: true + downscale_strategy: original + +gt_depth_type: depth_raw_norm +gt_mask_type: valid_mask_raw + +max_epoch: 10000 # a large enough number +max_iter: 30000 # usually converges at around 20k + +optimizer: + name: Adam + +loss: + name: mse_loss + kwargs: + reduction: mean + +lr: 3.0e-05 +lr_scheduler: + name: IterExponential + kwargs: + total_iter: 25000 + final_ratio: 0.01 + warmup_steps: 100 + +# Light setting for the in-training validation and visualization +validation: + denoising_steps: 1 + ensemble_size: 1 + processing_res: 0 + match_input_res: false + resample_method: bilinear + main_val_metric: abs_relative_difference + main_val_metric_goal: minimize + init_seed: 2024 + +eval: + alignment: least_square + align_max_res: null + eval_metrics: + - abs_relative_difference + - squared_relative_difference + - rmse_linear + - rmse_log + - log10 + - delta1_acc + - delta2_acc + - delta3_acc + - i_rmse + - silog_rmse diff --git a/config/train_marigold_iid_appearance.yaml b/config/train_marigold_iid_appearance.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e2492257c3e5e3559c1952bd606ac7914e495775 --- /dev/null +++ b/config/train_marigold_iid_appearance.yaml @@ -0,0 +1,81 @@ +base_config: +- config/logging.yaml +- config/wandb.yaml +- config/dataset_iid/dataset_appearance_train.yaml +- config/dataset_iid/dataset_appearance_val.yaml +- config/dataset_iid/dataset_appearance_vis.yaml +- config/model_sdv2.yaml + +pipeline: + name: MarigoldIIDPipeline + kwargs: + default_denoising_steps: 4 + default_processing_resolution: 768 + target_properties: + target_names: + - albedo + albedo: + prediction_space: srgb + +augmentation: + lr_flip_p: 0.5 + +dataloader: + num_workers: 2 + effective_batch_size: 32 + max_train_batch_size: 8 + seed: 2024 # to ensure continuity when resuming from checkpoint + +trainer: + name: MarigoldIIDTrainer + training_noise_scheduler: + pretrained_path: stable-diffusion-2 + init_seed: 2024 # use null to train w/o seeding + save_period: 50 + backup_period: 2000 + validation_period: 500 + visualization_period: 1000 + +multi_res_noise: + strength: 0.9 + annealed: true + downscale_strategy: original + +gt_mask_type: mask + +max_epoch: 10000 # a large enough number +max_iter: 10000 # usually converges at around 40k + +optimizer: + name: Adam + +loss: + name: mse_loss + kwargs: + reduction: mean + +lr: 2.0e-05 +lr_scheduler: + name: IterExponential + kwargs: + total_iter: 5000 + final_ratio: 0.01 + warmup_steps: 100 + +# Light setting for the in-training validation and visualization +validation: + denoising_steps: 4 + ensemble_size: 1 + processing_res: 0 + match_input_res: true + resample_method: bilinear + main_val_metric: psnr + main_val_metric_goal: maximize + init_seed: 2024 + use_mask: false + +eval: + eval_metrics: + - psnr + targets_to_eval_in_linear_space: + - material diff --git a/config/train_marigold_iid_appearance_finetuned.yaml b/config/train_marigold_iid_appearance_finetuned.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1942b4ed03898a07282f03683edb5866dd6dfb83 --- /dev/null +++ b/config/train_marigold_iid_appearance_finetuned.yaml @@ -0,0 +1,81 @@ +base_config: +- config/logging.yaml +- config/wandb.yaml +- config/dataset_iid/dataset_appearance_train.yaml +- config/dataset_iid/dataset_appearance_val.yaml +- config/dataset_iid/dataset_appearance_vis.yaml +- config/model_sdv2.yaml + +pipeline: + name: MarigoldIIDPipeline + kwargs: + default_denoising_steps: 4 + default_processing_resolution: 768 + target_properties: + target_names: + - albedo + albedo: + prediction_space: srgb + +augmentation: + lr_flip_p: 0.5 + +dataloader: + num_workers: 2 + effective_batch_size: 32 + max_train_batch_size: 8 + seed: 2024 # to ensure continuity when resuming from checkpoint + +trainer: + name: MarigoldIIDTrainer + training_noise_scheduler: + pretrained_path: stable-diffusion-2 + init_seed: 2024 # use null to train w/o seeding + save_period: 50 + backup_period: 2000 + validation_period: 177 + visualization_period: 177 + +multi_res_noise: + strength: 0.9 + annealed: true + downscale_strategy: original + +gt_mask_type: null + +max_epoch: 10000 # a large enough number +max_iter: 5000 # usually converges at around 40k + +optimizer: + name: Adam + +loss: + name: mse_loss + kwargs: + reduction: mean + +lr: 5.0e-07 +lr_scheduler: + name: IterExponential + kwargs: + total_iter: 2500 + final_ratio: 0.01 + warmup_steps: 100 + +# Light setting for the in-training validation and visualization +validation: + denoising_steps: 4 + ensemble_size: 1 + processing_res: 1000 + match_input_res: true + resample_method: bilinear + main_val_metric: psnr + main_val_metric_goal: maximize + init_seed: 2024 + use_mask: false + +eval: + eval_metrics: + - psnr + targets_to_eval_in_linear_space: + - material diff --git a/config/train_marigold_iid_lighting.yaml b/config/train_marigold_iid_lighting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1194dc773ad62387e688efdca6a5f7ffb28a4cdd --- /dev/null +++ b/config/train_marigold_iid_lighting.yaml @@ -0,0 +1,82 @@ +base_config: +- config/logging.yaml +- config/wandb.yaml +- config/dataset_iid/dataset_lighting_train.yaml +- config/dataset_iid/dataset_lighting_val.yaml +- config/dataset_iid/dataset_lighting_vis.yaml +- config/model_sdv2.yaml + +pipeline: + name: MarigoldIIDPipeline + kwargs: + default_denoising_steps: 4 + default_processing_resolution: 768 + target_properties: + target_names: + - albedo + albedo: + prediction_space: linear + up_to_scale: false + +augmentation: + lr_flip_p: 0.5 + +dataloader: + num_workers: 2 + effective_batch_size: 32 + max_train_batch_size: 8 + seed: 2024 # to ensure continuity when resuming from checkpoint + +trainer: + name: MarigoldIIDTrainer + training_noise_scheduler: + pretrained_path: stable-diffusion-2 + init_seed: 2024 # use null to train w/o seeding + save_period: 50 + backup_period: 2000 + validation_period: 500 + visualization_period: 1000 + +multi_res_noise: + strength: 0.9 + annealed: true + downscale_strategy: original + +gt_mask_type: mask + +max_epoch: 10000 # a large enough number +max_iter: 50000 # usually converges at around 34k + +optimizer: + name: Adam + +loss: + name: mse_loss + kwargs: + reduction: mean + +lr: 8e-05 +lr_scheduler: + name: IterExponential + kwargs: + total_iter: 45000 + final_ratio: 0.01 + warmup_steps: 100 + +# Light setting for the in-training validation and visualization +validation: + denoising_steps: 4 + ensemble_size: 1 + processing_res: 0 + match_input_res: true + resample_method: bilinear + main_val_metric: psnr + main_val_metric_goal: maximize + init_seed: 2024 + use_mask: false + +eval: + eval_metrics: + - psnr + targets_to_eval_in_linear_space: + - None diff --git a/config/train_marigold_normals.yaml b/config/train_marigold_normals.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a3a321e4a872a4334c4eedd4356de88a4d5619b --- /dev/null +++ b/config/train_marigold_normals.yaml @@ -0,0 +1,86 @@ +base_config: +- config/logging.yaml +- config/wandb.yaml +- config/dataset_normals/dataset_train.yaml +- config/dataset_normals/dataset_val.yaml +- config/dataset_normals/dataset_vis.yaml +- config/model_sdv2.yaml + +pipeline: + name: MarigoldNormalsPipeline + kwargs: + default_denoising_steps: 4 + default_processing_resolution: 768 + +augmentation: + lr_flip_p: 0.5 + color_jitter_p: 0.3 + gaussian_blur_p: 0.3 + motion_blur_p: 0.3 + gaussian_blur_sigma: 4 + motion_blur_kernel_size: 11 + motion_blur_angle_range: 360 + jitter_brightness_factor: 0.5 + jitter_contrast_factor: 0.5 + jitter_saturation_factor: 0.5 + jitter_hue_factor: 0.2 + +dataloader: + num_workers: 2 + effective_batch_size: 32 + max_train_batch_size: 2 + seed: 2024 # to ensure continuity when resuming from checkpoint + +trainer: + name: MarigoldNormalsTrainer + training_noise_scheduler: + pretrained_path: stable-diffusion-2 + init_seed: 2024 # use null to train w/o seeding + save_period: 50 + backup_period: 2000 + validation_period: 500 + visualization_period: 1000 + +multi_res_noise: + strength: 0.9 + annealed: true + downscale_strategy: original + +gt_normals_type: normals +gt_mask_type: null + +max_epoch: 10000 # a large enough number +max_iter: 30000 # usually converges at around 26k + +optimizer: + name: Adam + +loss: + name: mse_loss + kwargs: + reduction: mean + +lr: 6.0e-05 +lr_scheduler: + name: IterExponential + kwargs: + total_iter: 25000 + final_ratio: 0.01 + warmup_steps: 100 + +# Light setting for the in-training validation and visualization +validation: + denoising_steps: 4 + ensemble_size: 1 + processing_res: 768 + match_input_res: true + resample_method: bilinear + main_val_metric: mean_angular_error + main_val_metric_goal: minimize + init_seed: 0 + +eval: + align_max_res: null + eval_metrics: + - mean_angular_error + - sub11_25_error diff --git a/config/wandb.yaml b/config/wandb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1840d26b8cd85da1d922117305dbaa9faaccaee2 --- /dev/null +++ b/config/wandb.yaml @@ -0,0 +1,3 @@ +wandb: + # entity: your_entity + project: marigold diff --git a/marigold/marigold_depth_pipeline.py b/marigold/marigold_depth_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..07d02e2bba3ce8b4166541d65dd63d22d21ec8b7 --- /dev/null +++ b/marigold/marigold_depth_pipeline.py @@ -0,0 +1,516 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import logging +import numpy as np +import torch +from PIL import Image +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + LCMScheduler, + UNet2DConditionModel, +) +from diffusers.utils import BaseOutput +from torch.utils.data import DataLoader, TensorDataset +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import pil_to_tensor, resize +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from typing import Dict, Optional, Union + +from .util.batchsize import find_batch_size +from .util.ensemble import ensemble_depth +from .util.image_util import ( + chw2hwc, + colorize_depth_maps, + get_tv_resample_method, + resize_max_res, +) + + +class MarigoldDepthOutput(BaseOutput): + """ + Output class for Marigold Monocular Depth Estimation pipeline. + + Args: + depth_np (`np.ndarray`): + Predicted depth map, with depth values in the range of [0, 1]. + depth_colored (`PIL.Image.Image`): + Colorized depth map, with the shape of [H, W, 3] and values in [0, 255]. + uncertainty (`None` or `np.ndarray`): + Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. + """ + + depth_np: np.ndarray + depth_colored: Union[None, Image.Image] + uncertainty: Union[None, np.ndarray] + + +class MarigoldDepthPipeline(DiffusionPipeline): + """ + Pipeline for Marigold Monocular Depth Estimation: https://marigoldcomputervision.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the prediction latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and predictions + to and from latent representations. + scheduler (`DDIMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + scale_invariant (`bool`, *optional*): + A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in + the model config. When used together with the `shift_invariant=True` flag, the model is also called + "affine-invariant". NB: overriding this value is not supported. + shift_invariant (`bool`, *optional*): + A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in + the model config. When used together with the `scale_invariant=True` flag, the model is also called + "affine-invariant". NB: overriding this value is not supported. + default_denoising_steps (`int`, *optional*): + The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable + quality with the given model. This value must be set in the model config. When the pipeline is called + without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure + reasonable results with various model flavors compatible with the pipeline, such as those relying on very + short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). + default_processing_resolution (`int`, *optional*): + The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in + the model config. When the pipeline is called without explicitly setting `processing_resolution`, the + default value is used. This is required to ensure reasonable results with various model flavors trained + with varying optimal processing resolution values. + """ + + latent_scale_factor = 0.18215 + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, LCMScheduler], + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + scale_invariant: Optional[bool] = True, + shift_invariant: Optional[bool] = True, + default_denoising_steps: Optional[int] = None, + default_processing_resolution: Optional[int] = None, + ): + super().__init__() + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + self.register_to_config( + scale_invariant=scale_invariant, + shift_invariant=shift_invariant, + default_denoising_steps=default_denoising_steps, + default_processing_resolution=default_processing_resolution, + ) + + self.scale_invariant = scale_invariant + self.shift_invariant = shift_invariant + self.default_denoising_steps = default_denoising_steps + self.default_processing_resolution = default_processing_resolution + + self.empty_text_embed = None + + @torch.no_grad() + def __call__( + self, + input_image: Union[Image.Image, torch.Tensor], + denoising_steps: Optional[int] = None, + ensemble_size: int = 1, + processing_res: Optional[int] = None, + match_input_res: bool = True, + resample_method: str = "bilinear", + batch_size: int = 0, + generator: Union[torch.Generator, None] = None, + color_map: str = "Spectral", + show_progress_bar: bool = True, + ensemble_kwargs: Dict = None, + ) -> MarigoldDepthOutput: + """ + Function invoked when calling the pipeline. + + Args: + input_image (`Image`): + Input RGB (or gray-scale) image. + denoising_steps (`int`, *optional*, defaults to `None`): + Number of denoising diffusion steps during inference. The default value `None` results in automatic + selection. + ensemble_size (`int`, *optional*, defaults to `1`): + Number of predictions to be ensembled. + processing_res (`int`, *optional*, defaults to `None`): + Effective processing resolution. When set to `0`, processes at the original image resolution. This + produces crisper predictions, but may also lead to the overall loss of global context. The default + value `None` resolves to the optimal value from the model config. + match_input_res (`bool`, *optional*, defaults to `True`): + Resize the prediction to match the input resolution. + Only valid if `processing_res` > 0. + resample_method: (`str`, *optional*, defaults to `bilinear`): + Resampling method used to resize images and predictions. This can be one of `bilinear`, `bicubic` or + `nearest`, defaults to: `bilinear`. + batch_size (`int`, *optional*, defaults to `0`): + Inference batch size, no bigger than `num_ensemble`. + If set to 0, the script will automatically decide the proper batch size. + generator (`torch.Generator`, *optional*, defaults to `None`) + Random generator for initial noise generation. + show_progress_bar (`bool`, *optional*, defaults to `True`): + Display a progress bar of diffusion denoising. + color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation): + Colormap used to colorize the depth map. + scale_invariant (`str`, *optional*, defaults to `True`): + Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction. + shift_invariant (`str`, *optional*, defaults to `True`): + Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, + near plane will be fixed at 0m. + ensemble_kwargs (`dict`, *optional*, defaults to `None`): + Arguments for detailed ensembling settings. + Returns: + `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: + - **depth_np** (`np.ndarray`) Predicted depth map with depth values in the range of [0, 1] + - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [H, W, 3] and values in [0, 255], None if `color_map` is `None` + - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) + coming from ensembling. None if `ensemble_size = 1` + """ + # Model-specific optimal default values leading to fast and reasonable results. + if denoising_steps is None: + denoising_steps = self.default_denoising_steps + if processing_res is None: + processing_res = self.default_processing_resolution + + assert processing_res >= 0 + assert ensemble_size >= 1 + + # Check if denoising step is reasonable + self._check_inference_step(denoising_steps) + + resample_method: InterpolationMode = get_tv_resample_method(resample_method) + + # ----------------- Image Preprocess ----------------- + # Convert to torch tensor + if isinstance(input_image, Image.Image): + input_image = input_image.convert("RGB") + # convert to torch tensor [H, W, rgb] -> [rgb, H, W] + rgb = pil_to_tensor(input_image) + rgb = rgb.unsqueeze(0) # [1, rgb, H, W] + elif isinstance(input_image, torch.Tensor): + rgb = input_image + else: + raise TypeError(f"Unknown input type: {type(input_image) = }") + input_size = rgb.shape + assert ( + 4 == rgb.dim() and 3 == input_size[-3] + ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]" + + # Resize image + if processing_res > 0: + rgb = resize_max_res( + rgb, + max_edge_resolution=processing_res, + resample_method=resample_method, + ) + + # Normalize rgb values + rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] + rgb_norm = rgb_norm.to(self.dtype) + assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 + + # ----------------- Predicting depth ----------------- + # Batch repeated input image + duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1) + single_rgb_dataset = TensorDataset(duplicated_rgb) + if batch_size > 0: + _bs = batch_size + else: + _bs = find_batch_size( + ensemble_size=ensemble_size, + input_res=max(rgb_norm.shape[1:]), + dtype=self.dtype, + ) + + single_rgb_loader = DataLoader( + single_rgb_dataset, batch_size=_bs, shuffle=False + ) + + # Predict depth maps (batched) + target_pred_ls = [] + if show_progress_bar: + iterable = tqdm( + single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False + ) + else: + iterable = single_rgb_loader + for batch in iterable: + (batched_img,) = batch + target_pred_raw = self.single_infer( + rgb_in=batched_img, + num_inference_steps=denoising_steps, + show_pbar=show_progress_bar, + generator=generator, + ) + target_pred_ls.append(target_pred_raw.detach()) + target_preds = torch.concat(target_pred_ls, dim=0) + torch.cuda.empty_cache() # clear vram cache for ensembling + + # ----------------- Test-time ensembling ----------------- + if ensemble_size > 1: + final_pred, pred_uncert = ensemble_depth( + target_preds, + scale_invariant=self.scale_invariant, + shift_invariant=self.shift_invariant, + **(ensemble_kwargs or {}), + ) + else: + final_pred = target_preds + pred_uncert = None + + # Resize back to original resolution + if match_input_res: + final_pred = resize( + final_pred, + input_size[-2:], + interpolation=resample_method, + antialias=True, + ) + + # Convert to numpy + final_pred = final_pred.squeeze() + final_pred = final_pred.cpu().numpy() + if pred_uncert is not None: + pred_uncert = pred_uncert.squeeze().cpu().numpy() + + # Clip output range + final_pred = final_pred.clip(0, 1) + + # Colorize + if color_map is not None: + depth_colored = colorize_depth_maps( + final_pred, 0, 1, cmap=color_map + ).squeeze() # [3, H, W], value in (0, 1) + depth_colored = (depth_colored * 255).astype(np.uint8) + depth_colored_hwc = chw2hwc(depth_colored) + depth_colored_img = Image.fromarray(depth_colored_hwc) + else: + depth_colored_img = None + + return MarigoldDepthOutput( + depth_np=final_pred, + depth_colored=depth_colored_img, + uncertainty=pred_uncert, + ) + + def _check_inference_step(self, n_step: int) -> None: + """ + Check if denoising step is reasonable + Args: + n_step (`int`): denoising steps + """ + assert n_step >= 1 + + if isinstance(self.scheduler, DDIMScheduler): + if "trailing" != self.scheduler.config.timestep_spacing: + logging.warning( + f"The loaded `DDIMScheduler` is configured with `timestep_spacing=" + f'"{self.scheduler.config.timestep_spacing}"`; the recommended setting is `"trailing"`. ' + f"This change is backward-compatible and yields better results. " + f"Consider using `prs-eth/marigold-depth-v1-1` for the best experience." + ) + else: + if n_step > 10: + logging.warning( + f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on " + f"the default values." + ) + if not self.scheduler.config.rescale_betas_zero_snr: + logging.warning( + f"The loaded `DDIMScheduler` is configured with `rescale_betas_zero_snr=" + f"{self.scheduler.config.rescale_betas_zero_snr}`; the recommended setting is True. " + f"Consider using `prs-eth/marigold-depth-v1-1` for the best experience." + ) + elif isinstance(self.scheduler, LCMScheduler): + logging.warning( + "DeprecationWarning: LCMScheduler will not be supported in the future. " + "Consider using `prs-eth/marigold-depth-v1-1` for the best experience." + ) + if n_step > 10: + logging.warning( + f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on " + f"the default values." + ) + else: + raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}") + + def encode_empty_text(self): + """ + Encode text embedding for empty prompt + """ + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) + self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) + + @torch.no_grad() + def single_infer( + self, + rgb_in: torch.Tensor, + num_inference_steps: int, + generator: Union[torch.Generator, None], + show_pbar: bool, + ) -> torch.Tensor: + """ + Perform a single prediction without ensembling. + + Args: + rgb_in (`torch.Tensor`): + Input RGB image. + num_inference_steps (`int`): + Number of diffusion denoisign steps (DDIM) during inference. + show_pbar (`bool`): + Display a progress bar of diffusion denoising. + generator (`torch.Generator`) + Random generator for initial noise generation. + Returns: + `torch.Tensor`: Predicted targets. + """ + device = self.device + rgb_in = rgb_in.to(device) + + # Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps # [T] + + # Encode image + rgb_latent = self.encode_rgb(rgb_in) # [B, 4, h, w] + + # Noisy latent for outputs + target_latent = torch.randn( + rgb_latent.shape, + device=device, + dtype=self.dtype, + generator=generator, + ) # [B, 4, h, w] + + # Batched empty text embedding + if self.empty_text_embed is None: + self.encode_empty_text() + batch_empty_text_embed = self.empty_text_embed.repeat( + (rgb_latent.shape[0], 1, 1) + ).to(device) # [B, 2, 1024] + + # Denoising loop + if show_pbar: + iterable = tqdm( + enumerate(timesteps), + total=len(timesteps), + leave=False, + desc=" " * 4 + "Diffusion denoising", + ) + else: + iterable = enumerate(timesteps) + + for i, t in iterable: + unet_input = torch.cat( + [rgb_latent, target_latent], dim=1 + ) # this order is important + + # predict the noise residual + noise_pred = self.unet( + unet_input, t, encoder_hidden_states=batch_empty_text_embed + ).sample # [B, 4, h, w] + + # compute the previous noisy sample x_t -> x_t-1 + target_latent = self.scheduler.step( + noise_pred, t, target_latent, generator=generator + ).prev_sample + + depth = self.decode_depth(target_latent) # [B,3,H,W] + + # clip prediction + depth = torch.clip(depth, -1.0, 1.0) + # shift to [0, 1] + depth = (depth + 1.0) / 2.0 + + return depth + + def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: + """ + Encode RGB image into latent. + + Args: + rgb_in (`torch.Tensor`): + Input RGB image to be encoded. + + Returns: + `torch.Tensor`: Image latent. + """ + # encode + h = self.vae.encoder(rgb_in) + moments = self.vae.quant_conv(h) + mean, logvar = torch.chunk(moments, 2, dim=1) + # scale latent + rgb_latent = mean * self.latent_scale_factor + return rgb_latent + + def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: + """ + Decode depth latent into depth map. + + Args: + depth_latent (`torch.Tensor`): + Depth latent to be decoded. + + Returns: + `torch.Tensor`: Decoded depth map. + """ + # scale latent + depth_latent = depth_latent / self.latent_scale_factor + # decode + z = self.vae.post_quant_conv(depth_latent) + stacked = self.vae.decoder(z) + # mean of output channels + depth_mean = stacked.mean(dim=1, keepdim=True) + return depth_mean diff --git a/marigold/marigold_normals_pipeline.py b/marigold/marigold_normals_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1a06a1fc5926f35683edfd102246078f9b0d29 --- /dev/null +++ b/marigold/marigold_normals_pipeline.py @@ -0,0 +1,479 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import logging +import numpy as np +import torch +from PIL import Image +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + LCMScheduler, + UNet2DConditionModel, +) +from diffusers.utils import BaseOutput +from torch.utils.data import DataLoader, TensorDataset +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import pil_to_tensor, resize +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from typing import Dict, Optional, Union + +from .util.batchsize import find_batch_size +from .util.ensemble import ensemble_normals +from .util.image_util import ( + chw2hwc, + get_tv_resample_method, + resize_max_res, +) + + +class MarigoldNormalsOutput(BaseOutput): + """ + Output class for Marigold Surface Normals Estimation pipeline. + + Args: + normals_np (`np.ndarray`): + Predicted normals map of shape [3, H, W] with values in the range of [-1, 1] (unit length vectors). + normals_img (`PIL.Image.Image`): + Normals image, with the shape of [H, W, 3] and values in [0, 255]. + uncertainty (`None` or `np.ndarray`): + Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. + """ + + normals_np: np.ndarray + normals_img: Image.Image + uncertainty: Union[None, np.ndarray] + + +class MarigoldNormalsPipeline(DiffusionPipeline): + """ + Pipeline for Marigold Surface Normals Estimation: https://marigoldcomputervision.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the prediction latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and predictions + to and from latent representations. + scheduler (`DDIMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + default_denoising_steps (`int`, *optional*): + The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable + quality with the given model. This value must be set in the model config. When the pipeline is called + without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure + reasonable results with various model flavors compatible with the pipeline, such as those relying on very + short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). + default_processing_resolution (`int`, *optional*): + The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in + the model config. When the pipeline is called without explicitly setting `processing_resolution`, the + default value is used. This is required to ensure reasonable results with various model flavors trained + with varying optimal processing resolution values. + """ + + latent_scale_factor = 0.18215 + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, LCMScheduler], + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + default_denoising_steps: Optional[int] = None, + default_processing_resolution: Optional[int] = None, + ): + super().__init__() + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + self.register_to_config( + default_denoising_steps=default_denoising_steps, + default_processing_resolution=default_processing_resolution, + ) + + self.default_denoising_steps = default_denoising_steps + self.default_processing_resolution = default_processing_resolution + + self.empty_text_embed = None + + @torch.no_grad() + def __call__( + self, + input_image: Union[Image.Image, torch.Tensor], + denoising_steps: Optional[int] = None, + ensemble_size: int = 1, + processing_res: Optional[int] = None, + match_input_res: bool = True, + resample_method: str = "bilinear", + batch_size: int = 0, + generator: Union[torch.Generator, None] = None, + show_progress_bar: bool = True, + ensemble_kwargs: Dict = None, + ) -> MarigoldNormalsOutput: + """ + Function invoked when calling the pipeline. + + Args: + input_image (`Image`): + Input RGB (or gray-scale) image. + denoising_steps (`int`, *optional*, defaults to `None`): + Number of denoising diffusion steps during inference. The default value `None` results in automatic + selection. + ensemble_size (`int`, *optional*, defaults to `1`): + Number of predictions to be ensembled. + processing_res (`int`, *optional*, defaults to `None`): + Effective processing resolution. When set to `0`, processes at the original image resolution. This + produces crisper predictions, but may also lead to the overall loss of global context. The default + value `None` resolves to the optimal value from the model config. + match_input_res (`bool`, *optional*, defaults to `True`): + Resize the prediction to match the input resolution. + Only valid if `processing_res` > 0. + resample_method: (`str`, *optional*, defaults to `bilinear`): + Resampling method used to resize images and predictions. This can be one of `bilinear`, `bicubic` or + `nearest`, defaults to: `bilinear`. + batch_size (`int`, *optional*, defaults to `0`): + Inference batch size, no bigger than `num_ensemble`. + If set to 0, the script will automatically decide the proper batch size. + generator (`torch.Generator`, *optional*, defaults to `None`) + Random generator for initial noise generation. + show_progress_bar (`bool`, *optional*, defaults to `True`): + Display a progress bar of diffusion denoising. + ensemble_kwargs (`dict`, *optional*, defaults to `None`): + Arguments for detailed ensembling settings. + Returns: + `MarigoldNormalsOutput`: Output class for Marigold monocular surface normals estimation pipeline, including: + - **normals_np** (`np.ndarray`) Predicted normals map of shape [3, H, W] with values in the range of [-1, 1] + (unit length vectors) + - **normals_img** (`PIL.Image.Image`) Normals image, with the shape of [H, W, 3] and values in [0, 255] + - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) + coming from ensembling. None if `ensemble_size = 1` + """ + # Model-specific optimal default values leading to fast and reasonable results. + if denoising_steps is None: + denoising_steps = self.default_denoising_steps + if processing_res is None: + processing_res = self.default_processing_resolution + + assert processing_res >= 0 + assert ensemble_size >= 1 + + # Check if denoising step is reasonable + self._check_inference_step(denoising_steps) + + resample_method: InterpolationMode = get_tv_resample_method(resample_method) + + # ----------------- Image Preprocess ----------------- + # Convert to torch tensor + if isinstance(input_image, Image.Image): + input_image = input_image.convert("RGB") + # convert to torch tensor [H, W, rgb] -> [rgb, H, W] + rgb = pil_to_tensor(input_image) + rgb = rgb.unsqueeze(0) # [1, rgb, H, W] + elif isinstance(input_image, torch.Tensor): + rgb = input_image + else: + raise TypeError(f"Unknown input type: {type(input_image) = }") + input_size = rgb.shape + assert ( + 4 == rgb.dim() and 3 == input_size[-3] + ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]" + + # Resize image + if processing_res > 0: + rgb = resize_max_res( + rgb, + max_edge_resolution=processing_res, + resample_method=resample_method, + ) + + # Normalize rgb values + rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] + rgb_norm = rgb_norm.to(self.dtype) + assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 + + # ----------------- Predicting normals ----------------- + # Batch repeated input image + duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1) + single_rgb_dataset = TensorDataset(duplicated_rgb) + if batch_size > 0: + _bs = batch_size + else: + _bs = find_batch_size( + ensemble_size=ensemble_size, + input_res=max(rgb_norm.shape[1:]), + dtype=self.dtype, + ) + + single_rgb_loader = DataLoader( + single_rgb_dataset, batch_size=_bs, shuffle=False + ) + + # Predict normals maps (batched) + target_pred_ls = [] + if show_progress_bar: + iterable = tqdm( + single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False + ) + else: + iterable = single_rgb_loader + for batch in iterable: + (batched_img,) = batch + target_pred_raw = self.single_infer( + rgb_in=batched_img, + num_inference_steps=denoising_steps, + show_pbar=show_progress_bar, + generator=generator, + ) + target_pred_ls.append(target_pred_raw.detach()) + target_preds = torch.concat(target_pred_ls, dim=0) + torch.cuda.empty_cache() # clear vram cache for ensembling + + # ----------------- Test-time ensembling ----------------- + if ensemble_size > 1: + final_pred, pred_uncert = ensemble_normals( + target_preds, + **(ensemble_kwargs or {}), + ) + else: + final_pred = target_preds + pred_uncert = None + + # Resize back to original resolution + if match_input_res: + final_pred = resize( + final_pred, + input_size[-2:], + interpolation=resample_method, + antialias=True, + ) + + # Convert to numpy + final_pred = final_pred.squeeze() + final_pred = final_pred.cpu().numpy() + if pred_uncert is not None: + pred_uncert = pred_uncert.squeeze().cpu().numpy() + + # Clip output range + final_pred = final_pred.clip(-1, 1) + + # Colorize + normals_img = ((final_pred + 1) * 127.5).astype(np.uint8) + normals_img = chw2hwc(normals_img) + normals_img = Image.fromarray(normals_img) + + return MarigoldNormalsOutput( + normals_np=final_pred, + normals_img=normals_img, + uncertainty=pred_uncert, + ) + + def _check_inference_step(self, n_step: int) -> None: + """ + Check if denoising step is reasonable + Args: + n_step (`int`): denoising steps + """ + assert n_step >= 1 + + if isinstance(self.scheduler, DDIMScheduler): + if "trailing" != self.scheduler.config.timestep_spacing: + logging.warning( + f"The loaded `DDIMScheduler` is configured with `timestep_spacing=" + f'"{self.scheduler.config.timestep_spacing}"`; the recommended setting is `"trailing"`. ' + f"This change is backward-compatible and yields better results. " + f"Consider using `prs-eth/marigold-normals-v1-1` for the best experience." + ) + else: + if n_step > 10: + logging.warning( + f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on " + f"the default values." + ) + if not self.scheduler.config.rescale_betas_zero_snr: + logging.warning( + f"The loaded `DDIMScheduler` is configured with `rescale_betas_zero_snr=" + f"{self.scheduler.config.rescale_betas_zero_snr}`; the recommended setting is True. " + f"Consider using `prs-eth/marigold-normals-v1-1` for the best experience." + ) + elif isinstance(self.scheduler, LCMScheduler): + raise RuntimeError( + "This pipeline implementation does not support the LCMScheduler. Please refer to the project " + "README.md for instructions about using LCM." + ) + else: + raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}") + + def encode_empty_text(self): + """ + Encode text embedding for empty prompt + """ + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) + self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) + + @torch.no_grad() + def single_infer( + self, + rgb_in: torch.Tensor, + num_inference_steps: int, + generator: Union[torch.Generator, None], + show_pbar: bool, + ) -> torch.Tensor: + """ + Perform a single prediction without ensembling. + + Args: + rgb_in (`torch.Tensor`): + Input RGB image. + num_inference_steps (`int`): + Number of diffusion denoisign steps (DDIM) during inference. + show_pbar (`bool`): + Display a progress bar of diffusion denoising. + generator (`torch.Generator`) + Random generator for initial noise generation. + Returns: + `torch.Tensor`: Predicted targets. + """ + device = self.device + rgb_in = rgb_in.to(device) + + # Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps # [T] + + # Encode image + rgb_latent = self.encode_rgb(rgb_in) # [B, 4, h, w] + + # Noisy latent for outputs + target_latent = torch.randn( + rgb_latent.shape, + device=device, + dtype=self.dtype, + generator=generator, + ) # [B, 4, h, w] + + # Batched empty text embedding + if self.empty_text_embed is None: + self.encode_empty_text() + batch_empty_text_embed = self.empty_text_embed.repeat( + (rgb_latent.shape[0], 1, 1) + ).to(device) # [B, 2, 1024] + + # Denoising loop + if show_pbar: + iterable = tqdm( + enumerate(timesteps), + total=len(timesteps), + leave=False, + desc=" " * 4 + "Diffusion denoising", + ) + else: + iterable = enumerate(timesteps) + + for i, t in iterable: + unet_input = torch.cat( + [rgb_latent, target_latent], dim=1 + ) # this order is important + + # predict the noise residual + noise_pred = self.unet( + unet_input, t, encoder_hidden_states=batch_empty_text_embed + ).sample # [B, 4, h, w] + + # compute the previous noisy sample x_t -> x_t-1 + target_latent = self.scheduler.step( + noise_pred, t, target_latent, generator=generator + ).prev_sample + + normals = self.decode_normals(target_latent) # [B,3,H,W] + + # clip prediction + normals = torch.clip(normals, -1.0, 1.0) + norm = torch.norm(normals, dim=1, keepdim=True) + normals /= norm.clamp(min=1e-6) # [B,3,H,W] + + return normals + + def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: + """ + Encode RGB image into latent. + + Args: + rgb_in (`torch.Tensor`): + Input RGB image to be encoded. + + Returns: + `torch.Tensor`: Image latent. + """ + # encode + h = self.vae.encoder(rgb_in) + moments = self.vae.quant_conv(h) + mean, logvar = torch.chunk(moments, 2, dim=1) + # scale latent + rgb_latent = mean * self.latent_scale_factor + return rgb_latent + + def decode_normals(self, normals_latent: torch.Tensor) -> torch.Tensor: + """ + Decode normals latent into normals map. + + Args: + normals_latent (`torch.Tensor`): + Normals latent to be decoded. + + Returns: + `torch.Tensor`: Decoded normals map. + """ + # scale latent + normals_latent = normals_latent / self.latent_scale_factor + # decode + z = self.vae.post_quant_conv(normals_latent) + stacked = self.vae.decoder(z) + return stacked diff --git a/marigold/util/batchsize.py b/marigold/util/batchsize.py new file mode 100644 index 0000000000000000000000000000000000000000..888f0f6da708e57a827ed9e9ba1f94c0478ac551 --- /dev/null +++ b/marigold/util/batchsize.py @@ -0,0 +1,90 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import math +import torch + +# Search table for suggested max. inference batch size +bs_search_table = [ + # tested on A100-PCIE-80GB + {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, + {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, + # tested on A100-PCIE-40GB + {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, + {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, + {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, + {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, + # tested on RTX3090, RTX4090 + {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, + {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, + {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, + {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, + {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, + {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, + # tested on GTX1080Ti + {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, + {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, + {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, + {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, + {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, +] + + +def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: + """ + Automatically search for suitable operating batch size. + + Args: + ensemble_size (`int`): + Number of predictions to be ensembled. + input_res (`int`): + Operating resolution of the input image. + + Returns: + `int`: Operating batch size. + """ + if not torch.cuda.is_available(): + return 1 + + total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 + filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] + for settings in sorted( + filtered_bs_search_table, + key=lambda k: (k["res"], -k["total_vram"]), + ): + if input_res <= settings["res"] and total_vram >= settings["total_vram"]: + bs = settings["bs"] + if bs > ensemble_size: + bs = ensemble_size + elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: + bs = math.ceil(ensemble_size / 2) + return bs + + return 1 diff --git a/marigold/util/ensemble.py b/marigold/util/ensemble.py new file mode 100644 index 0000000000000000000000000000000000000000..a29db6df414cf3693715865fa4d3aea1d0481fd1 --- /dev/null +++ b/marigold/util/ensemble.py @@ -0,0 +1,270 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import numpy as np +import torch +from functools import partial +from typing import Optional, Tuple + +from .image_util import get_tv_resample_method, resize_max_res + + +def ensemble_depth( + depth: torch.Tensor, + scale_invariant: bool = True, + shift_invariant: bool = True, + output_uncertainty: bool = False, + reduction: str = "median", + regularizer_strength: float = 0.02, + max_iter: int = 50, + tol: float = 1e-6, + max_res: int = 1024, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the + number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for + depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The + alignment happens when the predictions have one or more degrees of freedom, that is when they are either + affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only + `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) + alignment is skipped and only ensembling is performed. + + Args: + depth (`torch.Tensor`): + Input ensemble depth maps. + scale_invariant (`bool`, *optional*, defaults to `True`): + Whether to treat predictions as scale-invariant. + shift_invariant (`bool`, *optional*, defaults to `True`): + Whether to treat predictions as shift-invariant. + output_uncertainty (`bool`, *optional*, defaults to `False`): + Whether to output uncertainty map. + reduction (`str`, *optional*, defaults to `"median"`): + Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and + `"median"`. + regularizer_strength (`float`, *optional*, defaults to `0.02`): + Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. + max_iter (`int`, *optional*, defaults to `2`): + Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` + argument. + tol (`float`, *optional*, defaults to `1e-3`): + Alignment solver tolerance. The solver stops when the tolerance is reached. + max_res (`int`, *optional*, defaults to `1024`): + Resolution at which the alignment is performed; `None` matches the `processing_resolution`. + Returns: + A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: + `(1, 1, H, W)`. + """ + if depth.dim() != 4 or depth.shape[1] != 1: + raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") + if reduction not in ("mean", "median"): + raise ValueError(f"Unrecognized reduction method: {reduction}.") + if not scale_invariant and shift_invariant: + raise ValueError("Pure shift-invariant ensembling is not supported.") + + def init_param(depth: torch.Tensor): + init_min = depth.reshape(ensemble_size, -1).min(dim=1).values + init_max = depth.reshape(ensemble_size, -1).max(dim=1).values + + if scale_invariant and shift_invariant: + init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) + init_t = -init_s * init_min + param = torch.cat((init_s, init_t)).cpu().numpy() + elif scale_invariant: + init_s = 1.0 / init_max.clamp(min=1e-6) + param = init_s.cpu().numpy() + else: + raise ValueError("Unrecognized alignment.") + + return param.astype(np.float64) + + def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor: + if scale_invariant and shift_invariant: + s, t = np.split(param, 2) + s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1) + t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1) + out = depth * s + t + elif scale_invariant: + s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1) + out = depth * s + else: + raise ValueError("Unrecognized alignment.") + return out + + def ensemble( + depth_aligned: torch.Tensor, return_uncertainty: bool = False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + uncertainty = None + if reduction == "mean": + prediction = torch.mean(depth_aligned, dim=0, keepdim=True) + if return_uncertainty: + uncertainty = torch.std(depth_aligned, dim=0, keepdim=True) + elif reduction == "median": + prediction = torch.median(depth_aligned, dim=0, keepdim=True).values + if return_uncertainty: + uncertainty = torch.median( + torch.abs(depth_aligned - prediction), dim=0, keepdim=True + ).values + else: + raise ValueError(f"Unrecognized reduction method: {reduction}.") + return prediction, uncertainty + + def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float: + cost = 0.0 + depth_aligned = align(depth, param) + + for i, j in torch.combinations(torch.arange(ensemble_size)): + diff = depth_aligned[i] - depth_aligned[j] + cost += (diff**2).mean().sqrt().item() + + if regularizer_strength > 0: + prediction, _ = ensemble(depth_aligned, return_uncertainty=False) + err_near = (0.0 - prediction.min()).abs().item() + err_far = (1.0 - prediction.max()).abs().item() + cost += (err_near + err_far) * regularizer_strength + + return cost + + def compute_param(depth: torch.Tensor): + import scipy + + depth_to_align = depth.to(torch.float32) + if max_res is not None and max(depth_to_align.shape[2:]) > max_res: + depth_to_align = resize_max_res( + depth_to_align, max_res, get_tv_resample_method("nearest-exact") + ) + + param = init_param(depth_to_align) + + res = scipy.optimize.minimize( + partial(cost_fn, depth=depth_to_align), + param, + method="BFGS", + tol=tol, + options={"maxiter": max_iter, "disp": False}, + ) + + return res.x + + requires_aligning = scale_invariant or shift_invariant + ensemble_size = depth.shape[0] + + if requires_aligning: + param = compute_param(depth) + depth = align(depth, param) + + depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) + + depth_max = depth.max() + if scale_invariant and shift_invariant: + depth_min = depth.min() + elif scale_invariant: + depth_min = 0 + else: + raise ValueError("Unrecognized alignment.") + depth_range = (depth_max - depth_min).clamp(min=1e-6) + depth = (depth - depth_min) / depth_range + if output_uncertainty: + uncertainty /= depth_range + + return depth, uncertainty # [1,1,H,W], [1,1,H,W] + + +def ensemble_normals( + normals: torch.Tensor, + output_uncertainty: bool = False, + reduction: str = "closest", +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is + the number of ensemble members for a given prediction of size `(H x W)`. + + Args: + normals (`torch.Tensor`): + Input ensemble normals maps. + output_uncertainty (`bool`, *optional*, defaults to `False`): + Whether to output uncertainty map. + reduction (`str`, *optional*, defaults to `"closest"`): + Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and + `"mean"`. + + Returns: + A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of + uncertainties of shape `(1, 1, H, W)`. + """ + if normals.dim() != 4 or normals.shape[1] != 3: + raise ValueError( + f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}." + ) + if reduction not in ("closest", "mean"): + raise ValueError(f"Unrecognized reduction method: {reduction}.") + + mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W] + norm = torch.norm(mean_normals, dim=1, keepdim=True) + mean_normals /= norm.clamp(min=1e-6) # [1,3,H,W] + + sim_cos = None + if output_uncertainty or (reduction != "mean"): + sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W] + sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16 + + uncertainty = None + if output_uncertainty: + uncertainty = sim_cos.arccos() # [E,1,H,W] + uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W] + + if reduction == "mean": + return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W] + + closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W] + closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W] + closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W] + + return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W] + + +def ensemble_iid( + targets: torch.Tensor, + output_uncertainty: bool = False, + reduction: str = "median", +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + uncertainty = None + if reduction == "mean": + prediction = torch.mean(targets, dim=0, keepdim=True) + if output_uncertainty: + uncertainty = torch.std(targets, dim=0, keepdim=True) + elif reduction == "median": + prediction = torch.median(targets, dim=0, keepdim=True).values + if output_uncertainty: + uncertainty = torch.median( + torch.abs(targets - prediction), dim=0, keepdim=True + ).values + else: + raise ValueError(f"Unrecognized reduction method: {reduction}.") + return prediction, uncertainty diff --git a/requirements++.txt b/requirements++.txt new file mode 100644 index 0000000000000000000000000000000000000000..54bcc575f0a71b42b149c1949a8a5d227c0cf10e --- /dev/null +++ b/requirements++.txt @@ -0,0 +1,7 @@ +h5py +opencv-python +tensorboard +wandb +scikit-learn +xformers==0.0.28 + diff --git a/requirements+.txt b/requirements+.txt new file mode 100644 index 0000000000000000000000000000000000000000..bdb3181b41057a276097e74f6774954087683b22 --- /dev/null +++ b/requirements+.txt @@ -0,0 +1,4 @@ +omegaconf +pandas +tabulate +torchmetrics \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 08ed5eeb4b9f080b780db7d3e0af6712866c0493..85293ecbaf122d5e79b3d6990d795f1e30d6f79e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,7 @@ -torch \ No newline at end of file +accelerate>=0.22.0 +diffusers>=0.25.0 +matplotlib +scipy +torch==2.4.1 +torchvision==0.19.1 +transformers>=4.32.1 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e530afed19f0ee401aea743a259dcd182434b106 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- diff --git a/src/dataset/__init__.py b/src/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef12299f659d527b0b4ddec1ce361a499d7a957 --- /dev/null +++ b/src/dataset/__init__.py @@ -0,0 +1,107 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import os +from typing import Union, List + +from .base_depth_dataset import ( + BaseDepthDataset, + get_pred_name, # noqa: F401 + DatasetMode, +) # noqa: F401 +from .base_iid_dataset import BaseIIDDataset # noqa: F401 +from .base_normals_dataset import BaseNormalsDataset # noqa: F401 +from .diode_dataset import DIODEDepthDataset, DIODENormalsDataset +from .eth3d_dataset import ETH3DDepthDataset +from .hypersim_dataset import ( + HypersimDepthDataset, + HypersimNormalsDataset, + HypersimIIDDataset, +) +from .ibims_dataset import IBimsNormalsDataset +from .interiorverse_dataset import InteriorVerseNormalsDataset, InteriorVerseIIDDataset +from .kitti_dataset import KITTIDepthDataset +from .nyu_dataset import NYUDepthDataset, NYUNormalsDataset +from .oasis_dataset import OasisNormalsDataset +from .scannet_dataset import ScanNetDepthDataset, ScanNetNormalsDataset +from .sintel_dataset import SintelNormalsDataset +from .vkitti_dataset import VirtualKITTIDepthDataset + +dataset_name_class_dict = { + "hypersim_depth": HypersimDepthDataset, + "vkitti_depth": VirtualKITTIDepthDataset, + "nyu_depth": NYUDepthDataset, + "kitti_depth": KITTIDepthDataset, + "eth3d_depth": ETH3DDepthDataset, + "diode_depth": DIODEDepthDataset, + "scannet_depth": ScanNetDepthDataset, + "hypersim_normals": HypersimNormalsDataset, + "interiorverse_normals": InteriorVerseNormalsDataset, + "sintel_normals": SintelNormalsDataset, + "ibims_normals": IBimsNormalsDataset, + "nyu_normals": NYUNormalsDataset, + "scannet_normals": ScanNetNormalsDataset, + "diode_normals": DIODENormalsDataset, + "oasis_normals": OasisNormalsDataset, + "interiorverse_iid": InteriorVerseIIDDataset, + "hypersim_iid": HypersimIIDDataset, +} + + +def get_dataset( + cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs +) -> Union[ + BaseDepthDataset, + BaseIIDDataset, + BaseNormalsDataset, + List[BaseDepthDataset], + List[BaseIIDDataset], + List[BaseNormalsDataset], +]: + if "mixed" == cfg_data_split.name: + assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets." + dataset_ls = [ + get_dataset(_cfg, base_data_dir, mode, **kwargs) + for _cfg in cfg_data_split.dataset_list + ] + return dataset_ls + elif cfg_data_split.name in dataset_name_class_dict.keys(): + dataset_class = dataset_name_class_dict[cfg_data_split.name] + dataset = dataset_class( + mode=mode, + filename_ls_path=cfg_data_split.filenames, + dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir), + **cfg_data_split, + **kwargs, + ) + else: + raise NotImplementedError + + return dataset diff --git a/src/dataset/base_depth_dataset.py b/src/dataset/base_depth_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5b7a9978700a449f84a7c928f1e44a63f9d2062d --- /dev/null +++ b/src/dataset/base_depth_dataset.py @@ -0,0 +1,285 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import io +import numpy as np +import os +import random +import tarfile +import torch +from PIL import Image +from enum import Enum +from torch.utils.data import Dataset +from torchvision.transforms import InterpolationMode, Resize +from typing import Union + +from src.util.depth_transform import DepthNormalizerBase + + +class DatasetMode(Enum): + RGB_ONLY = "rgb_only" + EVAL = "evaluate" + TRAIN = "train" + + +class DepthFileNameMode(Enum): + """Prediction file naming modes""" + + id = 1 # id.png + rgb_id = 2 # rgb_id.png + i_d_rgb = 3 # i_d_1_rgb.png + rgb_i_d = 4 + + +class BaseDepthDataset(Dataset): + def __init__( + self, + mode: DatasetMode, + filename_ls_path: str, + dataset_dir: str, + disp_name: str, + min_depth: float, + max_depth: float, + has_filled_depth: bool, + name_mode: DepthFileNameMode, + depth_transform: Union[DepthNormalizerBase, None] = None, + augmentation_args: dict = None, + resize_to_hw=None, + move_invalid_to_far_plane: bool = True, + rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1], + **kwargs, + ) -> None: + super().__init__() + self.mode = mode + # dataset info + self.filename_ls_path = filename_ls_path + self.dataset_dir = dataset_dir + assert os.path.exists( + self.dataset_dir + ), f"Dataset does not exist at: {self.dataset_dir}" + self.disp_name = disp_name + self.has_filled_depth = has_filled_depth + self.name_mode: DepthFileNameMode = name_mode + self.min_depth = min_depth + self.max_depth = max_depth + + # training arguments + self.depth_transform: DepthNormalizerBase = depth_transform + self.augm_args = augmentation_args + self.resize_to_hw = resize_to_hw + self.rgb_transform = rgb_transform + self.move_invalid_to_far_plane = move_invalid_to_far_plane + + # Load filenames + with open(self.filename_ls_path, "r") as f: + self.filenames = [ + s.split() for s in f.readlines() + ] # [['rgb.png', 'depth.tif'], [], ...] + + # Tar dataset + self.tar_obj = None + self.is_tar = ( + True + if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) + else False + ) + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, index): + rasters, other = self._get_data_item(index) + if DatasetMode.TRAIN == self.mode: + rasters = self._training_preprocess(rasters) + # merge + outputs = rasters + outputs.update(other) + return outputs + + def _get_data_item(self, index): + rgb_rel_path, depth_rel_path, filled_rel_path = self._get_data_path(index=index) + + rasters = {} + + # RGB data + rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) + + # Depth data + if DatasetMode.RGB_ONLY != self.mode: + # load data + depth_data = self._load_depth_data( + depth_rel_path=depth_rel_path, filled_rel_path=filled_rel_path + ) + rasters.update(depth_data) + # valid mask + rasters["valid_mask_raw"] = self._get_valid_mask( + rasters["depth_raw_linear"] + ).clone() + rasters["valid_mask_filled"] = self._get_valid_mask( + rasters["depth_filled_linear"] + ).clone() + + other = {"index": index, "rgb_relative_path": rgb_rel_path} + + return rasters, other + + def _load_rgb_data(self, rgb_rel_path): + # Read RGB data + rgb = self._read_rgb_file(rgb_rel_path) + rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] + + outputs = { + "rgb_int": torch.from_numpy(rgb).int(), + "rgb_norm": torch.from_numpy(rgb_norm).float(), + } + return outputs + + def _load_depth_data(self, depth_rel_path, filled_rel_path): + # Read depth data + outputs = {} + depth_raw = self._read_depth_file(depth_rel_path).squeeze() + depth_raw_linear = torch.from_numpy(depth_raw).float().unsqueeze(0) # [1, H, W] + outputs["depth_raw_linear"] = depth_raw_linear.clone() + + if self.has_filled_depth: + depth_filled = self._read_depth_file(filled_rel_path).squeeze() + depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0) + outputs["depth_filled_linear"] = depth_filled_linear + else: + outputs["depth_filled_linear"] = depth_raw_linear.clone() + + return outputs + + def _get_data_path(self, index): + filename_line = self.filenames[index] + + # Get data path + rgb_rel_path = filename_line[0] + + depth_rel_path, filled_rel_path = None, None + if DatasetMode.RGB_ONLY != self.mode: + depth_rel_path = filename_line[1] + if self.has_filled_depth: + filled_rel_path = filename_line[2] + return rgb_rel_path, depth_rel_path, filled_rel_path + + def _read_image(self, img_rel_path) -> np.ndarray: + if self.is_tar: + if self.tar_obj is None: + self.tar_obj = tarfile.open(self.dataset_dir) + image_to_read = self.tar_obj.extractfile("./" + img_rel_path) + image_to_read = image_to_read.read() + image_to_read = io.BytesIO(image_to_read) + else: + image_to_read = os.path.join(self.dataset_dir, img_rel_path) + image = Image.open(image_to_read) # [H, W, rgb] + image = np.asarray(image) + return image + + def _read_rgb_file(self, rel_path) -> np.ndarray: + rgb = self._read_image(rel_path) + rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W] + return rgb + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Replace code below to decode depth according to dataset definition + depth_decoded = depth_in + + return depth_decoded + + def _get_valid_mask(self, depth: torch.Tensor): + valid_mask = torch.logical_and( + (depth > self.min_depth), (depth < self.max_depth) + ).bool() + return valid_mask + + def _training_preprocess(self, rasters): + # Augmentation + if self.augm_args is not None: + rasters = self._augment_data(rasters) + + # Normalization + rasters["depth_raw_norm"] = self.depth_transform( + rasters["depth_raw_linear"], rasters["valid_mask_raw"] + ).clone() + rasters["depth_filled_norm"] = self.depth_transform( + rasters["depth_filled_linear"], rasters["valid_mask_filled"] + ).clone() + + # Set invalid pixel to far plane + if self.move_invalid_to_far_plane: + if self.depth_transform.far_plane_at_max: + rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( + self.depth_transform.norm_max + ) + else: + rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( + self.depth_transform.norm_min + ) + + # Resize + if self.resize_to_hw is not None: + resize_transform = Resize( + size=self.resize_to_hw, interpolation=InterpolationMode.NEAREST_EXACT + ) + rasters = {k: resize_transform(v) for k, v in rasters.items()} + + return rasters + + def _augment_data(self, rasters_dict): + # lr flipping + lr_flip_p = self.augm_args.lr_flip_p + if random.random() < lr_flip_p: + rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()} + + return rasters_dict + + def __del__(self): + if hasattr(self, "tar_obj") and self.tar_obj is not None: + self.tar_obj.close() + self.tar_obj = None + + +def get_pred_name(rgb_basename, name_mode, suffix=".png"): + if DepthFileNameMode.rgb_id == name_mode: + pred_basename = "pred_" + rgb_basename.split("_")[1] + elif DepthFileNameMode.i_d_rgb == name_mode: + pred_basename = rgb_basename.replace("_rgb.", "_pred.") + elif DepthFileNameMode.id == name_mode: + pred_basename = "pred_" + rgb_basename + elif DepthFileNameMode.rgb_i_d == name_mode: + pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:]) + else: + raise NotImplementedError + # change suffix + pred_basename = os.path.splitext(pred_basename)[0] + suffix + + return pred_basename diff --git a/src/dataset/base_iid_dataset.py b/src/dataset/base_iid_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f76900ebd1c234c6f9916486bbb3a21b2a0896ec --- /dev/null +++ b/src/dataset/base_iid_dataset.py @@ -0,0 +1,205 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import os + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # Enable OpenCV support for EXR +# ruff: noqa: E402 + +import io +import tarfile +import numpy as np +import random +import torch +from torch.utils.data import Dataset +from torchvision.transforms import InterpolationMode, Resize +from .base_depth_dataset import DatasetMode + +from src.util.image_util import ( + img_hwc2chw, + img_linear2srgb, + is_hdr, + read_img_from_file, + read_img_from_tar, +) + + +class BaseIIDDataset(Dataset): + def __init__( + self, + mode: DatasetMode, + filename_ls_path: str, + dataset_dir: str, + disp_name: str, + augmentation_args: dict = None, + resize_to_hw=None, + **kwargs, + ) -> None: + super().__init__() + self.mode = mode + # dataset info + self.filename_ls_path = filename_ls_path + self.dataset_dir = dataset_dir + assert os.path.exists( + self.dataset_dir + ), f"Dataset does not exist at: {self.dataset_dir}" + self.disp_name = disp_name + + # training arguments + self.augm_args = augmentation_args + self.resize_to_hw = resize_to_hw + + # Load filenames + with open(self.filename_ls_path, "r") as f: + self.filenames = [s.split() for s in f.readlines()] + + # Tar dataset + self.tar_obj = None + self.is_tar = ( + True + if os.path.isfile(self.dataset_dir) and tarfile.is_tarfile(self.dataset_dir) + else False + ) + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, index): + rasters, other = self._get_data_item(index) + if DatasetMode.TRAIN == self.mode: + rasters = self._training_preprocess(rasters) + # merge + outputs = rasters + outputs.update(other) + return outputs + + def _get_data_item(self, index): + rgb_rel_path, targets_rel_path = self._get_data_path(index=index) + + rasters = {} + + # RGB data + rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) + + # load targets data should be filled in specialized dataset definition + if DatasetMode.RGB_ONLY != self.mode: + targets_data = self._load_targets_data(rel_paths=targets_rel_path) + rasters.update(targets_data) + + other = {"index": index, "rgb_relative_path": rgb_rel_path} + + return rasters, other + + def _read_image(self, rel_path) -> np.ndarray: + if self.is_tar: + if self.tar_obj is None: + self.tar_obj = tarfile.open(self.dataset_dir) + img = read_img_from_tar(self.tar_obj, rel_path) + else: + img = read_img_from_file(os.path.join(self.dataset_dir, rel_path)) + + if len(img.shape) == 3: # hwc->chw, except for single-channel images + img = img_hwc2chw(img) + + assert img.min() >= 0 and img.max() <= 1 + return img + + def _load_rgb_data(self, rgb_rel_path): + # rgb is in [0,1] range + rgb = self._read_image(rgb_rel_path) + + # SD is pretrained in sRGB space. If we load HDR data, we should also convert to sRGB space. + if is_hdr(rgb_rel_path): + rgb = img_linear2srgb(rgb) + + if rgb.shape[0] == 4: + rgb = rgb[:3, :, :] + + outputs = {"rgb": torch.from_numpy(rgb).float()} # [0,1] + + return outputs + + def _read_numpy(self, rel_path): + if self.is_tar: + if self.tar_obj is None: + self.tar_obj = tarfile.open(self.dataset_dir) + image_to_read = self.tar_obj.extractfile("./" + rel_path) + image_to_read = image_to_read.read() + image_to_read = io.BytesIO(image_to_read) + else: + image_to_read = os.path.join(self.dataset_dir, rel_path) + + image = np.load(image_to_read).transpose((2, 0, 1)) # [3,H,W] + return image + + def _load_targets_data(self, rel_paths): + outputs = {} + return outputs + + def _get_data_path(self, index): + filename_line = self.filenames[index] + + # Only the first is the input image, the rest should be specialized by each dataset. + rgb_rel_path = filename_line[0] + targets_rel_path = filename_line[1:] + + return rgb_rel_path, targets_rel_path + + def _training_preprocess(self, rasters): + # Augmentation + if self.augm_args is not None: + rasters = self._augment_data(rasters) + + # Resize + if self.resize_to_hw is not None: + resize_bilinear = Resize( + size=self.resize_to_hw, interpolation=InterpolationMode.BILINEAR + ) + resize_nearest = Resize( + size=self.resize_to_hw, interpolation=InterpolationMode.NEAREST_EXACT + ) + + rasters = { + k: (resize_nearest(v) if "valid_mask" in k else resize_bilinear(v)) + for k, v in rasters.items() + } + + return rasters + + def _augment_data(self, rasters): + # horizontal flip + if random.random() < self.augm_args.lr_flip_p: + rasters = {k: v.flip(-1) for k, v in rasters.items()} + return rasters + + def __del__(self): + if hasattr(self, "tar_obj") and self.tar_obj is not None: + self.tar_obj.close() + self.tar_obj = None diff --git a/src/dataset/base_normals_dataset.py b/src/dataset/base_normals_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3da3711f7a184a300b94386a59358b2f32c520b8 --- /dev/null +++ b/src/dataset/base_normals_dataset.py @@ -0,0 +1,265 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import io +import numpy as np +import os +import random +import tarfile +import torch +from PIL import Image +import torchvision.transforms.functional as TF +from torch.utils.data import Dataset, get_worker_info +from torchvision.transforms import InterpolationMode, Resize, ColorJitter +from .base_depth_dataset import DatasetMode + + +class BaseNormalsDataset(Dataset): + def __init__( + self, + mode: DatasetMode, + filename_ls_path: str, + dataset_dir: str, + disp_name: str, + augmentation_args: dict = None, + resize_to_hw=None, + **kwargs, + ) -> None: + super().__init__() + self.mode = mode + # dataset info + self.filename_ls_path = filename_ls_path + self.dataset_dir = dataset_dir + assert os.path.exists( + self.dataset_dir + ), f"Dataset does not exist at: {self.dataset_dir}" + self.disp_name = disp_name + + # training arguments + self.augm_args = augmentation_args + self.resize_to_hw = resize_to_hw + + # Load filenames + with open(self.filename_ls_path, "r") as f: + self.filenames = [s.split() for s in f.readlines()] + + # Tar dataset + self.tar_obj = None + self.is_tar = ( + True + if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) + else False + ) + + if self.is_tar: + self.tar_obj = tarfile.open(self.dataset_dir) + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, index): + rasters, other = self._get_data_item(index) + if DatasetMode.TRAIN == self.mode: + rasters = self._training_preprocess(rasters) + # merge + outputs = rasters + outputs.update(other) + return outputs + + def _get_data_item(self, index): + rgb_rel_path, normals_rel_path = self._get_data_path(index=index) + + rasters = {} + + # RGB data + rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) + + # Normals data + if DatasetMode.RGB_ONLY != self.mode: + normals_data = self._load_normals_data(normals_rel_path=normals_rel_path) + rasters.update(normals_data) + + other = {"index": index, "rgb_relative_path": rgb_rel_path} + + return rasters, other + + def _load_rgb_data(self, rgb_rel_path): + # Read RGB data + rgb = self._read_rgb_file(rgb_rel_path) + rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] + + outputs = { + "rgb_int": torch.from_numpy(rgb).int(), + "rgb_norm": torch.from_numpy(rgb_norm).float(), + } + return outputs + + def _load_normals_data(self, normals_rel_path): + outputs = {} + normals = torch.from_numpy( + self._read_normals_file(normals_rel_path) + ).float() # [3,H,W] + outputs["normals"] = normals + + return outputs + + def _get_data_path(self, index): + filename_line = self.filenames[index] + + # Get data path + rgb_rel_path = filename_line[0] + normals_rel_path = filename_line[1] + + return rgb_rel_path, normals_rel_path + + def _read_image(self, img_rel_path) -> np.ndarray: + if self.is_tar: + if self.tar_obj is None: + self.tar_obj = tarfile.open(self.dataset_dir) + image_to_read = self.tar_obj.extractfile("./" + img_rel_path) + image_to_read = image_to_read.read() + image_to_read = io.BytesIO(image_to_read) + else: + image_to_read = os.path.join(self.dataset_dir, img_rel_path) + image = Image.open(image_to_read) # [H, W, rgb] + image = np.asarray(image) + return image + + def _read_rgb_file(self, rel_path) -> np.ndarray: + rgb = self._read_image(rel_path) + rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W] + return rgb + + def _read_normals_file(self, rel_path): + if self.is_tar: + if self.tar_obj is None: + self.tar_obj = tarfile.open(self.dataset_dir) + # normal = self.tar_obj.extractfile(f'./{tar_name}/'+rel_path) + normal = self.tar_obj.extractfile("./" + rel_path) + normal = normal.read() + normal = np.load(io.BytesIO(normal)) # [H, W, 3] + else: + normal_path = os.path.join(self.dataset_dir, rel_path) + normal = np.load(normal_path) + normal = np.transpose(normal, (2, 0, 1)) # [3, H, W] + return normal + + def _training_preprocess(self, rasters): + # Augmentation + if self.augm_args is not None: + rasters = self._augment_data(rasters) + + # Resize + if self.resize_to_hw is not None: + resize_transform = Resize( + size=self.resize_to_hw, interpolation=InterpolationMode.BILINEAR + ) + rasters = {k: resize_transform(v) for k, v in rasters.items()} + + return rasters + + def _augment_data(self, rasters): + # horizontal flip (gt normals have to be flipped too) + if random.random() < self.augm_args.lr_flip_p: + rasters = {k: v.flip(-1) for k, v in rasters.items()} + rasters["normals"][0, :, :] *= -1 + + # if the process is on the main thread, we can use gpu to to the augmentation + use_gpu = get_worker_info() is None + if use_gpu: + rasters = {k: v.cuda() for k, v in rasters.items()} + + # random gaussian blur + if ( + random.random() < self.augm_args.gaussian_blur_p + and rasters["rgb_int"].shape[-2] == 768 + ): # only blur if Hypersim sample + random_rgb_sigma = random.uniform(0.0, self.augm_args.gaussian_blur_sigma) + rasters["rgb_int"] = TF.gaussian_blur( + rasters["rgb_int"], kernel_size=33, sigma=random_rgb_sigma + ).int() + + # motion blur + if ( + random.random() < self.augm_args.motion_blur_p + and rasters["rgb_int"].shape[-2] == 768 + ): # only blur if Hypersim sample + random_kernel_size = random.choice( + [ + x + for x in range(3, self.augm_args.motion_blur_kernel_size + 1) + if x % 2 == 1 + ] + ) + kernel = torch.zeros( + random_kernel_size, + random_kernel_size, + dtype=rasters["rgb_int"].dtype, + device=rasters["rgb_int"].device, + ) + kernel[random_kernel_size // 2, :] = torch.ones(random_kernel_size) + kernel = TF.rotate( + kernel.unsqueeze(0), + random.uniform(0.0, self.augm_args.motion_blur_angle_range), + ) + kernel = kernel / kernel.sum() + channels = rasters["rgb_int"].shape[0] + kernel = kernel.expand(channels, 1, random_kernel_size, random_kernel_size) + rasters["rgb_int"] = ( + torch.conv2d( + rasters["rgb_int"].unsqueeze(0).float(), + kernel, + stride=1, + padding=random_kernel_size // 2, + groups=channels, + ) + .squeeze(0) + .int() + ) + # color jitter + if random.random() < self.augm_args.color_jitter_p: + color_jitter = ColorJitter( + brightness=self.augm_args.jitter_brightness_factor, + contrast=self.augm_args.jitter_contrast_factor, + saturation=self.augm_args.jitter_saturation_factor, + hue=self.augm_args.jitter_hue_factor, + ) + rgb_int_temp = rasters["rgb_int"].float() / 255.0 + rgb_int_temp = color_jitter(rgb_int_temp) + rasters["rgb_int"] = (rgb_int_temp * 255.0).int() + + # update normalized rgb + rasters["rgb_norm"] = rasters["rgb_int"].float() / 255.0 * 2.0 - 1.0 + return rasters + + def __del__(self): + if hasattr(self, "tar_obj") and self.tar_obj is not None: + self.tar_obj.close() + self.tar_obj = None diff --git a/src/dataset/diode_dataset.py b/src/dataset/diode_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..02af4fcf4db15c9ac965d3a200573575595d3a37 --- /dev/null +++ b/src/dataset/diode_dataset.py @@ -0,0 +1,104 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import numpy as np +import os +import tarfile +import torch +from io import BytesIO + +from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode, DatasetMode +from .base_normals_dataset import BaseNormalsDataset + + +class DIODEDepthDataset(BaseDepthDataset): + def __init__( + self, + **kwargs, + ) -> None: + super().__init__( + # DIODE data parameter + min_depth=0.6, + max_depth=350, + has_filled_depth=False, + name_mode=DepthFileNameMode.id, + **kwargs, + ) + + def _read_npy_file(self, rel_path): + if self.is_tar: + if self.tar_obj is None: + self.tar_obj = tarfile.open(self.dataset_dir) + fileobj = self.tar_obj.extractfile("./" + rel_path) + npy_path_or_content = BytesIO(fileobj.read()) + else: + npy_path_or_content = os.path.join(self.dataset_dir, rel_path) + data = np.load(npy_path_or_content).squeeze()[np.newaxis, :, :] + return data + + def _read_depth_file(self, rel_path): + depth = self._read_npy_file(rel_path) + return depth + + def _get_data_path(self, index): + return self.filenames[index] + + def _get_data_item(self, index): + # Special: depth mask is read from data + + rgb_rel_path, depth_rel_path, mask_rel_path = self._get_data_path(index=index) + + rasters = {} + + # RGB data + rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) + + # Depth data + if DatasetMode.RGB_ONLY != self.mode: + # load data + depth_data = self._load_depth_data( + depth_rel_path=depth_rel_path, filled_rel_path=None + ) + rasters.update(depth_data) + + # valid mask + mask = self._read_npy_file(mask_rel_path).astype(bool) + mask = torch.from_numpy(mask).bool() + rasters["valid_mask_raw"] = mask.clone() + rasters["valid_mask_filled"] = mask.clone() + + other = {"index": index, "rgb_relative_path": rgb_rel_path} + + return rasters, other + + +class DIODENormalsDataset(BaseNormalsDataset): + def __getitem__(self, index): + return super().__getitem__(index) diff --git a/src/dataset/eth3d_dataset.py b/src/dataset/eth3d_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..edd84f5204f7c29e77e7aa45ce5153d68925e97d --- /dev/null +++ b/src/dataset/eth3d_dataset.py @@ -0,0 +1,73 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import numpy as np +import os +import tarfile +import torch + +from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode + + +class ETH3DDepthDataset(BaseDepthDataset): + HEIGHT, WIDTH = 4032, 6048 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__( + # ETH3D data parameter + min_depth=1e-5, + max_depth=torch.inf, + has_filled_depth=False, + name_mode=DepthFileNameMode.id, + **kwargs, + ) + + def _read_depth_file(self, rel_path): + # Read special binary data: https://www.eth3d.net/documentation#format-of-multi-view-data-image-formats + if self.is_tar: + if self.tar_obj is None: + self.tar_obj = tarfile.open(self.dataset_dir) + binary_data = self.tar_obj.extractfile("./" + rel_path) + binary_data = binary_data.read() + + else: + depth_path = os.path.join(self.dataset_dir, rel_path) + with open(depth_path, "rb") as file: + binary_data = file.read() + # Convert the binary data to a numpy array of 32-bit floats + depth_decoded = np.frombuffer(binary_data, dtype=np.float32).copy() + + depth_decoded[depth_decoded == torch.inf] = 0.0 + + depth_decoded = depth_decoded.reshape((self.HEIGHT, self.WIDTH)) + return depth_decoded diff --git a/src/dataset/hypersim_dataset.py b/src/dataset/hypersim_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b50106d849834206104664319d3e40b5ea0882c9 --- /dev/null +++ b/src/dataset/hypersim_dataset.py @@ -0,0 +1,143 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import torch + +from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode +from .base_normals_dataset import BaseNormalsDataset +from .base_iid_dataset import BaseIIDDataset + + +class HypersimDepthDataset(BaseDepthDataset): + def __init__( + self, + **kwargs, + ) -> None: + super().__init__( + # Hypersim data parameter + min_depth=1e-5, + max_depth=65.0, + has_filled_depth=False, + name_mode=DepthFileNameMode.rgb_i_d, + **kwargs, + ) + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode Hypersim depth + depth_decoded = depth_in / 1000.0 + return depth_decoded + + +class HypersimNormalsDataset(BaseNormalsDataset): + pass + + +class HypersimIIDDataset(BaseIIDDataset): + def _load_targets_data(self, rel_paths): + albedo_path = rel_paths[0] # albedo_cam_00_fr0000.npy + shading_path = rel_paths[1] # shading_cam_00_fr0000.npy + residual_path = rel_paths[2] # residual_cam_00_fr0000.npy + + albedo_raw = self._read_numpy(albedo_path) # in linear space + shading_raw = self._read_numpy(shading_path) + residual_raw = self._read_numpy(residual_path) + + rasters = { + "albedo": torch.from_numpy(albedo_raw).float(), # [0,1] linear space + "shading_raw": torch.from_numpy(shading_raw), + "residual_raw": torch.from_numpy(residual_raw), + } + del albedo_raw, shading_raw, residual_raw + + # get the cut off value for shading/residual + cut_off_value = self._get_cut_off_value(rasters) + # clip and normalize shading and residual (to bring them to the same scale and value range) + rasters = self._process_shading_residual(rasters, cut_off_value) + + # Load masks + valid_mask_albedo, valid_mask_shading, valid_mask_residual = ( + self._get_valid_masks(rasters) + ) + rasters.update( + { + "mask_albedo": valid_mask_albedo.bool(), + "mask_shading": valid_mask_shading.bool(), + "mask_residual": valid_mask_residual.bool(), + } + ) + return rasters + + def _process_shading_residual(self, rasters, cut_off_value): + # Clip by cut_off_value + shading_clipped = torch.clip(rasters["shading_raw"], 0, cut_off_value) + residual_clipped = torch.clip(rasters["residual_raw"], 0, cut_off_value) + # Divide by them same cut off value to bring them to the same scale + shading_norm = shading_clipped / cut_off_value # [0,1] + residual_norm = residual_clipped / cut_off_value # [0,1] + + rasters.update( + { + "shading": shading_norm.float(), + "residual": residual_norm.float(), + } + ) + return rasters + + def _get_cut_off_value(self, rasters): + shading_raw = rasters["shading_raw"] + residual_raw = rasters["residual_raw"] + + # take the maximum of residual_98 and shading_98 as cutoff value + residual_98 = torch.quantile(residual_raw, 0.98) + shading_98 = torch.quantile(shading_raw, 0.98) + cut_off_value = torch.max(torch.tensor([residual_98, shading_98])) + + return cut_off_value + + def _get_valid_masks(self, rasters): + albedo_gt_ts = rasters["albedo"] # [3,H,W] + invalid_mask_albedo = torch.isnan(albedo_gt_ts) | torch.isinf(albedo_gt_ts) + zero_mask = (albedo_gt_ts == 0).all(dim=0, keepdim=True) + zero_mask = zero_mask.expand_as(albedo_gt_ts) + invalid_mask_albedo |= zero_mask + valid_mask_albedo = ~invalid_mask_albedo + + shading_gt_ts = rasters["shading"] + invalid_mask_shading = torch.isnan(shading_gt_ts) | torch.isinf(shading_gt_ts) + valid_mask_shading = ~invalid_mask_shading + + residual_gt_ts = rasters["residual"] + invalid_mask_residual = torch.isnan(residual_gt_ts) | torch.isinf( + residual_gt_ts + ) + valid_mask_residual = ~invalid_mask_residual + + return valid_mask_albedo, valid_mask_shading, valid_mask_residual diff --git a/src/dataset/ibims_dataset.py b/src/dataset/ibims_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a488d2317754cd34b494039d6a8bacbe35340ffa --- /dev/null +++ b/src/dataset/ibims_dataset.py @@ -0,0 +1,35 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +from .base_normals_dataset import BaseNormalsDataset + + +class IBimsNormalsDataset(BaseNormalsDataset): + pass diff --git a/src/dataset/interiorverse_dataset.py b/src/dataset/interiorverse_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2d4db857d0170a09e44ad69f8974e6223d7c43f1 --- /dev/null +++ b/src/dataset/interiorverse_dataset.py @@ -0,0 +1,100 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import numpy as np +import torch + +from src.util.image_util import img_linear2srgb, is_hdr +from .base_iid_dataset import BaseIIDDataset, DatasetMode +from .base_normals_dataset import BaseNormalsDataset + + +class InteriorVerseNormalsDataset(BaseNormalsDataset): + pass + + +# https://github.com/jingsenzhu/IndoorInverseRendering/tree/main/interiorverse +class InteriorVerseIIDDataset(BaseIIDDataset): + def _load_targets_data(self, rel_paths): + albedo_path = rel_paths[0] # 000_albedo.exr + # material_path = rel_paths[1] # 000_material.exr + + albedo_img = self._read_image(albedo_path) + # material_img = self._read_image( + # material_path + # ) # R is roughness, G is metallicity + # material_img[2, :, :] = 0 + + mask_img_squeezed = None + mask_img = None + + if len(rel_paths) == 2: + mask_path = rel_paths[1] # 000_mask.png + mask_img = self._read_image(mask_path) + mask_img = mask_img != 0 # Convert to a boolean np array + if mask_img.ndim == 3: + mask_img_squeezed = np.expand_dims( + np.all(mask_img, axis=0), axis=0 + ) # Convert 3 channel to 1 channel + elif mask_img.ndim == 2: + mask_img_squeezed = np.expand_dims(mask_img, axis=0) + mask_img = np.stack([mask_img] * 3, axis=0) # (3, H, W) + + # SD is pretrained in sRGB space. If we load HDR data, we should also convert to sRGB space. + if is_hdr(albedo_path): + albedo_img = img_linear2srgb(albedo_img) + + # if is_hdr(material_path): + # material_img = img_linear2srgb(material_img) + + if albedo_img.shape[0] == 4: + albedo_img = albedo_img[:3, :, :] + + if len(rel_paths) == 2: + outputs = { + "albedo": torch.from_numpy(albedo_img), + # "material": torch.from_numpy(material_img), + "mask": torch.from_numpy(mask_img_squeezed), + } + else: + outputs = { + "albedo": torch.from_numpy(albedo_img), + } + + # add three channel mask for evaluation + if self.mode == DatasetMode.EVAL: + if len(rel_paths) == 2: + eval_masks = { + "mask_albedo": torch.from_numpy(mask_img).bool(), + # "mask_material": torch.from_numpy(mask_img).bool(), + } + outputs.update(eval_masks) + + return outputs diff --git a/src/dataset/kitti_dataset.py b/src/dataset/kitti_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c92f9315f50c1ab4b3104fa5730a56f11a467d2b --- /dev/null +++ b/src/dataset/kitti_dataset.py @@ -0,0 +1,132 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import torch + +from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode + + +class KITTIDepthDataset(BaseDepthDataset): + def __init__( + self, + kitti_bm_crop, # Crop to KITTI benchmark size + valid_mask_crop, # Evaluation mask. [None, garg or eigen] + **kwargs, + ) -> None: + super().__init__( + # KITTI data parameter + min_depth=1e-5, + max_depth=80, + has_filled_depth=False, + name_mode=DepthFileNameMode.id, + **kwargs, + ) + self.kitti_bm_crop = kitti_bm_crop + self.valid_mask_crop = valid_mask_crop + assert self.valid_mask_crop in [ + None, + "garg", # set evaluation mask according to Garg ECCV16 + "eigen", # set evaluation mask according to Eigen NIPS14 + ], f"Unknown crop type: {self.valid_mask_crop}" + + # Filter out empty depth + self.filenames = [f for f in self.filenames if "None" != f[1]] + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode KITTI depth + depth_decoded = depth_in / 256.0 + return depth_decoded + + def _load_rgb_data(self, rgb_rel_path): + rgb_data = super()._load_rgb_data(rgb_rel_path) + if self.kitti_bm_crop: + rgb_data = {k: self.kitti_benchmark_crop(v) for k, v in rgb_data.items()} + return rgb_data + + def _load_depth_data(self, depth_rel_path, filled_rel_path): + depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path) + if self.kitti_bm_crop: + depth_data = { + k: self.kitti_benchmark_crop(v) for k, v in depth_data.items() + } + return depth_data + + @staticmethod + def kitti_benchmark_crop(input_img): + """ + Crop images to KITTI benchmark size + Args: + `input_img` (torch.Tensor): Input image to be cropped. + + Returns: + torch.Tensor:Cropped image. + """ + KB_CROP_HEIGHT = 352 + KB_CROP_WIDTH = 1216 + + height, width = input_img.shape[-2:] + top_margin = int(height - KB_CROP_HEIGHT) + left_margin = int((width - KB_CROP_WIDTH) / 2) + if 2 == len(input_img.shape): + out = input_img[ + top_margin : top_margin + KB_CROP_HEIGHT, + left_margin : left_margin + KB_CROP_WIDTH, + ] + elif 3 == len(input_img.shape): + out = input_img[ + :, + top_margin : top_margin + KB_CROP_HEIGHT, + left_margin : left_margin + KB_CROP_WIDTH, + ] + return out + + def _get_valid_mask(self, depth: torch.Tensor): + # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py + valid_mask = super()._get_valid_mask(depth) # [1, H, W] + + if self.valid_mask_crop is not None: + eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() + gt_height, gt_width = eval_mask.shape + + if "garg" == self.valid_mask_crop: + eval_mask[ + int(0.40810811 * gt_height) : int(0.99189189 * gt_height), + int(0.03594771 * gt_width) : int(0.96405229 * gt_width), + ] = 1 + elif "eigen" == self.valid_mask_crop: + eval_mask[ + int(0.3324324 * gt_height) : int(0.91351351 * gt_height), + int(0.0359477 * gt_width) : int(0.96405229 * gt_width), + ] = 1 + + eval_mask.reshape(valid_mask.shape) + valid_mask = torch.logical_and(valid_mask, eval_mask) + return valid_mask diff --git a/src/dataset/mixed_sampler.py b/src/dataset/mixed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..d9514d03396e17602bc88c77aff3a1ac067e8521 --- /dev/null +++ b/src/dataset/mixed_sampler.py @@ -0,0 +1,157 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import torch +from torch.utils.data import ( + BatchSampler, + RandomSampler, + SequentialSampler, +) + + +class MixedBatchSampler(BatchSampler): + """Sample one batch from a selected dataset with given probability. + Compatible with datasets at different resolution + """ + + def __init__( + self, src_dataset_ls, batch_size, drop_last, shuffle, prob=None, generator=None + ): + self.base_sampler = None + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + self.generator = generator + + self.src_dataset_ls = src_dataset_ls + self.n_dataset = len(self.src_dataset_ls) + + # Dataset length + self.dataset_length = [len(ds) for ds in self.src_dataset_ls] + self.cum_dataset_length = [ + sum(self.dataset_length[:i]) for i in range(self.n_dataset) + ] # cumulative dataset length + + # BatchSamplers for each source dataset + if self.shuffle: + self.src_batch_samplers = [ + BatchSampler( + sampler=RandomSampler( + ds, replacement=False, generator=self.generator + ), + batch_size=self.batch_size, + drop_last=self.drop_last, + ) + for ds in self.src_dataset_ls + ] + else: + self.src_batch_samplers = [ + BatchSampler( + sampler=SequentialSampler(ds), + batch_size=self.batch_size, + drop_last=self.drop_last, + ) + for ds in self.src_dataset_ls + ] + self.raw_batches = [ + list(bs) for bs in self.src_batch_samplers + ] # index in original dataset + self.n_batches = [len(b) for b in self.raw_batches] + self.n_total_batch = sum(self.n_batches) + + # sampling probability + if prob is None: + # if not given, decide by dataset length + self.prob = torch.tensor(self.n_batches) / self.n_total_batch + else: + self.prob = torch.as_tensor(prob) + + def __iter__(self): + """_summary_ + + Yields: + list(int): a batch of indics, corresponding to ConcatDataset of src_dataset_ls + """ + for _ in range(self.n_total_batch): + idx_ds = torch.multinomial( + self.prob, 1, replacement=True, generator=self.generator + ).item() + # if batch list is empty, generate new list + if 0 == len(self.raw_batches[idx_ds]): + self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds]) + # get a batch from list + batch_raw = self.raw_batches[idx_ds].pop() + # shift by cumulative dataset length + shift = self.cum_dataset_length[idx_ds] + batch = [n + shift for n in batch_raw] + + yield batch + + def __len__(self): + return self.n_total_batch + + +# Unit test +if "__main__" == __name__: + from torch.utils.data import ConcatDataset, DataLoader, Dataset + + class SimpleDataset(Dataset): + def __init__(self, start, len) -> None: + super().__init__() + self.start = start + self.len = len + + def __len__(self): + return self.len + + def __getitem__(self, index): + return self.start + index + + dataset_1 = SimpleDataset(0, 10) + dataset_2 = SimpleDataset(200, 20) + dataset_3 = SimpleDataset(1000, 50) + + concat_dataset = ConcatDataset( + [dataset_1, dataset_2, dataset_3] + ) # will directly concatenate + + mixed_sampler = MixedBatchSampler( + src_dataset_ls=[dataset_1, dataset_2, dataset_3], + batch_size=4, + drop_last=True, + shuffle=False, + prob=[0.6, 0.3, 0.1], + generator=torch.Generator().manual_seed(0), + ) + + loader = DataLoader(concat_dataset, batch_sampler=mixed_sampler) + + for d in loader: + print(d) diff --git a/src/dataset/nyu_dataset.py b/src/dataset/nyu_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..aee3272f6060315c7cace2fa44a2429439f63e68 --- /dev/null +++ b/src/dataset/nyu_dataset.py @@ -0,0 +1,74 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import torch + +from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode +from .base_normals_dataset import BaseNormalsDataset + + +class NYUDepthDataset(BaseDepthDataset): + def __init__( + self, + eigen_valid_mask: bool, + **kwargs, + ) -> None: + super().__init__( + # NYUv2 dataset parameter + min_depth=1e-3, + max_depth=10.0, + has_filled_depth=True, + name_mode=DepthFileNameMode.rgb_id, + **kwargs, + ) + + self.eigen_valid_mask = eigen_valid_mask + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode NYU depth + depth_decoded = depth_in / 1000.0 + return depth_decoded + + def _get_valid_mask(self, depth: torch.Tensor): + valid_mask = super()._get_valid_mask(depth) + + # Eigen crop for evaluation + if self.eigen_valid_mask: + eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() + eval_mask[45:471, 41:601] = 1 + eval_mask.reshape(valid_mask.shape) + valid_mask = torch.logical_and(valid_mask, eval_mask) + + return valid_mask + + +class NYUNormalsDataset(BaseNormalsDataset): + pass diff --git a/src/dataset/oasis_dataset.py b/src/dataset/oasis_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..29476e70b5110663d1fc1468532b271676f87493 --- /dev/null +++ b/src/dataset/oasis_dataset.py @@ -0,0 +1,35 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +from .base_normals_dataset import BaseNormalsDataset + + +class OasisNormalsDataset(BaseNormalsDataset): + pass diff --git a/src/dataset/scannet_dataset.py b/src/dataset/scannet_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8ed94cc6ec1d8b46aec3522d241979fe38f6b3 --- /dev/null +++ b/src/dataset/scannet_dataset.py @@ -0,0 +1,57 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode +from .base_normals_dataset import BaseNormalsDataset + + +class ScanNetDepthDataset(BaseDepthDataset): + def __init__( + self, + **kwargs, + ) -> None: + super().__init__( + # ScanNet data parameter + min_depth=1e-3, + max_depth=10, + has_filled_depth=False, + name_mode=DepthFileNameMode.id, + **kwargs, + ) + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode ScanNet depth + depth_decoded = depth_in / 1000.0 + return depth_decoded + + +class ScanNetNormalsDataset(BaseNormalsDataset): + pass diff --git a/src/dataset/sintel_dataset.py b/src/dataset/sintel_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..13f721b9f1163d248316807fc7976edfe0c58916 --- /dev/null +++ b/src/dataset/sintel_dataset.py @@ -0,0 +1,77 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import torch + +from .base_normals_dataset import BaseNormalsDataset + +# Sintel original resolution +H, W = 436, 1024 + + +# crop to [436,582] --> later upsample with factor 1.1 to [480,640] +# crop off 221 pixels on both sides (1024 - 2*221 = 582) +def center_crop(img): + assert img.shape[0] == 3 or img.shape[0] == 1, "Channel dim should be first dim" + crop = 221 + + out = img[:, :, crop : W - crop] # [3,H,W] + + return out # [3,436,582] + + +class SintelNormalsDataset(BaseNormalsDataset): + def _load_rgb_data(self, rgb_rel_path): + # Read RGB data + rgb = self._read_rgb_file(rgb_rel_path) + rgb = center_crop(rgb) + rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] + + outputs = { + "rgb_int": torch.from_numpy(rgb).int(), + "rgb_norm": torch.from_numpy(rgb_norm).float(), + } + return outputs + + def _load_normals_data(self, normals_rel_path): + outputs = {} + normals = torch.from_numpy( + self._read_normals_file(normals_rel_path) + ).float() # [3,H,W] + + # replace invalid sky values with camera facing normals + valid_normal_mask = torch.norm(normals, p=2, dim=0) > 0.1 + normals[:, ~valid_normal_mask] = torch.tensor( + [0.0, 0.0, 1.0], dtype=normals.dtype + ).view(3, 1) + # crop on both sides + outputs["normals"] = center_crop(normals) + + return outputs diff --git a/src/dataset/vkitti_dataset.py b/src/dataset/vkitti_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e9294f95534699d9b7b3b5845829fba02b9c788f --- /dev/null +++ b/src/dataset/vkitti_dataset.py @@ -0,0 +1,108 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import torch + +from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode +from .kitti_dataset import KITTIDepthDataset + + +class VirtualKITTIDepthDataset(BaseDepthDataset): + def __init__( + self, + kitti_bm_crop, # Crop to KITTI benchmark size + valid_mask_crop, # Evaluation mask. [None, garg or eigen] + **kwargs, + ) -> None: + super().__init__( + # virtual KITTI data parameter + min_depth=1e-5, + max_depth=80, # 655.35 + has_filled_depth=False, + name_mode=DepthFileNameMode.id, + **kwargs, + ) + self.kitti_bm_crop = kitti_bm_crop + self.valid_mask_crop = valid_mask_crop + assert self.valid_mask_crop in [ + None, + "garg", # set evaluation mask according to Garg ECCV16 + "eigen", # set evaluation mask according to Eigen NIPS14 + ], f"Unknown crop type: {self.valid_mask_crop}" + + # Filter out empty depth + self.filenames = [f for f in self.filenames if "None" != f[1]] + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode vKITTI depth + depth_decoded = depth_in / 100.0 + return depth_decoded + + def _load_rgb_data(self, rgb_rel_path): + rgb_data = super()._load_rgb_data(rgb_rel_path) + if self.kitti_bm_crop: + rgb_data = { + k: KITTIDepthDataset.kitti_benchmark_crop(v) + for k, v in rgb_data.items() + } + return rgb_data + + def _load_depth_data(self, depth_rel_path, filled_rel_path): + depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path) + if self.kitti_bm_crop: + depth_data = { + k: KITTIDepthDataset.kitti_benchmark_crop(v) + for k, v in depth_data.items() + } + return depth_data + + def _get_valid_mask(self, depth: torch.Tensor): + # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py + valid_mask = super()._get_valid_mask(depth) # [1, H, W] + + if self.valid_mask_crop is not None: + eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() + gt_height, gt_width = eval_mask.shape + + if "garg" == self.valid_mask_crop: + eval_mask[ + int(0.40810811 * gt_height) : int(0.99189189 * gt_height), + int(0.03594771 * gt_width) : int(0.96405229 * gt_width), + ] = 1 + elif "eigen" == self.valid_mask_crop: + eval_mask[ + int(0.3324324 * gt_height) : int(0.91351351 * gt_height), + int(0.0359477 * gt_width) : int(0.96405229 * gt_width), + ] = 1 + + eval_mask.reshape(valid_mask.shape) + valid_mask = torch.logical_and(valid_mask, eval_mask) + return valid_mask diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fbee2deea204cd63de3a4e7c8b0cd0c4c1897210 --- /dev/null +++ b/src/trainer/__init__.py @@ -0,0 +1,44 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +from .marigold_depth_trainer import MarigoldDepthTrainer +from .marigold_iid_trainer import MarigoldIIDTrainer +from .marigold_normals_trainer import MarigoldNormalsTrainer + + +trainer_cls_name_dict = { + "MarigoldDepthTrainer": MarigoldDepthTrainer, + "MarigoldIIDTrainer": MarigoldIIDTrainer, + "MarigoldNormalsTrainer": MarigoldNormalsTrainer, +} + + +def get_trainer_cls(trainer_name): + return trainer_cls_name_dict[trainer_name] diff --git a/src/trainer/marigold_depth_trainer.py b/src/trainer/marigold_depth_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..749f78dfcaa8b96cde946410c71ccd2abad20852 --- /dev/null +++ b/src/trainer/marigold_depth_trainer.py @@ -0,0 +1,699 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import logging +import numpy as np +import os +import shutil +import torch +from PIL import Image +from datetime import datetime +from diffusers import DDPMScheduler, DDIMScheduler +from omegaconf import OmegaConf +from torch.nn import Conv2d +from torch.nn.parameter import Parameter +from torch.optim import Adam +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader +from tqdm import tqdm +from typing import List, Union + +from marigold.marigold_depth_pipeline import MarigoldDepthPipeline, MarigoldDepthOutput +from src.util import metric +from src.util.alignment import align_depth_least_square +from src.util.data_loader import skip_first_batches +from src.util.logging_util import tb_logger, eval_dict_to_text +from src.util.loss import get_loss +from src.util.lr_scheduler import IterExponential +from src.util.metric import MetricTracker +from src.util.multi_res_noise import multi_res_noise_like +from src.util.seeding import generate_seed_sequence + + +class MarigoldDepthTrainer: + def __init__( + self, + cfg: OmegaConf, + model: MarigoldDepthPipeline, + train_dataloader: DataLoader, + device, + out_dir_ckpt, + out_dir_eval, + out_dir_vis, + accumulation_steps: int, + val_dataloaders: List[DataLoader] = None, + vis_dataloaders: List[DataLoader] = None, + ): + self.cfg: OmegaConf = cfg + self.model: MarigoldDepthPipeline = model + self.device = device + self.seed: Union[int, None] = ( + self.cfg.trainer.init_seed + ) # used to generate seed sequence, set to `None` to train w/o seeding + self.out_dir_ckpt = out_dir_ckpt + self.out_dir_eval = out_dir_eval + self.out_dir_vis = out_dir_vis + self.train_loader: DataLoader = train_dataloader + self.val_loaders: List[DataLoader] = val_dataloaders + self.vis_loaders: List[DataLoader] = vis_dataloaders + self.accumulation_steps: int = accumulation_steps + + # Adapt input layers + if 8 != self.model.unet.config["in_channels"]: + self._replace_unet_conv_in() + + # Encode empty text prompt + self.model.encode_empty_text() + self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) + + self.model.unet.enable_xformers_memory_efficient_attention() + + # Trainability + self.model.vae.requires_grad_(False) + self.model.text_encoder.requires_grad_(False) + self.model.unet.requires_grad_(True) + + # Optimizer !should be defined after input layer is adapted + lr = self.cfg.lr + self.optimizer = Adam(self.model.unet.parameters(), lr=lr) + + # LR scheduler + lr_func = IterExponential( + total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, + final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, + warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, + ) + self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) + + # Loss + self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) + + # Training noise scheduler + self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_config( + self.model.scheduler.config, + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + ) + + logging.info( + "DDPM training noise scheduler config is updated: " + f"rescale_betas_zero_snr = {self.training_noise_scheduler.config.rescale_betas_zero_snr}, " + f"timestep_spacing = {self.training_noise_scheduler.config.timestep_spacing}" + ) + + self.prediction_type = self.training_noise_scheduler.config.prediction_type + assert ( + self.prediction_type == self.model.scheduler.config.prediction_type + ), "Different prediction types" + self.scheduler_timesteps = ( + self.training_noise_scheduler.config.num_train_timesteps + ) + + # Inference DDIM scheduler (used for validation) + self.model.scheduler = DDIMScheduler.from_config( + self.training_noise_scheduler.config, + ) + + # Eval metrics + self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] + + self.train_metrics = MetricTracker(*["loss"]) + self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) + + # main metric for best checkpoint saving + self.main_val_metric = cfg.validation.main_val_metric + self.main_val_metric_goal = cfg.validation.main_val_metric_goal + + assert ( + self.main_val_metric in cfg.eval.eval_metrics + ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." + + self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 + + # Settings + self.max_epoch = self.cfg.max_epoch + self.max_iter = self.cfg.max_iter + self.gradient_accumulation_steps = accumulation_steps + self.gt_depth_type = self.cfg.gt_depth_type + self.gt_mask_type = self.cfg.gt_mask_type + self.save_period = self.cfg.trainer.save_period + self.backup_period = self.cfg.trainer.backup_period + self.val_period = self.cfg.trainer.validation_period + self.vis_period = self.cfg.trainer.visualization_period + + # Multi-resolution noise + self.apply_multi_res_noise = self.cfg.multi_res_noise is not None + if self.apply_multi_res_noise: + self.mr_noise_strength = self.cfg.multi_res_noise.strength + self.annealed_mr_noise = self.cfg.multi_res_noise.annealed + self.mr_noise_downscale_strategy = ( + self.cfg.multi_res_noise.downscale_strategy + ) + + # Internal variables + self.epoch = 1 + self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training + self.effective_iter = 0 # how many times optimizer.step() is called + self.in_evaluation = False + self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming + + def _replace_unet_conv_in(self): + # replace the first layer to accept 8 in_channels + _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3] + _bias = self.model.unet.conv_in.bias.clone() # [320] + _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s) + # half the activation magnitude + _weight *= 0.5 + # new conv_in channel + _n_convin_out_channel = self.model.unet.conv_in.out_channels + _new_conv_in = Conv2d( + 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + _new_conv_in.weight = Parameter(_weight) + _new_conv_in.bias = Parameter(_bias) + self.model.unet.conv_in = _new_conv_in + logging.info("Unet conv_in layer is replaced") + # replace config + self.model.unet.config["in_channels"] = 8 + logging.info("Unet config is updated") + return + + def train(self, t_end=None): + logging.info("Start training") + + device = self.device + self.model.to(device) + + if self.in_evaluation: + logging.info( + "Last evaluation was not finished, will do evaluation before continue training." + ) + self.validate() + + self.train_metrics.reset() + accumulated_step = 0 + + for epoch in range(self.epoch, self.max_epoch + 1): + self.epoch = epoch + logging.debug(f"epoch: {self.epoch}") + + # Skip previous batches when resume + for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): + self.model.unet.train() + + # globally consistent random generators + if self.seed is not None: + local_seed = self._get_next_seed() + rand_num_generator = torch.Generator(device=device) + rand_num_generator.manual_seed(local_seed) + else: + rand_num_generator = None + + # >>> With gradient accumulation >>> + + # Get data + rgb = batch["rgb_norm"].to(device) + depth_gt_for_latent = batch[self.gt_depth_type].to(device) + + if self.gt_mask_type is not None: + valid_mask_for_latent = batch[self.gt_mask_type].to(device) + invalid_mask = ~valid_mask_for_latent + valid_mask_down = ~torch.max_pool2d( + invalid_mask.float(), 8, 8 + ).bool() + valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1)) + + batch_size = rgb.shape[0] + + with torch.no_grad(): + # Encode image + rgb_latent = self.encode_rgb(rgb) # [B, 4, h, w] + # Encode GT depth + gt_target_latent = self.encode_depth( + depth_gt_for_latent + ) # [B, 4, h, w] + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=device, + generator=rand_num_generator, + ).long() # [B] + + # Sample noise + if self.apply_multi_res_noise: + strength = self.mr_noise_strength + if self.annealed_mr_noise: + # calculate strength depending on t + strength = strength * (timesteps / self.scheduler_timesteps) + noise = multi_res_noise_like( + gt_target_latent, + strength=strength, + downscale_strategy=self.mr_noise_downscale_strategy, + generator=rand_num_generator, + device=device, + ) + else: + noise = torch.randn( + gt_target_latent.shape, + device=device, + generator=rand_num_generator, + ) # [B, 4, h, w] + + # Add noise to the latents (diffusion forward process) + noisy_latents = self.training_noise_scheduler.add_noise( + gt_target_latent, noise, timesteps + ) # [B, 4, h, w] + + # Text embedding + text_embed = self.empty_text_embed.to(device).repeat( + (batch_size, 1, 1) + ) # [B, 77, 1024] + + # Concat rgb and target latents + cat_latents = torch.cat( + [rgb_latent, noisy_latents], dim=1 + ) # [B, 8, h, w] + cat_latents = cat_latents.float() + + # Predict the noise residual + model_pred = self.model.unet( + cat_latents, timesteps, text_embed + ).sample # [B, 4, h, w] + if torch.isnan(model_pred).any(): + logging.warning("model_pred contains NaN.") + + # Get the target for loss depending on the prediction type + if "sample" == self.prediction_type: + target = gt_target_latent + elif "epsilon" == self.prediction_type: + target = noise + elif "v_prediction" == self.prediction_type: + target = self.training_noise_scheduler.get_velocity( + gt_target_latent, noise, timesteps + ) + else: + raise ValueError(f"Unknown prediction type {self.prediction_type}") + + # Masked latent loss + if self.gt_mask_type is not None: + latent_loss = self.loss( + model_pred[valid_mask_down].float(), + target[valid_mask_down].float(), + ) + else: + latent_loss = self.loss(model_pred.float(), target.float()) + + loss = latent_loss.mean() + + self.train_metrics.update("loss", loss.item()) + + loss = loss / self.gradient_accumulation_steps + loss.backward() + accumulated_step += 1 + + self.n_batch_in_epoch += 1 + # Practical batch end + + # Perform optimization step + if accumulated_step >= self.gradient_accumulation_steps: + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + accumulated_step = 0 + + self.effective_iter += 1 + + # Log to tensorboard + accumulated_loss = self.train_metrics.result()["loss"] + tb_logger.log_dict( + { + f"train/{k}": v + for k, v in self.train_metrics.result().items() + }, + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "lr", + self.lr_scheduler.get_last_lr()[0], + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "n_batch_in_epoch", + self.n_batch_in_epoch, + global_step=self.effective_iter, + ) + logging.info( + f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}" + ) + self.train_metrics.reset() + + # Per-step callback + self._train_step_callback() + + # End of training + if self.max_iter > 0 and self.effective_iter >= self.max_iter: + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), + save_train_state=False, + ) + logging.info("Training ended.") + return + # Time's up + elif t_end is not None and datetime.now() >= t_end: + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + logging.info("Time is up, training paused.") + return + + torch.cuda.empty_cache() + # <<< Effective batch end <<< + + # Epoch end + self.n_batch_in_epoch = 0 + + def encode_rgb(self, image_in): + assert len(image_in.shape) == 4 and image_in.shape[1] == 3 + latent = self.model.encode_rgb(image_in) + return latent + + def encode_depth(self, depth_in): + # stack depth into 3-channel + stacked = self.stack_depth_images(depth_in) + # encode using VAE encoder + depth_latent = self.model.encode_rgb(stacked) + return depth_latent + + @staticmethod + def stack_depth_images(depth_in): + if 4 == len(depth_in.shape): + stacked = depth_in.repeat(1, 3, 1, 1) + elif 3 == len(depth_in.shape): + stacked = depth_in.unsqueeze(1).repeat(1, 3, 1, 1) + return stacked + + def _train_step_callback(self): + """Executed after every iteration""" + # Save backup (with a larger interval, without training states) + if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), save_train_state=False + ) + + _is_latest_saved = False + # Validation + if self.val_period > 0 and 0 == self.effective_iter % self.val_period: + self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + _is_latest_saved = True + self.validate() + self.in_evaluation = False + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + + # Save training checkpoint (can be resumed) + if ( + self.save_period > 0 + and 0 == self.effective_iter % self.save_period + and not _is_latest_saved + ): + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + + # Visualization + if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: + self.visualize() + + def validate(self): + for i, val_loader in enumerate(self.val_loaders): + val_dataset_name = val_loader.dataset.disp_name + val_metric_dict = self.validate_single_dataset( + data_loader=val_loader, metric_tracker=self.val_metrics + ) + logging.info( + f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dict}" + ) + tb_logger.log_dict( + {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dict.items()}, + global_step=self.effective_iter, + ) + # save to file + eval_text = eval_dict_to_text( + val_metrics=val_metric_dict, + dataset_name=val_dataset_name, + sample_list_path=val_loader.dataset.filename_ls_path, + ) + _save_to = os.path.join( + self.out_dir_eval, + f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", + ) + with open(_save_to, "w+") as f: + f.write(eval_text) + + # Update main eval metric + if 0 == i: + main_eval_metric = val_metric_dict[self.main_val_metric] + if ( + "minimize" == self.main_val_metric_goal + and main_eval_metric < self.best_metric + or "maximize" == self.main_val_metric_goal + and main_eval_metric > self.best_metric + ): + self.best_metric = main_eval_metric + logging.info( + f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" + ) + # Save a checkpoint + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), save_train_state=False + ) + + def visualize(self): + for val_loader in self.vis_loaders: + vis_dataset_name = val_loader.dataset.disp_name + vis_out_dir = os.path.join( + self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name + ) + os.makedirs(vis_out_dir, exist_ok=True) + _ = self.validate_single_dataset( + data_loader=val_loader, + metric_tracker=self.val_metrics, + save_to_dir=vis_out_dir, + ) + + @torch.no_grad() + def validate_single_dataset( + self, + data_loader: DataLoader, + metric_tracker: MetricTracker, + save_to_dir: str = None, + ): + self.model.to(self.device) + metric_tracker.reset() + + # Generate seed sequence for consistent evaluation + val_init_seed = self.cfg.validation.init_seed + val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) + + for i, batch in enumerate( + tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), + start=1, + ): + assert 1 == data_loader.batch_size + # Read input image + rgb_int = batch["rgb_int"] # [B, 3, H, W] + # GT depth + depth_raw_ts = batch["depth_raw_linear"].squeeze() + depth_raw = depth_raw_ts.numpy() + depth_raw_ts = depth_raw_ts.to(self.device) + valid_mask_ts = batch["valid_mask_raw"].squeeze() + valid_mask = valid_mask_ts.numpy() + valid_mask_ts = valid_mask_ts.to(self.device) + + # Random number generator + seed = val_seed_ls.pop() + if seed is None: + generator = None + else: + generator = torch.Generator(device=self.device) + generator.manual_seed(seed) + + # Predict depth + pipe_out: MarigoldDepthOutput = self.model( + rgb_int, + denoising_steps=self.cfg.validation.denoising_steps, + ensemble_size=self.cfg.validation.ensemble_size, + processing_res=self.cfg.validation.processing_res, + match_input_res=self.cfg.validation.match_input_res, + generator=generator, + batch_size=1, # use batch size 1 to increase reproducibility + color_map=None, + show_progress_bar=False, + resample_method=self.cfg.validation.resample_method, + ) + + depth_pred: np.ndarray = pipe_out.depth_np + + if "least_square" == self.cfg.eval.alignment: + depth_pred, scale, shift = align_depth_least_square( + gt_arr=depth_raw, + pred_arr=depth_pred, + valid_mask_arr=valid_mask, + return_scale_shift=True, + max_resolution=self.cfg.eval.align_max_res, + ) + else: + raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}") + + # Clip to dataset min max + depth_pred = np.clip( + depth_pred, + a_min=data_loader.dataset.min_depth, + a_max=data_loader.dataset.max_depth, + ) + + # clip to d > 0 for evaluation + depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) + + # Evaluate + sample_metric = [] + depth_pred_ts = torch.from_numpy(depth_pred).to(self.device) + + for met_func in self.metric_funcs: + _metric_name = met_func.__name__ + _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item() + sample_metric.append(_metric.__str__()) + metric_tracker.update(_metric_name, _metric) + + # Save as 16-bit uint png + if save_to_dir is not None: + img_name = batch["rgb_relative_path"][0].replace("/", "_") + png_save_path = os.path.join(save_to_dir, f"{img_name}.png") + depth_to_save = (pipe_out.depth_np * 65535.0).astype(np.uint16) + Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") + + return metric_tracker.result() + + def _get_next_seed(self): + if 0 == len(self.global_seed_sequence): + self.global_seed_sequence = generate_seed_sequence( + initial_seed=self.seed, + length=self.max_iter * self.gradient_accumulation_steps, + ) + logging.info( + f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" + ) + return self.global_seed_sequence.pop() + + def save_checkpoint(self, ckpt_name, save_train_state): + ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) + logging.info(f"Saving checkpoint to: {ckpt_dir}") + # Backup previous checkpoint + temp_ckpt_dir = None + if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): + temp_ckpt_dir = os.path.join( + os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" + ) + if os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + os.rename(ckpt_dir, temp_ckpt_dir) + logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") + + # Save UNet + unet_path = os.path.join(ckpt_dir, "unet") + self.model.unet.save_pretrained(unet_path, safe_serialization=True) + logging.info(f"UNet is saved to: {unet_path}") + + # Save scheduler + scheduelr_path = os.path.join(ckpt_dir, "scheduler") + self.model.scheduler.save_pretrained(scheduelr_path) + logging.info(f"Scheduler is saved to: {scheduelr_path}") + + if save_train_state: + state = { + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "config": self.cfg, + "effective_iter": self.effective_iter, + "epoch": self.epoch, + "n_batch_in_epoch": self.n_batch_in_epoch, + "best_metric": self.best_metric, + "in_evaluation": self.in_evaluation, + "global_seed_sequence": self.global_seed_sequence, + } + train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") + torch.save(state, train_state_path) + # iteration indicator + f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") + f.close() + + logging.info(f"Trainer state is saved to: {train_state_path}") + + # Remove temp ckpt + if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + logging.debug("Old checkpoint backup is removed.") + + def load_checkpoint( + self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True + ): + logging.info(f"Loading checkpoint from: {ckpt_path}") + # Load UNet + _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") + self.model.unet.load_state_dict( + torch.load(_model_path, map_location=self.device) + ) + self.model.unet.to(self.device) + logging.info(f"UNet parameters are loaded from {_model_path}") + + # Load training states + if load_trainer_state: + checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) + self.effective_iter = checkpoint["effective_iter"] + self.epoch = checkpoint["epoch"] + self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] + self.in_evaluation = checkpoint["in_evaluation"] + self.global_seed_sequence = checkpoint["global_seed_sequence"] + + self.best_metric = checkpoint["best_metric"] + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + logging.info(f"optimizer state is loaded from {ckpt_path}") + + if resume_lr_scheduler: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + logging.info(f"LR scheduler state is loaded from {ckpt_path}") + + logging.info( + f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" + ) + return + + def _get_backup_ckpt_name(self): + return f"iter_{self.effective_iter:06d}" diff --git a/src/trainer/marigold_iid_trainer.py b/src/trainer/marigold_iid_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5250af1897aa11518b0eaa07eab9d3f647229535 --- /dev/null +++ b/src/trainer/marigold_iid_trainer.py @@ -0,0 +1,993 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import logging +import os +import shutil +import torch +from datetime import datetime +from diffusers import DDPMScheduler, DDIMScheduler +from omegaconf import OmegaConf +from torchmetrics.image import PeakSignalNoiseRatio +from torch.nn import Conv2d +from torch.nn.parameter import Parameter +from torch.optim import Adam +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader +from tqdm import tqdm +from typing import List, Union, Optional + +from marigold.marigold_iid_pipeline import MarigoldIIDPipeline, MarigoldIIDOutput +from src.util.data_loader import skip_first_batches +from src.util.image_util import ( + img_normalize, + img_float2int, + img_srgb2linear, + img_linear2srgb, +) +from src.util.logging_util import tb_logger, eval_dict_to_text +from src.util.loss import get_loss +from src.util.lr_scheduler import IterExponential +from src.util.metric import MetricTracker +from src.util.multi_res_noise import multi_res_noise_like +from src.util.seeding import generate_seed_sequence +from src.util.metric import compute_iid_metric +import torch.nn as nn +from peft import LoraConfig, get_peft_model, PeftModel + +from diffusers.loaders import ( + LoraLoaderMixin, + TextualInversionLoaderMixin, +) + +class MarigoldIIDTrainer: + def __init__( + self, + cfg: OmegaConf, + model: MarigoldIIDPipeline, + train_dataloader: DataLoader, + device, + out_dir_ckpt, + out_dir_eval, + out_dir_vis, + accumulation_steps: int, + val_dataloaders: List[DataLoader] = None, + vis_dataloaders: List[DataLoader] = None, + ): + self.cfg: OmegaConf = cfg + self.model: MarigoldIIDPipeline = model + self.device = device + self.seed: Union[int, None] = ( + self.cfg.trainer.init_seed + ) # used to generate seed sequence, set to `None` to train w/o seeding + self.out_dir_ckpt = out_dir_ckpt + self.out_dir_eval = out_dir_eval + self.out_dir_vis = out_dir_vis + self.train_loader: DataLoader = train_dataloader + self.val_loaders: List[DataLoader] = val_dataloaders + self.vis_loaders: List[DataLoader] = vis_dataloaders + self.accumulation_steps: int = accumulation_steps + + # Adapt input layers + if 4 * (model.n_targets + 1) != self.model.unet.config["in_channels"]: + self._replace_unet_conv_in_out_multimodal() + + # Encode empty text prompt + self.model.encode_empty_text() + self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) + + self.model.unet.enable_xformers_memory_efficient_attention() + + # Trainability + self.model.vae.requires_grad_(False) + self.model.text_encoder.requires_grad_(False) + self.model.unet.requires_grad_(False) + + unet_lora_config = LoraConfig( + r=8, + lora_alpha=16, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + self.model.unet = get_peft_model(self.model.unet, unet_lora_config) + lora_layers = filter(lambda p: p.requires_grad, self.model.unet.parameters()) + + trainable, frozen = 0, 0 + for name, param in self.model.unet.named_parameters(): + if param.requires_grad: + print(f"✅ Trainable: {name}, shape={param.shape}") + trainable += param.numel() + else: + frozen += param.numel() + + print(f"\nTotal trainable params: {trainable:,}") + print(f"Total frozen params: {frozen:,}") + print(f"Trainable ratio: {trainable / (trainable + frozen) * 100:.4f}%") + + + # Optimizer !should be defined after input layer is adapted + lr = self.cfg.lr + # self.optimizer = Adam(self.model.unet.parameters(), lr=lr) + self.optimizer = Adam(lora_layers, lr=lr) + + self.targets_to_eval_in_linear_space = ( + self.cfg.eval.targets_to_eval_in_linear_space + ) + + # LR scheduler + # lr_func = IterExponential( + # total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, + # final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, + # warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, + # ) + + lr_func = lambda step: 1.0 + self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) + + # Loss + self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) + + # Training noise scheduler + self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_config( + self.model.scheduler.config, + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + ) + + logging.info( + "DDPM training noise scheduler config is updated: " + f"rescale_betas_zero_snr = {self.training_noise_scheduler.config.rescale_betas_zero_snr}, " + f"timestep_spacing = {self.training_noise_scheduler.config.timestep_spacing}" + ) + + self.prediction_type = self.training_noise_scheduler.config.prediction_type + assert ( + self.prediction_type == self.model.scheduler.config.prediction_type + ), "Different prediction types" + self.scheduler_timesteps = ( + self.training_noise_scheduler.config.num_train_timesteps + ) + + # Inference DDIM scheduler (used for validation) + self.model.scheduler = DDIMScheduler.from_config( + self.training_noise_scheduler.config, + ) + + self.train_metrics = MetricTracker(*["loss"]) + val_metric_names = [ + f"{target}_{cfg.validation.main_val_metric}" + for target in model.target_names + ] + + self.val_metrics = MetricTracker(*val_metric_names) + self._val_metric = PeakSignalNoiseRatio(data_range=1.0).to(device) + + # main metric for best checkpoint saving + if "albedo" in model.target_names: + self.main_val_metric = "albedo_" + cfg.validation.main_val_metric + else: + self.main_val_metric = ( + model.target_names[0] + "_" + cfg.validation.main_val_metric + ) + + self.main_val_metric_goal = cfg.validation.main_val_metric_goal + + assert ( + self.main_val_metric in val_metric_names + ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." + + self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 + + # Settings + self.max_epoch = self.cfg.max_epoch + self.max_iter = self.cfg.max_iter + self.gradient_accumulation_steps = accumulation_steps + self.gt_mask_type = self.cfg.gt_mask_type + self.save_period = self.cfg.trainer.save_period + self.backup_period = self.cfg.trainer.backup_period + self.val_period = self.cfg.trainer.validation_period + self.vis_period = self.cfg.trainer.visualization_period + self.tokenizer = self.model.tokenizer + self.text_encoder = self.model.text_encoder.to(device) + + # Multi-resolution noise + self.apply_multi_res_noise = self.cfg.multi_res_noise is not None + if self.apply_multi_res_noise: + self.mr_noise_strength = self.cfg.multi_res_noise.strength + self.annealed_mr_noise = self.cfg.multi_res_noise.annealed + self.mr_noise_downscale_strategy = ( + self.cfg.multi_res_noise.downscale_strategy + ) + + # Internal variables + self.epoch = 1 + self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training + self.effective_iter = 0 # how many times optimizer.step() is called + self.in_evaluation = False + self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logging.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.text_encoder.dtype, device=device + ) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] + prompt_embeds = torch.cat( + [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] + ) + + return prompt_embeds + + def _replace_unet_conv_in_out_multimodal(self): + n_outputs = self.model.n_targets + # replace the first layer to accept (n+1)*4 in_channels + _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3] + _bias = self.model.unet.conv_in.bias.clone() # [320] + if _weight.shape[1] == 12 or _weight.shape[1] == 16: + _weight = _weight[:, :(n_outputs + 1)*4, :, :] + _n_convin_out_channel = self.model.unet.conv_in.out_channels + _new_conv_in = Conv2d( + (n_outputs + 1) * 4, + _n_convin_out_channel, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + ) + _new_conv_in.weight = Parameter(_weight) + _new_conv_in.bias = Parameter(_bias) + self.model.unet.conv_in = _new_conv_in + logging.info("Unet conv_in layer is replaced by selecting first {} channels".format((n_outputs + 1)*4)) + + elif _weight.shape[1] == 4: + _weight = _weight.repeat((1, n_outputs + 1, 1, 1)) # Keep selected channel(s) + # scale the activation magnitude + _weight /= n_outputs + 1 + # new conv_in channel + _n_convin_out_channel = self.model.unet.conv_in.out_channels + _new_conv_in = Conv2d( + (n_outputs + 1) * 4, + _n_convin_out_channel, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + ) + _new_conv_in.weight = Parameter(_weight) + _new_conv_in.bias = Parameter(_bias) + self.model.unet.conv_in = _new_conv_in + logging.info("Unet conv_in layer is replaced by repeating the weights") + + # replace the last layer to output n*4 in_channels + _weight = self.model.unet.conv_out.weight.clone() # [4, 320, 3, 3] + _bias = self.model.unet.conv_out.bias.clone() # [4] + if (_weight.shape[0] == 8 and _bias.shape[0] == 8) or (_weight.shape[0] == 12 and _bias.shape[0] == 12): + _weight = _weight[:(n_outputs * 4), :, :, :] + _bias = _bias[:(n_outputs * 4)] + # Since we are repeating output channels, no need to scale the weights here. + _n_convout_in_channel = self.model.unet.conv_out.in_channels + _new_conv_out = Conv2d( + _n_convout_in_channel, + n_outputs * 4, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + ) + _new_conv_out.weight = Parameter(_weight) + _new_conv_out.bias = Parameter(_bias) + self.model.unet.conv_out = _new_conv_out + logging.info("Unet conv_out layer is replaced by selecting first {} channels".format(n_outputs * 4)) + elif _weight.shape[0] == 4: + _weight = _weight.repeat((n_outputs, 1, 1, 1)) + _bias = _bias.repeat(n_outputs) + # Since we are repeating output channels, no need to scale the weights here. + _n_convout_in_channel = self.model.unet.conv_out.in_channels + _new_conv_out = Conv2d( + _n_convout_in_channel, + n_outputs * 4, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + ) + _new_conv_out.weight = Parameter(_weight) + _new_conv_out.bias = Parameter(_bias) + self.model.unet.conv_out = _new_conv_out + logging.info("Unet conv_out layer is replaced by repeating the weights") + + # replace config + self.model.unet.config["in_channels"] = (n_outputs + 1) * 4 + self.model.unet.config["out_channels"] = n_outputs * 4 + logging.info("Unet config is updated") + return + + def train(self, t_end=None): + logging.info("Start training") + + device = self.device + self.model.to(device) + + if self.in_evaluation: + logging.info( + "Last evaluation was not finished, will do evaluation before continue training." + ) + self.validate() + + self.validate() + self.visualize() + + self.train_metrics.reset() + accumulated_step = 0 + + ch_target_latent = 4 * self.model.n_targets + + for epoch in range(self.epoch, self.max_epoch + 1): + self.epoch = epoch + logging.debug(f"epoch: {self.epoch}") + + # Skip previous batches when resume + for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): + self.model.unet.train() + + # globally consistent random generators + if self.seed is not None: + local_seed = self._get_next_seed() + rand_num_generator = torch.Generator(device=device) + rand_num_generator.manual_seed(local_seed) + else: + rand_num_generator = None + + # >>> With gradient accumulation >>> + + # Get data, send to device, normalize + batch["rgb"] = img_normalize(batch["rgb"].to(device)) + for modality in self.model.target_names: + batch[modality] = img_normalize(batch[modality].to(device)) + + if self.gt_mask_type is not None: + valid_mask_for_latent = batch[self.gt_mask_type].to(device) + invalid_mask = ~valid_mask_for_latent + valid_mask_down = ~torch.max_pool2d( + invalid_mask.float(), 8, 8 + ).bool() + valid_mask_down = valid_mask_down.repeat( + (1, ch_target_latent, 1, 1) + ) + + batch_size = batch["rgb"].shape[0] + + with torch.no_grad(): + # Encode image + rgb_latent = self.encode_rgb(batch["rgb"]) # [B, 4, h, w] + # Encode iid properties + gt_target_latent = torch.cat( + [ + self.encode_rgb(batch[target_name]) + for target_name in self.model.target_names + ], + dim=1, + ) # [B, 4*n_targets, h, w] + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=device, + generator=rand_num_generator, + ).long() # [B] + + # Sample noise + if self.apply_multi_res_noise: + strength = self.mr_noise_strength + if self.annealed_mr_noise: + # calculate strength depending on t + strength = strength * (timesteps / self.scheduler_timesteps) + noise = multi_res_noise_like( + gt_target_latent, + strength=strength, + downscale_strategy=self.mr_noise_downscale_strategy, + generator=rand_num_generator, + device=device, + ) + else: + noise = torch.randn( + gt_target_latent.shape, + device=device, + generator=rand_num_generator, + ) # [B, 4*n_targets, h, w] + + # Add noise to the latents (diffusion forward process) + noisy_latents = self.training_noise_scheduler.add_noise( + gt_target_latent, noise, timesteps + ) # [B, 4*n_targets, h, w] + + # Text embedding + # text_embed = self.empty_text_embed.to(device).repeat( + # (batch_size, 1, 1) + # ) # [B, 77, 1024] + + prompt_embeds = None + prompt_embeds = self._encode_prompt( + "Albedo (diffuse basecolor)", + device, + num_images_per_prompt=1, + do_classifier_free_guidance=False, + negative_prompt=None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=None, + ) + text_embed = prompt_embeds.to(device).repeat((batch_size, 1, 1)) + + # Concat rgb and target latents + cat_latents = torch.cat( + # [rgb_latent, noisy_latents], dim=1 + [noisy_latents, rgb_latent], dim=1 + ) # [B, 4*n_targets + 4, h, w] + cat_latents = cat_latents.float() + + # Predict the noise residual + model_pred = self.model.unet( + cat_latents, timesteps, text_embed + ).sample # [B, 4*n_targets, h, w] + if torch.isnan(model_pred).any(): + logging.warning("model_pred contains NaN.") + + # Get the target for loss depending on the prediction type + if "sample" == self.prediction_type: + target = gt_target_latent + elif "epsilon" == self.prediction_type: + target = noise + elif "v_prediction" == self.prediction_type: + target = self.training_noise_scheduler.get_velocity( + gt_target_latent, noise, timesteps + ) + else: + raise ValueError(f"Unknown prediction type {self.prediction_type}") + + # Masked latent loss + if self.gt_mask_type is not None: + latent_loss = self.loss( + model_pred[valid_mask_down].float(), + target[valid_mask_down].float(), + ) + else: + latent_loss = self.loss(model_pred.float(), target.float()) + + loss = latent_loss.mean() + + self.train_metrics.update("loss", loss.item()) + + loss = loss / self.gradient_accumulation_steps + loss.backward() + accumulated_step += 1 + + self.n_batch_in_epoch += 1 + # Practical batch end + + # Perform optimization step + if accumulated_step >= self.gradient_accumulation_steps: + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + accumulated_step = 0 + + self.effective_iter += 1 + + # Log to tensorboard + accumulated_loss = self.train_metrics.result()["loss"] + tb_logger.log_dict( + { + f"train/{k}": v + for k, v in self.train_metrics.result().items() + }, + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "lr", + self.lr_scheduler.get_last_lr()[0], + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "n_batch_in_epoch", + self.n_batch_in_epoch, + global_step=self.effective_iter, + ) + logging.info( + f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}" + ) + self.train_metrics.reset() + + # Per-step callback + self._train_step_callback() + + # End of training + if self.max_iter > 0 and self.effective_iter >= self.max_iter: + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), + save_train_state=False, + ) + logging.info("Training ended.") + return + # Time's up + elif t_end is not None and datetime.now() >= t_end: + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + logging.info("Time is up, training paused.") + return + + torch.cuda.empty_cache() + # <<< Effective batch end <<< + + # Epoch end + self.n_batch_in_epoch = 0 + + def encode_rgb(self, image_in): + assert len(image_in.shape) == 4 and image_in.shape[1] == 3 + latent = self.model.encode_rgb(image_in) + return latent + + def _train_step_callback(self): + """Executed after every iteration""" + # Save backup (with a larger interval, without training states) + if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), save_train_state=False + ) + + _is_latest_saved = False + # Validation + if self.val_period > 0 and 0 == self.effective_iter % self.val_period: + self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + _is_latest_saved = True + self.validate() + self.in_evaluation = False + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + + # Save training checkpoint (can be resumed) + if ( + self.save_period > 0 + and 0 == self.effective_iter % self.save_period + and not _is_latest_saved + ): + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + + # Visualization + if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: + self.visualize() + + def validate(self): + for i, val_loader in enumerate(self.val_loaders): + val_dataset_name = val_loader.dataset.disp_name + val_metric_dict = self.validate_single_dataset( + data_loader=val_loader, metric_tracker=self.val_metrics + ) + logging.info( + f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dict}" + ) + tb_logger.log_dict( + {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dict.items()}, + global_step=self.effective_iter, + ) + # save to file + eval_text = eval_dict_to_text( + val_metrics=val_metric_dict, + dataset_name=val_dataset_name, + sample_list_path=val_loader.dataset.filename_ls_path, + ) + _save_to = os.path.join( + self.out_dir_eval, + f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", + ) + with open(_save_to, "w+") as f: + f.write(eval_text) + + # Update main eval metric + if 0 == i: + main_eval_metric = val_metric_dict[self.main_val_metric] + if ( + "minimize" == self.main_val_metric_goal + and main_eval_metric < self.best_metric + or "maximize" == self.main_val_metric_goal + and main_eval_metric > self.best_metric + ): + self.best_metric = main_eval_metric + logging.info( + f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" + ) + # Save a checkpoint + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), save_train_state=False + ) + else: + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), save_train_state=False + ) + + def visualize(self): + for val_loader in self.vis_loaders: + vis_dataset_name = val_loader.dataset.disp_name + vis_out_dir = os.path.join( + self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name + ) + os.makedirs(vis_out_dir, exist_ok=True) + _ = self.validate_single_dataset( + data_loader=val_loader, + metric_tracker=self.val_metrics, + save_to_dir=vis_out_dir, + ) + + @torch.no_grad() + def validate_single_dataset( + self, + data_loader: DataLoader, + metric_tracker: MetricTracker, + save_to_dir: str = None, + ): + if hasattr(self.model.unet, "peft_config"): + logging.info("Validation using LoRA-adapted UNet.") + else: + logging.info("Validation using full UNet (no LoRA).") + self.model.to(self.device) + metric_tracker.reset() + + # Generate seed sequence for consistent evaluation + val_init_seed = self.cfg.validation.init_seed + val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) + + for i, batch in enumerate( + tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), + start=1, + ): + assert 1 == data_loader.batch_size + # Read input image + img_int = img_float2int(batch["rgb"]) # [3, H, W] in [0, 255], sRGB space + # GT targets + for target_name in self.model.target_names: + batch[target_name] = batch[target_name].squeeze().to(self.device) + + # Random number generator + seed = val_seed_ls.pop() + if seed is None: + generator = None + else: + generator = torch.Generator(device=self.device) + generator.manual_seed(seed) + + # Predict materials + pipe_out: MarigoldIIDOutput = self.model( + img_int, + denoising_steps=self.cfg.validation.denoising_steps, + ensemble_size=self.cfg.validation.ensemble_size, + processing_res=self.cfg.validation.processing_res, + match_input_res=self.cfg.validation.match_input_res, + generator=generator, + batch_size=1, # use batch size 1 to increase reproducibility + show_progress_bar=False, + resample_method=self.cfg.validation.resample_method, + ) + + for target_name in self.model.target_names: + target_pred = pipe_out[target_name].array + target_pred_ts = ( + torch.from_numpy(target_pred).to(self.device).unsqueeze(0) + ) + target_gt = batch[target_name].to(self.device) + if self.cfg.validation.use_mask: + _mask_name = "mask_" + target_name + valid_mask = batch[_mask_name].to(self.device) + else: + valid_mask = None + if target_name in self.targets_to_eval_in_linear_space: + target_pred_ts = img_srgb2linear(target_pred_ts) + # Hypersim GT and IID Lighting model predictions are in linear space + # We evaluate albedo in sRGB space + if len(self.model.target_names) == 3 and target_name == "albedo": + # linear --> sRGB + target_gt = img_linear2srgb(target_gt) + target_pred_ts = img_linear2srgb(target_pred_ts) + + # eval pnsr + _metric_name = self.cfg.validation.main_val_metric + _metric_target = compute_iid_metric( + target_pred_ts, + target_gt, + target_name, + "psnr", + self._val_metric, + valid_mask=valid_mask, + ) + metric_tracker.update(f"{target_name}_{_metric_name}", _metric_target) + + # Save target as image + if save_to_dir is not None: + img_name = batch["rgb_relative_path"][0].replace("/", "_") + img_name_without_ext = os.path.splitext(img_name)[0] + + # Save target + target_save_path = os.path.join( + save_to_dir, f"{img_name_without_ext}_{target_name}.png" + ) + target_img = pipe_out[target_name].image + target_img.save(target_save_path) + + return metric_tracker.result() + + def _get_next_seed(self): + if 0 == len(self.global_seed_sequence): + self.global_seed_sequence = generate_seed_sequence( + initial_seed=self.seed, + length=self.max_iter * self.gradient_accumulation_steps, + ) + logging.info( + f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" + ) + return self.global_seed_sequence.pop() + + def save_checkpoint(self, ckpt_name, save_train_state): + ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) + logging.info(f"Saving checkpoint to: {ckpt_dir}") + # Backup previous checkpoint + temp_ckpt_dir = None + if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): + temp_ckpt_dir = os.path.join( + os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" + ) + if os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + os.rename(ckpt_dir, temp_ckpt_dir) + logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") + + # Save UNet + unet_path = os.path.join(ckpt_dir, "unet") + if isinstance(self.model.unet, PeftModel): + print('save lora model') + self.model.unet.save_pretrained(unet_path) + else: + self.model.unet.save_pretrained(unet_path, safe_serialization=True) + logging.info(f"UNet is saved to: {unet_path}") + + # Save scheduler + scheduelr_path = os.path.join(ckpt_dir, "scheduler") + self.model.scheduler.save_pretrained(scheduelr_path) + logging.info(f"Scheduler is saved to: {scheduelr_path}") + + if save_train_state: + state = { + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "config": self.cfg, + "effective_iter": self.effective_iter, + "epoch": self.epoch, + "n_batch_in_epoch": self.n_batch_in_epoch, + "best_metric": self.best_metric, + "in_evaluation": self.in_evaluation, + "global_seed_sequence": self.global_seed_sequence, + } + train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") + torch.save(state, train_state_path) + # iteration indicator + f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") + f.close() + + logging.info(f"Trainer state is saved to: {train_state_path}") + + # Remove temp ckpt + if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + logging.debug("Old checkpoint backup is removed.") + + def load_checkpoint( + self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True + ): + logging.info(f"Loading checkpoint from: {ckpt_path}") + # Load UNet + _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") + self.model.unet.load_state_dict( + torch.load(_model_path, map_location=self.device) + ) + self.model.unet.to(self.device) + logging.info(f"UNet parameters are loaded from {_model_path}") + + # Load training states + if load_trainer_state: + checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) + self.effective_iter = checkpoint["effective_iter"] + self.epoch = checkpoint["epoch"] + self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] + self.in_evaluation = checkpoint["in_evaluation"] + self.global_seed_sequence = checkpoint["global_seed_sequence"] + + self.best_metric = checkpoint["best_metric"] + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + logging.info(f"optimizer state is loaded from {ckpt_path}") + + if resume_lr_scheduler: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + logging.info(f"LR scheduler state is loaded from {ckpt_path}") + + logging.info( + f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" + ) + return + + def _get_backup_ckpt_name(self): + return f"iter_{self.effective_iter:06d}" diff --git a/src/trainer/marigold_normals_trainer.py b/src/trainer/marigold_normals_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..79f4691a3fb19ea98d83f9fa2e5b3d8f05023f0e --- /dev/null +++ b/src/trainer/marigold_normals_trainer.py @@ -0,0 +1,666 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import logging +import numpy as np +import os +import shutil +import torch +from PIL import Image +from datetime import datetime +from diffusers import DDPMScheduler, DDIMScheduler +from omegaconf import OmegaConf +from torch.nn import Conv2d +from torch.nn.parameter import Parameter +from torch.optim import Adam +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader +from tqdm import tqdm +from typing import List, Union + +from marigold.marigold_normals_pipeline import ( + MarigoldNormalsPipeline, + MarigoldNormalsOutput, +) +from src.util.image_util import img_chw2hwc +from src.util import metric +from src.util.data_loader import skip_first_batches +from src.util.logging_util import tb_logger, eval_dict_to_text +from src.util.loss import get_loss +from src.util.lr_scheduler import IterExponential +from src.util.metric import MetricTracker, compute_cosine_error +from src.util.multi_res_noise import multi_res_noise_like +from src.util.seeding import generate_seed_sequence + + +class MarigoldNormalsTrainer: + def __init__( + self, + cfg: OmegaConf, + model: MarigoldNormalsPipeline, + train_dataloader: DataLoader, + device, + out_dir_ckpt, + out_dir_eval, + out_dir_vis, + accumulation_steps: int, + val_dataloaders: List[DataLoader] = None, + vis_dataloaders: List[DataLoader] = None, + ): + self.cfg: OmegaConf = cfg + self.model: MarigoldNormalsPipeline = model + self.device = device + self.seed: Union[int, None] = ( + self.cfg.trainer.init_seed + ) # used to generate seed sequence, set to `None` to train w/o seeding + self.out_dir_ckpt = out_dir_ckpt + self.out_dir_eval = out_dir_eval + self.out_dir_vis = out_dir_vis + self.train_loader: DataLoader = train_dataloader + self.val_loaders: List[DataLoader] = val_dataloaders + self.vis_loaders: List[DataLoader] = vis_dataloaders + self.accumulation_steps: int = accumulation_steps + + # Adapt input layers + if 8 != self.model.unet.config["in_channels"]: + self._replace_unet_conv_in() + + # Encode empty text prompt + self.model.encode_empty_text() + self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) + + self.model.unet.enable_xformers_memory_efficient_attention() + + # Trainability + self.model.vae.requires_grad_(False) + self.model.text_encoder.requires_grad_(False) + self.model.unet.requires_grad_(True) + + # Optimizer !should be defined after input layer is adapted + lr = self.cfg.lr + self.optimizer = Adam(self.model.unet.parameters(), lr=lr) + + # LR scheduler + lr_func = IterExponential( + total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, + final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, + warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, + ) + self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) + + # Loss + self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) + + # Training noise scheduler + self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_config( + self.model.scheduler.config, + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + ) + + logging.info( + "DDPM training noise scheduler config is updated: " + f"rescale_betas_zero_snr = {self.training_noise_scheduler.config.rescale_betas_zero_snr}, " + f"timestep_spacing = {self.training_noise_scheduler.config.timestep_spacing}" + ) + + self.prediction_type = self.training_noise_scheduler.config.prediction_type + assert ( + self.prediction_type == self.model.scheduler.config.prediction_type + ), "Different prediction types" + self.scheduler_timesteps = ( + self.training_noise_scheduler.config.num_train_timesteps + ) + + # Inference DDIM scheduler (used for validation) + self.model.scheduler = DDIMScheduler.from_config( + self.training_noise_scheduler.config, + ) + + # Eval metrics + self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] + + self.train_metrics = MetricTracker(*["loss"]) + self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) + + # main metric for best checkpoint saving + self.main_val_metric = cfg.validation.main_val_metric + self.main_val_metric_goal = cfg.validation.main_val_metric_goal + + assert ( + self.main_val_metric in cfg.eval.eval_metrics + ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." + + self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 + + # Settings + self.max_epoch = self.cfg.max_epoch + self.max_iter = self.cfg.max_iter + self.gradient_accumulation_steps = accumulation_steps + self.gt_normals_type = self.cfg.gt_normals_type + self.gt_mask_type = self.cfg.gt_mask_type + self.save_period = self.cfg.trainer.save_period + self.backup_period = self.cfg.trainer.backup_period + self.val_period = self.cfg.trainer.validation_period + self.vis_period = self.cfg.trainer.visualization_period + + # Multi-resolution noise + self.apply_multi_res_noise = self.cfg.multi_res_noise is not None + if self.apply_multi_res_noise: + self.mr_noise_strength = self.cfg.multi_res_noise.strength + self.annealed_mr_noise = self.cfg.multi_res_noise.annealed + self.mr_noise_downscale_strategy = ( + self.cfg.multi_res_noise.downscale_strategy + ) + + # Internal variables + self.epoch = 1 + self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training + self.effective_iter = 0 # how many times optimizer.step() is called + self.in_evaluation = False + self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming + + def _replace_unet_conv_in(self): + # replace the first layer to accept 8 in_channels + _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3] + _bias = self.model.unet.conv_in.bias.clone() # [320] + _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s) + # half the activation magnitude + _weight *= 0.5 + # new conv_in channel + _n_convin_out_channel = self.model.unet.conv_in.out_channels + _new_conv_in = Conv2d( + 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + _new_conv_in.weight = Parameter(_weight) + _new_conv_in.bias = Parameter(_bias) + self.model.unet.conv_in = _new_conv_in + logging.info("Unet conv_in layer is replaced") + # replace config + self.model.unet.config["in_channels"] = 8 + logging.info("Unet config is updated") + return + + def train(self, t_end=None): + logging.info("Start training") + + device = self.device + self.model.to(device) + + if self.in_evaluation: + logging.info( + "Last evaluation was not finished, will do evaluation before continue training." + ) + self.validate() + + self.train_metrics.reset() + accumulated_step = 0 + + for epoch in range(self.epoch, self.max_epoch + 1): + self.epoch = epoch + logging.debug(f"epoch: {self.epoch}") + + # Skip previous batches when resume + for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): + self.model.unet.train() + + # globally consistent random generators + if self.seed is not None: + local_seed = self._get_next_seed() + rand_num_generator = torch.Generator(device=device) + rand_num_generator.manual_seed(local_seed) + else: + rand_num_generator = None + + # >>> With gradient accumulation >>> + + # Get data + rgb = batch["rgb_norm"].to(device) + normals_gt_for_latent = batch[self.gt_normals_type].to(device) + + if self.gt_mask_type is not None: + valid_mask_for_latent = batch[self.gt_mask_type].to(device) + invalid_mask = ~valid_mask_for_latent + valid_mask_down = ~torch.max_pool2d( + invalid_mask.float(), 8, 8 + ).bool() + valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1)) + + batch_size = rgb.shape[0] + + with torch.no_grad(): + # Encode image + rgb_latent = self.encode_rgb(rgb) # [B, 4, h, w] + # Encode GT normals + gt_target_latent = self.encode_rgb( + normals_gt_for_latent + ) # [B, 4, h, w] + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=device, + generator=rand_num_generator, + ).long() # [B] + + # Sample noise + if self.apply_multi_res_noise: + strength = self.mr_noise_strength + if self.annealed_mr_noise: + # calculate strength depending on t + strength = strength * (timesteps / self.scheduler_timesteps) + noise = multi_res_noise_like( + gt_target_latent, + strength=strength, + downscale_strategy=self.mr_noise_downscale_strategy, + generator=rand_num_generator, + device=device, + ) + else: + noise = torch.randn( + gt_target_latent.shape, + device=device, + generator=rand_num_generator, + ) # [B, 4, h, w] + + # Add noise to the latents (diffusion forward process) + noisy_latents = self.training_noise_scheduler.add_noise( + gt_target_latent, noise, timesteps + ) # [B, 4, h, w] + + # Text embedding + text_embed = self.empty_text_embed.to(device).repeat( + (batch_size, 1, 1) + ) # [B, 77, 1024] + + # Concat rgb and target latents + cat_latents = torch.cat( + [rgb_latent, noisy_latents], dim=1 + ) # [B, 8, h, w] + cat_latents = cat_latents.float() + + # Predict the noise residual + model_pred = self.model.unet( + cat_latents, timesteps, text_embed + ).sample # [B, 4, h, w] + if torch.isnan(model_pred).any(): + logging.warning("model_pred contains NaN.") + + # Get the target for loss depending on the prediction type + if "sample" == self.prediction_type: + target = gt_target_latent + elif "epsilon" == self.prediction_type: + target = noise + elif "v_prediction" == self.prediction_type: + target = self.training_noise_scheduler.get_velocity( + gt_target_latent, noise, timesteps + ) + else: + raise ValueError(f"Unknown prediction type {self.prediction_type}") + + # Masked latent loss + if self.gt_mask_type is not None: + latent_loss = self.loss( + model_pred[valid_mask_down].float(), + target[valid_mask_down].float(), + ) + else: + latent_loss = self.loss(model_pred.float(), target.float()) + + loss = latent_loss.mean() + + self.train_metrics.update("loss", loss.item()) + + loss = loss / self.gradient_accumulation_steps + loss.backward() + accumulated_step += 1 + + self.n_batch_in_epoch += 1 + # Practical batch end + + # Perform optimization step + if accumulated_step >= self.gradient_accumulation_steps: + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + accumulated_step = 0 + + self.effective_iter += 1 + + # Log to tensorboard + accumulated_loss = self.train_metrics.result()["loss"] + tb_logger.log_dict( + { + f"train/{k}": v + for k, v in self.train_metrics.result().items() + }, + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "lr", + self.lr_scheduler.get_last_lr()[0], + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "n_batch_in_epoch", + self.n_batch_in_epoch, + global_step=self.effective_iter, + ) + logging.info( + f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}" + ) + self.train_metrics.reset() + + # Per-step callback + self._train_step_callback() + + # End of training + if self.max_iter > 0 and self.effective_iter >= self.max_iter: + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), + save_train_state=False, + ) + logging.info("Training ended.") + return + # Time's up + elif t_end is not None and datetime.now() >= t_end: + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + logging.info("Time is up, training paused.") + return + + torch.cuda.empty_cache() + # <<< Effective batch end <<< + + # Epoch end + self.n_batch_in_epoch = 0 + + def encode_rgb(self, image_in): + assert len(image_in.shape) == 4 and image_in.shape[1] == 3 + latent = self.model.encode_rgb(image_in) + return latent + + def _train_step_callback(self): + """Executed after every iteration""" + # Save backup (with a larger interval, without training states) + if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), save_train_state=False + ) + + _is_latest_saved = False + # Validation + if self.val_period > 0 and 0 == self.effective_iter % self.val_period: + self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + _is_latest_saved = True + self.validate() + self.in_evaluation = False + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + + # Save training checkpoint (can be resumed) + if ( + self.save_period > 0 + and 0 == self.effective_iter % self.save_period + and not _is_latest_saved + ): + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + + # Visualization + if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: + self.visualize() + + def validate(self): + for i, val_loader in enumerate(self.val_loaders): + val_dataset_name = val_loader.dataset.disp_name + val_metric_dict = self.validate_single_dataset( + data_loader=val_loader, metric_tracker=self.val_metrics + ) + logging.info( + f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dict}" + ) + tb_logger.log_dict( + {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dict.items()}, + global_step=self.effective_iter, + ) + # save to file + eval_text = eval_dict_to_text( + val_metrics=val_metric_dict, + dataset_name=val_dataset_name, + sample_list_path=val_loader.dataset.filename_ls_path, + ) + _save_to = os.path.join( + self.out_dir_eval, + f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", + ) + with open(_save_to, "w+") as f: + f.write(eval_text) + + # Update main eval metric + if 0 == i: + main_eval_metric = val_metric_dict[self.main_val_metric] + if ( + "minimize" == self.main_val_metric_goal + and main_eval_metric < self.best_metric + or "maximize" == self.main_val_metric_goal + and main_eval_metric > self.best_metric + ): + self.best_metric = main_eval_metric + logging.info( + f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" + ) + # Save a checkpoint + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), save_train_state=False + ) + + def visualize(self): + for val_loader in self.vis_loaders: + vis_dataset_name = val_loader.dataset.disp_name + vis_out_dir = os.path.join( + self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name + ) + os.makedirs(vis_out_dir, exist_ok=True) + _ = self.validate_single_dataset( + data_loader=val_loader, + metric_tracker=self.val_metrics, + save_to_dir=vis_out_dir, + ) + + @torch.no_grad() + def validate_single_dataset( + self, + data_loader: DataLoader, + metric_tracker: MetricTracker, + save_to_dir: str = None, + ): + self.model.to(self.device) + metric_tracker.reset() + + # Generate seed sequence for consistent evaluation + val_init_seed = self.cfg.validation.init_seed + val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) + + for i, batch in enumerate( + tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), + start=1, + ): + assert 1 == data_loader.batch_size + # Read input image + rgb_int = batch["rgb_int"] # [B, 3, H, W] + # GT normals + normals_gt = batch["normals"].to(self.device) # [B, 3, H, W] + + # Random number generator + seed = val_seed_ls.pop() + if seed is None: + generator = None + else: + generator = torch.Generator(device=self.device) + generator.manual_seed(seed) + + # Predict normals + pipe_out: MarigoldNormalsOutput = self.model( + rgb_int, + denoising_steps=self.cfg.validation.denoising_steps, + ensemble_size=self.cfg.validation.ensemble_size, + processing_res=self.cfg.validation.processing_res, + match_input_res=self.cfg.validation.match_input_res, + generator=generator, + batch_size=1, # use batch size 1 to increase reproducibility + show_progress_bar=False, + resample_method=self.cfg.validation.resample_method, + ) + + normals_pred = pipe_out.normals_np # [3, H, W] + + normals_pred_ts = ( + torch.from_numpy(normals_pred).unsqueeze(0).to(self.device) + ) + cosine_error = compute_cosine_error( + normals_pred_ts, normals_gt, masked=True + ) + sample_metric = [] + + for met_func in self.metric_funcs: + _metric_name = met_func.__name__ + _metric = met_func(cosine_error).item() + sample_metric.append(_metric.__str__()) + metric_tracker.update(_metric_name, _metric) + + # Save predicted normals as images + if save_to_dir is not None: + img_name = batch["rgb_relative_path"][0].replace("/", "_") + png_save_path = os.path.join(save_to_dir, img_name) + normals_to_save = img_chw2hwc(((normals_pred + 1) * 127.5)).astype( + np.uint8 + ) + Image.fromarray(normals_to_save).save(png_save_path) + + return metric_tracker.result() + + def _get_next_seed(self): + if 0 == len(self.global_seed_sequence): + self.global_seed_sequence = generate_seed_sequence( + initial_seed=self.seed, + length=self.max_iter * self.gradient_accumulation_steps, + ) + logging.info( + f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" + ) + return self.global_seed_sequence.pop() + + def save_checkpoint(self, ckpt_name, save_train_state): + ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) + logging.info(f"Saving checkpoint to: {ckpt_dir}") + # Backup previous checkpoint + temp_ckpt_dir = None + if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): + temp_ckpt_dir = os.path.join( + os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" + ) + if os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + os.rename(ckpt_dir, temp_ckpt_dir) + logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") + + # Save UNet + unet_path = os.path.join(ckpt_dir, "unet") + self.model.unet.save_pretrained(unet_path, safe_serialization=True) + logging.info(f"UNet is saved to: {unet_path}") + + # Save scheduler + scheduelr_path = os.path.join(ckpt_dir, "scheduler") + self.model.scheduler.save_pretrained(scheduelr_path) + logging.info(f"Scheduler is saved to: {scheduelr_path}") + + if save_train_state: + state = { + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "config": self.cfg, + "effective_iter": self.effective_iter, + "epoch": self.epoch, + "n_batch_in_epoch": self.n_batch_in_epoch, + "best_metric": self.best_metric, + "in_evaluation": self.in_evaluation, + "global_seed_sequence": self.global_seed_sequence, + } + train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") + torch.save(state, train_state_path) + # iteration indicator + f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") + f.close() + + logging.info(f"Trainer state is saved to: {train_state_path}") + + # Remove temp ckpt + if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + logging.debug("Old checkpoint backup is removed.") + + def load_checkpoint( + self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True + ): + logging.info(f"Loading checkpoint from: {ckpt_path}") + # Load UNet + _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") + self.model.unet.load_state_dict( + torch.load(_model_path, map_location=self.device) + ) + self.model.unet.to(self.device) + logging.info(f"UNet parameters are loaded from {_model_path}") + + # Load training states + if load_trainer_state: + checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) + self.effective_iter = checkpoint["effective_iter"] + self.epoch = checkpoint["epoch"] + self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] + self.in_evaluation = checkpoint["in_evaluation"] + self.global_seed_sequence = checkpoint["global_seed_sequence"] + + self.best_metric = checkpoint["best_metric"] + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + logging.info(f"optimizer state is loaded from {ckpt_path}") + + if resume_lr_scheduler: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + logging.info(f"LR scheduler state is loaded from {ckpt_path}") + + logging.info( + f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" + ) + return + + def _get_backup_ckpt_name(self): + return f"iter_{self.effective_iter:06d}" diff --git a/src/util/alignment.py b/src/util/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a73c13d98b2e314a3bd192440e86c46502e81a --- /dev/null +++ b/src/util/alignment.py @@ -0,0 +1,99 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import numpy as np +import torch + + +def align_depth_least_square( + gt_arr: np.ndarray, + pred_arr: np.ndarray, + valid_mask_arr: np.ndarray, + return_scale_shift=True, + max_resolution=None, +): + ori_shape = pred_arr.shape # input shape + + gt = gt_arr.squeeze() # [H, W] + pred = pred_arr.squeeze() + valid_mask = valid_mask_arr.squeeze() + + # Downsample + if max_resolution is not None: + scale_factor = np.min(max_resolution / np.array(ori_shape[-2:])) + if scale_factor < 1: + downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") + gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy() + pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy() + valid_mask = ( + downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float()) + .bool() + .numpy() + ) + + assert ( + gt.shape == pred.shape == valid_mask.shape + ), f"{gt.shape}, {pred.shape}, {valid_mask.shape}" + + gt_masked = gt[valid_mask].reshape((-1, 1)) + pred_masked = pred[valid_mask].reshape((-1, 1)) + + # numpy solver + _ones = np.ones_like(pred_masked) + A = np.concatenate([pred_masked, _ones], axis=-1) + X = np.linalg.lstsq(A, gt_masked, rcond=None)[0] + scale, shift = X + + aligned_pred = pred_arr * scale + shift + + # restore dimensions + aligned_pred = aligned_pred.reshape(ori_shape) + + if return_scale_shift: + return aligned_pred, scale, shift + else: + return aligned_pred + + +def depth2disparity(depth, return_mask=False): + if isinstance(depth, torch.Tensor): + disparity = torch.zeros_like(depth) + elif isinstance(depth, np.ndarray): + disparity = np.zeros_like(depth) + non_negtive_mask = depth > 0 + disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask] + if return_mask: + return disparity, non_negtive_mask + else: + return disparity + + +def disparity2depth(disparity, **kwargs): + return depth2disparity(disparity, **kwargs) diff --git a/src/util/config_util.py b/src/util/config_util.py new file mode 100644 index 0000000000000000000000000000000000000000..9d29bbb0599d062d7b2f9146859021b7d627847e --- /dev/null +++ b/src/util/config_util.py @@ -0,0 +1,76 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import omegaconf +from omegaconf import OmegaConf + + +def recursive_load_config(config_path: str) -> OmegaConf: + conf = OmegaConf.load(config_path) + + output_conf = OmegaConf.create({}) + + # Load base config. Later configs on the list will overwrite previous + base_configs = conf.get("base_config", default_value=None) + if base_configs is not None: + assert isinstance(base_configs, omegaconf.listconfig.ListConfig) + for _path in base_configs: + assert ( + _path != config_path + ), "Circulate merging, base_config should not include itself." + _base_conf = recursive_load_config(_path) + output_conf = OmegaConf.merge(output_conf, _base_conf) + + # Merge configs and overwrite values + output_conf = OmegaConf.merge(output_conf, conf) + + return output_conf + + +def find_value_in_omegaconf(search_key, config): + result_list = [] + + if isinstance(config, omegaconf.DictConfig): + for key, value in config.items(): + if key == search_key: + result_list.append(value) + elif isinstance(value, (omegaconf.DictConfig, omegaconf.ListConfig)): + result_list.extend(find_value_in_omegaconf(search_key, value)) + elif isinstance(config, omegaconf.ListConfig): + for item in config: + if isinstance(item, (omegaconf.DictConfig, omegaconf.ListConfig)): + result_list.extend(find_value_in_omegaconf(search_key, item)) + + return result_list + + +if "__main__" == __name__: + conf = recursive_load_config("config/train_base.yaml") + print(OmegaConf.to_yaml(conf)) diff --git a/src/util/data_loader.py b/src/util/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..226631f16b097129a97357143f5eda597c32f6e4 --- /dev/null +++ b/src/util/data_loader.py @@ -0,0 +1,140 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +# Adapted from https://github.com/huggingface/accelerate/blob/e2ae254008061b3e53fc1c97f88d65743a857e75/src/accelerate/data_loader.py + +from torch.utils.data import BatchSampler, DataLoader, IterableDataset + +# kwargs of the DataLoader in min version 1.4.0. +_PYTORCH_DATALOADER_KWARGS = { + "batch_size": 1, + "shuffle": False, + "sampler": None, + "batch_sampler": None, + "num_workers": 0, + "collate_fn": None, + "pin_memory": False, + "drop_last": False, + "timeout": 0, + "worker_init_fn": None, + "multiprocessing_context": None, + "generator": None, + "prefetch_factor": 2, + "persistent_workers": False, +} + + +class SkipBatchSampler(BatchSampler): + """ + A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. + """ + + def __init__(self, batch_sampler, skip_batches=0): + self.batch_sampler = batch_sampler + self.skip_batches = skip_batches + + def __iter__(self): + for index, samples in enumerate(self.batch_sampler): + if index >= self.skip_batches: + yield samples + + @property + def total_length(self): + return len(self.batch_sampler) + + def __len__(self): + return len(self.batch_sampler) - self.skip_batches + + +class SkipDataLoader(DataLoader): + """ + Subclass of a PyTorch `DataLoader` that will skip the first batches. + + Args: + dataset (`torch.utils.data.dataset.Dataset`): + The dataset to use to build this datalaoder. + skip_batches (`int`, *optional*, defaults to 0): + The number of batches to skip at the beginning. + kwargs: + All other keyword arguments to pass to the regular `DataLoader` initialization. + """ + + def __init__(self, dataset, skip_batches=0, **kwargs): + super().__init__(dataset, **kwargs) + self.skip_batches = skip_batches + + def __iter__(self): + for index, batch in enumerate(super().__iter__()): + if index >= self.skip_batches: + yield batch + + +def skip_first_batches(dataloader, num_batches=0): + """ + Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. + """ + dataset = dataloader.dataset + sampler_is_batch_sampler = False + if isinstance(dataset, IterableDataset): + new_batch_sampler = None + else: + sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + batch_sampler = ( + dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler + ) + new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) + + # We ignore all of those since they are all dealt with by our new_batch_sampler + ignore_kwargs = [ + "batch_size", + "shuffle", + "sampler", + "batch_sampler", + "drop_last", + ] + + kwargs = { + k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) + for k in _PYTORCH_DATALOADER_KWARGS + if k not in ignore_kwargs + } + + # Need to provide batch_size as batch_sampler is None for Iterable dataset + if new_batch_sampler is None: + kwargs["drop_last"] = dataloader.drop_last + kwargs["batch_size"] = dataloader.batch_size + + if new_batch_sampler is None: + # Need to manually skip batches in the dataloader + dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) + else: + dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) + + return dataloader diff --git a/src/util/depth_transform.py b/src/util/depth_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..4c12308d5a14f7811d647201e286011f35f16594 --- /dev/null +++ b/src/util/depth_transform.py @@ -0,0 +1,130 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import logging +import torch + + +def get_depth_normalizer(cfg_normalizer): + if cfg_normalizer is None: + + def identical(x): + return x + + depth_transform = identical + + elif "scale_shift_depth" == cfg_normalizer.type: + depth_transform = ScaleShiftDepthNormalizer( + norm_min=cfg_normalizer.norm_min, + norm_max=cfg_normalizer.norm_max, + min_max_quantile=cfg_normalizer.min_max_quantile, + clip=cfg_normalizer.clip, + ) + else: + raise NotImplementedError + return depth_transform + + +class DepthNormalizerBase: + is_absolute = None + far_plane_at_max = None + + def __init__( + self, + norm_min=-1.0, + norm_max=1.0, + ) -> None: + self.norm_min = norm_min + self.norm_max = norm_max + raise NotImplementedError + + def __call__(self, depth, valid_mask=None, clip=None): + raise NotImplementedError + + def denormalize(self, depth_norm, **kwargs): + # For metric depth: convert prediction back to metric depth + # For relative depth: convert prediction to [0, 1] + raise NotImplementedError + + +class ScaleShiftDepthNormalizer(DepthNormalizerBase): + """ + Use near and far plane to linearly normalize depth, + i.e. d' = d * s + t, + where near plane is mapped to `norm_min`, and far plane is mapped to `norm_max` + Near and far planes are determined by taking quantile values. + """ + + is_absolute = False + far_plane_at_max = True + + def __init__( + self, norm_min=-1.0, norm_max=1.0, min_max_quantile=0.02, clip=True + ) -> None: + self.norm_min = norm_min + self.norm_max = norm_max + self.norm_range = self.norm_max - self.norm_min + self.min_quantile = min_max_quantile + self.max_quantile = 1.0 - self.min_quantile + self.clip = clip + + def __call__(self, depth_linear, valid_mask=None, clip=None): + clip = clip if clip is not None else self.clip + + if valid_mask is None: + valid_mask = torch.ones_like(depth_linear).bool() + valid_mask = valid_mask & (depth_linear > 0) + + # Take quantiles as min and max + _min, _max = torch.quantile( + depth_linear[valid_mask], + torch.tensor([self.min_quantile, self.max_quantile]), + ) + + # scale and shift + depth_norm_linear = (depth_linear - _min) / ( + _max - _min + ) * self.norm_range + self.norm_min + + if clip: + depth_norm_linear = torch.clip( + depth_norm_linear, self.norm_min, self.norm_max + ) + + return depth_norm_linear + + def scale_back(self, depth_norm): + # scale to [0, 1] + depth_linear = (depth_norm - self.norm_min) / self.norm_range + return depth_linear + + def denormalize(self, depth_norm, **kwargs): + logging.warning(f"{self.__class__} is not revertible without GT") + return self.scale_back(depth_norm=depth_norm) diff --git a/src/util/image_util.py b/src/util/image_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3c450e23bc7dc67f2bfee99f7ad7dfebf80e2e1f --- /dev/null +++ b/src/util/image_util.py @@ -0,0 +1,128 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import cv2 +import numpy as np +import tarfile +import torch +from PIL import Image +from io import BytesIO +from typing import Union + + +def img_hwc2chw(img: Union[np.ndarray, torch.Tensor]): + assert len(img.shape) == 3 + if isinstance(img, np.ndarray): + return np.transpose(img, (2, 0, 1)) + if isinstance(img, torch.Tensor): + return img.permute(2, 0, 1) + raise TypeError("img should be np.ndarray or torch.Tensor") + + +def img_chw2hwc(chw): + assert 3 == len(chw.shape) + if isinstance(chw, torch.Tensor): + hwc = torch.permute(chw, (1, 2, 0)) + elif isinstance(chw, np.ndarray): + hwc = np.moveaxis(chw, 0, -1) + else: + raise TypeError("img should be np.ndarray or torch.Tensor") + return hwc + + +def img_int2float(img, dtype=None): + if dtype is not None: + if isinstance(img, np.ndarray): + img = img.astype(dtype) + else: + img = img.to(dtype) + return img / 255.0 + + +def img_float2int(img): + if isinstance(img, np.ndarray): + return (img * 255.0).astype(np.uint8) + else: + return (img * 255.0).to(torch.uint8) + + +def img_normalize(img): + return img * 2.0 - 1.0 + + +def img_denormalize(img): + return img * 0.5 + 0.5 + + +def img_linear2srgb(img): + return img ** (1 / 2.2) + + +def img_srgb2linear(img): + return img**2.2 + + +def write_img(img: np.ndarray, path): + img = img_float2int(img) + if len(img.shape) == 3: + img = img[:, :, ::-1] # RGB->BGR + cv2.imwrite(path, img) + + +def _read_image_from_buffer(buffer: BytesIO, is_hdr: bool) -> np.ndarray: + if is_hdr: + file_bytes = np.frombuffer(buffer.read(), dtype=np.uint8) + img = cv2.imdecode(file_bytes, cv2.IMREAD_UNCHANGED) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = np.clip(img, 0, 1) + else: + img = Image.open(buffer) # [H, W, rgb] + img = np.asarray(img) + img = img_int2float(img) + + return img + + +def is_hdr(path: str): + return path.endswith(".exr") + + +def read_img_from_tar(tar_file: tarfile.TarFile, rel_path: str) -> np.ndarray: + tar_obj = tar_file.extractfile(rel_path) + buffer = BytesIO(tar_obj.read()) + img = _read_image_from_buffer(buffer, is_hdr(rel_path)) + return img + + +def read_img_from_file(path: str) -> np.ndarray: + with open(path, "rb") as f: + buffer = BytesIO(f.read()) + img = _read_image_from_buffer(buffer, is_hdr(path)) + return img diff --git a/src/util/logging_util.py b/src/util/logging_util.py new file mode 100644 index 0000000000000000000000000000000000000000..31c553c599ced0366ca70b19bd99f6e270b0cb1a --- /dev/null +++ b/src/util/logging_util.py @@ -0,0 +1,129 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import logging +import os +import sys +import wandb +from tabulate import tabulate +from torch.utils.tensorboard import SummaryWriter + + +def config_logging(cfg_logging, out_dir=None): + file_level = cfg_logging.get("file_level", 10) + console_level = cfg_logging.get("console_level", 10) + + log_formatter = logging.Formatter(cfg_logging["format"]) + + root_logger = logging.getLogger() + root_logger.handlers.clear() + + root_logger.setLevel(min(file_level, console_level)) + + if out_dir is not None: + _logging_file = os.path.join( + out_dir, cfg_logging.get("filename", "logging.log") + ) + file_handler = logging.FileHandler(_logging_file) + file_handler.setFormatter(log_formatter) + file_handler.setLevel(file_level) + root_logger.addHandler(file_handler) + + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(log_formatter) + console_handler.setLevel(console_level) + root_logger.addHandler(console_handler) + + # Avoid pollution by packages + logging.getLogger("PIL").setLevel(logging.INFO) + logging.getLogger("matplotlib").setLevel(logging.INFO) + + +class MyTrainingLogger: + """Tensorboard + wandb logger""" + + writer: SummaryWriter + is_initialized = False + + def __init__(self) -> None: + pass + + def set_dir(self, tb_log_dir): + if self.is_initialized: + raise ValueError("Do not initialize writer twice") + self.writer = SummaryWriter(tb_log_dir) + self.is_initialized = True + + def log_dict(self, scalar_dict, global_step, walltime=None): + for k, v in scalar_dict.items(): + self.writer.add_scalar(k, v, global_step=global_step, walltime=walltime) + return + + +# global instance +tb_logger = MyTrainingLogger() + + +# -------------- wandb tools -------------- +def init_wandb(enable: bool, **kwargs): + if enable: + run = wandb.init(sync_tensorboard=True, **kwargs) + else: + run = wandb.init(mode="disabled") + return run + + +def log_slurm_job_id(step): + global tb_logger + _jobid = os.getenv("SLURM_JOB_ID") + if _jobid is None: + _jobid = -1 + tb_logger.writer.add_scalar("job_id", int(_jobid), global_step=step) + logging.debug(f"Slurm job_id: {_jobid}") + + +def load_wandb_job_id(out_dir): + with open(os.path.join(out_dir, "WANDB_ID"), "r") as f: + wandb_id = f.read() + return wandb_id + + +def save_wandb_job_id(run, out_dir): + with open(os.path.join(out_dir, "WANDB_ID"), "w+") as f: + f.write(run.id) + + +def eval_dict_to_text(val_metrics: dict, dataset_name: str, sample_list_path: str): + eval_text = f"Evaluation metrics:\n\ + on dataset: {dataset_name}\n\ + over samples in: {sample_list_path}\n" + + eval_text += tabulate([val_metrics.keys(), val_metrics.values()]) + return eval_text diff --git a/src/util/loss.py b/src/util/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..43c617a147afd2f5525900e447b3262e95a4748a --- /dev/null +++ b/src/util/loss.py @@ -0,0 +1,151 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import torch + + +def get_loss(loss_name, **kwargs): + if "silog_mse" == loss_name: + criterion = SILogMSELoss(**kwargs) + elif "silog_rmse" == loss_name: + criterion = SILogRMSELoss(**kwargs) + elif "mse_loss" == loss_name: + criterion = torch.nn.MSELoss(**kwargs) + elif "l1_loss" == loss_name: + criterion = torch.nn.L1Loss(**kwargs) + elif "l1_loss_with_mask" == loss_name: + criterion = L1LossWithMask(**kwargs) + elif "mean_abs_rel" == loss_name: + criterion = MeanAbsRelLoss() + else: + raise NotImplementedError + + return criterion + + +class L1LossWithMask: + def __init__(self, batch_reduction=False): + self.batch_reduction = batch_reduction + + def __call__(self, depth_pred, depth_gt, valid_mask=None): + diff = depth_pred - depth_gt + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = depth_gt.shape[-2] * depth_gt.shape[-1] + + loss = torch.sum(torch.abs(diff)) / n + if self.batch_reduction: + loss = loss.mean() + return loss + + +class MeanAbsRelLoss: + def __init__(self) -> None: + # super().__init__() + pass + + def __call__(self, pred, gt): + diff = pred - gt + rel_abs = torch.abs(diff / gt) + loss = torch.mean(rel_abs, dim=0) + return loss + + +class SILogMSELoss: + def __init__(self, lamb, log_pred=True, batch_reduction=True): + """Scale Invariant Log MSE Loss + + Args: + lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss + log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred + """ + super(SILogMSELoss, self).__init__() + self.lamb = lamb + self.pred_in_log = log_pred + self.batch_reduction = batch_reduction + + def __call__(self, depth_pred, depth_gt, valid_mask=None): + log_depth_pred = ( + depth_pred if self.pred_in_log else torch.log(torch.clip(depth_pred, 1e-8)) + ) + log_depth_gt = torch.log(depth_gt) + + diff = log_depth_pred - log_depth_gt + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = depth_gt.shape[-2] * depth_gt.shape[-1] + + diff2 = torch.pow(diff, 2) + + first_term = torch.sum(diff2, (-1, -2)) / n + second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) + loss = first_term - second_term + if self.batch_reduction: + loss = loss.mean() + return loss + + +class SILogRMSELoss: + def __init__(self, lamb, alpha, log_pred=True): + """Scale Invariant Log RMSE Loss + + Args: + lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss + alpha: + log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred + """ + super(SILogRMSELoss, self).__init__() + self.lamb = lamb + self.alpha = alpha + self.pred_in_log = log_pred + + def __call__(self, depth_pred, depth_gt, valid_mask): + log_depth_pred = depth_pred if self.pred_in_log else torch.log(depth_pred) + log_depth_gt = torch.log(depth_gt) + # borrowed from https://github.com/aliyun/NeWCRFs + # diff = log_depth_pred[valid_mask] - log_depth_gt[valid_mask] + # return torch.sqrt((diff ** 2).mean() - self.lamb * (diff.mean() ** 2)) * self.alpha + + diff = log_depth_pred - log_depth_gt + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = depth_gt.shape[-2] * depth_gt.shape[-1] + + diff2 = torch.pow(diff, 2) + first_term = torch.sum(diff2, (-1, -2)) / n + second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) + loss = torch.sqrt(first_term - second_term).mean() * self.alpha + return loss diff --git a/src/util/lr_scheduler.py b/src/util/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..f54c8de73449971c828265dbfb163c39f80cf17f --- /dev/null +++ b/src/util/lr_scheduler.py @@ -0,0 +1,75 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import numpy as np + + +class IterExponential: + def __init__(self, total_iter_length, final_ratio, warmup_steps=0) -> None: + """ + Customized iteration-wise exponential scheduler. + Re-calculate for every step, to reduce error accumulation + + Args: + total_iter_length (int): Expected total iteration number + final_ratio (float): Expected LR ratio at n_iter = total_iter_length + """ + self.total_length = total_iter_length + self.effective_length = total_iter_length - warmup_steps + self.final_ratio = final_ratio + self.warmup_steps = warmup_steps + + def __call__(self, n_iter) -> float: + if n_iter < self.warmup_steps: + alpha = 1.0 * n_iter / self.warmup_steps + elif n_iter >= self.total_length: + alpha = self.final_ratio + else: + actual_iter = n_iter - self.warmup_steps + alpha = np.exp( + actual_iter / self.effective_length * np.log(self.final_ratio) + ) + return alpha + + +if "__main__" == __name__: + lr_scheduler = IterExponential( + total_iter_length=50000, final_ratio=0.01, warmup_steps=200 + ) + # lr_scheduler = IterExponential( + # total_iter_length=50000, final_ratio=0.01, warmup_steps=0 + # ) + + x = np.arange(100000) + alphas = [lr_scheduler(i) for i in x] + import matplotlib.pyplot as plt + + plt.plot(alphas) + plt.savefig("lr_scheduler.png") diff --git a/src/util/metric.py b/src/util/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..7491c9ec6e624c1dadd515a2c10a3272f34c5c5a --- /dev/null +++ b/src/util/metric.py @@ -0,0 +1,338 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import numpy as np +import pandas as pd +import torch + + +# Adapted from: https://github.com/victoresque/pytorch-template/blob/master/utils/util.py +class MetricTracker: + def __init__(self, *keys, writer=None): + self.writer = writer + self._data = pd.DataFrame(index=keys, columns=["total", "counts", "average"]) + self.reset() + + def reset(self): + for col in self._data.columns: + self._data[col].values[:] = 0 + + def update(self, key, value, n=1): + if self.writer is not None: + self.writer.add_scalar(key, value) + self._data.loc[key, "total"] += value * n + self._data.loc[key, "counts"] += n + self._data.loc[key, "average"] = self._data.total[key] / self._data.counts[key] + + def avg(self, key): + return self._data.average[key] + + def result(self): + return dict(self._data.average) + + +# -------------------- Depth Metrics -------------------- + + +def abs_relative_difference(output, target, valid_mask=None): + actual_output = output + actual_target = target + abs_relative_diff = torch.abs(actual_output - actual_target) / actual_target + if valid_mask is not None: + abs_relative_diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = output.shape[-1] * output.shape[-2] + abs_relative_diff = torch.sum(abs_relative_diff, (-1, -2)) / n + return abs_relative_diff.mean() + + +def squared_relative_difference(output, target, valid_mask=None): + actual_output = output + actual_target = target + square_relative_diff = ( + torch.pow(torch.abs(actual_output - actual_target), 2) / actual_target + ) + if valid_mask is not None: + square_relative_diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = output.shape[-1] * output.shape[-2] + square_relative_diff = torch.sum(square_relative_diff, (-1, -2)) / n + return square_relative_diff.mean() + + +def rmse_linear(output, target, valid_mask=None): + actual_output = output + actual_target = target + diff = actual_output - actual_target + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = output.shape[-1] * output.shape[-2] + diff2 = torch.pow(diff, 2) + mse = torch.sum(diff2, (-1, -2)) / n + rmse = torch.sqrt(mse) + return rmse.mean() + + +def rmse_log(output, target, valid_mask=None): + diff = torch.log(output) - torch.log(target) + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = output.shape[-1] * output.shape[-2] + diff2 = torch.pow(diff, 2) + mse = torch.sum(diff2, (-1, -2)) / n # [B] + rmse = torch.sqrt(mse) + return rmse.mean() + + +def log10(output, target, valid_mask=None): + if valid_mask is not None: + diff = torch.abs( + torch.log10(output[valid_mask]) - torch.log10(target[valid_mask]) + ) + else: + diff = torch.abs(torch.log10(output) - torch.log10(target)) + return diff.mean() + + +# Adapted from: https://github.com/imran3180/depth-map-prediction/blob/master/main.py +def threshold_percentage(output, target, threshold_val, valid_mask=None): + d1 = output / target + d2 = target / output + max_d1_d2 = torch.max(d1, d2) + zero = torch.zeros(*output.shape) + one = torch.ones(*output.shape) + bit_mat = torch.where(max_d1_d2.cpu() < threshold_val, one, zero) + if valid_mask is not None: + bit_mat[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = output.shape[-1] * output.shape[-2] + count_mat = torch.sum(bit_mat, (-1, -2)) + threshold_mat = count_mat / n.cpu() + return threshold_mat.mean() + + +def delta1_acc(pred, gt, valid_mask): + return threshold_percentage(pred, gt, 1.25, valid_mask) + + +def delta2_acc(pred, gt, valid_mask): + return threshold_percentage(pred, gt, 1.25**2, valid_mask) + + +def delta3_acc(pred, gt, valid_mask): + return threshold_percentage(pred, gt, 1.25**3, valid_mask) + + +def i_rmse(output, target, valid_mask=None): + output_inv = 1.0 / output + target_inv = 1.0 / target + diff = output_inv - target_inv + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = output.shape[-1] * output.shape[-2] + diff2 = torch.pow(diff, 2) + mse = torch.sum(diff2, (-1, -2)) / n # [B] + rmse = torch.sqrt(mse) + return rmse.mean() + + +def silog_rmse(depth_pred, depth_gt, valid_mask=None): + diff = torch.log(depth_pred) - torch.log(depth_gt) + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = depth_gt.shape[-2] * depth_gt.shape[-1] + + diff2 = torch.pow(diff, 2) + + first_term = torch.sum(diff2, (-1, -2)) / n + second_term = torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) + loss = torch.sqrt(torch.mean(first_term - second_term)) * 100 + return loss + + +# -------------------- Normals Metrics -------------------- + + +def compute_cosine_error(pred_norm, gt_norm, masked=False): + if len(pred_norm.shape) == 4: + pred_norm = pred_norm.squeeze(0) + if len(gt_norm.shape) == 4: + gt_norm = gt_norm.squeeze(0) + + # shape must be [3,H,W] + assert (gt_norm.shape[0] == 3) and ( + pred_norm.shape[0] == 3 + ), "Channel dim should be the first dimension!" + # mask out the zero vectors, otherwise torch.cosine_similarity computes 90° as error + if masked: + ch, h, w = gt_norm.shape + + mask = torch.norm(gt_norm, dim=0) > 0 + + pred_norm = pred_norm[:, mask.view(h, w)] + gt_norm = gt_norm[:, mask.view(h, w)] + + pred_error = torch.cosine_similarity(pred_norm, gt_norm, dim=0) + pred_error = torch.clamp(pred_error, min=-1.0, max=1.0) + pred_error = torch.acos(pred_error) * 180.0 / np.pi # (H, W) + + return ( + pred_error.view(-1).detach().cpu().numpy() + ) # flatten so can directly input to compute_normal_metrics() + + +def mean_angular_error(cosine_error): + return round(np.average(cosine_error), 4) + + +def median_angular_error(cosine_error): + return round(np.median(cosine_error), 4) + + +def rmse_angular_error(cosine_error): + num_pixels = cosine_error.shape[0] + return round(np.sqrt(np.sum(cosine_error * cosine_error) / num_pixels), 4) + + +def sub5_error(cosine_error): + num_pixels = cosine_error.shape[0] + return round(100.0 * (np.sum(cosine_error < 5) / num_pixels), 4) + + +def sub7_5_error(cosine_error): + num_pixels = cosine_error.shape[0] + return round(100.0 * (np.sum(cosine_error < 7.5) / num_pixels), 4) + + +def sub11_25_error(cosine_error): + num_pixels = cosine_error.shape[0] + return round(100.0 * (np.sum(cosine_error < 11.25) / num_pixels), 4) + + +def sub22_5_error(cosine_error): + num_pixels = cosine_error.shape[0] + return round(100.0 * (np.sum(cosine_error < 22.5) / num_pixels), 4) + + +def sub30_error(cosine_error): + num_pixels = cosine_error.shape[0] + return round(100.0 * (np.sum(cosine_error < 30) / num_pixels), 4) + + +# -------------------- IID Metrics -------------------- + + +def compute_iid_metric(pred, gt, target_name, metric_name, metric, valid_mask=None): + # Shading and residual are up-to-scale. We first scale-align them to the gt + # and map them to the range [0,1] for metric computation + if target_name == "shading" or target_name == "residual": + alignment_scale = compute_alignment_scale(pred, gt, valid_mask) + pred = alignment_scale * pred + # map to [0,1] + pred, gt = quantile_map(pred, gt, valid_mask) + + if len(pred.shape) == 3: + pred = pred.unsqueeze(0) + if len(gt.shape) == 3: + gt = gt.unsqueeze(0) + if valid_mask is not None: + if len(valid_mask.shape) == 3: + valid_mask = valid_mask.unsqueeze(0) + if metric_name == "psnr": + return metric(pred[valid_mask], gt[valid_mask]).item() + # for SSIM and LPIPs set the invalid pixels to zero + else: + invalid_mask = ~valid_mask + pred[invalid_mask] = 0 + gt[invalid_mask] = 0 + + return metric(pred, gt).item() + + +# compute least-squares alignment scale to align shading/residual prediction to gt +def compute_alignment_scale(pred, gt, valid_mask=None): + pred = pred.squeeze() + gt = gt.squeeze() + assert pred.shape[0] == 3 and gt.shape[0] == 3, "First dim should be channel dim" + + if valid_mask is not None: + valid_mask = valid_mask.squeeze() + pred = pred[valid_mask] + gt = gt[valid_mask] + + A_flattened = pred.view(-1, 1) + b_flattened = gt.view(-1, 1) + # Solve the least squares problem + x, residuals, rank, s = torch.linalg.lstsq(A_flattened.float(), b_flattened.float()) + return x + + +def quantile_map(pred, gt, valid_mask=None): + pred = pred.squeeze() + gt = gt.squeeze() + assert gt.shape[0] == 3, "channel dim must be first dim" + + percentile = 90 + brightness_nth_percentile_desired = 0.8 + brightness = 0.3 * gt[0, :, :] + 0.59 * gt[1, :, :] + 0.11 * gt[2, :, :] + + if valid_mask is not None: + valid_mask = valid_mask.squeeze() + brightness = brightness[valid_mask[0]] + else: + brightness = brightness.flatten() + + eps = 0.0001 + + brightness_nth_percentile_current = torch.quantile(brightness, percentile / 100.0) + + if brightness_nth_percentile_current < eps: + scale = 0 + else: + scale = float( + brightness_nth_percentile_desired / brightness_nth_percentile_current + ) + + # Apply scaling to ground truth and prediction + gt_mapped = torch.clamp(scale * gt, 0, 1).unsqueeze(0) # [1,3,H,W] + pred_mapped = torch.clamp(scale * pred, 0, 1).unsqueeze(0) # [1,3,H,W] + + return pred_mapped, gt_mapped diff --git a/src/util/multi_res_noise.py b/src/util/multi_res_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..a92d14c6cb0669b464f92af04560a963ea327af3 --- /dev/null +++ b/src/util/multi_res_noise.py @@ -0,0 +1,103 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +# Adapted from: https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2?s=31 + +import math +import torch + + +def multi_res_noise_like( + x, strength=0.9, downscale_strategy="original", generator=None, device=None +): + if torch.is_tensor(strength): + strength = strength.reshape((-1, 1, 1, 1)) + b, c, w, h = x.shape + + if device is None: + device = x.device + + up_sampler = torch.nn.Upsample(size=(w, h), mode="bilinear") + noise = torch.randn(x.shape, device=x.device, generator=generator) + + if "original" == downscale_strategy: + for i in range(10): + r = ( + torch.rand(1, generator=generator, device=device) * 2 + 2 + ) # Rather than always going 2x, + w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) + noise += ( + up_sampler( + torch.randn(b, c, w, h, generator=generator, device=device).to(x) + ) + * strength**i + ) + if w == 1 or h == 1: + break # Lowest resolution is 1x1 + elif "every_layer" == downscale_strategy: + for i in range(int(math.log2(min(w, h)))): + w, h = max(1, int(w / 2)), max(1, int(h / 2)) + noise += ( + up_sampler( + torch.randn(b, c, w, h, generator=generator, device=device).to(x) + ) + * strength**i + ) + elif "power_of_two" == downscale_strategy: + for i in range(10): + r = 2 + w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) + noise += ( + up_sampler( + torch.randn(b, c, w, h, generator=generator, device=device).to(x) + ) + * strength**i + ) + if w == 1 or h == 1: + break # Lowest resolution is 1x1 + elif "random_step" == downscale_strategy: + for i in range(10): + r = ( + torch.rand(1, generator=generator, device=device) * 2 + 2 + ) # Rather than always going 2x, + w, h = max(1, int(w / (r))), max(1, int(h / (r))) + noise += ( + up_sampler( + torch.randn(b, c, w, h, generator=generator, device=device).to(x) + ) + * strength**i + ) + if w == 1 or h == 1: + break # Lowest resolution is 1x1 + else: + raise ValueError(f"unknown downscale strategy: {downscale_strategy}") + + noise = noise / noise.std() # Scaled back to roughly unit variance + return noise diff --git a/src/util/seeding.py b/src/util/seeding.py new file mode 100644 index 0000000000000000000000000000000000000000..f84c393a41f2d31ae0d092c7e39a9e665920c0af --- /dev/null +++ b/src/util/seeding.py @@ -0,0 +1,64 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import logging +import numpy as np +import random +import torch + + +def seed_all(seed: int = 0): + """ + Set random seeds of all components. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def generate_seed_sequence( + initial_seed: int, + length: int, + min_val=-0x8000_0000_0000_0000, + max_val=0xFFFF_FFFF_FFFF_FFFF, +): + if initial_seed is None: + logging.warning("initial_seed is None, reproducibility is not guaranteed") + random.seed(initial_seed) + + seed_sequence = [] + + for _ in range(length): + seed = random.randint(min_val, max_val) + + seed_sequence.append(seed) + + return seed_sequence diff --git a/src/util/slurm_util.py b/src/util/slurm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..77c0f479979eb815a31a3c149d546dfa68c252af --- /dev/null +++ b/src/util/slurm_util.py @@ -0,0 +1,42 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import os + + +def is_on_slurm(): + cluster_name = os.getenv("SLURM_CLUSTER_NAME") + is_on_slurm = cluster_name is not None + return is_on_slurm + + +def get_local_scratch_dir(): + local_scratch_dir = os.getenv("TMPDIR") + return local_scratch_dir