File size: 8,506 Bytes
c46900a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
Diffusion Process Implementation (DDPM + DDIM)

Changes from original:
- Fixed q_posterior_mean_variance return value count (was 3, caller expected 4)
- Fixed DDIM sigma/dir_xt formula inconsistency
- Made GaussianDiffusion an nn.Module with registered buffers for proper device handling
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class GaussianDiffusion(nn.Module):
    def __init__(self, timesteps=1500, beta_start=1e-4, beta_end=0.02, schedule_type="linear"):
        super().__init__()
        self.timesteps = timesteps

        if schedule_type == "linear":
            betas = torch.linspace(beta_start, beta_end, timesteps)
        elif schedule_type == "cosine":
            betas = self._cosine_beta_schedule(timesteps)
        else:
            raise ValueError(f"Unknown schedule: {schedule_type}")

        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

        # Register all schedule tensors as buffers so they move with .to(device)
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))

        posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)
        self.register_buffer('posterior_log_variance_clipped', torch.log(torch.clamp(posterior_variance, min=1e-20)))
        self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod))

        # Precompute reciprocals used in _predict_xstart_from_noise (avoids recomputation per step)
        self.register_buffer('recip_sqrt_alphas_cumprod', 1.0 / torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_recip_minus_one', torch.sqrt(1.0 / alphas_cumprod - 1.0))

    def _cosine_beta_schedule(self, timesteps, s=0.008):
        steps = timesteps + 1
        x = torch.linspace(0, timesteps, steps)
        alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0.0001, 0.9999)

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_mean_variance(self, model, x_t, t, labels, clip_denoised=True):
        pred_noise = model(x_t, t, labels)
        x_start = self._predict_xstart_from_noise(x_t, t, pred_noise)
        if clip_denoised:
            x_start = torch.clamp(x_start, -1.0, 1.0)
        # FIX: q_posterior_mean_variance returns 3 values, not 4
        model_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(x_start, x_t, t)
        return model_mean, posterior_variance, posterior_log_variance, x_start

    def _predict_xstart_from_noise(self, x_t, t, noise):
        return (
            self._extract(self.recip_sqrt_alphas_cumprod, t, x_t.shape) * x_t -
            self._extract(self.sqrt_recip_minus_one, t, x_t.shape) * noise
        )

    def q_posterior_mean_variance(self, x_start, x_t, t):
        posterior_mean = (
            self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance

    def p_sample(self, model, x_t, t, labels):
        model_mean, _, model_log_variance, _ = self.p_mean_variance(model, x_t, t, labels)
        noise = torch.randn_like(x_t)
        nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))))
        return model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise

    def ddim_sample_step(self, model, x_t, t, t_next, labels, eta=0.0):
        pred_noise = model(x_t, t, labels)
        alpha_t = self._extract(self.alphas_cumprod, t, x_t.shape)
        alpha_t_next = self._extract(self.alphas_cumprod, t_next, x_t.shape) if t_next[0] >= 0 else torch.ones_like(alpha_t)

        x0_pred = (x_t - torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(alpha_t)
        x0_pred = torch.clamp(x0_pred, -1.0, 1.0)

        # FIX: Consistent DDIM sigma computation
        # sigma^2 = eta^2 * (1 - alpha_{t-1}) / (1 - alpha_t) * (1 - alpha_t / alpha_{t-1})
        sigma_sq = eta**2 * (1 - alpha_t_next) / (1 - alpha_t) * (1 - alpha_t / alpha_t_next) if eta > 0 else 0
        sigma_t = torch.sqrt(torch.clamp(sigma_sq, min=0)) if eta > 0 else 0

        # dir_xt uses the same sigma^2
        dir_xt = torch.sqrt(torch.clamp(1 - alpha_t_next - sigma_sq, min=0)) * pred_noise

        noise = torch.randn_like(x_t) if eta > 0 else 0
        return torch.sqrt(alpha_t_next) * x0_pred + dir_xt + sigma_t * noise

    def sample(self, model, labels, channels, height, width, device, progress=False, use_ddim=True, ddim_steps=50, eta=0.0):
        batch_size = labels.shape[0]
        img = torch.randn((batch_size, channels, height, width), device=device)

        if use_ddim:
            skip = self.timesteps // ddim_steps
            seq = list(range(0, self.timesteps, skip))
            seq_next = [-1] + seq[:-1]
            seq_iter = reversed(list(zip(seq, seq_next)))
            if progress:
                from tqdm import tqdm
                seq_iter = tqdm(seq_iter, desc=f'DDIM Sampling ({ddim_steps} steps)', total=len(seq))
            for i, j in seq_iter:
                t = torch.full((batch_size,), i, device=device, dtype=torch.long)
                t_next = torch.full((batch_size,), j, device=device, dtype=torch.long)
                img = self.ddim_sample_step(model, img, t, t_next, labels, eta)
        else:
            if progress:
                from tqdm import tqdm
                timesteps_iter = tqdm(reversed(range(self.timesteps)), total=self.timesteps)
            else:
                timesteps_iter = reversed(range(self.timesteps))
            for i in timesteps_iter:
                t = torch.full((batch_size,), i, device=device, dtype=torch.long)
                img = self.p_sample(model, img, t, labels)
        return img

    def training_losses(self, model, x_start, labels, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        x_t = self.q_sample(x_start, t, noise)
        pred_noise = model(x_t, t, labels)
        return F.mse_loss(pred_noise, noise, reduction='none').mean(dim=list(range(1, len(pred_noise.shape))))

    def _extract(self, a, t, x_shape):
        batch_size = t.shape[0]
        out = a.gather(0, t)
        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))


class ConditionalDiffusionModel(nn.Module):
    def __init__(self, unet, diffusion_process):
        super().__init__()
        self.unet = unet
        self.diffusion = diffusion_process

    def forward(self, x, t, labels):
        return self.unet(x, t, labels)

    def get_loss(self, x, labels, noise=None):
        batch_size = x.shape[0]
        device = x.device
        t = torch.randint(0, self.diffusion.timesteps, (batch_size,), device=device).long()
        return self.diffusion.training_losses(self, x, labels, t, noise=noise).mean()

    def sample(self, labels, channels, height, width, device, progress=False, use_ddim=True, ddim_steps=50, eta=0.0):
        self.eval()
        with torch.no_grad():
            return self.diffusion.sample(self, labels, channels, height, width, device, progress, use_ddim, ddim_steps, eta)