Update apdepth/marigold_pipeline.py
Browse files- apdepth/marigold_pipeline.py +31 -76
apdepth/marigold_pipeline.py
CHANGED
|
@@ -23,6 +23,7 @@ import logging
|
|
| 23 |
from typing import Dict, Optional, Union
|
| 24 |
|
| 25 |
import numpy as np
|
|
|
|
| 26 |
import torch
|
| 27 |
import torch.nn as nn
|
| 28 |
import torch.nn.functional as F
|
|
@@ -49,6 +50,7 @@ from .util.image_util import (
|
|
| 49 |
get_tv_resample_method,
|
| 50 |
resize_max_res,
|
| 51 |
)
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
class MarigoldDepthOutput(BaseOutput):
|
|
@@ -96,12 +98,6 @@ class MarigoldPipeline(DiffusionPipeline):
|
|
| 96 |
A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
|
| 97 |
the model config. When used together with the `scale_invariant=True` flag, the model is also called
|
| 98 |
"affine-invariant". NB: overriding this value is not supported.
|
| 99 |
-
default_denoising_steps (`int`, *optional*):
|
| 100 |
-
The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
|
| 101 |
-
quality with the given model. This value must be set in the model config. When the pipeline is called
|
| 102 |
-
without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
|
| 103 |
-
reasonable results with various model flavors compatible with the pipeline, such as those relying on very
|
| 104 |
-
short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
|
| 105 |
default_processing_resolution (`int`, *optional*):
|
| 106 |
The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
|
| 107 |
the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
|
|
@@ -116,42 +112,54 @@ class MarigoldPipeline(DiffusionPipeline):
|
|
| 116 |
self,
|
| 117 |
unet: UNet2DConditionModel,
|
| 118 |
vae: AutoencoderKL,
|
|
|
|
| 119 |
text_encoder: CLIPTextModel,
|
| 120 |
tokenizer: CLIPTokenizer,
|
| 121 |
scale_invariant: Optional[bool] = True,
|
| 122 |
shift_invariant: Optional[bool] = True,
|
| 123 |
-
default_denoising_steps: Optional[int] = None,
|
| 124 |
default_processing_resolution: Optional[int] = None,
|
| 125 |
):
|
| 126 |
super().__init__()
|
| 127 |
self.register_modules(
|
| 128 |
unet=unet,
|
| 129 |
vae=vae,
|
|
|
|
| 130 |
text_encoder=text_encoder,
|
| 131 |
tokenizer=tokenizer,
|
| 132 |
)
|
| 133 |
self.register_to_config(
|
| 134 |
scale_invariant=scale_invariant,
|
| 135 |
shift_invariant=shift_invariant,
|
| 136 |
-
default_denoising_steps=default_denoising_steps,
|
| 137 |
default_processing_resolution=default_processing_resolution,
|
| 138 |
)
|
| 139 |
|
| 140 |
self.scale_invariant = scale_invariant
|
| 141 |
self.shift_invariant = shift_invariant
|
| 142 |
-
self.default_denoising_steps = default_denoising_steps
|
| 143 |
self.default_processing_resolution = default_processing_resolution
|
| 144 |
|
| 145 |
self.empty_text_embed = None
|
| 146 |
|
| 147 |
self._fft_masks = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
@torch.no_grad()
|
| 150 |
def __call__(
|
| 151 |
self,
|
| 152 |
input_image: Union[Image.Image, torch.Tensor],
|
| 153 |
-
|
| 154 |
-
ensemble_size: int = 5,
|
| 155 |
processing_res: Optional[int] = None,
|
| 156 |
match_input_res: bool = True,
|
| 157 |
resample_method: str = "bilinear",
|
|
@@ -166,10 +174,6 @@ class MarigoldPipeline(DiffusionPipeline):
|
|
| 166 |
Args:
|
| 167 |
input_image (`Image`):
|
| 168 |
Input RGB (or gray-scale) image.
|
| 169 |
-
denoising_steps (`int`, *optional*, defaults to `None`):
|
| 170 |
-
Number of denoising diffusion steps during inference. The default value `None` results in automatic
|
| 171 |
-
selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
|
| 172 |
-
for Marigold-LCM models.
|
| 173 |
ensemble_size (`int`, *optional*, defaults to `10`):
|
| 174 |
Number of predictions to be ensembled.
|
| 175 |
processing_res (`int`, *optional*, defaults to `None`):
|
|
@@ -209,9 +213,6 @@ class MarigoldPipeline(DiffusionPipeline):
|
|
| 209 |
|
| 210 |
assert processing_res >= 0
|
| 211 |
|
| 212 |
-
# Check if denoising step is reasonable
|
| 213 |
-
# self._check_inference_step(denoising_steps)
|
| 214 |
-
|
| 215 |
resample_method: InterpolationMode = get_tv_resample_method(resample_method)
|
| 216 |
|
| 217 |
# ----------------- Image Preprocess -----------------
|
|
@@ -245,7 +246,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
|
| 245 |
|
| 246 |
# ----------------- Predicting depth -----------------
|
| 247 |
# Batch repeated input image
|
| 248 |
-
duplicated_rgb = rgb_norm.expand(
|
| 249 |
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
| 250 |
if batch_size > 0:
|
| 251 |
_bs = batch_size
|
|
@@ -321,27 +322,6 @@ class MarigoldPipeline(DiffusionPipeline):
|
|
| 321 |
uncertainty=pred_uncert,
|
| 322 |
)
|
| 323 |
|
| 324 |
-
def _check_inference_step(self, n_step: int) -> None:
|
| 325 |
-
"""
|
| 326 |
-
Check if denoising step is reasonable
|
| 327 |
-
Args:
|
| 328 |
-
n_step (`int`): denoising steps
|
| 329 |
-
"""
|
| 330 |
-
assert n_step >= 1
|
| 331 |
-
|
| 332 |
-
if isinstance(self.scheduler, DDIMScheduler):
|
| 333 |
-
if n_step < 10:
|
| 334 |
-
logging.warning(
|
| 335 |
-
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
|
| 336 |
-
)
|
| 337 |
-
elif isinstance(self.scheduler, LCMScheduler):
|
| 338 |
-
if not 1 <= n_step <= 4:
|
| 339 |
-
logging.warning(
|
| 340 |
-
f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps."
|
| 341 |
-
)
|
| 342 |
-
else:
|
| 343 |
-
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
|
| 344 |
-
|
| 345 |
def encode_empty_text(self):
|
| 346 |
"""
|
| 347 |
Encode text embedding for empty prompt
|
|
@@ -357,36 +337,6 @@ class MarigoldPipeline(DiffusionPipeline):
|
|
| 357 |
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
| 358 |
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
| 359 |
|
| 360 |
-
@torch.no_grad()
|
| 361 |
-
def _get_highpass_mask(self, H, W, radius, device):
|
| 362 |
-
key = (H, W, radius, device)
|
| 363 |
-
if key not in self._fft_masks:
|
| 364 |
-
yy, xx = torch.meshgrid(torch.arange(H, device=device),
|
| 365 |
-
torch.arange(W, device=device),
|
| 366 |
-
indexing="ij")
|
| 367 |
-
yy = yy - H // 2
|
| 368 |
-
xx = xx - W // 2
|
| 369 |
-
mask_low = (xx**2 + yy**2 <= radius**2).float()
|
| 370 |
-
mask_high = 1 - mask_low
|
| 371 |
-
mask_high = mask_high[None, None, :, :]
|
| 372 |
-
self._fft_masks[key] = mask_high
|
| 373 |
-
return self._fft_masks[key]
|
| 374 |
-
|
| 375 |
-
@torch.no_grad()
|
| 376 |
-
def rgb_fft(self, x: torch.Tensor, highpass_radius: int = 30):
|
| 377 |
-
B, C, H, W = x.shape
|
| 378 |
-
device = x.device
|
| 379 |
-
|
| 380 |
-
f = torch.fft.fft2(x, norm="ortho")
|
| 381 |
-
fshift = torch.fft.fftshift(f)
|
| 382 |
-
|
| 383 |
-
mask_high = self._get_highpass_mask(H, W, highpass_radius, device)
|
| 384 |
-
fshift_high = fshift * mask_high
|
| 385 |
-
|
| 386 |
-
f_ishift = torch.fft.ifftshift(fshift_high)
|
| 387 |
-
img_high = torch.fft.ifft2(f_ishift, norm="ortho").real
|
| 388 |
-
return img_high
|
| 389 |
-
|
| 390 |
@torch.no_grad()
|
| 391 |
def single_infer(
|
| 392 |
self,
|
|
@@ -408,12 +358,14 @@ class MarigoldPipeline(DiffusionPipeline):
|
|
| 408 |
`torch.Tensor`: Predicted depth map.
|
| 409 |
"""
|
| 410 |
device = self.device
|
|
|
|
| 411 |
rgb_in = rgb_in.to(device)
|
| 412 |
-
|
| 413 |
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
|
|
|
| 417 |
|
| 418 |
# Batched empty text embedding
|
| 419 |
if self.empty_text_embed is None:
|
|
@@ -422,10 +374,13 @@ class MarigoldPipeline(DiffusionPipeline):
|
|
| 422 |
(rgb_latent.shape[0], 1, 1)
|
| 423 |
).to(device) # [B, 2, 1024]
|
| 424 |
|
| 425 |
-
#
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
depth_latent = self.unet(
|
| 428 |
-
|
| 429 |
).sample # [B, 4, h, w]
|
| 430 |
|
| 431 |
depth = self.decode_depth(depth_latent)
|
|
|
|
| 23 |
from typing import Dict, Optional, Union
|
| 24 |
|
| 25 |
import numpy as np
|
| 26 |
+
import cv2
|
| 27 |
import torch
|
| 28 |
import torch.nn as nn
|
| 29 |
import torch.nn.functional as F
|
|
|
|
| 50 |
get_tv_resample_method,
|
| 51 |
resize_max_res,
|
| 52 |
)
|
| 53 |
+
from DA2.depth_anything_v2.dpt import DepthAnythingV2
|
| 54 |
|
| 55 |
|
| 56 |
class MarigoldDepthOutput(BaseOutput):
|
|
|
|
| 98 |
A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
|
| 99 |
the model config. When used together with the `scale_invariant=True` flag, the model is also called
|
| 100 |
"affine-invariant". NB: overriding this value is not supported.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
default_processing_resolution (`int`, *optional*):
|
| 102 |
The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
|
| 103 |
the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
|
|
|
|
| 112 |
self,
|
| 113 |
unet: UNet2DConditionModel,
|
| 114 |
vae: AutoencoderKL,
|
| 115 |
+
scheduler: Union[DDIMScheduler, LCMScheduler],
|
| 116 |
text_encoder: CLIPTextModel,
|
| 117 |
tokenizer: CLIPTokenizer,
|
| 118 |
scale_invariant: Optional[bool] = True,
|
| 119 |
shift_invariant: Optional[bool] = True,
|
|
|
|
| 120 |
default_processing_resolution: Optional[int] = None,
|
| 121 |
):
|
| 122 |
super().__init__()
|
| 123 |
self.register_modules(
|
| 124 |
unet=unet,
|
| 125 |
vae=vae,
|
| 126 |
+
scheduler=scheduler,
|
| 127 |
text_encoder=text_encoder,
|
| 128 |
tokenizer=tokenizer,
|
| 129 |
)
|
| 130 |
self.register_to_config(
|
| 131 |
scale_invariant=scale_invariant,
|
| 132 |
shift_invariant=shift_invariant,
|
|
|
|
| 133 |
default_processing_resolution=default_processing_resolution,
|
| 134 |
)
|
| 135 |
|
| 136 |
self.scale_invariant = scale_invariant
|
| 137 |
self.shift_invariant = shift_invariant
|
|
|
|
| 138 |
self.default_processing_resolution = default_processing_resolution
|
| 139 |
|
| 140 |
self.empty_text_embed = None
|
| 141 |
|
| 142 |
self._fft_masks = {}
|
| 143 |
+
|
| 144 |
+
da2_config = {
|
| 145 |
+
'encoder': 'vits', # 'vits', 'vitb', 'vitl', 'vitg'
|
| 146 |
+
'features': 64,
|
| 147 |
+
'out_channels': [48, 96, 192, 384],
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
# 初始化 DA2 模型
|
| 151 |
+
if da2_config is not None:
|
| 152 |
+
self.da2 = DepthAnythingV2(**da2_config)
|
| 153 |
+
self.da2.load_state_dict(torch.load(f'/root/Marigold/DA2/checkpoints/depth_anything_v2_{da2_config["encoder"]}.pth', map_location='cpu'))
|
| 154 |
+
self.da2.to(device="cuda").eval()
|
| 155 |
+
else:
|
| 156 |
+
self.da2 = None
|
| 157 |
|
| 158 |
@torch.no_grad()
|
| 159 |
def __call__(
|
| 160 |
self,
|
| 161 |
input_image: Union[Image.Image, torch.Tensor],
|
| 162 |
+
ensemble_size: int = 1,
|
|
|
|
| 163 |
processing_res: Optional[int] = None,
|
| 164 |
match_input_res: bool = True,
|
| 165 |
resample_method: str = "bilinear",
|
|
|
|
| 174 |
Args:
|
| 175 |
input_image (`Image`):
|
| 176 |
Input RGB (or gray-scale) image.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
ensemble_size (`int`, *optional*, defaults to `10`):
|
| 178 |
Number of predictions to be ensembled.
|
| 179 |
processing_res (`int`, *optional*, defaults to `None`):
|
|
|
|
| 213 |
|
| 214 |
assert processing_res >= 0
|
| 215 |
|
|
|
|
|
|
|
|
|
|
| 216 |
resample_method: InterpolationMode = get_tv_resample_method(resample_method)
|
| 217 |
|
| 218 |
# ----------------- Image Preprocess -----------------
|
|
|
|
| 246 |
|
| 247 |
# ----------------- Predicting depth -----------------
|
| 248 |
# Batch repeated input image
|
| 249 |
+
duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
|
| 250 |
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
| 251 |
if batch_size > 0:
|
| 252 |
_bs = batch_size
|
|
|
|
| 322 |
uncertainty=pred_uncert,
|
| 323 |
)
|
| 324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
def encode_empty_text(self):
|
| 326 |
"""
|
| 327 |
Encode text embedding for empty prompt
|
|
|
|
| 337 |
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
| 338 |
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
@torch.no_grad()
|
| 341 |
def single_infer(
|
| 342 |
self,
|
|
|
|
| 358 |
`torch.Tensor`: Predicted depth map.
|
| 359 |
"""
|
| 360 |
device = self.device
|
| 361 |
+
# preprare data
|
| 362 |
rgb_in = rgb_in.to(device)
|
| 363 |
+
depth_da2 = self.da2.infer_batch(rgb_in).to(device)
|
| 364 |
|
| 365 |
+
with torch.no_grad():
|
| 366 |
+
# Encode image
|
| 367 |
+
rgb_latent = self.encode_rgb(rgb_in)
|
| 368 |
+
depth_da2_latent = self.encode_rgb(depth_da2)
|
| 369 |
|
| 370 |
# Batched empty text embedding
|
| 371 |
if self.empty_text_embed is None:
|
|
|
|
| 374 |
(rgb_latent.shape[0], 1, 1)
|
| 375 |
).to(device) # [B, 2, 1024]
|
| 376 |
|
| 377 |
+
# get input
|
| 378 |
+
unet_input = torch.cat(
|
| 379 |
+
[depth_da2_latent, rgb_latent],dim=1
|
| 380 |
+
) # this order is important
|
| 381 |
|
| 382 |
depth_latent = self.unet(
|
| 383 |
+
unet_input, 1, encoder_hidden_states=batch_empty_text_embed
|
| 384 |
).sample # [B, 4, h, w]
|
| 385 |
|
| 386 |
depth = self.decode_depth(depth_latent)
|