| |
|
|
| import torch |
| from tqdm import tqdm |
| from typing import Tuple, List, Union, Optional |
| from diffusers.schedulers import DDIMScheduler |
|
|
|
|
| __all__ = ["ddim_sample"] |
|
|
|
|
| def ddim_sample(ddim_scheduler: DDIMScheduler, |
| diffusion_model: torch.nn.Module, |
| shape: Union[List[int], Tuple[int]], |
| cond: torch.FloatTensor, |
| steps: int, |
| eta: float = 0.0, |
| guidance_scale: float = 3.0, |
| do_classifier_free_guidance: bool = True, |
| generator: Optional[torch.Generator] = None, |
| device: torch.device = "cuda:0", |
| disable_prog: bool = True): |
|
|
| assert steps > 0, f"{steps} must > 0." |
|
|
| |
| bsz = cond.shape[0] |
| if do_classifier_free_guidance: |
| bsz = bsz // 2 |
|
|
| latents = torch.randn( |
| (bsz, *shape), |
| generator=generator, |
| device=cond.device, |
| dtype=cond.dtype, |
| ) |
| |
| latents = latents * ddim_scheduler.init_noise_sigma |
| |
| ddim_scheduler.set_timesteps(steps) |
| timesteps = ddim_scheduler.timesteps.to(device) |
| |
| |
| extra_step_kwargs = { |
| "eta": eta, |
| "generator": generator |
| } |
|
|
| |
| for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): |
| |
| latent_model_input = ( |
| torch.cat([latents] * 2) |
| if do_classifier_free_guidance |
| else latents |
| ) |
| |
| |
| timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) |
| timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) |
| noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) |
|
|
| |
| if do_classifier_free_guidance: |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * ( |
| noise_pred_text - noise_pred_uncond |
| ) |
| |
| |
| |
| latents = ddim_scheduler.step( |
| noise_pred, t, latents, **extra_step_kwargs |
| ).prev_sample |
|
|
| yield latents, t |
|
|
|
|
| def karra_sample(): |
| pass |
|
|