Spaces:
Sleeping
Sleeping
| import torch | |
| from diffusion_model.sampler.base_sampler import BaseSampler | |
| class DDIM(BaseSampler): | |
| def __init__(self, config_path): | |
| super().__init__(config_path) | |
| self.sampling_T = self.config['sampling_T'] | |
| step = self.T // self.sampling_T | |
| self.timesteps = torch.arange(0, self.T, step, dtype=torch.int) | |
| self.ddim_alpha = self.alpha_bar[self.timesteps] | |
| self.sqrt_one_minus_alpha_bar = (1. - self.ddim_alpha).sqrt() | |
| self.alpha_bar_prev = torch.cat([self.ddim_alpha[0:1], self.ddim_alpha[:-1]]) | |
| self.sigma = (self.config['eta'] * | |
| torch.sqrt((1-self.alpha_bar_prev) / (1-self.ddim_alpha) * | |
| (1 - self.ddim_alpha / self.alpha_bar_prev))) | |
| def get_x_prev(self, x, tau, eps_hat) : | |
| alpha_prev = self.alpha_bar_prev[tau] | |
| sigma = self.sigma[tau] | |
| x0_hat = (x - self.sqrt_one_minus_alpha_bar[tau] * eps_hat) \ | |
| / (self.ddim_alpha[tau] ** 0.5) | |
| dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * eps_hat | |
| if sigma == 0. : noise = 0. | |
| else : noise = torch.randn_like(x, device = x.device) | |
| x = alpha_prev.sqrt() * x0_hat + dir_xt + sigma * noise | |
| return x |