File size: 13,423 Bytes
be89dda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
import torch
import torch.nn as nn
import torch.nn.functional as F

from .layers import ResidualBlock, AttnBlock
from .utils import get_named_beta_schedule

def sinusoidal_embedding(n, d):
    """
    n: iteration steps,
    d: time embedding dimension
    """
    # Returns the standard positional embedding
    embedding = torch.tensor([[i / 10000 ** (2 * j / d) for j in range(d)] for i in range(n)])
    sin_mask = torch.arange(0, n, 2)

    embedding[sin_mask] = torch.sin(embedding[sin_mask])
    embedding[1 - sin_mask] = torch.cos(embedding[sin_mask])

    return embedding

def _make_te(dim_in, dim_out):
        return nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            nn.Linear(dim_out, dim_out)
        )

class UNet_with_time(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        input_frame = config.input_frame
        output_frame = config.output_frame
        n_steps = config.n_steps
        time_emb_dim = config.time_emb_dim
        cond_nc = config.cond_nc
        chs_mult = config.chs_mult ## e.g. (1, 2, 4, 8)
        n_res_blocks = config.n_res_blocks
        base_chs = config.base_chs
        ## e.g. (0, 0, 1, 1) -> 0 means no attention
        use_attn_list = config.use_attn_list 

        layer_depth = len(chs_mult)
        assert len(use_attn_list) == layer_depth, "length of use_attn_list should be the same as chs_mult"
        assert input_frame >= output_frame, "input_frame should be larger than or equal to output_frame"

        self.filter_list = [base_chs * m for m in chs_mult]

        ## time embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)
        self.time_embed_fc = _make_te(time_emb_dim, time_emb_dim)
        ## end of time embedding

        ## input conv
        self.input_layer = nn.PixelUnshuffle(downscale_factor=2)

        ## downsampling
        self.down_blocks = nn.ModuleList()
        in_c = input_frame * 4 ## after pixel unshuffle
        for i in range(layer_depth):
            out_c = self.filter_list[i]

            for _ in range(n_res_blocks):
                self.down_blocks.append(
                    ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False)
                )

            if use_attn_list[i]:
                self.down_blocks.append(AttnBlock(in_c, 4)) ## num_head=4
            
            self.down_blocks.append(
                ResidualBlock(in_c, out_c, cond_nc, time_emb_dim, down_flag=True, up_flag=False)
            )
            in_c = out_c
        ## end of downsampling

        ## middle
        self.mid_block1 = ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False)
        self.mid_attn = AttnBlock(in_c, 4)
        self.mid_block2 = ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False)
        ## end of middle

        ## upsampling
        self.up_blocks = nn.ModuleList()
        self.filter_list = [input_frame * 4] + self.filter_list[:-1]
        for i in reversed(range(layer_depth)): ## i = layer_depth-1, ..., 0
            out_c = self.filter_list[i]
            
            self.up_blocks.append(
                ResidualBlock(in_c*2, out_c, cond_nc, time_emb_dim, down_flag=False, up_flag=True)
            )
            if use_attn_list[i]:
                self.up_blocks.append(AttnBlock(out_c)) ## num_head=1

            for _ in range(n_res_blocks):
                self.up_blocks.append(
                    ResidualBlock(out_c*2, out_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False)
                )

            in_c = out_c
        
        ## end of upsampling
        self.out_up = nn.PixelShuffle(upscale_factor=2)
        self.out_conv = nn.Conv2d(input_frame, output_frame, 3, padding=1)

    def forward(self, x, t, cond):
        """
        x: (b, in_c, h, w), noisy input (concatenated with some data)
        t: (b,), time step
        cond: (b, cond_nc, h, w), conditional input
        """
        # time embedding
        t_emb = self.time_embed(t) ## (b, time_emb_dim)
        t_emb = self.time_embed_fc(t_emb) ## (b, time_emb_dim)

        # input conv
        x = self.input_layer(x)

        # downsampling
        skip_x = []
        for ii, down_layer in enumerate(self.down_blocks):
            if isinstance(down_layer, ResidualBlock):
                x = down_layer(x, cond, t_emb)
                skip_x.append(x)
            elif isinstance(down_layer, AttnBlock):
                x = down_layer(x)
            else:
                raise ValueError("Wrong layer type in down_blocks")

        # middle
        x = self.mid_block1(x, cond, t_emb)
        x = self.mid_attn(x)
        x = self.mid_block2(x, cond, t_emb)

        # upsampling
        for up_layer in self.up_blocks:
            if isinstance(up_layer, ResidualBlock):
                skip_feat = skip_x.pop()
                x = torch.cat([x, skip_feat], dim=1) ## concat along channel dimension
                x = up_layer(x, cond, t_emb)
            elif isinstance(up_layer, AttnBlock):
                x = up_layer(x)
            else:
                raise ValueError("Wrong layer type in up_blocks")

        # output
        x = self.out_up(x)
        x = self.out_conv(x)

        return x

