degbo commited on
Commit
7d57572
·
1 Parent(s): 2c5006e
Files changed (1) hide show
  1. marigold/marigold_iid_pipeline.py +777 -0
marigold/marigold_iid_pipeline.py ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # More information about Marigold:
16
+ # https://marigoldmonodepth.github.io
17
+ # https://marigoldcomputervision.github.io
18
+ # Efficient inference pipelines are now part of diffusers:
19
+ # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
20
+ # https://huggingface.co/docs/diffusers/api/pipelines/marigold
21
+ # Examples of trained models and live demos:
22
+ # https://huggingface.co/prs-eth
23
+ # Related projects:
24
+ # https://rollingdepth.github.io/
25
+ # https://marigolddepthcompletion.github.io/
26
+ # Citation (BibTeX):
27
+ # https://github.com/prs-eth/Marigold#-citation
28
+ # If you find Marigold useful, we kindly ask you to cite our papers.
29
+ # --------------------------------------------------------------------------
30
+
31
+ import logging
32
+ import numpy as np
33
+ import torch
34
+ from PIL import Image
35
+ from diffusers import (
36
+ AutoencoderKL,
37
+ DDIMScheduler,
38
+ DiffusionPipeline,
39
+ LCMScheduler,
40
+ UNet2DConditionModel,
41
+ )
42
+ from torch.utils.data import DataLoader, TensorDataset
43
+ from torchvision.transforms import InterpolationMode
44
+ from torchvision.transforms.functional import pil_to_tensor, resize
45
+ from tqdm.auto import tqdm
46
+ from transformers import CLIPTextModel, CLIPTokenizer
47
+ from typing import Any, Dict, Optional, Union, List
48
+ from dataclasses import dataclass
49
+
50
+ from .util.batchsize import find_batch_size
51
+ from .util.ensemble import ensemble_iid
52
+ from .util.image_util import (
53
+ chw2hwc,
54
+ get_tv_resample_method,
55
+ resize_max_res,
56
+ )
57
+
58
+ from diffusers.loaders import (
59
+ LoraLoaderMixin,
60
+ TextualInversionLoaderMixin,
61
+ )
62
+
63
+ @dataclass
64
+ class IIDEntry:
65
+ """
66
+ A single entry in the IID output, representing one decomposed component.
67
+ For each entry we output the following properties:
68
+ name (`str`):
69
+ The name of the entry.
70
+ array (`np.ndarray`):
71
+ Predicted numpy array with the shape of [3, H, W] values in the range of [0, 1].
72
+ image (`PIL.Image.Image`):
73
+ Predicted image with the shape of [H, W, 3] and values in [0, 255].
74
+ uncertainty (`None` or `np.ndarray`):
75
+ Uncalibrated uncertainty from ensembling.
76
+ """
77
+
78
+ name: str
79
+ array: Optional[np.ndarray] = None
80
+ image: Optional[Image.Image] = None
81
+ uncertainty: Optional[np.ndarray] = None
82
+
83
+
84
+ class MarigoldIIDOutput:
85
+ """Output class for Marigold Intrinsic Image Decomposition pipelines."""
86
+
87
+ def __init__(self, target_names: List[str]):
88
+ """Initialize output container with target names.
89
+
90
+ Args:
91
+ target_names: List of names for each target component
92
+ """
93
+ self.n_targets = len(target_names)
94
+ self.target_names = target_names
95
+ self.entries: List[IIDEntry] = [IIDEntry(name=name) for name in target_names]
96
+ self._entry_map = {entry.name: entry for entry in self.entries}
97
+ self._filled_entries = set()
98
+
99
+ def fill_entry(
100
+ self,
101
+ name: str,
102
+ prediction: torch.Tensor,
103
+ uncertainty: Optional[torch.Tensor] = None,
104
+ target_properties: Optional[Dict[str, Any]] = None,
105
+ ) -> None:
106
+ """Fill a single entry with prediction data.
107
+
108
+ Args:
109
+ name: Name of the entry to fill
110
+ prediction: Tensor containing the prediction for this entry
111
+ uncertainty: Optional tensor containing uncertainty values
112
+ target_properties: Properties of the predicted targets
113
+ """
114
+ if name not in self._entry_map:
115
+ raise KeyError(f"Unknown entry name: {name}")
116
+ if name in self._filled_entries:
117
+ raise RuntimeError(f"Entry {name} already filled")
118
+
119
+ entry = self._entry_map[name]
120
+
121
+ # Process prediction
122
+ array = prediction.squeeze().cpu().numpy()
123
+ img_array = array
124
+
125
+ # Prepare image visualization
126
+ prediction_space = target_properties[name].get("prediction_space", "srgb")
127
+ if prediction_space == "stack":
128
+ pass
129
+ elif prediction_space == "linear":
130
+ up_to_scale = target_properties[name].get("up_to_scale", False)
131
+ if up_to_scale:
132
+ img_array = img_array / max(img_array.max(), 1e-6)
133
+ img_array = img_array ** (1 / 2.2)
134
+ elif prediction_space == "srgb":
135
+ pass
136
+
137
+ # Create image
138
+ img_array = (img_array * 255).astype(np.uint8)
139
+ img_array = chw2hwc(img_array) # Convert from CHW to HWC format
140
+ image = Image.fromarray(img_array)
141
+
142
+ # Process uncertainty if available
143
+ uncert_array = (
144
+ uncertainty.squeeze().cpu().numpy() if uncertainty is not None else None
145
+ )
146
+
147
+ # Update entry
148
+ entry.array = array
149
+ entry.image = image
150
+ entry.uncertainty = uncert_array
151
+
152
+ self._filled_entries.add(name)
153
+
154
+ @property
155
+ def is_complete(self) -> bool:
156
+ """Check if all entries have been filled."""
157
+ return len(self._filled_entries) == self.n_targets
158
+
159
+ def __getitem__(self, key: str) -> IIDEntry:
160
+ """Get an entry by name."""
161
+ return self._entry_map[key]
162
+
163
+ def __iter__(self):
164
+ """Iterate over entries."""
165
+ return iter(self.entries)
166
+
167
+
168
+ class MarigoldIIDPipeline(DiffusionPipeline):
169
+ """
170
+ Pipeline for Marigold Intrinsic Image Decomposition (IID): https://marigoldcomputervision.github.io.
171
+ This class supports arbitrary number of target modalities with names set in `target_names`.
172
+
173
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
174
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
175
+
176
+ Args:
177
+ unet (`UNet2DConditionModel`):
178
+ Conditional U-Net to denoise the prediction latent, conditioned on image latent.
179
+ vae (`AutoencoderKL`):
180
+ Variational Auto-Encoder (VAE) Model to encode and decode images and predictions
181
+ to and from latent representations.
182
+ scheduler (`DDIMScheduler`):
183
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
184
+ text_encoder (`CLIPTextModel`):
185
+ Text-encoder, for empty text embedding.
186
+ tokenizer (`CLIPTokenizer`):
187
+ CLIP tokenizer.
188
+ target_properties (`Dict[str, Any]`, *optional*):
189
+ Properties of the predicted modalities, such as `target_names`, a `List[str]` used to define the number,
190
+ order and names of the predicted modalities, and any other metadata that may be required to interpret the
191
+ predictions.
192
+ default_denoising_steps (`int`, *optional*):
193
+ The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
194
+ quality with the given model. This value must be set in the model config. When the pipeline is called
195
+ without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
196
+ reasonable results with various model flavors compatible with the pipeline, such as those relying on very
197
+ short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
198
+ default_processing_resolution (`int`, *optional*):
199
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
200
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
201
+ default value is used. This is required to ensure reasonable results with various model flavors trained
202
+ with varying optimal processing resolution values.
203
+ """
204
+
205
+ latent_scale_factor = 0.18215
206
+
207
+ def __init__(
208
+ self,
209
+ unet: UNet2DConditionModel,
210
+ vae: AutoencoderKL,
211
+ scheduler: Union[DDIMScheduler, LCMScheduler],
212
+ text_encoder: CLIPTextModel,
213
+ tokenizer: CLIPTokenizer,
214
+ target_properties: Optional[Dict[str, Any]] = None,
215
+ default_denoising_steps: Optional[int] = None,
216
+ default_processing_resolution: Optional[int] = None,
217
+ model_mode = "rgbx",
218
+ model_prompt = "Albedo (diffuse basecolor)",
219
+ ):
220
+ super().__init__()
221
+ self.register_modules(
222
+ unet=unet,
223
+ vae=vae,
224
+ scheduler=scheduler,
225
+ text_encoder=text_encoder,
226
+ tokenizer=tokenizer,
227
+ )
228
+
229
+ self.register_to_config(
230
+ target_properties=target_properties,
231
+ default_denoising_steps=default_denoising_steps,
232
+ default_processing_resolution=default_processing_resolution,
233
+ )
234
+
235
+ self.target_properties = target_properties
236
+ self.target_names = target_properties["target_names"]
237
+ self.n_targets = len(self.target_names)
238
+ self.mode = model_mode
239
+ self.prompt = model_prompt
240
+
241
+ self.default_denoising_steps = default_denoising_steps
242
+ self.default_processing_resolution = default_processing_resolution
243
+
244
+ self.empty_text_embed = None
245
+
246
+ def _encode_prompt(
247
+ self,
248
+ prompt,
249
+ device,
250
+ num_images_per_prompt,
251
+ do_classifier_free_guidance,
252
+ negative_prompt=None,
253
+ prompt_embeds: Optional[torch.FloatTensor] = None,
254
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
255
+ ):
256
+ r"""
257
+ Encodes the prompt into text encoder hidden states.
258
+
259
+ Args:
260
+ prompt (`str` or `List[str]`, *optional*):
261
+ prompt to be encoded
262
+ device: (`torch.device`):
263
+ torch device
264
+ num_images_per_prompt (`int`):
265
+ number of images that should be generated per prompt
266
+ do_classifier_free_guidance (`bool`):
267
+ whether to use classifier free guidance or not
268
+ negative_ prompt (`str` or `List[str]`, *optional*):
269
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
270
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
271
+ less than `1`).
272
+ prompt_embeds (`torch.FloatTensor`, *optional*):
273
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
274
+ provided, text embeddings will be generated from `prompt` input argument.
275
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
276
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
277
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
278
+ argument.
279
+ """
280
+ if prompt is not None and isinstance(prompt, str):
281
+ batch_size = 1
282
+ elif prompt is not None and isinstance(prompt, list):
283
+ batch_size = len(prompt)
284
+ else:
285
+ batch_size = prompt_embeds.shape[0]
286
+
287
+ if prompt_embeds is None:
288
+ # textual inversion: procecss multi-vector tokens if necessary
289
+ if isinstance(self, TextualInversionLoaderMixin):
290
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
291
+
292
+ text_inputs = self.tokenizer(
293
+ prompt,
294
+ padding="max_length",
295
+ max_length=self.tokenizer.model_max_length,
296
+ truncation=True,
297
+ return_tensors="pt",
298
+ )
299
+ text_input_ids = text_inputs.input_ids
300
+ untruncated_ids = self.tokenizer(
301
+ prompt, padding="longest", return_tensors="pt"
302
+ ).input_ids
303
+
304
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
305
+ -1
306
+ ] and not torch.equal(text_input_ids, untruncated_ids):
307
+ removed_text = self.tokenizer.batch_decode(
308
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
309
+ )
310
+ logging.warning(
311
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
312
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
313
+ )
314
+
315
+ if (
316
+ hasattr(self.text_encoder.config, "use_attention_mask")
317
+ and self.text_encoder.config.use_attention_mask
318
+ ):
319
+ attention_mask = text_inputs.attention_mask.to(device)
320
+ else:
321
+ attention_mask = None
322
+
323
+ prompt_embeds = self.text_encoder(
324
+ text_input_ids.to(device),
325
+ attention_mask=attention_mask,
326
+ )
327
+ prompt_embeds = prompt_embeds[0]
328
+
329
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
330
+
331
+ bs_embed, seq_len, _ = prompt_embeds.shape
332
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
333
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
334
+ prompt_embeds = prompt_embeds.view(
335
+ bs_embed * num_images_per_prompt, seq_len, -1
336
+ )
337
+
338
+ # get unconditional embeddings for classifier free guidance
339
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
340
+ uncond_tokens: List[str]
341
+ if negative_prompt is None:
342
+ uncond_tokens = [""] * batch_size
343
+ elif type(prompt) is not type(negative_prompt):
344
+ raise TypeError(
345
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
346
+ f" {type(prompt)}."
347
+ )
348
+ elif isinstance(negative_prompt, str):
349
+ uncond_tokens = [negative_prompt]
350
+ elif batch_size != len(negative_prompt):
351
+ raise ValueError(
352
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
353
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
354
+ " the batch size of `prompt`."
355
+ )
356
+ else:
357
+ uncond_tokens = negative_prompt
358
+
359
+ # textual inversion: procecss multi-vector tokens if necessary
360
+ if isinstance(self, TextualInversionLoaderMixin):
361
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
362
+
363
+ max_length = prompt_embeds.shape[1]
364
+ uncond_input = self.tokenizer(
365
+ uncond_tokens,
366
+ padding="max_length",
367
+ max_length=max_length,
368
+ truncation=True,
369
+ return_tensors="pt",
370
+ )
371
+
372
+ if (
373
+ hasattr(self.text_encoder.config, "use_attention_mask")
374
+ and self.text_encoder.config.use_attention_mask
375
+ ):
376
+ attention_mask = uncond_input.attention_mask.to(device)
377
+ else:
378
+ attention_mask = None
379
+
380
+ negative_prompt_embeds = self.text_encoder(
381
+ uncond_input.input_ids.to(device),
382
+ attention_mask=attention_mask,
383
+ )
384
+ negative_prompt_embeds = negative_prompt_embeds[0]
385
+
386
+ if do_classifier_free_guidance:
387
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
388
+ seq_len = negative_prompt_embeds.shape[1]
389
+
390
+ negative_prompt_embeds = negative_prompt_embeds.to(
391
+ dtype=self.text_encoder.dtype, device=device
392
+ )
393
+
394
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
395
+ 1, num_images_per_prompt, 1
396
+ )
397
+ negative_prompt_embeds = negative_prompt_embeds.view(
398
+ batch_size * num_images_per_prompt, seq_len, -1
399
+ )
400
+
401
+ # For classifier free guidance, we need to do two forward passes.
402
+ # Here we concatenate the unconditional and text embeddings into a single batch
403
+ # to avoid doing two forward passes
404
+ # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
405
+ prompt_embeds = torch.cat(
406
+ [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
407
+ )
408
+
409
+ return prompt_embeds
410
+
411
+ @torch.no_grad()
412
+ def __call__(
413
+ self,
414
+ input_image: Union[Image.Image, torch.Tensor],
415
+ denoising_steps: Optional[int] = None,
416
+ ensemble_size: int = 1,
417
+ processing_res: Optional[int] = None,
418
+ match_input_res: bool = True,
419
+ resample_method: str = "bilinear",
420
+ batch_size: int = 0,
421
+ generator: Union[torch.Generator, None] = None,
422
+ show_progress_bar: bool = True,
423
+ ensemble_kwargs: Dict = None,
424
+ ) -> MarigoldIIDOutput:
425
+ """
426
+ Function invoked when calling the pipeline.
427
+
428
+ Args:
429
+ input_image (`Image`):
430
+ Input RGB (or gray-scale) image.
431
+ denoising_steps (`int`, *optional*, defaults to `None`):
432
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
433
+ selection.
434
+ ensemble_size (`int`, *optional*, defaults to `1`):
435
+ Number of predictions to be ensembled.
436
+ processing_res (`int`, *optional*, defaults to `None`):
437
+ Effective processing resolution. When set to `0`, processes at the original image resolution. This
438
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
439
+ value `None` resolves to the optimal value from the model config.
440
+ match_input_res (`bool`, *optional*, defaults to `True`):
441
+ Resize the prediction to match the input resolution.
442
+ Only valid if `processing_res` > 0.
443
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
444
+ Resampling method used to resize images and predictions. This can be one of `bilinear`, `bicubic` or
445
+ `nearest`, defaults to: `bilinear`.
446
+ batch_size (`int`, *optional*, defaults to `0`):
447
+ Inference batch size, no bigger than `num_ensemble`.
448
+ If set to 0, the script will automatically decide the proper batch size.
449
+ generator (`torch.Generator`, *optional*, defaults to `None`)
450
+ Random generator for initial noise generation.
451
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
452
+ Display a progress bar of diffusion denoising.
453
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
454
+ Arguments for detailed ensembling settings.
455
+ Returns:
456
+ `MarigoldIIDOutput`: Output class for Marigold Intrinsic Image Decomposition prediction pipeline.
457
+ """
458
+ # Model-specific optimal default values leading to fast and reasonable results.
459
+ if denoising_steps is None:
460
+ denoising_steps = self.default_denoising_steps
461
+ if processing_res is None:
462
+ processing_res = self.default_processing_resolution
463
+ assert processing_res >= 0
464
+ assert ensemble_size >= 1
465
+
466
+ # Check if denoising step is reasonable
467
+ self._check_inference_step(denoising_steps)
468
+
469
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
470
+
471
+ # ----------------- Image Preprocess -----------------
472
+ # Convert to torch tensor
473
+ if isinstance(input_image, Image.Image):
474
+ input_image = input_image.convert("RGB")
475
+ # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
476
+ rgb = pil_to_tensor(input_image)
477
+ rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
478
+ elif isinstance(input_image, torch.Tensor):
479
+ rgb = input_image
480
+ else:
481
+ raise TypeError(f"Unknown input type: {type(input_image) = }")
482
+ input_size = rgb.shape
483
+ assert (
484
+ 4 == rgb.dim() and 3 == input_size[-3]
485
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
486
+
487
+ # Resize image
488
+ if processing_res > 0:
489
+ rgb = resize_max_res(
490
+ rgb,
491
+ max_edge_resolution=processing_res,
492
+ resample_method=resample_method,
493
+ )
494
+
495
+ # Normalize rgb values
496
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
497
+ rgb_norm = rgb_norm.to(self.dtype)
498
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
499
+
500
+ # ----------------- Predicting IID -----------------
501
+ # Batch repeated input image
502
+ duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
503
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
504
+ if batch_size > 0:
505
+ _bs = batch_size
506
+ else:
507
+ _bs = find_batch_size(
508
+ ensemble_size=ensemble_size,
509
+ input_res=max(rgb_norm.shape[1:]),
510
+ dtype=self.dtype,
511
+ )
512
+
513
+ single_rgb_loader = DataLoader(
514
+ single_rgb_dataset, batch_size=_bs, shuffle=False
515
+ )
516
+
517
+ # Predict IID maps (batched)
518
+ target_pred_ls = []
519
+ if show_progress_bar:
520
+ iterable = tqdm(
521
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
522
+ )
523
+ else:
524
+ iterable = single_rgb_loader
525
+ for batch in iterable:
526
+ (batched_img,) = batch
527
+ target_pred_raw = self.single_infer(
528
+ rgb_in=batched_img,
529
+ num_inference_steps=denoising_steps,
530
+ show_pbar=show_progress_bar,
531
+ generator=generator,
532
+ )
533
+ assert (
534
+ target_pred_raw.dim() == 4
535
+ and target_pred_raw.shape[1] == 3 * self.n_targets
536
+ )
537
+ target_pred_ls.append(target_pred_raw.detach())
538
+ target_preds = torch.concat(target_pred_ls, dim=0)
539
+ torch.cuda.empty_cache() # clear vram cache for ensembling
540
+
541
+ # ----------------- Test-time ensembling -----------------
542
+ if ensemble_size > 1:
543
+ final_pred, pred_uncert = ensemble_iid(
544
+ target_preds,
545
+ **(ensemble_kwargs or {}),
546
+ )
547
+ else:
548
+ final_pred = target_preds
549
+ pred_uncert = None
550
+
551
+ # Resize back to original resolution
552
+ if match_input_res:
553
+ final_pred = resize(
554
+ final_pred,
555
+ input_size[-2:],
556
+ interpolation=resample_method,
557
+ antialias=True,
558
+ )
559
+
560
+ # Create output
561
+ output = MarigoldIIDOutput(target_names=self.target_names)
562
+ self.fill_outputs(output, final_pred, pred_uncert)
563
+ assert output.is_complete
564
+ return output
565
+
566
+ def fill_outputs(
567
+ self,
568
+ output: MarigoldIIDOutput,
569
+ final_pred: torch.Tensor,
570
+ pred_uncert: Optional[torch.Tensor] = None,
571
+ ):
572
+ for i, name in enumerate(self.target_names):
573
+ start_idx = i * 3
574
+ end_idx = start_idx + 3
575
+ output.fill_entry(
576
+ name=name,
577
+ prediction=final_pred[:, start_idx:end_idx],
578
+ uncertainty=(
579
+ pred_uncert[:, start_idx:end_idx]
580
+ if pred_uncert is not None
581
+ else None
582
+ ),
583
+ target_properties=self.target_properties,
584
+ )
585
+
586
+ def _check_inference_step(self, n_step: int) -> None:
587
+ """
588
+ Check if denoising step is reasonable
589
+ Args:
590
+ n_step (`int`): denoising steps
591
+ """
592
+ assert n_step >= 1
593
+
594
+ if isinstance(self.scheduler, DDIMScheduler):
595
+ if "trailing" != self.scheduler.config.timestep_spacing:
596
+ logging.warning(
597
+ f"The loaded `DDIMScheduler` is configured with `timestep_spacing="
598
+ f'"{self.scheduler.config.timestep_spacing}"`; the recommended setting is `"trailing"`. '
599
+ f"This change is backward-compatible and yields better results. "
600
+ f"Consider using `prs-eth/marigold-iid-appearance-v1-1` or `prs-eth/marigold-iid-lighting-v1-1` "
601
+ f"for the best experience."
602
+ )
603
+ else:
604
+ if n_step > 10:
605
+ logging.warning(
606
+ f"Setting too many denoising steps ({n_step}) may degrade the prediction; consider relying on "
607
+ f"the default values."
608
+ )
609
+ if not self.scheduler.config.rescale_betas_zero_snr:
610
+ logging.warning(
611
+ f"The loaded `DDIMScheduler` is configured with `rescale_betas_zero_snr="
612
+ f"{self.scheduler.config.rescale_betas_zero_snr}`; the recommended setting is True. "
613
+ f"Consider using `prs-eth/marigold-iid-appearance-v1-1` or `prs-eth/marigold-iid-lighting-v1-1` "
614
+ f"for the best experience."
615
+ )
616
+ elif isinstance(self.scheduler, LCMScheduler):
617
+ raise RuntimeError(
618
+ "This pipeline implementation does not support the LCMScheduler. Please refer to the project "
619
+ "README.md for instructions about using LCM."
620
+ )
621
+ else:
622
+ raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
623
+
624
+ def encode_empty_text(self):
625
+ """
626
+ Encode text embedding for empty prompt
627
+ """
628
+ prompt = ""
629
+ text_inputs = self.tokenizer(
630
+ prompt,
631
+ padding="do_not_pad",
632
+ max_length=self.tokenizer.model_max_length,
633
+ truncation=True,
634
+ return_tensors="pt",
635
+ )
636
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
637
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
638
+
639
+ @torch.no_grad()
640
+ def single_infer(
641
+ self,
642
+ rgb_in: torch.Tensor,
643
+ num_inference_steps: int,
644
+ generator: Union[torch.Generator, None],
645
+ show_pbar: bool,
646
+ ) -> torch.Tensor:
647
+ """
648
+ Perform a single prediction without ensembling.
649
+
650
+ Args:
651
+ rgb_in (`torch.Tensor`):
652
+ Input RGB image.
653
+ num_inference_steps (`int`):
654
+ Number of diffusion denoisign steps (DDIM) during inference.
655
+ show_pbar (`bool`):
656
+ Display a progress bar of diffusion denoising.
657
+ generator (`torch.Generator`)
658
+ Random generator for initial noise generation.
659
+ Returns:
660
+ `torch.Tensor`: Predicted targets of shape (B,3*n_targets,H,W).
661
+ """
662
+ device = self.device
663
+ rgb_in = rgb_in.to(device)
664
+
665
+ # Set timesteps
666
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
667
+ timesteps = self.scheduler.timesteps # [T]
668
+
669
+ # Encode image
670
+ rgb_latent = self.encode_rgb(rgb_in) # [B, 4, h, w]
671
+
672
+ target_latent_shape = list(rgb_latent.shape)
673
+ target_latent_shape[1] *= self.n_targets
674
+
675
+ # Noisy latent for outputs
676
+ target_latent = torch.randn(
677
+ target_latent_shape, device=device, dtype=self.dtype, generator=generator
678
+ ) # [B, 4*n_targets, h, w]
679
+
680
+ # Batched empty text embedding
681
+ if self.empty_text_embed is None:
682
+ self.encode_empty_text()
683
+
684
+ if self.mode == "rgbx":
685
+ prompt_embeds = None
686
+ prompt_embeds = self._encode_prompt(
687
+ self.prompt,
688
+ device,
689
+ num_images_per_prompt=1,
690
+ do_classifier_free_guidance=False,
691
+ negative_prompt=None,
692
+ prompt_embeds=prompt_embeds,
693
+ negative_prompt_embeds=None,
694
+ )
695
+ batch_empty_text_embed = prompt_embeds.repeat((rgb_latent.shape[0], 1, 1)).to(device)
696
+ else:
697
+ batch_empty_text_embed = self.empty_text_embed.repeat(
698
+ (rgb_latent.shape[0], 1, 1)
699
+ ).to(device) # [B, 2, 1024]
700
+
701
+ # Denoising loop
702
+ if show_pbar:
703
+ iterable = tqdm(
704
+ enumerate(timesteps),
705
+ total=len(timesteps),
706
+ leave=False,
707
+ desc=" " * 4 + "Diffusion denoising",
708
+ )
709
+ else:
710
+ iterable = enumerate(timesteps)
711
+
712
+ for i, t in iterable:
713
+ if self.mode == "rgbx":
714
+ unet_input = torch.cat(
715
+ [target_latent, rgb_latent], dim=1
716
+ ) # this order is important
717
+ else:
718
+ unet_input = torch.cat(
719
+ [rgb_latent, target_latent], dim=1
720
+ ) # this order is important
721
+
722
+ # predict the noise residual
723
+ noise_pred = self.unet(
724
+ unet_input, t, encoder_hidden_states=batch_empty_text_embed
725
+ ).sample # [B, 4*n_targets, h, w]
726
+
727
+ # compute the previous noisy sample x_t -> x_t-1
728
+ target_latent = self.scheduler.step(
729
+ noise_pred, t, target_latent, generator=generator
730
+ ).prev_sample
731
+
732
+ targets = self.decode_targets(target_latent) # [B,3*n_targets,H,W]
733
+
734
+ # clip prediction
735
+ targets = torch.clip(targets, -1.0, 1.0)
736
+ # shift to [0, 1]
737
+ targets = (targets + 1.0) / 2.0
738
+
739
+ return targets
740
+
741
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
742
+ """
743
+ Encode RGB image into latent.
744
+
745
+ Args:
746
+ rgb_in (`torch.Tensor`):
747
+ Input RGB image to be encoded.
748
+
749
+ Returns:
750
+ `torch.Tensor`: Image latent.
751
+ """
752
+ # encode
753
+ h = self.vae.encoder(rgb_in)
754
+ moments = self.vae.quant_conv(h)
755
+ mean, logvar = torch.chunk(moments, 2, dim=1)
756
+ # scale latent
757
+ rgb_latent = mean * self.latent_scale_factor
758
+ return rgb_latent
759
+
760
+ def decode_targets(self, target_latent: torch.Tensor) -> torch.Tensor:
761
+ """
762
+ Decode target latents into image space.
763
+
764
+ Args:
765
+ target_latent: Target latent tensor of shape [B, 4*n_targets, h, w]
766
+
767
+ Returns:
768
+ Decoded target tensor of shape [B, 3*n_targets, H, W]
769
+ """
770
+ target_latent = target_latent / self.latent_scale_factor
771
+ targets = []
772
+ for i in range(self.n_targets):
773
+ latent = target_latent[:, i * 4 : (i + 1) * 4, :, :]
774
+ z = self.vae.post_quant_conv(latent)
775
+ stacked = self.vae.decoder(z)
776
+ targets.append(stacked)
777
+ return torch.cat(targets, dim=1)