degbo commited on
Commit
f2dd2b8
·
1 Parent(s): 5e75ca4

update with new code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +5 -8
  2. config/dataset_depth/data_diode_all.yaml +0 -4
  3. config/dataset_depth/data_eth3d.yaml +0 -4
  4. config/dataset_depth/data_hypersim_train.yaml +0 -4
  5. config/dataset_depth/data_hypersim_val.yaml +0 -4
  6. config/dataset_depth/data_kitti_eigen_test.yaml +0 -6
  7. config/dataset_depth/data_kitti_val.yaml +0 -6
  8. config/dataset_depth/data_nyu_test.yaml +0 -5
  9. config/dataset_depth/data_nyu_train.yaml +0 -5
  10. config/dataset_depth/data_scannet_val.yaml +0 -4
  11. config/dataset_depth/data_vkitti_train.yaml +0 -6
  12. config/dataset_depth/data_vkitti_val.yaml +0 -6
  13. config/dataset_depth/dataset_train.yaml +0 -18
  14. config/dataset_depth/dataset_val.yaml +0 -45
  15. config/dataset_depth/dataset_vis.yaml +0 -9
  16. config/dataset_iid/data_appearance_interiorverse_test.yaml +0 -4
  17. config/dataset_iid/data_appearance_synthetic_test.yaml +0 -4
  18. config/dataset_iid/data_art_test.yaml +0 -4
  19. config/dataset_iid/data_lighting_hypersim_test.yaml +0 -4
  20. config/dataset_iid/dataset_appearance_train.yaml +0 -9
  21. config/dataset_iid/dataset_appearance_val.yaml +0 -6
  22. config/dataset_iid/dataset_appearance_vis.yaml +0 -6
  23. config/dataset_iid/dataset_lighting_train.yaml +0 -12
  24. config/dataset_iid/dataset_lighting_val.yaml +0 -6
  25. config/dataset_iid/dataset_lighting_vis.yaml +0 -6
  26. config/dataset_iid/osu_data_appearance_interiorverse_test.yaml +0 -4
  27. config/dataset_normals/data_diode_test.yaml +0 -4
  28. config/dataset_normals/data_ibims_test.yaml +0 -4
  29. config/dataset_normals/data_nyu_test.yaml +0 -4
  30. config/dataset_normals/data_oasis_test.yaml +0 -4
  31. config/dataset_normals/data_scannet_test.yaml +0 -4
  32. config/dataset_normals/dataset_train.yaml +0 -25
  33. config/dataset_normals/dataset_val.yaml +0 -7
  34. config/dataset_normals/dataset_vis.yaml +0 -7
  35. config/logging.yaml +0 -5
  36. config/model_sdv2.yaml +0 -4
  37. config/train_debug_depth.yaml +0 -10
  38. config/train_debug_iid.yaml +0 -11
  39. config/train_debug_normals.yaml +0 -10
  40. config/train_marigold_depth.yaml +0 -94
  41. config/train_marigold_iid_appearance.yaml +0 -81
  42. config/train_marigold_iid_appearance_finetuned.yaml +0 -81
  43. config/train_marigold_iid_lighting.yaml +0 -82
  44. config/train_marigold_normals.yaml +0 -86
  45. config/wandb.yaml +0 -3
  46. marigold/__init__.py +0 -41
  47. marigold/marigold_depth_pipeline.py +0 -516
  48. marigold/marigold_normals_pipeline.py +0 -479
  49. {src → olbedo}/__init__.py +2 -0
  50. olbedo/__pycache__/__init__.cpython-310.pyc +0 -0
app.py CHANGED
@@ -7,15 +7,13 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..",
7
  import gradio as gr
8
  import numpy as np
9
  import torch
10
- from marigold import MarigoldIIDOutput, MarigoldIIDPipeline
11
  from src.util.image_util import read_img_from_file, img_hwc2chw, img_linear2srgb, is_hdr
12
- from marigold.util.image_util import float2int
13
  from src.util.seeding import seed_all
14
  import logging
15
  from huggingface_hub import snapshot_download
16
 
17
- HF_TOKEN = os.getenv("HF_TOKEN")
18
-
19
  seed = 1234
20
  seed_all(seed)
21
  if torch.cuda.is_available():
@@ -46,12 +44,11 @@ def get_demo():
46
  local_dir = snapshot_download(
47
  repo_id="GDAOSU/olbedo",
48
  allow_patterns=f"{selected_model}/*",
49
- token=HF_TOKEN,
50
  )
51
 
52
  model_path = os.path.join(local_dir, selected_model)
53
 
54
- pipe = MarigoldIIDPipeline.from_pretrained(
55
  model_path,
56
  torch_dtype=torch.float32,
57
  ).to(device)
@@ -102,7 +99,7 @@ def get_demo():
102
  if "rgbx" in selected_model:
103
  pipe.prompt = prompt
104
 
