Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import List, Tuple | |
| import torch | |
| from torch import Tensor | |
| from .denoiser import Denoiser | |
| class DiffusionSamplerConfig: | |
| num_steps_denoising: int | |
| sigma_min: float = 2e-3 | |
| sigma_max: float = 5 | |
| rho: int = 7 | |
| order: int = 1 | |
| s_churn: float = 0 | |
| s_tmin: float = 0 | |
| s_tmax: float = float("inf") | |
| s_noise: float = 1 | |
| class DiffusionSampler: | |
| def __init__(self, denoiser: Denoiser, cfg: DiffusionSamplerConfig) -> None: | |
| self.denoiser = denoiser | |
| self.cfg = cfg | |
| self.sigmas = build_sigmas(cfg.num_steps_denoising, cfg.sigma_min, cfg.sigma_max, cfg.rho, denoiser.device) | |
| def sample(self, prev_obs: Tensor, prev_act: Tensor) -> Tuple[Tensor, List[Tensor]]: | |
| device = prev_obs.device | |
| b, t, c, h, w = prev_obs.size() | |
| prev_obs = prev_obs.reshape(b, t * c, h, w) | |
| s_in = torch.ones(b, device=device) | |
| gamma_ = min(self.cfg.s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) | |
| x = torch.randn(b, c, h, w, device=device) | |
| trajectory = [x] | |
| for sigma, next_sigma in zip(self.sigmas[:-1], self.sigmas[1:]): | |
| gamma = gamma_ if self.cfg.s_tmin <= sigma <= self.cfg.s_tmax else 0 | |
| sigma_hat = sigma * (gamma + 1) | |
| if gamma > 0: | |
| eps = torch.randn_like(x) * self.cfg.s_noise | |
| x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5 | |
| denoised = self.denoiser.denoise(x, sigma, prev_obs, prev_act) | |
| d = (x - denoised) / sigma_hat | |
| dt = next_sigma - sigma_hat | |
| if self.cfg.order == 1 or next_sigma == 0: | |
| # Euler method | |
| x = x + d * dt | |
| else: | |
| # Heun's method | |
| x_2 = x + d * dt | |
| denoised_2 = self.denoiser.denoise(x_2, next_sigma * s_in, prev_obs, prev_act) | |
| d_2 = (x_2 - denoised_2) / next_sigma | |
| d_prime = (d + d_2) / 2 | |
| x = x + d_prime * dt | |
| trajectory.append(x) | |
| return x, trajectory | |
| def build_sigmas(num_steps: int, sigma_min: float, sigma_max: float, rho: int, device: torch.device) -> Tensor: | |
| min_inv_rho = sigma_min ** (1 / rho) | |
| max_inv_rho = sigma_max ** (1 / rho) | |
| l = torch.linspace(0, 1, num_steps, device=device) | |
| sigmas = (max_inv_rho + l * (min_inv_rho - max_inv_rho)) ** rho | |
| return torch.cat((sigmas, sigmas.new_zeros(1))) | |