Spaces:
Sleeping
Sleeping
Delete Marigold
Browse files- Marigold/README.md +0 -1
- Marigold/marigold/__init__.py +0 -21
- Marigold/marigold/marigold_pipeline.py +0 -534
- Marigold/marigold/util/__init__.py +0 -0
- Marigold/marigold/util/batchsize.py +0 -81
- Marigold/marigold/util/ensemble.py +0 -132
- Marigold/marigold/util/image_util.py +0 -121
Marigold/README.md
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
Code is copied from https://github.com/prs-eth/Marigold. Modifications are indicated within the code.
|
|
|
|
|
|
Marigold/marigold/__init__.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
# --------------------------------------------------------------------------
|
| 15 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 16 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
| 17 |
-
# More information about the method can be found at https://marigoldmonodepth.github.io
|
| 18 |
-
# --------------------------------------------------------------------------
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
from .marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput # noqa: F401
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Marigold/marigold/marigold_pipeline.py
DELETED
|
@@ -1,534 +0,0 @@
|
|
| 1 |
-
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
# --------------------------------------------------------------------------
|
| 15 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 16 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
| 17 |
-
# More information about the method can be found at https://marigoldmonodepth.github.io
|
| 18 |
-
# --------------------------------------------------------------------------
|
| 19 |
-
|
| 20 |
-
# @GonzaloMartinGarcia
|
| 21 |
-
# This file is a modified version of the original Marigold pipeline file.
|
| 22 |
-
# Based on GeoWizard, we added the option to sample surface normals, marked with # add.
|
| 23 |
-
|
| 24 |
-
from typing import Dict, Union
|
| 25 |
-
|
| 26 |
-
import numpy as np
|
| 27 |
-
import torch
|
| 28 |
-
from diffusers import (
|
| 29 |
-
AutoencoderKL,
|
| 30 |
-
DDIMScheduler,
|
| 31 |
-
DiffusionPipeline,
|
| 32 |
-
LCMScheduler,
|
| 33 |
-
UNet2DConditionModel,
|
| 34 |
-
DDPMScheduler,
|
| 35 |
-
)
|
| 36 |
-
from diffusers.utils import BaseOutput
|
| 37 |
-
from PIL import Image
|
| 38 |
-
from torchvision.transforms.functional import resize, pil_to_tensor
|
| 39 |
-
from torchvision.transforms import InterpolationMode
|
| 40 |
-
from torch.utils.data import DataLoader, TensorDataset
|
| 41 |
-
from tqdm.auto import tqdm
|
| 42 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
| 43 |
-
|
| 44 |
-
from .util.batchsize import find_batch_size
|
| 45 |
-
from .util.ensemble import ensemble_depths
|
| 46 |
-
from .util.image_util import (
|
| 47 |
-
chw2hwc,
|
| 48 |
-
colorize_depth_maps,
|
| 49 |
-
get_tv_resample_method,
|
| 50 |
-
resize_max_res,
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
# add
|
| 54 |
-
import random
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
# add
|
| 58 |
-
# Surface Normals Ensamble from the GeoWizard github repository (https://github.com/fuxiao0719/GeoWizard)
|
| 59 |
-
def ensemble_normals(input_images:torch.Tensor):
|
| 60 |
-
normal_preds = input_images
|
| 61 |
-
bsz, d, h, w = normal_preds.shape
|
| 62 |
-
normal_preds = normal_preds / (torch.norm(normal_preds, p=2, dim=1).unsqueeze(1)+1e-5)
|
| 63 |
-
phi = torch.atan2(normal_preds[:,1,:,:], normal_preds[:,0,:,:]).mean(dim=0)
|
| 64 |
-
theta = torch.atan2(torch.norm(normal_preds[:,:2,:,:], p=2, dim=1), normal_preds[:,2,:,:]).mean(dim=0)
|
| 65 |
-
normal_pred = torch.zeros((d,h,w)).to(normal_preds)
|
| 66 |
-
normal_pred[0,:,:] = torch.sin(theta) * torch.cos(phi)
|
| 67 |
-
normal_pred[1,:,:] = torch.sin(theta) * torch.sin(phi)
|
| 68 |
-
normal_pred[2,:,:] = torch.cos(theta)
|
| 69 |
-
angle_error = torch.acos(torch.clip(torch.cosine_similarity(normal_pred[None], normal_preds, dim=1),-0.999, 0.999))
|
| 70 |
-
normal_idx = torch.argmin(angle_error.reshape(bsz,-1).sum(-1))
|
| 71 |
-
return normal_preds[normal_idx], None
|
| 72 |
-
|
| 73 |
-
# add
|
| 74 |
-
# Pyramid nosie from
|
| 75 |
-
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2?s=31
|
| 76 |
-
def pyramid_noise_like(x, discount=0.9):
|
| 77 |
-
b, c, w, h = x.shape
|
| 78 |
-
u = torch.nn.Upsample(size=(w, h), mode='bilinear')
|
| 79 |
-
noise = torch.randn_like(x)
|
| 80 |
-
for i in range(10):
|
| 81 |
-
r = random.random()*2+2
|
| 82 |
-
w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
|
| 83 |
-
noise += u(torch.randn(b, c, w, h).to(x)) * discount**i
|
| 84 |
-
if w==1 or h==1:
|
| 85 |
-
break
|
| 86 |
-
return noise / noise.std()
|
| 87 |
-
|
| 88 |
-
class MarigoldDepthOutput(BaseOutput):
|
| 89 |
-
"""
|
| 90 |
-
Output class for Marigold monocular depth prediction pipeline.
|
| 91 |
-
|
| 92 |
-
Args:
|
| 93 |
-
depth_np (`np.ndarray`):
|
| 94 |
-
Predicted depth map, with depth values in the range of [0, 1].
|
| 95 |
-
depth_colored (`PIL.Image.Image`):
|
| 96 |
-
Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
|
| 97 |
-
uncertainty (`None` or `np.ndarray`):
|
| 98 |
-
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
|
| 99 |
-
normal_np (`np.ndarray`):
|
| 100 |
-
Predicted normal map, with normal vectors in the range of [-1, 1].
|
| 101 |
-
normal_colored (`PIL.Image.Image`):
|
| 102 |
-
Colorized normal map
|
| 103 |
-
"""
|
| 104 |
-
|
| 105 |
-
depth_np: np.ndarray
|
| 106 |
-
depth_colored: Union[None, Image.Image]
|
| 107 |
-
uncertainty: Union[None, np.ndarray]
|
| 108 |
-
# add
|
| 109 |
-
normal_np: np.ndarray
|
| 110 |
-
normal_colored: Union[None, Image.Image]
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
class MarigoldPipeline(DiffusionPipeline):
|
| 114 |
-
"""
|
| 115 |
-
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
|
| 116 |
-
|
| 117 |
-
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 118 |
-
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 119 |
-
|
| 120 |
-
Args:
|
| 121 |
-
unet (`UNet2DConditionModel`):
|
| 122 |
-
Conditional U-Net to denoise the depth latent, conditioned on image latent.
|
| 123 |
-
vae (`AutoencoderKL`):
|
| 124 |
-
Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
|
| 125 |
-
to and from latent representations.
|
| 126 |
-
scheduler (`DDIMScheduler`):
|
| 127 |
-
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
| 128 |
-
text_encoder (`CLIPTextModel`):
|
| 129 |
-
Text-encoder, for empty text embedding.
|
| 130 |
-
tokenizer (`CLIPTokenizer`):
|
| 131 |
-
CLIP tokenizer.
|
| 132 |
-
"""
|
| 133 |
-
|
| 134 |
-
rgb_latent_scale_factor = 0.18215
|
| 135 |
-
depth_latent_scale_factor = 0.18215
|
| 136 |
-
|
| 137 |
-
def __init__(
|
| 138 |
-
self,
|
| 139 |
-
unet: UNet2DConditionModel,
|
| 140 |
-
vae: AutoencoderKL,
|
| 141 |
-
scheduler: Union[DDIMScheduler,DDPMScheduler,LCMScheduler],
|
| 142 |
-
text_encoder: CLIPTextModel,
|
| 143 |
-
tokenizer: CLIPTokenizer,
|
| 144 |
-
):
|
| 145 |
-
super().__init__()
|
| 146 |
-
|
| 147 |
-
self.register_modules(
|
| 148 |
-
unet=unet,
|
| 149 |
-
vae=vae,
|
| 150 |
-
scheduler=scheduler,
|
| 151 |
-
text_encoder=text_encoder,
|
| 152 |
-
tokenizer=tokenizer,
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
self.empty_text_embed = None
|
| 156 |
-
|
| 157 |
-
@torch.no_grad()
|
| 158 |
-
def __call__(
|
| 159 |
-
self,
|
| 160 |
-
input_image: Union[Image.Image, torch.Tensor],
|
| 161 |
-
denoising_steps: int = 10,
|
| 162 |
-
ensemble_size: int = 10,
|
| 163 |
-
processing_res: int = 768,
|
| 164 |
-
match_input_res: bool = True,
|
| 165 |
-
resample_method: str = "bilinear",
|
| 166 |
-
batch_size: int = 0,
|
| 167 |
-
color_map: str = "Spectral",
|
| 168 |
-
show_progress_bar: bool = True,
|
| 169 |
-
ensemble_kwargs: Dict = None,
|
| 170 |
-
# add
|
| 171 |
-
noise="gaussian",
|
| 172 |
-
normals=False,
|
| 173 |
-
) -> MarigoldDepthOutput:
|
| 174 |
-
"""
|
| 175 |
-
Function invoked when calling the pipeline.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
input_image (`Image`):
|
| 179 |
-
Input RGB (or gray-scale) image.
|
| 180 |
-
processing_res (`int`, *optional*, defaults to `768`):
|
| 181 |
-
Maximum resolution of processing.
|
| 182 |
-
If set to 0: will not resize at all.
|
| 183 |
-
match_input_res (`bool`, *optional*, defaults to `True`):
|
| 184 |
-
Resize depth prediction to match input resolution.
|
| 185 |
-
Only valid if `processing_res` > 0.
|
| 186 |
-
resample_method: (`str`, *optional*, defaults to `bilinear`):
|
| 187 |
-
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
|
| 188 |
-
denoising_steps (`int`, *optional*, defaults to `10`):
|
| 189 |
-
Number of diffusion denoising steps (DDIM) during inference.
|
| 190 |
-
ensemble_size (`int`, *optional*, defaults to `10`):
|
| 191 |
-
Number of predictions to be ensembled.
|
| 192 |
-
batch_size (`int`, *optional*, defaults to `0`):
|
| 193 |
-
Inference batch size, no bigger than `num_ensemble`.
|
| 194 |
-
If set to 0, the script will automatically decide the proper batch size.
|
| 195 |
-
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
| 196 |
-
Display a progress bar of diffusion denoising.
|
| 197 |
-
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
|
| 198 |
-
Colormap used to colorize the depth map.
|
| 199 |
-
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
|
| 200 |
-
Arguments for detailed ensembling settings.
|
| 201 |
-
noise (`str`, *optional*, defaults to `gaussian`):
|
| 202 |
-
Type of noise to be used for the initial depth map.
|
| 203 |
-
Can be one of `gaussian`, `pyramid`, `zeros`.
|
| 204 |
-
normals (`bool`, *optional*, defaults to `False`):
|
| 205 |
-
If `True`, the pipeline will predict surface normals instead of depth maps.
|
| 206 |
-
Returns:
|
| 207 |
-
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
|
| 208 |
-
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
|
| 209 |
-
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
|
| 210 |
-
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
|
| 211 |
-
coming from ensembling. None if `ensemble_size = 1`
|
| 212 |
-
- **normal_np** (`np.ndarray`) Predicted normal map, with normal vectors in the range of [-1, 1]
|
| 213 |
-
- **normal_colored** (`PIL.Image.Image`) Colorized normal map
|
| 214 |
-
"""
|
| 215 |
-
|
| 216 |
-
assert processing_res >= 0
|
| 217 |
-
assert ensemble_size >= 1
|
| 218 |
-
|
| 219 |
-
resample_method: InterpolationMode = get_tv_resample_method(resample_method)
|
| 220 |
-
|
| 221 |
-
# ----------------- Image Preprocess -----------------
|
| 222 |
-
|
| 223 |
-
# Convert to torch tensor
|
| 224 |
-
if isinstance(input_image, Image.Image):
|
| 225 |
-
input_image = input_image.convert("RGB")
|
| 226 |
-
rgb = pil_to_tensor(input_image) # [H, W, rgb] -> [rgb, H, W]
|
| 227 |
-
elif isinstance(input_image, torch.Tensor):
|
| 228 |
-
rgb = input_image.squeeze()
|
| 229 |
-
else:
|
| 230 |
-
raise TypeError(f"Unknown input type: {type(input_image) = }")
|
| 231 |
-
input_size = rgb.shape
|
| 232 |
-
assert (
|
| 233 |
-
3 == rgb.dim() and 3 == input_size[0]
|
| 234 |
-
), f"Wrong input shape {input_size}, expected [rgb, H, W]"
|
| 235 |
-
|
| 236 |
-
# Resize image
|
| 237 |
-
if processing_res > 0:
|
| 238 |
-
rgb = resize_max_res(
|
| 239 |
-
rgb,
|
| 240 |
-
max_edge_resolution=processing_res,
|
| 241 |
-
resample_method=resample_method,
|
| 242 |
-
)
|
| 243 |
-
|
| 244 |
-
# Normalize rgb values
|
| 245 |
-
rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
|
| 246 |
-
rgb_norm = rgb_norm.to(self.dtype)
|
| 247 |
-
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
|
| 248 |
-
|
| 249 |
-
# ----------------- Predicting depth/normal --------------
|
| 250 |
-
|
| 251 |
-
# Batch repeated input image
|
| 252 |
-
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
|
| 253 |
-
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
| 254 |
-
if batch_size > 0:
|
| 255 |
-
_bs = batch_size
|
| 256 |
-
else:
|
| 257 |
-
_bs = find_batch_size(
|
| 258 |
-
ensemble_size=ensemble_size,
|
| 259 |
-
input_res=max(rgb_norm.shape[1:]),
|
| 260 |
-
dtype=self.dtype,
|
| 261 |
-
)
|
| 262 |
-
|
| 263 |
-
single_rgb_loader = DataLoader(
|
| 264 |
-
single_rgb_dataset, batch_size=_bs, shuffle=False
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
# load iterator
|
| 268 |
-
pred_ls = []
|
| 269 |
-
if show_progress_bar:
|
| 270 |
-
iterable = tqdm(
|
| 271 |
-
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
|
| 272 |
-
)
|
| 273 |
-
else:
|
| 274 |
-
iterable = single_rgb_loader
|
| 275 |
-
|
| 276 |
-
# inference (batched)
|
| 277 |
-
for batch in iterable:
|
| 278 |
-
(batched_img,) = batch
|
| 279 |
-
pred_raw = self.single_infer(
|
| 280 |
-
rgb_in=batched_img,
|
| 281 |
-
num_inference_steps=denoising_steps,
|
| 282 |
-
show_pbar=show_progress_bar,
|
| 283 |
-
# add
|
| 284 |
-
noise=noise,
|
| 285 |
-
normals=normals,
|
| 286 |
-
)
|
| 287 |
-
pred_ls.append(pred_raw.detach())
|
| 288 |
-
preds = torch.concat(pred_ls, dim=0).squeeze()
|
| 289 |
-
torch.cuda.empty_cache() # clear vram cache for ensembling
|
| 290 |
-
|
| 291 |
-
# ----------------- Test-time ensembling -----------------
|
| 292 |
-
|
| 293 |
-
if ensemble_size > 1: # add
|
| 294 |
-
pred, pred_uncert = ensemble_normals(preds) if normals else ensemble_depths(preds, **(ensemble_kwargs or {}))
|
| 295 |
-
else:
|
| 296 |
-
pred = preds
|
| 297 |
-
pred_uncert = None
|
| 298 |
-
|
| 299 |
-
# ----------------- Post processing -----------------
|
| 300 |
-
|
| 301 |
-
if normals:
|
| 302 |
-
# add
|
| 303 |
-
# Normalizae normal vectors to unit length
|
| 304 |
-
pred /= (torch.norm(pred, p=2, dim=0, keepdim=True)+1e-5)
|
| 305 |
-
else:
|
| 306 |
-
# Scale relative prediction to [0, 1]
|
| 307 |
-
min_d = torch.min(pred)
|
| 308 |
-
max_d = torch.max(pred)
|
| 309 |
-
if max_d == min_d:
|
| 310 |
-
pred = torch.zeros_like(pred)
|
| 311 |
-
else:
|
| 312 |
-
pred = (pred - min_d) / (max_d - min_d)
|
| 313 |
-
|
| 314 |
-
# Resize back to original resolution
|
| 315 |
-
if match_input_res:
|
| 316 |
-
pred = resize(
|
| 317 |
-
pred if normals else pred.unsqueeze(0),
|
| 318 |
-
(input_size[-2],input_size[-1]),
|
| 319 |
-
interpolation=resample_method,
|
| 320 |
-
antialias=True,
|
| 321 |
-
).squeeze()
|
| 322 |
-
|
| 323 |
-
# Convert to numpy
|
| 324 |
-
pred = pred.cpu().numpy()
|
| 325 |
-
|
| 326 |
-
# Process prediction for visualization
|
| 327 |
-
if not normals:
|
| 328 |
-
# add
|
| 329 |
-
pred = pred.clip(0, 1)
|
| 330 |
-
if color_map is not None:
|
| 331 |
-
colored = colorize_depth_maps(
|
| 332 |
-
pred, 0, 1, cmap=color_map
|
| 333 |
-
).squeeze() # [3, H, W], value in (0, 1)
|
| 334 |
-
colored = (colored * 255).astype(np.uint8)
|
| 335 |
-
colored_hwc = chw2hwc(colored)
|
| 336 |
-
colored_img = Image.fromarray(colored_hwc)
|
| 337 |
-
else:
|
| 338 |
-
colored_img = None
|
| 339 |
-
else:
|
| 340 |
-
pred = pred.clip(-1.0, 1.0)
|
| 341 |
-
colored = (((pred+1)/2) * 255).astype(np.uint8)
|
| 342 |
-
colored_hwc = chw2hwc(colored)
|
| 343 |
-
colored_img = Image.fromarray(colored_hwc)
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
return MarigoldDepthOutput(
|
| 347 |
-
depth_np = pred if not normals else None,
|
| 348 |
-
depth_colored = colored_img if not normals else None,
|
| 349 |
-
uncertainty = pred_uncert,
|
| 350 |
-
# add
|
| 351 |
-
normal_np = pred if normals else None,
|
| 352 |
-
normal_colored = colored_img if normals else None,
|
| 353 |
-
)
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
def encode_empty_text(self):
|
| 357 |
-
"""
|
| 358 |
-
Encode text embedding for empty prompt
|
| 359 |
-
"""
|
| 360 |
-
prompt = ""
|
| 361 |
-
text_inputs = self.tokenizer(
|
| 362 |
-
prompt,
|
| 363 |
-
padding="do_not_pad",
|
| 364 |
-
max_length=self.tokenizer.model_max_length,
|
| 365 |
-
truncation=True,
|
| 366 |
-
return_tensors="pt",
|
| 367 |
-
)
|
| 368 |
-
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
| 369 |
-
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
| 370 |
-
|
| 371 |
-
@torch.no_grad()
|
| 372 |
-
def single_infer(
|
| 373 |
-
self,
|
| 374 |
-
rgb_in: torch.Tensor,
|
| 375 |
-
num_inference_steps: int,
|
| 376 |
-
show_pbar: bool,
|
| 377 |
-
# add
|
| 378 |
-
noise="gaussian",
|
| 379 |
-
normals=False,
|
| 380 |
-
) -> torch.Tensor:
|
| 381 |
-
"""
|
| 382 |
-
Perform an individual depth prediction without ensembling.
|
| 383 |
-
|
| 384 |
-
Args:
|
| 385 |
-
rgb_in (`torch.Tensor`):
|
| 386 |
-
Input RGB image.
|
| 387 |
-
num_inference_steps (`int`):
|
| 388 |
-
Number of diffusion denoisign steps (DDIM) during inference.
|
| 389 |
-
show_pbar (`bool`):
|
| 390 |
-
Display a progress bar of diffusion denoising.
|
| 391 |
-
noise (`str`, *optional*, defaults to `gaussian`):
|
| 392 |
-
Type of noise to be used for the initial depth map.
|
| 393 |
-
Can be one of `gaussian`, `pyramid`, `zeros`.
|
| 394 |
-
Returns:
|
| 395 |
-
`torch.Tensor`: Predicted depth map.
|
| 396 |
-
"""
|
| 397 |
-
device = self.device
|
| 398 |
-
rgb_in = rgb_in.to(device)
|
| 399 |
-
|
| 400 |
-
# Set timesteps
|
| 401 |
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 402 |
-
timesteps = self.scheduler.timesteps # [T]
|
| 403 |
-
|
| 404 |
-
# Encode image
|
| 405 |
-
rgb_latent = self.encode_rgb(rgb_in)
|
| 406 |
-
|
| 407 |
-
# add
|
| 408 |
-
# Initial prediction
|
| 409 |
-
latent_shape = rgb_latent.shape
|
| 410 |
-
if noise == "gaussian":
|
| 411 |
-
latent = torch.randn(
|
| 412 |
-
latent_shape,
|
| 413 |
-
device=device,
|
| 414 |
-
dtype=self.dtype,
|
| 415 |
-
)
|
| 416 |
-
elif noise == "pyramid":
|
| 417 |
-
latent = pyramid_noise_like(rgb_latent).to(device) # [B, 4, h, w]
|
| 418 |
-
elif noise == "zeros":
|
| 419 |
-
latent = torch.zeros(
|
| 420 |
-
latent_shape,
|
| 421 |
-
device=device,
|
| 422 |
-
dtype=self.dtype,
|
| 423 |
-
)
|
| 424 |
-
else:
|
| 425 |
-
raise ValueError(f"Unknown noise type: {noise}")
|
| 426 |
-
|
| 427 |
-
# Batched empty text embedding
|
| 428 |
-
if self.empty_text_embed is None:
|
| 429 |
-
self.encode_empty_text()
|
| 430 |
-
batch_empty_text_embed = self.empty_text_embed.repeat(
|
| 431 |
-
(rgb_latent.shape[0], 1, 1)
|
| 432 |
-
) # [B, 2, 1024]
|
| 433 |
-
|
| 434 |
-
# Denoising loop
|
| 435 |
-
if show_pbar:
|
| 436 |
-
iterable = tqdm(
|
| 437 |
-
enumerate(timesteps),
|
| 438 |
-
total=len(timesteps),
|
| 439 |
-
leave=False,
|
| 440 |
-
desc=" " * 4 + "Diffusion denoising",
|
| 441 |
-
)
|
| 442 |
-
else:
|
| 443 |
-
iterable = enumerate(timesteps)
|
| 444 |
-
|
| 445 |
-
for i, t in iterable:
|
| 446 |
-
|
| 447 |
-
unet_input = torch.cat(
|
| 448 |
-
[rgb_latent, latent], dim=1
|
| 449 |
-
) # this order is important
|
| 450 |
-
|
| 451 |
-
# predict the noise residual
|
| 452 |
-
noise_pred = self.unet(
|
| 453 |
-
unet_input, t, encoder_hidden_states=batch_empty_text_embed
|
| 454 |
-
).sample # [B, 4, h, w]
|
| 455 |
-
|
| 456 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 457 |
-
scheduler_step = self.scheduler.step(
|
| 458 |
-
noise_pred, t, latent
|
| 459 |
-
)
|
| 460 |
-
|
| 461 |
-
latent = scheduler_step.prev_sample
|
| 462 |
-
|
| 463 |
-
if normals:
|
| 464 |
-
# add
|
| 465 |
-
# decode and normalize normal vectors
|
| 466 |
-
normal = self.decode_normal(latent)
|
| 467 |
-
normal /= (torch.norm(normal, p=2, dim=1, keepdim=True)+1e-5)
|
| 468 |
-
return normal
|
| 469 |
-
else:
|
| 470 |
-
# decode and normalize depth map
|
| 471 |
-
depth = self.decode_depth(latent)
|
| 472 |
-
depth = torch.clip(depth, -1.0, 1.0)
|
| 473 |
-
depth = (depth + 1.0) / 2.0
|
| 474 |
-
return depth
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
| 478 |
-
"""
|
| 479 |
-
Encode RGB image into latent.
|
| 480 |
-
|
| 481 |
-
Args:
|
| 482 |
-
rgb_in (`torch.Tensor`):
|
| 483 |
-
Input RGB image to be encoded.
|
| 484 |
-
|
| 485 |
-
Returns:
|
| 486 |
-
`torch.Tensor`: Image latent.
|
| 487 |
-
"""
|
| 488 |
-
# encode
|
| 489 |
-
h = self.vae.encoder(rgb_in)
|
| 490 |
-
moments = self.vae.quant_conv(h)
|
| 491 |
-
mean, logvar = torch.chunk(moments, 2, dim=1)
|
| 492 |
-
# scale latent
|
| 493 |
-
rgb_latent = mean * self.rgb_latent_scale_factor
|
| 494 |
-
return rgb_latent
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
|
| 498 |
-
"""
|
| 499 |
-
Decode depth latent into depth map.
|
| 500 |
-
|
| 501 |
-
Args:
|
| 502 |
-
depth_latent (`torch.Tensor`):
|
| 503 |
-
Depth latent to be decoded.
|
| 504 |
-
|
| 505 |
-
Returns:
|
| 506 |
-
`torch.Tensor`: Decoded depth map.
|
| 507 |
-
"""
|
| 508 |
-
# scale latent
|
| 509 |
-
depth_latent = depth_latent / self.depth_latent_scale_factor
|
| 510 |
-
# decode
|
| 511 |
-
z = self.vae.post_quant_conv(depth_latent)
|
| 512 |
-
stacked = self.vae.decoder(z)
|
| 513 |
-
# mean of output channels
|
| 514 |
-
depth_mean = stacked.mean(dim=1, keepdim=True)
|
| 515 |
-
return depth_mean
|
| 516 |
-
|
| 517 |
-
# add
|
| 518 |
-
def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor:
|
| 519 |
-
"""
|
| 520 |
-
Decode normal latent into normal map.
|
| 521 |
-
|
| 522 |
-
Args:
|
| 523 |
-
normal_latent (`torch.Tensor`):
|
| 524 |
-
normal latent to be decoded.
|
| 525 |
-
|
| 526 |
-
Returns:
|
| 527 |
-
`torch.Tensor`: Decoded depth map.
|
| 528 |
-
"""
|
| 529 |
-
# scale latent
|
| 530 |
-
normal_latent = normal_latent / self.depth_latent_scale_factor
|
| 531 |
-
# decode
|
| 532 |
-
z = self.vae.post_quant_conv(normal_latent)
|
| 533 |
-
normal = self.vae.decoder(z)
|
| 534 |
-
return normal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Marigold/marigold/util/__init__.py
DELETED
|
File without changes
|
Marigold/marigold/util/batchsize.py
DELETED
|
@@ -1,81 +0,0 @@
|
|
| 1 |
-
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
# --------------------------------------------------------------------------
|
| 15 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 16 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
| 17 |
-
# More information about the method can be found at https://marigoldmonodepth.github.io
|
| 18 |
-
# --------------------------------------------------------------------------
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
import torch
|
| 22 |
-
import math
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# Search table for suggested max. inference batch size
|
| 26 |
-
bs_search_table = [
|
| 27 |
-
# tested on A100-PCIE-80GB
|
| 28 |
-
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
| 29 |
-
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
| 30 |
-
# tested on A100-PCIE-40GB
|
| 31 |
-
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
| 32 |
-
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
| 33 |
-
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
| 34 |
-
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
| 35 |
-
# tested on RTX3090, RTX4090
|
| 36 |
-
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
| 37 |
-
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
| 38 |
-
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
| 39 |
-
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
| 40 |
-
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
| 41 |
-
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
| 42 |
-
# tested on GTX1080Ti
|
| 43 |
-
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
| 44 |
-
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
| 45 |
-
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
| 46 |
-
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
| 47 |
-
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
| 48 |
-
]
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
| 52 |
-
"""
|
| 53 |
-
Automatically search for suitable operating batch size.
|
| 54 |
-
|
| 55 |
-
Args:
|
| 56 |
-
ensemble_size (`int`):
|
| 57 |
-
Number of predictions to be ensembled.
|
| 58 |
-
input_res (`int`):
|
| 59 |
-
Operating resolution of the input image.
|
| 60 |
-
|
| 61 |
-
Returns:
|
| 62 |
-
`int`: Operating batch size.
|
| 63 |
-
"""
|
| 64 |
-
if not torch.cuda.is_available():
|
| 65 |
-
return 1
|
| 66 |
-
|
| 67 |
-
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
| 68 |
-
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
| 69 |
-
for settings in sorted(
|
| 70 |
-
filtered_bs_search_table,
|
| 71 |
-
key=lambda k: (k["res"], -k["total_vram"]),
|
| 72 |
-
):
|
| 73 |
-
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
| 74 |
-
bs = settings["bs"]
|
| 75 |
-
if bs > ensemble_size:
|
| 76 |
-
bs = ensemble_size
|
| 77 |
-
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
| 78 |
-
bs = math.ceil(ensemble_size / 2)
|
| 79 |
-
return bs
|
| 80 |
-
|
| 81 |
-
return 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Marigold/marigold/util/ensemble.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
# --------------------------------------------------------------------------
|
| 15 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 16 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
| 17 |
-
# More information about the method can be found at https://marigoldmonodepth.github.io
|
| 18 |
-
# --------------------------------------------------------------------------
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
import numpy as np
|
| 22 |
-
import torch
|
| 23 |
-
|
| 24 |
-
from scipy.optimize import minimize
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def inter_distances(tensors: torch.Tensor):
|
| 28 |
-
"""
|
| 29 |
-
To calculate the distance between each two depth maps.
|
| 30 |
-
"""
|
| 31 |
-
distances = []
|
| 32 |
-
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
|
| 33 |
-
arr1 = tensors[i : i + 1]
|
| 34 |
-
arr2 = tensors[j : j + 1]
|
| 35 |
-
distances.append(arr1 - arr2)
|
| 36 |
-
dist = torch.concatenate(distances, dim=0)
|
| 37 |
-
return dist
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def ensemble_depths(
|
| 41 |
-
input_images: torch.Tensor,
|
| 42 |
-
regularizer_strength: float = 0.02,
|
| 43 |
-
max_iter: int = 2,
|
| 44 |
-
tol: float = 1e-3,
|
| 45 |
-
reduction: str = "median",
|
| 46 |
-
max_res: int = None,
|
| 47 |
-
):
|
| 48 |
-
"""
|
| 49 |
-
To ensemble multiple affine-invariant depth images (up to scale and shift),
|
| 50 |
-
by aligning estimating the scale and shift
|
| 51 |
-
"""
|
| 52 |
-
device = input_images.device
|
| 53 |
-
dtype = input_images.dtype
|
| 54 |
-
np_dtype = np.float32
|
| 55 |
-
|
| 56 |
-
original_input = input_images.clone()
|
| 57 |
-
n_img = input_images.shape[0]
|
| 58 |
-
ori_shape = input_images.shape
|
| 59 |
-
|
| 60 |
-
if max_res is not None:
|
| 61 |
-
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
|
| 62 |
-
if scale_factor < 1:
|
| 63 |
-
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
|
| 64 |
-
input_images = downscaler(input_images)
|
| 65 |
-
|
| 66 |
-
# init guess
|
| 67 |
-
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
| 68 |
-
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
| 69 |
-
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
|
| 70 |
-
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
|
| 71 |
-
x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
|
| 72 |
-
|
| 73 |
-
input_images = input_images.to(device)
|
| 74 |
-
|
| 75 |
-
# objective function
|
| 76 |
-
def closure(x):
|
| 77 |
-
len_x = len(x)
|
| 78 |
-
s = x[: int(len_x / 2)]
|
| 79 |
-
t = x[int(len_x / 2) :]
|
| 80 |
-
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
| 81 |
-
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
| 82 |
-
|
| 83 |
-
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
|
| 84 |
-
dists = inter_distances(transformed_arrays)
|
| 85 |
-
sqrt_dist = torch.sqrt(torch.mean(dists**2))
|
| 86 |
-
|
| 87 |
-
if "mean" == reduction:
|
| 88 |
-
pred = torch.mean(transformed_arrays, dim=0)
|
| 89 |
-
elif "median" == reduction:
|
| 90 |
-
pred = torch.median(transformed_arrays, dim=0).values
|
| 91 |
-
else:
|
| 92 |
-
raise ValueError
|
| 93 |
-
|
| 94 |
-
near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
|
| 95 |
-
far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
|
| 96 |
-
|
| 97 |
-
err = sqrt_dist + (near_err + far_err) * regularizer_strength
|
| 98 |
-
err = err.detach().cpu().numpy().astype(np_dtype)
|
| 99 |
-
return err
|
| 100 |
-
|
| 101 |
-
res = minimize(
|
| 102 |
-
closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False}
|
| 103 |
-
)
|
| 104 |
-
x = res.x
|
| 105 |
-
len_x = len(x)
|
| 106 |
-
s = x[: int(len_x / 2)]
|
| 107 |
-
t = x[int(len_x / 2) :]
|
| 108 |
-
|
| 109 |
-
# Prediction
|
| 110 |
-
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
| 111 |
-
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
| 112 |
-
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
|
| 113 |
-
if "mean" == reduction:
|
| 114 |
-
aligned_images = torch.mean(transformed_arrays, dim=0)
|
| 115 |
-
std = torch.std(transformed_arrays, dim=0)
|
| 116 |
-
uncertainty = std
|
| 117 |
-
elif "median" == reduction:
|
| 118 |
-
aligned_images = torch.median(transformed_arrays, dim=0).values
|
| 119 |
-
# MAD (median absolute deviation) as uncertainty indicator
|
| 120 |
-
abs_dev = torch.abs(transformed_arrays - aligned_images)
|
| 121 |
-
mad = torch.median(abs_dev, dim=0).values
|
| 122 |
-
uncertainty = mad
|
| 123 |
-
else:
|
| 124 |
-
raise ValueError(f"Unknown reduction method: {reduction}")
|
| 125 |
-
|
| 126 |
-
# Scale and shift to [0, 1]
|
| 127 |
-
_min = torch.min(aligned_images)
|
| 128 |
-
_max = torch.max(aligned_images)
|
| 129 |
-
aligned_images = (aligned_images - _min) / (_max - _min)
|
| 130 |
-
uncertainty /= _max - _min
|
| 131 |
-
|
| 132 |
-
return aligned_images, uncertainty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Marigold/marigold/util/image_util.py
DELETED
|
@@ -1,121 +0,0 @@
|
|
| 1 |
-
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 2 |
-
# Last modified: 2024-04-16
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
# --------------------------------------------------------------------------
|
| 16 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 17 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
| 18 |
-
# More information about the method can be found at https://marigoldmonodepth.github.io
|
| 19 |
-
# --------------------------------------------------------------------------
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
import matplotlib
|
| 23 |
-
import numpy as np
|
| 24 |
-
import torch
|
| 25 |
-
from torchvision.transforms import InterpolationMode
|
| 26 |
-
from torchvision.transforms.functional import resize
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def colorize_depth_maps(
|
| 30 |
-
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
|
| 31 |
-
):
|
| 32 |
-
"""
|
| 33 |
-
Colorize depth maps.
|
| 34 |
-
"""
|
| 35 |
-
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
| 36 |
-
|
| 37 |
-
if isinstance(depth_map, torch.Tensor):
|
| 38 |
-
depth = depth_map.detach().squeeze().numpy()
|
| 39 |
-
elif isinstance(depth_map, np.ndarray):
|
| 40 |
-
depth = depth_map.copy().squeeze()
|
| 41 |
-
# reshape to [ (B,) H, W ]
|
| 42 |
-
if depth.ndim < 3:
|
| 43 |
-
depth = depth[np.newaxis, :, :]
|
| 44 |
-
|
| 45 |
-
# colorize
|
| 46 |
-
cm = matplotlib.colormaps[cmap]
|
| 47 |
-
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
| 48 |
-
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
|
| 49 |
-
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
| 50 |
-
|
| 51 |
-
if valid_mask is not None:
|
| 52 |
-
if isinstance(depth_map, torch.Tensor):
|
| 53 |
-
valid_mask = valid_mask.detach().numpy()
|
| 54 |
-
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
| 55 |
-
if valid_mask.ndim < 3:
|
| 56 |
-
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
| 57 |
-
else:
|
| 58 |
-
valid_mask = valid_mask[:, np.newaxis, :, :]
|
| 59 |
-
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
| 60 |
-
img_colored_np[~valid_mask] = 0
|
| 61 |
-
|
| 62 |
-
if isinstance(depth_map, torch.Tensor):
|
| 63 |
-
img_colored = torch.from_numpy(img_colored_np).float()
|
| 64 |
-
elif isinstance(depth_map, np.ndarray):
|
| 65 |
-
img_colored = img_colored_np
|
| 66 |
-
|
| 67 |
-
return img_colored
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def chw2hwc(chw):
|
| 71 |
-
assert 3 == len(chw.shape)
|
| 72 |
-
if isinstance(chw, torch.Tensor):
|
| 73 |
-
hwc = torch.permute(chw, (1, 2, 0))
|
| 74 |
-
elif isinstance(chw, np.ndarray):
|
| 75 |
-
hwc = np.moveaxis(chw, 0, -1)
|
| 76 |
-
return hwc
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def resize_max_res(
|
| 80 |
-
img: torch.Tensor,
|
| 81 |
-
max_edge_resolution: int,
|
| 82 |
-
resample_method: InterpolationMode = InterpolationMode.BILINEAR,
|
| 83 |
-
) -> torch.Tensor:
|
| 84 |
-
"""
|
| 85 |
-
Resize image to limit maximum edge length while keeping aspect ratio.
|
| 86 |
-
|
| 87 |
-
Args:
|
| 88 |
-
img (`torch.Tensor`):
|
| 89 |
-
Image tensor to be resized.
|
| 90 |
-
max_edge_resolution (`int`):
|
| 91 |
-
Maximum edge length (pixel).
|
| 92 |
-
resample_method (`PIL.Image.Resampling`):
|
| 93 |
-
Resampling method used to resize images.
|
| 94 |
-
|
| 95 |
-
Returns:
|
| 96 |
-
`torch.Tensor`: Resized image.
|
| 97 |
-
"""
|
| 98 |
-
assert 3 == img.dim()
|
| 99 |
-
_, original_height, original_width = img.shape
|
| 100 |
-
downscale_factor = min(
|
| 101 |
-
max_edge_resolution / original_width, max_edge_resolution / original_height
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
new_width = int(original_width * downscale_factor)
|
| 105 |
-
new_height = int(original_height * downscale_factor)
|
| 106 |
-
|
| 107 |
-
resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
|
| 108 |
-
return resized_img
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
def get_tv_resample_method(method_str: str) -> InterpolationMode:
|
| 112 |
-
resample_method_dict = {
|
| 113 |
-
"bilinear": InterpolationMode.BILINEAR,
|
| 114 |
-
"bicubic": InterpolationMode.BICUBIC,
|
| 115 |
-
"nearest": InterpolationMode.NEAREST_EXACT,
|
| 116 |
-
}
|
| 117 |
-
resample_method = resample_method_dict.get(method_str, None)
|
| 118 |
-
if resample_method is None:
|
| 119 |
-
raise ValueError(f"Unknown resampling method: {resample_method}")
|
| 120 |
-
else:
|
| 121 |
-
return resample_method
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|