degbo commited on
Commit
7decfe1
·
1 Parent(s): 19a564f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. config/dataset_depth/data_diode_all.yaml +4 -0
  2. config/dataset_depth/data_eth3d.yaml +4 -0
  3. config/dataset_depth/data_hypersim_train.yaml +4 -0
  4. config/dataset_depth/data_hypersim_val.yaml +4 -0
  5. config/dataset_depth/data_kitti_eigen_test.yaml +6 -0
  6. config/dataset_depth/data_kitti_val.yaml +6 -0
  7. config/dataset_depth/data_nyu_test.yaml +5 -0
  8. config/dataset_depth/data_nyu_train.yaml +5 -0
  9. config/dataset_depth/data_scannet_val.yaml +4 -0
  10. config/dataset_depth/data_vkitti_train.yaml +6 -0
  11. config/dataset_depth/data_vkitti_val.yaml +6 -0
  12. config/dataset_depth/dataset_train.yaml +18 -0
  13. config/dataset_depth/dataset_val.yaml +45 -0
  14. config/dataset_depth/dataset_vis.yaml +9 -0
  15. config/dataset_iid/data_appearance_interiorverse_test.yaml +4 -0
  16. config/dataset_iid/data_appearance_synthetic_test.yaml +4 -0
  17. config/dataset_iid/data_art_test.yaml +4 -0
  18. config/dataset_iid/data_lighting_hypersim_test.yaml +4 -0
  19. config/dataset_iid/dataset_appearance_train.yaml +9 -0
  20. config/dataset_iid/dataset_appearance_val.yaml +6 -0
  21. config/dataset_iid/dataset_appearance_vis.yaml +6 -0
  22. config/dataset_iid/dataset_lighting_train.yaml +12 -0
  23. config/dataset_iid/dataset_lighting_val.yaml +6 -0
  24. config/dataset_iid/dataset_lighting_vis.yaml +6 -0
  25. config/dataset_iid/osu_data_appearance_interiorverse_test.yaml +4 -0
  26. config/dataset_normals/data_diode_test.yaml +4 -0
  27. config/dataset_normals/data_ibims_test.yaml +4 -0
  28. config/dataset_normals/data_nyu_test.yaml +4 -0
  29. config/dataset_normals/data_oasis_test.yaml +4 -0
  30. config/dataset_normals/data_scannet_test.yaml +4 -0
  31. config/dataset_normals/dataset_train.yaml +25 -0
  32. config/dataset_normals/dataset_val.yaml +7 -0
  33. config/dataset_normals/dataset_vis.yaml +7 -0
  34. config/logging.yaml +5 -0
  35. config/model_sdv2.yaml +4 -0
  36. config/train_debug_depth.yaml +10 -0
  37. config/train_debug_iid.yaml +11 -0
  38. config/train_debug_normals.yaml +10 -0
  39. config/train_marigold_depth.yaml +94 -0
  40. config/train_marigold_iid_appearance.yaml +81 -0
  41. config/train_marigold_iid_appearance_finetuned.yaml +81 -0
  42. config/train_marigold_iid_lighting.yaml +82 -0
  43. config/train_marigold_normals.yaml +86 -0
  44. config/wandb.yaml +3 -0
  45. marigold/marigold_depth_pipeline.py +516 -0
  46. marigold/marigold_normals_pipeline.py +479 -0
  47. marigold/util/batchsize.py +90 -0
  48. marigold/util/ensemble.py +270 -0
  49. requirements++.txt +7 -0
  50. requirements+.txt +4 -0
