Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| import inspect | |
| import os.path as osp | |
| from typing import Union, Optional | |
| from tqdm import tqdm | |
| from omegaconf import OmegaConf | |
| from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser | |
| from diffusers.schedulers import ( | |
| DDIMScheduler, | |
| DPMSolverMultistepScheduler, | |
| PNDMScheduler, | |
| LMSDiscreteScheduler, | |
| ) | |
| def exists(v): | |
| return v is not None | |
| class CFGDenoiser(nn.Module): | |
| """ | |
| Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) | |
| that can take a noisy picture and produce a noise-free picture using two guidances (prompts) | |
| instead of one. Originally, the second prompt is just an empty string, but we use non-empty | |
| negative prompt. | |
| """ | |
| def __init__(self, model, device): | |
| super().__init__() | |
| denoiser = CompVisDenoiser if model.parameterization == "eps" else CompVisVDenoiser | |
| self.model_wrap = denoiser(model, device=device) | |
| def inner_model(self): | |
| return self.model_wrap | |
| def forward( | |
| self, | |
| x, | |
| sigma, | |
| cond: dict, | |
| cond_scale: Union[float, list[float]] | |
| ): | |
| """ | |
| Simplify k-diffusion sampler for sketch colorizaiton. | |
| Available for reference CFG / sketch CFG or Dual CFG | |
| """ | |
| if not isinstance(cond_scale, list): | |
| if cond_scale > 1.: | |
| repeats = 2 | |
| else: | |
| return self.inner_model(x, sigma, cond=cond) | |
| else: | |
| repeats = 3 | |
| x_in = torch.cat([x] * repeats) | |
| sigma_in = torch.cat([sigma] * repeats) | |
| x_out = self.inner_model(x_in, sigma_in, cond=cond).chunk(repeats) | |
| if repeats == 2: | |
| x_cond, x_uncond = x_out[:] | |
| return x_uncond + (x_cond - x_uncond) * cond_scale | |
| else: | |
| x_cond, x_uncond_0, x_uncond_1 = x_out[:] | |
| return (x_uncond_0 + (x_cond - x_uncond_0) * cond_scale[0] + | |
| x_uncond_1 + (x_cond - x_uncond_1) * cond_scale[1]) * 0.5 | |
| scheduler_config_path = "configs/scheduler_cfgs" | |
| class DiffuserDenoiser: | |
| scheduler_types = { | |
| "ddim": DDIMScheduler, | |
| "dpm": DPMSolverMultistepScheduler, | |
| "dpm_sde": DPMSolverMultistepScheduler, | |
| "pndm": PNDMScheduler, | |
| "lms": LMSDiscreteScheduler | |
| } | |
| def __init__(self, scheduler_type, prediction_type, use_karras=False): | |
| scheduler_type = scheduler_type.replace("diffuser_", "") | |
| assert scheduler_type in self.scheduler_types.keys(), "Selected scheduler is not implemented" | |
| scheduler = self.scheduler_types[scheduler_type] | |
| scheduler_config = OmegaConf.load(osp.abspath(osp.join(scheduler_config_path, scheduler_type + ".yaml"))) | |
| if "use_karras_sigmas" in set(inspect.signature(scheduler).parameters.keys()): | |
| scheduler_config.use_karras_sigmas = use_karras | |
| self.scheduler = scheduler(prediction_type=prediction_type, **scheduler_config) | |
| def prepare_extra_step_kwargs(self, generator, eta): | |
| # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
| # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
| # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
| # and should be between [0, 1] | |
| accepts_eta = "eta" in set( | |
| inspect.signature(self.scheduler.step).parameters.keys() | |
| ) | |
| extra_step_kwargs = {} | |
| if accepts_eta: | |
| extra_step_kwargs["eta"] = eta | |
| # check if the scheduler accepts generator | |
| accepts_generator = "generator" in set( | |
| inspect.signature(self.scheduler.step).parameters.keys() | |
| ) | |
| if accepts_generator: | |
| extra_step_kwargs["generator"] = generator | |
| return extra_step_kwargs | |
| def __call__( | |
| self, | |
| x, | |
| cond, | |
| cond_scale, | |
| unet, | |
| timesteps, | |
| generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, | |
| eta: float = 0.0, | |
| device: str = "cuda" | |
| ): | |
| self.scheduler.set_timesteps(timesteps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
| x_start = x | |
| x = x * self.scheduler.init_noise_sigma | |
| inpaint_latents = cond.pop("inpaint_bg", None) | |
| if exists(inpaint_latents): | |
| mask = cond.get("mask", None) | |
| threshold = cond.pop("threshold", 0.5) | |
| inpaint_latents = inpaint_latents[0] | |
| assert exists(mask) | |
| mask = mask[0] | |
| mask = torch.where(mask > threshold, torch.ones_like(mask), torch.zeros_like(mask)) | |
| for i, t in enumerate(tqdm(timesteps)): | |
| x_t = self.scheduler.scale_model_input(x, t) | |
| if not isinstance(cond_scale, list): | |
| if cond_scale > 1.: | |
| repeats = 2 | |
| else: | |
| repeats = 1 | |
| else: | |
| repeats = 3 | |
| x_in = torch.cat([x_t] * repeats) | |
| x_out = unet.apply_model( | |
| x_in, | |
| t[None].expand(x_in.shape[0]), | |
| cond=cond | |
| ) | |
| if repeats == 1: | |
| pred = x_out | |
| elif repeats == 2: | |
| x_cond, x_uncond = x_out.chunk(2) | |
| pred = x_uncond + (x_cond - x_uncond) * cond_scale | |
| else: | |
| x_cond, x_uncond_0, x_uncond_1 = x_out.chunk(3) | |
| pred = (x_uncond_0 + (x_cond - x_uncond_0) * cond_scale[0] + | |
| x_uncond_1 + (x_cond - x_uncond_1) * cond_scale[1]) * 0.5 | |
| x = self.scheduler.step( | |
| pred, t, x, **extra_step_kwargs, return_dict=False | |
| )[0] | |
| if exists(inpaint_latents) and exists(mask) and i < len(timesteps) - 1: | |
| noise_timestep = timesteps[i + 1] | |
| init_latents_proper = inpaint_latents | |
| init_latents_proper = self.scheduler.add_noise( | |
| init_latents_proper, x_start, torch.tensor([noise_timestep]) | |
| ) | |
| x = (1 - mask) * init_latents_proper + mask * x | |
| return x |