developy commited on
Commit
576b7f9
·
verified ·
1 Parent(s): 6d7da3f

Update apdepth/marigold_pipeline.py

Browse files
Files changed (1) hide show
  1. apdepth/marigold_pipeline.py +31 -76
apdepth/marigold_pipeline.py CHANGED
@@ -23,6 +23,7 @@ import logging
23
  from typing import Dict, Optional, Union
24
 
25
  import numpy as np
 
26
  import torch
27
  import torch.nn as nn
28
  import torch.nn.functional as F
@@ -49,6 +50,7 @@ from .util.image_util import (
49
  get_tv_resample_method,
50
  resize_max_res,
51
  )
 
52
 
53
 
54
  class MarigoldDepthOutput(BaseOutput):
@@ -96,12 +98,6 @@ class MarigoldPipeline(DiffusionPipeline):
96
  A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
97
  the model config. When used together with the `scale_invariant=True` flag, the model is also called
98
  "affine-invariant". NB: overriding this value is not supported.
99
- default_denoising_steps (`int`, *optional*):
100
- The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
101
- quality with the given model. This value must be set in the model config. When the pipeline is called
102
- without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
103
- reasonable results with various model flavors compatible with the pipeline, such as those relying on very
104
- short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
105
  default_processing_resolution (`int`, *optional*):
106
  The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
107
  the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
@@ -116,42 +112,54 @@ class MarigoldPipeline(DiffusionPipeline):
116
  self,
117
  unet: UNet2DConditionModel,
118
  vae: AutoencoderKL,
 
119
  text_encoder: CLIPTextModel,
120
  tokenizer: CLIPTokenizer,
121
  scale_invariant: Optional[bool] = True,
122
  shift_invariant: Optional[bool] = True,
123
- default_denoising_steps: Optional[int] = None,
124
  default_processing_resolution: Optional[int] = None,
125
  ):
126
  super().__init__()
127
  self.register_modules(
128
  unet=unet,
129
  vae=vae,
 
130
  text_encoder=text_encoder,
131
  tokenizer=tokenizer,
132
  )
133
  self.register_to_config(
134
  scale_invariant=scale_invariant,
135
  shift_invariant=shift_invariant,
136
- default_denoising_steps=default_denoising_steps,
137
  default_processing_resolution=default_processing_resolution,
138
  )
139
 
140
  self.scale_invariant = scale_invariant
141
  self.shift_invariant = shift_invariant
142
- self.default_denoising_steps = default_denoising_steps
143
  self.default_processing_resolution = default_processing_resolution
144
 
145
  self.empty_text_embed = None
146
 
147
  self._fft_masks = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  @torch.no_grad()
150
  def __call__(
151
  self,
152
  input_image: Union[Image.Image, torch.Tensor],
153
- denoising_steps: Optional[int] = None,
154
- ensemble_size: int = 5,
155
  processing_res: Optional[int] = None,
156
  match_input_res: bool = True,
157
  resample_method: str = "bilinear",
@@ -166,10 +174,6 @@ class MarigoldPipeline(DiffusionPipeline):
166
  Args:
167
  input_image (`Image`):
168
  Input RGB (or gray-scale) image.
169
- denoising_steps (`int`, *optional*, defaults to `None`):
170
- Number of denoising diffusion steps during inference. The default value `None` results in automatic
171
- selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
172
- for Marigold-LCM models.
173
  ensemble_size (`int`, *optional*, defaults to `10`):
174
  Number of predictions to be ensembled.
175
  processing_res (`int`, *optional*, defaults to `None`):
@@ -209,9 +213,6 @@ class MarigoldPipeline(DiffusionPipeline):
209
 
210
  assert processing_res >= 0
211
 
212
- # Check if denoising step is reasonable
213
- # self._check_inference_step(denoising_steps)
214
-
215
  resample_method: InterpolationMode = get_tv_resample_method(resample_method)
216
 
217
  # ----------------- Image Preprocess -----------------
@@ -245,7 +246,7 @@ class MarigoldPipeline(DiffusionPipeline):
245
 
246
  # ----------------- Predicting depth -----------------
247
  # Batch repeated input image
248
- duplicated_rgb = rgb_norm.expand(1, -1, -1, -1)
249
  single_rgb_dataset = TensorDataset(duplicated_rgb)
250
  if batch_size > 0:
251
  _bs = batch_size
@@ -321,27 +322,6 @@ class MarigoldPipeline(DiffusionPipeline):
321
  uncertainty=pred_uncert,
322
  )
