# from dataclasses import dataclass # from typing import List, Optional, Tuple # import torch # from torch import Tensor # from .denoiser import Denoiser # @dataclass # class DiffusionSamplerConfig: # num_steps_denoising: int # sigma_min: float = 2e-3 # sigma_max: float = 5 # rho: int = 7 # order: int = 1 # s_churn: float = 0 # s_tmin: float = 0 # s_tmax: float = float("inf") # s_noise: float = 1 # s_cond: float = 0 # class DiffusionSampler: # def __init__(self, denoiser: Denoiser, cfg: DiffusionSamplerConfig) -> None: # self.denoiser = denoiser # self.cfg = cfg # self.sigmas = build_sigmas(cfg.num_steps_denoising, cfg.sigma_min, cfg.sigma_max, cfg.rho, denoiser.device) # @torch.no_grad() # def sample(self, prev_obs: Tensor, prev_act: Optional[Tensor]) -> Tuple[Tensor, List[Tensor]]: # device = prev_obs.device # b, t, c, h, w = prev_obs.size() # prev_obs = prev_obs.reshape(b, t * c, h, w) # s_in = torch.ones(b, device=device) # gamma_ = min(self.cfg.s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) # x = torch.randn(b, c, h, w, device=device) # trajectory = [x] # for sigma, next_sigma in zip(self.sigmas[:-1], self.sigmas[1:]): # gamma = gamma_ if self.cfg.s_tmin <= sigma <= self.cfg.s_tmax else 0 # sigma_hat = sigma * (gamma + 1) # if gamma > 0: # eps = torch.randn_like(x) * self.cfg.s_noise # x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5 # if self.cfg.s_cond > 0: # sigma_cond = torch.full((b,), fill_value=self.cfg.s_cond, device=device) # prev_obs = self.denoiser.apply_noise(prev_obs, sigma_cond, sigma_offset_noise=0) # else: # sigma_cond = None # denoised = self.denoiser.denoise(x, sigma, sigma_cond, prev_obs, prev_act) # d = (x - denoised) / sigma_hat # dt = next_sigma - sigma_hat # if self.cfg.order == 1 or next_sigma == 0: # # Euler method # x = x + d * dt # else: # # Heun's method # x_2 = x + d * dt # denoised_2 = self.denoiser.denoise(x_2, next_sigma * s_in, sigma_cond, prev_obs, prev_act) # d_2 = (x_2 - denoised_2) / next_sigma # d_prime = (d + d_2) / 2 # x = x + d_prime * dt # trajectory.append(x) # return x, trajectory # def build_sigmas(num_steps: int, sigma_min: float, sigma_max: float, rho: int, device: torch.device) -> Tensor: # min_inv_rho = sigma_min ** (1 / rho) # max_inv_rho = sigma_max ** (1 / rho) # l = torch.linspace(0, 1, num_steps, device=device) # sigmas = (max_inv_rho + l * (min_inv_rho - max_inv_rho)) ** rho # return torch.cat((sigmas, sigmas.new_zeros(1))) from dataclasses import dataclass from typing import List, Optional, Tuple import torch from torch import Tensor from .denoiser import Denoiser @dataclass class DiffusionSamplerConfig: num_steps_denoising: int sigma_min: float = 2e-3 sigma_max: float = 5 rho: int = 7 order: int = 1 s_churn: float = 0 s_tmin: float = 0 s_tmax: float = float("inf") s_noise: float = 1 s_cond: float = 0 class DiffusionSampler: def __init__(self, denoiser: Denoiser, cfg: DiffusionSamplerConfig) -> None: self.denoiser = denoiser self.cfg = cfg self.sigmas = build_sigmas(cfg.num_steps_denoising, cfg.sigma_min, cfg.sigma_max, cfg.rho, denoiser.device) self.is_first_frame = True self.last_frame = None @torch.no_grad() def sample(self, prev_obs: Tensor, prev_act: Optional[Tensor]) -> Tuple[Tensor, List[Tensor]]: device = prev_obs.device b, t, c, h, w = prev_obs.size() prev_obs = prev_obs.reshape(b, t * c, h, w) s_in = torch.ones(b, device=device) gamma_ = min(self.cfg.s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) # x = torch.randn(b, c, h, w, device=device) # use Gaussian noise as initial sample # use warmstart of last frame if available if self.is_first_frame: # first frame x = torch.randn(b, c, h, w, device=device) # use Gaussian noise as initial sample self.is_first_frame = False else: # use last framw for warmstart sigma_cond = torch.full((b,), fill_value=0.05, device=device) x = self.denoiser.apply_noise(self.last_frame, sigma_cond, sigma_offset_noise=0.05) trajectory = [x] for sigma, next_sigma in zip(self.sigmas[:-1], self.sigmas[1:]): gamma = gamma_ if self.cfg.s_tmin <= sigma <= self.cfg.s_tmax else 0 sigma_hat = sigma * (gamma + 1) if gamma > 0: eps = torch.randn_like(x) * self.cfg.s_noise x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5 if self.cfg.s_cond > 0: sigma_cond = torch.full((b,), fill_value=self.cfg.s_cond, device=device) prev_obs = self.denoiser.apply_noise(prev_obs, sigma_cond, sigma_offset_noise=0) else: sigma_cond = None denoised = self.denoiser.denoise(x, sigma, sigma_cond, prev_obs, prev_act) d = (x - denoised) / sigma_hat dt = next_sigma - sigma_hat if self.cfg.order == 1 or next_sigma == 0: # Euler method x = x + d * dt else: # Heun's method x_2 = x + d * dt denoised_2 = self.denoiser.denoise(x_2, next_sigma * s_in, sigma_cond, prev_obs, prev_act) d_2 = (x_2 - denoised_2) / next_sigma d_prime = (d + d_2) / 2 x = x + d_prime * dt trajectory.append(x) self.last_frame = x # visulize low resolution observation # Denoiser.save_tensor_as_image(x, "inference_output_low_res.png", tensor_name="Inference Low Resolution Observation") return x, trajectory def build_sigmas(num_steps: int, sigma_min: float, sigma_max: float, rho: int, device: torch.device) -> Tensor: min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) l = torch.linspace(0, 1, num_steps, device=device) sigmas = (max_inv_rho + l * (min_inv_rho - max_inv_rho)) ** rho return torch.cat((sigmas, sigmas.new_zeros(1)))