|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from diffusers import LCMScheduler
|
| from dataclasses import dataclass
|
| from typing import List, Optional, Tuple, Union
|
| import numpy as np
|
| import torch
|
| from config import config
|
|
|
|
|
| from diffusers.schedulers.scheduling_lcm import BaseOutput, logging, randn_tensor
|
| from diffusers.schedulers.scheduling_lcm import register_to_config
|
|
|
| logger = logging.get_logger(__name__)
|
|
|
|
|
| @dataclass
|
| class BBDMSchedulerOutput(BaseOutput):
|
| """
|
| Output class for the scheduler's `step` function output.
|
|
|
| Args:
|
| prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| denoising loop.
|
| """
|
|
|
| prev_sample: torch.FloatTensor
|
| predicted_original_sample: torch.FloatTensor
|
|
|
|
|
| class BBDMScheduler(LCMScheduler):
|
| order = 1
|
|
|
| @register_to_config
|
| def __init__(self, num_train_timesteps: int = 1000, original_inference_steps: int = 1000):
|
|
|
|
|
| super().__init__(num_train_timesteps, original_inference_steps)
|
| self.num_inference_steps = None
|
| self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
| self.custom_timesteps = False
|
|
|
| self._step_index = None
|
| self._begin_index = None
|
|
|
| def m(self, timestep) -> torch.FloatTensor:
|
| m_t = timestep / self.num_train_timesteps
|
| return m_t
|
|
|
| def sigma(self, timestep) -> torch.FloatTensor:
|
| sigma_t = 2 * (self.m(timestep) - self.m(timestep) ** 2) * config.s
|
| return sigma_t
|
|
|
| def sqrt_sigma(self, timestep) -> torch.FloatTensor:
|
| sqrt_sigma_t = torch.sqrt(self.sigma(timestep))
|
| return sqrt_sigma_t
|
|
|
| def sigma_previous(self, timestep, prev_timestep) -> torch.FloatTensor:
|
| sigma_previous_t = self.sigma(timestep) - self.sigma(prev_timestep) * (1 - self.m(timestep)) ** 2 / (
|
| 1 - self.m(prev_timestep)) ** 2
|
| return sigma_previous_t
|
|
|
| def tilde_sigma(self, timestep, prev_timestep) -> torch.FloatTensor:
|
| tilde_sigma_t = self.sigma_previous(timestep, prev_timestep) * self.sigma(prev_timestep) / self.sigma(timestep)
|
| return tilde_sigma_t
|
|
|
| def sqrt_tilde_sigma(self, timestep, prev_timestep) -> torch.FloatTensor:
|
| sqrt_tilde_sigma_t = torch.sqrt(self.tilde_sigma(timestep, prev_timestep))
|
| return sqrt_tilde_sigma_t
|
|
|
| def c_x(self, timestep, prev_timestep) -> torch.FloatTensor:
|
| c_xt = self.sigma(prev_timestep) / self.sigma(timestep) * (1 - self.m(timestep)) / (
|
| 1 - self.m(prev_timestep)) + self.sigma_previous(timestep, prev_timestep) / self.sigma(timestep) * (
|
| 1 - self.m(prev_timestep))
|
| return c_xt
|
|
|
| def c_y(self, timestep, prev_timestep) -> torch.FloatTensor:
|
| c_yt = self.m(prev_timestep) - self.m(timestep) * (1 - self.m(timestep)) / (
|
| 1 - self.m(prev_timestep)) * self.sigma(prev_timestep) / self.sigma(timestep)
|
| return c_yt
|
|
|
| def c_epsilon(self, timestep, prev_timestep) -> torch.FloatTensor:
|
| c_epsilon_t = (1 - self.m(prev_timestep)) * self.sigma_previous(timestep, prev_timestep) / self.sigma(
|
| timestep) * self.sqrt_sigma(timestep) / (1 - self.m(timestep))
|
| return c_epsilon_t
|
|
|
| def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| if schedule_timesteps is None:
|
| schedule_timesteps = self.timesteps
|
|
|
| indices = (schedule_timesteps == timestep).nonzero()
|
|
|
| pos = 1 if len(indices) > 1 else 0
|
|
|
| return indices[pos].item()
|
|
|
| def _init_step_index(self, timestep):
|
| if self.begin_index is None:
|
| if isinstance(timestep, torch.Tensor):
|
| timestep = timestep.to(self.timesteps.device)
|
| self._step_index = self.index_for_timestep(timestep)
|
| else:
|
| self._step_index = self._begin_index
|
|
|
| @property
|
| def step_index(self):
|
| return self._step_index
|
|
|
| @property
|
| def begin_index(self):
|
| """
|
| The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| """
|
| return self._begin_index
|
|
|
| def set_begin_index(self, begin_index: int = 0):
|
| """
|
| Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
|
|
| Args:
|
| begin_index (`int`):
|
| The begin index for the scheduler.
|
| """
|
| self._begin_index = begin_index
|
|
|
| def set_timesteps(
|
| self,
|
| num_inference_steps: Optional[int] = None,
|
| device: Union[str, torch.device] = None,
|
| original_inference_steps: Optional[int] = 1000,
|
| timesteps: Optional[List[int]] = None,
|
| strength: int = 1.0,
|
| ):
|
| """
|
| Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
|
|
| Args:
|
| num_inference_steps (`int`, *optional*):
|
| The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
| `timesteps` must be `None`.
|
| device (`str` or `torch.device`, *optional*):
|
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| original_inference_steps (`int`, *optional*):
|
| The original number of inference steps, which will be used to generate a linearly-spaced timestep
|
| schedule (which is different from the standard `diffusers` implementation). We will then take
|
| `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
|
| our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
|
| timesteps (`List[int]`, *optional*):
|
| Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
| timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
|
| schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
|
| """
|
|
|
| if num_inference_steps is None and timesteps is None:
|
| raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
|
|
|
| if num_inference_steps is not None and timesteps is not None:
|
| raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
|
|
|
|
| original_steps = (
|
| original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
|
| )
|
|
|
| if original_steps > self.config.num_train_timesteps:
|
| raise ValueError(
|
| f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
|
| f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
| f" maximal {self.config.num_train_timesteps} timesteps."
|
| )
|
|
|
|
|
|
|
| k = self.config.num_train_timesteps // original_steps
|
|
|
|
|
| bbdm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
|
|
|
|
|
| if timesteps is not None:
|
|
|
| train_timesteps = set(bbdm_origin_timesteps)
|
| non_train_timesteps = []
|
| for i in range(1, len(timesteps)):
|
| if timesteps[i] >= timesteps[i - 1]:
|
| raise ValueError("`custom_timesteps` must be in descending order.")
|
|
|
| if timesteps[i] not in train_timesteps:
|
| non_train_timesteps.append(timesteps[i])
|
|
|
| if timesteps[0] >= self.config.num_train_timesteps:
|
| raise ValueError(
|
| f"`timesteps` must start before `self.config.train_timesteps`:"
|
| f" {self.config.num_train_timesteps}."
|
| )
|
|
|
|
|
| if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1:
|
| logger.warning(
|
| f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
|
| f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get"
|
| f" unexpected results when using this timestep schedule."
|
| )
|
|
|
|
|
| if non_train_timesteps:
|
| logger.warning(
|
| f"The custom timestep schedule contains the following timesteps which are not on the original"
|
| f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
|
| f" when using this timestep schedule."
|
| )
|
|
|
|
|
| if len(timesteps) > original_steps:
|
| logger.warning(
|
| f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
|
| f" the length of the timestep schedule used for training: {original_steps}. You may get some"
|
| f" unexpected results when using this timestep schedule."
|
| )
|
|
|
| timesteps = np.array(timesteps, dtype=np.int64)
|
| self.num_inference_steps = len(timesteps)
|
| self.custom_timesteps = True
|
|
|
| init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps)
|
| t_start = max(self.num_inference_steps - init_timestep, 0)
|
| timesteps = timesteps[t_start * self.order:]
|
| else:
|
| if num_inference_steps > self.config.num_train_timesteps:
|
| raise ValueError(
|
| f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
| f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
| f" maximal {self.config.num_train_timesteps} timesteps."
|
| )
|
|
|
| self.num_inference_steps = num_inference_steps
|
|
|
| if num_inference_steps > original_steps:
|
| raise ValueError(
|
| f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
|
| f" {original_steps} because the final timestep schedule will be a subset of the"
|
| f" `original_inference_steps`-sized initial timestep schedule."
|
| )
|
|
|
|
|
| bbdm_origin_timesteps = bbdm_origin_timesteps[::-1].copy()
|
|
|
|
|
| inference_indices = np.linspace(0, len(bbdm_origin_timesteps) - 1, num=num_inference_steps, endpoint=True)
|
| inference_indices = np.floor(inference_indices).astype(np.int64)
|
| timesteps = bbdm_origin_timesteps[inference_indices]
|
|
|
| self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
|
|
|
| self._step_index = None
|
| self._begin_index = None
|
|
|
|
|
| def step(
|
| self,
|
| model_output: torch.FloatTensor,
|
| timestep: int,
|
| sample: torch.FloatTensor,
|
| y: torch.FloatTensor,
|
| generator: Optional[torch.Generator] = None,
|
| return_dict: bool = True,
|
| ) -> Union[BBDMSchedulerOutput, Tuple]:
|
| """
|
| Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| process from the learned model outputs (most often the predicted noise).
|
|
|
| Args:
|
| model_output (`torch.FloatTensor`):
|
| The direct output from learned diffusion model.
|
| timestep (`int`):
|
| The current discrete timestep in the diffusion chain.
|
| sample (`torch.FloatTensor`):
|
| A current instance of a sample created by the diffusion process.
|
| y (`torch.FloatTensor`):
|
| y.
|
| generator (`torch.Generator`, *optional*):
|
| A random number generator.
|
| return_dict (`bool`, *optional*, defaults to `True`):
|
| Whether or not to return a [`~schedulers.scheduling_bbdm.BBDMSchedulerOutput`] or `tuple`.
|
| Returns:
|
| [`~schedulers.scheduling_utils.BBDMSchedulerOutput`] or `tuple`:
|
| If return_dict is `True`, [`~schedulers.scheduling_bbdm.BBDMSchedulerOutput`] is returned, otherwise a
|
| tuple is returned where the first element is the sample tensor.
|
| """
|
| if self.num_inference_steps is None:
|
| raise ValueError(
|
| "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| )
|
|
|
| if self.step_index is None:
|
| self._init_step_index(timestep)
|
|
|
| prev_step_index = self.step_index + 1
|
| if prev_step_index < len(self.timesteps):
|
| prev_timestep = self.timesteps[prev_step_index]
|
| else:
|
| prev_timestep = timestep
|
|
|
| if config.noise_correction:
|
| model_output = model_output - torch.mean(model_output)
|
| model_output = model_output / torch.std(model_output)
|
|
|
| m_t = self.m(timestep).view(-1, 1, 1, 1)
|
|
|
| if config.prediction_type == "noise":
|
|
|
| denom = 1 - m_t
|
| denom = torch.clamp(denom, min=1e-5)
|
|
|
|
|
| predicted_original_sample = (sample - m_t * y - self.sqrt_sigma(timestep).view(-1, 1, 1,
|
| 1) * model_output) / denom
|
| else:
|
| predicted_original_sample = sample - model_output
|
|
|
| m_prev = self.m(prev_timestep).view(-1, 1, 1, 1)
|
| alpha_prev = 1 - m_prev
|
|
|
| prev_sample_mean = alpha_prev * predicted_original_sample + m_prev * y
|
|
|
|
|
| if self.step_index != self.num_inference_steps - 1:
|
| noise = randn_tensor(
|
| model_output.shape, generator=generator, device=model_output.device, dtype=sample.dtype
|
| )
|
| variance = self.sqrt_tilde_sigma(timestep, prev_timestep).view(-1, 1, 1, 1)
|
| prev_sample = prev_sample_mean + variance * noise
|
| else:
|
| prev_sample = prev_sample_mean
|
|
|
|
|
| self._step_index += 1
|
|
|
| if not return_dict:
|
| return (prev_sample, predicted_original_sample)
|
|
|
| return BBDMSchedulerOutput(prev_sample=prev_sample, predicted_original_sample=predicted_original_sample)
|
|
|
| def add_noise(
|
| self,
|
| original_samples: torch.FloatTensor,
|
| ys: torch.FloatTensor,
|
| noise: torch.FloatTensor,
|
| timesteps: torch.IntTensor,
|
| ) -> torch.FloatTensor:
|
| noisy_samples = (1 - self.m(timesteps)).view(-1, 1, 1, 1) * original_samples + self.m(timesteps).view(-1, 1, 1,
|
| 1) * ys + self.sqrt_sigma(
|
| timesteps).view(-1, 1, 1, 1) * noise
|
| return noisy_samples
|
|
|
| def __len__(self):
|
| return self.config.num_train_timesteps
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| pass
|
|
|