diff --git a/app.py b/app.py index 4b51132a8c6e0bcbd604e1e7b2e52f3a24fdfc75..d0d95a28b7a57a73be4ce7cc7b9aa4f11755b9df 100644 --- a/app.py +++ b/app.py @@ -7,15 +7,13 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", import gradio as gr import numpy as np import torch -from marigold import MarigoldIIDOutput, MarigoldIIDPipeline +from olbedo import OlbedoIIDOutput, OlbedoIIDPipeline from src.util.image_util import read_img_from_file, img_hwc2chw, img_linear2srgb, is_hdr -from marigold.util.image_util import float2int +from olbedo.util.image_util import float2int from src.util.seeding import seed_all import logging from huggingface_hub import snapshot_download -HF_TOKEN = os.getenv("HF_TOKEN") - seed = 1234 seed_all(seed) if torch.cuda.is_available(): @@ -46,12 +44,11 @@ def get_demo(): local_dir = snapshot_download( repo_id="GDAOSU/olbedo", allow_patterns=f"{selected_model}/*", - token=HF_TOKEN, ) model_path = os.path.join(local_dir, selected_model) - pipe = MarigoldIIDPipeline.from_pretrained( + pipe = OlbedoIIDPipeline.from_pretrained( model_path, torch_dtype=torch.float32, ).to(device) @@ -102,7 +99,7 @@ def get_demo(): if "rgbx" in selected_model: pipe.prompt = prompt - pipe_out: MarigoldIIDOutput = pipe( + pipe_out: OlbedoIIDOutput = pipe( input_image, denoising_steps=inference_step, ensemble_size=1, @@ -136,7 +133,7 @@ def get_demo(): block = gr.Blocks() with block: with gr.Row(): - gr.Markdown("## OSU albedo demo") + gr.Markdown("## Olbedo: An Albedo and Shading Aerial Dataset for Large-Scale Outdoor Environments") with gr.Row(): # Input side with gr.Column(): diff --git a/config/dataset_depth/data_diode_all.yaml b/config/dataset_depth/data_diode_all.yaml deleted file mode 100644 index c2284fe699778b6ef4c23270d8f49c983c57306d..0000000000000000000000000000000000000000 --- a/config/dataset_depth/data_diode_all.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: diode_depth -disp_name: diode_depth_val_all -dir: diode/diode_val.tar -filenames: data_split/diode_depth/diode_val_all_filename_list.txt diff --git a/config/dataset_depth/data_eth3d.yaml b/config/dataset_depth/data_eth3d.yaml deleted file mode 100644 index 7175e2b583b2172feb1ba4630b98256de453ade7..0000000000000000000000000000000000000000 --- a/config/dataset_depth/data_eth3d.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: eth3d_depth -disp_name: eth3d_depth_full -dir: eth3d/eth3d.tar -filenames: data_split/eth3d_depth/eth3d_filename_list.txt diff --git a/config/dataset_depth/data_hypersim_train.yaml b/config/dataset_depth/data_hypersim_train.yaml deleted file mode 100644 index 400e7f45d7f64303a94e5c60fcb0c6792f4ee36d..0000000000000000000000000000000000000000 --- a/config/dataset_depth/data_hypersim_train.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: hypersim_depth -disp_name: hypersim_depth_train -dir: hypersim/hypersim_processed_train.tar -filenames: data_split/hypersim_depth/filename_list_train_filtered.txt diff --git a/config/dataset_depth/data_hypersim_val.yaml b/config/dataset_depth/data_hypersim_val.yaml deleted file mode 100644 index 2edd0fbd1b107a0a5af6ac324a34286e56d2926e..0000000000000000000000000000000000000000 --- a/config/dataset_depth/data_hypersim_val.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: hypersim_depth -disp_name: hypersim_depth_val -dir: hypersim/hypersim_processed_val.tar -filenames: data_split/hypersim_depth/filename_list_val_filtered.txt diff --git a/config/dataset_depth/data_kitti_eigen_test.yaml b/config/dataset_depth/data_kitti_eigen_test.yaml deleted file mode 100644 index a2ef3f8766dc22597188f7b8911a2d3ee54eb297..0000000000000000000000000000000000000000 --- a/config/dataset_depth/data_kitti_eigen_test.yaml +++ /dev/null @@ -1,6 +0,0 @@ -name: kitti_depth -disp_name: kitti_depth_eigen_test_full -dir: kitti/kitti_eigen_split_test.tar -filenames: data_split/kitti_depth/eigen_test_files_with_gt.txt -kitti_bm_crop: true -valid_mask_crop: eigen diff --git a/config/dataset_depth/data_kitti_val.yaml b/config/dataset_depth/data_kitti_val.yaml deleted file mode 100644 index 68bbc54036b142bd5867a8e5acf4acca36cd268b..0000000000000000000000000000000000000000 --- a/config/dataset_depth/data_kitti_val.yaml +++ /dev/null @@ -1,6 +0,0 @@ -name: kitti_depth -disp_name: kitti_depth_val800_from_eigen_train -dir: kitti/kitti_sampled_val_800.tar -filenames: data_split/kitti_depth/eigen_val_from_train_800.txt -kitti_bm_crop: true -valid_mask_crop: eigen diff --git a/config/dataset_depth/data_nyu_test.yaml b/config/dataset_depth/data_nyu_test.yaml deleted file mode 100644 index 23799be1f1a4939422cc5b5ff6217e226fb604c3..0000000000000000000000000000000000000000 --- a/config/dataset_depth/data_nyu_test.yaml +++ /dev/null @@ -1,5 +0,0 @@ -name: nyu_depth -disp_name: nyu_depth_test_full -dir: nyuv2/nyu_labeled_extracted.tar -filenames: data_split/nyu_depth/labeled/filename_list_test.txt -eigen_valid_mask: true diff --git a/config/dataset_depth/data_nyu_train.yaml b/config/dataset_depth/data_nyu_train.yaml deleted file mode 100644 index d8d7e84a0c6e7bc3698fbe45e8c8241966fab512..0000000000000000000000000000000000000000 --- a/config/dataset_depth/data_nyu_train.yaml +++ /dev/null @@ -1,5 +0,0 @@ -name: nyu_depth -disp_name: nyu_depth_train_full -dir: nyuv2/nyu_labeled_extracted.tar -filenames: data_split/nyu_depth/labeled/filename_list_train.txt -eigen_valid_mask: true diff --git a/config/dataset_depth/data_scannet_val.yaml b/config/dataset_depth/data_scannet_val.yaml deleted file mode 100644 index 0e99b6cee6e25ac880784d4a5dff9071390180ed..0000000000000000000000000000000000000000 --- a/config/dataset_depth/data_scannet_val.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: scannet_depth -disp_name: scannet_depth_val_800_1 -dir: scannet/scannet_val_sampled_800_1.tar -filenames: data_split/scannet_depth/scannet_val_sampled_list_800_1.txt diff --git a/config/dataset_depth/data_vkitti_train.yaml b/config/dataset_depth/data_vkitti_train.yaml deleted file mode 100644 index 11aea772903c640c815f8d70b1a2263e6e6dbca8..0000000000000000000000000000000000000000 --- a/config/dataset_depth/data_vkitti_train.yaml +++ /dev/null @@ -1,6 +0,0 @@ -name: vkitti_depth -disp_name: vkitti_depth_train -dir: vkitti/vkitti.tar -filenames: data_split/vkitti_depth/vkitti_train.txt -kitti_bm_crop: true -valid_mask_crop: null # no valid_mask_crop for training diff --git a/config/dataset_depth/data_vkitti_val.yaml b/config/dataset_depth/data_vkitti_val.yaml deleted file mode 100644 index 1c9862b8e1d59203b51398fadcd0b7ac1db511d8..0000000000000000000000000000000000000000 --- a/config/dataset_depth/data_vkitti_val.yaml +++ /dev/null @@ -1,6 +0,0 @@ -name: vkitti_depth -disp_name: vkitti_depth_val -dir: vkitti/vkitti.tar -filenames: data_split/vkitti_depth/vkitti_val.txt -kitti_bm_crop: true -valid_mask_crop: eigen diff --git a/config/dataset_depth/dataset_train.yaml b/config/dataset_depth/dataset_train.yaml deleted file mode 100644 index 214b3f2febad3d80ca61461c176d1614d7cb68e2..0000000000000000000000000000000000000000 --- a/config/dataset_depth/dataset_train.yaml +++ /dev/null @@ -1,18 +0,0 @@ -dataset: - train: - name: mixed - prob_ls: [0.9, 0.1] - dataset_list: - - name: hypersim_depth - disp_name: hypersim_depth_train - dir: hypersim/hypersim_processed_train.tar - filenames: data_split/hypersim_depth/filename_list_train_filtered.txt - resize_to_hw: - - 480 - - 640 - - name: vkitti_depth - disp_name: vkitti_depth_train - dir: vkitti/vkitti.tar - filenames: data_split/vkitti_depth/vkitti_train.txt - kitti_bm_crop: true - valid_mask_crop: null diff --git a/config/dataset_depth/dataset_val.yaml b/config/dataset_depth/dataset_val.yaml deleted file mode 100644 index 67207c01ee2c1a4c52f387b39f9ad6428b694f09..0000000000000000000000000000000000000000 --- a/config/dataset_depth/dataset_val.yaml +++ /dev/null @@ -1,45 +0,0 @@ -dataset: - val: - # - name: hypersim_depth - # disp_name: hypersim_depth_val - # dir: hypersim/hypersim_processed_val.tar - # filenames: data_split/hypersim_depth/filename_list_val_filtered.txt - # resize_to_hw: - # - 480 - # - 640 - - # - name: nyu_depth - # disp_name: nyu_depth_train_full - # dir: nyuv2/nyu_labeled_extracted.tar - # filenames: data_split/nyu_depth/labeled/filename_list_train.txt - # eigen_valid_mask: true - - # - name: kitti_depth - # disp_name: kitti_depth_val800_from_eigen_train - # dir: kitti/kitti_depth_sampled_val_800.tar - # filenames: data_split/kitti_depth/eigen_val_from_train_800.txt - # kitti_bm_crop: true - # valid_mask_crop: eigen - - # Smaller subsets for faster validation during training - # The first dataset is used to calculate main eval metric. - - name: hypersim_depth - disp_name: hypersim_depth_val_small_80 - dir: hypersim/hypersim_processed_val.tar - filenames: data_split/hypersim_depth/filename_list_val_filtered_small_80.txt - resize_to_hw: - - 480 - - 640 - - - name: nyu_depth - disp_name: nyu_depth_train_small_100 - dir: nyuv2/nyu_labeled_extracted.tar - filenames: data_split/nyu_depth/labeled/filename_list_train_small_100.txt - eigen_valid_mask: true - - - name: kitti_depth - disp_name: kitti_depth_val_from_train_sub_100 - dir: kitti/kitti_sampled_val_800.tar - filenames: data_split/kitti_depth/eigen_val_from_train_sub_100.txt - kitti_bm_crop: true - valid_mask_crop: eigen diff --git a/config/dataset_depth/dataset_vis.yaml b/config/dataset_depth/dataset_vis.yaml deleted file mode 100644 index 9b1dd23748fd784c5e5c2e3210e6b9a8a837eb0e..0000000000000000000000000000000000000000 --- a/config/dataset_depth/dataset_vis.yaml +++ /dev/null @@ -1,9 +0,0 @@ -dataset: - vis: - - name: hypersim_depth - disp_name: hypersim_depth_vis - dir: hypersim/hypersim_processed_val.tar - filenames: data_split/hypersim_depth/selected_vis_sample.txt - resize_to_hw: - - 480 - - 640 diff --git a/config/dataset_iid/data_appearance_interiorverse_test.yaml b/config/dataset_iid/data_appearance_interiorverse_test.yaml deleted file mode 100644 index 26e1d6cc2b8dba72dc5344dc77e4399bfdecdb4d..0000000000000000000000000000000000000000 --- a/config/dataset_iid/data_appearance_interiorverse_test.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: interiorverse_iid -disp_name: interiorverse_iid_appearance_test -dir: interiorverse/InteriorVerse.tar -filenames: data_split/interiorverse_iid/interiorverse_test_scenes_85.txt diff --git a/config/dataset_iid/data_appearance_synthetic_test.yaml b/config/dataset_iid/data_appearance_synthetic_test.yaml deleted file mode 100644 index 413910624eb238661ee2b1be14fb4d453ce1194b..0000000000000000000000000000000000000000 --- a/config/dataset_iid/data_appearance_synthetic_test.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: interiorverse_iid -disp_name: interiorverse_iid_appearance_test -dir: synthetic -filenames: data_split/osu/osu_test_scenes_85.txt diff --git a/config/dataset_iid/data_art_test.yaml b/config/dataset_iid/data_art_test.yaml deleted file mode 100644 index 8c65bdf6b15896f335f3f1234fe8a7403a85b897..0000000000000000000000000000000000000000 --- a/config/dataset_iid/data_art_test.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: interiorverse_iid -disp_name: interiorverse_iid_appearance_test -dir: art -filenames: data_split/osu/art_test_scenes.txt \ No newline at end of file diff --git a/config/dataset_iid/data_lighting_hypersim_test.yaml b/config/dataset_iid/data_lighting_hypersim_test.yaml deleted file mode 100644 index 8a6a6777ff279527a85cf70dcfc204c2073fd05a..0000000000000000000000000000000000000000 --- a/config/dataset_iid/data_lighting_hypersim_test.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: hypersim_iid -disp_name: hypersim_iid_lighting_test -dir: hypersim -filenames: data_split/hypersim_iid/hypersim_test.txt diff --git a/config/dataset_iid/dataset_appearance_train.yaml b/config/dataset_iid/dataset_appearance_train.yaml deleted file mode 100644 index 8a77bf692ceab8484c09659456bfcddfaa00ff9f..0000000000000000000000000000000000000000 --- a/config/dataset_iid/dataset_appearance_train.yaml +++ /dev/null @@ -1,9 +0,0 @@ -dataset: - train: - name: mixed - prob_ls: [1.0] - dataset_list: - - name: interiorverse_iid - disp_name: interiorverse_iid_appearance_train - dir: osu_albedo_new - filenames: data_split/osu/osu_train_scenes_85.txt diff --git a/config/dataset_iid/dataset_appearance_val.yaml b/config/dataset_iid/dataset_appearance_val.yaml deleted file mode 100644 index 75558aa513633ac5d4dc72e7532d1907712d8274..0000000000000000000000000000000000000000 --- a/config/dataset_iid/dataset_appearance_val.yaml +++ /dev/null @@ -1,6 +0,0 @@ -dataset: - val: - - name: interiorverse_iid - disp_name: interiorverse_iid_appearance_val - dir: synthetic - filenames: data_split/MatrixCity/matrixcity_val_scenes_small.txt diff --git a/config/dataset_iid/dataset_appearance_vis.yaml b/config/dataset_iid/dataset_appearance_vis.yaml deleted file mode 100644 index 4a7956fc29be25dc6ca7701275c13dbe6c6e3857..0000000000000000000000000000000000000000 --- a/config/dataset_iid/dataset_appearance_vis.yaml +++ /dev/null @@ -1,6 +0,0 @@ -dataset: - vis: - - name: interiorverse_iid - disp_name: interiorverse_iid_appearance_vis - dir: synthetic - filenames: data_split/MatrixCity/matrixcity_vis_scenes.txt diff --git a/config/dataset_iid/dataset_lighting_train.yaml b/config/dataset_iid/dataset_lighting_train.yaml deleted file mode 100644 index 031c62264689d512dea92efc51870d9522a88d4b..0000000000000000000000000000000000000000 --- a/config/dataset_iid/dataset_lighting_train.yaml +++ /dev/null @@ -1,12 +0,0 @@ -dataset: - train: - name: mixed - prob_ls: [1.0] - dataset_list: - - name: hypersim_iid - disp_name: hypersim_iid_lighting_train - dir: hypersim - filenames: data_split/hypersim_iid/hypersim_train_filtered.txt - resize_to_hw: - - 480 - - 640 \ No newline at end of file diff --git a/config/dataset_iid/dataset_lighting_val.yaml b/config/dataset_iid/dataset_lighting_val.yaml deleted file mode 100644 index 58e2da3facbff138bbe639bf88287f541ddbf281..0000000000000000000000000000000000000000 --- a/config/dataset_iid/dataset_lighting_val.yaml +++ /dev/null @@ -1,6 +0,0 @@ -dataset: - val: - - name: hypersim_iid - disp_name: hypersim_iid_lighting_val - dir: hypersim - filenames: data_split/hypersim_iid/hypersim_val.txt diff --git a/config/dataset_iid/dataset_lighting_vis.yaml b/config/dataset_iid/dataset_lighting_vis.yaml deleted file mode 100644 index 70608ae85e27f7c3aa5a2069edda3cbc20b33872..0000000000000000000000000000000000000000 --- a/config/dataset_iid/dataset_lighting_vis.yaml +++ /dev/null @@ -1,6 +0,0 @@ -dataset: - vis: - - name: hypersim_iid - disp_name: hypersim_iid_lighting_vis - dir: hypersim - filenames: data_split/hypersim_iid/hypersim_vis.txt diff --git a/config/dataset_iid/osu_data_appearance_interiorverse_test.yaml b/config/dataset_iid/osu_data_appearance_interiorverse_test.yaml deleted file mode 100644 index 413910624eb238661ee2b1be14fb4d453ce1194b..0000000000000000000000000000000000000000 --- a/config/dataset_iid/osu_data_appearance_interiorverse_test.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: interiorverse_iid -disp_name: interiorverse_iid_appearance_test -dir: synthetic -filenames: data_split/osu/osu_test_scenes_85.txt diff --git a/config/dataset_normals/data_diode_test.yaml b/config/dataset_normals/data_diode_test.yaml deleted file mode 100644 index b8de2bbf22bd58792e12fc2d126b8402a80bec64..0000000000000000000000000000000000000000 --- a/config/dataset_normals/data_diode_test.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: diode_normals -disp_name: diode_normals_test -dir: diode/val -filenames: data_split/diode_normals/diode_test.txt diff --git a/config/dataset_normals/data_ibims_test.yaml b/config/dataset_normals/data_ibims_test.yaml deleted file mode 100644 index 41e43ec4f5b31c7d1aaedeca2e90986018121c69..0000000000000000000000000000000000000000 --- a/config/dataset_normals/data_ibims_test.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: ibims_normals -disp_name: ibims_normals_test -dir: ibims/ibims -filenames: data_split/ibims_normals/ibims_test.txt diff --git a/config/dataset_normals/data_nyu_test.yaml b/config/dataset_normals/data_nyu_test.yaml deleted file mode 100644 index 13263dd0c6a06f89cf69ad94643261477db114ff..0000000000000000000000000000000000000000 --- a/config/dataset_normals/data_nyu_test.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: nyu_normals -disp_name: nyu_normals_test -dir: nyuv2/test -filenames: data_split/nyu_normals/nyuv2_test.txt diff --git a/config/dataset_normals/data_oasis_test.yaml b/config/dataset_normals/data_oasis_test.yaml deleted file mode 100644 index d00eb6d6ad37c4cf5e573fa8b45f0fae3ae73e98..0000000000000000000000000000000000000000 --- a/config/dataset_normals/data_oasis_test.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: oasis_normals -disp_name: oasis_normals_test -dir: oasis/val -filenames: data_split/oasis_normals/oasis_test.txt diff --git a/config/dataset_normals/data_scannet_test.yaml b/config/dataset_normals/data_scannet_test.yaml deleted file mode 100644 index 28d6e006795610c0772ff1704af664de63afb6b1..0000000000000000000000000000000000000000 --- a/config/dataset_normals/data_scannet_test.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: scannet_normals -disp_name: scannet_normals_test -dir: scannet -filenames: data_split/scannet_normals/scannet_test.txt diff --git a/config/dataset_normals/dataset_train.yaml b/config/dataset_normals/dataset_train.yaml deleted file mode 100644 index 053f25228e503f447da502da8e8915f9e22f174e..0000000000000000000000000000000000000000 --- a/config/dataset_normals/dataset_train.yaml +++ /dev/null @@ -1,25 +0,0 @@ -dataset: - train: - name: mixed - prob_ls: [0.5, 0.49, 0.01] - dataset_list: - - name: hypersim_normals - disp_name: hypersim_normals_train - dir: hypersim - filenames: data_split/hypersim_normals/hypersim_filtered_all.txt - resize_to_hw: - - 480 - - 640 - - name: interiorverse_normals - disp_name: interiorverse_normals_train - dir: interiorverse/scenes_85 - filenames: data_split/interiorverse_normals/interiorverse_filtered_all.txt - resize_to_hw: null - - name: sintel_normals - disp_name: sintel_normals_train - dir: sintel - filenames: data_split/sintel_normals/sintel_filtered.txt - resize_to_hw: - - 480 - - 640 - center_crop: true diff --git a/config/dataset_normals/dataset_val.yaml b/config/dataset_normals/dataset_val.yaml deleted file mode 100644 index b4a23d65b489af5e6afb3d6c5b325778bb85e8d7..0000000000000000000000000000000000000000 --- a/config/dataset_normals/dataset_val.yaml +++ /dev/null @@ -1,7 +0,0 @@ -dataset: - val: - - name: hypersim_normals - disp_name: hypersim_normals_val_small_100 - dir: hypersim - filenames: data_split/hypersim_normals/hypersim_filtered_val_100.txt - resize_to_hw: null diff --git a/config/dataset_normals/dataset_vis.yaml b/config/dataset_normals/dataset_vis.yaml deleted file mode 100644 index 6ff805087bcd423b2d8c66d0d96c250394d61cc3..0000000000000000000000000000000000000000 --- a/config/dataset_normals/dataset_vis.yaml +++ /dev/null @@ -1,7 +0,0 @@ -dataset: - vis: - - name: hypersim_normals - disp_name: hypersim_normals_vis - dir: hypersim - filenames: data_split/hypersim_normals/hypersim_filtered_vis_20.txt - resize_to_hw: null diff --git a/config/logging.yaml b/config/logging.yaml deleted file mode 100644 index 8cecbaeca91f2ae38677336b4dc7dcc5ad020f60..0000000000000000000000000000000000000000 --- a/config/logging.yaml +++ /dev/null @@ -1,5 +0,0 @@ -logging: - filename: logging.log - format: ' %(asctime)s - %(levelname)s -%(filename)s - %(funcName)s >> %(message)s' - console_level: 20 - file_level: 10 diff --git a/config/model_sdv2.yaml b/config/model_sdv2.yaml deleted file mode 100644 index 4fb702126ae11eb8b74d57f0d5569bd790269af6..0000000000000000000000000000000000000000 --- a/config/model_sdv2.yaml +++ /dev/null @@ -1,4 +0,0 @@ -model: - name: marigold_pipeline - pretrained_path: stable-diffusion-2 - latent_scale_factor: 0.18215 diff --git a/config/train_debug_depth.yaml b/config/train_debug_depth.yaml deleted file mode 100644 index ba4ddfada72fedd93acf7bc0d4a4f0e04129fabf..0000000000000000000000000000000000000000 --- a/config/train_debug_depth.yaml +++ /dev/null @@ -1,10 +0,0 @@ -base_config: - - config/train_marigold_depth.yaml - -trainer: - save_period: 5 - backup_period: 10 - validation_period: 5 - visualization_period: 5 - -max_iter: 50 diff --git a/config/train_debug_iid.yaml b/config/train_debug_iid.yaml deleted file mode 100644 index b8b95c038eb22c39f3b23a6329126a381e3ba46b..0000000000000000000000000000000000000000 --- a/config/train_debug_iid.yaml +++ /dev/null @@ -1,11 +0,0 @@ -base_config: - # - config/train_marigold_iid_lighting.yaml - - config/train_marigold_iid_appearance.yaml - -trainer: - save_period: 10 - backup_period: 10 - validation_period: 5 - visualization_period: 5 - -max_iter: 50 diff --git a/config/train_debug_normals.yaml b/config/train_debug_normals.yaml deleted file mode 100644 index dcfb4e146c63a3a662fd05e65f06f56b39611a67..0000000000000000000000000000000000000000 --- a/config/train_debug_normals.yaml +++ /dev/null @@ -1,10 +0,0 @@ -base_config: - - config/train_marigold_normals.yaml - -trainer: - save_period: 5 - backup_period: 10 - validation_period: 5 - visualization_period: 5 - -max_iter: 50 diff --git a/config/train_marigold_depth.yaml b/config/train_marigold_depth.yaml deleted file mode 100644 index f35828b7a51c4c39fb4ded843a4e7f35308b2bd2..0000000000000000000000000000000000000000 --- a/config/train_marigold_depth.yaml +++ /dev/null @@ -1,94 +0,0 @@ -base_config: -- config/logging.yaml -- config/wandb.yaml -- config/dataset_depth/dataset_train.yaml -- config/dataset_depth/dataset_val.yaml -- config/dataset_depth/dataset_vis.yaml -- config/model_sdv2.yaml - -pipeline: - name: MarigoldDepthPipeline - kwargs: - scale_invariant: true - shift_invariant: true - default_denoising_steps: 4 - default_processing_resolution: 768 - -depth_normalization: - type: scale_shift_depth - clip: true - norm_min: -1.0 - norm_max: 1.0 - min_max_quantile: 0.02 - -augmentation: - lr_flip_p: 0.5 - -dataloader: - num_workers: 2 - effective_batch_size: 32 - max_train_batch_size: 2 - seed: 2024 # to ensure continuity when resuming from checkpoint - -trainer: - name: MarigoldDepthTrainer - training_noise_scheduler: - pretrained_path: stable-diffusion-2 - init_seed: 2024 # use null to train w/o seeding - save_period: 50 - backup_period: 2000 - validation_period: 500 - visualization_period: 1000 - -multi_res_noise: - strength: 0.9 - annealed: true - downscale_strategy: original - -gt_depth_type: depth_raw_norm -gt_mask_type: valid_mask_raw - -max_epoch: 10000 # a large enough number -max_iter: 30000 # usually converges at around 20k - -optimizer: - name: Adam - -loss: - name: mse_loss - kwargs: - reduction: mean - -lr: 3.0e-05 -lr_scheduler: - name: IterExponential - kwargs: - total_iter: 25000 - final_ratio: 0.01 - warmup_steps: 100 - -# Light setting for the in-training validation and visualization -validation: - denoising_steps: 1 - ensemble_size: 1 - processing_res: 0 - match_input_res: false - resample_method: bilinear - main_val_metric: abs_relative_difference - main_val_metric_goal: minimize - init_seed: 2024 - -eval: - alignment: least_square - align_max_res: null - eval_metrics: - - abs_relative_difference - - squared_relative_difference - - rmse_linear - - rmse_log - - log10 - - delta1_acc - - delta2_acc - - delta3_acc - - i_rmse - - silog_rmse diff --git a/config/train_marigold_iid_appearance.yaml b/config/train_marigold_iid_appearance.yaml deleted file mode 100644 index e2492257c3e5e3559c1952bd606ac7914e495775..0000000000000000000000000000000000000000 --- a/config/train_marigold_iid_appearance.yaml +++ /dev/null @@ -1,81 +0,0 @@ -base_config: -- config/logging.yaml -- config/wandb.yaml -- config/dataset_iid/dataset_appearance_train.yaml -- config/dataset_iid/dataset_appearance_val.yaml -- config/dataset_iid/dataset_appearance_vis.yaml -- config/model_sdv2.yaml - -pipeline: - name: MarigoldIIDPipeline - kwargs: - default_denoising_steps: 4 - default_processing_resolution: 768 - target_properties: - target_names: - - albedo - albedo: - prediction_space: srgb - -augmentation: - lr_flip_p: 0.5 - -dataloader: - num_workers: 2 - effective_batch_size: 32 - max_train_batch_size: 8 - seed: 2024 # to ensure continuity when resuming from checkpoint - -trainer: - name: MarigoldIIDTrainer - training_noise_scheduler: - pretrained_path: stable-diffusion-2 - init_seed: 2024 # use null to train w/o seeding - save_period: 50 - backup_period: 2000 - validation_period: 500 - visualization_period: 1000 - -multi_res_noise: - strength: 0.9 - annealed: true - downscale_strategy: original - -gt_mask_type: mask - -max_epoch: 10000 # a large enough number -max_iter: 10000 # usually converges at around 40k - -optimizer: - name: Adam - -loss: - name: mse_loss - kwargs: - reduction: mean - -lr: 2.0e-05 -lr_scheduler: - name: IterExponential - kwargs: - total_iter: 5000 - final_ratio: 0.01 - warmup_steps: 100 - -# Light setting for the in-training validation and visualization -validation: - denoising_steps: 4 - ensemble_size: 1 - processing_res: 0 - match_input_res: true - resample_method: bilinear - main_val_metric: psnr - main_val_metric_goal: maximize - init_seed: 2024 - use_mask: false - -eval: - eval_metrics: - - psnr - targets_to_eval_in_linear_space: - - material diff --git a/config/train_marigold_iid_appearance_finetuned.yaml b/config/train_marigold_iid_appearance_finetuned.yaml deleted file mode 100644 index 1942b4ed03898a07282f03683edb5866dd6dfb83..0000000000000000000000000000000000000000 --- a/config/train_marigold_iid_appearance_finetuned.yaml +++ /dev/null @@ -1,81 +0,0 @@ -base_config: -- config/logging.yaml -- config/wandb.yaml -- config/dataset_iid/dataset_appearance_train.yaml -- config/dataset_iid/dataset_appearance_val.yaml -- config/dataset_iid/dataset_appearance_vis.yaml -- config/model_sdv2.yaml - -pipeline: - name: MarigoldIIDPipeline - kwargs: - default_denoising_steps: 4 - default_processing_resolution: 768 - target_properties: - target_names: - - albedo - albedo: - prediction_space: srgb - -augmentation: - lr_flip_p: 0.5 - -dataloader: - num_workers: 2 - effective_batch_size: 32 - max_train_batch_size: 8 - seed: 2024 # to ensure continuity when resuming from checkpoint - -trainer: - name: MarigoldIIDTrainer - training_noise_scheduler: - pretrained_path: stable-diffusion-2 - init_seed: 2024 # use null to train w/o seeding - save_period: 50 - backup_period: 2000 - validation_period: 177 - visualization_period: 177 - -multi_res_noise: - strength: 0.9 - annealed: true - downscale_strategy: original - -gt_mask_type: null - -max_epoch: 10000 # a large enough number -max_iter: 5000 # usually converges at around 40k - -optimizer: - name: Adam - -loss: - name: mse_loss - kwargs: - reduction: mean - -lr: 5.0e-07 -lr_scheduler: - name: IterExponential - kwargs: - total_iter: 2500 - final_ratio: 0.01 - warmup_steps: 100 - -# Light setting for the in-training validation and visualization -validation: - denoising_steps: 4 - ensemble_size: 1 - processing_res: 1000 - match_input_res: true - resample_method: bilinear - main_val_metric: psnr - main_val_metric_goal: maximize - init_seed: 2024 - use_mask: false - -eval: - eval_metrics: - - psnr - targets_to_eval_in_linear_space: - - material diff --git a/config/train_marigold_iid_lighting.yaml b/config/train_marigold_iid_lighting.yaml deleted file mode 100644 index 1194dc773ad62387e688efdca6a5f7ffb28a4cdd..0000000000000000000000000000000000000000 --- a/config/train_marigold_iid_lighting.yaml +++ /dev/null @@ -1,82 +0,0 @@ -base_config: -- config/logging.yaml -- config/wandb.yaml -- config/dataset_iid/dataset_lighting_train.yaml -- config/dataset_iid/dataset_lighting_val.yaml -- config/dataset_iid/dataset_lighting_vis.yaml -- config/model_sdv2.yaml - -pipeline: - name: MarigoldIIDPipeline - kwargs: - default_denoising_steps: 4 - default_processing_resolution: 768 - target_properties: - target_names: - - albedo - albedo: - prediction_space: linear - up_to_scale: false - -augmentation: - lr_flip_p: 0.5 - -dataloader: - num_workers: 2 - effective_batch_size: 32 - max_train_batch_size: 8 - seed: 2024 # to ensure continuity when resuming from checkpoint - -trainer: - name: MarigoldIIDTrainer - training_noise_scheduler: - pretrained_path: stable-diffusion-2 - init_seed: 2024 # use null to train w/o seeding - save_period: 50 - backup_period: 2000 - validation_period: 500 - visualization_period: 1000 - -multi_res_noise: - strength: 0.9 - annealed: true - downscale_strategy: original - -gt_mask_type: mask - -max_epoch: 10000 # a large enough number -max_iter: 50000 # usually converges at around 34k - -optimizer: - name: Adam - -loss: - name: mse_loss - kwargs: - reduction: mean - -lr: 8e-05 -lr_scheduler: - name: IterExponential - kwargs: - total_iter: 45000 - final_ratio: 0.01 - warmup_steps: 100 - -# Light setting for the in-training validation and visualization -validation: - denoising_steps: 4 - ensemble_size: 1 - processing_res: 0 - match_input_res: true - resample_method: bilinear - main_val_metric: psnr - main_val_metric_goal: maximize - init_seed: 2024 - use_mask: false - -eval: - eval_metrics: - - psnr - targets_to_eval_in_linear_space: - - None diff --git a/config/train_marigold_normals.yaml b/config/train_marigold_normals.yaml deleted file mode 100644 index 0a3a321e4a872a4334c4eedd4356de88a4d5619b..0000000000000000000000000000000000000000 --- a/config/train_marigold_normals.yaml +++ /dev/null @@ -1,86 +0,0 @@ -base_config: -- config/logging.yaml -- config/wandb.yaml -- config/dataset_normals/dataset_train.yaml -- config/dataset_normals/dataset_val.yaml -- config/dataset_normals/dataset_vis.yaml -- config/model_sdv2.yaml - -pipeline: - name: MarigoldNormalsPipeline - kwargs: - default_denoising_steps: 4 - default_processing_resolution: 768 - -augmentation: - lr_flip_p: 0.5 - color_jitter_p: 0.3 - gaussian_blur_p: 0.3 - motion_blur_p: 0.3 - gaussian_blur_sigma: 4 - motion_blur_kernel_size: 11 - motion_blur_angle_range: 360 - jitter_brightness_factor: 0.5 - jitter_contrast_factor: 0.5 - jitter_saturation_factor: 0.5 - jitter_hue_factor: 0.2 - -dataloader: - num_workers: 2 - effective_batch_size: 32 - max_train_batch_size: 2 - seed: 2024 # to ensure continuity when resuming from checkpoint - -trainer: - name: MarigoldNormalsTrainer - training_noise_scheduler: - pretrained_path: stable-diffusion-2 - init_seed: 2024 # use null to train w/o seeding - save_period: 50 - backup_period: 2000 - validation_period: 500 - visualization_period: 1000 - -multi_res_noise: - strength: 0.9 - annealed: true - downscale_strategy: original - -gt_normals_type: normals -gt_mask_type: null - -max_epoch: 10000 # a large enough number -max_iter: 30000 # usually converges at around 26k - -optimizer: - name: Adam - -loss: - name: mse_loss - kwargs: - reduction: mean - -lr: 6.0e-05 -lr_scheduler: - name: IterExponential - kwargs: - total_iter: 25000 - final_ratio: 0.01 - warmup_steps: 100 - -# Light setting for the in-training validation and visualization -validation: - denoising_steps: 4 - ensemble_size: 1 - processing_res: 768 - match_input_res: true - resample_method: bilinear - main_val_metric: mean_angular_error - main_val_metric_goal: minimize - init_seed: 0 - -eval: - align_max_res: null - eval_metrics: - - mean_angular_error - - sub11_25_error diff --git a/config/wandb.yaml b/config/wandb.yaml deleted file mode 100644 index 1840d26b8cd85da1d922117305dbaa9faaccaee2..0000000000000000000000000000000000000000 --- a/config/wandb.yaml +++ /dev/null @@ -1,3 +0,0 @@ -wandb: - # entity: your_entity - project: marigold diff --git a/marigold/__init__.py b/marigold/__init__.py deleted file mode 100644 index 47d2556ccacef86f30aac19402e690c11b5f629f..0000000000000000000000000000000000000000 --- a/marigold/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -from .marigold_depth_pipeline import ( - MarigoldDepthPipeline, - MarigoldDepthOutput, # noqa: F401 -) -from .marigold_iid_pipeline import MarigoldIIDPipeline, MarigoldIIDOutput # noqa: F401 -from .marigold_normals_pipeline import ( - MarigoldNormalsPipeline, # noqa: F401 - MarigoldNormalsOutput, # noqa: F401 -) - -MarigoldPipeline = MarigoldDepthPipeline # for backward compatibility diff --git a/marigold/marigold_depth_pipeline.py b/marigold/marigold_depth_pipeline.py deleted file mode 100644 index 07d02e2bba3ce8b4166541d65dd63d22d21ec8b7..0000000000000000000000000000000000000000 --- a/marigold/marigold_depth_pipeline.py +++ /dev/null @@ -1,516 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import logging -import numpy as np -import torch -from PIL import Image -from diffusers import ( - AutoencoderKL, - DDIMScheduler, - DiffusionPipeline, - LCMScheduler, - UNet2DConditionModel, -) -from diffusers.utils import BaseOutput -from torch.utils.data import DataLoader, TensorDataset -from torchvision.transforms import InterpolationMode -from torchvision.transforms.functional import pil_to_tensor, resize -from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTokenizer -from typing import Dict, Optional, Union - -from .util.batchsize import find_batch_size -from .util.ensemble import ensemble_depth -from .util.image_util import ( - chw2hwc, - colorize_depth_maps, - get_tv_resample_method, - resize_max_res, -) - - -class MarigoldDepthOutput(BaseOutput): - """ - Output class for Marigold Monocular Depth Estimation pipeline. - - Args: - depth_np (`np.ndarray`): - Predicted depth map, with depth values in the range of [0, 1]. - depth_colored (`PIL.Image.Image`): - Colorized depth map, with the shape of [H, W, 3] and values in [0, 255]. - uncertainty (`None` or `np.ndarray`): - Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. - """ - - depth_np: np.ndarray - depth_colored: Union[None, Image.Image] - uncertainty: Union[None, np.ndarray] - - -class MarigoldDepthPipeline(DiffusionPipeline): - """ - Pipeline for Marigold Monocular Depth Estimation: https://marigoldcomputervision.github.io. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - unet (`UNet2DConditionModel`): - Conditional U-Net to denoise the prediction latent, conditioned on image latent. - vae (`AutoencoderKL`): - Variational Auto-Encoder (VAE) Model to encode and decode images and predictions - to and from latent representations. - scheduler (`DDIMScheduler`): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. - text_encoder (`CLIPTextModel`): - Text-encoder, for empty text embedding. - tokenizer (`CLIPTokenizer`): - CLIP tokenizer. - scale_invariant (`bool`, *optional*): - A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in - the model config. When used together with the `shift_invariant=True` flag, the model is also called - "affine-invariant". NB: overriding this value is not supported. - shift_invariant (`bool`, *optional*): - A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in - the model config. When used together with the `scale_invariant=True` flag, the model is also called - "affine-invariant". NB: overriding this value is not supported. - default_denoising_steps (`int`, *optional*): - The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable - quality with the given model. This value must be set in the model config. When the pipeline is called - without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure - reasonable results with various model flavors compatible with the pipeline, such as those relying on very - short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). - default_processing_resolution (`int`, *optional*): - The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in - the model config. When the pipeline is called without explicitly setting `processing_resolution`, the - default value is used. This is required to ensure reasonable results with various model flavors trained - with varying optimal processing resolution values. - """ - - latent_scale_factor = 0.18215 - - def __init__( - self, - unet: UNet2DConditionModel, - vae: AutoencoderKL, - scheduler: Union[DDIMScheduler, LCMScheduler], - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - scale_invariant: Optional[bool] = True, - shift_invariant: Optional[bool] = True, - default_denoising_steps: Optional[int] = None, - default_processing_resolution: Optional[int] = None, - ): - super().__init__() - self.register_modules( - unet=unet, - vae=vae, - scheduler=scheduler, - text_encoder=text_encoder, - tokenizer=tokenizer, - ) - self.register_to_config( - scale_invariant=scale_invariant, - shift_invariant=shift_invariant, - default_denoising_steps=default_denoising_steps, - default_processing_resolution=default_processing_resolution, - ) - - self.scale_invariant = scale_invariant - self.shift_invariant = shift_invariant - self.default_denoising_steps = default_denoising_steps - self.default_processing_resolution = default_processing_resolution - - self.empty_text_embed = None - - @torch.no_grad() - def __call__( - self, - input_image: Union[Image.Image, torch.Tensor], - denoising_steps: Optional[int] = None, - ensemble_size: int = 1, - processing_res: Optional[int] = None, - match_input_res: bool = True, - resample_method: str = "bilinear", - batch_size: int = 0, - generator: Union[torch.Generator, None] = None, - color_map: str = "Spectral", - show_progress_bar: bool = True, - ensemble_kwargs: Dict = None, - ) -> MarigoldDepthOutput: - """ - Function invoked when calling the pipeline. - - Args: - input_image (`Image`): - Input RGB (or gray-scale) image. - denoising_steps (`int`, *optional*, defaults to `None`): - Number of denoising diffusion steps during inference. The default value `None` results in automatic - selection. - ensemble_size (`int`, *optional*, defaults to `1`): - Number of predictions to be ensembled. - processing_res (`int`, *optional*, defaults to `None`): - Effective processing resolution. When set to `0`, processes at the original image resolution. This - produces crisper predictions, but may also lead to the overall loss of global context. The default - value `None` resolves to the optimal value from the model config. - match_input_res (`bool`, *optional*, defaults to `True`): - Resize the prediction to match the input resolution. - Only valid if `processing_res` > 0. - resample_method: (`str`, *optional*, defaults to `bilinear`): - Resampling method used to resize images and predictions. This can be one of `bilinear`, `bicubic` or - `nearest`, defaults to: `bilinear`. - batch_size (`int`, *optional*, defaults to `0`): - Inference batch size, no bigger than `num_ensemble`. - If set to 0, the script will automatically decide the proper batch size. - generator (`torch.Generator`, *optional*, defaults to `None`) - Random generator for initial noise generation. - show_progress_bar (`bool`, *optional*, defaults to `True`): - Display a progress bar of diffusion denoising. - color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation): - Colormap used to colorize the depth map. - scale_invariant (`str`, *optional*, defaults to `True`): - Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction. - shift_invariant (`str`, *optional*, defaults to `True`): - Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, - near plane will be fixed at 0m. - ensemble_kwargs (`dict`, *optional*, defaults to `None`): - Arguments for detailed ensembling settings. - Returns: - `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: - - **depth_np** (`np.ndarray`) Predicted depth map with depth values in the range of [0, 1] - - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [H, W, 3] and values in [0, 255], None if `color_map` is `None` - - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) - coming from ensembling. None if `ensemble_size = 1` - """ - # Model-specific optimal default values leading to fast and reasonable results. - if denoising_steps is None: - denoising_steps = self.default_denoising_steps - if processing_res is None: - processing_res = self.default_processing_resolution - - assert processing_res >= 0 - assert ensemble_size >= 1 - - # Check if denoising step is reasonable - self._check_inference_step(denoising_steps) - - resample_method: InterpolationMode = get_tv_resample_method(resample_method) - - # ----------------- Image Preprocess ----------------- - # Convert to torch tensor - if isinstance(input_image, Image.Image): - input_image = input_image.convert("RGB") - # convert to torch tensor [H, W, rgb] -> [rgb, H, W] - rgb = pil_to_tensor(input_image) - rgb = rgb.unsqueeze(0) # [1, rgb, H, W] - elif isinstance(input_image, torch.Tensor): - rgb = input_image - else: - raise TypeError(f"Unknown input type: {type(input_image) = }") - input_size = rgb.shape - assert ( - 4 == rgb.dim() and 3 == input_size[-3] - ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]" - - # Resize image - if processing_res > 0: - rgb = resize_max_res( - rgb, - max_edge_resolution=processing_res, - resample_method=resample_method, - ) - - # Normalize rgb values - rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] - rgb_norm = rgb_norm.to(self.dtype) - assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 - - # ----------------- Predicting depth ----------------- - # Batch repeated input image - duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1) - single_rgb_dataset = TensorDataset(duplicated_rgb) - if batch_size > 0: - _bs = batch_size - else: - _bs = find_batch_size( - ensemble_size=ensemble_size, - input_res=max(rgb_norm.shape[1:]), - dtype=self.dtype, - ) - - single_rgb_loader = DataLoader( - single_rgb_dataset, batch_size=_bs, shuffle=False - ) - - # Predict depth maps (batched) - target_pred_ls = [] - if show_progress_bar: - iterable = tqdm( - single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False - ) - else: - iterable = single_rgb_loader - for batch in iterable: - (batched_img,) = batch - target_pred_raw = self.single_infer( - rgb_in=batched_img, - num_inference_steps=denoising_steps, - show_pbar=show_progress_bar, - generator=generator, - ) - target_pred_ls.append(target_pred_raw.detach()) - target_preds = torch.concat(target_pred_ls, dim=0) - torch.cuda.empty_cache() # clear vram cache for ensembling - - # ----------------- Test-time ensembling ----------------- - if ensemble_size > 1: - final_pred, pred_uncert = ensemble_depth( - target_preds, - scale_invariant=self.scale_invariant, - shift_invariant=self.shift_invariant, - **(ensemble_kwargs or {}), - ) - else: - final_pred = target_preds - pred_uncert = None - - # Resize back to original resolution - if match_input_res: - final_pred = resize( - final_pred, - input_size[-2:], - interpolation=resample_method, - antialias=True, - ) - - # Convert to numpy - final_pred = final_pred.squeeze() - final_pred = final_pred.cpu().numpy() - if pred_uncert is not None: - pred_uncert = pred_uncert.squeeze().cpu().numpy() - - # Clip output range - final_pred = final_pred.clip(0, 1) - - # Colorize - if color_map is not None: - depth_colored = colorize_depth_maps( - final_pred, 0, 1, cmap=color_map - ).squeeze() # [3, H, W], value in (0, 1) - depth_colored = (depth_colored * 255).astype(np.uint8) - depth_colored_hwc = chw2hwc(depth_colored) - depth_colored_img = Image.fromarray(depth_colored_hwc) - else: - depth_colored_img = None - - return MarigoldDepthOutput( - depth_np=final_pred, - depth_colored=depth_colored_img, - uncertainty=pred_uncert, - ) - - def _check_inference_step(self, n_step: int) -> None: - """ - Check if denoising step is reasonable - Args: - n_step (`int`): denoising steps - """ - assert n_step >= 1 - - if isinstance(self.scheduler, DDIMScheduler): - if "trailing" != self.scheduler.config.timestep_spacing: - logging.warning( - f"The loaded `DDIMScheduler` is configured with `timestep_spacing=" - f'"{self.scheduler.config.timestep_spacing}"`; the recommended setting is `"trailing"`. ' - f"This change is backward-compatible and yields better results. " - f"Consider using `prs-eth/marigold-depth-v1-1` for the best experience." - ) - else: - if n_step > 10: - logging.warning( - f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on " - f"the default values." - ) - if not self.scheduler.config.rescale_betas_zero_snr: - logging.warning( - f"The loaded `DDIMScheduler` is configured with `rescale_betas_zero_snr=" - f"{self.scheduler.config.rescale_betas_zero_snr}`; the recommended setting is True. " - f"Consider using `prs-eth/marigold-depth-v1-1` for the best experience." - ) - elif isinstance(self.scheduler, LCMScheduler): - logging.warning( - "DeprecationWarning: LCMScheduler will not be supported in the future. " - "Consider using `prs-eth/marigold-depth-v1-1` for the best experience." - ) - if n_step > 10: - logging.warning( - f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on " - f"the default values." - ) - else: - raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}") - - def encode_empty_text(self): - """ - Encode text embedding for empty prompt - """ - prompt = "" - text_inputs = self.tokenizer( - prompt, - padding="do_not_pad", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) - self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) - - @torch.no_grad() - def single_infer( - self, - rgb_in: torch.Tensor, - num_inference_steps: int, - generator: Union[torch.Generator, None], - show_pbar: bool, - ) -> torch.Tensor: - """ - Perform a single prediction without ensembling. - - Args: - rgb_in (`torch.Tensor`): - Input RGB image. - num_inference_steps (`int`): - Number of diffusion denoisign steps (DDIM) during inference. - show_pbar (`bool`): - Display a progress bar of diffusion denoising. - generator (`torch.Generator`) - Random generator for initial noise generation. - Returns: - `torch.Tensor`: Predicted targets. - """ - device = self.device - rgb_in = rgb_in.to(device) - - # Set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps # [T] - - # Encode image - rgb_latent = self.encode_rgb(rgb_in) # [B, 4, h, w] - - # Noisy latent for outputs - target_latent = torch.randn( - rgb_latent.shape, - device=device, - dtype=self.dtype, - generator=generator, - ) # [B, 4, h, w] - - # Batched empty text embedding - if self.empty_text_embed is None: - self.encode_empty_text() - batch_empty_text_embed = self.empty_text_embed.repeat( - (rgb_latent.shape[0], 1, 1) - ).to(device) # [B, 2, 1024] - - # Denoising loop - if show_pbar: - iterable = tqdm( - enumerate(timesteps), - total=len(timesteps), - leave=False, - desc=" " * 4 + "Diffusion denoising", - ) - else: - iterable = enumerate(timesteps) - - for i, t in iterable: - unet_input = torch.cat( - [rgb_latent, target_latent], dim=1 - ) # this order is important - - # predict the noise residual - noise_pred = self.unet( - unet_input, t, encoder_hidden_states=batch_empty_text_embed - ).sample # [B, 4, h, w] - - # compute the previous noisy sample x_t -> x_t-1 - target_latent = self.scheduler.step( - noise_pred, t, target_latent, generator=generator - ).prev_sample - - depth = self.decode_depth(target_latent) # [B,3,H,W] - - # clip prediction - depth = torch.clip(depth, -1.0, 1.0) - # shift to [0, 1] - depth = (depth + 1.0) / 2.0 - - return depth - - def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: - """ - Encode RGB image into latent. - - Args: - rgb_in (`torch.Tensor`): - Input RGB image to be encoded. - - Returns: - `torch.Tensor`: Image latent. - """ - # encode - h = self.vae.encoder(rgb_in) - moments = self.vae.quant_conv(h) - mean, logvar = torch.chunk(moments, 2, dim=1) - # scale latent - rgb_latent = mean * self.latent_scale_factor - return rgb_latent - - def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: - """ - Decode depth latent into depth map. - - Args: - depth_latent (`torch.Tensor`): - Depth latent to be decoded. - - Returns: - `torch.Tensor`: Decoded depth map. - """ - # scale latent - depth_latent = depth_latent / self.latent_scale_factor - # decode - z = self.vae.post_quant_conv(depth_latent) - stacked = self.vae.decoder(z) - # mean of output channels - depth_mean = stacked.mean(dim=1, keepdim=True) - return depth_mean diff --git a/marigold/marigold_normals_pipeline.py b/marigold/marigold_normals_pipeline.py deleted file mode 100644 index 9f1a06a1fc5926f35683edfd102246078f9b0d29..0000000000000000000000000000000000000000 --- a/marigold/marigold_normals_pipeline.py +++ /dev/null @@ -1,479 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import logging -import numpy as np -import torch -from PIL import Image -from diffusers import ( - AutoencoderKL, - DDIMScheduler, - DiffusionPipeline, - LCMScheduler, - UNet2DConditionModel, -) -from diffusers.utils import BaseOutput -from torch.utils.data import DataLoader, TensorDataset -from torchvision.transforms import InterpolationMode -from torchvision.transforms.functional import pil_to_tensor, resize -from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTokenizer -from typing import Dict, Optional, Union - -from .util.batchsize import find_batch_size -from .util.ensemble import ensemble_normals -from .util.image_util import ( - chw2hwc, - get_tv_resample_method, - resize_max_res, -) - - -class MarigoldNormalsOutput(BaseOutput): - """ - Output class for Marigold Surface Normals Estimation pipeline. - - Args: - normals_np (`np.ndarray`): - Predicted normals map of shape [3, H, W] with values in the range of [-1, 1] (unit length vectors). - normals_img (`PIL.Image.Image`): - Normals image, with the shape of [H, W, 3] and values in [0, 255]. - uncertainty (`None` or `np.ndarray`): - Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. - """ - - normals_np: np.ndarray - normals_img: Image.Image - uncertainty: Union[None, np.ndarray] - - -class MarigoldNormalsPipeline(DiffusionPipeline): - """ - Pipeline for Marigold Surface Normals Estimation: https://marigoldcomputervision.github.io. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - unet (`UNet2DConditionModel`): - Conditional U-Net to denoise the prediction latent, conditioned on image latent. - vae (`AutoencoderKL`): - Variational Auto-Encoder (VAE) Model to encode and decode images and predictions - to and from latent representations. - scheduler (`DDIMScheduler`): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. - text_encoder (`CLIPTextModel`): - Text-encoder, for empty text embedding. - tokenizer (`CLIPTokenizer`): - CLIP tokenizer. - default_denoising_steps (`int`, *optional*): - The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable - quality with the given model. This value must be set in the model config. When the pipeline is called - without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure - reasonable results with various model flavors compatible with the pipeline, such as those relying on very - short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). - default_processing_resolution (`int`, *optional*): - The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in - the model config. When the pipeline is called without explicitly setting `processing_resolution`, the - default value is used. This is required to ensure reasonable results with various model flavors trained - with varying optimal processing resolution values. - """ - - latent_scale_factor = 0.18215 - - def __init__( - self, - unet: UNet2DConditionModel, - vae: AutoencoderKL, - scheduler: Union[DDIMScheduler, LCMScheduler], - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - default_denoising_steps: Optional[int] = None, - default_processing_resolution: Optional[int] = None, - ): - super().__init__() - self.register_modules( - unet=unet, - vae=vae, - scheduler=scheduler, - text_encoder=text_encoder, - tokenizer=tokenizer, - ) - self.register_to_config( - default_denoising_steps=default_denoising_steps, - default_processing_resolution=default_processing_resolution, - ) - - self.default_denoising_steps = default_denoising_steps - self.default_processing_resolution = default_processing_resolution - - self.empty_text_embed = None - - @torch.no_grad() - def __call__( - self, - input_image: Union[Image.Image, torch.Tensor], - denoising_steps: Optional[int] = None, - ensemble_size: int = 1, - processing_res: Optional[int] = None, - match_input_res: bool = True, - resample_method: str = "bilinear", - batch_size: int = 0, - generator: Union[torch.Generator, None] = None, - show_progress_bar: bool = True, - ensemble_kwargs: Dict = None, - ) -> MarigoldNormalsOutput: - """ - Function invoked when calling the pipeline. - - Args: - input_image (`Image`): - Input RGB (or gray-scale) image. - denoising_steps (`int`, *optional*, defaults to `None`): - Number of denoising diffusion steps during inference. The default value `None` results in automatic - selection. - ensemble_size (`int`, *optional*, defaults to `1`): - Number of predictions to be ensembled. - processing_res (`int`, *optional*, defaults to `None`): - Effective processing resolution. When set to `0`, processes at the original image resolution. This - produces crisper predictions, but may also lead to the overall loss of global context. The default - value `None` resolves to the optimal value from the model config. - match_input_res (`bool`, *optional*, defaults to `True`): - Resize the prediction to match the input resolution. - Only valid if `processing_res` > 0. - resample_method: (`str`, *optional*, defaults to `bilinear`): - Resampling method used to resize images and predictions. This can be one of `bilinear`, `bicubic` or - `nearest`, defaults to: `bilinear`. - batch_size (`int`, *optional*, defaults to `0`): - Inference batch size, no bigger than `num_ensemble`. - If set to 0, the script will automatically decide the proper batch size. - generator (`torch.Generator`, *optional*, defaults to `None`) - Random generator for initial noise generation. - show_progress_bar (`bool`, *optional*, defaults to `True`): - Display a progress bar of diffusion denoising. - ensemble_kwargs (`dict`, *optional*, defaults to `None`): - Arguments for detailed ensembling settings. - Returns: - `MarigoldNormalsOutput`: Output class for Marigold monocular surface normals estimation pipeline, including: - - **normals_np** (`np.ndarray`) Predicted normals map of shape [3, H, W] with values in the range of [-1, 1] - (unit length vectors) - - **normals_img** (`PIL.Image.Image`) Normals image, with the shape of [H, W, 3] and values in [0, 255] - - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) - coming from ensembling. None if `ensemble_size = 1` - """ - # Model-specific optimal default values leading to fast and reasonable results. - if denoising_steps is None: - denoising_steps = self.default_denoising_steps - if processing_res is None: - processing_res = self.default_processing_resolution - - assert processing_res >= 0 - assert ensemble_size >= 1 - - # Check if denoising step is reasonable - self._check_inference_step(denoising_steps) - - resample_method: InterpolationMode = get_tv_resample_method(resample_method) - - # ----------------- Image Preprocess ----------------- - # Convert to torch tensor - if isinstance(input_image, Image.Image): - input_image = input_image.convert("RGB") - # convert to torch tensor [H, W, rgb] -> [rgb, H, W] - rgb = pil_to_tensor(input_image) - rgb = rgb.unsqueeze(0) # [1, rgb, H, W] - elif isinstance(input_image, torch.Tensor): - rgb = input_image - else: - raise TypeError(f"Unknown input type: {type(input_image) = }") - input_size = rgb.shape - assert ( - 4 == rgb.dim() and 3 == input_size[-3] - ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]" - - # Resize image - if processing_res > 0: - rgb = resize_max_res( - rgb, - max_edge_resolution=processing_res, - resample_method=resample_method, - ) - - # Normalize rgb values - rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] - rgb_norm = rgb_norm.to(self.dtype) - assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 - - # ----------------- Predicting normals ----------------- - # Batch repeated input image - duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1) - single_rgb_dataset = TensorDataset(duplicated_rgb) - if batch_size > 0: - _bs = batch_size - else: - _bs = find_batch_size( - ensemble_size=ensemble_size, - input_res=max(rgb_norm.shape[1:]), - dtype=self.dtype, - ) - - single_rgb_loader = DataLoader( - single_rgb_dataset, batch_size=_bs, shuffle=False - ) - - # Predict normals maps (batched) - target_pred_ls = [] - if show_progress_bar: - iterable = tqdm( - single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False - ) - else: - iterable = single_rgb_loader - for batch in iterable: - (batched_img,) = batch - target_pred_raw = self.single_infer( - rgb_in=batched_img, - num_inference_steps=denoising_steps, - show_pbar=show_progress_bar, - generator=generator, - ) - target_pred_ls.append(target_pred_raw.detach()) - target_preds = torch.concat(target_pred_ls, dim=0) - torch.cuda.empty_cache() # clear vram cache for ensembling - - # ----------------- Test-time ensembling ----------------- - if ensemble_size > 1: - final_pred, pred_uncert = ensemble_normals( - target_preds, - **(ensemble_kwargs or {}), - ) - else: - final_pred = target_preds - pred_uncert = None - - # Resize back to original resolution - if match_input_res: - final_pred = resize( - final_pred, - input_size[-2:], - interpolation=resample_method, - antialias=True, - ) - - # Convert to numpy - final_pred = final_pred.squeeze() - final_pred = final_pred.cpu().numpy() - if pred_uncert is not None: - pred_uncert = pred_uncert.squeeze().cpu().numpy() - - # Clip output range - final_pred = final_pred.clip(-1, 1) - - # Colorize - normals_img = ((final_pred + 1) * 127.5).astype(np.uint8) - normals_img = chw2hwc(normals_img) - normals_img = Image.fromarray(normals_img) - - return MarigoldNormalsOutput( - normals_np=final_pred, - normals_img=normals_img, - uncertainty=pred_uncert, - ) - - def _check_inference_step(self, n_step: int) -> None: - """ - Check if denoising step is reasonable - Args: - n_step (`int`): denoising steps - """ - assert n_step >= 1 - - if isinstance(self.scheduler, DDIMScheduler): - if "trailing" != self.scheduler.config.timestep_spacing: - logging.warning( - f"The loaded `DDIMScheduler` is configured with `timestep_spacing=" - f'"{self.scheduler.config.timestep_spacing}"`; the recommended setting is `"trailing"`. ' - f"This change is backward-compatible and yields better results. " - f"Consider using `prs-eth/marigold-normals-v1-1` for the best experience." - ) - else: - if n_step > 10: - logging.warning( - f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on " - f"the default values." - ) - if not self.scheduler.config.rescale_betas_zero_snr: - logging.warning( - f"The loaded `DDIMScheduler` is configured with `rescale_betas_zero_snr=" - f"{self.scheduler.config.rescale_betas_zero_snr}`; the recommended setting is True. " - f"Consider using `prs-eth/marigold-normals-v1-1` for the best experience." - ) - elif isinstance(self.scheduler, LCMScheduler): - raise RuntimeError( - "This pipeline implementation does not support the LCMScheduler. Please refer to the project " - "README.md for instructions about using LCM." - ) - else: - raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}") - - def encode_empty_text(self): - """ - Encode text embedding for empty prompt - """ - prompt = "" - text_inputs = self.tokenizer( - prompt, - padding="do_not_pad", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) - self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) - - @torch.no_grad() - def single_infer( - self, - rgb_in: torch.Tensor, - num_inference_steps: int, - generator: Union[torch.Generator, None], - show_pbar: bool, - ) -> torch.Tensor: - """ - Perform a single prediction without ensembling. - - Args: - rgb_in (`torch.Tensor`): - Input RGB image. - num_inference_steps (`int`): - Number of diffusion denoisign steps (DDIM) during inference. - show_pbar (`bool`): - Display a progress bar of diffusion denoising. - generator (`torch.Generator`) - Random generator for initial noise generation. - Returns: - `torch.Tensor`: Predicted targets. - """ - device = self.device - rgb_in = rgb_in.to(device) - - # Set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps # [T] - - # Encode image - rgb_latent = self.encode_rgb(rgb_in) # [B, 4, h, w] - - # Noisy latent for outputs - target_latent = torch.randn( - rgb_latent.shape, - device=device, - dtype=self.dtype, - generator=generator, - ) # [B, 4, h, w] - - # Batched empty text embedding - if self.empty_text_embed is None: - self.encode_empty_text() - batch_empty_text_embed = self.empty_text_embed.repeat( - (rgb_latent.shape[0], 1, 1) - ).to(device) # [B, 2, 1024] - - # Denoising loop - if show_pbar: - iterable = tqdm( - enumerate(timesteps), - total=len(timesteps), - leave=False, - desc=" " * 4 + "Diffusion denoising", - ) - else: - iterable = enumerate(timesteps) - - for i, t in iterable: - unet_input = torch.cat( - [rgb_latent, target_latent], dim=1 - ) # this order is important - - # predict the noise residual - noise_pred = self.unet( - unet_input, t, encoder_hidden_states=batch_empty_text_embed - ).sample # [B, 4, h, w] - - # compute the previous noisy sample x_t -> x_t-1 - target_latent = self.scheduler.step( - noise_pred, t, target_latent, generator=generator - ).prev_sample - - normals = self.decode_normals(target_latent) # [B,3,H,W] - - # clip prediction - normals = torch.clip(normals, -1.0, 1.0) - norm = torch.norm(normals, dim=1, keepdim=True) - normals /= norm.clamp(min=1e-6) # [B,3,H,W] - - return normals - - def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: - """ - Encode RGB image into latent. - - Args: - rgb_in (`torch.Tensor`): - Input RGB image to be encoded. - - Returns: - `torch.Tensor`: Image latent. - """ - # encode - h = self.vae.encoder(rgb_in) - moments = self.vae.quant_conv(h) - mean, logvar = torch.chunk(moments, 2, dim=1) - # scale latent - rgb_latent = mean * self.latent_scale_factor - return rgb_latent - - def decode_normals(self, normals_latent: torch.Tensor) -> torch.Tensor: - """ - Decode normals latent into normals map. - - Args: - normals_latent (`torch.Tensor`): - Normals latent to be decoded. - - Returns: - `torch.Tensor`: Decoded normals map. - """ - # scale latent - normals_latent = normals_latent / self.latent_scale_factor - # decode - z = self.vae.post_quant_conv(normals_latent) - stacked = self.vae.decoder(z) - return stacked diff --git a/src/__init__.py b/olbedo/__init__.py similarity index 94% rename from src/__init__.py rename to olbedo/__init__.py index e530afed19f0ee401aea743a259dcd182434b106..6d5f45933de564e2dc041d2c59acd870d74d8a65 100644 --- a/src/__init__.py +++ b/olbedo/__init__.py @@ -27,3 +27,5 @@ # https://github.com/prs-eth/Marigold#-citation # If you find Marigold useful, we kindly ask you to cite our papers. # -------------------------------------------------------------------------- + +from .olbedo_iid_pipeline import OlbedoIIDPipeline, OlbedoIIDOutput # noqa: F401 \ No newline at end of file diff --git a/olbedo/__pycache__/__init__.cpython-310.pyc b/olbedo/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1853bced53488938981a4c8ac6b9231daf6e0151 Binary files /dev/null and b/olbedo/__pycache__/__init__.cpython-310.pyc differ diff --git a/olbedo/__pycache__/olbedo_iid_pipeline.cpython-310.pyc b/olbedo/__pycache__/olbedo_iid_pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55d2d3c781e2d9507e3fa89173d30917b850d59c Binary files /dev/null and b/olbedo/__pycache__/olbedo_iid_pipeline.cpython-310.pyc differ diff --git a/marigold/marigold_iid_pipeline.py b/olbedo/olbedo_iid_pipeline.py similarity index 97% rename from marigold/marigold_iid_pipeline.py rename to olbedo/olbedo_iid_pipeline.py index 8dcd7e3ad132b53bd557bfa07e53fcd4c4a77c4d..2edd5f1a24ab564b28ddfd93b7ac61b2a2beeddd 100644 --- a/marigold/marigold_iid_pipeline.py +++ b/olbedo/olbedo_iid_pipeline.py @@ -81,8 +81,8 @@ class IIDEntry: uncertainty: Optional[np.ndarray] = None -class MarigoldIIDOutput: - """Output class for Marigold Intrinsic Image Decomposition pipelines.""" +class OlbedoIIDOutput: + """Output class for Olbedo Intrinsic Image Decomposition pipelines.""" def __init__(self, target_names: List[str]): """Initialize output container with target names. @@ -165,9 +165,9 @@ class MarigoldIIDOutput: return iter(self.entries) -class MarigoldIIDPipeline(DiffusionPipeline): +class OlbedoIIDPipeline(DiffusionPipeline): """ - Pipeline for Marigold Intrinsic Image Decomposition (IID): https://marigoldcomputervision.github.io. + Pipeline for Olbedo Intrinsic Image Decomposition (IID). This class supports arbitrary number of target modalities with names set in `target_names`. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the @@ -421,7 +421,7 @@ class MarigoldIIDPipeline(DiffusionPipeline): generator: Union[torch.Generator, None] = None, show_progress_bar: bool = True, ensemble_kwargs: Dict = None, - ) -> MarigoldIIDOutput: + ) -> OlbedoIIDOutput: """ Function invoked when calling the pipeline. @@ -453,7 +453,7 @@ class MarigoldIIDPipeline(DiffusionPipeline): ensemble_kwargs (`dict`, *optional*, defaults to `None`): Arguments for detailed ensembling settings. Returns: - `MarigoldIIDOutput`: Output class for Marigold Intrinsic Image Decomposition prediction pipeline. + `OlbedoIIDOutput`: Output class for Olbedo Intrinsic Image Decomposition prediction pipeline. """ # Model-specific optimal default values leading to fast and reasonable results. if denoising_steps is None: @@ -558,14 +558,14 @@ class MarigoldIIDPipeline(DiffusionPipeline): ) # Create output - output = MarigoldIIDOutput(target_names=self.target_names) + output = OlbedoIIDOutput(target_names=self.target_names) self.fill_outputs(output, final_pred, pred_uncert) assert output.is_complete return output def fill_outputs( self, - output: MarigoldIIDOutput, + output: OlbedoIIDOutput, final_pred: torch.Tensor, pred_uncert: Optional[torch.Tensor] = None, ): @@ -597,8 +597,6 @@ class MarigoldIIDPipeline(DiffusionPipeline): f"The loaded `DDIMScheduler` is configured with `timestep_spacing=" f'"{self.scheduler.config.timestep_spacing}"`; the recommended setting is `"trailing"`. ' f"This change is backward-compatible and yields better results. " - f"Consider using `prs-eth/marigold-iid-appearance-v1-1` or `prs-eth/marigold-iid-lighting-v1-1` " - f"for the best experience." ) else: if n_step > 10: @@ -610,8 +608,6 @@ class MarigoldIIDPipeline(DiffusionPipeline): logging.warning( f"The loaded `DDIMScheduler` is configured with `rescale_betas_zero_snr=" f"{self.scheduler.config.rescale_betas_zero_snr}`; the recommended setting is True. " - f"Consider using `prs-eth/marigold-iid-appearance-v1-1` or `prs-eth/marigold-iid-lighting-v1-1` " - f"for the best experience." ) elif isinstance(self.scheduler, LCMScheduler): raise RuntimeError( diff --git a/olbedo/util/__pycache__/batchsize.cpython-310.pyc b/olbedo/util/__pycache__/batchsize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cfecac8428e5d393ed906d25565947dba1249ab Binary files /dev/null and b/olbedo/util/__pycache__/batchsize.cpython-310.pyc differ diff --git a/olbedo/util/__pycache__/ensemble.cpython-310.pyc b/olbedo/util/__pycache__/ensemble.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cfba08d56c246351ec0084020a582838e66f8ed Binary files /dev/null and b/olbedo/util/__pycache__/ensemble.cpython-310.pyc differ diff --git a/olbedo/util/__pycache__/image_util.cpython-310.pyc b/olbedo/util/__pycache__/image_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67446b70fe12c3320c49cdecfd02fac909ff538c Binary files /dev/null and b/olbedo/util/__pycache__/image_util.cpython-310.pyc differ diff --git a/marigold/util/batchsize.py b/olbedo/util/batchsize.py similarity index 100% rename from marigold/util/batchsize.py rename to olbedo/util/batchsize.py diff --git a/marigold/util/ensemble.py b/olbedo/util/ensemble.py similarity index 100% rename from marigold/util/ensemble.py rename to olbedo/util/ensemble.py diff --git a/marigold/util/image_util.py b/olbedo/util/image_util.py similarity index 100% rename from marigold/util/image_util.py rename to olbedo/util/image_util.py diff --git a/requirements++.txt b/requirements++.txt deleted file mode 100644 index 54bcc575f0a71b42b149c1949a8a5d227c0cf10e..0000000000000000000000000000000000000000 --- a/requirements++.txt +++ /dev/null @@ -1,7 +0,0 @@ -h5py -opencv-python -tensorboard -wandb -scikit-learn -xformers==0.0.28 - diff --git a/requirements+.txt b/requirements+.txt deleted file mode 100644 index bdb3181b41057a276097e74f6774954087683b22..0000000000000000000000000000000000000000 --- a/requirements+.txt +++ /dev/null @@ -1,4 +0,0 @@ -omegaconf -pandas -tabulate -torchmetrics \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 34fa658b103d85961843d6efb723f63ccd5f390f..8928aa1a5c01bcc4cb2dd021aa50e6bc7f443483 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ scipy torch==2.4.1 torchvision==0.19.1 transformers>=4.32.1 -opencv-python \ No newline at end of file +rasterio \ No newline at end of file diff --git a/runtime.txt b/runtime.txt new file mode 100644 index 0000000000000000000000000000000000000000..e3d06d763d11faa4e32cf87b7286c0a1fc18ac67 --- /dev/null +++ b/runtime.txt @@ -0,0 +1 @@ +python-3.10.12 \ No newline at end of file diff --git a/script/download_weights.py b/script/download_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..9acc100f9bdb226c5319d6cb9fd01f1ae221c4d6 --- /dev/null +++ b/script/download_weights.py @@ -0,0 +1,65 @@ +import argparse +import os +import shutil +from huggingface_hub import snapshot_download + +available_models = [ + "marigold_appearance/finetuned", + "marigold_appearance/pretrained", + "marigold_lighting/finetuned", + "marigold_lighting/pretrained", + "rgbx/finetuned", + "rgbx/pretrained" +] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="rgbx/finetuned", + choices=available_models, + help="Select model to download (default: rgbx/finetuned)" + ) + parser.add_argument( + "--local_dir", + type=str, + default="checkpoint", + help="Directory to save the model" + ) + + args = parser.parse_args() + + LOCAL_DIR = args.local_dir + selected_model = args.model + + if os.path.exists(LOCAL_DIR): + if os.path.abspath(LOCAL_DIR) in ["/", os.path.expanduser("~")]: + raise ValueError("Refusing to delete critical directory.") + print(f"Removing existing directory: {LOCAL_DIR}") + shutil.rmtree(LOCAL_DIR) + + print(f"Downloading model: {selected_model}") + + + snapshot_download( + repo_id="GDAOSU/olbedo", + allow_patterns=f"{selected_model}/*", + local_dir=LOCAL_DIR, + local_dir_use_symlinks=False, + ) + + src = os.path.join(LOCAL_DIR, *selected_model.split("/")) + + for name in os.listdir(src): + shutil.move( + os.path.join(src, name), + os.path.join(LOCAL_DIR, name) + ) + + top_level_folder = selected_model.split("/")[0] + shutil.rmtree(os.path.join(LOCAL_DIR, top_level_folder), ignore_errors=True) + shutil.rmtree(os.path.join(LOCAL_DIR, ".cache"), ignore_errors=True) + +if __name__ == "__main__": + main() diff --git a/script/iid/run.py b/script/iid/run.py new file mode 100644 index 0000000000000000000000000000000000000000..5b760e0a8035302bf647f24e8ea5e46aec4ab9bd --- /dev/null +++ b/script/iid/run.py @@ -0,0 +1,271 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) + +import argparse +import logging +import numpy as np +import os +import torch +from PIL import Image +from glob import glob +from tqdm.auto import tqdm + +from olbedo import OlbedoIIDPipeline, OlbedoIIDOutput +from olbedo.util.image_util import chw2hwc + +import rasterio + +EXTENSION_LIST = [".jpg", ".jpeg", ".png", ".tif", ".tiff"] + + +if "__main__" == __name__: + logging.basicConfig(level=logging.INFO) + + # -------------------- Arguments -------------------- + parser = argparse.ArgumentParser( + description="Olbedo: Intrinsic Image Decomposition for Large-Scale Outdoor Environments" + ) + parser.add_argument( + "--checkpoint", + type=str, + default="checkpoint", + help="Checkpoint path or hub name.", + ) + parser.add_argument( + "--input_rgb_dir", + type=str, + required=True, + help="Path to the input image folder.", + ) + parser.add_argument( + "--output_dir", type=str, required=True, help="Output directory." + ) + parser.add_argument( + "--denoise_steps", + type=int, + default=4, + help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed. If set to " + "`None`, default value will be read from checkpoint.", + ) + parser.add_argument( + "--processing_res", + type=int, + default=2000, + help="Resolution to which the input is resized before performing estimation. `0` uses the original input " + "resolution; `None` resolves the best default from the model checkpoint. Default: `None`", + ) + parser.add_argument( + "--ensemble_size", + type=int, + default=1, + help="Number of predictions to be ensembled. Default: `1`.", + ) + parser.add_argument( + "--half_precision", + "--fp16", + action="store_true", + help="Run with half-precision (16-bit float), might lead to suboptimal result.", + ) + parser.add_argument( + "--output_processing_res", + action="store_true", + help="Setting this flag will output the result at the effective value of `processing_res`, otherwise the " + "output will be resized to the input resolution.", + ) + parser.add_argument( + "--resample_method", + choices=["bilinear", "bicubic", "nearest"], + default="bilinear", + help="Resampling method used to resize images and predictions. This can be one of `bilinear`, `bicubic` or " + "`nearest`. Default: `bilinear`", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Reproducibility seed. Set to `None` for randomized inference. Default: `None`", + ) + parser.add_argument( + "--batch_size", + type=int, + default=0, + help="Inference batch size. Default: 0 (will be set automatically).", + ) + + parser.add_argument( + "--model", + type=str, + default="rgbx", + choices=["rgbx", "others"], + help="Choose model", + ) + + args = parser.parse_args() + + checkpoint_path = args.checkpoint + input_rgb_dir = args.input_rgb_dir + output_dir = args.output_dir + + denoise_steps = args.denoise_steps + ensemble_size = args.ensemble_size + if ensemble_size > 15: + logging.warning("Running with large ensemble size will be slow.") + half_precision = args.half_precision + + processing_res = args.processing_res + match_input_res = not args.output_processing_res + if 0 == processing_res and match_input_res is False: + logging.warning( + "Processing at native resolution without resizing output might NOT lead to exactly the same resolution, " + "due to the padding and pooling properties of conv layers." + ) + resample_method = args.resample_method + + seed = args.seed + batch_size = args.batch_size + model = args.model + + # -------------------- Preparation -------------------- + # Output directories + + output_dir_vis = os.path.join(output_dir, "albedo") + + os.makedirs(output_dir, exist_ok=True) + os.makedirs(output_dir_vis, exist_ok=True) + logging.info(f"output dir = {output_dir}") + + # -------------------- Device -------------------- + + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + logging.warning("CUDA is not available. Running on CPU will be slow.") + logging.info(f"device = {device}") + + # -------------------- Data -------------------- + rgb_filename_list = glob(os.path.join(input_rgb_dir, "*")) + rgb_filename_list = [ + f for f in rgb_filename_list if os.path.splitext(f)[1].lower() in EXTENSION_LIST + ] + rgb_filename_list = sorted(rgb_filename_list) + n_images = len(rgb_filename_list) + if n_images > 0: + logging.info(f"Found {n_images} images") + else: + logging.error(f"No image found in '{input_rgb_dir}'") + exit(1) + + # -------------------- Model -------------------- + if half_precision: + dtype = torch.float16 + variant = "fp16" + logging.info( + f"Running with half precision ({dtype}), might lead to suboptimal result." + ) + else: + dtype = torch.float32 + variant = None + + pipe: OlbedoIIDPipeline = OlbedoIIDPipeline.from_pretrained( + checkpoint_path, variant=variant, torch_dtype=dtype + ) + + pipe.mode = model + + try: + pipe.enable_xformers_memory_efficient_attention() + except ImportError: + pass # run without xformers + + pipe = pipe.to(device) + logging.info("Loaded IID pipeline") + + # Print out config + logging.info( + f"Inference settings: checkpoint = `{checkpoint_path}`, " + f"with predicted target names: {pipe.target_names}, " + f"denoise_steps = {denoise_steps or pipe.default_denoising_steps}, " + f"ensemble_size = {ensemble_size}, " + f"processing resolution = {processing_res or pipe.default_processing_resolution}, " + f"seed = {seed}; " + ) + + # -------------------- Inference and saving -------------------- + with torch.no_grad(): + os.makedirs(output_dir, exist_ok=True) + + for rgb_path in tqdm(rgb_filename_list, desc="IID Inference", leave=True): + # Read input image + input_image = Image.open(rgb_path) + + # Random number generator + if seed is None: + generator = None + else: + generator = torch.Generator(device=device) + generator.manual_seed(seed) + + # Perform inference + pipe_out: OlbedoIIDOutput = pipe( + input_image, + denoising_steps=denoise_steps, + ensemble_size=ensemble_size, + processing_res=processing_res, + match_input_res=match_input_res, + batch_size=batch_size, + show_progress_bar=False, + resample_method=resample_method, + generator=generator, + ) + + rgb_name_base, rgb_ext = os.path.splitext(os.path.basename(rgb_path)) + for target_name in pipe.target_names: + if target_name!='albedo': + continue + target_entry = pipe_out[target_name] + + if rgb_ext.lower() in [".tif", ".tiff"]: + img = np.array(target_entry.image).transpose((2, 0, 1)) + + with rasterio.open(rgb_path) as src: + profile = src.profile.copy() + + with rasterio.open(os.path.join(output_dir_vis, f"{rgb_name_base}.tif"), 'w', **profile) as dst: + dst.write(img.astype(profile['dtype'])) + else: + target_entry.image.save( + os.path.join(output_dir_vis, f"{rgb_name_base}{rgb_ext}") + ) \ No newline at end of file diff --git a/src/dataset/__init__.py b/src/dataset/__init__.py deleted file mode 100644 index 1ef12299f659d527b0b4ddec1ce361a499d7a957..0000000000000000000000000000000000000000 --- a/src/dataset/__init__.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import os -from typing import Union, List - -from .base_depth_dataset import ( - BaseDepthDataset, - get_pred_name, # noqa: F401 - DatasetMode, -) # noqa: F401 -from .base_iid_dataset import BaseIIDDataset # noqa: F401 -from .base_normals_dataset import BaseNormalsDataset # noqa: F401 -from .diode_dataset import DIODEDepthDataset, DIODENormalsDataset -from .eth3d_dataset import ETH3DDepthDataset -from .hypersim_dataset import ( - HypersimDepthDataset, - HypersimNormalsDataset, - HypersimIIDDataset, -) -from .ibims_dataset import IBimsNormalsDataset -from .interiorverse_dataset import InteriorVerseNormalsDataset, InteriorVerseIIDDataset -from .kitti_dataset import KITTIDepthDataset -from .nyu_dataset import NYUDepthDataset, NYUNormalsDataset -from .oasis_dataset import OasisNormalsDataset -from .scannet_dataset import ScanNetDepthDataset, ScanNetNormalsDataset -from .sintel_dataset import SintelNormalsDataset -from .vkitti_dataset import VirtualKITTIDepthDataset - -dataset_name_class_dict = { - "hypersim_depth": HypersimDepthDataset, - "vkitti_depth": VirtualKITTIDepthDataset, - "nyu_depth": NYUDepthDataset, - "kitti_depth": KITTIDepthDataset, - "eth3d_depth": ETH3DDepthDataset, - "diode_depth": DIODEDepthDataset, - "scannet_depth": ScanNetDepthDataset, - "hypersim_normals": HypersimNormalsDataset, - "interiorverse_normals": InteriorVerseNormalsDataset, - "sintel_normals": SintelNormalsDataset, - "ibims_normals": IBimsNormalsDataset, - "nyu_normals": NYUNormalsDataset, - "scannet_normals": ScanNetNormalsDataset, - "diode_normals": DIODENormalsDataset, - "oasis_normals": OasisNormalsDataset, - "interiorverse_iid": InteriorVerseIIDDataset, - "hypersim_iid": HypersimIIDDataset, -} - - -def get_dataset( - cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs -) -> Union[ - BaseDepthDataset, - BaseIIDDataset, - BaseNormalsDataset, - List[BaseDepthDataset], - List[BaseIIDDataset], - List[BaseNormalsDataset], -]: - if "mixed" == cfg_data_split.name: - assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets." - dataset_ls = [ - get_dataset(_cfg, base_data_dir, mode, **kwargs) - for _cfg in cfg_data_split.dataset_list - ] - return dataset_ls - elif cfg_data_split.name in dataset_name_class_dict.keys(): - dataset_class = dataset_name_class_dict[cfg_data_split.name] - dataset = dataset_class( - mode=mode, - filename_ls_path=cfg_data_split.filenames, - dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir), - **cfg_data_split, - **kwargs, - ) - else: - raise NotImplementedError - - return dataset diff --git a/src/dataset/base_depth_dataset.py b/src/dataset/base_depth_dataset.py deleted file mode 100644 index 5b7a9978700a449f84a7c928f1e44a63f9d2062d..0000000000000000000000000000000000000000 --- a/src/dataset/base_depth_dataset.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import io -import numpy as np -import os -import random -import tarfile -import torch -from PIL import Image -from enum import Enum -from torch.utils.data import Dataset -from torchvision.transforms import InterpolationMode, Resize -from typing import Union - -from src.util.depth_transform import DepthNormalizerBase - - -class DatasetMode(Enum): - RGB_ONLY = "rgb_only" - EVAL = "evaluate" - TRAIN = "train" - - -class DepthFileNameMode(Enum): - """Prediction file naming modes""" - - id = 1 # id.png - rgb_id = 2 # rgb_id.png - i_d_rgb = 3 # i_d_1_rgb.png - rgb_i_d = 4 - - -class BaseDepthDataset(Dataset): - def __init__( - self, - mode: DatasetMode, - filename_ls_path: str, - dataset_dir: str, - disp_name: str, - min_depth: float, - max_depth: float, - has_filled_depth: bool, - name_mode: DepthFileNameMode, - depth_transform: Union[DepthNormalizerBase, None] = None, - augmentation_args: dict = None, - resize_to_hw=None, - move_invalid_to_far_plane: bool = True, - rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1], - **kwargs, - ) -> None: - super().__init__() - self.mode = mode - # dataset info - self.filename_ls_path = filename_ls_path - self.dataset_dir = dataset_dir - assert os.path.exists( - self.dataset_dir - ), f"Dataset does not exist at: {self.dataset_dir}" - self.disp_name = disp_name - self.has_filled_depth = has_filled_depth - self.name_mode: DepthFileNameMode = name_mode - self.min_depth = min_depth - self.max_depth = max_depth - - # training arguments - self.depth_transform: DepthNormalizerBase = depth_transform - self.augm_args = augmentation_args - self.resize_to_hw = resize_to_hw - self.rgb_transform = rgb_transform - self.move_invalid_to_far_plane = move_invalid_to_far_plane - - # Load filenames - with open(self.filename_ls_path, "r") as f: - self.filenames = [ - s.split() for s in f.readlines() - ] # [['rgb.png', 'depth.tif'], [], ...] - - # Tar dataset - self.tar_obj = None - self.is_tar = ( - True - if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) - else False - ) - - def __len__(self): - return len(self.filenames) - - def __getitem__(self, index): - rasters, other = self._get_data_item(index) - if DatasetMode.TRAIN == self.mode: - rasters = self._training_preprocess(rasters) - # merge - outputs = rasters - outputs.update(other) - return outputs - - def _get_data_item(self, index): - rgb_rel_path, depth_rel_path, filled_rel_path = self._get_data_path(index=index) - - rasters = {} - - # RGB data - rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) - - # Depth data - if DatasetMode.RGB_ONLY != self.mode: - # load data - depth_data = self._load_depth_data( - depth_rel_path=depth_rel_path, filled_rel_path=filled_rel_path - ) - rasters.update(depth_data) - # valid mask - rasters["valid_mask_raw"] = self._get_valid_mask( - rasters["depth_raw_linear"] - ).clone() - rasters["valid_mask_filled"] = self._get_valid_mask( - rasters["depth_filled_linear"] - ).clone() - - other = {"index": index, "rgb_relative_path": rgb_rel_path} - - return rasters, other - - def _load_rgb_data(self, rgb_rel_path): - # Read RGB data - rgb = self._read_rgb_file(rgb_rel_path) - rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] - - outputs = { - "rgb_int": torch.from_numpy(rgb).int(), - "rgb_norm": torch.from_numpy(rgb_norm).float(), - } - return outputs - - def _load_depth_data(self, depth_rel_path, filled_rel_path): - # Read depth data - outputs = {} - depth_raw = self._read_depth_file(depth_rel_path).squeeze() - depth_raw_linear = torch.from_numpy(depth_raw).float().unsqueeze(0) # [1, H, W] - outputs["depth_raw_linear"] = depth_raw_linear.clone() - - if self.has_filled_depth: - depth_filled = self._read_depth_file(filled_rel_path).squeeze() - depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0) - outputs["depth_filled_linear"] = depth_filled_linear - else: - outputs["depth_filled_linear"] = depth_raw_linear.clone() - - return outputs - - def _get_data_path(self, index): - filename_line = self.filenames[index] - - # Get data path - rgb_rel_path = filename_line[0] - - depth_rel_path, filled_rel_path = None, None - if DatasetMode.RGB_ONLY != self.mode: - depth_rel_path = filename_line[1] - if self.has_filled_depth: - filled_rel_path = filename_line[2] - return rgb_rel_path, depth_rel_path, filled_rel_path - - def _read_image(self, img_rel_path) -> np.ndarray: - if self.is_tar: - if self.tar_obj is None: - self.tar_obj = tarfile.open(self.dataset_dir) - image_to_read = self.tar_obj.extractfile("./" + img_rel_path) - image_to_read = image_to_read.read() - image_to_read = io.BytesIO(image_to_read) - else: - image_to_read = os.path.join(self.dataset_dir, img_rel_path) - image = Image.open(image_to_read) # [H, W, rgb] - image = np.asarray(image) - return image - - def _read_rgb_file(self, rel_path) -> np.ndarray: - rgb = self._read_image(rel_path) - rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W] - return rgb - - def _read_depth_file(self, rel_path): - depth_in = self._read_image(rel_path) - # Replace code below to decode depth according to dataset definition - depth_decoded = depth_in - - return depth_decoded - - def _get_valid_mask(self, depth: torch.Tensor): - valid_mask = torch.logical_and( - (depth > self.min_depth), (depth < self.max_depth) - ).bool() - return valid_mask - - def _training_preprocess(self, rasters): - # Augmentation - if self.augm_args is not None: - rasters = self._augment_data(rasters) - - # Normalization - rasters["depth_raw_norm"] = self.depth_transform( - rasters["depth_raw_linear"], rasters["valid_mask_raw"] - ).clone() - rasters["depth_filled_norm"] = self.depth_transform( - rasters["depth_filled_linear"], rasters["valid_mask_filled"] - ).clone() - - # Set invalid pixel to far plane - if self.move_invalid_to_far_plane: - if self.depth_transform.far_plane_at_max: - rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( - self.depth_transform.norm_max - ) - else: - rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( - self.depth_transform.norm_min - ) - - # Resize - if self.resize_to_hw is not None: - resize_transform = Resize( - size=self.resize_to_hw, interpolation=InterpolationMode.NEAREST_EXACT - ) - rasters = {k: resize_transform(v) for k, v in rasters.items()} - - return rasters - - def _augment_data(self, rasters_dict): - # lr flipping - lr_flip_p = self.augm_args.lr_flip_p - if random.random() < lr_flip_p: - rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()} - - return rasters_dict - - def __del__(self): - if hasattr(self, "tar_obj") and self.tar_obj is not None: - self.tar_obj.close() - self.tar_obj = None - - -def get_pred_name(rgb_basename, name_mode, suffix=".png"): - if DepthFileNameMode.rgb_id == name_mode: - pred_basename = "pred_" + rgb_basename.split("_")[1] - elif DepthFileNameMode.i_d_rgb == name_mode: - pred_basename = rgb_basename.replace("_rgb.", "_pred.") - elif DepthFileNameMode.id == name_mode: - pred_basename = "pred_" + rgb_basename - elif DepthFileNameMode.rgb_i_d == name_mode: - pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:]) - else: - raise NotImplementedError - # change suffix - pred_basename = os.path.splitext(pred_basename)[0] + suffix - - return pred_basename diff --git a/src/dataset/base_iid_dataset.py b/src/dataset/base_iid_dataset.py deleted file mode 100644 index f76900ebd1c234c6f9916486bbb3a21b2a0896ec..0000000000000000000000000000000000000000 --- a/src/dataset/base_iid_dataset.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import os - -os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # Enable OpenCV support for EXR -# ruff: noqa: E402 - -import io -import tarfile -import numpy as np -import random -import torch -from torch.utils.data import Dataset -from torchvision.transforms import InterpolationMode, Resize -from .base_depth_dataset import DatasetMode - -from src.util.image_util import ( - img_hwc2chw, - img_linear2srgb, - is_hdr, - read_img_from_file, - read_img_from_tar, -) - - -class BaseIIDDataset(Dataset): - def __init__( - self, - mode: DatasetMode, - filename_ls_path: str, - dataset_dir: str, - disp_name: str, - augmentation_args: dict = None, - resize_to_hw=None, - **kwargs, - ) -> None: - super().__init__() - self.mode = mode - # dataset info - self.filename_ls_path = filename_ls_path - self.dataset_dir = dataset_dir - assert os.path.exists( - self.dataset_dir - ), f"Dataset does not exist at: {self.dataset_dir}" - self.disp_name = disp_name - - # training arguments - self.augm_args = augmentation_args - self.resize_to_hw = resize_to_hw - - # Load filenames - with open(self.filename_ls_path, "r") as f: - self.filenames = [s.split() for s in f.readlines()] - - # Tar dataset - self.tar_obj = None - self.is_tar = ( - True - if os.path.isfile(self.dataset_dir) and tarfile.is_tarfile(self.dataset_dir) - else False - ) - - def __len__(self): - return len(self.filenames) - - def __getitem__(self, index): - rasters, other = self._get_data_item(index) - if DatasetMode.TRAIN == self.mode: - rasters = self._training_preprocess(rasters) - # merge - outputs = rasters - outputs.update(other) - return outputs - - def _get_data_item(self, index): - rgb_rel_path, targets_rel_path = self._get_data_path(index=index) - - rasters = {} - - # RGB data - rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) - - # load targets data should be filled in specialized dataset definition - if DatasetMode.RGB_ONLY != self.mode: - targets_data = self._load_targets_data(rel_paths=targets_rel_path) - rasters.update(targets_data) - - other = {"index": index, "rgb_relative_path": rgb_rel_path} - - return rasters, other - - def _read_image(self, rel_path) -> np.ndarray: - if self.is_tar: - if self.tar_obj is None: - self.tar_obj = tarfile.open(self.dataset_dir) - img = read_img_from_tar(self.tar_obj, rel_path) - else: - img = read_img_from_file(os.path.join(self.dataset_dir, rel_path)) - - if len(img.shape) == 3: # hwc->chw, except for single-channel images - img = img_hwc2chw(img) - - assert img.min() >= 0 and img.max() <= 1 - return img - - def _load_rgb_data(self, rgb_rel_path): - # rgb is in [0,1] range - rgb = self._read_image(rgb_rel_path) - - # SD is pretrained in sRGB space. If we load HDR data, we should also convert to sRGB space. - if is_hdr(rgb_rel_path): - rgb = img_linear2srgb(rgb) - - if rgb.shape[0] == 4: - rgb = rgb[:3, :, :] - - outputs = {"rgb": torch.from_numpy(rgb).float()} # [0,1] - - return outputs - - def _read_numpy(self, rel_path): - if self.is_tar: - if self.tar_obj is None: - self.tar_obj = tarfile.open(self.dataset_dir) - image_to_read = self.tar_obj.extractfile("./" + rel_path) - image_to_read = image_to_read.read() - image_to_read = io.BytesIO(image_to_read) - else: - image_to_read = os.path.join(self.dataset_dir, rel_path) - - image = np.load(image_to_read).transpose((2, 0, 1)) # [3,H,W] - return image - - def _load_targets_data(self, rel_paths): - outputs = {} - return outputs - - def _get_data_path(self, index): - filename_line = self.filenames[index] - - # Only the first is the input image, the rest should be specialized by each dataset. - rgb_rel_path = filename_line[0] - targets_rel_path = filename_line[1:] - - return rgb_rel_path, targets_rel_path - - def _training_preprocess(self, rasters): - # Augmentation - if self.augm_args is not None: - rasters = self._augment_data(rasters) - - # Resize - if self.resize_to_hw is not None: - resize_bilinear = Resize( - size=self.resize_to_hw, interpolation=InterpolationMode.BILINEAR - ) - resize_nearest = Resize( - size=self.resize_to_hw, interpolation=InterpolationMode.NEAREST_EXACT - ) - - rasters = { - k: (resize_nearest(v) if "valid_mask" in k else resize_bilinear(v)) - for k, v in rasters.items() - } - - return rasters - - def _augment_data(self, rasters): - # horizontal flip - if random.random() < self.augm_args.lr_flip_p: - rasters = {k: v.flip(-1) for k, v in rasters.items()} - return rasters - - def __del__(self): - if hasattr(self, "tar_obj") and self.tar_obj is not None: - self.tar_obj.close() - self.tar_obj = None diff --git a/src/dataset/base_normals_dataset.py b/src/dataset/base_normals_dataset.py deleted file mode 100644 index 3da3711f7a184a300b94386a59358b2f32c520b8..0000000000000000000000000000000000000000 --- a/src/dataset/base_normals_dataset.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import io -import numpy as np -import os -import random -import tarfile -import torch -from PIL import Image -import torchvision.transforms.functional as TF -from torch.utils.data import Dataset, get_worker_info -from torchvision.transforms import InterpolationMode, Resize, ColorJitter -from .base_depth_dataset import DatasetMode - - -class BaseNormalsDataset(Dataset): - def __init__( - self, - mode: DatasetMode, - filename_ls_path: str, - dataset_dir: str, - disp_name: str, - augmentation_args: dict = None, - resize_to_hw=None, - **kwargs, - ) -> None: - super().__init__() - self.mode = mode - # dataset info - self.filename_ls_path = filename_ls_path - self.dataset_dir = dataset_dir - assert os.path.exists( - self.dataset_dir - ), f"Dataset does not exist at: {self.dataset_dir}" - self.disp_name = disp_name - - # training arguments - self.augm_args = augmentation_args - self.resize_to_hw = resize_to_hw - - # Load filenames - with open(self.filename_ls_path, "r") as f: - self.filenames = [s.split() for s in f.readlines()] - - # Tar dataset - self.tar_obj = None - self.is_tar = ( - True - if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) - else False - ) - - if self.is_tar: - self.tar_obj = tarfile.open(self.dataset_dir) - - def __len__(self): - return len(self.filenames) - - def __getitem__(self, index): - rasters, other = self._get_data_item(index) - if DatasetMode.TRAIN == self.mode: - rasters = self._training_preprocess(rasters) - # merge - outputs = rasters - outputs.update(other) - return outputs - - def _get_data_item(self, index): - rgb_rel_path, normals_rel_path = self._get_data_path(index=index) - - rasters = {} - - # RGB data - rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) - - # Normals data - if DatasetMode.RGB_ONLY != self.mode: - normals_data = self._load_normals_data(normals_rel_path=normals_rel_path) - rasters.update(normals_data) - - other = {"index": index, "rgb_relative_path": rgb_rel_path} - - return rasters, other - - def _load_rgb_data(self, rgb_rel_path): - # Read RGB data - rgb = self._read_rgb_file(rgb_rel_path) - rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] - - outputs = { - "rgb_int": torch.from_numpy(rgb).int(), - "rgb_norm": torch.from_numpy(rgb_norm).float(), - } - return outputs - - def _load_normals_data(self, normals_rel_path): - outputs = {} - normals = torch.from_numpy( - self._read_normals_file(normals_rel_path) - ).float() # [3,H,W] - outputs["normals"] = normals - - return outputs - - def _get_data_path(self, index): - filename_line = self.filenames[index] - - # Get data path - rgb_rel_path = filename_line[0] - normals_rel_path = filename_line[1] - - return rgb_rel_path, normals_rel_path - - def _read_image(self, img_rel_path) -> np.ndarray: - if self.is_tar: - if self.tar_obj is None: - self.tar_obj = tarfile.open(self.dataset_dir) - image_to_read = self.tar_obj.extractfile("./" + img_rel_path) - image_to_read = image_to_read.read() - image_to_read = io.BytesIO(image_to_read) - else: - image_to_read = os.path.join(self.dataset_dir, img_rel_path) - image = Image.open(image_to_read) # [H, W, rgb] - image = np.asarray(image) - return image - - def _read_rgb_file(self, rel_path) -> np.ndarray: - rgb = self._read_image(rel_path) - rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W] - return rgb - - def _read_normals_file(self, rel_path): - if self.is_tar: - if self.tar_obj is None: - self.tar_obj = tarfile.open(self.dataset_dir) - # normal = self.tar_obj.extractfile(f'./{tar_name}/'+rel_path) - normal = self.tar_obj.extractfile("./" + rel_path) - normal = normal.read() - normal = np.load(io.BytesIO(normal)) # [H, W, 3] - else: - normal_path = os.path.join(self.dataset_dir, rel_path) - normal = np.load(normal_path) - normal = np.transpose(normal, (2, 0, 1)) # [3, H, W] - return normal - - def _training_preprocess(self, rasters): - # Augmentation - if self.augm_args is not None: - rasters = self._augment_data(rasters) - - # Resize - if self.resize_to_hw is not None: - resize_transform = Resize( - size=self.resize_to_hw, interpolation=InterpolationMode.BILINEAR - ) - rasters = {k: resize_transform(v) for k, v in rasters.items()} - - return rasters - - def _augment_data(self, rasters): - # horizontal flip (gt normals have to be flipped too) - if random.random() < self.augm_args.lr_flip_p: - rasters = {k: v.flip(-1) for k, v in rasters.items()} - rasters["normals"][0, :, :] *= -1 - - # if the process is on the main thread, we can use gpu to to the augmentation - use_gpu = get_worker_info() is None - if use_gpu: - rasters = {k: v.cuda() for k, v in rasters.items()} - - # random gaussian blur - if ( - random.random() < self.augm_args.gaussian_blur_p - and rasters["rgb_int"].shape[-2] == 768 - ): # only blur if Hypersim sample - random_rgb_sigma = random.uniform(0.0, self.augm_args.gaussian_blur_sigma) - rasters["rgb_int"] = TF.gaussian_blur( - rasters["rgb_int"], kernel_size=33, sigma=random_rgb_sigma - ).int() - - # motion blur - if ( - random.random() < self.augm_args.motion_blur_p - and rasters["rgb_int"].shape[-2] == 768 - ): # only blur if Hypersim sample - random_kernel_size = random.choice( - [ - x - for x in range(3, self.augm_args.motion_blur_kernel_size + 1) - if x % 2 == 1 - ] - ) - kernel = torch.zeros( - random_kernel_size, - random_kernel_size, - dtype=rasters["rgb_int"].dtype, - device=rasters["rgb_int"].device, - ) - kernel[random_kernel_size // 2, :] = torch.ones(random_kernel_size) - kernel = TF.rotate( - kernel.unsqueeze(0), - random.uniform(0.0, self.augm_args.motion_blur_angle_range), - ) - kernel = kernel / kernel.sum() - channels = rasters["rgb_int"].shape[0] - kernel = kernel.expand(channels, 1, random_kernel_size, random_kernel_size) - rasters["rgb_int"] = ( - torch.conv2d( - rasters["rgb_int"].unsqueeze(0).float(), - kernel, - stride=1, - padding=random_kernel_size // 2, - groups=channels, - ) - .squeeze(0) - .int() - ) - # color jitter - if random.random() < self.augm_args.color_jitter_p: - color_jitter = ColorJitter( - brightness=self.augm_args.jitter_brightness_factor, - contrast=self.augm_args.jitter_contrast_factor, - saturation=self.augm_args.jitter_saturation_factor, - hue=self.augm_args.jitter_hue_factor, - ) - rgb_int_temp = rasters["rgb_int"].float() / 255.0 - rgb_int_temp = color_jitter(rgb_int_temp) - rasters["rgb_int"] = (rgb_int_temp * 255.0).int() - - # update normalized rgb - rasters["rgb_norm"] = rasters["rgb_int"].float() / 255.0 * 2.0 - 1.0 - return rasters - - def __del__(self): - if hasattr(self, "tar_obj") and self.tar_obj is not None: - self.tar_obj.close() - self.tar_obj = None diff --git a/src/dataset/diode_dataset.py b/src/dataset/diode_dataset.py deleted file mode 100644 index 02af4fcf4db15c9ac965d3a200573575595d3a37..0000000000000000000000000000000000000000 --- a/src/dataset/diode_dataset.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import numpy as np -import os -import tarfile -import torch -from io import BytesIO - -from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode, DatasetMode -from .base_normals_dataset import BaseNormalsDataset - - -class DIODEDepthDataset(BaseDepthDataset): - def __init__( - self, - **kwargs, - ) -> None: - super().__init__( - # DIODE data parameter - min_depth=0.6, - max_depth=350, - has_filled_depth=False, - name_mode=DepthFileNameMode.id, - **kwargs, - ) - - def _read_npy_file(self, rel_path): - if self.is_tar: - if self.tar_obj is None: - self.tar_obj = tarfile.open(self.dataset_dir) - fileobj = self.tar_obj.extractfile("./" + rel_path) - npy_path_or_content = BytesIO(fileobj.read()) - else: - npy_path_or_content = os.path.join(self.dataset_dir, rel_path) - data = np.load(npy_path_or_content).squeeze()[np.newaxis, :, :] - return data - - def _read_depth_file(self, rel_path): - depth = self._read_npy_file(rel_path) - return depth - - def _get_data_path(self, index): - return self.filenames[index] - - def _get_data_item(self, index): - # Special: depth mask is read from data - - rgb_rel_path, depth_rel_path, mask_rel_path = self._get_data_path(index=index) - - rasters = {} - - # RGB data - rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) - - # Depth data - if DatasetMode.RGB_ONLY != self.mode: - # load data - depth_data = self._load_depth_data( - depth_rel_path=depth_rel_path, filled_rel_path=None - ) - rasters.update(depth_data) - - # valid mask - mask = self._read_npy_file(mask_rel_path).astype(bool) - mask = torch.from_numpy(mask).bool() - rasters["valid_mask_raw"] = mask.clone() - rasters["valid_mask_filled"] = mask.clone() - - other = {"index": index, "rgb_relative_path": rgb_rel_path} - - return rasters, other - - -class DIODENormalsDataset(BaseNormalsDataset): - def __getitem__(self, index): - return super().__getitem__(index) diff --git a/src/dataset/eth3d_dataset.py b/src/dataset/eth3d_dataset.py deleted file mode 100644 index edd84f5204f7c29e77e7aa45ce5153d68925e97d..0000000000000000000000000000000000000000 --- a/src/dataset/eth3d_dataset.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import numpy as np -import os -import tarfile -import torch - -from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode - - -class ETH3DDepthDataset(BaseDepthDataset): - HEIGHT, WIDTH = 4032, 6048 - - def __init__( - self, - **kwargs, - ) -> None: - super().__init__( - # ETH3D data parameter - min_depth=1e-5, - max_depth=torch.inf, - has_filled_depth=False, - name_mode=DepthFileNameMode.id, - **kwargs, - ) - - def _read_depth_file(self, rel_path): - # Read special binary data: https://www.eth3d.net/documentation#format-of-multi-view-data-image-formats - if self.is_tar: - if self.tar_obj is None: - self.tar_obj = tarfile.open(self.dataset_dir) - binary_data = self.tar_obj.extractfile("./" + rel_path) - binary_data = binary_data.read() - - else: - depth_path = os.path.join(self.dataset_dir, rel_path) - with open(depth_path, "rb") as file: - binary_data = file.read() - # Convert the binary data to a numpy array of 32-bit floats - depth_decoded = np.frombuffer(binary_data, dtype=np.float32).copy() - - depth_decoded[depth_decoded == torch.inf] = 0.0 - - depth_decoded = depth_decoded.reshape((self.HEIGHT, self.WIDTH)) - return depth_decoded diff --git a/src/dataset/hypersim_dataset.py b/src/dataset/hypersim_dataset.py deleted file mode 100644 index b50106d849834206104664319d3e40b5ea0882c9..0000000000000000000000000000000000000000 --- a/src/dataset/hypersim_dataset.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import torch - -from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode -from .base_normals_dataset import BaseNormalsDataset -from .base_iid_dataset import BaseIIDDataset - - -class HypersimDepthDataset(BaseDepthDataset): - def __init__( - self, - **kwargs, - ) -> None: - super().__init__( - # Hypersim data parameter - min_depth=1e-5, - max_depth=65.0, - has_filled_depth=False, - name_mode=DepthFileNameMode.rgb_i_d, - **kwargs, - ) - - def _read_depth_file(self, rel_path): - depth_in = self._read_image(rel_path) - # Decode Hypersim depth - depth_decoded = depth_in / 1000.0 - return depth_decoded - - -class HypersimNormalsDataset(BaseNormalsDataset): - pass - - -class HypersimIIDDataset(BaseIIDDataset): - def _load_targets_data(self, rel_paths): - albedo_path = rel_paths[0] # albedo_cam_00_fr0000.npy - shading_path = rel_paths[1] # shading_cam_00_fr0000.npy - residual_path = rel_paths[2] # residual_cam_00_fr0000.npy - - albedo_raw = self._read_numpy(albedo_path) # in linear space - shading_raw = self._read_numpy(shading_path) - residual_raw = self._read_numpy(residual_path) - - rasters = { - "albedo": torch.from_numpy(albedo_raw).float(), # [0,1] linear space - "shading_raw": torch.from_numpy(shading_raw), - "residual_raw": torch.from_numpy(residual_raw), - } - del albedo_raw, shading_raw, residual_raw - - # get the cut off value for shading/residual - cut_off_value = self._get_cut_off_value(rasters) - # clip and normalize shading and residual (to bring them to the same scale and value range) - rasters = self._process_shading_residual(rasters, cut_off_value) - - # Load masks - valid_mask_albedo, valid_mask_shading, valid_mask_residual = ( - self._get_valid_masks(rasters) - ) - rasters.update( - { - "mask_albedo": valid_mask_albedo.bool(), - "mask_shading": valid_mask_shading.bool(), - "mask_residual": valid_mask_residual.bool(), - } - ) - return rasters - - def _process_shading_residual(self, rasters, cut_off_value): - # Clip by cut_off_value - shading_clipped = torch.clip(rasters["shading_raw"], 0, cut_off_value) - residual_clipped = torch.clip(rasters["residual_raw"], 0, cut_off_value) - # Divide by them same cut off value to bring them to the same scale - shading_norm = shading_clipped / cut_off_value # [0,1] - residual_norm = residual_clipped / cut_off_value # [0,1] - - rasters.update( - { - "shading": shading_norm.float(), - "residual": residual_norm.float(), - } - ) - return rasters - - def _get_cut_off_value(self, rasters): - shading_raw = rasters["shading_raw"] - residual_raw = rasters["residual_raw"] - - # take the maximum of residual_98 and shading_98 as cutoff value - residual_98 = torch.quantile(residual_raw, 0.98) - shading_98 = torch.quantile(shading_raw, 0.98) - cut_off_value = torch.max(torch.tensor([residual_98, shading_98])) - - return cut_off_value - - def _get_valid_masks(self, rasters): - albedo_gt_ts = rasters["albedo"] # [3,H,W] - invalid_mask_albedo = torch.isnan(albedo_gt_ts) | torch.isinf(albedo_gt_ts) - zero_mask = (albedo_gt_ts == 0).all(dim=0, keepdim=True) - zero_mask = zero_mask.expand_as(albedo_gt_ts) - invalid_mask_albedo |= zero_mask - valid_mask_albedo = ~invalid_mask_albedo - - shading_gt_ts = rasters["shading"] - invalid_mask_shading = torch.isnan(shading_gt_ts) | torch.isinf(shading_gt_ts) - valid_mask_shading = ~invalid_mask_shading - - residual_gt_ts = rasters["residual"] - invalid_mask_residual = torch.isnan(residual_gt_ts) | torch.isinf( - residual_gt_ts - ) - valid_mask_residual = ~invalid_mask_residual - - return valid_mask_albedo, valid_mask_shading, valid_mask_residual diff --git a/src/dataset/ibims_dataset.py b/src/dataset/ibims_dataset.py deleted file mode 100644 index a488d2317754cd34b494039d6a8bacbe35340ffa..0000000000000000000000000000000000000000 --- a/src/dataset/ibims_dataset.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -from .base_normals_dataset import BaseNormalsDataset - - -class IBimsNormalsDataset(BaseNormalsDataset): - pass diff --git a/src/dataset/interiorverse_dataset.py b/src/dataset/interiorverse_dataset.py deleted file mode 100644 index 2d4db857d0170a09e44ad69f8974e6223d7c43f1..0000000000000000000000000000000000000000 --- a/src/dataset/interiorverse_dataset.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import numpy as np -import torch - -from src.util.image_util import img_linear2srgb, is_hdr -from .base_iid_dataset import BaseIIDDataset, DatasetMode -from .base_normals_dataset import BaseNormalsDataset - - -class InteriorVerseNormalsDataset(BaseNormalsDataset): - pass - - -# https://github.com/jingsenzhu/IndoorInverseRendering/tree/main/interiorverse -class InteriorVerseIIDDataset(BaseIIDDataset): - def _load_targets_data(self, rel_paths): - albedo_path = rel_paths[0] # 000_albedo.exr - # material_path = rel_paths[1] # 000_material.exr - - albedo_img = self._read_image(albedo_path) - # material_img = self._read_image( - # material_path - # ) # R is roughness, G is metallicity - # material_img[2, :, :] = 0 - - mask_img_squeezed = None - mask_img = None - - if len(rel_paths) == 2: - mask_path = rel_paths[1] # 000_mask.png - mask_img = self._read_image(mask_path) - mask_img = mask_img != 0 # Convert to a boolean np array - if mask_img.ndim == 3: - mask_img_squeezed = np.expand_dims( - np.all(mask_img, axis=0), axis=0 - ) # Convert 3 channel to 1 channel - elif mask_img.ndim == 2: - mask_img_squeezed = np.expand_dims(mask_img, axis=0) - mask_img = np.stack([mask_img] * 3, axis=0) # (3, H, W) - - # SD is pretrained in sRGB space. If we load HDR data, we should also convert to sRGB space. - if is_hdr(albedo_path): - albedo_img = img_linear2srgb(albedo_img) - - # if is_hdr(material_path): - # material_img = img_linear2srgb(material_img) - - if albedo_img.shape[0] == 4: - albedo_img = albedo_img[:3, :, :] - - if len(rel_paths) == 2: - outputs = { - "albedo": torch.from_numpy(albedo_img), - # "material": torch.from_numpy(material_img), - "mask": torch.from_numpy(mask_img_squeezed), - } - else: - outputs = { - "albedo": torch.from_numpy(albedo_img), - } - - # add three channel mask for evaluation - if self.mode == DatasetMode.EVAL: - if len(rel_paths) == 2: - eval_masks = { - "mask_albedo": torch.from_numpy(mask_img).bool(), - # "mask_material": torch.from_numpy(mask_img).bool(), - } - outputs.update(eval_masks) - - return outputs diff --git a/src/dataset/kitti_dataset.py b/src/dataset/kitti_dataset.py deleted file mode 100644 index c92f9315f50c1ab4b3104fa5730a56f11a467d2b..0000000000000000000000000000000000000000 --- a/src/dataset/kitti_dataset.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import torch - -from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode - - -class KITTIDepthDataset(BaseDepthDataset): - def __init__( - self, - kitti_bm_crop, # Crop to KITTI benchmark size - valid_mask_crop, # Evaluation mask. [None, garg or eigen] - **kwargs, - ) -> None: - super().__init__( - # KITTI data parameter - min_depth=1e-5, - max_depth=80, - has_filled_depth=False, - name_mode=DepthFileNameMode.id, - **kwargs, - ) - self.kitti_bm_crop = kitti_bm_crop - self.valid_mask_crop = valid_mask_crop - assert self.valid_mask_crop in [ - None, - "garg", # set evaluation mask according to Garg ECCV16 - "eigen", # set evaluation mask according to Eigen NIPS14 - ], f"Unknown crop type: {self.valid_mask_crop}" - - # Filter out empty depth - self.filenames = [f for f in self.filenames if "None" != f[1]] - - def _read_depth_file(self, rel_path): - depth_in = self._read_image(rel_path) - # Decode KITTI depth - depth_decoded = depth_in / 256.0 - return depth_decoded - - def _load_rgb_data(self, rgb_rel_path): - rgb_data = super()._load_rgb_data(rgb_rel_path) - if self.kitti_bm_crop: - rgb_data = {k: self.kitti_benchmark_crop(v) for k, v in rgb_data.items()} - return rgb_data - - def _load_depth_data(self, depth_rel_path, filled_rel_path): - depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path) - if self.kitti_bm_crop: - depth_data = { - k: self.kitti_benchmark_crop(v) for k, v in depth_data.items() - } - return depth_data - - @staticmethod - def kitti_benchmark_crop(input_img): - """ - Crop images to KITTI benchmark size - Args: - `input_img` (torch.Tensor): Input image to be cropped. - - Returns: - torch.Tensor:Cropped image. - """ - KB_CROP_HEIGHT = 352 - KB_CROP_WIDTH = 1216 - - height, width = input_img.shape[-2:] - top_margin = int(height - KB_CROP_HEIGHT) - left_margin = int((width - KB_CROP_WIDTH) / 2) - if 2 == len(input_img.shape): - out = input_img[ - top_margin : top_margin + KB_CROP_HEIGHT, - left_margin : left_margin + KB_CROP_WIDTH, - ] - elif 3 == len(input_img.shape): - out = input_img[ - :, - top_margin : top_margin + KB_CROP_HEIGHT, - left_margin : left_margin + KB_CROP_WIDTH, - ] - return out - - def _get_valid_mask(self, depth: torch.Tensor): - # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py - valid_mask = super()._get_valid_mask(depth) # [1, H, W] - - if self.valid_mask_crop is not None: - eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() - gt_height, gt_width = eval_mask.shape - - if "garg" == self.valid_mask_crop: - eval_mask[ - int(0.40810811 * gt_height) : int(0.99189189 * gt_height), - int(0.03594771 * gt_width) : int(0.96405229 * gt_width), - ] = 1 - elif "eigen" == self.valid_mask_crop: - eval_mask[ - int(0.3324324 * gt_height) : int(0.91351351 * gt_height), - int(0.0359477 * gt_width) : int(0.96405229 * gt_width), - ] = 1 - - eval_mask.reshape(valid_mask.shape) - valid_mask = torch.logical_and(valid_mask, eval_mask) - return valid_mask diff --git a/src/dataset/mixed_sampler.py b/src/dataset/mixed_sampler.py deleted file mode 100644 index d9514d03396e17602bc88c77aff3a1ac067e8521..0000000000000000000000000000000000000000 --- a/src/dataset/mixed_sampler.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import torch -from torch.utils.data import ( - BatchSampler, - RandomSampler, - SequentialSampler, -) - - -class MixedBatchSampler(BatchSampler): - """Sample one batch from a selected dataset with given probability. - Compatible with datasets at different resolution - """ - - def __init__( - self, src_dataset_ls, batch_size, drop_last, shuffle, prob=None, generator=None - ): - self.base_sampler = None - self.batch_size = batch_size - self.shuffle = shuffle - self.drop_last = drop_last - self.generator = generator - - self.src_dataset_ls = src_dataset_ls - self.n_dataset = len(self.src_dataset_ls) - - # Dataset length - self.dataset_length = [len(ds) for ds in self.src_dataset_ls] - self.cum_dataset_length = [ - sum(self.dataset_length[:i]) for i in range(self.n_dataset) - ] # cumulative dataset length - - # BatchSamplers for each source dataset - if self.shuffle: - self.src_batch_samplers = [ - BatchSampler( - sampler=RandomSampler( - ds, replacement=False, generator=self.generator - ), - batch_size=self.batch_size, - drop_last=self.drop_last, - ) - for ds in self.src_dataset_ls - ] - else: - self.src_batch_samplers = [ - BatchSampler( - sampler=SequentialSampler(ds), - batch_size=self.batch_size, - drop_last=self.drop_last, - ) - for ds in self.src_dataset_ls - ] - self.raw_batches = [ - list(bs) for bs in self.src_batch_samplers - ] # index in original dataset - self.n_batches = [len(b) for b in self.raw_batches] - self.n_total_batch = sum(self.n_batches) - - # sampling probability - if prob is None: - # if not given, decide by dataset length - self.prob = torch.tensor(self.n_batches) / self.n_total_batch - else: - self.prob = torch.as_tensor(prob) - - def __iter__(self): - """_summary_ - - Yields: - list(int): a batch of indics, corresponding to ConcatDataset of src_dataset_ls - """ - for _ in range(self.n_total_batch): - idx_ds = torch.multinomial( - self.prob, 1, replacement=True, generator=self.generator - ).item() - # if batch list is empty, generate new list - if 0 == len(self.raw_batches[idx_ds]): - self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds]) - # get a batch from list - batch_raw = self.raw_batches[idx_ds].pop() - # shift by cumulative dataset length - shift = self.cum_dataset_length[idx_ds] - batch = [n + shift for n in batch_raw] - - yield batch - - def __len__(self): - return self.n_total_batch - - -# Unit test -if "__main__" == __name__: - from torch.utils.data import ConcatDataset, DataLoader, Dataset - - class SimpleDataset(Dataset): - def __init__(self, start, len) -> None: - super().__init__() - self.start = start - self.len = len - - def __len__(self): - return self.len - - def __getitem__(self, index): - return self.start + index - - dataset_1 = SimpleDataset(0, 10) - dataset_2 = SimpleDataset(200, 20) - dataset_3 = SimpleDataset(1000, 50) - - concat_dataset = ConcatDataset( - [dataset_1, dataset_2, dataset_3] - ) # will directly concatenate - - mixed_sampler = MixedBatchSampler( - src_dataset_ls=[dataset_1, dataset_2, dataset_3], - batch_size=4, - drop_last=True, - shuffle=False, - prob=[0.6, 0.3, 0.1], - generator=torch.Generator().manual_seed(0), - ) - - loader = DataLoader(concat_dataset, batch_sampler=mixed_sampler) - - for d in loader: - print(d) diff --git a/src/dataset/nyu_dataset.py b/src/dataset/nyu_dataset.py deleted file mode 100644 index aee3272f6060315c7cace2fa44a2429439f63e68..0000000000000000000000000000000000000000 --- a/src/dataset/nyu_dataset.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import torch - -from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode -from .base_normals_dataset import BaseNormalsDataset - - -class NYUDepthDataset(BaseDepthDataset): - def __init__( - self, - eigen_valid_mask: bool, - **kwargs, - ) -> None: - super().__init__( - # NYUv2 dataset parameter - min_depth=1e-3, - max_depth=10.0, - has_filled_depth=True, - name_mode=DepthFileNameMode.rgb_id, - **kwargs, - ) - - self.eigen_valid_mask = eigen_valid_mask - - def _read_depth_file(self, rel_path): - depth_in = self._read_image(rel_path) - # Decode NYU depth - depth_decoded = depth_in / 1000.0 - return depth_decoded - - def _get_valid_mask(self, depth: torch.Tensor): - valid_mask = super()._get_valid_mask(depth) - - # Eigen crop for evaluation - if self.eigen_valid_mask: - eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() - eval_mask[45:471, 41:601] = 1 - eval_mask.reshape(valid_mask.shape) - valid_mask = torch.logical_and(valid_mask, eval_mask) - - return valid_mask - - -class NYUNormalsDataset(BaseNormalsDataset): - pass diff --git a/src/dataset/oasis_dataset.py b/src/dataset/oasis_dataset.py deleted file mode 100644 index 29476e70b5110663d1fc1468532b271676f87493..0000000000000000000000000000000000000000 --- a/src/dataset/oasis_dataset.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -from .base_normals_dataset import BaseNormalsDataset - - -class OasisNormalsDataset(BaseNormalsDataset): - pass diff --git a/src/dataset/scannet_dataset.py b/src/dataset/scannet_dataset.py deleted file mode 100644 index ba8ed94cc6ec1d8b46aec3522d241979fe38f6b3..0000000000000000000000000000000000000000 --- a/src/dataset/scannet_dataset.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode -from .base_normals_dataset import BaseNormalsDataset - - -class ScanNetDepthDataset(BaseDepthDataset): - def __init__( - self, - **kwargs, - ) -> None: - super().__init__( - # ScanNet data parameter - min_depth=1e-3, - max_depth=10, - has_filled_depth=False, - name_mode=DepthFileNameMode.id, - **kwargs, - ) - - def _read_depth_file(self, rel_path): - depth_in = self._read_image(rel_path) - # Decode ScanNet depth - depth_decoded = depth_in / 1000.0 - return depth_decoded - - -class ScanNetNormalsDataset(BaseNormalsDataset): - pass diff --git a/src/dataset/sintel_dataset.py b/src/dataset/sintel_dataset.py deleted file mode 100644 index 13f721b9f1163d248316807fc7976edfe0c58916..0000000000000000000000000000000000000000 --- a/src/dataset/sintel_dataset.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import torch - -from .base_normals_dataset import BaseNormalsDataset - -# Sintel original resolution -H, W = 436, 1024 - - -# crop to [436,582] --> later upsample with factor 1.1 to [480,640] -# crop off 221 pixels on both sides (1024 - 2*221 = 582) -def center_crop(img): - assert img.shape[0] == 3 or img.shape[0] == 1, "Channel dim should be first dim" - crop = 221 - - out = img[:, :, crop : W - crop] # [3,H,W] - - return out # [3,436,582] - - -class SintelNormalsDataset(BaseNormalsDataset): - def _load_rgb_data(self, rgb_rel_path): - # Read RGB data - rgb = self._read_rgb_file(rgb_rel_path) - rgb = center_crop(rgb) - rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] - - outputs = { - "rgb_int": torch.from_numpy(rgb).int(), - "rgb_norm": torch.from_numpy(rgb_norm).float(), - } - return outputs - - def _load_normals_data(self, normals_rel_path): - outputs = {} - normals = torch.from_numpy( - self._read_normals_file(normals_rel_path) - ).float() # [3,H,W] - - # replace invalid sky values with camera facing normals - valid_normal_mask = torch.norm(normals, p=2, dim=0) > 0.1 - normals[:, ~valid_normal_mask] = torch.tensor( - [0.0, 0.0, 1.0], dtype=normals.dtype - ).view(3, 1) - # crop on both sides - outputs["normals"] = center_crop(normals) - - return outputs diff --git a/src/dataset/vkitti_dataset.py b/src/dataset/vkitti_dataset.py deleted file mode 100644 index e9294f95534699d9b7b3b5845829fba02b9c788f..0000000000000000000000000000000000000000 --- a/src/dataset/vkitti_dataset.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import torch - -from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode -from .kitti_dataset import KITTIDepthDataset - - -class VirtualKITTIDepthDataset(BaseDepthDataset): - def __init__( - self, - kitti_bm_crop, # Crop to KITTI benchmark size - valid_mask_crop, # Evaluation mask. [None, garg or eigen] - **kwargs, - ) -> None: - super().__init__( - # virtual KITTI data parameter - min_depth=1e-5, - max_depth=80, # 655.35 - has_filled_depth=False, - name_mode=DepthFileNameMode.id, - **kwargs, - ) - self.kitti_bm_crop = kitti_bm_crop - self.valid_mask_crop = valid_mask_crop - assert self.valid_mask_crop in [ - None, - "garg", # set evaluation mask according to Garg ECCV16 - "eigen", # set evaluation mask according to Eigen NIPS14 - ], f"Unknown crop type: {self.valid_mask_crop}" - - # Filter out empty depth - self.filenames = [f for f in self.filenames if "None" != f[1]] - - def _read_depth_file(self, rel_path): - depth_in = self._read_image(rel_path) - # Decode vKITTI depth - depth_decoded = depth_in / 100.0 - return depth_decoded - - def _load_rgb_data(self, rgb_rel_path): - rgb_data = super()._load_rgb_data(rgb_rel_path) - if self.kitti_bm_crop: - rgb_data = { - k: KITTIDepthDataset.kitti_benchmark_crop(v) - for k, v in rgb_data.items() - } - return rgb_data - - def _load_depth_data(self, depth_rel_path, filled_rel_path): - depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path) - if self.kitti_bm_crop: - depth_data = { - k: KITTIDepthDataset.kitti_benchmark_crop(v) - for k, v in depth_data.items() - } - return depth_data - - def _get_valid_mask(self, depth: torch.Tensor): - # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py - valid_mask = super()._get_valid_mask(depth) # [1, H, W] - - if self.valid_mask_crop is not None: - eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() - gt_height, gt_width = eval_mask.shape - - if "garg" == self.valid_mask_crop: - eval_mask[ - int(0.40810811 * gt_height) : int(0.99189189 * gt_height), - int(0.03594771 * gt_width) : int(0.96405229 * gt_width), - ] = 1 - elif "eigen" == self.valid_mask_crop: - eval_mask[ - int(0.3324324 * gt_height) : int(0.91351351 * gt_height), - int(0.0359477 * gt_width) : int(0.96405229 * gt_width), - ] = 1 - - eval_mask.reshape(valid_mask.shape) - valid_mask = torch.logical_and(valid_mask, eval_mask) - return valid_mask diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py deleted file mode 100644 index fbee2deea204cd63de3a4e7c8b0cd0c4c1897210..0000000000000000000000000000000000000000 --- a/src/trainer/__init__.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -from .marigold_depth_trainer import MarigoldDepthTrainer -from .marigold_iid_trainer import MarigoldIIDTrainer -from .marigold_normals_trainer import MarigoldNormalsTrainer - - -trainer_cls_name_dict = { - "MarigoldDepthTrainer": MarigoldDepthTrainer, - "MarigoldIIDTrainer": MarigoldIIDTrainer, - "MarigoldNormalsTrainer": MarigoldNormalsTrainer, -} - - -def get_trainer_cls(trainer_name): - return trainer_cls_name_dict[trainer_name] diff --git a/src/trainer/marigold_depth_trainer.py b/src/trainer/marigold_depth_trainer.py deleted file mode 100644 index 749f78dfcaa8b96cde946410c71ccd2abad20852..0000000000000000000000000000000000000000 --- a/src/trainer/marigold_depth_trainer.py +++ /dev/null @@ -1,699 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import logging -import numpy as np -import os -import shutil -import torch -from PIL import Image -from datetime import datetime -from diffusers import DDPMScheduler, DDIMScheduler -from omegaconf import OmegaConf -from torch.nn import Conv2d -from torch.nn.parameter import Parameter -from torch.optim import Adam -from torch.optim.lr_scheduler import LambdaLR -from torch.utils.data import DataLoader -from tqdm import tqdm -from typing import List, Union - -from marigold.marigold_depth_pipeline import MarigoldDepthPipeline, MarigoldDepthOutput -from src.util import metric -from src.util.alignment import align_depth_least_square -from src.util.data_loader import skip_first_batches -from src.util.logging_util import tb_logger, eval_dict_to_text -from src.util.loss import get_loss -from src.util.lr_scheduler import IterExponential -from src.util.metric import MetricTracker -from src.util.multi_res_noise import multi_res_noise_like -from src.util.seeding import generate_seed_sequence - - -class MarigoldDepthTrainer: - def __init__( - self, - cfg: OmegaConf, - model: MarigoldDepthPipeline, - train_dataloader: DataLoader, - device, - out_dir_ckpt, - out_dir_eval, - out_dir_vis, - accumulation_steps: int, - val_dataloaders: List[DataLoader] = None, - vis_dataloaders: List[DataLoader] = None, - ): - self.cfg: OmegaConf = cfg - self.model: MarigoldDepthPipeline = model - self.device = device - self.seed: Union[int, None] = ( - self.cfg.trainer.init_seed - ) # used to generate seed sequence, set to `None` to train w/o seeding - self.out_dir_ckpt = out_dir_ckpt - self.out_dir_eval = out_dir_eval - self.out_dir_vis = out_dir_vis - self.train_loader: DataLoader = train_dataloader - self.val_loaders: List[DataLoader] = val_dataloaders - self.vis_loaders: List[DataLoader] = vis_dataloaders - self.accumulation_steps: int = accumulation_steps - - # Adapt input layers - if 8 != self.model.unet.config["in_channels"]: - self._replace_unet_conv_in() - - # Encode empty text prompt - self.model.encode_empty_text() - self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) - - self.model.unet.enable_xformers_memory_efficient_attention() - - # Trainability - self.model.vae.requires_grad_(False) - self.model.text_encoder.requires_grad_(False) - self.model.unet.requires_grad_(True) - - # Optimizer !should be defined after input layer is adapted - lr = self.cfg.lr - self.optimizer = Adam(self.model.unet.parameters(), lr=lr) - - # LR scheduler - lr_func = IterExponential( - total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, - final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, - warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, - ) - self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) - - # Loss - self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) - - # Training noise scheduler - self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_config( - self.model.scheduler.config, - rescale_betas_zero_snr=True, - timestep_spacing="trailing", - ) - - logging.info( - "DDPM training noise scheduler config is updated: " - f"rescale_betas_zero_snr = {self.training_noise_scheduler.config.rescale_betas_zero_snr}, " - f"timestep_spacing = {self.training_noise_scheduler.config.timestep_spacing}" - ) - - self.prediction_type = self.training_noise_scheduler.config.prediction_type - assert ( - self.prediction_type == self.model.scheduler.config.prediction_type - ), "Different prediction types" - self.scheduler_timesteps = ( - self.training_noise_scheduler.config.num_train_timesteps - ) - - # Inference DDIM scheduler (used for validation) - self.model.scheduler = DDIMScheduler.from_config( - self.training_noise_scheduler.config, - ) - - # Eval metrics - self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] - - self.train_metrics = MetricTracker(*["loss"]) - self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) - - # main metric for best checkpoint saving - self.main_val_metric = cfg.validation.main_val_metric - self.main_val_metric_goal = cfg.validation.main_val_metric_goal - - assert ( - self.main_val_metric in cfg.eval.eval_metrics - ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." - - self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 - - # Settings - self.max_epoch = self.cfg.max_epoch - self.max_iter = self.cfg.max_iter - self.gradient_accumulation_steps = accumulation_steps - self.gt_depth_type = self.cfg.gt_depth_type - self.gt_mask_type = self.cfg.gt_mask_type - self.save_period = self.cfg.trainer.save_period - self.backup_period = self.cfg.trainer.backup_period - self.val_period = self.cfg.trainer.validation_period - self.vis_period = self.cfg.trainer.visualization_period - - # Multi-resolution noise - self.apply_multi_res_noise = self.cfg.multi_res_noise is not None - if self.apply_multi_res_noise: - self.mr_noise_strength = self.cfg.multi_res_noise.strength - self.annealed_mr_noise = self.cfg.multi_res_noise.annealed - self.mr_noise_downscale_strategy = ( - self.cfg.multi_res_noise.downscale_strategy - ) - - # Internal variables - self.epoch = 1 - self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training - self.effective_iter = 0 # how many times optimizer.step() is called - self.in_evaluation = False - self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming - - def _replace_unet_conv_in(self): - # replace the first layer to accept 8 in_channels - _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3] - _bias = self.model.unet.conv_in.bias.clone() # [320] - _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s) - # half the activation magnitude - _weight *= 0.5 - # new conv_in channel - _n_convin_out_channel = self.model.unet.conv_in.out_channels - _new_conv_in = Conv2d( - 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) - ) - _new_conv_in.weight = Parameter(_weight) - _new_conv_in.bias = Parameter(_bias) - self.model.unet.conv_in = _new_conv_in - logging.info("Unet conv_in layer is replaced") - # replace config - self.model.unet.config["in_channels"] = 8 - logging.info("Unet config is updated") - return - - def train(self, t_end=None): - logging.info("Start training") - - device = self.device - self.model.to(device) - - if self.in_evaluation: - logging.info( - "Last evaluation was not finished, will do evaluation before continue training." - ) - self.validate() - - self.train_metrics.reset() - accumulated_step = 0 - - for epoch in range(self.epoch, self.max_epoch + 1): - self.epoch = epoch - logging.debug(f"epoch: {self.epoch}") - - # Skip previous batches when resume - for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): - self.model.unet.train() - - # globally consistent random generators - if self.seed is not None: - local_seed = self._get_next_seed() - rand_num_generator = torch.Generator(device=device) - rand_num_generator.manual_seed(local_seed) - else: - rand_num_generator = None - - # >>> With gradient accumulation >>> - - # Get data - rgb = batch["rgb_norm"].to(device) - depth_gt_for_latent = batch[self.gt_depth_type].to(device) - - if self.gt_mask_type is not None: - valid_mask_for_latent = batch[self.gt_mask_type].to(device) - invalid_mask = ~valid_mask_for_latent - valid_mask_down = ~torch.max_pool2d( - invalid_mask.float(), 8, 8 - ).bool() - valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1)) - - batch_size = rgb.shape[0] - - with torch.no_grad(): - # Encode image - rgb_latent = self.encode_rgb(rgb) # [B, 4, h, w] - # Encode GT depth - gt_target_latent = self.encode_depth( - depth_gt_for_latent - ) # [B, 4, h, w] - - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - self.scheduler_timesteps, - (batch_size,), - device=device, - generator=rand_num_generator, - ).long() # [B] - - # Sample noise - if self.apply_multi_res_noise: - strength = self.mr_noise_strength - if self.annealed_mr_noise: - # calculate strength depending on t - strength = strength * (timesteps / self.scheduler_timesteps) - noise = multi_res_noise_like( - gt_target_latent, - strength=strength, - downscale_strategy=self.mr_noise_downscale_strategy, - generator=rand_num_generator, - device=device, - ) - else: - noise = torch.randn( - gt_target_latent.shape, - device=device, - generator=rand_num_generator, - ) # [B, 4, h, w] - - # Add noise to the latents (diffusion forward process) - noisy_latents = self.training_noise_scheduler.add_noise( - gt_target_latent, noise, timesteps - ) # [B, 4, h, w] - - # Text embedding - text_embed = self.empty_text_embed.to(device).repeat( - (batch_size, 1, 1) - ) # [B, 77, 1024] - - # Concat rgb and target latents - cat_latents = torch.cat( - [rgb_latent, noisy_latents], dim=1 - ) # [B, 8, h, w] - cat_latents = cat_latents.float() - - # Predict the noise residual - model_pred = self.model.unet( - cat_latents, timesteps, text_embed - ).sample # [B, 4, h, w] - if torch.isnan(model_pred).any(): - logging.warning("model_pred contains NaN.") - - # Get the target for loss depending on the prediction type - if "sample" == self.prediction_type: - target = gt_target_latent - elif "epsilon" == self.prediction_type: - target = noise - elif "v_prediction" == self.prediction_type: - target = self.training_noise_scheduler.get_velocity( - gt_target_latent, noise, timesteps - ) - else: - raise ValueError(f"Unknown prediction type {self.prediction_type}") - - # Masked latent loss - if self.gt_mask_type is not None: - latent_loss = self.loss( - model_pred[valid_mask_down].float(), - target[valid_mask_down].float(), - ) - else: - latent_loss = self.loss(model_pred.float(), target.float()) - - loss = latent_loss.mean() - - self.train_metrics.update("loss", loss.item()) - - loss = loss / self.gradient_accumulation_steps - loss.backward() - accumulated_step += 1 - - self.n_batch_in_epoch += 1 - # Practical batch end - - # Perform optimization step - if accumulated_step >= self.gradient_accumulation_steps: - self.optimizer.step() - self.lr_scheduler.step() - self.optimizer.zero_grad() - accumulated_step = 0 - - self.effective_iter += 1 - - # Log to tensorboard - accumulated_loss = self.train_metrics.result()["loss"] - tb_logger.log_dict( - { - f"train/{k}": v - for k, v in self.train_metrics.result().items() - }, - global_step=self.effective_iter, - ) - tb_logger.writer.add_scalar( - "lr", - self.lr_scheduler.get_last_lr()[0], - global_step=self.effective_iter, - ) - tb_logger.writer.add_scalar( - "n_batch_in_epoch", - self.n_batch_in_epoch, - global_step=self.effective_iter, - ) - logging.info( - f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}" - ) - self.train_metrics.reset() - - # Per-step callback - self._train_step_callback() - - # End of training - if self.max_iter > 0 and self.effective_iter >= self.max_iter: - self.save_checkpoint( - ckpt_name=self._get_backup_ckpt_name(), - save_train_state=False, - ) - logging.info("Training ended.") - return - # Time's up - elif t_end is not None and datetime.now() >= t_end: - self.save_checkpoint(ckpt_name="latest", save_train_state=True) - logging.info("Time is up, training paused.") - return - - torch.cuda.empty_cache() - # <<< Effective batch end <<< - - # Epoch end - self.n_batch_in_epoch = 0 - - def encode_rgb(self, image_in): - assert len(image_in.shape) == 4 and image_in.shape[1] == 3 - latent = self.model.encode_rgb(image_in) - return latent - - def encode_depth(self, depth_in): - # stack depth into 3-channel - stacked = self.stack_depth_images(depth_in) - # encode using VAE encoder - depth_latent = self.model.encode_rgb(stacked) - return depth_latent - - @staticmethod - def stack_depth_images(depth_in): - if 4 == len(depth_in.shape): - stacked = depth_in.repeat(1, 3, 1, 1) - elif 3 == len(depth_in.shape): - stacked = depth_in.unsqueeze(1).repeat(1, 3, 1, 1) - return stacked - - def _train_step_callback(self): - """Executed after every iteration""" - # Save backup (with a larger interval, without training states) - if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: - self.save_checkpoint( - ckpt_name=self._get_backup_ckpt_name(), save_train_state=False - ) - - _is_latest_saved = False - # Validation - if self.val_period > 0 and 0 == self.effective_iter % self.val_period: - self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished - self.save_checkpoint(ckpt_name="latest", save_train_state=True) - _is_latest_saved = True - self.validate() - self.in_evaluation = False - self.save_checkpoint(ckpt_name="latest", save_train_state=True) - - # Save training checkpoint (can be resumed) - if ( - self.save_period > 0 - and 0 == self.effective_iter % self.save_period - and not _is_latest_saved - ): - self.save_checkpoint(ckpt_name="latest", save_train_state=True) - - # Visualization - if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: - self.visualize() - - def validate(self): - for i, val_loader in enumerate(self.val_loaders): - val_dataset_name = val_loader.dataset.disp_name - val_metric_dict = self.validate_single_dataset( - data_loader=val_loader, metric_tracker=self.val_metrics - ) - logging.info( - f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dict}" - ) - tb_logger.log_dict( - {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dict.items()}, - global_step=self.effective_iter, - ) - # save to file - eval_text = eval_dict_to_text( - val_metrics=val_metric_dict, - dataset_name=val_dataset_name, - sample_list_path=val_loader.dataset.filename_ls_path, - ) - _save_to = os.path.join( - self.out_dir_eval, - f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", - ) - with open(_save_to, "w+") as f: - f.write(eval_text) - - # Update main eval metric - if 0 == i: - main_eval_metric = val_metric_dict[self.main_val_metric] - if ( - "minimize" == self.main_val_metric_goal - and main_eval_metric < self.best_metric - or "maximize" == self.main_val_metric_goal - and main_eval_metric > self.best_metric - ): - self.best_metric = main_eval_metric - logging.info( - f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" - ) - # Save a checkpoint - self.save_checkpoint( - ckpt_name=self._get_backup_ckpt_name(), save_train_state=False - ) - - def visualize(self): - for val_loader in self.vis_loaders: - vis_dataset_name = val_loader.dataset.disp_name - vis_out_dir = os.path.join( - self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name - ) - os.makedirs(vis_out_dir, exist_ok=True) - _ = self.validate_single_dataset( - data_loader=val_loader, - metric_tracker=self.val_metrics, - save_to_dir=vis_out_dir, - ) - - @torch.no_grad() - def validate_single_dataset( - self, - data_loader: DataLoader, - metric_tracker: MetricTracker, - save_to_dir: str = None, - ): - self.model.to(self.device) - metric_tracker.reset() - - # Generate seed sequence for consistent evaluation - val_init_seed = self.cfg.validation.init_seed - val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) - - for i, batch in enumerate( - tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), - start=1, - ): - assert 1 == data_loader.batch_size - # Read input image - rgb_int = batch["rgb_int"] # [B, 3, H, W] - # GT depth - depth_raw_ts = batch["depth_raw_linear"].squeeze() - depth_raw = depth_raw_ts.numpy() - depth_raw_ts = depth_raw_ts.to(self.device) - valid_mask_ts = batch["valid_mask_raw"].squeeze() - valid_mask = valid_mask_ts.numpy() - valid_mask_ts = valid_mask_ts.to(self.device) - - # Random number generator - seed = val_seed_ls.pop() - if seed is None: - generator = None - else: - generator = torch.Generator(device=self.device) - generator.manual_seed(seed) - - # Predict depth - pipe_out: MarigoldDepthOutput = self.model( - rgb_int, - denoising_steps=self.cfg.validation.denoising_steps, - ensemble_size=self.cfg.validation.ensemble_size, - processing_res=self.cfg.validation.processing_res, - match_input_res=self.cfg.validation.match_input_res, - generator=generator, - batch_size=1, # use batch size 1 to increase reproducibility - color_map=None, - show_progress_bar=False, - resample_method=self.cfg.validation.resample_method, - ) - - depth_pred: np.ndarray = pipe_out.depth_np - - if "least_square" == self.cfg.eval.alignment: - depth_pred, scale, shift = align_depth_least_square( - gt_arr=depth_raw, - pred_arr=depth_pred, - valid_mask_arr=valid_mask, - return_scale_shift=True, - max_resolution=self.cfg.eval.align_max_res, - ) - else: - raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}") - - # Clip to dataset min max - depth_pred = np.clip( - depth_pred, - a_min=data_loader.dataset.min_depth, - a_max=data_loader.dataset.max_depth, - ) - - # clip to d > 0 for evaluation - depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) - - # Evaluate - sample_metric = [] - depth_pred_ts = torch.from_numpy(depth_pred).to(self.device) - - for met_func in self.metric_funcs: - _metric_name = met_func.__name__ - _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item() - sample_metric.append(_metric.__str__()) - metric_tracker.update(_metric_name, _metric) - - # Save as 16-bit uint png - if save_to_dir is not None: - img_name = batch["rgb_relative_path"][0].replace("/", "_") - png_save_path = os.path.join(save_to_dir, f"{img_name}.png") - depth_to_save = (pipe_out.depth_np * 65535.0).astype(np.uint16) - Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") - - return metric_tracker.result() - - def _get_next_seed(self): - if 0 == len(self.global_seed_sequence): - self.global_seed_sequence = generate_seed_sequence( - initial_seed=self.seed, - length=self.max_iter * self.gradient_accumulation_steps, - ) - logging.info( - f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" - ) - return self.global_seed_sequence.pop() - - def save_checkpoint(self, ckpt_name, save_train_state): - ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) - logging.info(f"Saving checkpoint to: {ckpt_dir}") - # Backup previous checkpoint - temp_ckpt_dir = None - if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): - temp_ckpt_dir = os.path.join( - os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" - ) - if os.path.exists(temp_ckpt_dir): - shutil.rmtree(temp_ckpt_dir, ignore_errors=True) - os.rename(ckpt_dir, temp_ckpt_dir) - logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") - - # Save UNet - unet_path = os.path.join(ckpt_dir, "unet") - self.model.unet.save_pretrained(unet_path, safe_serialization=True) - logging.info(f"UNet is saved to: {unet_path}") - - # Save scheduler - scheduelr_path = os.path.join(ckpt_dir, "scheduler") - self.model.scheduler.save_pretrained(scheduelr_path) - logging.info(f"Scheduler is saved to: {scheduelr_path}") - - if save_train_state: - state = { - "optimizer": self.optimizer.state_dict(), - "lr_scheduler": self.lr_scheduler.state_dict(), - "config": self.cfg, - "effective_iter": self.effective_iter, - "epoch": self.epoch, - "n_batch_in_epoch": self.n_batch_in_epoch, - "best_metric": self.best_metric, - "in_evaluation": self.in_evaluation, - "global_seed_sequence": self.global_seed_sequence, - } - train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") - torch.save(state, train_state_path) - # iteration indicator - f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") - f.close() - - logging.info(f"Trainer state is saved to: {train_state_path}") - - # Remove temp ckpt - if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): - shutil.rmtree(temp_ckpt_dir, ignore_errors=True) - logging.debug("Old checkpoint backup is removed.") - - def load_checkpoint( - self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True - ): - logging.info(f"Loading checkpoint from: {ckpt_path}") - # Load UNet - _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") - self.model.unet.load_state_dict( - torch.load(_model_path, map_location=self.device) - ) - self.model.unet.to(self.device) - logging.info(f"UNet parameters are loaded from {_model_path}") - - # Load training states - if load_trainer_state: - checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) - self.effective_iter = checkpoint["effective_iter"] - self.epoch = checkpoint["epoch"] - self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] - self.in_evaluation = checkpoint["in_evaluation"] - self.global_seed_sequence = checkpoint["global_seed_sequence"] - - self.best_metric = checkpoint["best_metric"] - - self.optimizer.load_state_dict(checkpoint["optimizer"]) - logging.info(f"optimizer state is loaded from {ckpt_path}") - - if resume_lr_scheduler: - self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) - logging.info(f"LR scheduler state is loaded from {ckpt_path}") - - logging.info( - f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" - ) - return - - def _get_backup_ckpt_name(self): - return f"iter_{self.effective_iter:06d}" diff --git a/src/trainer/marigold_iid_trainer.py b/src/trainer/marigold_iid_trainer.py deleted file mode 100644 index 5250af1897aa11518b0eaa07eab9d3f647229535..0000000000000000000000000000000000000000 --- a/src/trainer/marigold_iid_trainer.py +++ /dev/null @@ -1,993 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import logging -import os -import shutil -import torch -from datetime import datetime -from diffusers import DDPMScheduler, DDIMScheduler -from omegaconf import OmegaConf -from torchmetrics.image import PeakSignalNoiseRatio -from torch.nn import Conv2d -from torch.nn.parameter import Parameter -from torch.optim import Adam -from torch.optim.lr_scheduler import LambdaLR -from torch.utils.data import DataLoader -from tqdm import tqdm -from typing import List, Union, Optional - -from marigold.marigold_iid_pipeline import MarigoldIIDPipeline, MarigoldIIDOutput -from src.util.data_loader import skip_first_batches -from src.util.image_util import ( - img_normalize, - img_float2int, - img_srgb2linear, - img_linear2srgb, -) -from src.util.logging_util import tb_logger, eval_dict_to_text -from src.util.loss import get_loss -from src.util.lr_scheduler import IterExponential -from src.util.metric import MetricTracker -from src.util.multi_res_noise import multi_res_noise_like -from src.util.seeding import generate_seed_sequence -from src.util.metric import compute_iid_metric -import torch.nn as nn -from peft import LoraConfig, get_peft_model, PeftModel - -from diffusers.loaders import ( - LoraLoaderMixin, - TextualInversionLoaderMixin, -) - -class MarigoldIIDTrainer: - def __init__( - self, - cfg: OmegaConf, - model: MarigoldIIDPipeline, - train_dataloader: DataLoader, - device, - out_dir_ckpt, - out_dir_eval, - out_dir_vis, - accumulation_steps: int, - val_dataloaders: List[DataLoader] = None, - vis_dataloaders: List[DataLoader] = None, - ): - self.cfg: OmegaConf = cfg - self.model: MarigoldIIDPipeline = model - self.device = device - self.seed: Union[int, None] = ( - self.cfg.trainer.init_seed - ) # used to generate seed sequence, set to `None` to train w/o seeding - self.out_dir_ckpt = out_dir_ckpt - self.out_dir_eval = out_dir_eval - self.out_dir_vis = out_dir_vis - self.train_loader: DataLoader = train_dataloader - self.val_loaders: List[DataLoader] = val_dataloaders - self.vis_loaders: List[DataLoader] = vis_dataloaders - self.accumulation_steps: int = accumulation_steps - - # Adapt input layers - if 4 * (model.n_targets + 1) != self.model.unet.config["in_channels"]: - self._replace_unet_conv_in_out_multimodal() - - # Encode empty text prompt - self.model.encode_empty_text() - self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) - - self.model.unet.enable_xformers_memory_efficient_attention() - - # Trainability - self.model.vae.requires_grad_(False) - self.model.text_encoder.requires_grad_(False) - self.model.unet.requires_grad_(False) - - unet_lora_config = LoraConfig( - r=8, - lora_alpha=16, - init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0"], - ) - self.model.unet = get_peft_model(self.model.unet, unet_lora_config) - lora_layers = filter(lambda p: p.requires_grad, self.model.unet.parameters()) - - trainable, frozen = 0, 0 - for name, param in self.model.unet.named_parameters(): - if param.requires_grad: - print(f"✅ Trainable: {name}, shape={param.shape}") - trainable += param.numel() - else: - frozen += param.numel() - - print(f"\nTotal trainable params: {trainable:,}") - print(f"Total frozen params: {frozen:,}") - print(f"Trainable ratio: {trainable / (trainable + frozen) * 100:.4f}%") - - - # Optimizer !should be defined after input layer is adapted - lr = self.cfg.lr - # self.optimizer = Adam(self.model.unet.parameters(), lr=lr) - self.optimizer = Adam(lora_layers, lr=lr) - - self.targets_to_eval_in_linear_space = ( - self.cfg.eval.targets_to_eval_in_linear_space - ) - - # LR scheduler - # lr_func = IterExponential( - # total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, - # final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, - # warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, - # ) - - lr_func = lambda step: 1.0 - self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) - - # Loss - self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) - - # Training noise scheduler - self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_config( - self.model.scheduler.config, - rescale_betas_zero_snr=True, - timestep_spacing="trailing", - ) - - logging.info( - "DDPM training noise scheduler config is updated: " - f"rescale_betas_zero_snr = {self.training_noise_scheduler.config.rescale_betas_zero_snr}, " - f"timestep_spacing = {self.training_noise_scheduler.config.timestep_spacing}" - ) - - self.prediction_type = self.training_noise_scheduler.config.prediction_type - assert ( - self.prediction_type == self.model.scheduler.config.prediction_type - ), "Different prediction types" - self.scheduler_timesteps = ( - self.training_noise_scheduler.config.num_train_timesteps - ) - - # Inference DDIM scheduler (used for validation) - self.model.scheduler = DDIMScheduler.from_config( - self.training_noise_scheduler.config, - ) - - self.train_metrics = MetricTracker(*["loss"]) - val_metric_names = [ - f"{target}_{cfg.validation.main_val_metric}" - for target in model.target_names - ] - - self.val_metrics = MetricTracker(*val_metric_names) - self._val_metric = PeakSignalNoiseRatio(data_range=1.0).to(device) - - # main metric for best checkpoint saving - if "albedo" in model.target_names: - self.main_val_metric = "albedo_" + cfg.validation.main_val_metric - else: - self.main_val_metric = ( - model.target_names[0] + "_" + cfg.validation.main_val_metric - ) - - self.main_val_metric_goal = cfg.validation.main_val_metric_goal - - assert ( - self.main_val_metric in val_metric_names - ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." - - self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 - - # Settings - self.max_epoch = self.cfg.max_epoch - self.max_iter = self.cfg.max_iter - self.gradient_accumulation_steps = accumulation_steps - self.gt_mask_type = self.cfg.gt_mask_type - self.save_period = self.cfg.trainer.save_period - self.backup_period = self.cfg.trainer.backup_period - self.val_period = self.cfg.trainer.validation_period - self.vis_period = self.cfg.trainer.visualization_period - self.tokenizer = self.model.tokenizer - self.text_encoder = self.model.text_encoder.to(device) - - # Multi-resolution noise - self.apply_multi_res_noise = self.cfg.multi_res_noise is not None - if self.apply_multi_res_noise: - self.mr_noise_strength = self.cfg.multi_res_noise.strength - self.annealed_mr_noise = self.cfg.multi_res_noise.annealed - self.mr_noise_downscale_strategy = ( - self.cfg.multi_res_noise.downscale_strategy - ) - - # Internal variables - self.epoch = 1 - self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training - self.effective_iter = 0 # how many times optimizer.step() is called - self.in_evaluation = False - self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_ prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logging.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=self.text_encoder.dtype, device=device - ) - - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] - prompt_embeds = torch.cat( - [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] - ) - - return prompt_embeds - - def _replace_unet_conv_in_out_multimodal(self): - n_outputs = self.model.n_targets - # replace the first layer to accept (n+1)*4 in_channels - _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3] - _bias = self.model.unet.conv_in.bias.clone() # [320] - if _weight.shape[1] == 12 or _weight.shape[1] == 16: - _weight = _weight[:, :(n_outputs + 1)*4, :, :] - _n_convin_out_channel = self.model.unet.conv_in.out_channels - _new_conv_in = Conv2d( - (n_outputs + 1) * 4, - _n_convin_out_channel, - kernel_size=(3, 3), - stride=(1, 1), - padding=(1, 1), - ) - _new_conv_in.weight = Parameter(_weight) - _new_conv_in.bias = Parameter(_bias) - self.model.unet.conv_in = _new_conv_in - logging.info("Unet conv_in layer is replaced by selecting first {} channels".format((n_outputs + 1)*4)) - - elif _weight.shape[1] == 4: - _weight = _weight.repeat((1, n_outputs + 1, 1, 1)) # Keep selected channel(s) - # scale the activation magnitude - _weight /= n_outputs + 1 - # new conv_in channel - _n_convin_out_channel = self.model.unet.conv_in.out_channels - _new_conv_in = Conv2d( - (n_outputs + 1) * 4, - _n_convin_out_channel, - kernel_size=(3, 3), - stride=(1, 1), - padding=(1, 1), - ) - _new_conv_in.weight = Parameter(_weight) - _new_conv_in.bias = Parameter(_bias) - self.model.unet.conv_in = _new_conv_in - logging.info("Unet conv_in layer is replaced by repeating the weights") - - # replace the last layer to output n*4 in_channels - _weight = self.model.unet.conv_out.weight.clone() # [4, 320, 3, 3] - _bias = self.model.unet.conv_out.bias.clone() # [4] - if (_weight.shape[0] == 8 and _bias.shape[0] == 8) or (_weight.shape[0] == 12 and _bias.shape[0] == 12): - _weight = _weight[:(n_outputs * 4), :, :, :] - _bias = _bias[:(n_outputs * 4)] - # Since we are repeating output channels, no need to scale the weights here. - _n_convout_in_channel = self.model.unet.conv_out.in_channels - _new_conv_out = Conv2d( - _n_convout_in_channel, - n_outputs * 4, - kernel_size=(3, 3), - stride=(1, 1), - padding=(1, 1), - ) - _new_conv_out.weight = Parameter(_weight) - _new_conv_out.bias = Parameter(_bias) - self.model.unet.conv_out = _new_conv_out - logging.info("Unet conv_out layer is replaced by selecting first {} channels".format(n_outputs * 4)) - elif _weight.shape[0] == 4: - _weight = _weight.repeat((n_outputs, 1, 1, 1)) - _bias = _bias.repeat(n_outputs) - # Since we are repeating output channels, no need to scale the weights here. - _n_convout_in_channel = self.model.unet.conv_out.in_channels - _new_conv_out = Conv2d( - _n_convout_in_channel, - n_outputs * 4, - kernel_size=(3, 3), - stride=(1, 1), - padding=(1, 1), - ) - _new_conv_out.weight = Parameter(_weight) - _new_conv_out.bias = Parameter(_bias) - self.model.unet.conv_out = _new_conv_out - logging.info("Unet conv_out layer is replaced by repeating the weights") - - # replace config - self.model.unet.config["in_channels"] = (n_outputs + 1) * 4 - self.model.unet.config["out_channels"] = n_outputs * 4 - logging.info("Unet config is updated") - return - - def train(self, t_end=None): - logging.info("Start training") - - device = self.device - self.model.to(device) - - if self.in_evaluation: - logging.info( - "Last evaluation was not finished, will do evaluation before continue training." - ) - self.validate() - - self.validate() - self.visualize() - - self.train_metrics.reset() - accumulated_step = 0 - - ch_target_latent = 4 * self.model.n_targets - - for epoch in range(self.epoch, self.max_epoch + 1): - self.epoch = epoch - logging.debug(f"epoch: {self.epoch}") - - # Skip previous batches when resume - for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): - self.model.unet.train() - - # globally consistent random generators - if self.seed is not None: - local_seed = self._get_next_seed() - rand_num_generator = torch.Generator(device=device) - rand_num_generator.manual_seed(local_seed) - else: - rand_num_generator = None - - # >>> With gradient accumulation >>> - - # Get data, send to device, normalize - batch["rgb"] = img_normalize(batch["rgb"].to(device)) - for modality in self.model.target_names: - batch[modality] = img_normalize(batch[modality].to(device)) - - if self.gt_mask_type is not None: - valid_mask_for_latent = batch[self.gt_mask_type].to(device) - invalid_mask = ~valid_mask_for_latent - valid_mask_down = ~torch.max_pool2d( - invalid_mask.float(), 8, 8 - ).bool() - valid_mask_down = valid_mask_down.repeat( - (1, ch_target_latent, 1, 1) - ) - - batch_size = batch["rgb"].shape[0] - - with torch.no_grad(): - # Encode image - rgb_latent = self.encode_rgb(batch["rgb"]) # [B, 4, h, w] - # Encode iid properties - gt_target_latent = torch.cat( - [ - self.encode_rgb(batch[target_name]) - for target_name in self.model.target_names - ], - dim=1, - ) # [B, 4*n_targets, h, w] - - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - self.scheduler_timesteps, - (batch_size,), - device=device, - generator=rand_num_generator, - ).long() # [B] - - # Sample noise - if self.apply_multi_res_noise: - strength = self.mr_noise_strength - if self.annealed_mr_noise: - # calculate strength depending on t - strength = strength * (timesteps / self.scheduler_timesteps) - noise = multi_res_noise_like( - gt_target_latent, - strength=strength, - downscale_strategy=self.mr_noise_downscale_strategy, - generator=rand_num_generator, - device=device, - ) - else: - noise = torch.randn( - gt_target_latent.shape, - device=device, - generator=rand_num_generator, - ) # [B, 4*n_targets, h, w] - - # Add noise to the latents (diffusion forward process) - noisy_latents = self.training_noise_scheduler.add_noise( - gt_target_latent, noise, timesteps - ) # [B, 4*n_targets, h, w] - - # Text embedding - # text_embed = self.empty_text_embed.to(device).repeat( - # (batch_size, 1, 1) - # ) # [B, 77, 1024] - - prompt_embeds = None - prompt_embeds = self._encode_prompt( - "Albedo (diffuse basecolor)", - device, - num_images_per_prompt=1, - do_classifier_free_guidance=False, - negative_prompt=None, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=None, - ) - text_embed = prompt_embeds.to(device).repeat((batch_size, 1, 1)) - - # Concat rgb and target latents - cat_latents = torch.cat( - # [rgb_latent, noisy_latents], dim=1 - [noisy_latents, rgb_latent], dim=1 - ) # [B, 4*n_targets + 4, h, w] - cat_latents = cat_latents.float() - - # Predict the noise residual - model_pred = self.model.unet( - cat_latents, timesteps, text_embed - ).sample # [B, 4*n_targets, h, w] - if torch.isnan(model_pred).any(): - logging.warning("model_pred contains NaN.") - - # Get the target for loss depending on the prediction type - if "sample" == self.prediction_type: - target = gt_target_latent - elif "epsilon" == self.prediction_type: - target = noise - elif "v_prediction" == self.prediction_type: - target = self.training_noise_scheduler.get_velocity( - gt_target_latent, noise, timesteps - ) - else: - raise ValueError(f"Unknown prediction type {self.prediction_type}") - - # Masked latent loss - if self.gt_mask_type is not None: - latent_loss = self.loss( - model_pred[valid_mask_down].float(), - target[valid_mask_down].float(), - ) - else: - latent_loss = self.loss(model_pred.float(), target.float()) - - loss = latent_loss.mean() - - self.train_metrics.update("loss", loss.item()) - - loss = loss / self.gradient_accumulation_steps - loss.backward() - accumulated_step += 1 - - self.n_batch_in_epoch += 1 - # Practical batch end - - # Perform optimization step - if accumulated_step >= self.gradient_accumulation_steps: - self.optimizer.step() - self.lr_scheduler.step() - self.optimizer.zero_grad() - accumulated_step = 0 - - self.effective_iter += 1 - - # Log to tensorboard - accumulated_loss = self.train_metrics.result()["loss"] - tb_logger.log_dict( - { - f"train/{k}": v - for k, v in self.train_metrics.result().items() - }, - global_step=self.effective_iter, - ) - tb_logger.writer.add_scalar( - "lr", - self.lr_scheduler.get_last_lr()[0], - global_step=self.effective_iter, - ) - tb_logger.writer.add_scalar( - "n_batch_in_epoch", - self.n_batch_in_epoch, - global_step=self.effective_iter, - ) - logging.info( - f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}" - ) - self.train_metrics.reset() - - # Per-step callback - self._train_step_callback() - - # End of training - if self.max_iter > 0 and self.effective_iter >= self.max_iter: - self.save_checkpoint( - ckpt_name=self._get_backup_ckpt_name(), - save_train_state=False, - ) - logging.info("Training ended.") - return - # Time's up - elif t_end is not None and datetime.now() >= t_end: - self.save_checkpoint(ckpt_name="latest", save_train_state=True) - logging.info("Time is up, training paused.") - return - - torch.cuda.empty_cache() - # <<< Effective batch end <<< - - # Epoch end - self.n_batch_in_epoch = 0 - - def encode_rgb(self, image_in): - assert len(image_in.shape) == 4 and image_in.shape[1] == 3 - latent = self.model.encode_rgb(image_in) - return latent - - def _train_step_callback(self): - """Executed after every iteration""" - # Save backup (with a larger interval, without training states) - if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: - self.save_checkpoint( - ckpt_name=self._get_backup_ckpt_name(), save_train_state=False - ) - - _is_latest_saved = False - # Validation - if self.val_period > 0 and 0 == self.effective_iter % self.val_period: - self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished - self.save_checkpoint(ckpt_name="latest", save_train_state=True) - _is_latest_saved = True - self.validate() - self.in_evaluation = False - self.save_checkpoint(ckpt_name="latest", save_train_state=True) - - # Save training checkpoint (can be resumed) - if ( - self.save_period > 0 - and 0 == self.effective_iter % self.save_period - and not _is_latest_saved - ): - self.save_checkpoint(ckpt_name="latest", save_train_state=True) - - # Visualization - if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: - self.visualize() - - def validate(self): - for i, val_loader in enumerate(self.val_loaders): - val_dataset_name = val_loader.dataset.disp_name - val_metric_dict = self.validate_single_dataset( - data_loader=val_loader, metric_tracker=self.val_metrics - ) - logging.info( - f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dict}" - ) - tb_logger.log_dict( - {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dict.items()}, - global_step=self.effective_iter, - ) - # save to file - eval_text = eval_dict_to_text( - val_metrics=val_metric_dict, - dataset_name=val_dataset_name, - sample_list_path=val_loader.dataset.filename_ls_path, - ) - _save_to = os.path.join( - self.out_dir_eval, - f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", - ) - with open(_save_to, "w+") as f: - f.write(eval_text) - - # Update main eval metric - if 0 == i: - main_eval_metric = val_metric_dict[self.main_val_metric] - if ( - "minimize" == self.main_val_metric_goal - and main_eval_metric < self.best_metric - or "maximize" == self.main_val_metric_goal - and main_eval_metric > self.best_metric - ): - self.best_metric = main_eval_metric - logging.info( - f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" - ) - # Save a checkpoint - self.save_checkpoint( - ckpt_name=self._get_backup_ckpt_name(), save_train_state=False - ) - else: - self.save_checkpoint( - ckpt_name=self._get_backup_ckpt_name(), save_train_state=False - ) - - def visualize(self): - for val_loader in self.vis_loaders: - vis_dataset_name = val_loader.dataset.disp_name - vis_out_dir = os.path.join( - self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name - ) - os.makedirs(vis_out_dir, exist_ok=True) - _ = self.validate_single_dataset( - data_loader=val_loader, - metric_tracker=self.val_metrics, - save_to_dir=vis_out_dir, - ) - - @torch.no_grad() - def validate_single_dataset( - self, - data_loader: DataLoader, - metric_tracker: MetricTracker, - save_to_dir: str = None, - ): - if hasattr(self.model.unet, "peft_config"): - logging.info("Validation using LoRA-adapted UNet.") - else: - logging.info("Validation using full UNet (no LoRA).") - self.model.to(self.device) - metric_tracker.reset() - - # Generate seed sequence for consistent evaluation - val_init_seed = self.cfg.validation.init_seed - val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) - - for i, batch in enumerate( - tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), - start=1, - ): - assert 1 == data_loader.batch_size - # Read input image - img_int = img_float2int(batch["rgb"]) # [3, H, W] in [0, 255], sRGB space - # GT targets - for target_name in self.model.target_names: - batch[target_name] = batch[target_name].squeeze().to(self.device) - - # Random number generator - seed = val_seed_ls.pop() - if seed is None: - generator = None - else: - generator = torch.Generator(device=self.device) - generator.manual_seed(seed) - - # Predict materials - pipe_out: MarigoldIIDOutput = self.model( - img_int, - denoising_steps=self.cfg.validation.denoising_steps, - ensemble_size=self.cfg.validation.ensemble_size, - processing_res=self.cfg.validation.processing_res, - match_input_res=self.cfg.validation.match_input_res, - generator=generator, - batch_size=1, # use batch size 1 to increase reproducibility - show_progress_bar=False, - resample_method=self.cfg.validation.resample_method, - ) - - for target_name in self.model.target_names: - target_pred = pipe_out[target_name].array - target_pred_ts = ( - torch.from_numpy(target_pred).to(self.device).unsqueeze(0) - ) - target_gt = batch[target_name].to(self.device) - if self.cfg.validation.use_mask: - _mask_name = "mask_" + target_name - valid_mask = batch[_mask_name].to(self.device) - else: - valid_mask = None - if target_name in self.targets_to_eval_in_linear_space: - target_pred_ts = img_srgb2linear(target_pred_ts) - # Hypersim GT and IID Lighting model predictions are in linear space - # We evaluate albedo in sRGB space - if len(self.model.target_names) == 3 and target_name == "albedo": - # linear --> sRGB - target_gt = img_linear2srgb(target_gt) - target_pred_ts = img_linear2srgb(target_pred_ts) - - # eval pnsr - _metric_name = self.cfg.validation.main_val_metric - _metric_target = compute_iid_metric( - target_pred_ts, - target_gt, - target_name, - "psnr", - self._val_metric, - valid_mask=valid_mask, - ) - metric_tracker.update(f"{target_name}_{_metric_name}", _metric_target) - - # Save target as image - if save_to_dir is not None: - img_name = batch["rgb_relative_path"][0].replace("/", "_") - img_name_without_ext = os.path.splitext(img_name)[0] - - # Save target - target_save_path = os.path.join( - save_to_dir, f"{img_name_without_ext}_{target_name}.png" - ) - target_img = pipe_out[target_name].image - target_img.save(target_save_path) - - return metric_tracker.result() - - def _get_next_seed(self): - if 0 == len(self.global_seed_sequence): - self.global_seed_sequence = generate_seed_sequence( - initial_seed=self.seed, - length=self.max_iter * self.gradient_accumulation_steps, - ) - logging.info( - f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" - ) - return self.global_seed_sequence.pop() - - def save_checkpoint(self, ckpt_name, save_train_state): - ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) - logging.info(f"Saving checkpoint to: {ckpt_dir}") - # Backup previous checkpoint - temp_ckpt_dir = None - if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): - temp_ckpt_dir = os.path.join( - os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" - ) - if os.path.exists(temp_ckpt_dir): - shutil.rmtree(temp_ckpt_dir, ignore_errors=True) - os.rename(ckpt_dir, temp_ckpt_dir) - logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") - - # Save UNet - unet_path = os.path.join(ckpt_dir, "unet") - if isinstance(self.model.unet, PeftModel): - print('save lora model') - self.model.unet.save_pretrained(unet_path) - else: - self.model.unet.save_pretrained(unet_path, safe_serialization=True) - logging.info(f"UNet is saved to: {unet_path}") - - # Save scheduler - scheduelr_path = os.path.join(ckpt_dir, "scheduler") - self.model.scheduler.save_pretrained(scheduelr_path) - logging.info(f"Scheduler is saved to: {scheduelr_path}") - - if save_train_state: - state = { - "optimizer": self.optimizer.state_dict(), - "lr_scheduler": self.lr_scheduler.state_dict(), - "config": self.cfg, - "effective_iter": self.effective_iter, - "epoch": self.epoch, - "n_batch_in_epoch": self.n_batch_in_epoch, - "best_metric": self.best_metric, - "in_evaluation": self.in_evaluation, - "global_seed_sequence": self.global_seed_sequence, - } - train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") - torch.save(state, train_state_path) - # iteration indicator - f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") - f.close() - - logging.info(f"Trainer state is saved to: {train_state_path}") - - # Remove temp ckpt - if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): - shutil.rmtree(temp_ckpt_dir, ignore_errors=True) - logging.debug("Old checkpoint backup is removed.") - - def load_checkpoint( - self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True - ): - logging.info(f"Loading checkpoint from: {ckpt_path}") - # Load UNet - _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") - self.model.unet.load_state_dict( - torch.load(_model_path, map_location=self.device) - ) - self.model.unet.to(self.device) - logging.info(f"UNet parameters are loaded from {_model_path}") - - # Load training states - if load_trainer_state: - checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) - self.effective_iter = checkpoint["effective_iter"] - self.epoch = checkpoint["epoch"] - self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] - self.in_evaluation = checkpoint["in_evaluation"] - self.global_seed_sequence = checkpoint["global_seed_sequence"] - - self.best_metric = checkpoint["best_metric"] - - self.optimizer.load_state_dict(checkpoint["optimizer"]) - logging.info(f"optimizer state is loaded from {ckpt_path}") - - if resume_lr_scheduler: - self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) - logging.info(f"LR scheduler state is loaded from {ckpt_path}") - - logging.info( - f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" - ) - return - - def _get_backup_ckpt_name(self): - return f"iter_{self.effective_iter:06d}" diff --git a/src/trainer/marigold_normals_trainer.py b/src/trainer/marigold_normals_trainer.py deleted file mode 100644 index 79f4691a3fb19ea98d83f9fa2e5b3d8f05023f0e..0000000000000000000000000000000000000000 --- a/src/trainer/marigold_normals_trainer.py +++ /dev/null @@ -1,666 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import logging -import numpy as np -import os -import shutil -import torch -from PIL import Image -from datetime import datetime -from diffusers import DDPMScheduler, DDIMScheduler -from omegaconf import OmegaConf -from torch.nn import Conv2d -from torch.nn.parameter import Parameter -from torch.optim import Adam -from torch.optim.lr_scheduler import LambdaLR -from torch.utils.data import DataLoader -from tqdm import tqdm -from typing import List, Union - -from marigold.marigold_normals_pipeline import ( - MarigoldNormalsPipeline, - MarigoldNormalsOutput, -) -from src.util.image_util import img_chw2hwc -from src.util import metric -from src.util.data_loader import skip_first_batches -from src.util.logging_util import tb_logger, eval_dict_to_text -from src.util.loss import get_loss -from src.util.lr_scheduler import IterExponential -from src.util.metric import MetricTracker, compute_cosine_error -from src.util.multi_res_noise import multi_res_noise_like -from src.util.seeding import generate_seed_sequence - - -class MarigoldNormalsTrainer: - def __init__( - self, - cfg: OmegaConf, - model: MarigoldNormalsPipeline, - train_dataloader: DataLoader, - device, - out_dir_ckpt, - out_dir_eval, - out_dir_vis, - accumulation_steps: int, - val_dataloaders: List[DataLoader] = None, - vis_dataloaders: List[DataLoader] = None, - ): - self.cfg: OmegaConf = cfg - self.model: MarigoldNormalsPipeline = model - self.device = device - self.seed: Union[int, None] = ( - self.cfg.trainer.init_seed - ) # used to generate seed sequence, set to `None` to train w/o seeding - self.out_dir_ckpt = out_dir_ckpt - self.out_dir_eval = out_dir_eval - self.out_dir_vis = out_dir_vis - self.train_loader: DataLoader = train_dataloader - self.val_loaders: List[DataLoader] = val_dataloaders - self.vis_loaders: List[DataLoader] = vis_dataloaders - self.accumulation_steps: int = accumulation_steps - - # Adapt input layers - if 8 != self.model.unet.config["in_channels"]: - self._replace_unet_conv_in() - - # Encode empty text prompt - self.model.encode_empty_text() - self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) - - self.model.unet.enable_xformers_memory_efficient_attention() - - # Trainability - self.model.vae.requires_grad_(False) - self.model.text_encoder.requires_grad_(False) - self.model.unet.requires_grad_(True) - - # Optimizer !should be defined after input layer is adapted - lr = self.cfg.lr - self.optimizer = Adam(self.model.unet.parameters(), lr=lr) - - # LR scheduler - lr_func = IterExponential( - total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, - final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, - warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, - ) - self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) - - # Loss - self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) - - # Training noise scheduler - self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_config( - self.model.scheduler.config, - rescale_betas_zero_snr=True, - timestep_spacing="trailing", - ) - - logging.info( - "DDPM training noise scheduler config is updated: " - f"rescale_betas_zero_snr = {self.training_noise_scheduler.config.rescale_betas_zero_snr}, " - f"timestep_spacing = {self.training_noise_scheduler.config.timestep_spacing}" - ) - - self.prediction_type = self.training_noise_scheduler.config.prediction_type - assert ( - self.prediction_type == self.model.scheduler.config.prediction_type - ), "Different prediction types" - self.scheduler_timesteps = ( - self.training_noise_scheduler.config.num_train_timesteps - ) - - # Inference DDIM scheduler (used for validation) - self.model.scheduler = DDIMScheduler.from_config( - self.training_noise_scheduler.config, - ) - - # Eval metrics - self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] - - self.train_metrics = MetricTracker(*["loss"]) - self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) - - # main metric for best checkpoint saving - self.main_val_metric = cfg.validation.main_val_metric - self.main_val_metric_goal = cfg.validation.main_val_metric_goal - - assert ( - self.main_val_metric in cfg.eval.eval_metrics - ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." - - self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 - - # Settings - self.max_epoch = self.cfg.max_epoch - self.max_iter = self.cfg.max_iter - self.gradient_accumulation_steps = accumulation_steps - self.gt_normals_type = self.cfg.gt_normals_type - self.gt_mask_type = self.cfg.gt_mask_type - self.save_period = self.cfg.trainer.save_period - self.backup_period = self.cfg.trainer.backup_period - self.val_period = self.cfg.trainer.validation_period - self.vis_period = self.cfg.trainer.visualization_period - - # Multi-resolution noise - self.apply_multi_res_noise = self.cfg.multi_res_noise is not None - if self.apply_multi_res_noise: - self.mr_noise_strength = self.cfg.multi_res_noise.strength - self.annealed_mr_noise = self.cfg.multi_res_noise.annealed - self.mr_noise_downscale_strategy = ( - self.cfg.multi_res_noise.downscale_strategy - ) - - # Internal variables - self.epoch = 1 - self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training - self.effective_iter = 0 # how many times optimizer.step() is called - self.in_evaluation = False - self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming - - def _replace_unet_conv_in(self): - # replace the first layer to accept 8 in_channels - _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3] - _bias = self.model.unet.conv_in.bias.clone() # [320] - _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s) - # half the activation magnitude - _weight *= 0.5 - # new conv_in channel - _n_convin_out_channel = self.model.unet.conv_in.out_channels - _new_conv_in = Conv2d( - 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) - ) - _new_conv_in.weight = Parameter(_weight) - _new_conv_in.bias = Parameter(_bias) - self.model.unet.conv_in = _new_conv_in - logging.info("Unet conv_in layer is replaced") - # replace config - self.model.unet.config["in_channels"] = 8 - logging.info("Unet config is updated") - return - - def train(self, t_end=None): - logging.info("Start training") - - device = self.device - self.model.to(device) - - if self.in_evaluation: - logging.info( - "Last evaluation was not finished, will do evaluation before continue training." - ) - self.validate() - - self.train_metrics.reset() - accumulated_step = 0 - - for epoch in range(self.epoch, self.max_epoch + 1): - self.epoch = epoch - logging.debug(f"epoch: {self.epoch}") - - # Skip previous batches when resume - for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): - self.model.unet.train() - - # globally consistent random generators - if self.seed is not None: - local_seed = self._get_next_seed() - rand_num_generator = torch.Generator(device=device) - rand_num_generator.manual_seed(local_seed) - else: - rand_num_generator = None - - # >>> With gradient accumulation >>> - - # Get data - rgb = batch["rgb_norm"].to(device) - normals_gt_for_latent = batch[self.gt_normals_type].to(device) - - if self.gt_mask_type is not None: - valid_mask_for_latent = batch[self.gt_mask_type].to(device) - invalid_mask = ~valid_mask_for_latent - valid_mask_down = ~torch.max_pool2d( - invalid_mask.float(), 8, 8 - ).bool() - valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1)) - - batch_size = rgb.shape[0] - - with torch.no_grad(): - # Encode image - rgb_latent = self.encode_rgb(rgb) # [B, 4, h, w] - # Encode GT normals - gt_target_latent = self.encode_rgb( - normals_gt_for_latent - ) # [B, 4, h, w] - - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - self.scheduler_timesteps, - (batch_size,), - device=device, - generator=rand_num_generator, - ).long() # [B] - - # Sample noise - if self.apply_multi_res_noise: - strength = self.mr_noise_strength - if self.annealed_mr_noise: - # calculate strength depending on t - strength = strength * (timesteps / self.scheduler_timesteps) - noise = multi_res_noise_like( - gt_target_latent, - strength=strength, - downscale_strategy=self.mr_noise_downscale_strategy, - generator=rand_num_generator, - device=device, - ) - else: - noise = torch.randn( - gt_target_latent.shape, - device=device, - generator=rand_num_generator, - ) # [B, 4, h, w] - - # Add noise to the latents (diffusion forward process) - noisy_latents = self.training_noise_scheduler.add_noise( - gt_target_latent, noise, timesteps - ) # [B, 4, h, w] - - # Text embedding - text_embed = self.empty_text_embed.to(device).repeat( - (batch_size, 1, 1) - ) # [B, 77, 1024] - - # Concat rgb and target latents - cat_latents = torch.cat( - [rgb_latent, noisy_latents], dim=1 - ) # [B, 8, h, w] - cat_latents = cat_latents.float() - - # Predict the noise residual - model_pred = self.model.unet( - cat_latents, timesteps, text_embed - ).sample # [B, 4, h, w] - if torch.isnan(model_pred).any(): - logging.warning("model_pred contains NaN.") - - # Get the target for loss depending on the prediction type - if "sample" == self.prediction_type: - target = gt_target_latent - elif "epsilon" == self.prediction_type: - target = noise - elif "v_prediction" == self.prediction_type: - target = self.training_noise_scheduler.get_velocity( - gt_target_latent, noise, timesteps - ) - else: - raise ValueError(f"Unknown prediction type {self.prediction_type}") - - # Masked latent loss - if self.gt_mask_type is not None: - latent_loss = self.loss( - model_pred[valid_mask_down].float(), - target[valid_mask_down].float(), - ) - else: - latent_loss = self.loss(model_pred.float(), target.float()) - - loss = latent_loss.mean() - - self.train_metrics.update("loss", loss.item()) - - loss = loss / self.gradient_accumulation_steps - loss.backward() - accumulated_step += 1 - - self.n_batch_in_epoch += 1 - # Practical batch end - - # Perform optimization step - if accumulated_step >= self.gradient_accumulation_steps: - self.optimizer.step() - self.lr_scheduler.step() - self.optimizer.zero_grad() - accumulated_step = 0 - - self.effective_iter += 1 - - # Log to tensorboard - accumulated_loss = self.train_metrics.result()["loss"] - tb_logger.log_dict( - { - f"train/{k}": v - for k, v in self.train_metrics.result().items() - }, - global_step=self.effective_iter, - ) - tb_logger.writer.add_scalar( - "lr", - self.lr_scheduler.get_last_lr()[0], - global_step=self.effective_iter, - ) - tb_logger.writer.add_scalar( - "n_batch_in_epoch", - self.n_batch_in_epoch, - global_step=self.effective_iter, - ) - logging.info( - f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}" - ) - self.train_metrics.reset() - - # Per-step callback - self._train_step_callback() - - # End of training - if self.max_iter > 0 and self.effective_iter >= self.max_iter: - self.save_checkpoint( - ckpt_name=self._get_backup_ckpt_name(), - save_train_state=False, - ) - logging.info("Training ended.") - return - # Time's up - elif t_end is not None and datetime.now() >= t_end: - self.save_checkpoint(ckpt_name="latest", save_train_state=True) - logging.info("Time is up, training paused.") - return - - torch.cuda.empty_cache() - # <<< Effective batch end <<< - - # Epoch end - self.n_batch_in_epoch = 0 - - def encode_rgb(self, image_in): - assert len(image_in.shape) == 4 and image_in.shape[1] == 3 - latent = self.model.encode_rgb(image_in) - return latent - - def _train_step_callback(self): - """Executed after every iteration""" - # Save backup (with a larger interval, without training states) - if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: - self.save_checkpoint( - ckpt_name=self._get_backup_ckpt_name(), save_train_state=False - ) - - _is_latest_saved = False - # Validation - if self.val_period > 0 and 0 == self.effective_iter % self.val_period: - self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished - self.save_checkpoint(ckpt_name="latest", save_train_state=True) - _is_latest_saved = True - self.validate() - self.in_evaluation = False - self.save_checkpoint(ckpt_name="latest", save_train_state=True) - - # Save training checkpoint (can be resumed) - if ( - self.save_period > 0 - and 0 == self.effective_iter % self.save_period - and not _is_latest_saved - ): - self.save_checkpoint(ckpt_name="latest", save_train_state=True) - - # Visualization - if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: - self.visualize() - - def validate(self): - for i, val_loader in enumerate(self.val_loaders): - val_dataset_name = val_loader.dataset.disp_name - val_metric_dict = self.validate_single_dataset( - data_loader=val_loader, metric_tracker=self.val_metrics - ) - logging.info( - f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dict}" - ) - tb_logger.log_dict( - {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dict.items()}, - global_step=self.effective_iter, - ) - # save to file - eval_text = eval_dict_to_text( - val_metrics=val_metric_dict, - dataset_name=val_dataset_name, - sample_list_path=val_loader.dataset.filename_ls_path, - ) - _save_to = os.path.join( - self.out_dir_eval, - f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", - ) - with open(_save_to, "w+") as f: - f.write(eval_text) - - # Update main eval metric - if 0 == i: - main_eval_metric = val_metric_dict[self.main_val_metric] - if ( - "minimize" == self.main_val_metric_goal - and main_eval_metric < self.best_metric - or "maximize" == self.main_val_metric_goal - and main_eval_metric > self.best_metric - ): - self.best_metric = main_eval_metric - logging.info( - f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" - ) - # Save a checkpoint - self.save_checkpoint( - ckpt_name=self._get_backup_ckpt_name(), save_train_state=False - ) - - def visualize(self): - for val_loader in self.vis_loaders: - vis_dataset_name = val_loader.dataset.disp_name - vis_out_dir = os.path.join( - self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name - ) - os.makedirs(vis_out_dir, exist_ok=True) - _ = self.validate_single_dataset( - data_loader=val_loader, - metric_tracker=self.val_metrics, - save_to_dir=vis_out_dir, - ) - - @torch.no_grad() - def validate_single_dataset( - self, - data_loader: DataLoader, - metric_tracker: MetricTracker, - save_to_dir: str = None, - ): - self.model.to(self.device) - metric_tracker.reset() - - # Generate seed sequence for consistent evaluation - val_init_seed = self.cfg.validation.init_seed - val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) - - for i, batch in enumerate( - tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), - start=1, - ): - assert 1 == data_loader.batch_size - # Read input image - rgb_int = batch["rgb_int"] # [B, 3, H, W] - # GT normals - normals_gt = batch["normals"].to(self.device) # [B, 3, H, W] - - # Random number generator - seed = val_seed_ls.pop() - if seed is None: - generator = None - else: - generator = torch.Generator(device=self.device) - generator.manual_seed(seed) - - # Predict normals - pipe_out: MarigoldNormalsOutput = self.model( - rgb_int, - denoising_steps=self.cfg.validation.denoising_steps, - ensemble_size=self.cfg.validation.ensemble_size, - processing_res=self.cfg.validation.processing_res, - match_input_res=self.cfg.validation.match_input_res, - generator=generator, - batch_size=1, # use batch size 1 to increase reproducibility - show_progress_bar=False, - resample_method=self.cfg.validation.resample_method, - ) - - normals_pred = pipe_out.normals_np # [3, H, W] - - normals_pred_ts = ( - torch.from_numpy(normals_pred).unsqueeze(0).to(self.device) - ) - cosine_error = compute_cosine_error( - normals_pred_ts, normals_gt, masked=True - ) - sample_metric = [] - - for met_func in self.metric_funcs: - _metric_name = met_func.__name__ - _metric = met_func(cosine_error).item() - sample_metric.append(_metric.__str__()) - metric_tracker.update(_metric_name, _metric) - - # Save predicted normals as images - if save_to_dir is not None: - img_name = batch["rgb_relative_path"][0].replace("/", "_") - png_save_path = os.path.join(save_to_dir, img_name) - normals_to_save = img_chw2hwc(((normals_pred + 1) * 127.5)).astype( - np.uint8 - ) - Image.fromarray(normals_to_save).save(png_save_path) - - return metric_tracker.result() - - def _get_next_seed(self): - if 0 == len(self.global_seed_sequence): - self.global_seed_sequence = generate_seed_sequence( - initial_seed=self.seed, - length=self.max_iter * self.gradient_accumulation_steps, - ) - logging.info( - f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" - ) - return self.global_seed_sequence.pop() - - def save_checkpoint(self, ckpt_name, save_train_state): - ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) - logging.info(f"Saving checkpoint to: {ckpt_dir}") - # Backup previous checkpoint - temp_ckpt_dir = None - if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): - temp_ckpt_dir = os.path.join( - os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" - ) - if os.path.exists(temp_ckpt_dir): - shutil.rmtree(temp_ckpt_dir, ignore_errors=True) - os.rename(ckpt_dir, temp_ckpt_dir) - logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") - - # Save UNet - unet_path = os.path.join(ckpt_dir, "unet") - self.model.unet.save_pretrained(unet_path, safe_serialization=True) - logging.info(f"UNet is saved to: {unet_path}") - - # Save scheduler - scheduelr_path = os.path.join(ckpt_dir, "scheduler") - self.model.scheduler.save_pretrained(scheduelr_path) - logging.info(f"Scheduler is saved to: {scheduelr_path}") - - if save_train_state: - state = { - "optimizer": self.optimizer.state_dict(), - "lr_scheduler": self.lr_scheduler.state_dict(), - "config": self.cfg, - "effective_iter": self.effective_iter, - "epoch": self.epoch, - "n_batch_in_epoch": self.n_batch_in_epoch, - "best_metric": self.best_metric, - "in_evaluation": self.in_evaluation, - "global_seed_sequence": self.global_seed_sequence, - } - train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") - torch.save(state, train_state_path) - # iteration indicator - f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") - f.close() - - logging.info(f"Trainer state is saved to: {train_state_path}") - - # Remove temp ckpt - if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): - shutil.rmtree(temp_ckpt_dir, ignore_errors=True) - logging.debug("Old checkpoint backup is removed.") - - def load_checkpoint( - self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True - ): - logging.info(f"Loading checkpoint from: {ckpt_path}") - # Load UNet - _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") - self.model.unet.load_state_dict( - torch.load(_model_path, map_location=self.device) - ) - self.model.unet.to(self.device) - logging.info(f"UNet parameters are loaded from {_model_path}") - - # Load training states - if load_trainer_state: - checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) - self.effective_iter = checkpoint["effective_iter"] - self.epoch = checkpoint["epoch"] - self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] - self.in_evaluation = checkpoint["in_evaluation"] - self.global_seed_sequence = checkpoint["global_seed_sequence"] - - self.best_metric = checkpoint["best_metric"] - - self.optimizer.load_state_dict(checkpoint["optimizer"]) - logging.info(f"optimizer state is loaded from {ckpt_path}") - - if resume_lr_scheduler: - self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) - logging.info(f"LR scheduler state is loaded from {ckpt_path}") - - logging.info( - f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" - ) - return - - def _get_backup_ckpt_name(self): - return f"iter_{self.effective_iter:06d}" diff --git a/src/util/__pycache__/image_util.cpython-310.pyc b/src/util/__pycache__/image_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e725128be63415d9a2d11979df7d9a898d94bf5 Binary files /dev/null and b/src/util/__pycache__/image_util.cpython-310.pyc differ diff --git a/src/util/__pycache__/seeding.cpython-310.pyc b/src/util/__pycache__/seeding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52fe6627a85f43ad1338ea1f72a52e2bf1ff880a Binary files /dev/null and b/src/util/__pycache__/seeding.cpython-310.pyc differ diff --git a/src/util/alignment.py b/src/util/alignment.py deleted file mode 100644 index d3a73c13d98b2e314a3bd192440e86c46502e81a..0000000000000000000000000000000000000000 --- a/src/util/alignment.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import numpy as np -import torch - - -def align_depth_least_square( - gt_arr: np.ndarray, - pred_arr: np.ndarray, - valid_mask_arr: np.ndarray, - return_scale_shift=True, - max_resolution=None, -): - ori_shape = pred_arr.shape # input shape - - gt = gt_arr.squeeze() # [H, W] - pred = pred_arr.squeeze() - valid_mask = valid_mask_arr.squeeze() - - # Downsample - if max_resolution is not None: - scale_factor = np.min(max_resolution / np.array(ori_shape[-2:])) - if scale_factor < 1: - downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") - gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy() - pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy() - valid_mask = ( - downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float()) - .bool() - .numpy() - ) - - assert ( - gt.shape == pred.shape == valid_mask.shape - ), f"{gt.shape}, {pred.shape}, {valid_mask.shape}" - - gt_masked = gt[valid_mask].reshape((-1, 1)) - pred_masked = pred[valid_mask].reshape((-1, 1)) - - # numpy solver - _ones = np.ones_like(pred_masked) - A = np.concatenate([pred_masked, _ones], axis=-1) - X = np.linalg.lstsq(A, gt_masked, rcond=None)[0] - scale, shift = X - - aligned_pred = pred_arr * scale + shift - - # restore dimensions - aligned_pred = aligned_pred.reshape(ori_shape) - - if return_scale_shift: - return aligned_pred, scale, shift - else: - return aligned_pred - - -def depth2disparity(depth, return_mask=False): - if isinstance(depth, torch.Tensor): - disparity = torch.zeros_like(depth) - elif isinstance(depth, np.ndarray): - disparity = np.zeros_like(depth) - non_negtive_mask = depth > 0 - disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask] - if return_mask: - return disparity, non_negtive_mask - else: - return disparity - - -def disparity2depth(disparity, **kwargs): - return depth2disparity(disparity, **kwargs) diff --git a/src/util/config_util.py b/src/util/config_util.py deleted file mode 100644 index 9d29bbb0599d062d7b2f9146859021b7d627847e..0000000000000000000000000000000000000000 --- a/src/util/config_util.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import omegaconf -from omegaconf import OmegaConf - - -def recursive_load_config(config_path: str) -> OmegaConf: - conf = OmegaConf.load(config_path) - - output_conf = OmegaConf.create({}) - - # Load base config. Later configs on the list will overwrite previous - base_configs = conf.get("base_config", default_value=None) - if base_configs is not None: - assert isinstance(base_configs, omegaconf.listconfig.ListConfig) - for _path in base_configs: - assert ( - _path != config_path - ), "Circulate merging, base_config should not include itself." - _base_conf = recursive_load_config(_path) - output_conf = OmegaConf.merge(output_conf, _base_conf) - - # Merge configs and overwrite values - output_conf = OmegaConf.merge(output_conf, conf) - - return output_conf - - -def find_value_in_omegaconf(search_key, config): - result_list = [] - - if isinstance(config, omegaconf.DictConfig): - for key, value in config.items(): - if key == search_key: - result_list.append(value) - elif isinstance(value, (omegaconf.DictConfig, omegaconf.ListConfig)): - result_list.extend(find_value_in_omegaconf(search_key, value)) - elif isinstance(config, omegaconf.ListConfig): - for item in config: - if isinstance(item, (omegaconf.DictConfig, omegaconf.ListConfig)): - result_list.extend(find_value_in_omegaconf(search_key, item)) - - return result_list - - -if "__main__" == __name__: - conf = recursive_load_config("config/train_base.yaml") - print(OmegaConf.to_yaml(conf)) diff --git a/src/util/data_loader.py b/src/util/data_loader.py deleted file mode 100644 index 226631f16b097129a97357143f5eda597c32f6e4..0000000000000000000000000000000000000000 --- a/src/util/data_loader.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -# Adapted from https://github.com/huggingface/accelerate/blob/e2ae254008061b3e53fc1c97f88d65743a857e75/src/accelerate/data_loader.py - -from torch.utils.data import BatchSampler, DataLoader, IterableDataset - -# kwargs of the DataLoader in min version 1.4.0. -_PYTORCH_DATALOADER_KWARGS = { - "batch_size": 1, - "shuffle": False, - "sampler": None, - "batch_sampler": None, - "num_workers": 0, - "collate_fn": None, - "pin_memory": False, - "drop_last": False, - "timeout": 0, - "worker_init_fn": None, - "multiprocessing_context": None, - "generator": None, - "prefetch_factor": 2, - "persistent_workers": False, -} - - -class SkipBatchSampler(BatchSampler): - """ - A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. - """ - - def __init__(self, batch_sampler, skip_batches=0): - self.batch_sampler = batch_sampler - self.skip_batches = skip_batches - - def __iter__(self): - for index, samples in enumerate(self.batch_sampler): - if index >= self.skip_batches: - yield samples - - @property - def total_length(self): - return len(self.batch_sampler) - - def __len__(self): - return len(self.batch_sampler) - self.skip_batches - - -class SkipDataLoader(DataLoader): - """ - Subclass of a PyTorch `DataLoader` that will skip the first batches. - - Args: - dataset (`torch.utils.data.dataset.Dataset`): - The dataset to use to build this datalaoder. - skip_batches (`int`, *optional*, defaults to 0): - The number of batches to skip at the beginning. - kwargs: - All other keyword arguments to pass to the regular `DataLoader` initialization. - """ - - def __init__(self, dataset, skip_batches=0, **kwargs): - super().__init__(dataset, **kwargs) - self.skip_batches = skip_batches - - def __iter__(self): - for index, batch in enumerate(super().__iter__()): - if index >= self.skip_batches: - yield batch - - -def skip_first_batches(dataloader, num_batches=0): - """ - Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. - """ - dataset = dataloader.dataset - sampler_is_batch_sampler = False - if isinstance(dataset, IterableDataset): - new_batch_sampler = None - else: - sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) - batch_sampler = ( - dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler - ) - new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) - - # We ignore all of those since they are all dealt with by our new_batch_sampler - ignore_kwargs = [ - "batch_size", - "shuffle", - "sampler", - "batch_sampler", - "drop_last", - ] - - kwargs = { - k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) - for k in _PYTORCH_DATALOADER_KWARGS - if k not in ignore_kwargs - } - - # Need to provide batch_size as batch_sampler is None for Iterable dataset - if new_batch_sampler is None: - kwargs["drop_last"] = dataloader.drop_last - kwargs["batch_size"] = dataloader.batch_size - - if new_batch_sampler is None: - # Need to manually skip batches in the dataloader - dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) - else: - dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) - - return dataloader diff --git a/src/util/depth_transform.py b/src/util/depth_transform.py deleted file mode 100644 index 4c12308d5a14f7811d647201e286011f35f16594..0000000000000000000000000000000000000000 --- a/src/util/depth_transform.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import logging -import torch - - -def get_depth_normalizer(cfg_normalizer): - if cfg_normalizer is None: - - def identical(x): - return x - - depth_transform = identical - - elif "scale_shift_depth" == cfg_normalizer.type: - depth_transform = ScaleShiftDepthNormalizer( - norm_min=cfg_normalizer.norm_min, - norm_max=cfg_normalizer.norm_max, - min_max_quantile=cfg_normalizer.min_max_quantile, - clip=cfg_normalizer.clip, - ) - else: - raise NotImplementedError - return depth_transform - - -class DepthNormalizerBase: - is_absolute = None - far_plane_at_max = None - - def __init__( - self, - norm_min=-1.0, - norm_max=1.0, - ) -> None: - self.norm_min = norm_min - self.norm_max = norm_max - raise NotImplementedError - - def __call__(self, depth, valid_mask=None, clip=None): - raise NotImplementedError - - def denormalize(self, depth_norm, **kwargs): - # For metric depth: convert prediction back to metric depth - # For relative depth: convert prediction to [0, 1] - raise NotImplementedError - - -class ScaleShiftDepthNormalizer(DepthNormalizerBase): - """ - Use near and far plane to linearly normalize depth, - i.e. d' = d * s + t, - where near plane is mapped to `norm_min`, and far plane is mapped to `norm_max` - Near and far planes are determined by taking quantile values. - """ - - is_absolute = False - far_plane_at_max = True - - def __init__( - self, norm_min=-1.0, norm_max=1.0, min_max_quantile=0.02, clip=True - ) -> None: - self.norm_min = norm_min - self.norm_max = norm_max - self.norm_range = self.norm_max - self.norm_min - self.min_quantile = min_max_quantile - self.max_quantile = 1.0 - self.min_quantile - self.clip = clip - - def __call__(self, depth_linear, valid_mask=None, clip=None): - clip = clip if clip is not None else self.clip - - if valid_mask is None: - valid_mask = torch.ones_like(depth_linear).bool() - valid_mask = valid_mask & (depth_linear > 0) - - # Take quantiles as min and max - _min, _max = torch.quantile( - depth_linear[valid_mask], - torch.tensor([self.min_quantile, self.max_quantile]), - ) - - # scale and shift - depth_norm_linear = (depth_linear - _min) / ( - _max - _min - ) * self.norm_range + self.norm_min - - if clip: - depth_norm_linear = torch.clip( - depth_norm_linear, self.norm_min, self.norm_max - ) - - return depth_norm_linear - - def scale_back(self, depth_norm): - # scale to [0, 1] - depth_linear = (depth_norm - self.norm_min) / self.norm_range - return depth_linear - - def denormalize(self, depth_norm, **kwargs): - logging.warning(f"{self.__class__} is not revertible without GT") - return self.scale_back(depth_norm=depth_norm) diff --git a/src/util/logging_util.py b/src/util/logging_util.py deleted file mode 100644 index 31c553c599ced0366ca70b19bd99f6e270b0cb1a..0000000000000000000000000000000000000000 --- a/src/util/logging_util.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import logging -import os -import sys -import wandb -from tabulate import tabulate -from torch.utils.tensorboard import SummaryWriter - - -def config_logging(cfg_logging, out_dir=None): - file_level = cfg_logging.get("file_level", 10) - console_level = cfg_logging.get("console_level", 10) - - log_formatter = logging.Formatter(cfg_logging["format"]) - - root_logger = logging.getLogger() - root_logger.handlers.clear() - - root_logger.setLevel(min(file_level, console_level)) - - if out_dir is not None: - _logging_file = os.path.join( - out_dir, cfg_logging.get("filename", "logging.log") - ) - file_handler = logging.FileHandler(_logging_file) - file_handler.setFormatter(log_formatter) - file_handler.setLevel(file_level) - root_logger.addHandler(file_handler) - - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setFormatter(log_formatter) - console_handler.setLevel(console_level) - root_logger.addHandler(console_handler) - - # Avoid pollution by packages - logging.getLogger("PIL").setLevel(logging.INFO) - logging.getLogger("matplotlib").setLevel(logging.INFO) - - -class MyTrainingLogger: - """Tensorboard + wandb logger""" - - writer: SummaryWriter - is_initialized = False - - def __init__(self) -> None: - pass - - def set_dir(self, tb_log_dir): - if self.is_initialized: - raise ValueError("Do not initialize writer twice") - self.writer = SummaryWriter(tb_log_dir) - self.is_initialized = True - - def log_dict(self, scalar_dict, global_step, walltime=None): - for k, v in scalar_dict.items(): - self.writer.add_scalar(k, v, global_step=global_step, walltime=walltime) - return - - -# global instance -tb_logger = MyTrainingLogger() - - -# -------------- wandb tools -------------- -def init_wandb(enable: bool, **kwargs): - if enable: - run = wandb.init(sync_tensorboard=True, **kwargs) - else: - run = wandb.init(mode="disabled") - return run - - -def log_slurm_job_id(step): - global tb_logger - _jobid = os.getenv("SLURM_JOB_ID") - if _jobid is None: - _jobid = -1 - tb_logger.writer.add_scalar("job_id", int(_jobid), global_step=step) - logging.debug(f"Slurm job_id: {_jobid}") - - -def load_wandb_job_id(out_dir): - with open(os.path.join(out_dir, "WANDB_ID"), "r") as f: - wandb_id = f.read() - return wandb_id - - -def save_wandb_job_id(run, out_dir): - with open(os.path.join(out_dir, "WANDB_ID"), "w+") as f: - f.write(run.id) - - -def eval_dict_to_text(val_metrics: dict, dataset_name: str, sample_list_path: str): - eval_text = f"Evaluation metrics:\n\ - on dataset: {dataset_name}\n\ - over samples in: {sample_list_path}\n" - - eval_text += tabulate([val_metrics.keys(), val_metrics.values()]) - return eval_text diff --git a/src/util/loss.py b/src/util/loss.py deleted file mode 100644 index 43c617a147afd2f5525900e447b3262e95a4748a..0000000000000000000000000000000000000000 --- a/src/util/loss.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import torch - - -def get_loss(loss_name, **kwargs): - if "silog_mse" == loss_name: - criterion = SILogMSELoss(**kwargs) - elif "silog_rmse" == loss_name: - criterion = SILogRMSELoss(**kwargs) - elif "mse_loss" == loss_name: - criterion = torch.nn.MSELoss(**kwargs) - elif "l1_loss" == loss_name: - criterion = torch.nn.L1Loss(**kwargs) - elif "l1_loss_with_mask" == loss_name: - criterion = L1LossWithMask(**kwargs) - elif "mean_abs_rel" == loss_name: - criterion = MeanAbsRelLoss() - else: - raise NotImplementedError - - return criterion - - -class L1LossWithMask: - def __init__(self, batch_reduction=False): - self.batch_reduction = batch_reduction - - def __call__(self, depth_pred, depth_gt, valid_mask=None): - diff = depth_pred - depth_gt - if valid_mask is not None: - diff[~valid_mask] = 0 - n = valid_mask.sum((-1, -2)) - else: - n = depth_gt.shape[-2] * depth_gt.shape[-1] - - loss = torch.sum(torch.abs(diff)) / n - if self.batch_reduction: - loss = loss.mean() - return loss - - -class MeanAbsRelLoss: - def __init__(self) -> None: - # super().__init__() - pass - - def __call__(self, pred, gt): - diff = pred - gt - rel_abs = torch.abs(diff / gt) - loss = torch.mean(rel_abs, dim=0) - return loss - - -class SILogMSELoss: - def __init__(self, lamb, log_pred=True, batch_reduction=True): - """Scale Invariant Log MSE Loss - - Args: - lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss - log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred - """ - super(SILogMSELoss, self).__init__() - self.lamb = lamb - self.pred_in_log = log_pred - self.batch_reduction = batch_reduction - - def __call__(self, depth_pred, depth_gt, valid_mask=None): - log_depth_pred = ( - depth_pred if self.pred_in_log else torch.log(torch.clip(depth_pred, 1e-8)) - ) - log_depth_gt = torch.log(depth_gt) - - diff = log_depth_pred - log_depth_gt - if valid_mask is not None: - diff[~valid_mask] = 0 - n = valid_mask.sum((-1, -2)) - else: - n = depth_gt.shape[-2] * depth_gt.shape[-1] - - diff2 = torch.pow(diff, 2) - - first_term = torch.sum(diff2, (-1, -2)) / n - second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) - loss = first_term - second_term - if self.batch_reduction: - loss = loss.mean() - return loss - - -class SILogRMSELoss: - def __init__(self, lamb, alpha, log_pred=True): - """Scale Invariant Log RMSE Loss - - Args: - lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss - alpha: - log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred - """ - super(SILogRMSELoss, self).__init__() - self.lamb = lamb - self.alpha = alpha - self.pred_in_log = log_pred - - def __call__(self, depth_pred, depth_gt, valid_mask): - log_depth_pred = depth_pred if self.pred_in_log else torch.log(depth_pred) - log_depth_gt = torch.log(depth_gt) - # borrowed from https://github.com/aliyun/NeWCRFs - # diff = log_depth_pred[valid_mask] - log_depth_gt[valid_mask] - # return torch.sqrt((diff ** 2).mean() - self.lamb * (diff.mean() ** 2)) * self.alpha - - diff = log_depth_pred - log_depth_gt - if valid_mask is not None: - diff[~valid_mask] = 0 - n = valid_mask.sum((-1, -2)) - else: - n = depth_gt.shape[-2] * depth_gt.shape[-1] - - diff2 = torch.pow(diff, 2) - first_term = torch.sum(diff2, (-1, -2)) / n - second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) - loss = torch.sqrt(first_term - second_term).mean() * self.alpha - return loss diff --git a/src/util/lr_scheduler.py b/src/util/lr_scheduler.py deleted file mode 100644 index f54c8de73449971c828265dbfb163c39f80cf17f..0000000000000000000000000000000000000000 --- a/src/util/lr_scheduler.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import numpy as np - - -class IterExponential: - def __init__(self, total_iter_length, final_ratio, warmup_steps=0) -> None: - """ - Customized iteration-wise exponential scheduler. - Re-calculate for every step, to reduce error accumulation - - Args: - total_iter_length (int): Expected total iteration number - final_ratio (float): Expected LR ratio at n_iter = total_iter_length - """ - self.total_length = total_iter_length - self.effective_length = total_iter_length - warmup_steps - self.final_ratio = final_ratio - self.warmup_steps = warmup_steps - - def __call__(self, n_iter) -> float: - if n_iter < self.warmup_steps: - alpha = 1.0 * n_iter / self.warmup_steps - elif n_iter >= self.total_length: - alpha = self.final_ratio - else: - actual_iter = n_iter - self.warmup_steps - alpha = np.exp( - actual_iter / self.effective_length * np.log(self.final_ratio) - ) - return alpha - - -if "__main__" == __name__: - lr_scheduler = IterExponential( - total_iter_length=50000, final_ratio=0.01, warmup_steps=200 - ) - # lr_scheduler = IterExponential( - # total_iter_length=50000, final_ratio=0.01, warmup_steps=0 - # ) - - x = np.arange(100000) - alphas = [lr_scheduler(i) for i in x] - import matplotlib.pyplot as plt - - plt.plot(alphas) - plt.savefig("lr_scheduler.png") diff --git a/src/util/metric.py b/src/util/metric.py deleted file mode 100644 index 7491c9ec6e624c1dadd515a2c10a3272f34c5c5a..0000000000000000000000000000000000000000 --- a/src/util/metric.py +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import numpy as np -import pandas as pd -import torch - - -# Adapted from: https://github.com/victoresque/pytorch-template/blob/master/utils/util.py -class MetricTracker: - def __init__(self, *keys, writer=None): - self.writer = writer - self._data = pd.DataFrame(index=keys, columns=["total", "counts", "average"]) - self.reset() - - def reset(self): - for col in self._data.columns: - self._data[col].values[:] = 0 - - def update(self, key, value, n=1): - if self.writer is not None: - self.writer.add_scalar(key, value) - self._data.loc[key, "total"] += value * n - self._data.loc[key, "counts"] += n - self._data.loc[key, "average"] = self._data.total[key] / self._data.counts[key] - - def avg(self, key): - return self._data.average[key] - - def result(self): - return dict(self._data.average) - - -# -------------------- Depth Metrics -------------------- - - -def abs_relative_difference(output, target, valid_mask=None): - actual_output = output - actual_target = target - abs_relative_diff = torch.abs(actual_output - actual_target) / actual_target - if valid_mask is not None: - abs_relative_diff[~valid_mask] = 0 - n = valid_mask.sum((-1, -2)) - else: - n = output.shape[-1] * output.shape[-2] - abs_relative_diff = torch.sum(abs_relative_diff, (-1, -2)) / n - return abs_relative_diff.mean() - - -def squared_relative_difference(output, target, valid_mask=None): - actual_output = output - actual_target = target - square_relative_diff = ( - torch.pow(torch.abs(actual_output - actual_target), 2) / actual_target - ) - if valid_mask is not None: - square_relative_diff[~valid_mask] = 0 - n = valid_mask.sum((-1, -2)) - else: - n = output.shape[-1] * output.shape[-2] - square_relative_diff = torch.sum(square_relative_diff, (-1, -2)) / n - return square_relative_diff.mean() - - -def rmse_linear(output, target, valid_mask=None): - actual_output = output - actual_target = target - diff = actual_output - actual_target - if valid_mask is not None: - diff[~valid_mask] = 0 - n = valid_mask.sum((-1, -2)) - else: - n = output.shape[-1] * output.shape[-2] - diff2 = torch.pow(diff, 2) - mse = torch.sum(diff2, (-1, -2)) / n - rmse = torch.sqrt(mse) - return rmse.mean() - - -def rmse_log(output, target, valid_mask=None): - diff = torch.log(output) - torch.log(target) - if valid_mask is not None: - diff[~valid_mask] = 0 - n = valid_mask.sum((-1, -2)) - else: - n = output.shape[-1] * output.shape[-2] - diff2 = torch.pow(diff, 2) - mse = torch.sum(diff2, (-1, -2)) / n # [B] - rmse = torch.sqrt(mse) - return rmse.mean() - - -def log10(output, target, valid_mask=None): - if valid_mask is not None: - diff = torch.abs( - torch.log10(output[valid_mask]) - torch.log10(target[valid_mask]) - ) - else: - diff = torch.abs(torch.log10(output) - torch.log10(target)) - return diff.mean() - - -# Adapted from: https://github.com/imran3180/depth-map-prediction/blob/master/main.py -def threshold_percentage(output, target, threshold_val, valid_mask=None): - d1 = output / target - d2 = target / output - max_d1_d2 = torch.max(d1, d2) - zero = torch.zeros(*output.shape) - one = torch.ones(*output.shape) - bit_mat = torch.where(max_d1_d2.cpu() < threshold_val, one, zero) - if valid_mask is not None: - bit_mat[~valid_mask] = 0 - n = valid_mask.sum((-1, -2)) - else: - n = output.shape[-1] * output.shape[-2] - count_mat = torch.sum(bit_mat, (-1, -2)) - threshold_mat = count_mat / n.cpu() - return threshold_mat.mean() - - -def delta1_acc(pred, gt, valid_mask): - return threshold_percentage(pred, gt, 1.25, valid_mask) - - -def delta2_acc(pred, gt, valid_mask): - return threshold_percentage(pred, gt, 1.25**2, valid_mask) - - -def delta3_acc(pred, gt, valid_mask): - return threshold_percentage(pred, gt, 1.25**3, valid_mask) - - -def i_rmse(output, target, valid_mask=None): - output_inv = 1.0 / output - target_inv = 1.0 / target - diff = output_inv - target_inv - if valid_mask is not None: - diff[~valid_mask] = 0 - n = valid_mask.sum((-1, -2)) - else: - n = output.shape[-1] * output.shape[-2] - diff2 = torch.pow(diff, 2) - mse = torch.sum(diff2, (-1, -2)) / n # [B] - rmse = torch.sqrt(mse) - return rmse.mean() - - -def silog_rmse(depth_pred, depth_gt, valid_mask=None): - diff = torch.log(depth_pred) - torch.log(depth_gt) - if valid_mask is not None: - diff[~valid_mask] = 0 - n = valid_mask.sum((-1, -2)) - else: - n = depth_gt.shape[-2] * depth_gt.shape[-1] - - diff2 = torch.pow(diff, 2) - - first_term = torch.sum(diff2, (-1, -2)) / n - second_term = torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) - loss = torch.sqrt(torch.mean(first_term - second_term)) * 100 - return loss - - -# -------------------- Normals Metrics -------------------- - - -def compute_cosine_error(pred_norm, gt_norm, masked=False): - if len(pred_norm.shape) == 4: - pred_norm = pred_norm.squeeze(0) - if len(gt_norm.shape) == 4: - gt_norm = gt_norm.squeeze(0) - - # shape must be [3,H,W] - assert (gt_norm.shape[0] == 3) and ( - pred_norm.shape[0] == 3 - ), "Channel dim should be the first dimension!" - # mask out the zero vectors, otherwise torch.cosine_similarity computes 90° as error - if masked: - ch, h, w = gt_norm.shape - - mask = torch.norm(gt_norm, dim=0) > 0 - - pred_norm = pred_norm[:, mask.view(h, w)] - gt_norm = gt_norm[:, mask.view(h, w)] - - pred_error = torch.cosine_similarity(pred_norm, gt_norm, dim=0) - pred_error = torch.clamp(pred_error, min=-1.0, max=1.0) - pred_error = torch.acos(pred_error) * 180.0 / np.pi # (H, W) - - return ( - pred_error.view(-1).detach().cpu().numpy() - ) # flatten so can directly input to compute_normal_metrics() - - -def mean_angular_error(cosine_error): - return round(np.average(cosine_error), 4) - - -def median_angular_error(cosine_error): - return round(np.median(cosine_error), 4) - - -def rmse_angular_error(cosine_error): - num_pixels = cosine_error.shape[0] - return round(np.sqrt(np.sum(cosine_error * cosine_error) / num_pixels), 4) - - -def sub5_error(cosine_error): - num_pixels = cosine_error.shape[0] - return round(100.0 * (np.sum(cosine_error < 5) / num_pixels), 4) - - -def sub7_5_error(cosine_error): - num_pixels = cosine_error.shape[0] - return round(100.0 * (np.sum(cosine_error < 7.5) / num_pixels), 4) - - -def sub11_25_error(cosine_error): - num_pixels = cosine_error.shape[0] - return round(100.0 * (np.sum(cosine_error < 11.25) / num_pixels), 4) - - -def sub22_5_error(cosine_error): - num_pixels = cosine_error.shape[0] - return round(100.0 * (np.sum(cosine_error < 22.5) / num_pixels), 4) - - -def sub30_error(cosine_error): - num_pixels = cosine_error.shape[0] - return round(100.0 * (np.sum(cosine_error < 30) / num_pixels), 4) - - -# -------------------- IID Metrics -------------------- - - -def compute_iid_metric(pred, gt, target_name, metric_name, metric, valid_mask=None): - # Shading and residual are up-to-scale. We first scale-align them to the gt - # and map them to the range [0,1] for metric computation - if target_name == "shading" or target_name == "residual": - alignment_scale = compute_alignment_scale(pred, gt, valid_mask) - pred = alignment_scale * pred - # map to [0,1] - pred, gt = quantile_map(pred, gt, valid_mask) - - if len(pred.shape) == 3: - pred = pred.unsqueeze(0) - if len(gt.shape) == 3: - gt = gt.unsqueeze(0) - if valid_mask is not None: - if len(valid_mask.shape) == 3: - valid_mask = valid_mask.unsqueeze(0) - if metric_name == "psnr": - return metric(pred[valid_mask], gt[valid_mask]).item() - # for SSIM and LPIPs set the invalid pixels to zero - else: - invalid_mask = ~valid_mask - pred[invalid_mask] = 0 - gt[invalid_mask] = 0 - - return metric(pred, gt).item() - - -# compute least-squares alignment scale to align shading/residual prediction to gt -def compute_alignment_scale(pred, gt, valid_mask=None): - pred = pred.squeeze() - gt = gt.squeeze() - assert pred.shape[0] == 3 and gt.shape[0] == 3, "First dim should be channel dim" - - if valid_mask is not None: - valid_mask = valid_mask.squeeze() - pred = pred[valid_mask] - gt = gt[valid_mask] - - A_flattened = pred.view(-1, 1) - b_flattened = gt.view(-1, 1) - # Solve the least squares problem - x, residuals, rank, s = torch.linalg.lstsq(A_flattened.float(), b_flattened.float()) - return x - - -def quantile_map(pred, gt, valid_mask=None): - pred = pred.squeeze() - gt = gt.squeeze() - assert gt.shape[0] == 3, "channel dim must be first dim" - - percentile = 90 - brightness_nth_percentile_desired = 0.8 - brightness = 0.3 * gt[0, :, :] + 0.59 * gt[1, :, :] + 0.11 * gt[2, :, :] - - if valid_mask is not None: - valid_mask = valid_mask.squeeze() - brightness = brightness[valid_mask[0]] - else: - brightness = brightness.flatten() - - eps = 0.0001 - - brightness_nth_percentile_current = torch.quantile(brightness, percentile / 100.0) - - if brightness_nth_percentile_current < eps: - scale = 0 - else: - scale = float( - brightness_nth_percentile_desired / brightness_nth_percentile_current - ) - - # Apply scaling to ground truth and prediction - gt_mapped = torch.clamp(scale * gt, 0, 1).unsqueeze(0) # [1,3,H,W] - pred_mapped = torch.clamp(scale * pred, 0, 1).unsqueeze(0) # [1,3,H,W] - - return pred_mapped, gt_mapped diff --git a/src/util/multi_res_noise.py b/src/util/multi_res_noise.py deleted file mode 100644 index a92d14c6cb0669b464f92af04560a963ea327af3..0000000000000000000000000000000000000000 --- a/src/util/multi_res_noise.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -# Adapted from: https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2?s=31 - -import math -import torch - - -def multi_res_noise_like( - x, strength=0.9, downscale_strategy="original", generator=None, device=None -): - if torch.is_tensor(strength): - strength = strength.reshape((-1, 1, 1, 1)) - b, c, w, h = x.shape - - if device is None: - device = x.device - - up_sampler = torch.nn.Upsample(size=(w, h), mode="bilinear") - noise = torch.randn(x.shape, device=x.device, generator=generator) - - if "original" == downscale_strategy: - for i in range(10): - r = ( - torch.rand(1, generator=generator, device=device) * 2 + 2 - ) # Rather than always going 2x, - w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) - noise += ( - up_sampler( - torch.randn(b, c, w, h, generator=generator, device=device).to(x) - ) - * strength**i - ) - if w == 1 or h == 1: - break # Lowest resolution is 1x1 - elif "every_layer" == downscale_strategy: - for i in range(int(math.log2(min(w, h)))): - w, h = max(1, int(w / 2)), max(1, int(h / 2)) - noise += ( - up_sampler( - torch.randn(b, c, w, h, generator=generator, device=device).to(x) - ) - * strength**i - ) - elif "power_of_two" == downscale_strategy: - for i in range(10): - r = 2 - w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) - noise += ( - up_sampler( - torch.randn(b, c, w, h, generator=generator, device=device).to(x) - ) - * strength**i - ) - if w == 1 or h == 1: - break # Lowest resolution is 1x1 - elif "random_step" == downscale_strategy: - for i in range(10): - r = ( - torch.rand(1, generator=generator, device=device) * 2 + 2 - ) # Rather than always going 2x, - w, h = max(1, int(w / (r))), max(1, int(h / (r))) - noise += ( - up_sampler( - torch.randn(b, c, w, h, generator=generator, device=device).to(x) - ) - * strength**i - ) - if w == 1 or h == 1: - break # Lowest resolution is 1x1 - else: - raise ValueError(f"unknown downscale strategy: {downscale_strategy}") - - noise = noise / noise.std() # Scaled back to roughly unit variance - return noise diff --git a/src/util/slurm_util.py b/src/util/slurm_util.py deleted file mode 100644 index 77c0f479979eb815a31a3c149d546dfa68c252af..0000000000000000000000000000000000000000 --- a/src/util/slurm_util.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -------------------------------------------------------------------------- -# More information about Marigold: -# https://marigoldmonodepth.github.io -# https://marigoldcomputervision.github.io -# Efficient inference pipelines are now part of diffusers: -# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage -# https://huggingface.co/docs/diffusers/api/pipelines/marigold -# Examples of trained models and live demos: -# https://huggingface.co/prs-eth -# Related projects: -# https://rollingdepth.github.io/ -# https://marigolddepthcompletion.github.io/ -# Citation (BibTeX): -# https://github.com/prs-eth/Marigold#-citation -# If you find Marigold useful, we kindly ask you to cite our papers. -# -------------------------------------------------------------------------- - -import os - - -def is_on_slurm(): - cluster_name = os.getenv("SLURM_CLUSTER_NAME") - is_on_slurm = cluster_name is not None - return is_on_slurm - - -def get_local_scratch_dir(): - local_scratch_dir = os.getenv("TMPDIR") - return local_scratch_dir