105
- pipe_out: MarigoldIIDOutput = pipe(
106
  input_image,
107
  denoising_steps=inference_step,
108
  ensemble_size=1,
@@ -136,7 +133,7 @@ def get_demo():
136
  block = gr.Blocks()
137
  with block:
138
  with gr.Row():
139
- gr.Markdown("## OSU albedo demo")
140
  with gr.Row():
141
  # Input side
142
  with gr.Column():
 
7
  import gradio as gr
8
  import numpy as np
9
  import torch
10
+ from olbedo import OlbedoIIDOutput, OlbedoIIDPipeline
11
  from src.util.image_util import read_img_from_file, img_hwc2chw, img_linear2srgb, is_hdr
12
+ from olbedo.util.image_util import float2int
13
  from src.util.seeding import seed_all
14
  import logging
15
  from huggingface_hub import snapshot_download
16
 
 
 
17
  seed = 1234
18
  seed_all(seed)
19
  if torch.cuda.is_available():
 
44
  local_dir = snapshot_download(
45
  repo_id="GDAOSU/olbedo",
46
  allow_patterns=f"{selected_model}/*",
 
47
  )
48
 
49
  model_path = os.path.join(local_dir, selected_model)
50
 
51
+ pipe = OlbedoIIDPipeline.from_pretrained(
52
  model_path,
53
  torch_dtype=torch.float32,
54
  ).to(device)
 
99
  if "rgbx" in selected_model:
100
  pipe.prompt = prompt
101
 
102
+ pipe_out: OlbedoIIDOutput = pipe(
103
  input_image,
104
  denoising_steps=inference_step,
105
  ensemble_size=1,
 
133
  block = gr.Blocks()
134
  with block:
135
  with gr.Row():
136
+ gr.Markdown("## Olbedo: An Albedo and Shading Aerial Dataset for Large-Scale Outdoor Environments")
137
  with gr.Row():
138
  # Input side
139
  with gr.Column():
config/dataset_depth/data_diode_all.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: diode_depth
2
- disp_name: diode_depth_val_all
3
- dir: diode/diode_val.tar
4
- filenames: data_split/diode_depth/diode_val_all_filename_list.txt
 
 
 
 
 
config/dataset_depth/data_eth3d.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: eth3d_depth
2
- disp_name: eth3d_depth_full
3
- dir: eth3d/eth3d.tar
4
- filenames: data_split/eth3d_depth/eth3d_filename_list.txt
 
 
 
 
 
config/dataset_depth/data_hypersim_train.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: hypersim_depth
2
- disp_name: hypersim_depth_train
3
- dir: hypersim/hypersim_processed_train.tar
4
- filenames: data_split/hypersim_depth/filename_list_train_filtered.txt
 
 
 
 
 
config/dataset_depth/data_hypersim_val.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: hypersim_depth
2
- disp_name: hypersim_depth_val
3
- dir: hypersim/hypersim_processed_val.tar
4
- filenames: data_split/hypersim_depth/filename_list_val_filtered.txt
 
 
 
 
 
config/dataset_depth/data_kitti_eigen_test.yaml DELETED
@@ -1,6 +0,0 @@
1
- name: kitti_depth
2
- disp_name: kitti_depth_eigen_test_full
3
- dir: kitti/kitti_eigen_split_test.tar
4
- filenames: data_split/kitti_depth/eigen_test_files_with_gt.txt
5
- kitti_bm_crop: true
6
- valid_mask_crop: eigen
 
 
 
 
 
 
 
config/dataset_depth/data_kitti_val.yaml DELETED
@@ -1,6 +0,0 @@
1
- name: kitti_depth
2
- disp_name: kitti_depth_val800_from_eigen_train
3
- dir: kitti/kitti_sampled_val_800.tar
4
- filenames: data_split/kitti_depth/eigen_val_from_train_800.txt
5
- kitti_bm_crop: true
6
- valid_mask_crop: eigen
 
 
 
 
 
 
 
config/dataset_depth/data_nyu_test.yaml DELETED
@@ -1,5 +0,0 @@
1
- name: nyu_depth
2
- disp_name: nyu_depth_test_full
3
- dir: nyuv2/nyu_labeled_extracted.tar
4
- filenames: data_split/nyu_depth/labeled/filename_list_test.txt
5
- eigen_valid_mask: true
 
 
 
 
 
 
config/dataset_depth/data_nyu_train.yaml DELETED
@@ -1,5 +0,0 @@
1
- name: nyu_depth
2
- disp_name: nyu_depth_train_full
3
- dir: nyuv2/nyu_labeled_extracted.tar
4
- filenames: data_split/nyu_depth/labeled/filename_list_train.txt
5
- eigen_valid_mask: true
 
 
 
 
 
 
config/dataset_depth/data_scannet_val.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: scannet_depth
2
- disp_name: scannet_depth_val_800_1
3
- dir: scannet/scannet_val_sampled_800_1.tar
4
- filenames: data_split/scannet_depth/scannet_val_sampled_list_800_1.txt
 
 
 
 
 
config/dataset_depth/data_vkitti_train.yaml DELETED
@@ -1,6 +0,0 @@
1
- name: vkitti_depth
2
- disp_name: vkitti_depth_train
3
- dir: vkitti/vkitti.tar
4
- filenames: data_split/vkitti_depth/vkitti_train.txt
5
- kitti_bm_crop: true
6
- valid_mask_crop: null # no valid_mask_crop for training
 
 
 
 
 
 
 
config/dataset_depth/data_vkitti_val.yaml DELETED
@@ -1,6 +0,0 @@
1
- name: vkitti_depth
2
- disp_name: vkitti_depth_val
3
- dir: vkitti/vkitti.tar
4
- filenames: data_split/vkitti_depth/vkitti_val.txt
5
- kitti_bm_crop: true
6
- valid_mask_crop: eigen
 
 
 
 
 
 
 
config/dataset_depth/dataset_train.yaml DELETED
@@ -1,18 +0,0 @@
1
- dataset:
2
- train:
3
- name: mixed
4
- prob_ls: [0.9, 0.1]
5
- dataset_list:
6
- - name: hypersim_depth
7
- disp_name: hypersim_depth_train
8
- dir: hypersim/hypersim_processed_train.tar
9
- filenames: data_split/hypersim_depth/filename_list_train_filtered.txt
10
- resize_to_hw:
11
- - 480
12
- - 640
13
- - name: vkitti_depth
14
- disp_name: vkitti_depth_train
15
- dir: vkitti/vkitti.tar
16
- filenames: data_split/vkitti_depth/vkitti_train.txt
17
- kitti_bm_crop: true
18
- valid_mask_crop: null
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/dataset_depth/dataset_val.yaml DELETED
@@ -1,45 +0,0 @@
1
- dataset:
2
- val:
3
- # - name: hypersim_depth
4
- # disp_name: hypersim_depth_val
5
- # dir: hypersim/hypersim_processed_val.tar
6
- # filenames: data_split/hypersim_depth/filename_list_val_filtered.txt
7
- # resize_to_hw:
8
- # - 480
9
- # - 640
10
-
11
- # - name: nyu_depth
12
- # disp_name: nyu_depth_train_full
13
- # dir: nyuv2/nyu_labeled_extracted.tar
14
- # filenames: data_split/nyu_depth/labeled/filename_list_train.txt
15
- # eigen_valid_mask: true
16
-
17
- # - name: kitti_depth
18
- # disp_name: kitti_depth_val800_from_eigen_train
19
- # dir: kitti/kitti_depth_sampled_val_800.tar
20
- # filenames: data_split/kitti_depth/eigen_val_from_train_800.txt
21
- # kitti_bm_crop: true
22
- # valid_mask_crop: eigen
23
-
24
- # Smaller subsets for faster validation during training
25
- # The first dataset is used to calculate main eval metric.
26
- - name: hypersim_depth
27
- disp_name: hypersim_depth_val_small_80
28
- dir: hypersim/hypersim_processed_val.tar
29
- filenames: data_split/hypersim_depth/filename_list_val_filtered_small_80.txt
30
- resize_to_hw:
31
- - 480
32
- - 640
33
-
34
- - name: nyu_depth
35
- disp_name: nyu_depth_train_small_100
36
- dir: nyuv2/nyu_labeled_extracted.tar
37
- filenames: data_split/nyu_depth/labeled/filename_list_train_small_100.txt
38
- eigen_valid_mask: true
39
-
40
- - name: kitti_depth
41
- disp_name: kitti_depth_val_from_train_sub_100
42
- dir: kitti/kitti_sampled_val_800.tar
43
- filenames: data_split/kitti_depth/eigen_val_from_train_sub_100.txt
44
- kitti_bm_crop: true
45
- valid_mask_crop: eigen
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/dataset_depth/dataset_vis.yaml DELETED
@@ -1,9 +0,0 @@
1
- dataset:
2
- vis:
3
- - name: hypersim_depth
4
- disp_name: hypersim_depth_vis
5
- dir: hypersim/hypersim_processed_val.tar
6
- filenames: data_split/hypersim_depth/selected_vis_sample.txt
7
- resize_to_hw:
8
- - 480
9
- - 640
 
 
 
 
 
 
 
 
 
 
config/dataset_iid/data_appearance_interiorverse_test.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: interiorverse_iid
2
- disp_name: interiorverse_iid_appearance_test
3
- dir: interiorverse/InteriorVerse.tar
4
- filenames: data_split/interiorverse_iid/interiorverse_test_scenes_85.txt
 
 
 
 
 
config/dataset_iid/data_appearance_synthetic_test.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: interiorverse_iid
2
- disp_name: interiorverse_iid_appearance_test
3
- dir: synthetic
4
- filenames: data_split/osu/osu_test_scenes_85.txt
 
 
 
 
 
config/dataset_iid/data_art_test.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: interiorverse_iid
2
- disp_name: interiorverse_iid_appearance_test
3
- dir: art
4
- filenames: data_split/osu/art_test_scenes.txt
 
 
 
 
 
config/dataset_iid/data_lighting_hypersim_test.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: hypersim_iid
2
- disp_name: hypersim_iid_lighting_test
3
- dir: hypersim
4
- filenames: data_split/hypersim_iid/hypersim_test.txt
 
 
 
 
 
config/dataset_iid/dataset_appearance_train.yaml DELETED
@@ -1,9 +0,0 @@
1
- dataset:
2
- train:
3
- name: mixed
4
- prob_ls: [1.0]
5
- dataset_list:
6
- - name: interiorverse_iid
7
- disp_name: interiorverse_iid_appearance_train
8
- dir: osu_albedo_new
9
- filenames: data_split/osu/osu_train_scenes_85.txt
 
 
 
 
 
 
 
 
 
 
config/dataset_iid/dataset_appearance_val.yaml DELETED
@@ -1,6 +0,0 @@
1
- dataset:
2
- val:
3
- - name: interiorverse_iid
4
- disp_name: interiorverse_iid_appearance_val
5
- dir: synthetic
6
- filenames: data_split/MatrixCity/matrixcity_val_scenes_small.txt
 
 
 
 
 
 
 
config/dataset_iid/dataset_appearance_vis.yaml DELETED
@@ -1,6 +0,0 @@
1
- dataset:
2
- vis:
3
- - name: interiorverse_iid
4
- disp_name: interiorverse_iid_appearance_vis
5
- dir: synthetic
6
- filenames: data_split/MatrixCity/matrixcity_vis_scenes.txt
 
 
 
 
 
 
 
config/dataset_iid/dataset_lighting_train.yaml DELETED
@@ -1,12 +0,0 @@
1
- dataset:
2
- train:
3
- name: mixed
4
- prob_ls: [1.0]
5
- dataset_list:
6
- - name: hypersim_iid
7
- disp_name: hypersim_iid_lighting_train
8
- dir: hypersim
9
- filenames: data_split/hypersim_iid/hypersim_train_filtered.txt
10
- resize_to_hw:
11
- - 480
12
- - 640
 
 
 
 
 
 
 
 
 
 
 
 
 
config/dataset_iid/dataset_lighting_val.yaml DELETED
@@ -1,6 +0,0 @@
1
- dataset:
2
- val:
3
- - name: hypersim_iid
4
- disp_name: hypersim_iid_lighting_val
5
- dir: hypersim
6
- filenames: data_split/hypersim_iid/hypersim_val.txt
 
 
 
 
 
 
 
config/dataset_iid/dataset_lighting_vis.yaml DELETED
@@ -1,6 +0,0 @@
1
- dataset:
2
- vis:
3
- - name: hypersim_iid
4
- disp_name: hypersim_iid_lighting_vis
5
- dir: hypersim
6
- filenames: data_split/hypersim_iid/hypersim_vis.txt
 
 
 
 
 
 
 
config/dataset_iid/osu_data_appearance_interiorverse_test.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: interiorverse_iid
2
- disp_name: interiorverse_iid_appearance_test
3
- dir: synthetic
4
- filenames: data_split/osu/osu_test_scenes_85.txt
 
 
 
 
 
config/dataset_normals/data_diode_test.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: diode_normals
2
- disp_name: diode_normals_test
3
- dir: diode/val
4
- filenames: data_split/diode_normals/diode_test.txt
 
 
 
 
 
config/dataset_normals/data_ibims_test.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: ibims_normals
2
- disp_name: ibims_normals_test
3
- dir: ibims/ibims
4
- filenames: data_split/ibims_normals/ibims_test.txt
 
 
 
 
 
config/dataset_normals/data_nyu_test.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: nyu_normals
2
- disp_name: nyu_normals_test
3
- dir: nyuv2/test
4
- filenames: data_split/nyu_normals/nyuv2_test.txt
 
 
 
 
 
config/dataset_normals/data_oasis_test.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: oasis_normals
2
- disp_name: oasis_normals_test
3
- dir: oasis/val
4
- filenames: data_split/oasis_normals/oasis_test.txt
 
 
 
 
 
config/dataset_normals/data_scannet_test.yaml DELETED
@@ -1,4 +0,0 @@
1
- name: scannet_normals
2
- disp_name: scannet_normals_test
3
- dir: scannet
4
- filenames: data_split/scannet_normals/scannet_test.txt
 
 
 
 
 
config/dataset_normals/dataset_train.yaml DELETED
@@ -1,25 +0,0 @@
1
- dataset:
2
- train:
3
- name: mixed
4
- prob_ls: [0.5, 0.49, 0.01]
5
- dataset_list:
6
- - name: hypersim_normals
7
- disp_name: hypersim_normals_train
8
- dir: hypersim
9
- filenames: data_split/hypersim_normals/hypersim_filtered_all.txt
10
- resize_to_hw:
11
- - 480
12
- - 640
13
- - name: interiorverse_normals
14
- disp_name: interiorverse_normals_train
15
- dir: interiorverse/scenes_85
16
- filenames: data_split/interiorverse_normals/interiorverse_filtered_all.txt
17
- resize_to_hw: null
18
- - name: sintel_normals
19
- disp_name: sintel_normals_train
20
- dir: sintel
21
- filenames: data_split/sintel_normals/sintel_filtered.txt
22
- resize_to_hw:
23
- - 480
24
- - 640
25
- center_crop: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/dataset_normals/dataset_val.yaml DELETED
@@ -1,7 +0,0 @@
1
- dataset:
2
- val:
3
- - name: hypersim_normals
4
- disp_name: hypersim_normals_val_small_100
5
- dir: hypersim
6
- filenames: data_split/hypersim_normals/hypersim_filtered_val_100.txt
7
- resize_to_hw: null
 
 
 
 
 
 
 
 
config/dataset_normals/dataset_vis.yaml DELETED
@@ -1,7 +0,0 @@
1
- dataset:
2
- vis:
3
- - name: hypersim_normals
4
- disp_name: hypersim_normals_vis
5
- dir: hypersim
6
- filenames: data_split/hypersim_normals/hypersim_filtered_vis_20.txt
7
- resize_to_hw: null
 
 
 
 
 
 
 
 
config/logging.yaml DELETED
@@ -1,5 +0,0 @@
1
- logging:
2
- filename: logging.log
3
- format: ' %(asctime)s - %(levelname)s -%(filename)s - %(funcName)s >> %(message)s'
4
- console_level: 20
5
- file_level: 10
 
 
 
 
 
 
config/model_sdv2.yaml DELETED
@@ -1,4 +0,0 @@
1
- model:
2
- name: marigold_pipeline
3
- pretrained_path: stable-diffusion-2
4
- latent_scale_factor: 0.18215
 
 
 
 
 
config/train_debug_depth.yaml DELETED
@@ -1,10 +0,0 @@
1
- base_config:
2
- - config/train_marigold_depth.yaml
3
-
4
- trainer:
5
- save_period: 5
6
- backup_period: 10
7
- validation_period: 5
8
- visualization_period: 5
9
-
10
- max_iter: 50
 
 
 
 
 
 
 
 
 
 
 
config/train_debug_iid.yaml DELETED
@@ -1,11 +0,0 @@
1
- base_config:
2
- # - config/train_marigold_iid_lighting.yaml
3
- - config/train_marigold_iid_appearance.yaml
4
-
5
- trainer:
6
- save_period: 10
7
- backup_period: 10
8
- validation_period: 5
9
- visualization_period: 5
10
-
11
- max_iter: 50
 
 
 
 
 
 
 
 
 
 
 
 
config/train_debug_normals.yaml DELETED
@@ -1,10 +0,0 @@
1
- base_config:
2
- - config/train_marigold_normals.yaml
3
-
4
- trainer:
5
- save_period: 5
6
- backup_period: 10
7
- validation_period: 5
8
- visualization_period: 5
9
-
10
- max_iter: 50
 
 
 
 
 
 
 
 
 
 
 
config/train_marigold_depth.yaml DELETED
@@ -1,94 +0,0 @@
1
- base_config:
2
- - config/logging.yaml
3
- - config/wandb.yaml
4
- - config/dataset_depth/dataset_train.yaml
5
- - config/dataset_depth/dataset_val.yaml
6
- - config/dataset_depth/dataset_vis.yaml
7
- - config/model_sdv2.yaml
8
-
9
- pipeline:
10
- name: MarigoldDepthPipeline
11
- kwargs:
12
- scale_invariant: true
13
- shift_invariant: true
14
- default_denoising_steps: 4
15
- default_processing_resolution: 768
16
-
17
- depth_normalization:
18
- type: scale_shift_depth
19
- clip: true
20
- norm_min: -1.0
21
- norm_max: 1.0
22
- min_max_quantile: 0.02
23
-
24
- augmentation:
25
- lr_flip_p: 0.5
26
-
27
- dataloader:
28
- num_workers: 2
29
- effective_batch_size: 32
30
- max_train_batch_size: 2
31
- seed: 2024 # to ensure continuity when resuming from checkpoint
32
-
33
- trainer:
34
- name: MarigoldDepthTrainer
35
- training_noise_scheduler:
36
- pretrained_path: stable-diffusion-2
37
- init_seed: 2024 # use null to train w/o seeding
38
- save_period: 50
39
- backup_period: 2000
40
- validation_period: 500
41
- visualization_period: 1000
42
-
43
- multi_res_noise:
44
- strength: 0.9
45
- annealed: true
46
- downscale_strategy: original
47
-
48
- gt_depth_type: depth_raw_norm
49
- gt_mask_type: valid_mask_raw
50
-
51
- max_epoch: 10000 # a large enough number
52
- max_iter: 30000 # usually converges at around 20k
53
-
54
- optimizer:
55
- name: Adam
56
-
57
- loss:
58
- name: mse_loss
59
- kwargs:
60
- reduction: mean
61
-
62
- lr: 3.0e-05
63
- lr_scheduler:
64
- name: IterExponential
65
- kwargs:
66
- total_iter: 25000
67
- final_ratio: 0.01
68
- warmup_steps: 100
69
-
70
- # Light setting for the in-training validation and visualization
71
- validation:
72
- denoising_steps: 1
73
- ensemble_size: 1
74
- processing_res: 0
75
- match_input_res: false
76
- resample_method: bilinear
77
- main_val_metric: abs_relative_difference
78
- main_val_metric_goal: minimize
79
- init_seed: 2024
80
-
81
- eval:
82
- alignment: least_square
83
- align_max_res: null
84
- eval_metrics:
85
- - abs_relative_difference
86
- - squared_relative_difference
87
- - rmse_linear
88
- - rmse_log
89
- - log10
90
- - delta1_acc
91
- - delta2_acc
92
- - delta3_acc
93
- - i_rmse
94
- - silog_rmse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/train_marigold_iid_appearance.yaml DELETED
@@ -1,81 +0,0 @@
1
- base_config:
2
- - config/logging.yaml
3
- - config/wandb.yaml
4
- - config/dataset_iid/dataset_appearance_train.yaml
5
- - config/dataset_iid/dataset_appearance_val.yaml
6
- - config/dataset_iid/dataset_appearance_vis.yaml
7
- - config/model_sdv2.yaml
8
-
9
- pipeline:
10
- name: MarigoldIIDPipeline
11
- kwargs:
12
- default_denoising_steps: 4
13
- default_processing_resolution: 768
14
- target_properties:
15
- target_names:
16
- - albedo
17
- albedo:
18
- prediction_space: srgb
19
-
20
- augmentation:
21
- lr_flip_p: 0.5
22
-
23
- dataloader:
24
- num_workers: 2
25
- effective_batch_size: 32
26
- max_train_batch_size: 8
27
- seed: 2024 # to ensure continuity when resuming from checkpoint
28
-
29
- trainer:
30
- name: MarigoldIIDTrainer
31
- training_noise_scheduler:
32
- pretrained_path: stable-diffusion-2
33
- init_seed: 2024 # use null to train w/o seeding
34
- save_period: 50
35
- backup_period: 2000
36
- validation_period: 500
37
- visualization_period: 1000
38
-
39
- multi_res_noise:
40
- strength: 0.9
41
- annealed: true
42
- downscale_strategy: original
43
-
44
- gt_mask_type: mask
45
-
46
- max_epoch: 10000 # a large enough number
47
- max_iter: 10000 # usually converges at around 40k
48
-
49
- optimizer:
50
- name: Adam
51
-
52
- loss:
53
- name: mse_loss
54
- kwargs:
55
- reduction: mean
56
-
57
- lr: 2.0e-05
58
- lr_scheduler:
59
- name: IterExponential
60
- kwargs:
61
- total_iter: 5000
62
- final_ratio: 0.01
63
- warmup_steps: 100
64
-
65
- # Light setting for the in-training validation and visualization
66
- validation:
67
- denoising_steps: 4
68
- ensemble_size: 1
69
- processing_res: 0
70
- match_input_res: true
71
- resample_method: bilinear
72
- main_val_metric: psnr
73
- main_val_metric_goal: maximize
74
- init_seed: 2024
75
- use_mask: false
76
-
77
- eval:
78
- eval_metrics:
79
- - psnr
80
- targets_to_eval_in_linear_space:
81
- - material
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/train_marigold_iid_appearance_finetuned.yaml DELETED
@@ -1,81 +0,0 @@
1
- base_config:
2
- - config/logging.yaml
3
- - config/wandb.yaml
4
- - config/dataset_iid/dataset_appearance_train.yaml
5
- - config/dataset_iid/dataset_appearance_val.yaml
6
- - config/dataset_iid/dataset_appearance_vis.yaml
7
- - config/model_sdv2.yaml
8
-
9
- pipeline:
10
- name: MarigoldIIDPipeline
11
- kwargs:
12
- default_denoising_steps: 4
13
- default_processing_resolution: 768
14
- target_properties:
15
- target_names:
16
- - albedo
17
- albedo:
18
- prediction_space: srgb
19
-
20
- augmentation:
21
- lr_flip_p: 0.5
22
-
23
- dataloader:
24
- num_workers: 2
25
- effective_batch_size: 32
26
- max_train_batch_size: 8
27
- seed: 2024 # to ensure continuity when resuming from checkpoint
28
-
29
- trainer:
30
- name: MarigoldIIDTrainer
31
- training_noise_scheduler:
32
- pretrained_path: stable-diffusion-2
33
- init_seed: 2024 # use null to train w/o seeding
34
- save_period: 50
35
- backup_period: 2000
36
- validation_period: 177
37
- visualization_period: 177
38
-
39
- multi_res_noise:
40
- strength: 0.9
41
- annealed: true
42
- downscale_strategy: original
43
-
44
- gt_mask_type: null
45
-
46
- max_epoch: 10000 # a large enough number
47
- max_iter: 5000 # usually converges at around 40k
48
-
49
- optimizer:
50
- name: Adam
51
-
52
- loss:
53
- name: mse_loss
54
- kwargs:
55
- reduction: mean
56
-
57
- lr: 5.0e-07
58
- lr_scheduler:
59
- name: IterExponential
60
- kwargs:
61
- total_iter: 2500
62
- final_ratio: 0.01
63
- warmup_steps: 100
64
-
65
- # Light setting for the in-training validation and visualization
66
- validation:
67
- denoising_steps: 4
68
- ensemble_size: 1
69
- processing_res: 1000
70
- match_input_res: true
71
- resample_method: bilinear
72
- main_val_metric: psnr
73
- main_val_metric_goal: maximize
74
- init_seed: 2024
75
- use_mask: false
76
-
77
- eval:
78
- eval_metrics:
79
- - psnr
80
- targets_to_eval_in_linear_space:
81
- - material
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/train_marigold_iid_lighting.yaml DELETED
@@ -1,82 +0,0 @@
1
- base_config:
2
- - config/logging.yaml
3
- - config/wandb.yaml
4
- - config/dataset_iid/dataset_lighting_train.yaml
5
- - config/dataset_iid/dataset_lighting_val.yaml
6
- - config/dataset_iid/dataset_lighting_vis.yaml
7
- - config/model_sdv2.yaml
8
-
9
- pipeline:
10
- name: MarigoldIIDPipeline
11
- kwargs:
12
- default_denoising_steps: 4
13
- default_processing_resolution: 768
14
- target_properties:
15
- target_names:
16
- - albedo
17
- albedo:
18
- prediction_space: linear
19
- up_to_scale: false
20
-
21
- augmentation:
22
- lr_flip_p: 0.5
23
-
24
- dataloader:
25
- num_workers: 2
26
- effective_batch_size: 32
27
- max_train_batch_size: 8
28
- seed: 2024 # to ensure continuity when resuming from checkpoint
29
-
30
- trainer:
31
- name: MarigoldIIDTrainer
32
- training_noise_scheduler:
33
- pretrained_path: stable-diffusion-2
34
- init_seed: 2024 # use null to train w/o seeding
35
- save_period: 50
36
- backup_period: 2000
37
- validation_period: 500
38
- visualization_period: 1000
39
-
40
- multi_res_noise:
41
- strength: 0.9
42
- annealed: true
43
- downscale_strategy: original
44
-
45
- gt_mask_type: mask
46
-
47
- max_epoch: 10000 # a large enough number
48
- max_iter: 50000 # usually converges at around 34k
49
-
50
- optimizer:
51
- name: Adam
52
-
53
- loss:
54
- name: mse_loss
55
- kwargs:
56
- reduction: mean
57
-
58
- lr: 8e-05
59
- lr_scheduler:
60
- name: IterExponential
61
- kwargs:
62
- total_iter: 45000
63
- final_ratio: 0.01
64
- warmup_steps: 100
65
-
66
- # Light setting for the in-training validation and visualization
67
- validation:
68
- denoising_steps: 4
69
- ensemble_size: 1
70
- processing_res: 0
71
- match_input_res: true
72
- resample_method: bilinear
73
- main_val_metric: psnr
74
- main_val_metric_goal: maximize
75
- init_seed: 2024
76
- use_mask: false
77
-
78
- eval:
79
- eval_metrics:
80
- - psnr
81
- targets_to_eval_in_linear_space:
82
- - None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/train_marigold_normals.yaml DELETED
@@ -1,86 +0,0 @@
1
- base_config:
2
- - config/logging.yaml
3
- - config/wandb.yaml
4
- - config/dataset_normals/dataset_train.yaml
5
- - config/dataset_normals/dataset_val.yaml
6
- - config/dataset_normals/dataset_vis.yaml
7
- - config/model_sdv2.yaml
8
-
9
- pipeline:
10
- name: MarigoldNormalsPipeline
11
- kwargs:
12
- default_denoising_steps: 4
13
- default_processing_resolution: 768
14
-
15
- augmentation:
16
- lr_flip_p: 0.5
17
- color_jitter_p: 0.3
18
- gaussian_blur_p: 0.3
19
- motion_blur_p: 0.3
20
- gaussian_blur_sigma: 4
21
- motion_blur_kernel_size: 11
22
- motion_blur_angle_range: 360
23
- jitter_brightness_factor: 0.5
24
- jitter_contrast_factor: 0.5
25
- jitter_saturation_factor: 0.5
26
- jitter_hue_factor: 0.2
27
-
28
- dataloader:
29
- num_workers: 2
30
- effective_batch_size: 32
31
- max_train_batch_size: 2
32
- seed: 2024 # to ensure continuity when resuming from checkpoint
33
-
34
- trainer:
35
- name: MarigoldNormalsTrainer
36
- training_noise_scheduler:
37
- pretrained_path: stable-diffusion-2
38
- init_seed: 2024 # use null to train w/o seeding
39
- save_period: 50
40
- backup_period: 2000
41
- validation_period: 500
42
- visualization_period: 1000
43
-
44
- multi_res_noise:
45
- strength: 0.9
46
- annealed: true
47
- downscale_strategy: original
48
-
49
- gt_normals_type: normals
50
- gt_mask_type: null
51
-
52
- max_epoch: 10000 # a large enough number
53
- max_iter: 30000 # usually converges at around 26k
54
-
55
- optimizer:
56
- name: Adam
57
-
58
- loss:
59
- name: mse_loss
60
- kwargs:
61
- reduction: mean
62
-
63
- lr: 6.0e-05
64
- lr_scheduler:
65
- name: IterExponential
66
- kwargs:
67
- total_iter: 25000
68
- final_ratio: 0.01
69
- warmup_steps: 100
70
-
71
- # Light setting for the in-training validation and visualization
72
- validation:
73
- denoising_steps: 4
74
- ensemble_size: 1
75
- processing_res: 768
76
- match_input_res: true
77
- resample_method: bilinear
78
- main_val_metric: mean_angular_error
79
- main_val_metric_goal: minimize
80
- init_seed: 0
81
-
82
- eval:
83
- align_max_res: null
84
- eval_metrics:
85
- - mean_angular_error
86
- - sub11_25_error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/wandb.yaml DELETED
@@ -1,3 +0,0 @@
1
- wandb:
2
- # entity: your_entity
3
- project: marigold
 
 
 
 
marigold/__init__.py DELETED
@@ -1,41 +0,0 @@
1
- # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # More information about Marigold:
16
- # https://marigoldmonodepth.github.io
17
- # https://marigoldcomputervision.github.io
18
- # Efficient inference pipelines are now part of diffusers:
19
- # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
20
- # https://huggingface.co/docs/diffusers/api/pipelines/marigold
21
- # Examples of trained models and live demos:
22
- # https://huggingface.co/prs-eth
23
- # Related projects:
24
- # https://rollingdepth.github.io/
25
- # https://marigolddepthcompletion.github.io/
26
- # Citation (BibTeX):
27
- # https://github.com/prs-eth/Marigold#-citation
28
- # If you find Marigold useful, we kindly ask you to cite our papers.
29
- # --------------------------------------------------------------------------
30
-
31
- from .marigold_depth_pipeline import (
32
- MarigoldDepthPipeline,
33
- MarigoldDepthOutput, # noqa: F401
34
- )
35
- from .marigold_iid_pipeline import MarigoldIIDPipeline, MarigoldIIDOutput # noqa: F401
36
- from .marigold_normals_pipeline import (
37
- MarigoldNormalsPipeline, # noqa: F401
38
- MarigoldNormalsOutput, # noqa: F401
39
- )
40
-
41
- MarigoldPipeline = MarigoldDepthPipeline # for backward compatibility
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
marigold/marigold_depth_pipeline.py DELETED
@@ -1,516 +0,0 @@
1
- # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # More information about Marigold:
16
- # https://marigoldmonodepth.github.io
17
- # https://marigoldcomputervision.github.io
18
- # Efficient inference pipelines are now part of diffusers:
19
- # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
20
- # https://huggingface.co/docs/diffusers/api/pipelines/marigold
21
- # Examples of trained models and live demos:
22
- # https://huggingface.co/prs-eth
23
- # Related projects:
24
- # https://rollingdepth.github.io/
25
- # https://marigolddepthcompletion.github.io/
26
- # Citation (BibTeX):
27
- # https://github.com/prs-eth/Marigold#-citation
28
- # If you find Marigold useful, we kindly ask you to cite our papers.
29
- # --------------------------------------------------------------------------
30
-
31
- import logging
32
- import numpy as np
33
- import torch
34
- from PIL import Image
35
- from diffusers import (
36
- AutoencoderKL,
37
- DDIMScheduler,
38
- DiffusionPipeline,
39
- LCMScheduler,
40
- UNet2DConditionModel,
41
- )
42
- from diffusers.utils import BaseOutput
43
- from torch.utils.data import DataLoader, TensorDataset
44
- from torchvision.transforms import InterpolationMode
45
- from torchvision.transforms.functional import pil_to_tensor, resize
46
- from tqdm.auto import tqdm
47
- from transformers import CLIPTextModel, CLIPTokenizer
48
- from typing import Dict, Optional, Union
49
-
50
- from .util.batchsize import find_batch_size
51
- from .util.ensemble import ensemble_depth
52
- from .util.image_util import (
53
- chw2hwc,
54
- colorize_depth_maps,
55
- get_tv_resample_method,
56
- resize_max_res,
57
- )
58
-
59
-
60
- class MarigoldDepthOutput(BaseOutput):
61
- """
62
- Output class for Marigold Monocular Depth Estimation pipeline.
63
-
64
- Args:
65
- depth_np (`np.ndarray`):
66
- Predicted depth map, with depth values in the range of [0, 1].
67
- depth_colored (`PIL.Image.Image`):
68
- Colorized depth map, with the shape of [H, W, 3] and values in [0, 255].
69
- uncertainty (`None` or `np.ndarray`):
70
- Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
71
- """
72
-
73
- depth_np: np.ndarray
74
- depth_colored: Union[None, Image.Image]
75
- uncertainty: Union[None, np.ndarray]
76
-
77
-
78
- class MarigoldDepthPipeline(DiffusionPipeline):
79
- """
80
- Pipeline for Marigold Monocular Depth Estimation: https://marigoldcomputervision.github.io.
81
-
82
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
83
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
84
-
85
- Args:
86
- unet (`UNet2DConditionModel`):
87
- Conditional U-Net to denoise the prediction latent, conditioned on image latent.
88
- vae (`AutoencoderKL`):
89
- Variational Auto-Encoder (VAE) Model to encode and decode images and predictions
90
- to and from latent representations.
91
- scheduler (`DDIMScheduler`):
92
- A scheduler to be used in combination with `unet` to denoise the encoded image latents.
93
- text_encoder (`CLIPTextModel`):
94
- Text-encoder, for empty text embedding.
95
- tokenizer (`CLIPTokenizer`):
96
- CLIP tokenizer.
97
- scale_invariant (`bool`, *optional*):
98
- A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
99
- the model config. When used together with the `shift_invariant=True` flag, the model is also called
100
- "affine-invariant". NB: overriding this value is not supported.
101
- shift_invariant (`bool`, *optional*):
102
- A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
103
- the model config. When used together with the `scale_invariant=True` flag, the model is also called
104
- "affine-invariant". NB: overriding this value is not supported.
105
- default_denoising_steps (`int`, *optional*):
106
- The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
107
- quality with the given model. This value must be set in the model config. When the pipeline is called
108
- without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
109
- reasonable results with various model flavors compatible with the pipeline, such as those relying on very
110
- short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
111
- default_processing_resolution (`int`, *optional*):
112
- The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
113
- the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
114
- default value is used. This is required to ensure reasonable results with various model flavors trained
115
- with varying optimal processing resolution values.
116
- """
117
-
118
- latent_scale_factor = 0.18215
119
-
120
- def __init__(
121
- self,
122
- unet: UNet2DConditionModel,
123
- vae: AutoencoderKL,
124
- scheduler: Union[DDIMScheduler, LCMScheduler],
125
- text_encoder: CLIPTextModel,
126
- tokenizer: CLIPTokenizer,
127
- scale_invariant: Optional[bool] = True,
128
- shift_invariant: Optional[bool] = True,
129
- default_denoising_steps: Optional[int] = None,
130
- default_processing_resolution: Optional[int] = None,
131
- ):
132
- super().__init__()
133
- self.register_modules(
134
- unet=unet,
135
- vae=vae,
136
- scheduler=scheduler,
137
- text_encoder=text_encoder,
138
- tokenizer=tokenizer,
139
- )
140
- self.register_to_config(
141
- scale_invariant=scale_invariant,
142
- shift_invariant=shift_invariant,
143
- default_denoising_steps=default_denoising_steps,
144
- default_processing_resolution=default_processing_resolution,
145
- )
146
-
147
- self.scale_invariant = scale_invariant
148
- self.shift_invariant = shift_invariant
149
- self.default_denoising_steps = default_denoising_steps
150
- self.default_processing_resolution = default_processing_resolution
151
-
152
- self.empty_text_embed = None
153
-
154
- @torch.no_grad()
155
- def __call__(
156
- self,
157
- input_image: Union[Image.Image, torch.Tensor],
158
- denoising_steps: Optional[int] = None,
159
- ensemble_size: int = 1,
160
- processing_res: Optional[int] = None,
161
- match_input_res: bool = True,
162
- resample_method: str = "bilinear",
163
- batch_size: int = 0,
164
- generator: Union[torch.Generator, None] = None,
165
- color_map: str = "Spectral",
166
- show_progress_bar: bool = True,
167
- ensemble_kwargs: Dict = None,
168
- ) -> MarigoldDepthOutput:
169
- """
170
- Function invoked when calling the pipeline.
171
-
172
- Args:
173
- input_image (`Image`):
174
- Input RGB (or gray-scale) image.
175
- denoising_steps (`int`, *optional*, defaults to `None`):
176
- Number of denoising diffusion steps during inference. The default value `None` results in automatic
177
- selection.
178
- ensemble_size (`int`, *optional*, defaults to `1`):
179
- Number of predictions to be ensembled.
180
- processing_res (`int`, *optional*, defaults to `None`):
181
- Effective processing resolution. When set to `0`, processes at the original image resolution. This
182
- produces crisper predictions, but may also lead to the overall loss of global context. The default
183
- value `None` resolves to the optimal value from the model config.
184
- match_input_res (`bool`, *optional*, defaults to `True`):
185
- Resize the prediction to match the input resolution.
186
- Only valid if `processing_res` > 0.
187
- resample_method: (`str`, *optional*, defaults to `bilinear`):
188
- Resampling method used to resize images and predictions. This can be one of `bilinear`, `bicubic` or
189
- `nearest`, defaults to: `bilinear`.
190
- batch_size (`int`, *optional*, defaults to `0`):
191
- Inference batch size, no bigger than `num_ensemble`.
192
- If set to 0, the script will automatically decide the proper batch size.
193
- generator (`torch.Generator`, *optional*, defaults to `None`)
194
- Random generator for initial noise generation.
195
- show_progress_bar (`bool`, *optional*, defaults to `True`):
196
- Display a progress bar of diffusion denoising.
197
- color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
198
- Colormap used to colorize the depth map.
199
- scale_invariant (`str`, *optional*, defaults to `True`):
200
- Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
201
- shift_invariant (`str`, *optional*, defaults to `True`):
202
- Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False,
203
- near plane will be fixed at 0m.
204
- ensemble_kwargs (`dict`, *optional*, defaults to `None`):
205
- Arguments for detailed ensembling settings.
206
- Returns:
207
- `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
208
- - **depth_np** (`np.ndarray`) Predicted depth map with depth values in the range of [0, 1]
209
- - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [H, W, 3] and values in [0, 255], None if `color_map` is `None`
210
- - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
211
- coming from ensembling. None if `ensemble_size = 1`
212
- """
213
- # Model-specific optimal default values leading to fast and reasonable results.
214
- if denoising_steps is None:
215
- denoising_steps = self.default_denoising_steps
216
- if processing_res is None:
217
- processing_res = self.default_processing_resolution
218
-
219
- assert processing_res >= 0
220
- assert ensemble_size >= 1
221
-
222
- # Check if denoising step is reasonable
223
- self._check_inference_step(denoising_steps)
224
-
225
- resample_method: InterpolationMode = get_tv_resample_method(resample_method)
226
-
227
- # ----------------- Image Preprocess -----------------
228
- # Convert to torch tensor
229
- if isinstance(input_image, Image.Image):
230
- input_image = input_image.convert("RGB")
231
- # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
232
- rgb = pil_to_tensor(input_image)
233
- rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
234
- elif isinstance(input_image, torch.Tensor):
235
- rgb = input_image
236
- else:
237
- raise TypeError(f"Unknown input type: {type(input_image) = }")
238
- input_size = rgb.shape
239
- assert (
240
- 4 == rgb.dim() and 3 == input_size[-3]
241
- ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
242
-
243
- # Resize image
244
- if processing_res > 0:
245
- rgb = resize_max_res(
246
- rgb,
247
- max_edge_resolution=processing_res,
248
- resample_method=resample_method,
249
- )
250
-
251
- # Normalize rgb values
252
- rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
253
- rgb_norm = rgb_norm.to(self.dtype)
254
- assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
255
-
256
- # ----------------- Predicting depth -----------------
257
- # Batch repeated input image
258
- duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
259
- single_rgb_dataset = TensorDataset(duplicated_rgb)
260
- if batch_size > 0:
261
- _bs = batch_size
262
- else:
263
- _bs = find_batch_size(
264
- ensemble_size=ensemble_size,
265
- input_res=max(rgb_norm.shape[1:]),
266
- dtype=self.dtype,
267
- )
268
-
269
- single_rgb_loader = DataLoader(
270
- single_rgb_dataset, batch_size=_bs, shuffle=False
271
- )
272
-
273
- # Predict depth maps (batched)
274
- target_pred_ls = []
275
- if show_progress_bar:
276
- iterable = tqdm(
277
- single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
278
- )
279
- else:
280
- iterable = single_rgb_loader
281
- for batch in iterable:
282
- (batched_img,) = batch
283
- target_pred_raw = self.single_infer(
284
- rgb_in=batched_img,
285
- num_inference_steps=denoising_steps,
286
- show_pbar=show_progress_bar,
287
- generator=generator,
288
- )
289
- target_pred_ls.append(target_pred_raw.detach())
290
- target_preds = torch.concat(target_pred_ls, dim=0)
291
- torch.cuda.empty_cache() # clear vram cache for ensembling
292
-
293
- # ----------------- Test-time ensembling -----------------
294
- if ensemble_size > 1:
295
- final_pred, pred_uncert = ensemble_depth(
296
- target_preds,
297
- scale_invariant=self.scale_invariant,
298
- shift_invariant=self.shift_invariant,
299
- **(ensemble_kwargs or {}),
300
- )
301
- else:
302
- final_pred = target_preds
303
- pred_uncert = None
304
-
305
- # Resize back to original resolution
306
- if match_input_res:
307
- final_pred = resize(
308
- final_pred,
309
- input_size[-2:],
310
- interpolation=resample_method,
311
- antialias=True,
312
- )
313
-
314
- # Convert to numpy
315
- final_pred = final_pred.squeeze()
316
- final_pred = final_pred.cpu().numpy()
317
- if pred_uncert is not None:
318
- pred_uncert = pred_uncert.squeeze().cpu().numpy()
319
-
320
- # Clip output range
321
- final_pred = final_pred.clip(0, 1)
322
-
323
- # Colorize
324
- if color_map is not None:
325
- depth_colored = colorize_depth_maps(
326
- final_pred, 0, 1, cmap=color_map
327
- ).squeeze() # [3, H, W], value in (0, 1)
328
- depth_colored = (depth_colored * 255).astype(np.uint8)
329
- depth_colored_hwc = chw2hwc(depth_colored)
330
- depth_colored_img = Image.fromarray(depth_colored_hwc)
331
- else:
332
- depth_colored_img = None
333
-
334
- return MarigoldDepthOutput(
335
- depth_np=final_pred,
336
- depth_colored=depth_colored_img,
337
- uncertainty=pred_uncert,
338
- )
339
-
340
- def _check_inference_step(self, n_step: int) -> None:
341
- """
342
- Check if denoising step is reasonable
343
- Args:
344
- n_step (`int`): denoising steps
345
- """
346
- assert n_step >= 1
347
-
348
- if isinstance(self.scheduler, DDIMScheduler):
349
- if "trailing" != self.scheduler.config.timestep_spacing:
350
- logging.warning(
351
- f"The loaded `DDIMScheduler` is configured with `timestep_spacing="
352
- f'"{self.scheduler.config.timestep_spacing}"`; the recommended setting is `"trailing"`. '
353
- f"This change is backward-compatible and yields better results. "
354
- f"Consider using `prs-eth/marigold-depth-v1-1` for the best experience."
355
- )
356
- else:
357
- if n_step > 10:
358
- logging.warning(
359
- f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on "
360
- f"the default values."
361
- )
362
- if not self.scheduler.config.rescale_betas_zero_snr:
363
- logging.warning(
364
- f"The loaded `DDIMScheduler` is configured with `rescale_betas_zero_snr="
365
- f"{self.scheduler.config.rescale_betas_zero_snr}`; the recommended setting is True. "
366
- f"Consider using `prs-eth/marigold-depth-v1-1` for the best experience."
367
- )
368
- elif isinstance(self.scheduler, LCMScheduler):
369
- logging.warning(
370
- "DeprecationWarning: LCMScheduler will not be supported in the future. "
371
- "Consider using `prs-eth/marigold-depth-v1-1` for the best experience."
372
- )
373
- if n_step > 10:
374
- logging.warning(
375
- f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on "
376
- f"the default values."
377
- )
378
- else:
379
- raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
380
-
381
- def encode_empty_text(self):
382
- """
383
- Encode text embedding for empty prompt
384
- """
385
- prompt = ""
386
- text_inputs = self.tokenizer(
387
- prompt,
388
- padding="do_not_pad",
389
- max_length=self.tokenizer.model_max_length,
390
- truncation=True,
391
- return_tensors="pt",
392
- )
393
- text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
394
- self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
395
-
396
- @torch.no_grad()
397
- def single_infer(
398
- self,
399
- rgb_in: torch.Tensor,
400
- num_inference_steps: int,
401
- generator: Union[torch.Generator, None],
402
- show_pbar: bool,
403
- ) -> torch.Tensor:
404
- """
405
- Perform a single prediction without ensembling.
406
-
407
- Args:
408
- rgb_in (`torch.Tensor`):
409
- Input RGB image.
410
- num_inference_steps (`int`):
411
- Number of diffusion denoisign steps (DDIM) during inference.
412
- show_pbar (`bool`):
413
- Display a progress bar of diffusion denoising.
414
- generator (`torch.Generator`)
415
- Random generator for initial noise generation.
416
- Returns:
417
- `torch.Tensor`: Predicted targets.
418
- """
419
- device = self.device
420
- rgb_in = rgb_in.to(device)
421
-
422
- # Set timesteps
423
- self.scheduler.set_timesteps(num_inference_steps, device=device)
424
- timesteps = self.scheduler.timesteps # [T]
425
-
426
- # Encode image
427
- rgb_latent = self.encode_rgb(rgb_in) # [B, 4, h, w]
428
-
429
- # Noisy latent for outputs
430
- target_latent = torch.randn(
431
- rgb_latent.shape,
432
- device=device,
433
- dtype=self.dtype,
434
- generator=generator,
435
- ) # [B, 4, h, w]
436
-
437
- # Batched empty text embedding
438
- if self.empty_text_embed is None:
439
- self.encode_empty_text()
440
- batch_empty_text_embed = self.empty_text_embed.repeat(
441
- (rgb_latent.shape[0], 1, 1)
442
- ).to(device) # [B, 2, 1024]
443
-
444
- # Denoising loop
445
- if show_pbar:
446
- iterable = tqdm(
447
- enumerate(timesteps),
448
- total=len(timesteps),
449
- leave=False,
450
- desc=" " * 4 + "Diffusion denoising",
451
- )
452
- else:
453
- iterable = enumerate(timesteps)
454
-
455
- for i, t in iterable:
456
- unet_input = torch.cat(
457
- [rgb_latent, target_latent], dim=1
458
- ) # this order is important
459
-
460
- # predict the noise residual
461
- noise_pred = self.unet(
462
- unet_input, t, encoder_hidden_states=batch_empty_text_embed
463
- ).sample # [B, 4, h, w]
464
-
465
- # compute the previous noisy sample x_t -> x_t-1
466
- target_latent = self.scheduler.step(
467
- noise_pred, t, target_latent, generator=generator
468
- ).prev_sample
469
-
470
- depth = self.decode_depth(target_latent) # [B,3,H,W]
471
-
472
- # clip prediction
473
- depth = torch.clip(depth, -1.0, 1.0)
474
- # shift to [0, 1]
475
- depth = (depth + 1.0) / 2.0
476
-
477
- return depth
478
-
479
- def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
480
- """
481
- Encode RGB image into latent.
482
-
483
- Args:
484
- rgb_in (`torch.Tensor`):
485
- Input RGB image to be encoded.
486
-
487
- Returns:
488
- `torch.Tensor`: Image latent.
489
- """
490
- # encode
491
- h = self.vae.encoder(rgb_in)
492
- moments = self.vae.quant_conv(h)
493
- mean, logvar = torch.chunk(moments, 2, dim=1)
494
- # scale latent
495
- rgb_latent = mean * self.latent_scale_factor
496
- return rgb_latent
497
-
498
- def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
499
- """
500
- Decode depth latent into depth map.
501
-
502
- Args:
503
- depth_latent (`torch.Tensor`):
504
- Depth latent to be decoded.
505
-
506
- Returns:
507
- `torch.Tensor`: Decoded depth map.
508
- """
509
- # scale latent
510
- depth_latent = depth_latent / self.latent_scale_factor
511
- # decode
512
- z = self.vae.post_quant_conv(depth_latent)
513
- stacked = self.vae.decoder(z)
514
- # mean of output channels
515
- depth_mean = stacked.mean(dim=1, keepdim=True)
516
- return depth_mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
marigold/marigold_normals_pipeline.py DELETED
@@ -1,479 +0,0 @@
1
- # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # More information about Marigold:
16
- # https://marigoldmonodepth.github.io
17
- # https://marigoldcomputervision.github.io
18
- # Efficient inference pipelines are now part of diffusers:
19
- # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
20
- # https://huggingface.co/docs/diffusers/api/pipelines/marigold
21
- # Examples of trained models and live demos:
22
- # https://huggingface.co/prs-eth
23
- # Related projects:
24
- # https://rollingdepth.github.io/
25
- # https://marigolddepthcompletion.github.io/
26
- # Citation (BibTeX):
27
- # https://github.com/prs-eth/Marigold#-citation
28
- # If you find Marigold useful, we kindly ask you to cite our papers.
29
- # --------------------------------------------------------------------------
30
-
31
- import logging
32
- import numpy as np
33
- import torch
34
- from PIL import Image
35
- from diffusers import (
36
- AutoencoderKL,
37
- DDIMScheduler,
38
- DiffusionPipeline,
39
- LCMScheduler,
40
- UNet2DConditionModel,
41
- )
42
- from diffusers.utils import BaseOutput
43
- from torch.utils.data import DataLoader, TensorDataset
44
- from torchvision.transforms import InterpolationMode
45
- from torchvision.transforms.functional import pil_to_tensor, resize
46
- from tqdm.auto import tqdm
47
- from transformers import CLIPTextModel, CLIPTokenizer
48
- from typing import Dict, Optional, Union
49
-
50
- from .util.batchsize import find_batch_size
51
- from .util.ensemble import ensemble_normals
52
- from .util.image_util import (
53
- chw2hwc,
54
- get_tv_resample_method,
55
- resize_max_res,
56
- )
57
-
58
-
59
- class MarigoldNormalsOutput(BaseOutput):
60
- """
61
- Output class for Marigold Surface Normals Estimation pipeline.
62
-
63
- Args:
64
- normals_np (`np.ndarray`):
65
- Predicted normals map of shape [3, H, W] with values in the range of [-1, 1] (unit length vectors).
66
- normals_img (`PIL.Image.Image`):
67
- Normals image, with the shape of [H, W, 3] and values in [0, 255].
68
- uncertainty (`None` or `np.ndarray`):
69
- Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
70
- """
71
-
72
- normals_np: np.ndarray
73
- normals_img: Image.Image
74
- uncertainty: Union[None, np.ndarray]
75
-
76
-
77
- class MarigoldNormalsPipeline(DiffusionPipeline):
78
- """
79
- Pipeline for Marigold Surface Normals Estimation: https://marigoldcomputervision.github.io.
80
-
81
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
82
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
83
-
84
- Args:
85
- unet (`UNet2DConditionModel`):
86
- Conditional U-Net to denoise the prediction latent, conditioned on image latent.
87
- vae (`AutoencoderKL`):
88
- Variational Auto-Encoder (VAE) Model to encode and decode images and predictions
89
- to and from latent representations.
90
- scheduler (`DDIMScheduler`):
91
- A scheduler to be used in combination with `unet` to denoise the encoded image latents.
92
- text_encoder (`CLIPTextModel`):
93
- Text-encoder, for empty text embedding.
94
- tokenizer (`CLIPTokenizer`):
95
- CLIP tokenizer.
96
- default_denoising_steps (`int`, *optional*):
97
- The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
98
- quality with the given model. This value must be set in the model config. When the pipeline is called
99
- without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
100
- reasonable results with various model flavors compatible with the pipeline, such as those relying on very
101
- short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
102
- default_processing_resolution (`int`, *optional*):
103
- The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
104
- the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
105
- default value is used. This is required to ensure reasonable results with various model flavors trained
106
- with varying optimal processing resolution values.
107
- """
108
-
109
- latent_scale_factor = 0.18215
110
-
111
- def __init__(
112
- self,
113
- unet: UNet2DConditionModel,
114
- vae: AutoencoderKL,
115
- scheduler: Union[DDIMScheduler, LCMScheduler],
116
- text_encoder: CLIPTextModel,
117
- tokenizer: CLIPTokenizer,
118
- default_denoising_steps: Optional[int] = None,
119
- default_processing_resolution: Optional[int] = None,
120
- ):
121
- super().__init__()
122
- self.register_modules(
123
- unet=unet,
124
- vae=vae,
125
- scheduler=scheduler,
126
- text_encoder=text_encoder,
127
- tokenizer=tokenizer,
128
- )
129
- self.register_to_config(
130
- default_denoising_steps=default_denoising_steps,
131
- default_processing_resolution=default_processing_resolution,
132
- )
133
-
134
- self.default_denoising_steps = default_denoising_steps
135
- self.default_processing_resolution = default_processing_resolution
136
-
137
- self.empty_text_embed = None
138
-
139
- @torch.no_grad()
140
- def __call__(
141
- self,
142
- input_image: Union[Image.Image, torch.Tensor],
143
- denoising_steps: Optional[int] = None,
144
- ensemble_size: int = 1,
145
- processing_res: Optional[int] = None,
146
- match_input_res: bool = True,
147
- resample_method: str = "bilinear",
148
- batch_size: int = 0,
149
- generator: Union[torch.Generator, None] = None,
150
- show_progress_bar: bool = True,
151
- ensemble_kwargs: Dict = None,
152
- ) -> MarigoldNormalsOutput:
153
- """
154
- Function invoked when calling the pipeline.
155
-
156
- Args:
157
- input_image (`Image`):
158
- Input RGB (or gray-scale) image.
159
- denoising_steps (`int`, *optional*, defaults to `None`):
160
- Number of denoising diffusion steps during inference. The default value `None` results in automatic
161
- selection.
162
- ensemble_size (`int`, *optional*, defaults to `1`):
163
- Number of predictions to be ensembled.
164
- processing_res (`int`, *optional*, defaults to `None`):
165
- Effective processing resolution. When set to `0`, processes at the original image resolution. This
166
- produces crisper predictions, but may also lead to the overall loss of global context. The default
167
- value `None` resolves to the optimal value from the model config.
168
- match_input_res (`bool`, *optional*, defaults to `True`):
169
- Resize the prediction to match the input resolution.
170
- Only valid if `processing_res` > 0.
171
- resample_method: (`str`, *optional*, defaults to `bilinear`):
172
- Resampling method used to resize images and predictions. This can be one of `bilinear`, `bicubic` or
173
- `nearest`, defaults to: `bilinear`.
174
- batch_size (`int`, *optional*, defaults to `0`):
175
- Inference batch size, no bigger than `num_ensemble`.
176
- If set to 0, the script will automatically decide the proper batch size.
177
- generator (`torch.Generator`, *optional*, defaults to `None`)
178
- Random generator for initial noise generation.
179
- show_progress_bar (`bool`, *optional*, defaults to `True`):
180
- Display a progress bar of diffusion denoising.
181
- ensemble_kwargs (`dict`, *optional*, defaults to `None`):
182
- Arguments for detailed ensembling settings.
183
- Returns:
184
- `MarigoldNormalsOutput`: Output class for Marigold monocular surface normals estimation pipeline, including:
185
- - **normals_np** (`np.ndarray`) Predicted normals map of shape [3, H, W] with values in the range of [-1, 1]
186
- (unit length vectors)
187
- - **normals_img** (`PIL.Image.Image`) Normals image, with the shape of [H, W, 3] and values in [0, 255]
188
- - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
189
- coming from ensembling. None if `ensemble_size = 1`
190
- """
191
- # Model-specific optimal default values leading to fast and reasonable results.
192
- if denoising_steps is None:
193
- denoising_steps = self.default_denoising_steps
194
- if processing_res is None:
195
- processing_res = self.default_processing_resolution
196
-
197
- assert processing_res >= 0
198
- assert ensemble_size >= 1
199
-
200
- # Check if denoising step is reasonable
201
- self._check_inference_step(denoising_steps)
202
-
203
- resample_method: InterpolationMode = get_tv_resample_method(resample_method)
204
-
205
- # ----------------- Image Preprocess -----------------
206
- # Convert to torch tensor
207
- if isinstance(input_image, Image.Image):
208
- input_image = input_image.convert("RGB")
209
- # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
210
- rgb = pil_to_tensor(input_image)
211
- rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
212
- elif isinstance(input_image, torch.Tensor):
213
- rgb = input_image
214
- else:
215
- raise TypeError(f"Unknown input type: {type(input_image) = }")
216
- input_size = rgb.shape
217
- assert (
218
- 4 == rgb.dim() and 3 == input_size[-3]
219
- ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
220
-
221
- # Resize image
222
- if processing_res > 0:
223
- rgb = resize_max_res(
224
- rgb,
225
- max_edge_resolution=processing_res,
226
- resample_method=resample_method,
227
- )
228
-
229
- # Normalize rgb values
230
- rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
231
- rgb_norm = rgb_norm.to(self.dtype)
232
- assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
233
-
234
- # ----------------- Predicting normals -----------------
235
- # Batch repeated input image
236
- duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
237
- single_rgb_dataset = TensorDataset(duplicated_rgb)
238
- if batch_size > 0:
239
- _bs = batch_size
240
- else:
241
- _bs = find_batch_size(
242
- ensemble_size=ensemble_size,
243
- input_res=max(rgb_norm.shape[1:]),
244
- dtype=self.dtype,
245
- )
246
-
247
- single_rgb_loader = DataLoader(
248
- single_rgb_dataset, batch_size=_bs, shuffle=False
249
- )
250
-
251
- # Predict normals maps (batched)
252
- target_pred_ls = []
253
- if show_progress_bar:
254
- iterable = tqdm(
255
- single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
256
- )
257
- else:
258
- iterable = single_rgb_loader
259
- for batch in iterable:
260
- (batched_img,) = batch
261
- target_pred_raw = self.single_infer(
262
- rgb_in=batched_img,
263
- num_inference_steps=denoising_steps,
264
- show_pbar=show_progress_bar,
265
- generator=generator,
266
- )
267
- target_pred_ls.append(target_pred_raw.detach())
268
- target_preds = torch.concat(target_pred_ls, dim=0)
269
- torch.cuda.empty_cache() # clear vram cache for ensembling
270
-
271
- # ----------------- Test-time ensembling -----------------
272
- if ensemble_size > 1:
273
- final_pred, pred_uncert = ensemble_normals(
274
- target_preds,
275
- **(ensemble_kwargs or {}),
276
- )
277
- else:
278
- final_pred = target_preds
279
- pred_uncert = None
280
-
281
- # Resize back to original resolution
282
- if match_input_res:
283
- final_pred = resize(
284
- final_pred,
285
- input_size[-2:],
286
- interpolation=resample_method,
287
- antialias=True,
288
- )
289
-
290
- # Convert to numpy
291
- final_pred = final_pred.squeeze()
292
- final_pred = final_pred.cpu().numpy()
293
- if pred_uncert is not None:
294
- pred_uncert = pred_uncert.squeeze().cpu().numpy()
295
-
296
- # Clip output range
297
- final_pred = final_pred.clip(-1, 1)
298
-
299
- # Colorize
300
- normals_img = ((final_pred + 1) * 127.5).astype(np.uint8)
301
- normals_img = chw2hwc(normals_img)
302
- normals_img = Image.fromarray(normals_img)
303
-
304
- return MarigoldNormalsOutput(
305
- normals_np=final_pred,
306
- normals_img=normals_img,
307
- uncertainty=pred_uncert,
308
- )
309
-
310
- def _check_inference_step(self, n_step: int) -> None:
311
- """
312
- Check if denoising step is reasonable
313
- Args:
314
- n_step (`int`): denoising steps
315
- """
316
- assert n_step >= 1
317
-
318
- if isinstance(self.scheduler, DDIMScheduler):
319
- if "trailing" != self.scheduler.config.timestep_spacing:
320
- logging.warning(
321
- f"The loaded `DDIMScheduler` is configured with `timestep_spacing="
322
- f'"{self.scheduler.config.timestep_spacing}"`; the recommended setting is `"trailing"`. '
323
- f"This change is backward-compatible and yields better results. "
324
- f"Consider using `prs-eth/marigold-normals-v1-1` for the best experience."
325
- )
326
- else:
327
- if n_step > 10:
328
- logging.warning(
329
- f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on "
330
- f"the default values."
331
- )
332
- if not self.scheduler.config.rescale_betas_zero_snr:
333
- logging.warning(
334
- f"The loaded `DDIMScheduler` is configured with `rescale_betas_zero_snr="
335
- f"{self.scheduler.config.rescale_betas_zero_snr}`; the recommended setting is True. "
336
- f"Consider using `prs-eth/marigold-normals-v1-1` for the best experience."
337
- )
338
- elif isinstance(self.scheduler, LCMScheduler):
339
- raise RuntimeError(
340
- "This pipeline implementation does not support the LCMScheduler. Please refer to the project "
341
- "README.md for instructions about using LCM."
342
- )
343
- else:
344
- raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
345
-
346
- def encode_empty_text(self):
347
- """
348
- Encode text embedding for empty prompt
349
- """
350
- prompt = ""
351
- text_inputs = self.tokenizer(
352
- prompt,
353
- padding="do_not_pad",
354
- max_length=self.tokenizer.model_max_length,
355
- truncation=True,
356
- return_tensors="pt",
357
- )
358
- text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
359
- self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
360
-
361
- @torch.no_grad()
362
- def single_infer(
363
- self,
364
- rgb_in: torch.Tensor,
365
- num_inference_steps: int,
366
- generator: Union[torch.Generator, None],
367
- show_pbar: bool,
368
- ) -> torch.Tensor:
369
- """
370
- Perform a single prediction without ensembling.
371
-
372
- Args:
373
- rgb_in (`torch.Tensor`):
374
- Input RGB image.
375
- num_inference_steps (`int`):
376
- Number of diffusion denoisign steps (DDIM) during inference.
377
- show_pbar (`bool`):
378
- Display a progress bar of diffusion denoising.
379
- generator (`torch.Generator`)
380
- Random generator for initial noise generation.
381
- Returns:
382
- `torch.Tensor`: Predicted targets.
383
- """
384
- device = self.device
385
- rgb_in = rgb_in.to(device)
386
-
387
- # Set timesteps
388
- self.scheduler.set_timesteps(num_inference_steps, device=device)
389
- timesteps = self.scheduler.timesteps # [T]
390
-
391
- # Encode image
392
- rgb_latent = self.encode_rgb(rgb_in) # [B, 4, h, w]
393
-
394
- # Noisy latent for outputs
395
- target_latent = torch.randn(
396
- rgb_latent.shape,
397
- device=device,
398
- dtype=self.dtype,
399
- generator=generator,
400
- ) # [B, 4, h, w]
401
-
402
- # Batched empty text embedding
403
- if self.empty_text_embed is None:
404
- self.encode_empty_text()
405
- batch_empty_text_embed = self.empty_text_embed.repeat(
406
- (rgb_latent.shape[0], 1, 1)
407
- ).to(device) # [B, 2, 1024]
408
-
409
- # Denoising loop
410
- if show_pbar:
411
- iterable = tqdm(
412
- enumerate(timesteps),
413
- total=len(timesteps),
414
- leave=False,
415
- desc=" " * 4 + "Diffusion denoising",
416
- )
417
- else:
418
- iterable = enumerate(timesteps)
419
-
420
- for i, t in iterable:
421
- unet_input = torch.cat(
422
- [rgb_latent, target_latent], dim=1
423
- ) # this order is important
424
-
425
- # predict the noise residual
426
- noise_pred = self.unet(
427
- unet_input, t, encoder_hidden_states=batch_empty_text_embed
428
- ).sample # [B, 4, h, w]
429
-
430
- # compute the previous noisy sample x_t -> x_t-1
431
- target_latent = self.scheduler.step(
432
- noise_pred, t, target_latent, generator=generator
433
- ).prev_sample
434
-
435
- normals = self.decode_normals(target_latent) # [B,3,H,W]
436
-
437
- # clip prediction
438
- normals = torch.clip(normals, -1.0, 1.0)
439
- norm = torch.norm(normals, dim=1, keepdim=True)
440
- normals /= norm.clamp(min=1e-6) # [B,3,H,W]
441
-
442
- return normals
443
-
444
- def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
445
- """
446
- Encode RGB image into latent.
447
-
448
- Args:
449
- rgb_in (`torch.Tensor`):
450
- Input RGB image to be encoded.
451
-
452
- Returns:
453
- `torch.Tensor`: Image latent.
454
- """
455
- # encode
456
- h = self.vae.encoder(rgb_in)
457
- moments = self.vae.quant_conv(h)
458
- mean, logvar = torch.chunk(moments, 2, dim=1)
459
- # scale latent
460
- rgb_latent = mean * self.latent_scale_factor
461
- return rgb_latent
462
-
463
- def decode_normals(self, normals_latent: torch.Tensor) -> torch.Tensor:
464
- """
465
- Decode normals latent into normals map.
466
-
467
- Args:
468
- normals_latent (`torch.Tensor`):
469
- Normals latent to be decoded.
470
-
471
- Returns:
472
- `torch.Tensor`: Decoded normals map.
473
- """
474
- # scale latent
475
- normals_latent = normals_latent / self.latent_scale_factor
476
- # decode
477
- z = self.vae.post_quant_conv(normals_latent)
478
- stacked = self.vae.decoder(z)
479
- return stacked
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{src → olbedo}/__init__.py RENAMED
@@ -27,3 +27,5 @@
27
  # https://github.com/prs-eth/Marigold#-citation
28
  # If you find Marigold useful, we kindly ask you to cite our papers.
29
  # --------------------------------------------------------------------------
 
 
 
27
  # https://github.com/prs-eth/Marigold#-citation
28
  # If you find Marigold useful, we kindly ask you to cite our papers.
29
  # --------------------------------------------------------------------------
30
+
31
+ from .olbedo_iid_pipeline import OlbedoIIDPipeline, OlbedoIIDOutput # noqa: F401
olbedo/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (231 Bytes). View file