File size: 4,987 Bytes
70ced22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Gaussian Diffusion (DDPM) framework for PDE next-frame prediction.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class GaussianDiffusion(nn.Module):
    """DDPM with linear beta schedule.

    Training: given (condition, target), add noise to target, predict noise.
    Sampling: iteratively denoise starting from Gaussian noise.

    Args:
        model: U-Net (or any eps-predicting network).
        timesteps: number of diffusion steps.
        beta_start: starting noise level.
        beta_end: ending noise level.
    """

    def __init__(self, model, timesteps=1000, beta_start=1e-4, beta_end=0.02):
        super().__init__()
        self.model = model
        self.T = timesteps

        # --- precompute schedule ---
        betas = torch.linspace(beta_start, beta_end, timesteps)
        alphas = 1.0 - betas
        alpha_bar = torch.cumprod(alphas, dim=0)

        self.register_buffer("betas", betas)
        self.register_buffer("alphas", alphas)
        self.register_buffer("alpha_bar", alpha_bar)
        self.register_buffer("sqrt_alpha_bar", torch.sqrt(alpha_bar))
        self.register_buffer("sqrt_one_minus_alpha_bar", torch.sqrt(1 - alpha_bar))
        self.register_buffer("sqrt_recip_alpha", torch.sqrt(1.0 / alphas))
        self.register_buffer(
            "posterior_variance",
            betas * (1.0 - F.pad(alpha_bar[:-1], (1, 0), value=1.0)) / (1.0 - alpha_bar),
        )

    def q_sample(self, x0, t, noise=None):
        """Forward process: add noise to x0 at timestep t."""
        if noise is None:
            noise = torch.randn_like(x0)
        a = self.sqrt_alpha_bar[t][:, None, None, None]
        b = self.sqrt_one_minus_alpha_bar[t][:, None, None, None]
        return a * x0 + b * noise, noise

    def training_loss(self, x_target, x_cond):
        """Compute training loss (predict noise).

        Args:
            x_target: clean target frames [B, C, H, W].
            x_cond: condition frames [B, C, H, W].

        Returns:
            scalar MSE loss.
        """
        B = x_target.shape[0]
        t = torch.randint(0, self.T, (B,), device=x_target.device)
        noise = torch.randn_like(x_target)
        x_noisy, _ = self.q_sample(x_target, t, noise)

        eps_pred = self.model(x_noisy, t, cond=x_cond)
        return F.mse_loss(eps_pred, noise)

    @torch.no_grad()
    def sample(self, x_cond, shape=None):
        """Generate target frames by iterative denoising (DDPM).

        Args:
            x_cond: condition frames [B, C_cond, H, W].
            shape: (B, C_out, H, W) of the target. Inferred if None.

        Returns:
            denoised sample [B, C_out, H, W].
        """
        device = x_cond.device
        if shape is None:
            shape = x_cond.shape  # assume same channels

        x = torch.randn(shape, device=device)

        for i in reversed(range(self.T)):
            t = torch.full((shape[0],), i, device=device, dtype=torch.long)
            eps = self.model(x, t, cond=x_cond)

            alpha = self.alphas[i]
            alpha_bar = self.alpha_bar[i]
            beta = self.betas[i]

            mean = (1.0 / alpha.sqrt()) * (x - beta / (1 - alpha_bar).sqrt() * eps)

            if i > 0:
                sigma = self.posterior_variance[i].sqrt()
                x = mean + sigma * torch.randn_like(x)
            else:
                x = mean

        return x

    @torch.no_grad()
    def sample_ddim(self, x_cond, shape=None, steps=50, eta=0.0):
        """DDIM accelerated sampling.

        Args:
            x_cond: condition [B, C_cond, H, W].
            shape: target shape.
            steps: number of DDIM steps (<<T for speed).
            eta: stochasticity (0=deterministic DDIM, 1=DDPM).

        Returns:
            denoised sample [B, C_out, H, W].
        """
        device = x_cond.device
        if shape is None:
            shape = x_cond.shape

        # Sub-sample timesteps uniformly
        step_indices = torch.linspace(0, self.T - 1, steps + 1, dtype=torch.long, device=device)
        step_indices = step_indices.flip(0)  # reverse: T-1 ... 0

        x = torch.randn(shape, device=device)

        for idx in range(len(step_indices) - 1):
            t_cur = step_indices[idx]
            t_next = step_indices[idx + 1]

            t_batch = t_cur.expand(shape[0])
            eps = self.model(x, t_batch, cond=x_cond)

            ab_cur = self.alpha_bar[t_cur]
            ab_next = self.alpha_bar[t_next]

            # Predict x0
            x0_pred = (x - (1 - ab_cur).sqrt() * eps) / ab_cur.sqrt()
            x0_pred = x0_pred.clamp(-5, 5)  # stability clamp

            # Direction
            sigma = eta * ((1 - ab_next) / (1 - ab_cur) * (1 - ab_cur / ab_next)).sqrt()
            dir_xt = (1 - ab_next - sigma**2).sqrt() * eps

            x = ab_next.sqrt() * x0_pred + dir_xt
            if sigma > 0:
                x = x + sigma * torch.randn_like(x)

        return x