Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| import yaml | |
| from helper.util import extract | |
| from helper.beta_generator import BetaGenerator | |
| from abc import ABC, abstractmethod | |
| class BaseSampler(nn.Module, ABC): | |
| def __init__(self, config_path : str): | |
| super().__init__() | |
| with open(config_path, "r") as file: | |
| self.config = yaml.safe_load(file)['sampler'] | |
| self.T = self.config['T'] | |
| beta_generator = BetaGenerator(T=self.T) | |
| self.timesteps = None | |
| self.register_buffer('beta', getattr(beta_generator, | |
| f"{self.config['beta']}_beta_schedule", | |
| beta_generator.linear_beta_schedule)()) | |
| self.register_buffer('alpha', 1 - self.beta) | |
| self.register_buffer('alpha_sqrt', self.alpha.sqrt()) | |
| self.register_buffer('alpha_bar', torch.cumprod(self.alpha, dim = 0)) | |
| def get_x_prev(self, x, t, idx, eps_hat): | |
| pass | |
| def set_network(self, network : nn.Module): | |
| self.network = network | |
| def q_sample(self, x0, t, eps = None): | |
| alpha_t_bar = extract(self.alpha_bar, t, x0.shape) | |
| if eps is None: | |
| eps = torch.randn_like(x0) | |
| q_xt_x0 = alpha_t_bar.sqrt() * x0 + (1 - alpha_t_bar).sqrt() * eps | |
| return q_xt_x0 | |
| def reverse_process(self, x_T, only_last=True, **kwargs): | |
| x = x_T | |
| if only_last: | |
| for i, t in tqdm(enumerate(reversed(self.timesteps))): | |
| idx = len(self.timesteps) - i - 1 | |
| x = self.p_sample(x, t, idx, **kwargs) | |
| return x | |
| else: | |
| x_seq = [] | |
| x_seq.append(x) | |
| for i, t in tqdm(enumerate(reversed(self.timesteps))): | |
| idx = len(self.timesteps) - i - 1 | |
| x_seq.append(self.p_sample(x_seq[-1], t, idx, **kwargs)) | |
| return x_seq | |
| def p_sample(self, x, t, idx, gamma = None, **kwargs): | |
| eps_hat = self.network(x = x, t = t, **kwargs) | |
| if gamma is not None: | |
| eps_null = self.network(x = x, t = t, cond_drop_all=True, **kwargs) | |
| eps_hat = gamma * eps_hat + (1 - gamma) * eps_null | |
| x = self.get_x_prev(x, idx, eps_hat) | |
| return x | |
| def forward(self, x_T, **kwargs): | |
| return self.reverse_process(x_T, **kwargs) |