neural-boy / src /models /diffusion /denoiser.py
sachinkumarsingh's picture
Upload folder using huggingface_hub
d548197 verified
Raw
History Blame Contribute Delete
4.3 kB
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from data import Batch
from .inner_model import InnerModel, InnerModelConfig
from utils import LossAndLogs
def add_dims(input: Tensor, n: int) -> Tensor:
return input.reshape(input.shape + (1,) * (n - input.ndim))
@dataclass
class Conditioners:
c_in: Tensor
c_out: Tensor
c_skip: Tensor
c_noise: Tensor
@dataclass
class SigmaDistributionConfig:
loc: float
scale: float
sigma_min: float
sigma_max: float
@dataclass
class DenoiserConfig:
inner_model: InnerModelConfig
sigma_data: float
sigma_offset_noise: float
class Denoiser(nn.Module):
def __init__(self, cfg: DenoiserConfig) -> None:
super().__init__()
self.cfg = cfg
self.inner_model = InnerModel(cfg.inner_model)
self.sample_sigma_training = None
@property
def device(self) -> torch.device:
return self.inner_model.noise_emb.weight.device
def setup_training(self, cfg: SigmaDistributionConfig) -> None:
assert self.sample_sigma_training is None
def sample_sigma(n: int, device: torch.device):
s = torch.randn(n, device=device) * cfg.scale + cfg.loc
return s.exp().clip(cfg.sigma_min, cfg.sigma_max)
self.sample_sigma_training = sample_sigma
def apply_noise(self, x: Tensor, sigma: Tensor, sigma_offset_noise: float) -> Tensor:
b, c, _, _ = x.shape
offset_noise = sigma_offset_noise * torch.randn(b, c, 1, 1, device=self.device)
return x + offset_noise + torch.randn_like(x) * add_dims(sigma, x.ndim)
def compute_conditioners(self, sigma: Tensor) -> Conditioners:
sigma = (sigma**2 + self.cfg.sigma_offset_noise**2).sqrt()
c_in = 1 / (sigma**2 + self.cfg.sigma_data**2).sqrt()
c_skip = self.cfg.sigma_data**2 / (sigma**2 + self.cfg.sigma_data**2)
c_out = sigma * c_skip.sqrt()
c_noise = sigma.log() / 4
return Conditioners(*(add_dims(c, n) for c, n in zip((c_in, c_out, c_skip, c_noise), (4, 4, 4, 1, 1))))
def compute_model_output(self, noisy_next_obs: Tensor, obs: Tensor, act: Tensor, cs: Conditioners) -> Tensor:
rescaled_obs = obs / self.cfg.sigma_data
rescaled_noise = noisy_next_obs * cs.c_in
return self.inner_model(rescaled_noise, cs.c_noise, rescaled_obs, act)
@torch.no_grad()
def wrap_model_output(self, noisy_next_obs: Tensor, model_output: Tensor, cs: Conditioners) -> Tensor:
d = cs.c_skip * noisy_next_obs + cs.c_out * model_output
# Quantize to {0, ..., 255}, then back to [-1, 1]
d = d.clamp(-1, 1).add(1).div(2).mul(255).byte().div(255).mul(2).sub(1)
return d
@torch.no_grad()
def denoise(self, noisy_next_obs: Tensor, sigma: Tensor, obs: Tensor, act: Tensor) -> Tensor:
cs = self.compute_conditioners(sigma)
model_output = self.compute_model_output(noisy_next_obs, obs, act, cs)
denoised = self.wrap_model_output(noisy_next_obs, model_output, cs)
return denoised
def forward(self, batch: Batch) -> LossAndLogs:
n = self.cfg.inner_model.num_steps_conditioning
seq_length = batch.obs.size(1) - n
all_obs = batch.obs.clone()
loss = 0
for i in range(seq_length):
obs = all_obs[:, i : n + i]
next_obs = all_obs[:, n + i]
act = batch.act[:, i : n + i]
mask = batch.mask_padding[:, n + i]
b, t, c, h, w = obs.shape
obs = obs.reshape(b, t * c, h, w)
sigma = self.sample_sigma_training(b, self.device)
noisy_next_obs = self.apply_noise(next_obs, sigma, self.cfg.sigma_offset_noise)
cs = self.compute_conditioners(sigma)
model_output = self.compute_model_output(noisy_next_obs, obs, act, cs)
target = (next_obs - cs.c_skip * noisy_next_obs) / cs.c_out
loss += F.mse_loss(model_output[mask], target[mask])
denoised = self.wrap_model_output(noisy_next_obs, model_output, cs)
all_obs[:, n + i] = denoised
loss /= seq_length
return loss, {"loss_denoising": loss.detach()}