from dataclasses import dataclass from typing import Optional import torch from torch import Tensor import torch.nn as nn import torch.nn.functional as F import os from PIL import Image import numpy as np import cv2 from ...data import Batch from .inner_model import InnerModel, InnerModelConfig from ...utils import LossAndLogs from ..contour_detection_model import ContourDetectionModel def add_dims(input: Tensor, n: int) -> Tensor: return input.reshape(input.shape + (1,) * (n - input.ndim)) @dataclass class Conditioners: c_in: Tensor c_out: Tensor c_skip: Tensor c_noise: Tensor c_noise_cond: Tensor @dataclass class SigmaDistributionConfig: loc: float scale: float sigma_min: float sigma_max: float @dataclass class DenoiserConfig: inner_model: InnerModelConfig sigma_data: float sigma_offset_noise: float noise_previous_obs: bool upsampling_factor: Optional[int] = None class Denoiser(nn.Module): def __init__(self, cfg: DenoiserConfig) -> None: super().__init__() self.cfg = cfg self.is_upsampler = cfg.upsampling_factor is not None cfg.inner_model.is_upsampler = self.is_upsampler self.inner_model = InnerModel(cfg.inner_model) self.sample_sigma_training = None self.contour_detection_model = ContourDetectionModel(min_contour_area=25) if self.is_upsampler else None @property def device(self) -> torch.device: return self.inner_model.noise_emb.weight.device def setup_training(self, cfg: SigmaDistributionConfig) -> None: assert self.sample_sigma_training is None def sample_sigma(n: int, device: torch.device): s = torch.randn(n, device=device) * cfg.scale + cfg.loc return s.exp().clip(cfg.sigma_min, cfg.sigma_max) self.sample_sigma_training = sample_sigma def apply_noise(self, x: Tensor, sigma: Tensor, sigma_offset_noise: float) -> Tensor: b, c, _, _ = x.shape offset_noise = sigma_offset_noise * torch.randn(b, c, 1, 1, device=self.device) return x + offset_noise + torch.randn_like(x) * add_dims(sigma, x.ndim) def compute_conditioners(self, sigma: Tensor, sigma_cond: Optional[Tensor]) -> Conditioners: sigma = (sigma**2 + self.cfg.sigma_offset_noise**2).sqrt() c_in = 1 / (sigma**2 + self.cfg.sigma_data**2).sqrt() c_skip = self.cfg.sigma_data**2 / (sigma**2 + self.cfg.sigma_data**2) c_out = sigma * c_skip.sqrt() c_noise = sigma.log() / 4 c_noise_cond = sigma_cond.log() / 4 if sigma_cond is not None else torch.zeros_like(c_noise) return Conditioners(*(add_dims(c, n) for c, n in zip((c_in, c_out, c_skip, c_noise, c_noise_cond), (4, 4, 4, 1, 1)))) def compute_model_output(self, noisy_next_obs: Tensor, obs: Tensor, act: Optional[Tensor], cs: Conditioners) -> Tensor: rescaled_obs = obs / self.cfg.sigma_data rescaled_noise = noisy_next_obs * cs.c_in return self.inner_model(rescaled_noise, cs.c_noise, cs.c_noise_cond, rescaled_obs, act) @torch.no_grad() def wrap_model_output(self, noisy_next_obs: Tensor, model_output: Tensor, cs: Conditioners) -> Tensor: d = cs.c_skip * noisy_next_obs + cs.c_out * model_output # Quantize to {0, ..., 255}, then back to [-1, 1] d = d.clamp(-1, 1).add(1).div(2).mul(255).byte().div(255).mul(2).sub(1) return d @torch.no_grad() def denoise(self, noisy_next_obs: Tensor, sigma: Tensor, sigma_cond: Optional[Tensor], obs: Tensor, act: Optional[Tensor]) -> Tensor: cs = self.compute_conditioners(sigma, sigma_cond) model_output = self.compute_model_output(noisy_next_obs, obs, act, cs) denoised = self.wrap_model_output(noisy_next_obs, model_output, cs) return denoised def forward(self, batch: Batch) -> LossAndLogs: b, t, c, h, w = batch.obs.size() H, W = (self.cfg.upsampling_factor * h, self.cfg.upsampling_factor * w) if self.is_upsampler else (h, w) n = self.cfg.inner_model.num_steps_conditioning seq_length = t - n # t = n + 1 + num_autoregressive_steps if self.is_upsampler: all_obs = torch.stack([x["full_res"] for x in batch.info]).to(self.device) low_res = F.interpolate(batch.obs.reshape(b * t, c, h, w), scale_factor=self.cfg.upsampling_factor, mode="bicubic").reshape(b, t, c, H, W) assert all_obs.shape == low_res.shape else: all_obs = batch.obs.clone() loss = 0 for i in range(seq_length): prev_obs = all_obs[:, i : n + i].reshape(b, n * c, H, W) prev_act = None if self.is_upsampler else batch.act[:, i : n + i] obs = all_obs[:, n + i] mask = batch.mask_padding[:, n + i] if self.cfg.noise_previous_obs: sigma_cond = self.sample_sigma_training(b, self.device) prev_obs = self.apply_noise(prev_obs, sigma_cond, self.cfg.sigma_offset_noise) else: sigma_cond = None if self.is_upsampler: prev_obs = torch.cat((prev_obs, low_res[:, n + i]), dim=1) sigma = self.sample_sigma_training(b, self.device) noisy_obs = self.apply_noise(obs, sigma, self.cfg.sigma_offset_noise) cs = self.compute_conditioners(sigma, sigma_cond) model_output = self.compute_model_output(noisy_obs, prev_obs, prev_act, cs) # save image here target = (obs - cs.c_skip * noisy_obs) / cs.c_out # obs.shape --> torch.Size([1, 3, 30, 120]), but raw image size: [150, 600] # # Save tensor as image using the established codebase pattern # def save_tensor_as_image(tensor, filename, tensor_name=""): # """Save tensor as image following the codebase's established pattern""" # assert tensor.ndim == 4 and tensor.size(0) == 1 # print(f"{tensor_name} range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]") # # Check if tensor is in expected [-1, 1] range # if tensor.min() < -1.1 or tensor.max() > 1.1: # print(f"WARNING: {tensor_name} values outside expected [-1,1] range!") # # Normalize to [-1, 1] if needed # tensor_normalized = tensor.clamp(-2, 2) # Clamp to reasonable range first # tensor_normalized = (tensor_normalized - tensor_normalized.min()) / (tensor_normalized.max() - tensor_normalized.min()) * 2 - 1 # img = Image.fromarray(tensor_normalized[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()) # else: # img = Image.fromarray(tensor[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()) # os.makedirs("/home/zhexiao/Documents/diamond/images", exist_ok=True) # img.save(os.path.join("/home/zhexiao/Documents/diamond/images", filename)) # Run contour detection on model output # if self.is_upsampler: # denoised = self.wrap_model_output(noisy_obs, model_output, cs) # rgb_image = model_output[0].add(1).div(2).mul(255).clamp(0, 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8) # annotated_image, vehicle_states = self.contour_detection_model.detect_from_image(rgb_image) # ######################################################### # # TODO: add contour detection related existence loss # if self.is_upsampler: # loss = 0 # ######################################################### loss += F.mse_loss(model_output[mask], target[mask]) denoised = self.wrap_model_output(noisy_obs, model_output, cs) all_obs[:, n + i] = denoised loss /= seq_length return loss, {"loss_denoising": loss.item()}