Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| from torch import Tensor | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import os | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| from ...data import Batch | |
| from .inner_model import InnerModel, InnerModelConfig | |
| from ...utils import LossAndLogs | |
| from ..contour_detection_model import ContourDetectionModel | |
| def add_dims(input: Tensor, n: int) -> Tensor: | |
| return input.reshape(input.shape + (1,) * (n - input.ndim)) | |
| class Conditioners: | |
| c_in: Tensor | |
| c_out: Tensor | |
| c_skip: Tensor | |
| c_noise: Tensor | |
| c_noise_cond: Tensor | |
| class SigmaDistributionConfig: | |
| loc: float | |
| scale: float | |
| sigma_min: float | |
| sigma_max: float | |
| class DenoiserConfig: | |
| inner_model: InnerModelConfig | |
| sigma_data: float | |
| sigma_offset_noise: float | |
| noise_previous_obs: bool | |
| upsampling_factor: Optional[int] = None | |
| class Denoiser(nn.Module): | |
| def __init__(self, cfg: DenoiserConfig) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| self.is_upsampler = cfg.upsampling_factor is not None | |
| cfg.inner_model.is_upsampler = self.is_upsampler | |
| self.inner_model = InnerModel(cfg.inner_model) | |
| self.sample_sigma_training = None | |
| self.contour_detection_model = ContourDetectionModel(min_contour_area=25) if self.is_upsampler else None | |
| def device(self) -> torch.device: | |
| return self.inner_model.noise_emb.weight.device | |
| def setup_training(self, cfg: SigmaDistributionConfig) -> None: | |
| assert self.sample_sigma_training is None | |
| def sample_sigma(n: int, device: torch.device): | |
| s = torch.randn(n, device=device) * cfg.scale + cfg.loc | |
| return s.exp().clip(cfg.sigma_min, cfg.sigma_max) | |
| self.sample_sigma_training = sample_sigma | |
| def apply_noise(self, x: Tensor, sigma: Tensor, sigma_offset_noise: float) -> Tensor: | |
| b, c, _, _ = x.shape | |
| offset_noise = sigma_offset_noise * torch.randn(b, c, 1, 1, device=self.device) | |
| return x + offset_noise + torch.randn_like(x) * add_dims(sigma, x.ndim) | |
| def compute_conditioners(self, sigma: Tensor, sigma_cond: Optional[Tensor]) -> Conditioners: | |
| sigma = (sigma**2 + self.cfg.sigma_offset_noise**2).sqrt() | |
| c_in = 1 / (sigma**2 + self.cfg.sigma_data**2).sqrt() | |
| c_skip = self.cfg.sigma_data**2 / (sigma**2 + self.cfg.sigma_data**2) | |
| c_out = sigma * c_skip.sqrt() | |
| c_noise = sigma.log() / 4 | |
| c_noise_cond = sigma_cond.log() / 4 if sigma_cond is not None else torch.zeros_like(c_noise) | |
| return Conditioners(*(add_dims(c, n) for c, n in zip((c_in, c_out, c_skip, c_noise, c_noise_cond), (4, 4, 4, 1, 1)))) | |
| def compute_model_output(self, noisy_next_obs: Tensor, obs: Tensor, act: Optional[Tensor], cs: Conditioners) -> Tensor: | |
| rescaled_obs = obs / self.cfg.sigma_data | |
| rescaled_noise = noisy_next_obs * cs.c_in | |
| return self.inner_model(rescaled_noise, cs.c_noise, cs.c_noise_cond, rescaled_obs, act) | |
| def wrap_model_output(self, noisy_next_obs: Tensor, model_output: Tensor, cs: Conditioners) -> Tensor: | |
| d = cs.c_skip * noisy_next_obs + cs.c_out * model_output | |
| # Quantize to {0, ..., 255}, then back to [-1, 1] | |
| d = d.clamp(-1, 1).add(1).div(2).mul(255).byte().div(255).mul(2).sub(1) | |
| return d | |
| def denoise(self, noisy_next_obs: Tensor, sigma: Tensor, sigma_cond: Optional[Tensor], obs: Tensor, act: Optional[Tensor]) -> Tensor: | |
| cs = self.compute_conditioners(sigma, sigma_cond) | |
| model_output = self.compute_model_output(noisy_next_obs, obs, act, cs) | |
| denoised = self.wrap_model_output(noisy_next_obs, model_output, cs) | |
| return denoised | |
| def forward(self, batch: Batch) -> LossAndLogs: | |
| b, t, c, h, w = batch.obs.size() | |
| H, W = (self.cfg.upsampling_factor * h, self.cfg.upsampling_factor * w) if self.is_upsampler else (h, w) | |
| n = self.cfg.inner_model.num_steps_conditioning | |
| seq_length = t - n # t = n + 1 + num_autoregressive_steps | |
| if self.is_upsampler: | |
| all_obs = torch.stack([x["full_res"] for x in batch.info]).to(self.device) | |
| low_res = F.interpolate(batch.obs.reshape(b * t, c, h, w), scale_factor=self.cfg.upsampling_factor, mode="bicubic").reshape(b, t, c, H, W) | |
| assert all_obs.shape == low_res.shape | |
| else: | |
| all_obs = batch.obs.clone() | |
| loss = 0 | |
| for i in range(seq_length): | |
| prev_obs = all_obs[:, i : n + i].reshape(b, n * c, H, W) | |
| prev_act = None if self.is_upsampler else batch.act[:, i : n + i] | |
| obs = all_obs[:, n + i] | |
| mask = batch.mask_padding[:, n + i] | |
| if self.cfg.noise_previous_obs: | |
| sigma_cond = self.sample_sigma_training(b, self.device) | |
| prev_obs = self.apply_noise(prev_obs, sigma_cond, self.cfg.sigma_offset_noise) | |
| else: | |
| sigma_cond = None | |
| if self.is_upsampler: | |
| prev_obs = torch.cat((prev_obs, low_res[:, n + i]), dim=1) | |
| sigma = self.sample_sigma_training(b, self.device) | |
| noisy_obs = self.apply_noise(obs, sigma, self.cfg.sigma_offset_noise) | |
| cs = self.compute_conditioners(sigma, sigma_cond) | |
| model_output = self.compute_model_output(noisy_obs, prev_obs, prev_act, cs) | |
| # save image here | |
| target = (obs - cs.c_skip * noisy_obs) / cs.c_out # obs.shape --> torch.Size([1, 3, 30, 120]), but raw image size: [150, 600] | |
| # # Save tensor as image using the established codebase pattern | |
| # def save_tensor_as_image(tensor, filename, tensor_name=""): | |
| # """Save tensor as image following the codebase's established pattern""" | |
| # assert tensor.ndim == 4 and tensor.size(0) == 1 | |
| # print(f"{tensor_name} range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]") | |
| # # Check if tensor is in expected [-1, 1] range | |
| # if tensor.min() < -1.1 or tensor.max() > 1.1: | |
| # print(f"WARNING: {tensor_name} values outside expected [-1,1] range!") | |
| # # Normalize to [-1, 1] if needed | |
| # tensor_normalized = tensor.clamp(-2, 2) # Clamp to reasonable range first | |
| # tensor_normalized = (tensor_normalized - tensor_normalized.min()) / (tensor_normalized.max() - tensor_normalized.min()) * 2 - 1 | |
| # img = Image.fromarray(tensor_normalized[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()) | |
| # else: | |
| # img = Image.fromarray(tensor[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()) | |
| # os.makedirs("/home/zhexiao/Documents/diamond/images", exist_ok=True) | |
| # img.save(os.path.join("/home/zhexiao/Documents/diamond/images", filename)) | |
| # Run contour detection on model output | |
| # if self.is_upsampler: | |
| # denoised = self.wrap_model_output(noisy_obs, model_output, cs) | |
| # rgb_image = model_output[0].add(1).div(2).mul(255).clamp(0, 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8) | |
| # annotated_image, vehicle_states = self.contour_detection_model.detect_from_image(rgb_image) | |
| # ######################################################### | |
| # # TODO: add contour detection related existence loss | |
| # if self.is_upsampler: | |
| # loss = 0 | |
| # ######################################################### | |
| loss += F.mse_loss(model_output[mask], target[mask]) | |
| denoised = self.wrap_model_output(noisy_obs, model_output, cs) | |
| all_obs[:, n + i] = denoised | |
| loss /= seq_length | |
| return loss, {"loss_denoising": loss.item()} | |