323
 
324
- def _check_inference_step(self, n_step: int) -> None:
325
- """
326
- Check if denoising step is reasonable
327
- Args:
328
- n_step (`int`): denoising steps
329
- """
330
- assert n_step >= 1
331
-
332
- if isinstance(self.scheduler, DDIMScheduler):
333
- if n_step < 10:
334
- logging.warning(
335
- f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
336
- )
337
- elif isinstance(self.scheduler, LCMScheduler):
338
- if not 1 <= n_step <= 4:
339
- logging.warning(
340
- f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps."
341
- )
342
- else:
343
- raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
344
-
345
  def encode_empty_text(self):
346
  """
347
  Encode text embedding for empty prompt
@@ -357,36 +337,6 @@ class MarigoldPipeline(DiffusionPipeline):
357
  text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
358
  self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
359
 
360
- @torch.no_grad()
361
- def _get_highpass_mask(self, H, W, radius, device):
362
- key = (H, W, radius, device)
363
- if key not in self._fft_masks:
364
- yy, xx = torch.meshgrid(torch.arange(H, device=device),
365
- torch.arange(W, device=device),
366
- indexing="ij")
367
- yy = yy - H // 2
368
- xx = xx - W // 2
369
- mask_low = (xx**2 + yy**2 <= radius**2).float()
370
- mask_high = 1 - mask_low
371
- mask_high = mask_high[None, None, :, :]
372
- self._fft_masks[key] = mask_high
373
- return self._fft_masks[key]
374
-
375
- @torch.no_grad()
376
- def rgb_fft(self, x: torch.Tensor, highpass_radius: int = 30):
377
- B, C, H, W = x.shape
378
- device = x.device
379
-
380
- f = torch.fft.fft2(x, norm="ortho")
381
- fshift = torch.fft.fftshift(f)
382
-
383
- mask_high = self._get_highpass_mask(H, W, highpass_radius, device)
384
- fshift_high = fshift * mask_high
385
-
386
- f_ishift = torch.fft.ifftshift(fshift_high)
387
- img_high = torch.fft.ifft2(f_ishift, norm="ortho").real
388
- return img_high
389
-
390
  @torch.no_grad()
391
  def single_infer(
392
  self,
@@ -408,12 +358,14 @@ class MarigoldPipeline(DiffusionPipeline):
408
  `torch.Tensor`: Predicted depth map.
409
  """
410
  device = self.device
 
411
  rgb_in = rgb_in.to(device)
412
- rgb_fft = self.rgb_fft(rgb_in)
413
 
414
- # Encode image
415
- rgb_latent = self.encode_rgb(rgb_in)
416
- # rgb_fft_latent = self.encode_rgb(rgb_fft)
 
417
 
418
  # Batched empty text embedding
419
  if self.empty_text_embed is None:
@@ -422,10 +374,13 @@ class MarigoldPipeline(DiffusionPipeline):
422
  (rgb_latent.shape[0], 1, 1)
423
  ).to(device) # [B, 2, 1024]
424
 
425
- # unet_input = torch.cat([rgb_latent, rgb_fft_latent],dim=1)
 
 
 
426
 
427
  depth_latent = self.unet(
428
- rgb_latent, 1, encoder_hidden_states=batch_empty_text_embed
429
  ).sample # [B, 4, h, w]
430
 
431
  depth = self.decode_depth(depth_latent)
 
23
  from typing import Dict, Optional, Union
24
 
25
  import numpy as np
26
+ import cv2
27
  import torch
28
  import torch.nn as nn
29
  import torch.nn.functional as F
 
50
  get_tv_resample_method,
51
  resize_max_res,
52
  )
53
+ from DA2.depth_anything_v2.dpt import DepthAnythingV2
54
 
55
 
56
  class MarigoldDepthOutput(BaseOutput):
 
98
  A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
99
  the model config. When used together with the `scale_invariant=True` flag, the model is also called
100
  "affine-invariant". NB: overriding this value is not supported.
 
 
 
 
 
 
101
  default_processing_resolution (`int`, *optional*):
