Spaces:
Runtime error
Runtime error
| # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from .base import * | |
| class DDIMScheduler(Scheduler): | |
| def step( | |
| self, | |
| model_output: torch.Tensor, | |
| model_output_type: str, | |
| timestep: Union[torch.Tensor, int], | |
| sample: torch.Tensor, | |
| eta: float = 0.0, | |
| clip_sample: bool = False, | |
| dynamic_threshold: Optional[float] = None, | |
| variance_noise: Optional[torch.Tensor] = None, | |
| ) -> SchedulerStepOutput: | |
| # 1. get previous step value (t-1) | |
| if not isinstance(timestep, torch.Tensor): | |
| timestep = torch.tensor(timestep, device=self.device, dtype=torch.int) | |
| idx = timestep.reshape(-1, 1).eq(self.timesteps.reshape(1, -1)).nonzero()[:, 1] | |
| prev_timestep = self.timesteps[idx.add(1).clamp_max(self.num_inference_timesteps - 1)] | |
| # 2. compute alphas, betas | |
| alpha_prod_t = self.alphas_cumprod[timestep].reshape(-1, *([1] * (sample.ndim - 1))) | |
| alpha_prod_t_prev = torch.where(idx < self.num_inference_timesteps - 1, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod).reshape(-1, *([1] * (sample.ndim - 1))) | |
| beta_prod_t = 1 - alpha_prod_t | |
| beta_prod_t_prev = 1 - alpha_prod_t_prev | |
| # 3. compute predicted original sample from predicted noise also called | |
| model_output_conversion = self.convert_output(model_output, model_output_type, sample, timestep) | |
| pred_original_sample = model_output_conversion.pred_original_sample | |
| pred_epsilon = model_output_conversion.pred_epsilon | |
| # 4. Clip or threshold "predicted x_0" | |
| if clip_sample: | |
| pred_original_sample = torch.clamp(pred_original_sample, -1, 1) | |
| pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon | |
| if dynamic_threshold is not None: | |
| # Dynamic thresholding in https://arxiv.org/abs/2205.11487 | |
| dynamic_max_val = pred_original_sample \ | |
| .flatten(1) \ | |
| .abs() \ | |
| .float() \ | |
| .quantile(dynamic_threshold, dim=1) \ | |
| .type_as(pred_original_sample) \ | |
| .clamp_min(1) \ | |
| .view(-1, *([1] * (pred_original_sample.ndim - 1))) | |
| pred_original_sample = pred_original_sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val | |
| pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon | |
| # 5. compute variance: "sigma_t(η)" -> see formula (16) from https://arxiv.org/pdf/2010.02502.pdf | |
| # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | |
| variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) | |
| std_dev_t = eta * variance ** (0.5) | |
| # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
| pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon | |
| # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
| prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | |
| # 8. add "random noise" if needed. | |
| if eta > 0: | |
| if variance_noise is None: | |
| variance_noise = torch.randn_like(model_output) | |
| prev_sample = prev_sample + std_dev_t * variance_noise | |
| return SchedulerStepOutput( | |
| prev_sample=prev_sample, | |
| pred_original_sample=pred_original_sample) | |