File size: 6,515 Bytes
c64c726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# from dataclasses import dataclass
# from typing import List, Optional, Tuple

# import torch
# from torch import Tensor

# from .denoiser import Denoiser


# @dataclass
# class DiffusionSamplerConfig:
#     num_steps_denoising: int
#     sigma_min: float = 2e-3
#     sigma_max: float = 5
#     rho: int = 7
#     order: int = 1
#     s_churn: float = 0
#     s_tmin: float = 0
#     s_tmax: float = float("inf")
#     s_noise: float = 1
#     s_cond: float = 0


# class DiffusionSampler:
#     def __init__(self, denoiser: Denoiser, cfg: DiffusionSamplerConfig) -> None:
#         self.denoiser = denoiser
#         self.cfg = cfg
#         self.sigmas = build_sigmas(cfg.num_steps_denoising, cfg.sigma_min, cfg.sigma_max, cfg.rho, denoiser.device)

#     @torch.no_grad()
#     def sample(self, prev_obs: Tensor, prev_act: Optional[Tensor]) -> Tuple[Tensor, List[Tensor]]:
#         device = prev_obs.device
#         b, t, c, h, w = prev_obs.size()
#         prev_obs = prev_obs.reshape(b, t * c, h, w)
#         s_in = torch.ones(b, device=device)
#         gamma_ = min(self.cfg.s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
#         x = torch.randn(b, c, h, w, device=device)
#         trajectory = [x]
#         for sigma, next_sigma in zip(self.sigmas[:-1], self.sigmas[1:]):
#             gamma = gamma_ if self.cfg.s_tmin <= sigma <= self.cfg.s_tmax else 0
#             sigma_hat = sigma * (gamma + 1)
#             if gamma > 0:
#                 eps = torch.randn_like(x) * self.cfg.s_noise
#                 x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
#             if self.cfg.s_cond > 0:
#                 sigma_cond = torch.full((b,), fill_value=self.cfg.s_cond, device=device)
#                 prev_obs = self.denoiser.apply_noise(prev_obs, sigma_cond, sigma_offset_noise=0)
#             else:
#                 sigma_cond = None
#             denoised = self.denoiser.denoise(x, sigma, sigma_cond, prev_obs, prev_act)
#             d = (x - denoised) / sigma_hat
#             dt = next_sigma - sigma_hat
#             if self.cfg.order == 1 or next_sigma == 0:
#                 # Euler method
#                 x = x + d * dt
#             else:
#                 # Heun's method
#                 x_2 = x + d * dt
#                 denoised_2 = self.denoiser.denoise(x_2, next_sigma * s_in, sigma_cond, prev_obs, prev_act)
#                 d_2 = (x_2 - denoised_2) / next_sigma
#                 d_prime = (d + d_2) / 2
#                 x = x + d_prime * dt
#             trajectory.append(x)
#         return x, trajectory


# def build_sigmas(num_steps: int, sigma_min: float, sigma_max: float, rho: int, device: torch.device) -> Tensor:
#     min_inv_rho = sigma_min ** (1 / rho)
#     max_inv_rho = sigma_max ** (1 / rho)
#     l = torch.linspace(0, 1, num_steps, device=device)
#     sigmas = (max_inv_rho + l * (min_inv_rho - max_inv_rho)) ** rho
#     return torch.cat((sigmas, sigmas.new_zeros(1)))



from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
from torch import Tensor

from .denoiser import Denoiser

@dataclass
class DiffusionSamplerConfig:
    num_steps_denoising: int
    sigma_min: float = 2e-3
    sigma_max: float = 5
    rho: int = 7
    order: int = 1
    s_churn: float = 0
    s_tmin: float = 0
    s_tmax: float = float("inf")
    s_noise: float = 1
    s_cond: float = 0


class DiffusionSampler:
    def __init__(self, denoiser: Denoiser, cfg: DiffusionSamplerConfig) -> None:
        self.denoiser = denoiser
        self.cfg = cfg
        self.sigmas = build_sigmas(cfg.num_steps_denoising, cfg.sigma_min, cfg.sigma_max, cfg.rho, denoiser.device)
        self.is_first_frame = True
        self.last_frame = None

    @torch.no_grad()
    def sample(self, prev_obs: Tensor, prev_act: Optional[Tensor]) -> Tuple[Tensor, List[Tensor]]:
        device = prev_obs.device
        b, t, c, h, w = prev_obs.size()
        prev_obs = prev_obs.reshape(b, t * c, h, w)
        s_in = torch.ones(b, device=device)
        gamma_ = min(self.cfg.s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
        # x = torch.randn(b, c, h, w, device=device) # use Gaussian noise as initial sample
        # use warmstart of last frame if available
        if self.is_first_frame: # first frame
            x = torch.randn(b, c, h, w, device=device) # use Gaussian noise as initial sample
            self.is_first_frame = False
        else:   # use last framw for warmstart
            sigma_cond = torch.full((b,), fill_value=0.05, device=device)
            x = self.denoiser.apply_noise(self.last_frame, sigma_cond, sigma_offset_noise=0.05)

        trajectory = [x]
        for sigma, next_sigma in zip(self.sigmas[:-1], self.sigmas[1:]):
            gamma = gamma_ if self.cfg.s_tmin <= sigma <= self.cfg.s_tmax else 0
            sigma_hat = sigma * (gamma + 1)
            if gamma > 0:
                eps = torch.randn_like(x) * self.cfg.s_noise
                x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
            if self.cfg.s_cond > 0:
                sigma_cond = torch.full((b,), fill_value=self.cfg.s_cond, device=device)
                prev_obs = self.denoiser.apply_noise(prev_obs, sigma_cond, sigma_offset_noise=0)
            else:
                sigma_cond = None
            denoised = self.denoiser.denoise(x, sigma, sigma_cond, prev_obs, prev_act)
            d = (x - denoised) / sigma_hat
            dt = next_sigma - sigma_hat
            if self.cfg.order == 1 or next_sigma == 0:
                # Euler method
                x = x + d * dt
            else:
                # Heun's method
                x_2 = x + d * dt
                denoised_2 = self.denoiser.denoise(x_2, next_sigma * s_in, sigma_cond, prev_obs, prev_act)
                d_2 = (x_2 - denoised_2) / next_sigma
                d_prime = (d + d_2) / 2
                x = x + d_prime * dt
            trajectory.append(x)
        self.last_frame = x

        # visulize low resolution observation
        # Denoiser.save_tensor_as_image(x, "inference_output_low_res.png", tensor_name="Inference Low Resolution Observation")
        return x, trajectory


def build_sigmas(num_steps: int, sigma_min: float, sigma_max: float, rho: int, device: torch.device) -> Tensor:
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    l = torch.linspace(0, 1, num_steps, device=device)
    sigmas = (max_inv_rho + l * (min_inv_rho - max_inv_rho)) ** rho
    return torch.cat((sigmas, sigmas.new_zeros(1)))