config/dataset_depth/data_diode_all.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: diode_depth
2
+ disp_name: diode_depth_val_all
3
+ dir: diode/diode_val.tar
4
+ filenames: data_split/diode_depth/diode_val_all_filename_list.txt
config/dataset_depth/data_eth3d.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: eth3d_depth
2
+ disp_name: eth3d_depth_full
3
+ dir: eth3d/eth3d.tar
4
+ filenames: data_split/eth3d_depth/eth3d_filename_list.txt
config/dataset_depth/data_hypersim_train.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: hypersim_depth
2
+ disp_name: hypersim_depth_train
3
+ dir: hypersim/hypersim_processed_train.tar
4
+ filenames: data_split/hypersim_depth/filename_list_train_filtered.txt
config/dataset_depth/data_hypersim_val.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: hypersim_depth
2
+ disp_name: hypersim_depth_val
3
+ dir: hypersim/hypersim_processed_val.tar
4
+ filenames: data_split/hypersim_depth/filename_list_val_filtered.txt
config/dataset_depth/data_kitti_eigen_test.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: kitti_depth
2
+ disp_name: kitti_depth_eigen_test_full
3
+ dir: kitti/kitti_eigen_split_test.tar
4
+ filenames: data_split/kitti_depth/eigen_test_files_with_gt.txt
5
+ kitti_bm_crop: true
6
+ valid_mask_crop: eigen
config/dataset_depth/data_kitti_val.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: kitti_depth
2
+ disp_name: kitti_depth_val800_from_eigen_train
3
+ dir: kitti/kitti_sampled_val_800.tar
4
+ filenames: data_split/kitti_depth/eigen_val_from_train_800.txt
5
+ kitti_bm_crop: true
6
+ valid_mask_crop: eigen
config/dataset_depth/data_nyu_test.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ name: nyu_depth
2
+ disp_name: nyu_depth_test_full
3
+ dir: nyuv2/nyu_labeled_extracted.tar
4
+ filenames: data_split/nyu_depth/labeled/filename_list_test.txt
5
+ eigen_valid_mask: true
config/dataset_depth/data_nyu_train.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ name: nyu_depth
2
+ disp_name: nyu_depth_train_full
3
+ dir: nyuv2/nyu_labeled_extracted.tar
4
+ filenames: data_split/nyu_depth/labeled/filename_list_train.txt
5
+ eigen_valid_mask: true
config/dataset_depth/data_scannet_val.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: scannet_depth
2
+ disp_name: scannet_depth_val_800_1
3
+ dir: scannet/scannet_val_sampled_800_1.tar
4
+ filenames: data_split/scannet_depth/scannet_val_sampled_list_800_1.txt
config/dataset_depth/data_vkitti_train.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: vkitti_depth
2
+ disp_name: vkitti_depth_train
3
+ dir: vkitti/vkitti.tar
4
+ filenames: data_split/vkitti_depth/vkitti_train.txt
5
+ kitti_bm_crop: true
6
+ valid_mask_crop: null # no valid_mask_crop for training
config/dataset_depth/data_vkitti_val.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: vkitti_depth
2
+ disp_name: vkitti_depth_val
3
+ dir: vkitti/vkitti.tar
4
+ filenames: data_split/vkitti_depth/vkitti_val.txt
5
+ kitti_bm_crop: true
6
+ valid_mask_crop: eigen
config/dataset_depth/dataset_train.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train:
3
+ name: mixed
4
+ prob_ls: [0.9, 0.1]
5
+ dataset_list:
6
+ - name: hypersim_depth
7
+ disp_name: hypersim_depth_train
8
+ dir: hypersim/hypersim_processed_train.tar
9
+ filenames: data_split/hypersim_depth/filename_list_train_filtered.txt
10
+ resize_to_hw:
11
+ - 480
12
+ - 640
13
+ - name: vkitti_depth
14
+ disp_name: vkitti_depth_train
15
+ dir: vkitti/vkitti.tar
16
+ filenames: data_split/vkitti_depth/vkitti_train.txt
17
+ kitti_bm_crop: true
18
+ valid_mask_crop: null
config/dataset_depth/dataset_val.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ val:
3
+ # - name: hypersim_depth
4
+ # disp_name: hypersim_depth_val
5
+ # dir: hypersim/hypersim_processed_val.tar
6
+ # filenames: data_split/hypersim_depth/filename_list_val_filtered.txt
7
+ # resize_to_hw:
8
+ # - 480
9
+ # - 640
10
+
11
+ # - name: nyu_depth
12
+ # disp_name: nyu_depth_train_full
13
+ # dir: nyuv2/nyu_labeled_extracted.tar
14
+ # filenames: data_split/nyu_depth/labeled/filename_list_train.txt
15
+ # eigen_valid_mask: true
16
+
17
+ # - name: kitti_depth
18
+ # disp_name: kitti_depth_val800_from_eigen_train
19
+ # dir: kitti/kitti_depth_sampled_val_800.tar
20
+ # filenames: data_split/kitti_depth/eigen_val_from_train_800.txt
21
+ # kitti_bm_crop: true
22
+ # valid_mask_crop: eigen
23
+
24
+ # Smaller subsets for faster validation during training
25
+ # The first dataset is used to calculate main eval metric.
26
+ - name: hypersim_depth
27
+ disp_name: hypersim_depth_val_small_80
28
+ dir: hypersim/hypersim_processed_val.tar
29
+ filenames: data_split/hypersim_depth/filename_list_val_filtered_small_80.txt
30
+ resize_to_hw:
31
+ - 480
32
+ - 640
33
+
34
+ - name: nyu_depth
35
+ disp_name: nyu_depth_train_small_100
36
+ dir: nyuv2/nyu_labeled_extracted.tar
37
+ filenames: data_split/nyu_depth/labeled/filename_list_train_small_100.txt
38
+ eigen_valid_mask: true
39
+
40
+ - name: kitti_depth
41
+ disp_name: kitti_depth_val_from_train_sub_100
42
+ dir: kitti/kitti_sampled_val_800.tar
43
+ filenames: data_split/kitti_depth/eigen_val_from_train_sub_100.txt
44
+ kitti_bm_crop: true
45
+ valid_mask_crop: eigen
config/dataset_depth/dataset_vis.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ vis:
3
+ - name: hypersim_depth
4
+ disp_name: hypersim_depth_vis
5
+ dir: hypersim/hypersim_processed_val.tar
6
+ filenames: data_split/hypersim_depth/selected_vis_sample.txt
7
+ resize_to_hw:
8
+ - 480
9
+ - 640
config/dataset_iid/data_appearance_interiorverse_test.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: interiorverse_iid
2
+ disp_name: interiorverse_iid_appearance_test
3
+ dir: interiorverse/InteriorVerse.tar
4
+ filenames: data_split/interiorverse_iid/interiorverse_test_scenes_85.txt
config/dataset_iid/data_appearance_synthetic_test.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: interiorverse_iid
2
+ disp_name: interiorverse_iid_appearance_test
3
+ dir: synthetic
4
+ filenames: data_split/osu/osu_test_scenes_85.txt
config/dataset_iid/data_art_test.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: interiorverse_iid
2
+ disp_name: interiorverse_iid_appearance_test
3
+ dir: art
4
+ filenames: data_split/osu/art_test_scenes.txt
config/dataset_iid/data_lighting_hypersim_test.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: hypersim_iid
2
+ disp_name: hypersim_iid_lighting_test
3
+ dir: hypersim
4
+ filenames: data_split/hypersim_iid/hypersim_test.txt
config/dataset_iid/dataset_appearance_train.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train:
3
+ name: mixed
4
+ prob_ls: [1.0]
5
+ dataset_list:
6
+ - name: interiorverse_iid
7
+ disp_name: interiorverse_iid_appearance_train
8
+ dir: osu_albedo_new
9
+ filenames: data_split/osu/osu_train_scenes_85.txt
config/dataset_iid/dataset_appearance_val.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ dataset:
2
+ val:
3
+ - name: interiorverse_iid
4
+ disp_name: interiorverse_iid_appearance_val
5
+ dir: synthetic
6
+ filenames: data_split/MatrixCity/matrixcity_val_scenes_small.txt
config/dataset_iid/dataset_appearance_vis.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ dataset:
2
+ vis:
3
+ - name: interiorverse_iid
4
+ disp_name: interiorverse_iid_appearance_vis
5
+ dir: synthetic
6
+ filenames: data_split/MatrixCity/matrixcity_vis_scenes.txt
config/dataset_iid/dataset_lighting_train.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train:
3
+ name: mixed
4
+ prob_ls: [1.0]
5
+ dataset_list:
6
+ - name: hypersim_iid
7
+ disp_name: hypersim_iid_lighting_train
8
+ dir: hypersim
9
+ filenames: data_split/hypersim_iid/hypersim_train_filtered.txt
10
+ resize_to_hw:
11
+ - 480
12
+ - 640
config/dataset_iid/dataset_lighting_val.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ dataset:
2
+ val:
3
+ - name: hypersim_iid
4
+ disp_name: hypersim_iid_lighting_val
5
+ dir: hypersim
6
+ filenames: data_split/hypersim_iid/hypersim_val.txt
config/dataset_iid/dataset_lighting_vis.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ dataset:
2
+ vis:
3
+ - name: hypersim_iid
4
+ disp_name: hypersim_iid_lighting_vis
5
+ dir: hypersim
6
+ filenames: data_split/hypersim_iid/hypersim_vis.txt
config/dataset_iid/osu_data_appearance_interiorverse_test.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: interiorverse_iid
2
+ disp_name: interiorverse_iid_appearance_test
3
+ dir: synthetic
4
+ filenames: data_split/osu/osu_test_scenes_85.txt
config/dataset_normals/data_diode_test.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: diode_normals
2
+ disp_name: diode_normals_test
3
+ dir: diode/val
4
+ filenames: data_split/diode_normals/diode_test.txt
config/dataset_normals/data_ibims_test.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: ibims_normals
2
+ disp_name: ibims_normals_test
3
+ dir: ibims/ibims
4
+ filenames: data_split/ibims_normals/ibims_test.txt
config/dataset_normals/data_nyu_test.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: nyu_normals
2
+ disp_name: nyu_normals_test
3
+ dir: nyuv2/test
4
+ filenames: data_split/nyu_normals/nyuv2_test.txt
config/dataset_normals/data_oasis_test.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: oasis_normals
2
+ disp_name: oasis_normals_test
3
+ dir: oasis/val
4
+ filenames: data_split/oasis_normals/oasis_test.txt
config/dataset_normals/data_scannet_test.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: scannet_normals
2
+ disp_name: scannet_normals_test
3
+ dir: scannet
4
+ filenames: data_split/scannet_normals/scannet_test.txt
config/dataset_normals/dataset_train.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train:
3
+ name: mixed
4
+ prob_ls: [0.5, 0.49, 0.01]
5
+ dataset_list:
6
+ - name: hypersim_normals
7
+ disp_name: hypersim_normals_train
8
+ dir: hypersim
9
+ filenames: data_split/hypersim_normals/hypersim_filtered_all.txt
10
+ resize_to_hw:
11
+ - 480
12
+ - 640
13
+ - name: interiorverse_normals
14
+ disp_name: interiorverse_normals_train
15
+ dir: interiorverse/scenes_85
16
+ filenames: data_split/interiorverse_normals/interiorverse_filtered_all.txt
17
+ resize_to_hw: null
18
+ - name: sintel_normals
19
+ disp_name: sintel_normals_train
20
+ dir: sintel
21
+ filenames: data_split/sintel_normals/sintel_filtered.txt
22
+ resize_to_hw:
23
+ - 480
24
+ - 640
25
+ center_crop: true
config/dataset_normals/dataset_val.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ val:
3
+ - name: hypersim_normals
4
+ disp_name: hypersim_normals_val_small_100
5
+ dir: hypersim
6
+ filenames: data_split/hypersim_normals/hypersim_filtered_val_100.txt
7
+ resize_to_hw: null
config/dataset_normals/dataset_vis.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ vis:
3
+ - name: hypersim_normals
4
+ disp_name: hypersim_normals_vis
5
+ dir: hypersim
6
+ filenames: data_split/hypersim_normals/hypersim_filtered_vis_20.txt
7
+ resize_to_hw: null
config/logging.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ logging:
2
+ filename: logging.log
3
+ format: ' %(asctime)s - %(levelname)s -%(filename)s - %(funcName)s >> %(message)s'
4
+ console_level: 20
5
+ file_level: 10
config/model_sdv2.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ model:
2
+ name: marigold_pipeline
3
+ pretrained_path: stable-diffusion-2
4
+ latent_scale_factor: 0.18215
config/train_debug_depth.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - config/train_marigold_depth.yaml
3
+
4
+ trainer:
5
+ save_period: 5
6
+ backup_period: 10
7
+ validation_period: 5
8
+ visualization_period: 5
9
+
10
+ max_iter: 50
config/train_debug_iid.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ # - config/train_marigold_iid_lighting.yaml
3
+ - config/train_marigold_iid_appearance.yaml
4
+
5
+ trainer:
6
+ save_period: 10
7
+ backup_period: 10
8
+ validation_period: 5
9
+ visualization_period: 5
10
+
11
+ max_iter: 50
config/train_debug_normals.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - config/train_marigold_normals.yaml
3
+
4
+ trainer:
5
+ save_period: 5
6
+ backup_period: 10
7
+ validation_period: 5
8
+ visualization_period: 5
9
+
10
+ max_iter: 50
config/train_marigold_depth.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - config/logging.yaml
3
+ - config/wandb.yaml
4
+ - config/dataset_depth/dataset_train.yaml
5
+ - config/dataset_depth/dataset_val.yaml
6
+ - config/dataset_depth/dataset_vis.yaml
7
+ - config/model_sdv2.yaml
8
+
9
+ pipeline:
10
+ name: MarigoldDepthPipeline
11
+ kwargs:
12
+ scale_invariant: true
13
+ shift_invariant: true
14
+ default_denoising_steps: 4
15
+ default_processing_resolution: 768
16
+
17
+ depth_normalization:
18
+ type: scale_shift_depth
19
+ clip: true
20
+ norm_min: -1.0
21
+ norm_max: 1.0
22
+ min_max_quantile: 0.02
23
+
24
+ augmentation:
25
+ lr_flip_p: 0.5
26
+
27
+ dataloader:
28
+ num_workers: 2
29
+ effective_batch_size: 32
30
+ max_train_batch_size: 2
31
+ seed: 2024 # to ensure continuity when resuming from checkpoint
32
+
33
+ trainer:
34
+ name: MarigoldDepthTrainer
35
+ training_noise_scheduler:
36
+ pretrained_path: stable-diffusion-2
37
+ init_seed: 2024 # use null to train w/o seeding
38
+ save_period: 50
39
+ backup_period: 2000
40
+ validation_period: 500
41
+ visualization_period: 1000
42
+
43
+ multi_res_noise:
44
+ strength: 0.9
45
+ annealed: true
46
+ downscale_strategy: original
47
+
48
+ gt_depth_type: depth_raw_norm
49
+ gt_mask_type: valid_mask_raw
50
+
51
+ max_epoch: 10000 # a large enough number
52
+ max_iter: 30000 # usually converges at around 20k
53
+
54
+ optimizer:
55
+ name: Adam
56
+
57
+ loss:
58
+ name: mse_loss
59
+ kwargs:
60
+ reduction: mean
61
+
62
+ lr: 3.0e-05
63
+ lr_scheduler:
64
+ name: IterExponential
65
+ kwargs:
66
+ total_iter: 25000
67
+ final_ratio: 0.01
68
+ warmup_steps: 100
69
+
70
+ # Light setting for the in-training validation and visualization
71
+ validation:
72
+ denoising_steps: 1
73
+ ensemble_size: 1
74
+ processing_res: 0
75
+ match_input_res: false
76
+ resample_method: bilinear
77
+ main_val_metric: abs_relative_difference
78
+ main_val_metric_goal: minimize
79
+ init_seed: 2024
80
+
81
+ eval:
82
+ alignment: least_square
83
+ align_max_res: null
84
+ eval_metrics:
85
+ - abs_relative_difference
86
+ - squared_relative_difference
87
+ - rmse_linear
88
+ - rmse_log
89
+ - log10
90
+ - delta1_acc
91
+ - delta2_acc
92
+ - delta3_acc
93
+ - i_rmse
94
+ - silog_rmse
config/train_marigold_iid_appearance.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - config/logging.yaml
3
+ - config/wandb.yaml
4
+ - config/dataset_iid/dataset_appearance_train.yaml
5
+ - config/dataset_iid/dataset_appearance_val.yaml
6
+ - config/dataset_iid/dataset_appearance_vis.yaml
7
+ - config/model_sdv2.yaml
8
+
9
+ pipeline:
10
+ name: MarigoldIIDPipeline
11
+ kwargs:
12
+ default_denoising_steps: 4
13
+ default_processing_resolution: 768
14
+ target_properties:
15
+ target_names:
16
+ - albedo
17
+ albedo:
18
+ prediction_space: srgb
19
+
20
+ augmentation:
21
+ lr_flip_p: 0.5
22
+
23
+ dataloader:
24
+ num_workers: 2
25
+ effective_batch_size: 32
26
+ max_train_batch_size: 8
27
+ seed: 2024 # to ensure continuity when resuming from checkpoint
28
+
29
+ trainer:
30
+ name: MarigoldIIDTrainer
31
+ training_noise_scheduler:
32
+ pretrained_path: stable-diffusion-2
33
+ init_seed: 2024 # use null to train w/o seeding
34
+ save_period: 50
35
+ backup_period: 2000
36
+ validation_period: 500
37
+ visualization_period: 1000
38
+
39
+ multi_res_noise:
40
+ strength: 0.9
41
+ annealed: true
42
+ downscale_strategy: original
43
+
44
+ gt_mask_type: mask
45
+
46
+ max_epoch: 10000 # a large enough number
47
+ max_iter: 10000 # usually converges at around 40k
48
+
49
+ optimizer:
50
+ name: Adam
51
+
52
+ loss:
53
+ name: mse_loss
54
+ kwargs:
55
+ reduction: mean
56
+
57
+ lr: 2.0e-05
58
+ lr_scheduler:
59
+ name: IterExponential
60
+ kwargs:
61
+ total_iter: 5000
62
+ final_ratio: 0.01
63
+ warmup_steps: 100
64
+
65
+ # Light setting for the in-training validation and visualization
66
+ validation:
67
+ denoising_steps: 4
68
+ ensemble_size: 1
69
+ processing_res: 0
70
+ match_input_res: true
71
+ resample_method: bilinear
72
+ main_val_metric: psnr
73
+ main_val_metric_goal: maximize
74
+ init_seed: 2024
75
+ use_mask: false
76
+
77
+ eval:
78
+ eval_metrics:
79
+ - psnr
80
+ targets_to_eval_in_linear_space:
81
+ - material
config/train_marigold_iid_appearance_finetuned.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - config/logging.yaml
3
+ - config/wandb.yaml
4
+ - config/dataset_iid/dataset_appearance_train.yaml
5
+ - config/dataset_iid/dataset_appearance_val.yaml
6
+ - config/dataset_iid/dataset_appearance_vis.yaml
7
+ - config/model_sdv2.yaml
8
+
9
+ pipeline:
10
+ name: MarigoldIIDPipeline
11
+ kwargs:
12
+ default_denoising_steps: 4
13
+ default_processing_resolution: 768
14
+ target_properties:
15
+ target_names:
16
+ - albedo
17
+ albedo:
18
+ prediction_space: srgb
19
+
20
+ augmentation:
21
+ lr_flip_p: 0.5
22
+
23
+ dataloader:
24
+ num_workers: 2
25
+ effective_batch_size: 32
26
+ max_train_batch_size: 8
27
+ seed: 2024 # to ensure continuity when resuming from checkpoint
28
+
29
+ trainer:
30
+ name: MarigoldIIDTrainer
31
+ training_noise_scheduler:
32
+ pretrained_path: stable-diffusion-2
33
+ init_seed: 2024 # use null to train w/o seeding
34
+ save_period: 50
35
+ backup_period: 2000
36
+ validation_period: 177
37
+ visualization_period: 177
38
+
39
+ multi_res_noise:
40
+ strength: 0.9
41
+ annealed: true
42
+ downscale_strategy: original
43
+
44
+ gt_mask_type: null
45
+
46
+ max_epoch: 10000 # a large enough number
47
+ max_iter: 5000 # usually converges at around 40k
48
+
49
+ optimizer:
50
+ name: Adam
51
+
52
+ loss:
53
+ name: mse_loss
54
+ kwargs:
55
+ reduction: mean
56
+
57
+ lr: 5.0e-07
58
+ lr_scheduler:
59
+ name: IterExponential
60
+ kwargs:
61
+ total_iter: 2500
62
+ final_ratio: 0.01
63
+ warmup_steps: 100
64
+
65
+ # Light setting for the in-training validation and visualization
66
+ validation:
67
+ denoising_steps: 4
68
+ ensemble_size: 1
69
+ processing_res: 1000
70
+ match_input_res: true
71
+ resample_method: bilinear
72
+ main_val_metric: psnr
73
+ main_val_metric_goal: maximize
74
+ init_seed: 2024
75
+ use_mask: false
76
+
77
+ eval:
78
+ eval_metrics:
79
+ - psnr
80
+ targets_to_eval_in_linear_space:
81
+ - material
config/train_marigold_iid_lighting.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - config/logging.yaml
3
+ - config/wandb.yaml
4
+ - config/dataset_iid/dataset_lighting_train.yaml
5
+ - config/dataset_iid/dataset_lighting_val.yaml
6
+ - config/dataset_iid/dataset_lighting_vis.yaml
7
+ - config/model_sdv2.yaml
8
+
9
+ pipeline:
10
+ name: MarigoldIIDPipeline
11
+ kwargs:
12
+ default_denoising_steps: 4
13
+ default_processing_resolution: 768
14
+ target_properties:
15
+ target_names:
16
+ - albedo
17
+ albedo:
18
+ prediction_space: linear
19
+ up_to_scale: false
20
+
21
+ augmentation:
22
+ lr_flip_p: 0.5
23
+
24
+ dataloader:
25
+ num_workers: 2
26
+ effective_batch_size: 32
27
+ max_train_batch_size: 8
28
+ seed: 2024 # to ensure continuity when resuming from checkpoint
29
+
30
+ trainer:
31
+ name: MarigoldIIDTrainer
32
+ training_noise_scheduler:
33
+ pretrained_path: stable-diffusion-2
34
+ init_seed: 2024 # use null to train w/o seeding
35
+ save_period: 50
36
+ backup_period: 2000
37
+ validation_period: 500
38
+ visualization_period: 1000
39
+
40
+ multi_res_noise:
41
+ strength: 0.9
42
+ annealed: true
43
+ downscale_strategy: original
44
+
45
+ gt_mask_type: mask
46
+
47
+ max_epoch: 10000 # a large enough number
48
+ max_iter: 50000 # usually converges at around 34k
49
+
50
+ optimizer:
51
+ name: Adam
52
+
53
+ loss:
54
+ name: mse_loss
55
+ kwargs:
56
+ reduction: mean
57
+
58
+ lr: 8e-05
59
+ lr_scheduler:
60
+ name: IterExponential
61
+ kwargs:
62
+ total_iter: 45000
63
+ final_ratio: 0.01
64
+ warmup_steps: 100
65
+
66
+ # Light setting for the in-training validation and visualization
67
+ validation:
68
+ denoising_steps: 4
69
+ ensemble_size: 1
70
+ processing_res: 0
71
+ match_input_res: true
72
+ resample_method: bilinear
73
+ main_val_metric: psnr
74
+ main_val_metric_goal: maximize
75
+ init_seed: 2024
76
+ use_mask: false
77
+
78
+ eval:
79
+ eval_metrics:
80
+ - psnr
81
+ targets_to_eval_in_linear_space:
82
+ - None
config/train_marigold_normals.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - config/logging.yaml
3
+ - config/wandb.yaml
4
+ - config/dataset_normals/dataset_train.yaml
5
+ - config/dataset_normals/dataset_val.yaml
6
+ - config/dataset_normals/dataset_vis.yaml
7
+ - config/model_sdv2.yaml
8
+
9
+ pipeline:
10
+ name: MarigoldNormalsPipeline
11
+ kwargs:
12
+ default_denoising_steps: 4
13
+ default_processing_resolution: 768
14
+
15
+ augmentation:
16
+ lr_flip_p: 0.5
17
+ color_jitter_p: 0.3
18
+ gaussian_blur_p: 0.3
19
+ motion_blur_p: 0.3
20
+ gaussian_blur_sigma: 4
21
+ motion_blur_kernel_size: 11
22
+ motion_blur_angle_range: 360
23
+ jitter_brightness_factor: 0.5
24
+ jitter_contrast_factor: 0.5
25
+ jitter_saturation_factor: 0.5
26
+ jitter_hue_factor: 0.2
27
+
28
+ dataloader:
29
+ num_workers: 2
30
+ effective_batch_size: 32
31
+ max_train_batch_size: 2
32
+ seed: 2024 # to ensure continuity when resuming from checkpoint
33
+
34
+ trainer:
35
+ name: MarigoldNormalsTrainer
36
+ training_noise_scheduler:
37
+ pretrained_path: stable-diffusion-2
38
+ init_seed: 2024 # use null to train w/o seeding
39
+ save_period: 50
40
+ backup_period: 2000
41
+ validation_period: 500
42
+ visualization_period: 1000
43
+
44
+ multi_res_noise:
45
+ strength: 0.9
46
+ annealed: true
47
+ downscale_strategy: original
48
+
49
+ gt_normals_type: normals
50
+ gt_mask_type: null
51
+
52
+ max_epoch: 10000 # a large enough number
53
+ max_iter: 30000 # usually converges at around 26k
54
+
55
+ optimizer:
56
+ name: Adam
57
+
58
+ loss:
59
+ name: mse_loss
60
+ kwargs:
61
+ reduction: mean
62
+
63
+ lr: 6.0e-05
64
+ lr_scheduler:
65
+ name: IterExponential
66
+ kwargs:
67
+ total_iter: 25000
68
+ final_ratio: 0.01
69
+ warmup_steps: 100
70
+
71
+ # Light setting for the in-training validation and visualization
72
+ validation:
73
+ denoising_steps: 4
74
+ ensemble_size: 1
75
+ processing_res: 768
76
+ match_input_res: true
77
+ resample_method: bilinear
78
+ main_val_metric: mean_angular_error
79
+ main_val_metric_goal: minimize
80
+ init_seed: 0
81
+
82
+ eval:
83
+ align_max_res: null
84
+ eval_metrics:
85
+ - mean_angular_error
86
+ - sub11_25_error
config/wandb.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ wandb:
2
+ # entity: your_entity
3
+ project: marigold
marigold/marigold_depth_pipeline.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # More information about Marigold:
16
+ # https://marigoldmonodepth.github.io
17
+ # https://marigoldcomputervision.github.io
18
+ # Efficient inference pipelines are now part of diffusers:
19
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
20
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
21
+ # Examples of trained models and live demos:
22
+ # https://huggingface.co/prs-eth
23
+ # Related projects:
24
+ # https://rollingdepth.github.io/
25
+ # https://marigolddepthcompletion.github.io/
26
+ # Citation (BibTeX):
27
+ # https://github.com/prs-eth/Marigold#-citation
28
+ # If you find Marigold useful, we kindly ask you to cite our papers.
29
+ # --------------------------------------------------------------------------
30
+
31
+ import logging
32
+ import numpy as np
33
+ import torch
34
+ from PIL import Image
35
+ from diffusers import (
36
+ AutoencoderKL,
37
+ DDIMScheduler,
38
+ DiffusionPipeline,
39
+ LCMScheduler,
40
+ UNet2DConditionModel,
41
+ )
42
+ from diffusers.utils import BaseOutput
43
+ from torch.utils.data import DataLoader, TensorDataset
44
+ from torchvision.transforms import InterpolationMode
45
+ from torchvision.transforms.functional import pil_to_tensor, resize
46
+ from tqdm.auto import tqdm
47
+ from transformers import CLIPTextModel, CLIPTokenizer
48
+ from typing import Dict, Optional, Union
49
+
50
+ from .util.batchsize import find_batch_size
51
+ from .util.ensemble import ensemble_depth
52
+ from .util.image_util import (
53
+ chw2hwc,
54
+ colorize_depth_maps,
55
+ get_tv_resample_method,
56
+ resize_max_res,
57
+ )
58
+
59
+
60
+ class MarigoldDepthOutput(BaseOutput):
61
+ """
62
+ Output class for Marigold Monocular Depth Estimation pipeline.
63
+
64
+ Args:
65
+ depth_np (`np.ndarray`):
66
+ Predicted depth map, with depth values in the range of [0, 1].
67
+ depth_colored (`PIL.Image.Image`):
68
+ Colorized depth map, with the shape of [H, W, 3] and values in [0, 255].
69
+ uncertainty (`None` or `np.ndarray`):
70
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
71
+ """
72
+
73
+ depth_np: np.ndarray
74
+ depth_colored: Union[None, Image.Image]
75
+ uncertainty: Union[None, np.ndarray]
76
+
77
+
78
+ class MarigoldDepthPipeline(DiffusionPipeline):
79
+ """
80
+ Pipeline for Marigold Monocular Depth Estimation: https://marigoldcomputervision.github.io.
81
+
82
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
83
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
84
+
85
+ Args:
86
+ unet (`UNet2DConditionModel`):
87
+ Conditional U-Net to denoise the prediction latent, conditioned on image latent.
88
+ vae (`AutoencoderKL`):
89
+ Variational Auto-Encoder (VAE) Model to encode and decode images and predictions
90
+ to and from latent representations.
91
+ scheduler (`DDIMScheduler`):
92
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
93
+ text_encoder (`CLIPTextModel`):
94
+ Text-encoder, for empty text embedding.
95
+ tokenizer (`CLIPTokenizer`):
96
+ CLIP tokenizer.
97
+ scale_invariant (`bool`, *optional*):
98
+ A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
99
+ the model config. When used together with the `shift_invariant=True` flag, the model is also called
100
+ "affine-invariant". NB: overriding this value is not supported.
101
+ shift_invariant (`bool`, *optional*):
102
+ A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
103
+ the model config. When used together with the `scale_invariant=True` flag, the model is also called
104
+ "affine-invariant". NB: overriding this value is not supported.
105
+ default_denoising_steps (`int`, *optional*):
106
+ The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
107
+ quality with the given model. This value must be set in the model config. When the pipeline is called
108
+ without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
109
+ reasonable results with various model flavors compatible with the pipeline, such as those relying on very
110
+ short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
111
+ default_processing_resolution (`int`, *optional*):
112
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
113
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
114
+ default value is used. This is required to ensure reasonable results with various model flavors trained
115
+ with varying optimal processing resolution values.
116
+ """
117
+
118
+ latent_scale_factor = 0.18215
119
+
120
+ def __init__(
121
+ self,
122
+ unet: UNet2DConditionModel,
123
+ vae: AutoencoderKL,
124
+ scheduler: Union[DDIMScheduler, LCMScheduler],
125
+ text_encoder: CLIPTextModel,
126
+ tokenizer: CLIPTokenizer,
127
+ scale_invariant: Optional[bool] = True,
128
+ shift_invariant: Optional[bool] = True,
129
+ default_denoising_steps: Optional[int] = None,
130
+ default_processing_resolution: Optional[int] = None,
131
+ ):
132
+ super().__init__()
133
+ self.register_modules(
134
+ unet=unet,
135
+ vae=vae,
136
+ scheduler=scheduler,
137
+ text_encoder=text_encoder,
138
+ tokenizer=tokenizer,
139
+ )
140
+ self.register_to_config(
141
+ scale_invariant=scale_invariant,
142
+ shift_invariant=shift_invariant,
143
+ default_denoising_steps=default_denoising_steps,
144
+ default_processing_resolution=default_processing_resolution,
145
+ )
146
+
147
+ self.scale_invariant = scale_invariant
148
+ self.shift_invariant = shift_invariant
149
+ self.default_denoising_steps = default_denoising_steps
150
+ self.default_processing_resolution = default_processing_resolution
151
+
152
+ self.empty_text_embed = None
153
+
154
+ @torch.no_grad()
155
+ def __call__(
156
+ self,
157
+ input_image: Union[Image.Image, torch.Tensor],
158
+ denoising_steps: Optional[int] = None,
159
+ ensemble_size: int = 1,
160
+ processing_res: Optional[int] = None,
161
+ match_input_res: bool = True,
162
+ resample_method: str = "bilinear",
163
+ batch_size: int = 0,
164
+ generator: Union[torch.Generator, None] = None,
165
+ color_map: str = "Spectral",
166
+ show_progress_bar: bool = True,
167
+ ensemble_kwargs: Dict = None,
168
+ ) -> MarigoldDepthOutput:
169
+ """
170
+ Function invoked when calling the pipeline.
171
+
172
+ Args:
173
+ input_image (`Image`):
174
+ Input RGB (or gray-scale) image.
175
+ denoising_steps (`int`, *optional*, defaults to `None`):
176
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
177
+ selection.
178
+ ensemble_size (`int`, *optional*, defaults to `1`):
179
+ Number of predictions to be ensembled.
180
+ processing_res (`int`, *optional*, defaults to `None`):
181
+ Effective processing resolution. When set to `0`, processes at the original image resolution. This
182
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
183
+ value `None` resolves to the optimal value from the model config.
184
+ match_input_res (`bool`, *optional*, defaults to `True`):
185
+ Resize the prediction to match the input resolution.
186
+ Only valid if `processing_res` > 0.
187
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
188
+ Resampling method used to resize images and predictions. This can be one of `bilinear`, `bicubic` or
189
+ `nearest`, defaults to: `bilinear`.
190
+ batch_size (`int`, *optional*, defaults to `0`):
191
+ Inference batch size, no bigger than `num_ensemble`.
192
+ If set to 0, the script will automatically decide the proper batch size.
193
+ generator (`torch.Generator`, *optional*, defaults to `None`)
194
+ Random generator for initial noise generation.
195
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
196
+ Display a progress bar of diffusion denoising.
197
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
198
+ Colormap used to colorize the depth map.
199
+ scale_invariant (`str`, *optional*, defaults to `True`):
200
+ Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
201
+ shift_invariant (`str`, *optional*, defaults to `True`):
202
+ Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False,
203
+ near plane will be fixed at 0m.
204
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
205
+ Arguments for detailed ensembling settings.
206
+ Returns:
207
+ `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
208
+ - **depth_np** (`np.ndarray`) Predicted depth map with depth values in the range of [0, 1]
209
+ - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [H, W, 3] and values in [0, 255], None if `color_map` is `None`
210
+ - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
211
+ coming from ensembling. None if `ensemble_size = 1`
212
+ """
213
+ # Model-specific optimal default values leading to fast and reasonable results.
214
+ if denoising_steps is None:
215
+ denoising_steps = self.default_denoising_steps
216
+ if processing_res is None:
217
+ processing_res = self.default_processing_resolution
218
+
219
+ assert processing_res >= 0
220
+ assert ensemble_size >= 1
221
+
222
+ # Check if denoising step is reasonable
223
+ self._check_inference_step(denoising_steps)
224
+
225
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
226
+
227
+ # ----------------- Image Preprocess -----------------
228
+ # Convert to torch tensor
229
+ if isinstance(input_image, Image.Image):
230
+ input_image = input_image.convert("RGB")
231
+ # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
232
+ rgb = pil_to_tensor(input_image)
233
+ rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
234
+ elif isinstance(input_image, torch.Tensor):
235
+ rgb = input_image
236
+ else:
237
+ raise TypeError(f"Unknown input type: {type(input_image) = }")
238
+ input_size = rgb.shape
239
+ assert (
240
+ 4 == rgb.dim() and 3 == input_size[-3]
241
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
242
+
243
+ # Resize image
244
+ if processing_res > 0:
245
+ rgb = resize_max_res(
246
+ rgb,
247
+ max_edge_resolution=processing_res,
248
+ resample_method=resample_method,
249
+ )
250
+
251
+ # Normalize rgb values
252
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
253
+ rgb_norm = rgb_norm.to(self.dtype)
254
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
255
+
256
+ # ----------------- Predicting depth -----------------
257
+ # Batch repeated input image
258
+ duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
259
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
260
+ if batch_size > 0:
261
+ _bs = batch_size
262
+ else:
263
+ _bs = find_batch_size(
264
+ ensemble_size=ensemble_size,
265
+ input_res=max(rgb_norm.shape[1:]),
266
+ dtype=self.dtype,
267
+ )
268
+
269
+ single_rgb_loader = DataLoader(
270
+ single_rgb_dataset, batch_size=_bs, shuffle=False
271
+ )
272
+
273
+ # Predict depth maps (batched)
274
+ target_pred_ls = []
275
+ if show_progress_bar:
276
+ iterable = tqdm(
277
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
278
+ )
279
+ else:
280
+ iterable = single_rgb_loader
281
+ for batch in iterable:
282
+ (batched_img,) = batch
283
+ target_pred_raw = self.single_infer(
284
+ rgb_in=batched_img,
285
+ num_inference_steps=denoising_steps,
286
+ show_pbar=show_progress_bar,
287
+ generator=generator,
288
+ )
289
+ target_pred_ls.append(target_pred_raw.detach())
290
+ target_preds = torch.concat(target_pred_ls, dim=0)
291
+ torch.cuda.empty_cache() # clear vram cache for ensembling
292
+
293
+ # ----------------- Test-time ensembling -----------------
294
+ if ensemble_size > 1:
295
+ final_pred, pred_uncert = ensemble_depth(
296
+ target_preds,
297
+ scale_invariant=self.scale_invariant,
298
+ shift_invariant=self.shift_invariant,
299
+ **(ensemble_kwargs or {}),
300
+ )
301
+ else:
302
+ final_pred = target_preds
303
+ pred_uncert = None
304
+
305
+ # Resize back to original resolution
306
+ if match_input_res:
307
+ final_pred = resize(
308
+ final_pred,
309
+ input_size[-2:],
310
+ interpolation=resample_method,
311
+ antialias=True,
312
+ )
313
+
314
+ # Convert to numpy
315
+ final_pred = final_pred.squeeze()
316
+ final_pred = final_pred.cpu().numpy()
317
+ if pred_uncert is not None:
318
+ pred_uncert = pred_uncert.squeeze().cpu().numpy()
319
+
320
+ # Clip output range
321
+ final_pred = final_pred.clip(0, 1)
322
+
323
+ # Colorize
324
+ if color_map is not None:
325
+ depth_colored = colorize_depth_maps(
326
+ final_pred, 0, 1, cmap=color_map
327
+ ).squeeze() # [3, H, W], value in (0, 1)
328
+ depth_colored = (depth_colored * 255).astype(np.uint8)
329
+ depth_colored_hwc = chw2hwc(depth_colored)
330
+ depth_colored_img = Image.fromarray(depth_colored_hwc)
331
+ else:
332
+ depth_colored_img = None
333
+
334
+ return MarigoldDepthOutput(
335
+ depth_np=final_pred,
336
+ depth_colored=depth_colored_img,
337
+ uncertainty=pred_uncert,
338
+ )
339
+
340
+ def _check_inference_step(self, n_step: int) -> None:
341
+ """
342
+ Check if denoising step is reasonable
343
+ Args:
344
+ n_step (`int`): denoising steps
345
+ """
346
+ assert n_step >= 1
347
+
348
+ if isinstance(self.scheduler, DDIMScheduler):
349
+ if "trailing" != self.scheduler.config.timestep_spacing:
350
+ logging.warning(
351
+ f"The loaded `DDIMScheduler` is configured with `timestep_spacing="
352
+ f'"{self.scheduler.config.timestep_spacing}"`; the recommended setting is `"trailing"`. '
353
+ f"This change is backward-compatible and yields better results. "
354
+ f"Consider using `prs-eth/marigold-depth-v1-1` for the best experience."
355
+ )
356
+ else:
357
+ if n_step > 10:
358
+ logging.warning(
359
+ f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on "
360
+ f"the default values."
361
+ )
362
+ if not self.scheduler.config.rescale_betas_zero_snr:
363
+ logging.warning(
364
+ f"The loaded `DDIMScheduler` is configured with `rescale_betas_zero_snr="
365
+ f"{self.scheduler.config.rescale_betas_zero_snr}`; the recommended setting is True. "
366
+ f"Consider using `prs-eth/marigold-depth-v1-1` for the best experience."
367
+ )
368
+ elif isinstance(self.scheduler, LCMScheduler):
369
+ logging.warning(
370
+ "DeprecationWarning: LCMScheduler will not be supported in the future. "
371
+ "Consider using `prs-eth/marigold-depth-v1-1` for the best experience."
372
+ )
373
+ if n_step > 10:
374
+ logging.warning(
375
+ f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on "
376
+ f"the default values."
377
+ )
378
+ else:
379
+ raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
380
+
381
+ def encode_empty_text(self):
382
+ """
383
+ Encode text embedding for empty prompt
384
+ """
385
+ prompt = ""
386
+ text_inputs = self.tokenizer(
387
+ prompt,
388
+ padding="do_not_pad",
389
+ max_length=self.tokenizer.model_max_length,
390
+ truncation=True,
391
+ return_tensors="pt",
392
+ )
393
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
394
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
395
+
396
+ @torch.no_grad()
397
+ def single_infer(
398
+ self,
399
+ rgb_in: torch.Tensor,
400
+ num_inference_steps: int,
401
+ generator: Union[torch.Generator, None],
402
+ show_pbar: bool,
403
+ ) -> torch.Tensor:
404
+ """
405
+ Perform a single prediction without ensembling.
406
+
407
+ Args:
408
+ rgb_in (`torch.Tensor`):
409
+ Input RGB image.
410
+ num_inference_steps (`int`):
411
+ Number of diffusion denoisign steps (DDIM) during inference.
412
+ show_pbar (`bool`):
413
+ Display a progress bar of diffusion denoising.
414
+ generator (`torch.Generator`)
415
+ Random generator for initial noise generation.
416
+ Returns:
417
+ `torch.Tensor`: Predicted targets.
418
+ """
419
+ device = self.device
420
+ rgb_in = rgb_in.to(device)
421
+
422
+ # Set timesteps
423
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
424
+ timesteps = self.scheduler.timesteps # [T]
425
+
426
+ # Encode image
427
+ rgb_latent = self.encode_rgb(rgb_in) # [B, 4, h, w]
428
+
429
+ # Noisy latent for outputs
430
+ target_latent = torch.randn(
431
+ rgb_latent.shape,
432
+ device=device,
433
+ dtype=self.dtype,
434
+ generator=generator,
435
+ ) # [B, 4, h, w]
436
+
437
+ # Batched empty text embedding
438
+ if self.empty_text_embed is None:
439
+ self.encode_empty_text()
440
+ batch_empty_text_embed = self.empty_text_embed.repeat(
441
+ (rgb_latent.shape[0], 1, 1)
442
+ ).to(device) # [B, 2, 1024]
443
+
444
+ # Denoising loop
445
+ if show_pbar:
446
+ iterable = tqdm(
447
+ enumerate(timesteps),
448
+ total=len(timesteps),
449
+ leave=False,
450
+ desc=" " * 4 + "Diffusion denoising",
451
+ )
452
+ else:
453
+ iterable = enumerate(timesteps)
454
+
455
+ for i, t in iterable:
456
+ unet_input = torch.cat(
457
+ [rgb_latent, target_latent], dim=1
458
+ ) # this order is important
459
+
460
+ # predict the noise residual
461
+ noise_pred = self.unet(
462
+ unet_input, t, encoder_hidden_states=batch_empty_text_embed
463
+ ).sample # [B, 4, h, w]
464
+
465
+ # compute the previous noisy sample x_t -> x_t-1
466
+ target_latent = self.scheduler.step(
467
+ noise_pred, t, target_latent, generator=generator
468
+ ).prev_sample
469
+
470
+ depth = self.decode_depth(target_latent) # [B,3,H,W]
471
+
472
+ # clip prediction
473
+ depth = torch.clip(depth, -1.0, 1.0)
474
+ # shift to [0, 1]
475
+ depth = (depth + 1.0) / 2.0
476
+
477
+ return depth
478
+
479
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
480
+ """
481
+ Encode RGB image into latent.
482
+
483
+ Args:
484
+ rgb_in (`torch.Tensor`):
485
+ Input RGB image to be encoded.
486
+
487
+ Returns:
488
+ `torch.Tensor`: Image latent.
489
+ """
490
+ # encode
491
+ h = self.vae.encoder(rgb_in)
492
+ moments = self.vae.quant_conv(h)
493
+ mean, logvar = torch.chunk(moments, 2, dim=1)
494
+ # scale latent
495
+ rgb_latent = mean * self.latent_scale_factor
496
+ return rgb_latent
497
+
498
+ def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
499
+ """
500
+ Decode depth latent into depth map.
501
+
502
+ Args:
503
+ depth_latent (`torch.Tensor`):
504
+ Depth latent to be decoded.
505
+
506
+ Returns:
507
+ `torch.Tensor`: Decoded depth map.
508
+ """
509
+ # scale latent
510
+ depth_latent = depth_latent / self.latent_scale_factor
511
+ # decode
512
+ z = self.vae.post_quant_conv(depth_latent)
513
+ stacked = self.vae.decoder(z)
514
+ # mean of output channels
515
+ depth_mean = stacked.mean(dim=1, keepdim=True)
516
+ return depth_mean
marigold/marigold_normals_pipeline.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # More information about Marigold:
16
+ # https://marigoldmonodepth.github.io
17
+ # https://marigoldcomputervision.github.io
18
+ # Efficient inference pipelines are now part of diffusers:
19
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
20
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
21
+ # Examples of trained models and live demos:
22
+ # https://huggingface.co/prs-eth
23
+ # Related projects:
24
+ # https://rollingdepth.github.io/
25
+ # https://marigolddepthcompletion.github.io/
26
+ # Citation (BibTeX):
27
+ # https://github.com/prs-eth/Marigold#-citation
28
+ # If you find Marigold useful, we kindly ask you to cite our papers.
29
+ # --------------------------------------------------------------------------
30
+
31
+ import logging
32
+ import numpy as np
33
+ import torch
34
+ from PIL import Image
35
+ from diffusers import (
36
+ AutoencoderKL,
37
+ DDIMScheduler,
38
+ DiffusionPipeline,
39
+ LCMScheduler,
40
+ UNet2DConditionModel,
41
+ )
42
+ from diffusers.utils import BaseOutput
43
+ from torch.utils.data import DataLoader, TensorDataset
44
+ from torchvision.transforms import InterpolationMode
45
+ from torchvision.transforms.functional import pil_to_tensor, resize
46
+ from tqdm.auto import tqdm
47
+ from transformers import CLIPTextModel, CLIPTokenizer
48
+ from typing import Dict, Optional, Union
49
+
50
+ from .util.batchsize import find_batch_size
51
+ from .util.ensemble import ensemble_normals
52
+ from .util.image_util import (
53
+ chw2hwc,
54
+ get_tv_resample_method,
55
+ resize_max_res,
56
+ )
57
+
58
+
59
+ class MarigoldNormalsOutput(BaseOutput):
60
+ """
61
+ Output class for Marigold Surface Normals Estimation pipeline.
62
+
63
+ Args:
64
+ normals_np (`np.ndarray`):
65
+ Predicted normals map of shape [3, H, W] with values in the range of [-1, 1] (unit length vectors).
66
+ normals_img (`PIL.Image.Image`):
67
+ Normals image, with the shape of [H, W, 3] and values in [0, 255].
68
+ uncertainty (`None` or `np.ndarray`):
69
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
70
+ """
71
+
72
+ normals_np: np.ndarray
73
+ normals_img: Image.Image
74
+ uncertainty: Union[None, np.ndarray]
75
+
76
+
77
+ class MarigoldNormalsPipeline(DiffusionPipeline):
78
+ """
79
+ Pipeline for Marigold Surface Normals Estimation: https://marigoldcomputervision.github.io.
80
+
81
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
82
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
83
+
84
+ Args:
85
+ unet (`UNet2DConditionModel`):
86
+ Conditional U-Net to denoise the prediction latent, conditioned on image latent.
87
+ vae (`AutoencoderKL`):
88
+ Variational Auto-Encoder (VAE) Model to encode and decode images and predictions
89
+ to and from latent representations.
90
+ scheduler (`DDIMScheduler`):
91
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
92
+ text_encoder (`CLIPTextModel`):
93
+ Text-encoder, for empty text embedding.
94
+ tokenizer (`CLIPTokenizer`):
95
+ CLIP tokenizer.
96
+ default_denoising_steps (`int`, *optional*):
97
+ The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
98
+ quality with the given model. This value must be set in the model config. When the pipeline is called
99
+ without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
100
+ reasonable results with various model flavors compatible with the pipeline, such as those relying on very
101
+ short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
102
+ default_processing_resolution (`int`, *optional*):
103
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
104
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
105
+ default value is used. This is required to ensure reasonable results with various model flavors trained
106
+ with varying optimal processing resolution values.
107
+ """
108
+
109
+ latent_scale_factor = 0.18215
110
+
111
+ def __init__(
112
+ self,
113
+ unet: UNet2DConditionModel,
114
+ vae: AutoencoderKL,
115
+ scheduler: Union[DDIMScheduler, LCMScheduler],
116
+ text_encoder: CLIPTextModel,
117
+ tokenizer: CLIPTokenizer,
118
+ default_denoising_steps: Optional[int] = None,
119
+ default_processing_resolution: Optional[int] = None,
120
+ ):
121
+ super().__init__()
122
+ self.register_modules(
123
+ unet=unet,
124
+ vae=vae,
125
+ scheduler=scheduler,
126
+ text_encoder=text_encoder,
127
+ tokenizer=tokenizer,
128
+ )
129
+ self.register_to_config(
130
+ default_denoising_steps=default_denoising_steps,
131
+ default_processing_resolution=default_processing_resolution,
132
+ )
133
+
134
+ self.default_denoising_steps = default_denoising_steps
135
+ self.default_processing_resolution = default_processing_resolution
136
+
137
+ self.empty_text_embed = None
138
+
139
+ @torch.no_grad()
140
+ def __call__(
141
+ self,
142
+ input_image: Union[Image.Image, torch.Tensor],
143
+ denoising_steps: Optional[int] = None,
144
+ ensemble_size: int = 1,
145
+ processing_res: Optional[int] = None,
146
+ match_input_res: bool = True,
147
+ resample_method: str = "bilinear",
148
+ batch_size: int = 0,
149
+ generator: Union[torch.Generator, None] = None,
150
+ show_progress_bar: bool = True,
151
+ ensemble_kwargs: Dict = None,
152
+ ) -> MarigoldNormalsOutput:
153
+ """
154
+ Function invoked when calling the pipeline.
155
+
156
+ Args:
157
+ input_image (`Image`):
158
+ Input RGB (or gray-scale) image.
159
+ denoising_steps (`int`, *optional*, defaults to `None`):
160
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
161
+ selection.
162
+ ensemble_size (`int`, *optional*, defaults to `1`):
163
+ Number of predictions to be ensembled.
164
+ processing_res (`int`, *optional*, defaults to `None`):
165
+ Effective processing resolution. When set to `0`, processes at the original image resolution. This
166
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
167
+ value `None` resolves to the optimal value from the model config.
168
+ match_input_res (`bool`, *optional*, defaults to `True`):
169
+ Resize the prediction to match the input resolution.
170
+ Only valid if `processing_res` > 0.
171
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
172
+ Resampling method used to resize images and predictions. This can be one of `bilinear`, `bicubic` or
173
+ `nearest`, defaults to: `bilinear`.
174
+ batch_size (`int`, *optional*, defaults to `0`):
175
+ Inference batch size, no bigger than `num_ensemble`.
176
+ If set to 0, the script will automatically decide the proper batch size.
177
+ generator (`torch.Generator`, *optional*, defaults to `None`)
178
+ Random generator for initial noise generation.
179
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
180
+ Display a progress bar of diffusion denoising.
181
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
182
+ Arguments for detailed ensembling settings.
183
+ Returns:
184
+ `MarigoldNormalsOutput`: Output class for Marigold monocular surface normals estimation pipeline, including:
185
+ - **normals_np** (`np.ndarray`) Predicted normals map of shape [3, H, W] with values in the range of [-1, 1]
186
+ (unit length vectors)
187
+ - **normals_img** (`PIL.Image.Image`) Normals image, with the shape of [H, W, 3] and values in [0, 255]
188
+ - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
189
+ coming from ensembling. None if `ensemble_size = 1`
190
+ """
191
+ # Model-specific optimal default values leading to fast and reasonable results.
192
+ if denoising_steps is None:
193
+ denoising_steps = self.default_denoising_steps
194
+ if processing_res is None:
195
+ processing_res = self.default_processing_resolution
196
+
197
+ assert processing_res >= 0
198
+ assert ensemble_size >= 1
199
+
200
+ # Check if denoising step is reasonable
201
+ self._check_inference_step(denoising_steps)
202
+
203
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
204
+
205
+ # ----------------- Image Preprocess -----------------
206
+ # Convert to torch tensor
207
+ if isinstance(input_image, Image.Image):
208
+ input_image = input_image.convert("RGB")
209
+ # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
210
+ rgb = pil_to_tensor(input_image)
211
+ rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
212
+ elif isinstance(input_image, torch.Tensor):
213
+ rgb = input_image
214
+ else:
215
+ raise TypeError(f"Unknown input type: {type(input_image) = }")
216
+ input_size = rgb.shape
217
+ assert (
218
+ 4 == rgb.dim() and 3 == input_size[-3]
219
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
220
+
221
+ # Resize image
222
+ if processing_res > 0:
223
+ rgb = resize_max_res(
224
+ rgb,
225
+ max_edge_resolution=processing_res,
226
+ resample_method=resample_method,
227
+ )
228
+
229
+ # Normalize rgb values
230
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
231
+ rgb_norm = rgb_norm.to(self.dtype)
232
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
233
+
234
+ # ----------------- Predicting normals -----------------
235
+ # Batch repeated input image
236
+ duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
237
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
238
+ if batch_size > 0:
239
+ _bs = batch_size
240
+ else:
241
+ _bs = find_batch_size(
242
+ ensemble_size=ensemble_size,
243
+ input_res=max(rgb_norm.shape[1:]),
244
+ dtype=self.dtype,
245
+ )
246
+
247
+ single_rgb_loader = DataLoader(
248
+ single_rgb_dataset, batch_size=_bs, shuffle=False
249
+ )
250
+
251
+ # Predict normals maps (batched)
252
+ target_pred_ls = []
253
+ if show_progress_bar:
254
+ iterable = tqdm(
255
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
256
+ )
257
+ else:
258
+ iterable = single_rgb_loader
259
+ for batch in iterable:
260
+ (batched_img,) = batch
261
+ target_pred_raw = self.single_infer(
262
+ rgb_in=batched_img,
263
+ num_inference_steps=denoising_steps,
264
+ show_pbar=show_progress_bar,
265
+ generator=generator,
266
+ )
267
+ target_pred_ls.append(target_pred_raw.detach())
268
+ target_preds = torch.concat(target_pred_ls, dim=0)
269
+ torch.cuda.empty_cache() # clear vram cache for ensembling
270
+
271
+ # ----------------- Test-time ensembling -----------------
272
+ if ensemble_size > 1:
273
+ final_pred, pred_uncert = ensemble_normals(
274
+ target_preds,
275
+ **(ensemble_kwargs or {}),
276
+ )
277
+ else:
278
+ final_pred = target_preds
279
+ pred_uncert = None
280
+
281
+ # Resize back to original resolution
282
+ if match_input_res:
283
+ final_pred = resize(
284
+ final_pred,
285
+ input_size[-2:],
286
+ interpolation=resample_method,
287
+ antialias=True,
288
+ )
289
+
290
+ # Convert to numpy
291
+ final_pred = final_pred.squeeze()
292
+ final_pred = final_pred.cpu().numpy()
293
+ if pred_uncert is not None:
294
+ pred_uncert = pred_uncert.squeeze().cpu().numpy()
295
+
296
+ # Clip output range
297
+ final_pred = final_pred.clip(-1, 1)
298
+
299
+ # Colorize
300
+ normals_img = ((final_pred + 1) * 127.5).astype(np.uint8)
301
+ normals_img = chw2hwc(normals_img)
302
+ normals_img = Image.fromarray(normals_img)
303
+
304
+ return MarigoldNormalsOutput(
305
+ normals_np=final_pred,
306
+ normals_img=normals_img,
307
+ uncertainty=pred_uncert,
308
+ )
309
+
310
+ def _check_inference_step(self, n_step: int) -> None:
311
+ """
312
+ Check if denoising step is reasonable
313
+ Args:
314
+ n_step (`int`): denoising steps
315
+ """
316
+ assert n_step >= 1
317
+
318
+ if isinstance(self.scheduler, DDIMScheduler):
319
+ if "trailing" != self.scheduler.config.timestep_spacing:
320
+ logging.warning(
321
+ f"The loaded `DDIMScheduler` is configured with `timestep_spacing="
322
+ f'"{self.scheduler.config.timestep_spacing}"`; the recommended setting is `"trailing"`. '
323
+ f"This change is backward-compatible and yields better results. "
324
+ f"Consider using `prs-eth/marigold-normals-v1-1` for the best experience."
325
+ )
326
+ else:
327
+ if n_step > 10:
328
+ logging.warning(
329
+ f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on "
330
+ f"the default values."
331
+ )
332
+ if not self.scheduler.config.rescale_betas_zero_snr:
333
+ logging.warning(
334
+ f"The loaded `DDIMScheduler` is configured with `rescale_betas_zero_snr="
335
+ f"{self.scheduler.config.rescale_betas_zero_snr}`; the recommended setting is True. "
336
+ f"Consider using `prs-eth/marigold-normals-v1-1` for the best experience."
337
+ )
338
+ elif isinstance(self.scheduler, LCMScheduler):
339
+ raise RuntimeError(
340
+ "This pipeline implementation does not support the LCMScheduler. Please refer to the project "
341
+ "README.md for instructions about using LCM."
342
+ )
343
+ else:
344
+ raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
345
+
346
+ def encode_empty_text(self):
347
+ """
348
+ Encode text embedding for empty prompt
349
+ """
350
+ prompt = ""
351
+ text_inputs = self.tokenizer(
352
+ prompt,
353
+ padding="do_not_pad",
354
+ max_length=self.tokenizer.model_max_length,
355
+ truncation=True,
356
+ return_tensors="pt",
357
+ )
358
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
359
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
360
+
361
+ @torch.no_grad()
362
+ def single_infer(
363
+ self,
364
+ rgb_in: torch.Tensor,
365
+ num_inference_steps: int,
366
+ generator: Union[torch.Generator, None],
367
+ show_pbar: bool,
368
+ ) -> torch.Tensor:
369
+ """
370
+ Perform a single prediction without ensembling.
371
+
372
+ Args:
373
+ rgb_in (`torch.Tensor`):
374
+ Input RGB image.
375
+ num_inference_steps (`int`):
376
+ Number of diffusion denoisign steps (DDIM) during inference.
377
+ show_pbar (`bool`):
378
+ Display a progress bar of diffusion denoising.
379
+ generator (`torch.Generator`)
380
+ Random generator for initial noise generation.
381
+ Returns:
382
+ `torch.Tensor`: Predicted targets.
383
+ """
384
+ device = self.device
385
+ rgb_in = rgb_in.to(device)
386
+
387
+ # Set timesteps
388
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
389
+ timesteps = self.scheduler.timesteps # [T]
390
+
391
+ # Encode image
392
+ rgb_latent = self.encode_rgb(rgb_in) # [B, 4, h, w]
393
+
394
+ # Noisy latent for outputs
395
+ target_latent = torch.randn(
396
+ rgb_latent.shape,
397
+ device=device,
398
+ dtype=self.dtype,
399
+ generator=generator,
400
+ ) # [B, 4, h, w]
401
+
402
+ # Batched empty text embedding
403
+ if self.empty_text_embed is None:
404
+ self.encode_empty_text()
405
+ batch_empty_text_embed = self.empty_text_embed.repeat(
406
+ (rgb_latent.shape[0], 1, 1)
407
+ ).to(device) # [B, 2, 1024]
408
+
409
+ # Denoising loop
410
+ if show_pbar:
411
+ iterable = tqdm(
412
+ enumerate(timesteps),
413
+ total=len(timesteps),
414
+ leave=False,
415
+ desc=" " * 4 + "Diffusion denoising",
416
+ )
417
+ else:
418
+ iterable = enumerate(timesteps)
419
+
420
+ for i, t in iterable:
421
+ unet_input = torch.cat(
422
+ [rgb_latent, target_latent], dim=1
423
+ ) # this order is important
424
+
425
+ # predict the noise residual
426
+ noise_pred = self.unet(
427
+ unet_input, t, encoder_hidden_states=batch_empty_text_embed
428
+ ).sample # [B, 4, h, w]
429
+
430
+ # compute the previous noisy sample x_t -> x_t-1
431
+ target_latent = self.scheduler.step(
432
+ noise_pred, t, target_latent, generator=generator
433
+ ).prev_sample
434
+
435
+ normals = self.decode_normals(target_latent) # [B,3,H,W]
436
+
437
+ # clip prediction
438
+ normals = torch.clip(normals, -1.0, 1.0)
439
+ norm = torch.norm(normals, dim=1, keepdim=True)
440
+ normals /= norm.clamp(min=1e-6) # [B,3,H,W]
441
+
442
+ return normals
443
+
444
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
445
+ """
446
+ Encode RGB image into latent.
447
+
448
+ Args:
449
+ rgb_in (`torch.Tensor`):
450
+ Input RGB image to be encoded.
451
+
452
+ Returns:
453
+ `torch.Tensor`: Image latent.
454
+ """
455
+ # encode
456
+ h = self.vae.encoder(rgb_in)
457
+ moments = self.vae.quant_conv(h)
458
+ mean, logvar = torch.chunk(moments, 2, dim=1)
459
+ # scale latent
460
+ rgb_latent = mean * self.latent_scale_factor
461
+ return rgb_latent
462
+
463
+ def decode_normals(self, normals_latent: torch.Tensor) -> torch.Tensor:
464
+ """
465
+ Decode normals latent into normals map.
466
+
467
+ Args:
468
+ normals_latent (`torch.Tensor`):
469
+ Normals latent to be decoded.
470
+
471
+ Returns:
472
+ `torch.Tensor`: Decoded normals map.
473
+ """
474
+ # scale latent
475
+ normals_latent = normals_latent / self.latent_scale_factor
476
+ # decode
477
+ z = self.vae.post_quant_conv(normals_latent)
478
+ stacked = self.vae.decoder(z)
479
+ return stacked
marigold/util/batchsize.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # More information about Marigold:
16
+ # https://marigoldmonodepth.github.io
17
+ # https://marigoldcomputervision.github.io
18
+ # Efficient inference pipelines are now part of diffusers:
19
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
20
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
21
+ # Examples of trained models and live demos:
22
+ # https://huggingface.co/prs-eth
23
+ # Related projects:
24
+ # https://rollingdepth.github.io/
25
+ # https://marigolddepthcompletion.github.io/
26
+ # Citation (BibTeX):
27
+ # https://github.com/prs-eth/Marigold#-citation
28
+ # If you find Marigold useful, we kindly ask you to cite our papers.
29
+ # --------------------------------------------------------------------------
30
+
31
+ import math
32
+ import torch
33
+
34
+ # Search table for suggested max. inference batch size
35
+ bs_search_table = [
36
+ # tested on A100-PCIE-80GB
37
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
38
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
39
+ # tested on A100-PCIE-40GB
40
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
41
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
42
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
43
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
44
+ # tested on RTX3090, RTX4090
45
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
46
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
47
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
48
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
49
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
50
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
51
+ # tested on GTX1080Ti
52
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
53
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
54
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
55
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
56
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
57
+ ]
58
+
59
+
60
+ def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
61
+ """
62
+ Automatically search for suitable operating batch size.
63
+
64
+ Args:
65
+ ensemble_size (`int`):
66
+ Number of predictions to be ensembled.
67
+ input_res (`int`):
68
+ Operating resolution of the input image.
69
+
70
+ Returns:
71
+ `int`: Operating batch size.
72
+ """
73
+ if not torch.cuda.is_available():
74
+ return 1
75
+
76
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
77
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
78
+ for settings in sorted(
79
+ filtered_bs_search_table,
80
+ key=lambda k: (k["res"], -k["total_vram"]),
81
+ ):
82
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
83
+ bs = settings["bs"]
84
+ if bs > ensemble_size:
85
+ bs = ensemble_size
86
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
87
+ bs = math.ceil(ensemble_size / 2)
88
+ return bs
89
+
90
+ return 1
marigold/util/ensemble.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # More information about Marigold:
16
+ # https://marigoldmonodepth.github.io
17
+ # https://marigoldcomputervision.github.io
18
+ # Efficient inference pipelines are now part of diffusers:
19
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
20
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
21
+ # Examples of trained models and live demos:
22
+ # https://huggingface.co/prs-eth
23
+ # Related projects:
24
+ # https://rollingdepth.github.io/
25
+ # https://marigolddepthcompletion.github.io/
26
+ # Citation (BibTeX):
27
+ # https://github.com/prs-eth/Marigold#-citation
28
+ # If you find Marigold useful, we kindly ask you to cite our papers.
29
+ # --------------------------------------------------------------------------
30
+
31
+ import numpy as np
32
+ import torch
33
+ from functools import partial
34
+ from typing import Optional, Tuple
35
+
36
+ from .image_util import get_tv_resample_method, resize_max_res
37
+
38
+
39
+ def ensemble_depth(
40
+ depth: torch.Tensor,
41
+ scale_invariant: bool = True,
42
+ shift_invariant: bool = True,
43
+ output_uncertainty: bool = False,
44
+ reduction: str = "median",
45
+ regularizer_strength: float = 0.02,
46
+ max_iter: int = 50,
47
+ tol: float = 1e-6,
48
+ max_res: int = 1024,
49
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
50
+ """
51
+ Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the
52
+ number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for
53
+ depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The
54
+ alignment happens when the predictions have one or more degrees of freedom, that is when they are either
55
+ affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only
56
+ `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`)
57
+ alignment is skipped and only ensembling is performed.
58
+
59
+ Args:
60
+ depth (`torch.Tensor`):
61
+ Input ensemble depth maps.
62
+ scale_invariant (`bool`, *optional*, defaults to `True`):
63
+ Whether to treat predictions as scale-invariant.
64
+ shift_invariant (`bool`, *optional*, defaults to `True`):
65
+ Whether to treat predictions as shift-invariant.
66
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
67
+ Whether to output uncertainty map.
68
+ reduction (`str`, *optional*, defaults to `"median"`):
69
+ Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and
70
+ `"median"`.
71
+ regularizer_strength (`float`, *optional*, defaults to `0.02`):
72
+ Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1.
73
+ max_iter (`int`, *optional*, defaults to `2`):
74
+ Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options`
75
+ argument.
76
+ tol (`float`, *optional*, defaults to `1e-3`):
77
+ Alignment solver tolerance. The solver stops when the tolerance is reached.
78
+ max_res (`int`, *optional*, defaults to `1024`):
79
+ Resolution at which the alignment is performed; `None` matches the `processing_resolution`.
80
+ Returns:
81
+ A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape:
82
+ `(1, 1, H, W)`.
83
+ """
84
+ if depth.dim() != 4 or depth.shape[1] != 1:
85
+ raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.")
86
+ if reduction not in ("mean", "median"):
87
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
88
+ if not scale_invariant and shift_invariant:
89
+ raise ValueError("Pure shift-invariant ensembling is not supported.")
90
+
91
+ def init_param(depth: torch.Tensor):
92
+ init_min = depth.reshape(ensemble_size, -1).min(dim=1).values
93
+ init_max = depth.reshape(ensemble_size, -1).max(dim=1).values
94
+
95
+ if scale_invariant and shift_invariant:
96
+ init_s = 1.0 / (init_max - init_min).clamp(min=1e-6)
97
+ init_t = -init_s * init_min
98
+ param = torch.cat((init_s, init_t)).cpu().numpy()
99
+ elif scale_invariant:
100
+ init_s = 1.0 / init_max.clamp(min=1e-6)
101
+ param = init_s.cpu().numpy()
102
+ else:
103
+ raise ValueError("Unrecognized alignment.")
104
+
105
+ return param.astype(np.float64)
106
+
107
+ def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor:
108
+ if scale_invariant and shift_invariant:
109
+ s, t = np.split(param, 2)
110
+ s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1)
111
+ t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1)
112
+ out = depth * s + t
113
+ elif scale_invariant:
114
+ s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1)
115
+ out = depth * s
116
+ else:
117
+ raise ValueError("Unrecognized alignment.")
118
+ return out
119
+
120
+ def ensemble(
121
+ depth_aligned: torch.Tensor, return_uncertainty: bool = False
122
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
123
+ uncertainty = None
124
+ if reduction == "mean":
125
+ prediction = torch.mean(depth_aligned, dim=0, keepdim=True)
126
+ if return_uncertainty:
127
+ uncertainty = torch.std(depth_aligned, dim=0, keepdim=True)
128
+ elif reduction == "median":
129
+ prediction = torch.median(depth_aligned, dim=0, keepdim=True).values
130
+ if return_uncertainty:
131
+ uncertainty = torch.median(
132
+ torch.abs(depth_aligned - prediction), dim=0, keepdim=True
133
+ ).values
134
+ else:
135
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
136
+ return prediction, uncertainty
137
+
138
+ def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float:
139
+ cost = 0.0
140
+ depth_aligned = align(depth, param)
141
+
142
+ for i, j in torch.combinations(torch.arange(ensemble_size)):
143
+ diff = depth_aligned[i] - depth_aligned[j]
144
+ cost += (diff**2).mean().sqrt().item()
145
+
146
+ if regularizer_strength > 0:
147
+ prediction, _ = ensemble(depth_aligned, return_uncertainty=False)
148
+ err_near = (0.0 - prediction.min()).abs().item()
149
+ err_far = (1.0 - prediction.max()).abs().item()
150
+ cost += (err_near + err_far) * regularizer_strength
151
+
152
+ return cost
153
+
154
+ def compute_param(depth: torch.Tensor):
155
+ import scipy
156
+
157
+ depth_to_align = depth.to(torch.float32)
158
+ if max_res is not None and max(depth_to_align.shape[2:]) > max_res:
159
+ depth_to_align = resize_max_res(
160
+ depth_to_align, max_res, get_tv_resample_method("nearest-exact")
161
+ )
162
+
163
+ param = init_param(depth_to_align)
164
+
165
+ res = scipy.optimize.minimize(
166
+ partial(cost_fn, depth=depth_to_align),
167
+ param,
168
+ method="BFGS",
169
+ tol=tol,
170
+ options={"maxiter": max_iter, "disp": False},
171
+ )
172
+
173
+ return res.x
174
+
175
+ requires_aligning = scale_invariant or shift_invariant
176
+ ensemble_size = depth.shape[0]
177
+
178
+ if requires_aligning:
179
+ param = compute_param(depth)
180
+ depth = align(depth, param)
181
+
182
+ depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty)
183
+
184
+ depth_max = depth.max()
185
+ if scale_invariant and shift_invariant:
186
+ depth_min = depth.min()
187
+ elif scale_invariant:
188
+ depth_min = 0
189
+ else:
190
+ raise ValueError("Unrecognized alignment.")
191
+ depth_range = (depth_max - depth_min).clamp(min=1e-6)
192
+ depth = (depth - depth_min) / depth_range
193
+ if output_uncertainty:
194
+ uncertainty /= depth_range
195
+
196
+ return depth, uncertainty # [1,1,H,W], [1,1,H,W]
197
+
198
+
199
+ def ensemble_normals(
200
+ normals: torch.Tensor,
201
+ output_uncertainty: bool = False,
202
+ reduction: str = "closest",
203
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
204
+ """
205
+ Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is
206
+ the number of ensemble members for a given prediction of size `(H x W)`.
207
+
208
+ Args:
209
+ normals (`torch.Tensor`):
210
+ Input ensemble normals maps.
211
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
212
+ Whether to output uncertainty map.
213
+ reduction (`str`, *optional*, defaults to `"closest"`):
214
+ Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and
215
+ `"mean"`.
216
+
217
+ Returns:
218
+ A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of
219
+ uncertainties of shape `(1, 1, H, W)`.
220
+ """
221
+ if normals.dim() != 4 or normals.shape[1] != 3:
222
+ raise ValueError(
223
+ f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}."
224
+ )
225
+ if reduction not in ("closest", "mean"):
226
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
227
+
228
+ mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W]
229
+ norm = torch.norm(mean_normals, dim=1, keepdim=True)
230
+ mean_normals /= norm.clamp(min=1e-6) # [1,3,H,W]
231
+
232
+ sim_cos = None
233
+ if output_uncertainty or (reduction != "mean"):
234
+ sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W]
235
+ sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16
236
+
237
+ uncertainty = None
238
+ if output_uncertainty:
239
+ uncertainty = sim_cos.arccos() # [E,1,H,W]
240
+ uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W]
241
+
242
+ if reduction == "mean":
243
+ return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W]
244
+
245
+ closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W]
246
+ closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W]
247
+ closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W]
248
+
249
+ return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W]
250
+
251
+
252
+ def ensemble_iid(
253
+ targets: torch.Tensor,
254
+ output_uncertainty: bool = False,
255
+ reduction: str = "median",
256
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
257
+ uncertainty = None
258
+ if reduction == "mean":
259
+ prediction = torch.mean(targets, dim=0, keepdim=True)
260
+ if output_uncertainty:
261
+ uncertainty = torch.std(targets, dim=0, keepdim=True)
262
+ elif reduction == "median":
263
+ prediction = torch.median(targets, dim=0, keepdim=True).values
264
+ if output_uncertainty:
265
+ uncertainty = torch.median(
266
+ torch.abs(targets - prediction), dim=0, keepdim=True
267
+ ).values
268
+ else:
269
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
270
+ return prediction, uncertainty
requirements++.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ h5py
2
+ opencv-python
3
+ tensorboard
4
+ wandb
5
+ scikit-learn
6
+ xformers==0.0.28
7
+
requirements+.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ omegaconf
2
+ pandas
3
+ tabulate
4
+ torchmetrics