class DDPM(nn.Module):
    def __init__(self, backbone, output_shape, n_steps=1000, min_beta=1e-4, max_beta=0.02, device='cuda'):
        """
        output_shape: dim(C, H, W)
        """
        super().__init__()
        self.device = device
        self.backbone_model = backbone
        self.output_shape  = output_shape

        self.n_steps = n_steps
        
        ## linear betas
        betas = get_named_beta_schedule("linear", n_steps, min_beta, max_beta)
        alphas = 1.0 - betas
        alpha_bars = torch.cumprod(alphas, dim=0)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alpha_bars', alpha_bars)

    def forward(self, x, t, cond):
        """
        x: (b, in_c, h, w), noisy input (concatenated with some data)
        cond: (b, cond_nc, h, w), conditional input
        t: (b,), time step
        """
        return self.backbone_model(x, t, cond)

    @torch.no_grad()
    def add_noise(self, x0, t, eta=None):
        """
        x0: (b, c, h, w), original data
        t: (b,), time step (0 <= t < n_steps)
        """
        b, c, h, w = x0.shape
        if eta is None:
            eta = torch.randn(b, c, h, w, device=x0.device)

        alpha_bar = self.alpha_bars[t]
        noisy_x = alpha_bar.sqrt().reshape(b, 1, 1, 1) * x0 + (1 - alpha_bar).sqrt().reshape(b, 1, 1, 1) * eta

        return noisy_x

    def denoise(self, xt, t, cond):
        """
        xt: (b, in_c, h, w), noisy input (concatenated with some data)
        cond: (b, cond_nc, h, w), conditional input
        t: (b,), time step (0 <= t < n_steps)
        """
        pred_noise = self(xt, t, cond)
        return pred_noise

    @torch.no_grad()
    def _build_progress_iter(self, iterable, total, mode: str):
        """
        Internal helper to create a progress iterator based on verbose mode.
        """
        mode = (mode or "none").lower()
        if mode == "tqdm":
            try:
                from tqdm import tqdm

                return tqdm(iterable, total=total, desc="DDPM sampling", leave=False), mode
            except Exception:
                return iterable, "none"
        return iterable, mode

    @torch.no_grad()
    def sample_ddpm(self, cond, input_cond=None, verbose: str = "none", store_intermediate: bool = False):
        """
        input_frame: (b, c, h, w) number of input frames (conditional input frames) for the diffusion model
        cond: (b, cond_nc, h, w), conditional input
        verbose: "none", "text", or "tqdm" for progress display
        """
        ## confirm that the model is in eval mode
        self.backbone_model.eval()
        
        B, C, H, W = cond.shape
        ## get cond device
        device = cond.device

        x = torch.randn(B, *self.output_shape, device=device)

        progress_iter_raw = reversed(range(self.n_steps))
        progress_iter, mode = self._build_progress_iter(progress_iter_raw, self.n_steps, verbose)
        use_text = mode == "text"

        text_interval = max(1, self.n_steps // 10)
        
        frames = []
        for idx, t in enumerate(progress_iter):
            time_tensor = (torch.ones(B, device=device) * t).long()
            if input_cond is not None:
                input_ = torch.cat((x, input_cond), dim=1)
            else:
                input_ = x

            eta_theta = self.denoise(input_, time_tensor, cond)

            alpha_t = self.alphas[t]
            alpha_t_bar = self.alpha_bars[t]

            a = 1 / alpha_t.sqrt()
            b = ((1 - alpha_t) / (1 - alpha_t_bar).sqrt()) * eta_theta

            x = a * (x - b)
            if t > 0:
                z = torch.randn(B, *self.output_shape, device=device)
                beta_t = self.betas[t]
                sigma_t = beta_t.sqrt()
                x = x + sigma_t * z

            ## store intermediate frames for visualization
            if (idx % 50 == 0) or (t == 0):
                out = x.clone()
                out = ((out + 1) / 2).clamp(0, 1)
                out = out.cpu().numpy()
                frames.append(out)
            
            if use_text and (idx + 1) % text_interval == 0:
                print(f"DDPM sampling {idx + 1}/{self.n_steps}", flush=True)

        if mode == "tqdm" and hasattr(progress_iter, "close"):
            progress_iter.close()
        
        if store_intermediate:
            return x, frames
        else:
            return x

    @torch.no_grad()
    def sample_ddim(self, cond, input_cond=None, ddim_steps: int = 100, eta: float = 0.2, verbose: str = "none", store_intermediate: bool = False):
        """
        Deterministic/stochastic DDIM sampling.

        cond: (b, cond_nc, h, w)
        input_cond: optional conditional input concatenated with the predicted frames
        ddim_steps: number of steps to sample (<= n_steps)
        eta: 0 for deterministic DDIM, >0 adds noise controlled by eta
        verbose: "none", "text", or "tqdm" for progress display
        """
        self.backbone_model.eval()

        B, C, H, W = cond.shape
        device = cond.device
        ddim_steps = max(1, min(ddim_steps, self.n_steps))

        # create evenly spaced timesteps
        ddim_timesteps = torch.linspace(0, self.n_steps - 1, steps=ddim_steps, device=device).long()
        ddim_timesteps = torch.unique(ddim_timesteps, sorted=True)  # safety against duplicates
        ddim_t_reverse = list(reversed(ddim_timesteps.tolist()))

        x = torch.randn(B, *self.output_shape, device=device)

        progress_iter_raw = enumerate(ddim_t_reverse)
        progress_iter, mode = self._build_progress_iter(progress_iter_raw, len(ddim_t_reverse), verbose)
        use_text = mode == "text"
        text_interval = max(1, len(ddim_t_reverse) // 10)

        frames = []
        for idx, (iter_idx, t) in enumerate(progress_iter):
            time_tensor = torch.full((B,), t, device=device, dtype=torch.long)
            if input_cond is not None:
                input_ = torch.cat((x, input_cond), dim=1)
            else:
                input_ = x

            eps = self.denoise(input_, time_tensor, cond)

            alpha_bar_t = self.alpha_bars[t]
            sqrt_alpha_bar_t = alpha_bar_t.sqrt()
            sqrt_one_minus_alpha_bar_t = (1 - alpha_bar_t).sqrt()

            x0_pred = (x - sqrt_one_minus_alpha_bar_t * eps) / sqrt_alpha_bar_t

            if iter_idx + 1 < len(ddim_t_reverse):
                t_prev = ddim_t_reverse[iter_idx + 1]
                alpha_bar_prev = self.alpha_bars[t_prev]
            else:
                alpha_bar_prev = torch.ones_like(alpha_bar_t, device=device)

            sigma_t = 0.0
            if eta > 0 and alpha_bar_prev < 1:
                sigma_t = eta * torch.sqrt(
                    (1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_prev)
                )

            sigma_t = torch.as_tensor(sigma_t, device=device, dtype=x.dtype)
            noise = torch.randn_like(x) if (eta > 0 and alpha_bar_prev < 1) else torch.zeros_like(x)

            c_t = torch.sqrt(torch.clamp(1 - alpha_bar_prev - sigma_t ** 2, min=0.0))
            x = (
                alpha_bar_prev.sqrt() * x0_pred
                + c_t * eps
                + sigma_t * noise
            )

            ## store intermediate frames for visualization
            if (idx % 25 == 0) or (t == 0):
                out = x.clone()
                out = ((out + 1) / 2).clamp(0, 1)
                out = out.cpu().numpy()
                frames.append(out)

            if use_text and (idx + 1) % text_interval == 0:
                print(f"DDIM sampling {idx + 1}/{len(ddim_t_reverse)}", flush=True)

        if mode == "tqdm" and hasattr(progress_iter, "close"):
            progress_iter.close()

        if store_intermediate:
            return x, frames
        else:
            return x

    # Backward-compatible alias
    sample = sample_ddpm