PIWM / src /models /diffusion /diffusion_sampler.py
musictimer's picture
Initial Diamond CSGO AI deployment
c64c726
# from dataclasses import dataclass
# from typing import List, Optional, Tuple
# import torch
# from torch import Tensor
# from .denoiser import Denoiser
# @dataclass
# class DiffusionSamplerConfig:
# num_steps_denoising: int
# sigma_min: float = 2e-3
# sigma_max: float = 5
# rho: int = 7
# order: int = 1
# s_churn: float = 0
# s_tmin: float = 0
# s_tmax: float = float("inf")
# s_noise: float = 1
# s_cond: float = 0
# class DiffusionSampler:
# def __init__(self, denoiser: Denoiser, cfg: DiffusionSamplerConfig) -> None:
# self.denoiser = denoiser
# self.cfg = cfg
# self.sigmas = build_sigmas(cfg.num_steps_denoising, cfg.sigma_min, cfg.sigma_max, cfg.rho, denoiser.device)
# @torch.no_grad()
# def sample(self, prev_obs: Tensor, prev_act: Optional[Tensor]) -> Tuple[Tensor, List[Tensor]]:
# device = prev_obs.device
# b, t, c, h, w = prev_obs.size()
# prev_obs = prev_obs.reshape(b, t * c, h, w)
# s_in = torch.ones(b, device=device)
# gamma_ = min(self.cfg.s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
# x = torch.randn(b, c, h, w, device=device)
# trajectory = [x]
# for sigma, next_sigma in zip(self.sigmas[:-1], self.sigmas[1:]):
# gamma = gamma_ if self.cfg.s_tmin <= sigma <= self.cfg.s_tmax else 0
# sigma_hat = sigma * (gamma + 1)
# if gamma > 0:
# eps = torch.randn_like(x) * self.cfg.s_noise
# x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
# if self.cfg.s_cond > 0:
# sigma_cond = torch.full((b,), fill_value=self.cfg.s_cond, device=device)
# prev_obs = self.denoiser.apply_noise(prev_obs, sigma_cond, sigma_offset_noise=0)
# else:
# sigma_cond = None
# denoised = self.denoiser.denoise(x, sigma, sigma_cond, prev_obs, prev_act)
# d = (x - denoised) / sigma_hat
# dt = next_sigma - sigma_hat
# if self.cfg.order == 1 or next_sigma == 0:
# # Euler method
# x = x + d * dt
# else:
# # Heun's method
# x_2 = x + d * dt
# denoised_2 = self.denoiser.denoise(x_2, next_sigma * s_in, sigma_cond, prev_obs, prev_act)
# d_2 = (x_2 - denoised_2) / next_sigma
# d_prime = (d + d_2) / 2
# x = x + d_prime * dt
# trajectory.append(x)
# return x, trajectory
# def build_sigmas(num_steps: int, sigma_min: float, sigma_max: float, rho: int, device: torch.device) -> Tensor:
# min_inv_rho = sigma_min ** (1 / rho)
# max_inv_rho = sigma_max ** (1 / rho)
# l = torch.linspace(0, 1, num_steps, device=device)
# sigmas = (max_inv_rho + l * (min_inv_rho - max_inv_rho)) ** rho
# return torch.cat((sigmas, sigmas.new_zeros(1)))
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
from torch import Tensor
from .denoiser import Denoiser
@dataclass
class DiffusionSamplerConfig:
num_steps_denoising: int
sigma_min: float = 2e-3
sigma_max: float = 5
rho: int = 7
order: int = 1
s_churn: float = 0
s_tmin: float = 0
s_tmax: float = float("inf")
s_noise: float = 1
s_cond: float = 0
class DiffusionSampler:
def __init__(self, denoiser: Denoiser, cfg: DiffusionSamplerConfig) -> None:
self.denoiser = denoiser
self.cfg = cfg
self.sigmas = build_sigmas(cfg.num_steps_denoising, cfg.sigma_min, cfg.sigma_max, cfg.rho, denoiser.device)
self.is_first_frame = True
self.last_frame = None
@torch.no_grad()
def sample(self, prev_obs: Tensor, prev_act: Optional[Tensor]) -> Tuple[Tensor, List[Tensor]]:
device = prev_obs.device
b, t, c, h, w = prev_obs.size()
prev_obs = prev_obs.reshape(b, t * c, h, w)
s_in = torch.ones(b, device=device)
gamma_ = min(self.cfg.s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
# x = torch.randn(b, c, h, w, device=device) # use Gaussian noise as initial sample
# use warmstart of last frame if available
if self.is_first_frame: # first frame
x = torch.randn(b, c, h, w, device=device) # use Gaussian noise as initial sample
self.is_first_frame = False
else: # use last framw for warmstart
sigma_cond = torch.full((b,), fill_value=0.05, device=device)
x = self.denoiser.apply_noise(self.last_frame, sigma_cond, sigma_offset_noise=0.05)
trajectory = [x]
for sigma, next_sigma in zip(self.sigmas[:-1], self.sigmas[1:]):
gamma = gamma_ if self.cfg.s_tmin <= sigma <= self.cfg.s_tmax else 0
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * self.cfg.s_noise
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
if self.cfg.s_cond > 0:
sigma_cond = torch.full((b,), fill_value=self.cfg.s_cond, device=device)
prev_obs = self.denoiser.apply_noise(prev_obs, sigma_cond, sigma_offset_noise=0)
else:
sigma_cond = None
denoised = self.denoiser.denoise(x, sigma, sigma_cond, prev_obs, prev_act)
d = (x - denoised) / sigma_hat
dt = next_sigma - sigma_hat
if self.cfg.order == 1 or next_sigma == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = self.denoiser.denoise(x_2, next_sigma * s_in, sigma_cond, prev_obs, prev_act)
d_2 = (x_2 - denoised_2) / next_sigma
d_prime = (d + d_2) / 2
x = x + d_prime * dt
trajectory.append(x)
self.last_frame = x
# visulize low resolution observation
# Denoiser.save_tensor_as_image(x, "inference_output_low_res.png", tensor_name="Inference Low Resolution Observation")
return x, trajectory
def build_sigmas(num_steps: int, sigma_min: float, sigma_max: float, rho: int, device: torch.device) -> Tensor:
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
l = torch.linspace(0, 1, num_steps, device=device)
sigmas = (max_inv_rho + l * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat((sigmas, sigmas.new_zeros(1)))