Spaces:
Sleeping
Sleeping
| import ipdb # noqa: F401 | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from diffusionsfm.utils.visualization import plot_to_image | |
| class NoiseScheduler(nn.Module): | |
| def __init__( | |
| self, | |
| max_timesteps=1000, | |
| beta_start=0.0001, | |
| beta_end=0.02, | |
| cos_power=2, | |
| num_inference_steps=100, | |
| type="linear", | |
| ): | |
| super().__init__() | |
| self.max_timesteps = max_timesteps | |
| self.num_inference_steps = num_inference_steps | |
| self.beta_start = beta_start | |
| self.beta_end = beta_end | |
| self.cos_power = cos_power | |
| self.type = type | |
| if type == "linear": | |
| self.register_linear_schedule() | |
| elif type == "cosine": | |
| self.register_cosine_schedule(cos_power) | |
| elif type == "scaled_linear": | |
| self.register_scaled_linear_schedule() | |
| self.inference_timesteps = self.compute_inference_timesteps() | |
| def register_linear_schedule(self): | |
| # zero terminal SNR (https://arxiv.org/pdf/2305.08891) | |
| betas = torch.linspace( | |
| self.beta_start, | |
| self.beta_end, | |
| self.max_timesteps, | |
| dtype=torch.float32, | |
| ) | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, dim=0) | |
| alphas_bar_sqrt = alphas_cumprod.sqrt() | |
| alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() | |
| alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() | |
| alphas_bar_sqrt -= alphas_bar_sqrt_T | |
| alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) | |
| alphas_bar = alphas_bar_sqrt**2 | |
| alphas = alphas_bar[1:] / alphas_bar[:-1] | |
| alphas = torch.cat([alphas_bar[0:1], alphas]) | |
| betas = 1 - alphas | |
| self.register_buffer( | |
| "betas", | |
| betas, | |
| ) | |
| self.register_buffer("alphas", 1.0 - self.betas) | |
| self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0)) | |
| def register_cosine_schedule(self, cos_power, s=0.008): | |
| timesteps = ( | |
| torch.arange(self.max_timesteps + 1, dtype=torch.float32) | |
| / self.max_timesteps | |
| ) | |
| alpha_bars = (timesteps + s) / (1 + s) * np.pi / 2 | |
| alpha_bars = torch.cos(alpha_bars).pow(cos_power) | |
| alpha_bars = alpha_bars / alpha_bars[0] | |
| betas = 1 - alpha_bars[1:] / alpha_bars[:-1] | |
| betas = np.clip(betas, a_min=0, a_max=0.999) | |
| self.register_buffer( | |
| "betas", | |
| betas, | |
| ) | |
| self.register_buffer("alphas", 1.0 - betas) | |
| self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0)) | |
| def register_scaled_linear_schedule(self): | |
| self.register_buffer( | |
| "betas", | |
| torch.linspace( | |
| self.beta_start**0.5, | |
| self.beta_end**0.5, | |
| self.max_timesteps, | |
| dtype=torch.float32, | |
| ) | |
| ** 2, | |
| ) | |
| self.register_buffer("alphas", 1.0 - self.betas) | |
| self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0)) | |
| def compute_inference_timesteps( | |
| self, num_inference_steps=None, num_train_steps=None | |
| ): | |
| # based on diffusers's scheduling code | |
| if num_inference_steps is None: | |
| num_inference_steps = self.num_inference_steps | |
| if num_train_steps is None: | |
| num_train_steps = self.max_timesteps | |
| step_ratio = num_train_steps // num_inference_steps | |
| timesteps = ( | |
| (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(int) | |
| ) | |
| return timesteps | |
| def plot_schedule(self, return_image=False): | |
| fig = plt.figure(figsize=(6, 4), dpi=100) | |
| alpha_bars = self.alphas_cumprod.cpu().numpy() | |
| plt.plot(np.sqrt(alpha_bars)) | |
| plt.grid() | |
| if self.type == "linear": | |
| plt.title( | |
| f"Linear (T={self.max_timesteps}, S={self.beta_start}, E={self.beta_end})" | |
| ) | |
| else: | |
| self.type == "cosine" | |
| plt.title(f"Cosine (T={self.max_timesteps}, P={self.cos_power})") | |
| if return_image: | |
| image = plot_to_image(fig) | |
| plt.close(fig) | |
| return image | |