Spaces:
Sleeping
Sleeping
File size: 2,475 Bytes
d548197 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | from dataclasses import dataclass
from typing import List, Tuple
import torch
from torch import Tensor
from .denoiser import Denoiser
@dataclass
class DiffusionSamplerConfig:
num_steps_denoising: int
sigma_min: float = 2e-3
sigma_max: float = 5
rho: int = 7
order: int = 1
s_churn: float = 0
s_tmin: float = 0
s_tmax: float = float("inf")
s_noise: float = 1
class DiffusionSampler:
def __init__(self, denoiser: Denoiser, cfg: DiffusionSamplerConfig) -> None:
self.denoiser = denoiser
self.cfg = cfg
self.sigmas = build_sigmas(cfg.num_steps_denoising, cfg.sigma_min, cfg.sigma_max, cfg.rho, denoiser.device)
@torch.no_grad()
def sample(self, prev_obs: Tensor, prev_act: Tensor) -> Tuple[Tensor, List[Tensor]]:
device = prev_obs.device
b, t, c, h, w = prev_obs.size()
prev_obs = prev_obs.reshape(b, t * c, h, w)
s_in = torch.ones(b, device=device)
gamma_ = min(self.cfg.s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
x = torch.randn(b, c, h, w, device=device)
trajectory = [x]
for sigma, next_sigma in zip(self.sigmas[:-1], self.sigmas[1:]):
gamma = gamma_ if self.cfg.s_tmin <= sigma <= self.cfg.s_tmax else 0
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * self.cfg.s_noise
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
denoised = self.denoiser.denoise(x, sigma, prev_obs, prev_act)
d = (x - denoised) / sigma_hat
dt = next_sigma - sigma_hat
if self.cfg.order == 1 or next_sigma == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = self.denoiser.denoise(x_2, next_sigma * s_in, prev_obs, prev_act)
d_2 = (x_2 - denoised_2) / next_sigma
d_prime = (d + d_2) / 2
x = x + d_prime * dt
trajectory.append(x)
return x, trajectory
def build_sigmas(num_steps: int, sigma_min: float, sigma_max: float, rho: int, device: torch.device) -> Tensor:
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
l = torch.linspace(0, 1, num_steps, device=device)
sigmas = (max_inv_rho + l * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat((sigmas, sigmas.new_zeros(1)))
|