degbo commited on
Commit
19a564f
·
1 Parent(s): 7d57572
Files changed (2) hide show
  1. marigold/__init__.py +41 -0
  2. marigold/util/image_util.py +149 -0
marigold/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # More information about Marigold:
16
+ # https://marigoldmonodepth.github.io
17
+ # https://marigoldcomputervision.github.io
18
+ # Efficient inference pipelines are now part of diffusers:
19
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
20
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
21
+ # Examples of trained models and live demos:
22
+ # https://huggingface.co/prs-eth
23
+ # Related projects:
24
+ # https://rollingdepth.github.io/
25
+ # https://marigolddepthcompletion.github.io/
26
+ # Citation (BibTeX):
27
+ # https://github.com/prs-eth/Marigold#-citation
28
+ # If you find Marigold useful, we kindly ask you to cite our papers.
29
+ # --------------------------------------------------------------------------
30
+
31
+ from .marigold_depth_pipeline import (
32
+ MarigoldDepthPipeline,
33
+ MarigoldDepthOutput, # noqa: F401
34
+ )
35
+ from .marigold_iid_pipeline import MarigoldIIDPipeline, MarigoldIIDOutput # noqa: F401
36
+ from .marigold_normals_pipeline import (
37
+ MarigoldNormalsPipeline, # noqa: F401
38
+ MarigoldNormalsOutput, # noqa: F401
39
+ )
40
+
41
+ MarigoldPipeline = MarigoldDepthPipeline # for backward compatibility
marigold/util/image_util.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # More information about Marigold:
16
+ # https://marigoldmonodepth.github.io
17
+ # https://marigoldcomputervision.github.io
18
+ # Efficient inference pipelines are now part of diffusers:
19
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
20
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
21
+ # Examples of trained models and live demos:
22
+ # https://huggingface.co/prs-eth
23
+ # Related projects:
24
+ # https://rollingdepth.github.io/
25
+ # https://marigolddepthcompletion.github.io/
26
+ # Citation (BibTeX):
27
+ # https://github.com/prs-eth/Marigold#-citation
28
+ # If you find Marigold useful, we kindly ask you to cite our papers.
29
+ # --------------------------------------------------------------------------
30
+
31
+ import matplotlib
32
+ import numpy as np
33
+ import torch
34
+ from torchvision.transforms import InterpolationMode
35
+ from torchvision.transforms.functional import resize
36
+
37
+
38
+ def colorize_depth_maps(
39
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
40
+ ):
41
+ """
42
+ Colorize depth maps.
43
+ """
44
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
45
+
46
+ if isinstance(depth_map, torch.Tensor):
47
+ depth = depth_map.detach().squeeze().numpy()
48
+ elif isinstance(depth_map, np.ndarray):
49
+ depth = depth_map.copy().squeeze()
50
+ # reshape to [ (B,) H, W ]
51
+ if depth.ndim < 3:
52
+ depth = depth[np.newaxis, :, :]
53
+
54
+ # colorize
55
+ cm = matplotlib.colormaps[cmap]
56
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
57
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
58
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
59
+
60
+ if valid_mask is not None:
61
+ if isinstance(depth_map, torch.Tensor):
62
+ valid_mask = valid_mask.detach().numpy()
63
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
64
+ if valid_mask.ndim < 3:
65
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
66
+ else:
67
+ valid_mask = valid_mask[:, np.newaxis, :, :]
68
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
69
+ img_colored_np[~valid_mask] = 0
70
+
71
+ if isinstance(depth_map, torch.Tensor):
72
+ img_colored = torch.from_numpy(img_colored_np).float()
73
+ elif isinstance(depth_map, np.ndarray):
74
+ img_colored = img_colored_np
75
+
76
+ return img_colored
77
+
78
+
79
+ def chw2hwc(chw):
80
+ assert 3 == len(chw.shape)
81
+ if isinstance(chw, torch.Tensor):
82
+ hwc = torch.permute(chw, (1, 2, 0))
83
+ elif isinstance(chw, np.ndarray):
84
+ hwc = np.moveaxis(chw, 0, -1)
85
+ else:
86
+ raise TypeError("img should be np.ndarray or torch.Tensor")
87
+ return hwc
88
+
89
+
90
+ def resize_max_res(
91
+ img: torch.Tensor,
92
+ max_edge_resolution: int,
93
+ resample_method: InterpolationMode = InterpolationMode.BILINEAR,
94
+ ) -> torch.Tensor:
95
+ """
96
+ Resize image to limit maximum edge length while keeping aspect ratio.
97
+
98
+ Args:
99
+ img (`torch.Tensor`):
100
+ Image tensor to be resized. Expected shape: [B, C, H, W]
101
+ max_edge_resolution (`int`):
102
+ Maximum edge length (pixel).
103
+ resample_method (`PIL.Image.Resampling`):
104
+ Resampling method used to resize images.
105
+
106
+ Returns:
107
+ `torch.Tensor`: Resized image.
108
+ """
109
+ assert 4 == img.dim(), f"Invalid input shape {img.shape}"
110
+
111
+ original_height, original_width = img.shape[-2:]
112
+ downscale_factor = min(
113
+ max_edge_resolution / original_width, max_edge_resolution / original_height
114
+ )
115
+
116
+ new_width = int(original_width * downscale_factor)
117
+ new_height = int(original_height * downscale_factor)
118
+
119
+ resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
120
+ return resized_img
121
+
122
+
123
+ def get_tv_resample_method(method_str: str) -> InterpolationMode:
124
+ resample_method_dict = {
125
+ "bilinear": InterpolationMode.BILINEAR,
126
+ "bicubic": InterpolationMode.BICUBIC,
127
+ "nearest": InterpolationMode.NEAREST_EXACT,
128
+ "nearest-exact": InterpolationMode.NEAREST_EXACT,
129
+ }
130
+ resample_method = resample_method_dict.get(method_str, None)
131
+ if resample_method is None:
132
+ raise ValueError(f"Unknown resampling method: {resample_method}")
133
+ else:
134
+ return resample_method
135
+
136
+
137
+ def float2int(img):
138
+ if isinstance(img, np.ndarray):
139
+ return (img * 255.0).astype(np.uint8)
140
+ else:
141
+ return (img * 255.0).to(torch.uint8)
142
+
143
+
144
+ def srgb2linear(img):
145
+ return img**2.2
146
+
147
+
148
+ def linear2srgb(img):
149
+ return img ** (1.0 / 2.2)