Spaces:
Sleeping
Sleeping
File size: 8,086 Bytes
c64c726 17fd5e3 c64c726 17fd5e3 c64c726 17fd5e3 c64c726 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import os
from PIL import Image
import numpy as np
import cv2
from ...data import Batch
from .inner_model import InnerModel, InnerModelConfig
from ...utils import LossAndLogs
from ..contour_detection_model import ContourDetectionModel
def add_dims(input: Tensor, n: int) -> Tensor:
return input.reshape(input.shape + (1,) * (n - input.ndim))
@dataclass
class Conditioners:
c_in: Tensor
c_out: Tensor
c_skip: Tensor
c_noise: Tensor
c_noise_cond: Tensor
@dataclass
class SigmaDistributionConfig:
loc: float
scale: float
sigma_min: float
sigma_max: float
@dataclass
class DenoiserConfig:
inner_model: InnerModelConfig
sigma_data: float
sigma_offset_noise: float
noise_previous_obs: bool
upsampling_factor: Optional[int] = None
class Denoiser(nn.Module):
def __init__(self, cfg: DenoiserConfig) -> None:
super().__init__()
self.cfg = cfg
self.is_upsampler = cfg.upsampling_factor is not None
cfg.inner_model.is_upsampler = self.is_upsampler
self.inner_model = InnerModel(cfg.inner_model)
self.sample_sigma_training = None
self.contour_detection_model = ContourDetectionModel(min_contour_area=25) if self.is_upsampler else None
@property
def device(self) -> torch.device:
return self.inner_model.noise_emb.weight.device
def setup_training(self, cfg: SigmaDistributionConfig) -> None:
assert self.sample_sigma_training is None
def sample_sigma(n: int, device: torch.device):
s = torch.randn(n, device=device) * cfg.scale + cfg.loc
return s.exp().clip(cfg.sigma_min, cfg.sigma_max)
self.sample_sigma_training = sample_sigma
def apply_noise(self, x: Tensor, sigma: Tensor, sigma_offset_noise: float) -> Tensor:
b, c, _, _ = x.shape
offset_noise = sigma_offset_noise * torch.randn(b, c, 1, 1, device=self.device)
return x + offset_noise + torch.randn_like(x) * add_dims(sigma, x.ndim)
def compute_conditioners(self, sigma: Tensor, sigma_cond: Optional[Tensor]) -> Conditioners:
sigma = (sigma**2 + self.cfg.sigma_offset_noise**2).sqrt()
c_in = 1 / (sigma**2 + self.cfg.sigma_data**2).sqrt()
c_skip = self.cfg.sigma_data**2 / (sigma**2 + self.cfg.sigma_data**2)
c_out = sigma * c_skip.sqrt()
c_noise = sigma.log() / 4
c_noise_cond = sigma_cond.log() / 4 if sigma_cond is not None else torch.zeros_like(c_noise)
return Conditioners(*(add_dims(c, n) for c, n in zip((c_in, c_out, c_skip, c_noise, c_noise_cond), (4, 4, 4, 1, 1))))
def compute_model_output(self, noisy_next_obs: Tensor, obs: Tensor, act: Optional[Tensor], cs: Conditioners) -> Tensor:
rescaled_obs = obs / self.cfg.sigma_data
rescaled_noise = noisy_next_obs * cs.c_in
return self.inner_model(rescaled_noise, cs.c_noise, cs.c_noise_cond, rescaled_obs, act)
@torch.no_grad()
def wrap_model_output(self, noisy_next_obs: Tensor, model_output: Tensor, cs: Conditioners) -> Tensor:
d = cs.c_skip * noisy_next_obs + cs.c_out * model_output
# Quantize to {0, ..., 255}, then back to [-1, 1]
d = d.clamp(-1, 1).add(1).div(2).mul(255).byte().div(255).mul(2).sub(1)
return d
@torch.no_grad()
def denoise(self, noisy_next_obs: Tensor, sigma: Tensor, sigma_cond: Optional[Tensor], obs: Tensor, act: Optional[Tensor]) -> Tensor:
cs = self.compute_conditioners(sigma, sigma_cond)
model_output = self.compute_model_output(noisy_next_obs, obs, act, cs)
denoised = self.wrap_model_output(noisy_next_obs, model_output, cs)
return denoised
def forward(self, batch: Batch) -> LossAndLogs:
b, t, c, h, w = batch.obs.size()
H, W = (self.cfg.upsampling_factor * h, self.cfg.upsampling_factor * w) if self.is_upsampler else (h, w)
n = self.cfg.inner_model.num_steps_conditioning
seq_length = t - n # t = n + 1 + num_autoregressive_steps
if self.is_upsampler:
all_obs = torch.stack([x["full_res"] for x in batch.info]).to(self.device)
low_res = F.interpolate(batch.obs.reshape(b * t, c, h, w), scale_factor=self.cfg.upsampling_factor, mode="bicubic").reshape(b, t, c, H, W)
assert all_obs.shape == low_res.shape
else:
all_obs = batch.obs.clone()
loss = 0
for i in range(seq_length):
prev_obs = all_obs[:, i : n + i].reshape(b, n * c, H, W)
prev_act = None if self.is_upsampler else batch.act[:, i : n + i]
obs = all_obs[:, n + i]
mask = batch.mask_padding[:, n + i]
if self.cfg.noise_previous_obs:
sigma_cond = self.sample_sigma_training(b, self.device)
prev_obs = self.apply_noise(prev_obs, sigma_cond, self.cfg.sigma_offset_noise)
else:
sigma_cond = None
if self.is_upsampler:
prev_obs = torch.cat((prev_obs, low_res[:, n + i]), dim=1)
sigma = self.sample_sigma_training(b, self.device)
noisy_obs = self.apply_noise(obs, sigma, self.cfg.sigma_offset_noise)
cs = self.compute_conditioners(sigma, sigma_cond)
model_output = self.compute_model_output(noisy_obs, prev_obs, prev_act, cs)
# save image here
target = (obs - cs.c_skip * noisy_obs) / cs.c_out # obs.shape --> torch.Size([1, 3, 30, 120]), but raw image size: [150, 600]
# # Save tensor as image using the established codebase pattern
# def save_tensor_as_image(tensor, filename, tensor_name=""):
# """Save tensor as image following the codebase's established pattern"""
# assert tensor.ndim == 4 and tensor.size(0) == 1
# print(f"{tensor_name} range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
# # Check if tensor is in expected [-1, 1] range
# if tensor.min() < -1.1 or tensor.max() > 1.1:
# print(f"WARNING: {tensor_name} values outside expected [-1,1] range!")
# # Normalize to [-1, 1] if needed
# tensor_normalized = tensor.clamp(-2, 2) # Clamp to reasonable range first
# tensor_normalized = (tensor_normalized - tensor_normalized.min()) / (tensor_normalized.max() - tensor_normalized.min()) * 2 - 1
# img = Image.fromarray(tensor_normalized[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy())
# else:
# img = Image.fromarray(tensor[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy())
# os.makedirs("/home/zhexiao/Documents/diamond/images", exist_ok=True)
# img.save(os.path.join("/home/zhexiao/Documents/diamond/images", filename))
# Run contour detection on model output
# if self.is_upsampler:
# denoised = self.wrap_model_output(noisy_obs, model_output, cs)
# rgb_image = model_output[0].add(1).div(2).mul(255).clamp(0, 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
# annotated_image, vehicle_states = self.contour_detection_model.detect_from_image(rgb_image)
# #########################################################
# # TODO: add contour detection related existence loss
# if self.is_upsampler:
# loss = 0
# #########################################################
loss += F.mse_loss(model_output[mask], target[mask])
denoised = self.wrap_model_output(noisy_obs, model_output, cs)
all_obs[:, n + i] = denoised
loss /= seq_length
return loss, {"loss_denoising": loss.item()}
|