Spaces:
Runtime error
Runtime error
| # Copyright 2024 EPFL and Apple Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| import numpy as np | |
| import torch | |
| def enforce_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: | |
| """Scales the noise schedule betas so that last time step has zero SNR. | |
| See https://arxiv.org/abs/2305.08891 | |
| Args: | |
| betas: the initial diffusion noise schedule betas | |
| Returns: | |
| The diffusion noise schedule betas with the last time step having zero SNR | |
| """ | |
| # Convert betas to alphas_bar_sqrt | |
| alphas = 1 - betas | |
| alphas_bar = alphas.cumprod(0) | |
| alphas_bar_sqrt = alphas_bar.sqrt() | |
| # Store old values. | |
| alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() | |
| alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() | |
| # Shift so last timestep is zero. | |
| alphas_bar_sqrt -= alphas_bar_sqrt_T | |
| # Scale so first timestep is back to old value. | |
| alphas_bar_sqrt *= alphas_bar_sqrt_0 / ( | |
| alphas_bar_sqrt_0 - alphas_bar_sqrt_T) | |
| # Convert alphas_bar_sqrt to betas | |
| alphas_bar = alphas_bar_sqrt ** 2 | |
| alphas = alphas_bar[1:] / alphas_bar[:-1] | |
| alphas = torch.cat([alphas_bar[0:1], alphas]) | |
| betas = 1 - alphas | |
| return betas | |
| def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta: float = 0.999) -> torch.Tensor: | |
| """Create a beta schedule that discretizes the given alpha_t_bar function, | |
| which defines the cumulative product of (1-beta) over time from t = [0,1]. | |
| Contains a function alpha_bar that takes an argument t and transforms it to | |
| the cumulative product of (1-beta) up to that part of the diffusion process. | |
| Args: | |
| num_diffusion_timesteps: the number of betas to produce. | |
| max_beta: the maximum beta to use; use values lower than 1 to | |
| prevent singularities. | |
| Returns: | |
| The betas used by the scheduler to step the model outputs | |
| """ | |
| def alpha_bar(time_step): | |
| return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 | |
| betas = [] | |
| for i in range(num_diffusion_timesteps): | |
| t1 = i / num_diffusion_timesteps | |
| t2 = (i + 1) / num_diffusion_timesteps | |
| betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) | |
| return torch.tensor(betas, dtype=torch.float32) | |
| def scaled_cosine_alphas(num_diffusion_timesteps: int, noise_shift: float = 1.0) -> torch.Tensor: | |
| """Shifts a cosine noise schedule by a specified amount in log-SNR space. | |
| noise_shift = 1.0 corresponds to the standard cosine noise schedule. | |
| 0 < noise_shift < 1.0 corresponds to a less noisy schedule (better | |
| suited if the conditioning is highly informative, e.g. low-res images). | |
| noise_shift > 1.0 corresponds to a more noisy schedule (better suited | |
| if the conditioning is not as informative, e.g. captions). | |
| See https://arxiv.org/abs/2305.18231 | |
| Args: | |
| num_diffusion_timesteps: the number of diffusion timesteps. | |
| noise_shift: the amount to shift the noise schedule by in log-SNR space. | |
| Returns: | |
| The alphas_cumprod used by the diffusion noise scheduler | |
| """ | |
| t = torch.linspace(0, 1, num_diffusion_timesteps).to(torch.float64) | |
| log_snr = -2 * (torch.tan(torch.pi * t / 2).log() + np.log(noise_shift)) | |
| log_snr = log_snr.clamp(-15,15).float() | |
| alphas_cumprod = log_snr.sigmoid() | |
| alphas_cumprod[-1] = 0.0 | |
| return alphas_cumprod |