102
  The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
103
  the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
 
112
  self,
113
  unet: UNet2DConditionModel,
114
  vae: AutoencoderKL,
115
+ scheduler: Union[DDIMScheduler, LCMScheduler],
116
  text_encoder: CLIPTextModel,
117
  tokenizer: CLIPTokenizer,
118
  scale_invariant: Optional[bool] = True,
119
  shift_invariant: Optional[bool] = True,
 
120
  default_processing_resolution: Optional[int] = None,
121
  ):
122
  super().__init__()
123
  self.register_modules(
124
  unet=unet,
125
  vae=vae,
126
+ scheduler=scheduler,
127
  text_encoder=text_encoder,
128
  tokenizer=tokenizer,
129
  )
130
  self.register_to_config(
131
  scale_invariant=scale_invariant,
132
  shift_invariant=shift_invariant,
 
133
  default_processing_resolution=default_processing_resolution,
134
  )
135
 
136
  self.scale_invariant = scale_invariant
137
  self.shift_invariant = shift_invariant
 
138
  self.default_processing_resolution = default_processing_resolution
139
 
140
  self.empty_text_embed = None
141
 
142
  self._fft_masks = {}
143
+
144
+ da2_config = {
145
+ 'encoder': 'vits', # 'vits', 'vitb', 'vitl', 'vitg'
146
+ 'features': 64,
147
+ 'out_channels': [48, 96, 192, 384],
148
+ }
149
+
150
+ # 初始化 DA2 模型
151
+ if da2_config is not None:
152
+ self.da2 = DepthAnythingV2(**da2_config)
153
+ self.da2.load_state_dict(torch.load(f'/root/Marigold/DA2/checkpoints/depth_anything_v2_{da2_config["encoder"]}.pth', map_location='cpu'))
154
+ self.da2.to(device="cuda").eval()
155
+ else:
156
+ self.da2 = None
157
 
158
  @torch.no_grad()
159
  def __call__(
160
  self,
161
  input_image: Union[Image.Image, torch.Tensor],
162
+ ensemble_size: int = 1,
 
163
  processing_res: Optional[int] = None,
164
  match_input_res: bool = True,
165
  resample_method: str = "bilinear",
 
174
  Args:
175
  input_image (`Image`):
176
  Input RGB (or gray-scale) image.
 
 
 
 
177
  ensemble_size (`int`, *optional*, defaults to `10`):
178
  Number of predictions to be ensembled.
179
  processing_res (`int`, *optional*, defaults to `None`):
 
213
 
214
  assert processing_res >= 0
215
 
 
 
 
216
  resample_method: InterpolationMode = get_tv_resample_method(resample_method)
217
 
218
  # ----------------- Image Preprocess -----------------
 
246
 
247
  # ----------------- Predicting depth -----------------
248
  # Batch repeated input image
249
+ duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
250
  single_rgb_dataset = TensorDataset(duplicated_rgb)
251
  if batch_size > 0:
252
  _bs = batch_size
 
322
  uncertainty=pred_uncert,
323
  )
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  def encode_empty_text(self):
326
  """
327
  Encode text embedding for empty prompt
 
337
  text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
338
  self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  @torch.no_grad()
341
  def single_infer(
342
  self,
 
358
  `torch.Tensor`: Predicted depth map.
359
  """
360
  device = self.device
361
+ # preprare data
362
  rgb_in = rgb_in.to(device)
363
+ depth_da2 = self.da2.infer_batch(rgb_in).to(device)
364
 
365
+ with torch.no_grad():
366
+ # Encode image
367
+ rgb_latent = self.encode_rgb(rgb_in)
368
+ depth_da2_latent = self.encode_rgb(depth_da2)
369
 
370
  # Batched empty text embedding
371
  if self.empty_text_embed is None:
 
374
  (rgb_latent.shape[0], 1, 1)
375
  ).to(device) # [B, 2, 1024]
376
 
377
+ # get input
378
+ unet_input = torch.cat(
379
+ [depth_da2_latent, rgb_latent],dim=1
380
+ ) # this order is important
381
 
382
  depth_latent = self.unet(
383
+ unet_input, 1, encoder_hidden_states=batch_empty_text_embed
384
  ).sample # [B, 4, h, w]
385
 
386
  depth = self.decode_depth(depth_latent)