File size: 8,086 Bytes
c64c726
 
 
 
 
 
 
 
 
 
 
 
17fd5e3
c64c726
17fd5e3
c64c726
17fd5e3
c64c726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from dataclasses import dataclass
from typing import Optional

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import os
from PIL import Image
import numpy as np
import cv2

from ...data import Batch
from .inner_model import InnerModel, InnerModelConfig
from ...utils import LossAndLogs

from ..contour_detection_model import ContourDetectionModel

def add_dims(input: Tensor, n: int) -> Tensor:
    return input.reshape(input.shape + (1,) * (n - input.ndim))


@dataclass
class Conditioners:
    c_in: Tensor
    c_out: Tensor
    c_skip: Tensor
    c_noise: Tensor
    c_noise_cond: Tensor


@dataclass
class SigmaDistributionConfig:
    loc: float
    scale: float
    sigma_min: float
    sigma_max: float


@dataclass
class DenoiserConfig:
    inner_model: InnerModelConfig
    sigma_data: float
    sigma_offset_noise: float
    noise_previous_obs: bool
    upsampling_factor: Optional[int] = None


class Denoiser(nn.Module):
    def __init__(self, cfg: DenoiserConfig) -> None:
        super().__init__()
        self.cfg = cfg
        self.is_upsampler = cfg.upsampling_factor is not None
        cfg.inner_model.is_upsampler = self.is_upsampler
        self.inner_model = InnerModel(cfg.inner_model)
        self.sample_sigma_training = None
        self.contour_detection_model = ContourDetectionModel(min_contour_area=25) if self.is_upsampler else None

    @property
    def device(self) -> torch.device:
        return self.inner_model.noise_emb.weight.device

    def setup_training(self, cfg: SigmaDistributionConfig) -> None:
        assert self.sample_sigma_training is None

        def sample_sigma(n: int, device: torch.device):
            s = torch.randn(n, device=device) * cfg.scale + cfg.loc
            return s.exp().clip(cfg.sigma_min, cfg.sigma_max)

        self.sample_sigma_training = sample_sigma
    
    def apply_noise(self, x: Tensor, sigma: Tensor, sigma_offset_noise: float) -> Tensor:
        b, c, _, _ = x.shape 
        offset_noise = sigma_offset_noise * torch.randn(b, c, 1, 1, device=self.device)
        return x + offset_noise + torch.randn_like(x) * add_dims(sigma, x.ndim)

    def compute_conditioners(self, sigma: Tensor, sigma_cond: Optional[Tensor]) -> Conditioners:
        sigma = (sigma**2 + self.cfg.sigma_offset_noise**2).sqrt()
        c_in = 1 / (sigma**2 + self.cfg.sigma_data**2).sqrt()
        c_skip = self.cfg.sigma_data**2 / (sigma**2 + self.cfg.sigma_data**2)
        c_out = sigma * c_skip.sqrt()
        c_noise = sigma.log() / 4
        c_noise_cond = sigma_cond.log() / 4 if sigma_cond is not None else torch.zeros_like(c_noise)
        return Conditioners(*(add_dims(c, n) for c, n in zip((c_in, c_out, c_skip, c_noise, c_noise_cond), (4, 4, 4, 1, 1))))

    def compute_model_output(self, noisy_next_obs: Tensor, obs: Tensor, act: Optional[Tensor], cs: Conditioners) -> Tensor:
        rescaled_obs = obs / self.cfg.sigma_data
        rescaled_noise = noisy_next_obs * cs.c_in
        return self.inner_model(rescaled_noise, cs.c_noise, cs.c_noise_cond, rescaled_obs, act)
    
    @torch.no_grad()
    def wrap_model_output(self, noisy_next_obs: Tensor, model_output: Tensor, cs: Conditioners) -> Tensor:
        d = cs.c_skip * noisy_next_obs + cs.c_out * model_output
        # Quantize to {0, ..., 255}, then back to [-1, 1]
        d = d.clamp(-1, 1).add(1).div(2).mul(255).byte().div(255).mul(2).sub(1)
        return d
    
    @torch.no_grad()
    def denoise(self, noisy_next_obs: Tensor, sigma: Tensor, sigma_cond: Optional[Tensor], obs: Tensor, act: Optional[Tensor]) -> Tensor:
        cs = self.compute_conditioners(sigma, sigma_cond)
        model_output = self.compute_model_output(noisy_next_obs, obs, act, cs)
        denoised = self.wrap_model_output(noisy_next_obs, model_output, cs)
        return denoised
    
    def forward(self, batch: Batch) -> LossAndLogs:
        b, t, c, h, w = batch.obs.size()
        H, W = (self.cfg.upsampling_factor * h, self.cfg.upsampling_factor * w) if self.is_upsampler else (h, w)
        n = self.cfg.inner_model.num_steps_conditioning
        seq_length = t - n  # t = n + 1 + num_autoregressive_steps

        if self.is_upsampler:
            all_obs = torch.stack([x["full_res"] for x in batch.info]).to(self.device)
            low_res = F.interpolate(batch.obs.reshape(b * t, c, h, w), scale_factor=self.cfg.upsampling_factor, mode="bicubic").reshape(b, t, c, H, W)
            assert all_obs.shape == low_res.shape
        else:
            all_obs = batch.obs.clone()

        loss = 0
        for i in range(seq_length):
            prev_obs = all_obs[:, i : n + i].reshape(b, n * c, H, W)
            prev_act = None if self.is_upsampler else batch.act[:, i : n + i]
            obs = all_obs[:, n + i]
            mask = batch.mask_padding[:, n + i]

            if self.cfg.noise_previous_obs:
                sigma_cond = self.sample_sigma_training(b, self.device)
                prev_obs = self.apply_noise(prev_obs, sigma_cond, self.cfg.sigma_offset_noise)
            else:
                sigma_cond = None

            if self.is_upsampler:
                prev_obs = torch.cat((prev_obs, low_res[:, n + i]), dim=1)

            sigma = self.sample_sigma_training(b, self.device)
            noisy_obs = self.apply_noise(obs, sigma, self.cfg.sigma_offset_noise)

            cs = self.compute_conditioners(sigma, sigma_cond)
            model_output = self.compute_model_output(noisy_obs, prev_obs, prev_act, cs)
            # save image here
            target = (obs - cs.c_skip * noisy_obs) / cs.c_out  # obs.shape --> torch.Size([1, 3, 30, 120]), but raw image size: [150, 600]
            
            # # Save tensor as image using the established codebase pattern
            # def save_tensor_as_image(tensor, filename, tensor_name=""):
            #     """Save tensor as image following the codebase's established pattern"""
            #     assert tensor.ndim == 4 and tensor.size(0) == 1
            #     print(f"{tensor_name} range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
                
            #     # Check if tensor is in expected [-1, 1] range
            #     if tensor.min() < -1.1 or tensor.max() > 1.1:
            #         print(f"WARNING: {tensor_name} values outside expected [-1,1] range!")
            #         # Normalize to [-1, 1] if needed
            #         tensor_normalized = tensor.clamp(-2, 2)  # Clamp to reasonable range first
            #         tensor_normalized = (tensor_normalized - tensor_normalized.min()) / (tensor_normalized.max() - tensor_normalized.min()) * 2 - 1
            #         img = Image.fromarray(tensor_normalized[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy())
            #     else:
            #         img = Image.fromarray(tensor[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy())
                
            #     os.makedirs("/home/zhexiao/Documents/diamond/images", exist_ok=True)
            #     img.save(os.path.join("/home/zhexiao/Documents/diamond/images", filename))
            
            # Run contour detection on model output
            # if self.is_upsampler:
            #     denoised = self.wrap_model_output(noisy_obs, model_output, cs)
            #     rgb_image = model_output[0].add(1).div(2).mul(255).clamp(0, 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
            #     annotated_image, vehicle_states = self.contour_detection_model.detect_from_image(rgb_image)
                
            # #########################################################
            # # TODO: add contour detection related existence loss
            # if self.is_upsampler: 
            #     loss = 0
            # #########################################################

            loss += F.mse_loss(model_output[mask], target[mask]) 

            denoised = self.wrap_model_output(noisy_obs, model_output, cs)
            all_obs[:, n + i] = denoised

        loss /= seq_length
        return loss, {"loss_denoising": loss.item()}