yosepyossi commited on
Commit
03371fb
·
verified ·
1 Parent(s): 160a8a7

Upload pipeline_mvrag.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline_mvrag.py +535 -0
pipeline_mvrag.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import einops
3
+ import inspect
4
+ from torchvision.transforms import v2
5
+ from typing import List, Optional
6
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
7
+ from diffusers import AutoencoderKL, DiffusionPipeline
8
+ from diffusers.utils import (
9
+ deprecate,
10
+ is_accelerate_available,
11
+ is_accelerate_version,
12
+ logging,
13
+ )
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from diffusers.configuration_utils import FrozenDict
16
+ from diffusers.schedulers import DDIMScheduler
17
+
18
+ from modules.mv_unet import MultiViewUNetModel
19
+ from utils import *
20
+ from modules.f_score import SimilarityModel
21
+ from modules.resampler import Resampler
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class MVRAGPipeline(DiffusionPipeline):
27
+ def __init__(
28
+ self,
29
+ vae: AutoencoderKL,
30
+ unet: MultiViewUNetModel,
31
+ tokenizer: CLIPTokenizer,
32
+ text_encoder: CLIPTextModel,
33
+ scheduler: DDIMScheduler,
34
+ feature_extractor: CLIPImageProcessor,
35
+ image_encoder: CLIPVisionModel,
36
+ resampler: Resampler,
37
+ requires_safety_checker: bool = False,
38
+ ):
39
+ super().__init__()
40
+
41
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
42
+ deprecation_message = (
43
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
44
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
45
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
46
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
47
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
48
+ " file"
49
+ )
50
+ deprecate(
51
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
52
+ )
53
+ new_config = dict(scheduler.config)
54
+ new_config["steps_offset"] = 1
55
+ scheduler._internal_dict = FrozenDict(new_config)
56
+
57
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
58
+ deprecation_message = (
59
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
60
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
61
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
62
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
63
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
64
+ )
65
+ deprecate(
66
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
67
+ )
68
+ new_config = dict(scheduler.config)
69
+ new_config["clip_sample"] = False
70
+ scheduler._internal_dict = FrozenDict(new_config)
71
+ self.similarity_model = SimilarityModel()
72
+ self.register_modules(
73
+ vae=vae,
74
+ unet=unet,
75
+ scheduler=scheduler,
76
+ tokenizer=tokenizer,
77
+ text_encoder=text_encoder,
78
+ feature_extractor=feature_extractor,
79
+ image_encoder=image_encoder,
80
+ resampler=resampler,
81
+ )
82
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
83
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
84
+
85
+ self.transform = v2.Compose([
86
+ pad_to_square,
87
+ v2.Resize(size=256),
88
+ ])
89
+
90
+ def to(self, device, **kwargs):
91
+ super().to(device, **kwargs)
92
+ if self.similarity_model is not None:
93
+ self.similarity_model.to(device)
94
+ return self
95
+
96
+
97
+ def enable_vae_slicing(self):
98
+ r"""
99
+ Enable sliced VAE decoding.
100
+
101
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
102
+ steps. This is useful to save some memory and allow larger batch sizes.
103
+ """
104
+ self.vae.enable_slicing()
105
+
106
+ def disable_vae_slicing(self):
107
+ r"""
108
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
109
+ computing decoding in one step.
110
+ """
111
+ self.vae.disable_slicing()
112
+
113
+ def enable_vae_tiling(self):
114
+ r"""
115
+ Enable tiled VAE decoding.
116
+
117
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
118
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
119
+ """
120
+ self.vae.enable_tiling()
121
+
122
+ def disable_vae_tiling(self):
123
+ r"""
124
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
125
+ computing decoding in one step.
126
+ """
127
+ self.vae.disable_tiling()
128
+
129
+ def enable_sequential_cpu_offload(self, gpu_id=0):
130
+ r"""
131
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
132
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
133
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
134
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
135
+ `enable_model_cpu_offload`, but performance is lower.
136
+ """
137
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
138
+ from accelerate import cpu_offload
139
+ else:
140
+ raise ImportError(
141
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
142
+ )
143
+
144
+ device = torch.device(f"cuda:{gpu_id}")
145
+
146
+ if self.device.type != "cpu":
147
+ self.to("cpu", silence_dtype_warnings=True)
148
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
149
+
150
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
151
+ cpu_offload(cpu_offloaded_model, device)
152
+
153
+ def enable_model_cpu_offload(self, gpu_id=0):
154
+ r"""
155
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
156
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
157
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
158
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
159
+ """
160
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
161
+ from accelerate import cpu_offload_with_hook
162
+ else:
163
+ raise ImportError(
164
+ "`enable_model_offload` requires `accelerate v0.17.0` or higher."
165
+ )
166
+
167
+ device = torch.device(f"cuda:{gpu_id}")
168
+
169
+ if self.device.type != "cpu":
170
+ self.to("cpu", silence_dtype_warnings=True)
171
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
172
+
173
+ hook = None
174
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
175
+ _, hook = cpu_offload_with_hook(
176
+ cpu_offloaded_model, device, prev_module_hook=hook
177
+ )
178
+
179
+ # We'll offload the last model manually.
180
+ self.final_offload_hook = hook
181
+
182
+ @property
183
+ def _execution_device(self):
184
+ r"""
185
+ Returns the device on which the pipeline's models will be executed. After calling
186
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
187
+ hooks.
188
+ """
189
+ if not hasattr(self.unet, "_hf_hook"):
190
+ return self.device
191
+ for module in self.unet.modules():
192
+ if (
193
+ hasattr(module, "_hf_hook")
194
+ and hasattr(module._hf_hook, "execution_device")
195
+ and module._hf_hook.execution_device is not None
196
+ ):
197
+ return torch.device(module._hf_hook.execution_device)
198
+ return self.device
199
+
200
+ def _encode_prompt(
201
+ self,
202
+ prompt,
203
+ device,
204
+ do_classifier_free_guidance: bool,
205
+ negative_prompt=None,
206
+ ):
207
+ r"""
208
+ Encodes the prompt into text encoder hidden states.
209
+
210
+ Args:
211
+ prompt (`str` or `List[str]`, *optional*):
212
+ prompt to be encoded
213
+ device: (`torch.device`):
214
+ torch device
215
+ do_classifier_free_guidance (`bool`):
216
+ whether to use classifier free guidance or not
217
+ negative_prompt (`str` or `List[str]`, *optional*):
218
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
219
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
220
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
221
+ """
222
+ if prompt is not None and isinstance(prompt, str):
223
+ batch_size = 1
224
+ if prompt.endswith('.'):
225
+ prompt = prompt[:-1]
226
+ prompt = [prompt + ", 3d asset"]
227
+
228
+ elif prompt is not None and isinstance(prompt, list):
229
+ batch_size = len(prompt)
230
+ prompt = [p[:-1] + ", 3d asset" if p.endswith(".") else p + ", 3d asset" for p in prompt]
231
+ else:
232
+ raise ValueError(
233
+ f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
234
+ )
235
+
236
+ text_inputs = self.tokenizer(
237
+ prompt,
238
+ padding="max_length",
239
+ max_length=self.tokenizer.model_max_length,
240
+ truncation=True,
241
+ return_tensors="pt",
242
+ )
243
+ text_input_ids = text_inputs.input_ids
244
+ untruncated_ids = self.tokenizer(
245
+ prompt, padding="longest", return_tensors="pt"
246
+ ).input_ids
247
+
248
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
249
+ text_input_ids, untruncated_ids
250
+ ):
251
+ removed_text = self.tokenizer.batch_decode(
252
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
253
+ )
254
+ logger.warning(
255
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
256
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
257
+ )
258
+
259
+ if (
260
+ hasattr(self.text_encoder.config, "use_attention_mask")
261
+ and self.text_encoder.config.use_attention_mask
262
+ ):
263
+ attention_mask = text_inputs.attention_mask.to(device)
264
+ else:
265
+ attention_mask = None
266
+
267
+ prompt_embeds = self.text_encoder(
268
+ text_input_ids.to(device),
269
+ attention_mask=attention_mask,
270
+ )
271
+ prompt_embeds = prompt_embeds[0]
272
+
273
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
274
+
275
+ # get unconditional embeddings for classifier free guidance
276
+ if do_classifier_free_guidance:
277
+ uncond_tokens: List[str]
278
+ if negative_prompt is None:
279
+ uncond_tokens = [""] * batch_size
280
+ elif isinstance(negative_prompt, str):
281
+ uncond_tokens = [negative_prompt] * batch_size
282
+ else:
283
+ uncond_tokens = negative_prompt
284
+
285
+ max_length = prompt_embeds.shape[1]
286
+ uncond_input = self.tokenizer(
287
+ uncond_tokens,
288
+ padding="max_length",
289
+ max_length=max_length,
290
+ truncation=True,
291
+ return_tensors="pt",
292
+ )
293
+
294
+ if (
295
+ hasattr(self.text_encoder.config, "use_attention_mask")
296
+ and self.text_encoder.config.use_attention_mask
297
+ ):
298
+ attention_mask = uncond_input.attention_mask.to(device)
299
+ else:
300
+ attention_mask = None
301
+
302
+ negative_prompt_embeds = self.text_encoder(
303
+ uncond_input.input_ids.to(device),
304
+ attention_mask=attention_mask,
305
+ )[0]
306
+ negative_prompt_embeds = negative_prompt_embeds.to(
307
+ dtype=self.text_encoder.dtype, device=device
308
+ )
309
+ # For classifier free guidance, we need to do two forward passes.
310
+ # Here we concatenate the unconditional and text embeddings into a single batch
311
+ # to avoid doing two forward passes
312
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
313
+
314
+ return prompt_embeds
315
+
316
+ def decode_latents(self, latents):
317
+ latents = 1 / self.vae.config.scaling_factor * latents
318
+ image = self.vae.decode(latents).sample
319
+ image = (image / 2 + 0.5).clamp(0, 1)
320
+ return image
321
+
322
+ def prepare_extra_step_kwargs(self, generator, eta):
323
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
324
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
325
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
326
+ # and should be between [0, 1]
327
+
328
+ accepts_eta = "eta" in set(
329
+ inspect.signature(self.scheduler.step).parameters.keys()
330
+ )
331
+ extra_step_kwargs = {}
332
+ if accepts_eta:
333
+ extra_step_kwargs["eta"] = eta
334
+
335
+ # check if the scheduler accepts generator
336
+ accepts_generator = "generator" in set(
337
+ inspect.signature(self.scheduler.step).parameters.keys()
338
+ )
339
+ if accepts_generator:
340
+ extra_step_kwargs["generator"] = generator
341
+ return extra_step_kwargs
342
+
343
+ def prepare_latents(
344
+ self,
345
+ batch_size,
346
+ num_channels_latents,
347
+ height,
348
+ width,
349
+ dtype,
350
+ device,
351
+ generator,
352
+ latents=None,
353
+ ):
354
+ shape = (
355
+ batch_size,
356
+ num_channels_latents,
357
+ height // self.vae_scale_factor,
358
+ width // self.vae_scale_factor,
359
+ )
360
+
361
+ if latents is None:
362
+ latents = randn_tensor(
363
+ shape, generator=generator, device=device, dtype=dtype
364
+ )
365
+ else:
366
+ latents = latents.to(device)
367
+
368
+ # scale the initial noise by the standard deviation required by the scheduler
369
+ latents = latents * self.scheduler.init_noise_sigma
370
+ return latents
371
+
372
+ def encode_images(self, images, device):
373
+ dtype = next(self.image_encoder.parameters()).dtype
374
+
375
+ ret_images = [self.transform(ret) for ret in images]
376
+ images_proc = self.feature_extractor(ret_images, return_tensors="pt").pixel_values
377
+ images_proc = images_proc.to(device=device, dtype=dtype)
378
+ clip_images = self.image_encoder(images_proc, output_hidden_states=True).hidden_states[-2]
379
+
380
+ neg_images = torch.zeros_like(images_proc, device=device)
381
+ clip_images_neg = self.image_encoder(neg_images, output_hidden_states=True).hidden_states[-2]
382
+
383
+ image_embeds = torch.cat([clip_images_neg, clip_images], dim=0)
384
+ image_tokens = self.resampler(image_embeds)
385
+ return image_tokens
386
+
387
+
388
+ @torch.no_grad()
389
+ def __call__(
390
+ self,
391
+ prompt: str = "",
392
+ images = None,
393
+ height: int = 256,
394
+ width: int = 256,
395
+ elevation: int = 0,
396
+ azimuth_start:int = 0,
397
+ num_inference_steps: int = 50,
398
+ num_initial_steps: int = 10,
399
+ guidance_scale: float = 7.0,
400
+ negative_prompt: str = "",
401
+ eta: float = 0.0,
402
+ output_type: Optional[str] = "numpy", # pil, numpy
403
+ num_frames: int = 4,
404
+ device=torch.device("cuda:0"),
405
+ seed: Optional[int] = None,
406
+ ):
407
+ self.unet = self.unet.to(device=device)
408
+ self.vae = self.vae.to(device=device)
409
+ self.text_encoder = self.text_encoder.to(device=device)
410
+
411
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
412
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
413
+ # corresponds to doing no classifier free guidance.
414
+ do_classifier_free_guidance = guidance_scale > 1.0
415
+ bs = len(prompt) if isinstance(prompt, list) else 1
416
+ bs = bs * 2 if do_classifier_free_guidance else bs
417
+ assert isinstance(images, list) and len(images) > 0 and isinstance(images[0], PIL.Image.Image)
418
+ self.image_encoder = self.image_encoder.to(device=device)
419
+ ret_eval_embs = self.similarity_model.get_embeddings(images)
420
+ image_tokens = self.encode_images(images, device)
421
+ image_tokens = einops.rearrange(image_tokens, "(b n) c f -> b (n c) f", b=bs)
422
+ image_tokens = torch.repeat_interleave(image_tokens, num_frames, dim=0)
423
+
424
+ prompt_embeds = self._encode_prompt(
425
+ prompt=prompt,
426
+ device=device,
427
+ do_classifier_free_guidance=do_classifier_free_guidance,
428
+ negative_prompt=negative_prompt,
429
+ )
430
+ generator = torch.Generator(device=device)
431
+ if seed is not None:
432
+ generator.manual_seed(seed)
433
+ # Prepare latent variables
434
+ latents: torch.Tensor = self.prepare_latents(
435
+ num_frames,
436
+ 4,
437
+ height,
438
+ width,
439
+ prompt_embeds.dtype,
440
+ device,
441
+ generator=generator,
442
+ )
443
+ # Get camera
444
+ camera = get_camera(num_frames, elevation=elevation, azimuth_start=azimuth_start).to(dtype=latents.dtype, device=device)
445
+
446
+ # Prepare extra step kwargs.
447
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
448
+
449
+ # initial forward pass for fusion coefficient
450
+ base_out = self._sample(camera=camera,
451
+ device=device,
452
+ extra_step_kwargs=extra_step_kwargs,
453
+ guidance_scale=guidance_scale,
454
+ latents=latents.clone(),
455
+ num_frames=num_frames,
456
+ num_inference_steps=num_initial_steps,
457
+ prompt_embeds=prompt_embeds,
458
+ )
459
+ embs = self.similarity_model.get_embeddings(base_out)
460
+ scale = max(0, self.similarity_model.get_similarity_score(embs, ret_eval_embs))
461
+
462
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
463
+ latents = self._sample(camera=camera,
464
+ device=device,
465
+ extra_step_kwargs=extra_step_kwargs,
466
+ guidance_scale=guidance_scale,
467
+ image_tokens=image_tokens,
468
+ latents=latents,
469
+ num_frames=num_frames,
470
+ num_inference_steps=num_inference_steps,
471
+ progress_bar=progress_bar,
472
+ prompt_embeds=prompt_embeds,
473
+ scale=scale,
474
+ )
475
+ # Post-processing
476
+ images = latents.cpu().permute(0, 2, 3, 1).float().numpy()
477
+ if output_type == "pil":
478
+ images = self.numpy_to_pil(images)
479
+ # Offload last model to CPU
480
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
481
+ self.final_offload_hook.offload()
482
+
483
+ return images
484
+
485
+ def _sample(self, device,
486
+ extra_step_kwargs, guidance_scale,
487
+ latents,
488
+ num_inference_steps, prompt_embeds,
489
+ progress_bar=None,
490
+ num_frames=4,
491
+ camera=None,
492
+ do_classifier_free_guidance=True,
493
+ image_tokens=None,
494
+ scale=1.0
495
+ ):
496
+ # Prepare timesteps
497
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
498
+ timesteps = self.scheduler.timesteps
499
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
500
+ for i, t in enumerate(timesteps):
501
+ # expand the latents if we are doing classifier free guidance
502
+ multiplier = 2 if do_classifier_free_guidance else 1
503
+ latent_model_input = torch.cat([latents] * multiplier)
504
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
505
+
506
+ unet_inputs = {
507
+ 'x': latent_model_input,
508
+ 'timesteps': torch.tensor([t] * num_frames * multiplier, dtype=latent_model_input.dtype, device=device),
509
+ 'context': torch.repeat_interleave(prompt_embeds, num_frames, dim=0),
510
+ 'num_frames': num_frames,
511
+ 'camera': torch.cat([camera] * multiplier),
512
+ 'images_tokens': image_tokens,
513
+ 'scale': scale
514
+ }
515
+ # predict the noise residual
516
+ noise_pred = self.unet.forward(**unet_inputs)
517
+
518
+ # perform guidance
519
+ if do_classifier_free_guidance:
520
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
521
+ noise_pred = noise_pred_uncond + guidance_scale * (
522
+ noise_pred_text - noise_pred_uncond
523
+ )
524
+
525
+ # compute the previous noisy sample x_t -> x_t-1
526
+ latents: torch.Tensor = self.scheduler.step(
527
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
528
+ )[0]
529
+
530
+ if (i == len(timesteps) - 1 or (
531
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
532
+ )) and progress_bar is not None:
533
+ progress_bar.update()
534
+ latents = self.decode_latents(latents)
535
+ return latents