PIWM / src /models /diffusion /denoiser.py
musictimer's picture
Fix bug 1
17fd5e3
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import os
from PIL import Image
import numpy as np
import cv2
from ...data import Batch
from .inner_model import InnerModel, InnerModelConfig
from ...utils import LossAndLogs
from ..contour_detection_model import ContourDetectionModel
def add_dims(input: Tensor, n: int) -> Tensor:
return input.reshape(input.shape + (1,) * (n - input.ndim))
@dataclass
class Conditioners:
c_in: Tensor
c_out: Tensor
c_skip: Tensor
c_noise: Tensor
c_noise_cond: Tensor
@dataclass
class SigmaDistributionConfig:
loc: float
scale: float
sigma_min: float
sigma_max: float
@dataclass
class DenoiserConfig:
inner_model: InnerModelConfig
sigma_data: float
sigma_offset_noise: float
noise_previous_obs: bool
upsampling_factor: Optional[int] = None
class Denoiser(nn.Module):
def __init__(self, cfg: DenoiserConfig) -> None:
super().__init__()
self.cfg = cfg
self.is_upsampler = cfg.upsampling_factor is not None
cfg.inner_model.is_upsampler = self.is_upsampler
self.inner_model = InnerModel(cfg.inner_model)
self.sample_sigma_training = None
self.contour_detection_model = ContourDetectionModel(min_contour_area=25) if self.is_upsampler else None
@property
def device(self) -> torch.device:
return self.inner_model.noise_emb.weight.device
def setup_training(self, cfg: SigmaDistributionConfig) -> None:
assert self.sample_sigma_training is None
def sample_sigma(n: int, device: torch.device):
s = torch.randn(n, device=device) * cfg.scale + cfg.loc
return s.exp().clip(cfg.sigma_min, cfg.sigma_max)
self.sample_sigma_training = sample_sigma
def apply_noise(self, x: Tensor, sigma: Tensor, sigma_offset_noise: float) -> Tensor:
b, c, _, _ = x.shape
offset_noise = sigma_offset_noise * torch.randn(b, c, 1, 1, device=self.device)
return x + offset_noise + torch.randn_like(x) * add_dims(sigma, x.ndim)
def compute_conditioners(self, sigma: Tensor, sigma_cond: Optional[Tensor]) -> Conditioners:
sigma = (sigma**2 + self.cfg.sigma_offset_noise**2).sqrt()
c_in = 1 / (sigma**2 + self.cfg.sigma_data**2).sqrt()
c_skip = self.cfg.sigma_data**2 / (sigma**2 + self.cfg.sigma_data**2)
c_out = sigma * c_skip.sqrt()
c_noise = sigma.log() / 4
c_noise_cond = sigma_cond.log() / 4 if sigma_cond is not None else torch.zeros_like(c_noise)
return Conditioners(*(add_dims(c, n) for c, n in zip((c_in, c_out, c_skip, c_noise, c_noise_cond), (4, 4, 4, 1, 1))))
def compute_model_output(self, noisy_next_obs: Tensor, obs: Tensor, act: Optional[Tensor], cs: Conditioners) -> Tensor:
rescaled_obs = obs / self.cfg.sigma_data
rescaled_noise = noisy_next_obs * cs.c_in
return self.inner_model(rescaled_noise, cs.c_noise, cs.c_noise_cond, rescaled_obs, act)
@torch.no_grad()
def wrap_model_output(self, noisy_next_obs: Tensor, model_output: Tensor, cs: Conditioners) -> Tensor:
d = cs.c_skip * noisy_next_obs + cs.c_out * model_output
# Quantize to {0, ..., 255}, then back to [-1, 1]
d = d.clamp(-1, 1).add(1).div(2).mul(255).byte().div(255).mul(2).sub(1)
return d
@torch.no_grad()
def denoise(self, noisy_next_obs: Tensor, sigma: Tensor, sigma_cond: Optional[Tensor], obs: Tensor, act: Optional[Tensor]) -> Tensor:
cs = self.compute_conditioners(sigma, sigma_cond)
model_output = self.compute_model_output(noisy_next_obs, obs, act, cs)
denoised = self.wrap_model_output(noisy_next_obs, model_output, cs)
return denoised
def forward(self, batch: Batch) -> LossAndLogs:
b, t, c, h, w = batch.obs.size()
H, W = (self.cfg.upsampling_factor * h, self.cfg.upsampling_factor * w) if self.is_upsampler else (h, w)
n = self.cfg.inner_model.num_steps_conditioning
seq_length = t - n # t = n + 1 + num_autoregressive_steps
if self.is_upsampler:
all_obs = torch.stack([x["full_res"] for x in batch.info]).to(self.device)
low_res = F.interpolate(batch.obs.reshape(b * t, c, h, w), scale_factor=self.cfg.upsampling_factor, mode="bicubic").reshape(b, t, c, H, W)
assert all_obs.shape == low_res.shape
else:
all_obs = batch.obs.clone()
loss = 0
for i in range(seq_length):
prev_obs = all_obs[:, i : n + i].reshape(b, n * c, H, W)
prev_act = None if self.is_upsampler else batch.act[:, i : n + i]
obs = all_obs[:, n + i]
mask = batch.mask_padding[:, n + i]
if self.cfg.noise_previous_obs:
sigma_cond = self.sample_sigma_training(b, self.device)
prev_obs = self.apply_noise(prev_obs, sigma_cond, self.cfg.sigma_offset_noise)
else:
sigma_cond = None
if self.is_upsampler:
prev_obs = torch.cat((prev_obs, low_res[:, n + i]), dim=1)
sigma = self.sample_sigma_training(b, self.device)
noisy_obs = self.apply_noise(obs, sigma, self.cfg.sigma_offset_noise)
cs = self.compute_conditioners(sigma, sigma_cond)
model_output = self.compute_model_output(noisy_obs, prev_obs, prev_act, cs)
# save image here
target = (obs - cs.c_skip * noisy_obs) / cs.c_out # obs.shape --> torch.Size([1, 3, 30, 120]), but raw image size: [150, 600]
# # Save tensor as image using the established codebase pattern
# def save_tensor_as_image(tensor, filename, tensor_name=""):
# """Save tensor as image following the codebase's established pattern"""
# assert tensor.ndim == 4 and tensor.size(0) == 1
# print(f"{tensor_name} range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
# # Check if tensor is in expected [-1, 1] range
# if tensor.min() < -1.1 or tensor.max() > 1.1:
# print(f"WARNING: {tensor_name} values outside expected [-1,1] range!")
# # Normalize to [-1, 1] if needed
# tensor_normalized = tensor.clamp(-2, 2) # Clamp to reasonable range first
# tensor_normalized = (tensor_normalized - tensor_normalized.min()) / (tensor_normalized.max() - tensor_normalized.min()) * 2 - 1
# img = Image.fromarray(tensor_normalized[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy())
# else:
# img = Image.fromarray(tensor[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy())
# os.makedirs("/home/zhexiao/Documents/diamond/images", exist_ok=True)
# img.save(os.path.join("/home/zhexiao/Documents/diamond/images", filename))
# Run contour detection on model output
# if self.is_upsampler:
# denoised = self.wrap_model_output(noisy_obs, model_output, cs)
# rgb_image = model_output[0].add(1).div(2).mul(255).clamp(0, 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
# annotated_image, vehicle_states = self.contour_detection_model.detect_from_image(rgb_image)
# #########################################################
# # TODO: add contour detection related existence loss
# if self.is_upsampler:
# loss = 0
# #########################################################
loss += F.mse_loss(model_output[mask], target[mask])
denoised = self.wrap_model_output(noisy_obs, model_output, cs)
all_obs[:, n + i] = denoised
loss /= seq_length
return loss, {"loss_denoising": loss.item()}