Upload folder using huggingface_hub

#1
by Bangchis - opened
Files changed (5) hide show
  1. config.yaml +97 -0
  2. diffusion.py +429 -0
  3. pytorch_model.bin +3 -0
  4. requirements.txt +24 -0
  5. unet.py +245 -0
config.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ project: diffusion-from-scratch
2
+ run_name: mnist32_small
3
+
4
+ data:
5
+ dataset: mnist
6
+ image_size: 32 # resize MNIST 28 -> 32 (chia được cho UNet)
7
+ channels: 1
8
+ batch_size: 128
9
+ num_workers: 4
10
+
11
+ opt:
12
+ lr: 0.0002
13
+ betas: [0.9, 0.999]
14
+ grad_clip: 1.0
15
+
16
+ diffusion:
17
+ T: 400 # fewer steps for MNIST
18
+ beta_schedule: cosine
19
+ objective: pred_noise # start simple; later try pred_v
20
+ sampling_steps: 400 # < T => DDIM fast sampling
21
+ eta: 0.0
22
+ self_condition: false
23
+ clamp_x0: true
24
+ sample_every: 2000
25
+ sample_n: 64
26
+ learned_variance: false
27
+ var_loss_weight: 0.0
28
+ min_snr_loss_weight: false
29
+
30
+
31
+ model:
32
+ dim: 32 # lightweight UNet
33
+ dim_mults: [1, 2, 4] # shallow for MNIST
34
+ channels: 1
35
+ attn_heads: 2
36
+ attn_dim_head: 16
37
+ dropout: 0.0
38
+ self_condition: false
39
+ learned_variance: false
40
+ outer_attn: false # turn off outer attention; keep only bottleneck attention
41
+
42
+ train:
43
+ max_steps: 30000
44
+ log_every: 200
45
+ ckpt_dir: ./checkpoints
46
+ grad_accum: 1
47
+
48
+ ema:
49
+ enabled: false
50
+ decay: 0.995
51
+ update_every: 10
52
+
53
+ wandb:
54
+ enabled: true
55
+ mode: online
56
+ api_key_env: b66dc9962d08bb26ff3fc4928703a13b30b2e9c9
57
+ tags: [mnist, small, bottleneck-attn]
58
+
59
+ compute:
60
+ enable_tf32: true
61
+
62
+
63
+ metrics:
64
+ # norms
65
+ global_norm_every: 1000
66
+
67
+
68
+ # FID / IS (optional; need clean-fid and torch-fidelity installed)
69
+ enable_fid: true
70
+ enable_is: true
71
+ fid_every: 4000
72
+ is_every: 4000
73
+ metric_num_gen: 5000
74
+ metric_batch_size: 32
75
+
76
+
77
+ diffusion:
78
+ T: 400
79
+ beta_schedule: cosine
80
+ objective: pred_noise
81
+ sampling_steps: 400 # DDPM
82
+ eta: 0.0
83
+ sample_every: 1000
84
+ sample_n: 64
85
+
86
+ viz:
87
+ enable_reverse_traj: true
88
+ reverse_every_steps: 4000 # log video thưa để nhẹ
89
+ reverse_record_every: 5 # ↓ số này => ghi nhiều snapshot hơn (1 = mượt nhất)
90
+ reverse_batch_n: 16
91
+ enable_forward_traj: true
92
+ forward_every_steps: 4000
93
+ forward_t_values: [0, 20, 40, 60, 80, 120, 160, 240, 320, 399] # dày hơn chút
94
+ forward_batch_n: 16
95
+ video_fps: 16 # tăng FPS (16–24) cho playback mượt hơn
96
+ # fps cao hơn để mượt
97
+
diffusion.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # diffusion_core.py
2
+ # -- file này chứa các công thức toán cốt lõi của DDPM/DDIM
3
+ # -- mục tiêu: tính các hệ số từ beta-schedule, và 4 hàm quan trọng:
4
+ # q_sample, predict_start_from_noise, predict_noise_from_start, q_posterior
5
+
6
+ # diffusion_core.py
7
+ # Core DDPM math: schedules and q/p transformations.
8
+
9
+
10
+ import math
11
+ import torch
12
+
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ def extract(a, t, x_shape):
18
+ batch_size = t.shape[0]
19
+ out = a.gather(-1, t)
20
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
21
+
22
+
23
+ def cosine_beta_schedule(timesteps, s=0.008):
24
+ steps = timesteps + 1
25
+ t = torch.linspace(0, timesteps, steps, dtype=torch.float32) / timesteps
26
+ alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
27
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
28
+ betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
29
+ return torch.clip(betas, 0, 0.999)
30
+
31
+
32
+ class GaussianDiffusion(nn.Module):
33
+ """
34
+ Core diffusion module that wraps a denoiser (UNet):
35
+ - Precomputes diffusion constants (betas, alphas, etc.)
36
+ - Provides training loss (forward): randomly pick t, add noise, regress target
37
+ - Provides sampling loops (DDPM or DDIM)
38
+
39
+ The denoiser must have forward(x, t, [x_self_cond]), returning a predicted target
40
+ (epsilon, x0, or v depending on `objective`).
41
+ """
42
+
43
+ def __init__(self, model, *, image_size, timesteps=400, beta_schedule='cosine',
44
+ objective='pred_noise', sampling_steps=None, eta=0.0,
45
+ self_condition=False, auto_normalize=True, clamp_x0=True):
46
+ """
47
+ Args:
48
+ model (nn.Module): denoiser network (e.g., UNet).
49
+ image_size (int or (h,w)): training/sampling resolution (must match UNet).
50
+ timesteps (int): T. Smaller (e.g., 400) is enough for MNIST.
51
+ beta_schedule (str): only 'cosine' implemented here for simplicity.
52
+ objective (str): 'pred_noise'|'pred_x0'|'pred_v' (training target).
53
+ sampling_steps (int or None): if set < T => DDIM sampling with S steps; else DDPM full T.
54
+ eta (float): DDIM stochasticity (0.0 => deterministic).
55
+ self_condition (bool): optional self-conditioning flag.
56
+ auto_normalize (bool): map inputs [0,1] <-> [-1,1] inside module.
57
+ clamp_x0 (bool): clamp predicted x0 to [-1,1] during sampling for stability.
58
+ """
59
+ super().__init__()
60
+ self.model = model
61
+ param = next(model.parameters())
62
+ param_dtype = param.dtype
63
+ param_device = param.device
64
+ self.channels = model.channels
65
+ self.self_condition = self_condition
66
+ self.objective = objective
67
+ self.clamp_x0 = clamp_x0
68
+
69
+ # In-module normalization helpers (kept simple & explicit)
70
+ self.normalize = (lambda x: x * 2 -
71
+ 1) if auto_normalize else (lambda x: x)
72
+ self.unnormalize = (lambda x: (x + 1) *
73
+ 0.5) if auto_normalize else (lambda x: x)
74
+
75
+ # Normalize image_size to (H, W)
76
+ if isinstance(image_size, int):
77
+ image_size = (image_size, image_size)
78
+ self.image_size = image_size
79
+
80
+ # --- schedule setup ---
81
+ if beta_schedule != 'cosine':
82
+ raise NotImplementedError(
83
+ "For MNIST small, keep beta_schedule='cosine'")
84
+ betas = cosine_beta_schedule(timesteps).to(
85
+ device=param_device, dtype=param_dtype) # shape [T]
86
+
87
+ alphas = 1.0 - betas # alpha_t
88
+ alphas_cumprod = torch.cumprod(alphas, dim=0) # alpha_bar_t
89
+ alphas_cumprod_prev = F.pad(
90
+ alphas_cumprod[:-1], (1, 0), value=1.0) # alpha_bar_{t-1}
91
+
92
+ # Timesteps used in training and sampling
93
+ self.num_timesteps = int(betas.shape[0])
94
+ self.sampling_steps = int(
95
+ sampling_steps) if sampling_steps else self.num_timesteps
96
+ self.is_ddim_sampling = self.sampling_steps < self.num_timesteps
97
+ self.ddim_sampling_eta = float(eta)
98
+
99
+ # Register constants as buffers (moved with .to(device), saved in state_dict)
100
+ self.register_buffer('betas', betas)
101
+ self.register_buffer('alphas_cumprod', alphas_cumprod)
102
+ self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
103
+ self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
104
+ self.register_buffer('sqrt_one_minus_alphas_cumprod',
105
+ torch.sqrt(1.0 - alphas_cumprod))
106
+ self.register_buffer('sqrt_recip_alphas_cumprod',
107
+ torch.sqrt(1.0 / alphas_cumprod))
108
+ self.register_buffer('sqrt_recipm1_alphas_cumprod',
109
+ torch.sqrt(1.0 / alphas_cumprod - 1.0))
110
+
111
+ # Posterior q(x_{t-1} | x_t, x_0) parameters
112
+ posterior_variance = betas * \
113
+ (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
114
+ self.register_buffer('posterior_variance', posterior_variance)
115
+ self.register_buffer('posterior_log_variance_clipped', torch.log(
116
+ posterior_variance.clamp(min=1e-20)))
117
+ self.register_buffer('posterior_mean_coef1', betas *
118
+ torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
119
+ self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev)
120
+ * torch.sqrt(1.0 - betas) / (1.0 - alphas_cumprod))
121
+
122
+ # Optional loss re-weighting by SNR (kept simple here)
123
+ snr = alphas_cumprod / (1 - alphas_cumprod)
124
+ if objective == 'pred_noise':
125
+ loss_weight = snr / snr # becomes 1
126
+ elif objective == 'pred_x0':
127
+ loss_weight = snr
128
+ else: # pred_v
129
+ loss_weight = snr / (snr + 1)
130
+ self.register_buffer('loss_weight', loss_weight)
131
+
132
+ @property
133
+ def device(self):
134
+ """Convenience: returns the device where buffers live."""
135
+ return self.betas.device
136
+
137
+ # ----------------------
138
+ # Forward diffusion (q)
139
+ # ----------------------
140
+ def q_sample(self, x0, t, noise=None):
141
+ """
142
+ Sample x_t from q(x_t | x_0):
143
+ x_t = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * noise
144
+ """
145
+ if noise is None:
146
+ noise = torch.randn_like(x0)
147
+ return extract(self.sqrt_alphas_cumprod, t, x0.shape) * x0 + \
148
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
149
+
150
+ # ---------------------------------
151
+ # Converters between parameterizations
152
+ # ---------------------------------
153
+ def predict_start_from_noise(self, x_t, t, eps):
154
+ """Given epsilon prediction, reconstruct x0."""
155
+ return extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - \
156
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
157
+
158
+ def predict_noise_from_start(self, x_t, t, x0):
159
+ """Given x0 prediction, reconstruct epsilon."""
160
+ return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
161
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
162
+
163
+ def predict_v(self, x0, t, eps):
164
+ """v-parameterization = sqrt(alpha_bar)*eps - sqrt(1-alpha_bar)*x0."""
165
+ return extract(self.alphas_cumprod.sqrt(), t, x0.shape) * eps - \
166
+ extract((1.0 - self.alphas_cumprod).sqrt(), t, x0.shape) * x0
167
+
168
+ def predict_start_from_v(self, x_t, t, v):
169
+ """Given v prediction, reconstruct x0."""
170
+ return extract(self.alphas_cumprod.sqrt(), t, x_t.shape) * x_t - \
171
+ extract((1.0 - self.alphas_cumprod).sqrt(), t, x_t.shape) * v
172
+
173
+ # ---------------------------------
174
+ # Model predictions at time t
175
+ # ---------------------------------
176
+ def model_predictions(self, x, t, x_self_cond=None, clip_x_start=False, rederive_pred_noise=False):
177
+ """
178
+ Run the denoiser and return (pred_noise, x0):
179
+ - If objective == pred_noise: UNet predicts epsilon directly.
180
+ - If objective == pred_x0: UNet predicts x0 directly.
181
+ - If objective == pred_v: UNet predicts v; we convert to x0 & epsilon.
182
+
183
+ Args:
184
+ x (Tensor): noised image x_t.
185
+ t (LongTensor): time indices.
186
+ x_self_cond (Tensor|None): optional self-conditioning input.
187
+ clip_x_start (bool): clamp x0 to [-1,1] after prediction.
188
+ rederive_pred_noise (bool): if True, recompute epsilon from clamped x0.
189
+
190
+ Returns:
191
+ (pred_noise, x0) both shape like x.
192
+ """
193
+ out = self.model(
194
+ x, t, x_self_cond) if x_self_cond is not None else self.model(x, t)
195
+
196
+ maybe_clip = (lambda z: z.clamp(-1, 1)
197
+ ) if clip_x_start else (lambda z: z)
198
+
199
+ if self.objective == 'pred_noise':
200
+ pred_noise = out
201
+ x0 = self.predict_start_from_noise(x, t, pred_noise)
202
+ x0 = maybe_clip(x0)
203
+ if clip_x_start and rederive_pred_noise:
204
+ pred_noise = self.predict_noise_from_start(x, t, x0)
205
+
206
+ elif self.objective == 'pred_x0':
207
+ x0 = maybe_clip(out)
208
+ pred_noise = self.predict_noise_from_start(x, t, x0)
209
+
210
+ else: # 'pred_v'
211
+ v = out
212
+ x0 = self.predict_start_from_v(x, t, v)
213
+ x0 = maybe_clip(x0)
214
+ pred_noise = self.predict_noise_from_start(x, t, x0)
215
+
216
+ return pred_noise, x0
217
+
218
+ def q_posterior(self, x0, x_t, t):
219
+ """
220
+ Compute the Gaussian q(x_{t-1} | x_t, x0) parameters:
221
+ mean = c1 * x0 + c2 * x_t
222
+ var, log_var: closed-form from betas and alpha_bars.
223
+ """
224
+ mean = extract(self.posterior_mean_coef1, t, x_t.shape) * x0 + \
225
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
226
+ var = extract(self.posterior_variance, t, x_t.shape)
227
+ log_var = extract(self.posterior_log_variance_clipped, t, x_t.shape)
228
+ return mean, var, log_var
229
+
230
+ # ----------------------
231
+ # Training loss (forward)
232
+ # ----------------------
233
+ def p_losses(self, x_start, t, noise=None):
234
+ """
235
+ DDPM training objective:
236
+ - Sample x_t = q(x_t | x_0)
237
+ - Predict target according to objective and MSE it
238
+ - (Optional) self-conditioning can be added outside for simplicity
239
+ """
240
+ noise = torch.randn_like(x_start) if noise is None else noise
241
+ x = self.q_sample(x_start, t, noise)
242
+
243
+ x_self_cond = None
244
+ if self.self_condition and torch.rand(1, device=self.device) < 0.5:
245
+ # simple self-conditioning: predict x0 once and feed back
246
+ with torch.no_grad():
247
+ _, x_self_cond = self.model_predictions(
248
+ x, t, None, clip_x_start=True)
249
+
250
+ model_out = self.model(
251
+ x, t, x_self_cond) if x_self_cond is not None else self.model(x, t)
252
+
253
+ if self.objective == 'pred_noise':
254
+ target = noise
255
+ elif self.objective == 'pred_x0':
256
+ target = x_start
257
+ else: # pred_v
258
+ v = self.predict_v(x_start, t, noise)
259
+ target = v
260
+
261
+ # MSE over channels/spatial dims -> mean over batch
262
+ loss = F.mse_loss(model_out, target, reduction='none')
263
+ loss = loss.mean(dim=list(range(1, loss.ndim))) # average over C,H,W
264
+ # snr-based weight (here often ==1)
265
+ loss = loss * extract(self.loss_weight, t, loss.shape)
266
+ return loss.mean()
267
+
268
+ def forward(self, img):
269
+ """
270
+ Training entry point:
271
+ - Normalize to [-1,1]
272
+ - Draw random timesteps
273
+ - Compute loss
274
+ """
275
+ img = img.to(device=self.device, dtype=next(
276
+ self.model.parameters()).dtype)
277
+ b, c, h, w = img.shape
278
+ assert (
279
+ h, w) == self.image_size, f"image must be {self.image_size}, got {(h,w)}"
280
+ t = torch.randint(0, self.num_timesteps, (b,),
281
+ device=img.device).long()
282
+ img = self.normalize(img)
283
+ return self.p_losses(img, t)
284
+
285
+ # ----------------------
286
+ # Single DDPM step p(x_{t-1}|x_t)
287
+ # ----------------------
288
+ @torch.inference_mode()
289
+ def p_sample(self, x, t: int, x_self_cond=None):
290
+ """
291
+ Compute one reverse step:
292
+ - predict (epsilon, x0), compute posterior q(x_{t-1}|x_t, x0)
293
+ - sample from that Gaussian (add noise except at t=0)
294
+ """
295
+ b = x.shape[0]
296
+ tt = torch.full((b,), t, device=self.device, dtype=torch.long)
297
+ pred_noise, x0 = self.model_predictions(
298
+ x, tt, x_self_cond, clip_x_start=True)
299
+ mean, _, log_var = self.q_posterior(x0, x, tt)
300
+ noise = torch.randn_like(x) if t > 0 else 0.0
301
+ return mean + (0.5 * log_var).exp() * noise, x0
302
+
303
+ # ----------------------
304
+ # Sampling loops
305
+ # ----------------------
306
+ @torch.inference_mode()
307
+ def ddpm_sample(self, shape):
308
+ """
309
+ DDPM sampling with T steps (slow, high quality).
310
+ """
311
+ img = torch.randn(shape, device=self.device)
312
+ x0 = None
313
+ for t in reversed(range(self.num_timesteps)):
314
+ self_cond = x0 if self.self_condition else None
315
+ img, x0 = self.p_sample(img, t, self_cond)
316
+ return self.unnormalize(img)
317
+
318
+ @torch.inference_mode()
319
+ def ddim_sample(self, shape):
320
+ """
321
+ DDIM sampling with S < T steps (fast, often good quality).
322
+ Deterministic when eta=0.0.
323
+ """
324
+ T, S, eta = self.num_timesteps, self.sampling_steps, self.ddim_sampling_eta
325
+ # create a decreasing time index schedule of length S+1: [T-1, ..., 0, -1]
326
+ times = torch.linspace(-1, T - 1, steps=S + 1,
327
+ device=self.device).long().flip(0)
328
+ pairs = list(zip(times[:-1].tolist(), times[1:].tolist()))
329
+
330
+ img = torch.randn(shape, device=self.device)
331
+ x0 = None
332
+
333
+ for t, t_next in pairs:
334
+ tt = torch.full(
335
+ (shape[0],), t, device=self.device, dtype=torch.long)
336
+ pred_noise, x0 = self.model_predictions(
337
+ img, tt, None, clip_x_start=True, rederive_pred_noise=True)
338
+
339
+ if t_next < 0:
340
+ # final step: directly set to predicted x0
341
+ img = x0
342
+ continue
343
+
344
+ a_t, a_next = self.alphas_cumprod[t], self.alphas_cumprod[t_next]
345
+ sigma = eta * ((1 - a_t / a_next) *
346
+ (1 - a_next) / (1 - a_t)).sqrt()
347
+ c = (1 - a_next - sigma ** 2).sqrt()
348
+ noise = torch.randn_like(img)
349
+
350
+ # DDIM update rule
351
+ img = x0 * a_next.sqrt() + c * pred_noise + sigma * noise
352
+
353
+ return self.unnormalize(img)
354
+
355
+ @torch.inference_mode()
356
+ def sample(self, batch_size=16):
357
+ """
358
+ Public sampling API:
359
+ - choose DDPM or DDIM depending on `sampling_steps`
360
+ - returns a batch of images in [0,1]
361
+ """
362
+ H, W = self.image_size
363
+ fn = self.ddim_sample if self.is_ddim_sampling else self.ddpm_sample
364
+ return fn((batch_size, self.channels, H, W))
365
+
366
+ # In diffusion_core.py (add these methods inside GaussianDiffusion)
367
+
368
+ # ----------------------
369
+ # DDPM sampling with trajectory recording and foward transformations
370
+ # ----------------------
371
+
372
+ @torch.inference_mode()
373
+ def ddpm_sample_trajectory(self, shape, record_every=50, return_x0=False):
374
+ """
375
+ DDPM sampling but also record intermediate frames.
376
+ - record_every: save a snapshot every N steps (also includes first/last).
377
+ - return_x0: if True, also store predicted x0 at the same checkpoints.
378
+
379
+ Returns:
380
+ final_img [B,C,H,W] in [0,1],
381
+ frames_xt: list of tensors in [0,1], each [B,C,H,W]
382
+ frames_x0 (or None): same length as frames_xt if return_x0=True
383
+ """
384
+ img = torch.randn(shape, device=self.device)
385
+ frames_xt = []
386
+ frames_x0 = [] if return_x0 else None
387
+
388
+ x0 = None
389
+ T = self.num_timesteps
390
+
391
+ for t in reversed(range(T)):
392
+ # record current x_t before stepping
393
+ if t == T - 1 or t == 0 or (t % record_every) == 0:
394
+ # unnormalize for visualization (to [0,1])
395
+ frames_xt.append(self.unnormalize(img.clamp(-1, 1)))
396
+ if return_x0 and x0 is not None:
397
+ frames_x0.append(self.unnormalize(x0.clamp(-1, 1)))
398
+
399
+ self_cond = x0 if self.self_condition else None
400
+ img, x0 = self.p_sample(img, t, self_cond)
401
+
402
+ # record the final image
403
+ frames_xt.append(self.unnormalize(img.clamp(-1, 1)))
404
+ if return_x0:
405
+ frames_x0.append(self.unnormalize(x0.clamp(-1, 1)))
406
+
407
+ return self.unnormalize(img), frames_xt, frames_x0
408
+
409
+ @torch.no_grad()
410
+ def forward_noising_trajectory(self, x0, t_values):
411
+ """
412
+ Visualize forward diffusion q(x_t | x_0) at selected t.
413
+ Args:
414
+ x0: clean images in [0,1], [B,C,H,W]
415
+ t_values: list/iterable of ints (0..T-1)
416
+
417
+ Returns:
418
+ frames_xt: list of tensors in [0,1], each [B,C,H,W]
419
+ """
420
+ # normalize like training path
421
+ x0n = self.normalize(x0.to(self.device))
422
+ frames = []
423
+ for t in t_values:
424
+ tt = torch.full((x0n.size(0),), int(
425
+ t), device=self.device, dtype=torch.long)
426
+ xt = self.q_sample(x0n, tt) # in [-1,1] domain
427
+ # map back to [0,1] for viewing
428
+ frames.append(self.unnormalize(xt))
429
+ return frames
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b434bb9f31f1b7204aa76c4b93881e896f3b0280233d4237518357cb71cd14d5
3
+ size 33190930
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use PyTorch CUDA 12.1 wheels for torch/torchvision
2
+ --index-url https://download.pytorch.org/whl/cu121
3
+ torch==2.3.1
4
+ torchvision==0.18.1
5
+
6
+ # Core utils
7
+ pyyaml>=6.0.1
8
+ tqdm>=4.66.0
9
+ numpy>=1.26.0
10
+ Pillow>=10.0.0
11
+ einops>=0.7.0
12
+
13
+
14
+ # Logging & videos
15
+ wandb>=0.16.0
16
+ imageio>=2.31.0
17
+ imageio-ffmpeg>=0.4.9 # để ghi MP4 mà không cần ffmpeg hệ thống
18
+
19
+ # Metrics (nếu bật FID/IS)
20
+ clean-fid>=0.1.35
21
+ torch-fidelity>=0.3.0
22
+
23
+ # (Tùy chọn) Đẩy model lên Hugging Face
24
+ huggingface_hub>=0.23.0
unet.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # unet.py
2
+ # Lightweight UNet for MNIST:
3
+ # - Optional outer attention disabled (Identity)
4
+ # - Full attention only at bottleneck
5
+
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+
12
+
13
+ def divisible_by(x, y):
14
+ return x % y == 0
15
+
16
+
17
+ class RMSNorm(nn.Module):
18
+ def __init__(self, dim):
19
+ super().__init__()
20
+ self.scale = dim ** 0.5
21
+ self.weight = nn.Parameter(torch.ones(1, dim, 1, 1))
22
+
23
+ def forward(self, x):
24
+ return F.normalize(x, dim=1) * self.weight * self.scale
25
+
26
+
27
+ class SinusoidalPosEmb(nn.Module):
28
+ def __init__(self, dim, theta=10000):
29
+ super().__init__()
30
+ self.dim = dim
31
+ self.theta = theta
32
+
33
+ def forward(self, t):
34
+ device = t.device
35
+ half = self.dim // 2
36
+ freqs = torch.exp(torch.arange(half, device=device)
37
+ * -(torch.log(torch.tensor(self.theta)) / (half - 1)))
38
+ args = t.float()[:, None] * freqs[None, :]
39
+ return torch.cat([args.sin(), args.cos()], dim=-1)
40
+
41
+
42
+ class Block(nn.Module):
43
+ def __init__(self, dim, dim_out, dropout=0.):
44
+ super().__init__()
45
+
46
+ self.project = nn.Conv2d(dim, dim_out, kernel_size=3, padding=1)
47
+ self.norm = RMSNorm(dim_out)
48
+ self.dropout = nn.Dropout(dropout)
49
+ self.activation = nn.SiLU()
50
+
51
+ def forward(self, x, shift_scale=None):
52
+ x = self.project(x)
53
+ x = self.norm(x)
54
+ if shift_scale is not None:
55
+ s, b = shift_scale
56
+ x = x * (s + 1) + b
57
+ x = self.dropout(self.activation(x))
58
+
59
+ return x
60
+
61
+
62
+ class ResnetBlock(nn.Module):
63
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, dropout=0.):
64
+ super().__init__()
65
+ self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(
66
+ time_emb_dim, dim_out * 2)) if time_emb_dim else None
67
+ self.b1 = Block(dim, dim_out, dropout=dropout)
68
+ self.b2 = Block(dim_out, dim_out, dropout=0.)
69
+ self.skip = nn.Conv2d(
70
+ dim, dim_out, 1) if dim != dim_out else nn.Identity()
71
+
72
+ def forward(self, x, t=None):
73
+ scale_shift = None
74
+ if self.mlp is not None and t is not None:
75
+ emb = self.mlp(t).view(t.size(0), -1, 1, 1)
76
+ scale_shift = emb.chunk(2, dim=1)
77
+ h = self.b1(x, scale_shift)
78
+ h = self.b2(h)
79
+ return h + self.skip(x)
80
+
81
+
82
+ class LinearAttention(nn.Module):
83
+ def __init__(self, dim, heads=2, dim_head=16):
84
+ super().__init__()
85
+ self.heads = heads
86
+ self.norm = RMSNorm(dim)
87
+ self.to_qkv = nn.Conv2d(dim, dim_head * heads * 3, 1, bias=False)
88
+ self.to_out = nn.Sequential(
89
+ nn.Conv2d(dim_head * heads, dim, 1), RMSNorm(dim))
90
+ self.scale = dim_head ** -0.5
91
+
92
+ def forward(self, x):
93
+ b, c, h, w = x.shape
94
+ x = self.norm(x)
95
+ q, k, v = self.to_qkv(x).chunk(3, dim=1)
96
+ q = rearrange(q, 'b (h d) x y -> b h d (x y)', h=self.heads)
97
+ k = rearrange(k, 'b (h d) x y -> b h d (x y)', h=self.heads)
98
+ v = rearrange(v, 'b (h d) x y -> b h d (x y)', h=self.heads)
99
+ q = torch.softmax(q, dim=-2) * self.scale
100
+ k = torch.softmax(k, dim=-1)
101
+ ctx = torch.einsum('b h d n, b h e n -> b h d e', k, v)
102
+ out = torch.einsum('b h d e, b h d n -> b h e n', ctx, q)
103
+ out = rearrange(out, 'b h c (x y) -> b (h c) x y', x=h, y=w)
104
+ return self.to_out(out)
105
+
106
+
107
+ class FullAttention(nn.Module):
108
+ def __init__(self, dim, heads=2, dim_head=16):
109
+ super().__init__()
110
+ self.heads = heads
111
+ inner = heads * dim_head
112
+ self.norm = RMSNorm(dim)
113
+ self.to_qkv = nn.Conv2d(dim, inner * 3, 1, bias=False)
114
+ self.to_out = nn.Conv2d(inner, dim, 1)
115
+ self.scale = dim_head ** -0.5
116
+
117
+ def forward(self, x):
118
+ b, c, h, w = x.shape
119
+ x = self.norm(x)
120
+ q, k, v = self.to_qkv(x).chunk(3, dim=1)
121
+ q = rearrange(q, 'b (h d) x y -> b h (x y) d', h=self.heads)
122
+ k = rearrange(k, 'b (h d) x y -> b h (x y) d', h=self.heads)
123
+ v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=self.heads)
124
+ attn = torch.softmax((q @ k.transpose(-1, -2)) * self.scale, dim=-1)
125
+ out = attn @ v
126
+ out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
127
+ return self.to_out(out) + x
128
+
129
+
130
+ class UNet(nn.Module):
131
+ """
132
+ Minimal UNet for MNIST 32x32.
133
+ - outer_attn=False -> use Identity in outer levels
134
+ - FullAttention at bottleneck only
135
+ """
136
+
137
+ def __init__(self, dim=32, init_dim=None, out_dim=None, dim_mults=(1, 2, 4),
138
+ channels=1, dropout=0.0, attn_heads=2, attn_dim_head=16,
139
+ self_condition=False, learned_variance=False, outer_attn=False):
140
+ super().__init__()
141
+ self.channels = channels
142
+ self.self_condition = self_condition
143
+ self.learned_variance = learned_variance
144
+
145
+ in_ch = channels * (2 if self_condition else 1)
146
+ init_dim = init_dim or dim
147
+ self.init_conv = nn.Conv2d(in_ch, init_dim, 7, padding=3)
148
+
149
+ dims = [init_dim, *[dim * m for m in dim_mults]]
150
+ in_out = list(zip(dims[:-1], dims[1:]))
151
+
152
+ time_dim = dim * 4
153
+ self.time_mlp = nn.Sequential(
154
+ SinusoidalPosEmb(dim),
155
+ nn.Linear(dim, time_dim),
156
+ nn.GELU(),
157
+ nn.Linear(time_dim, time_dim)
158
+ )
159
+
160
+ self.downs = nn.ModuleList([])
161
+ self.ups = nn.ModuleList([])
162
+
163
+ for i, (d_in, d_out) in enumerate(in_out):
164
+ is_last = i == (len(in_out) - 1)
165
+ attn_mod = LinearAttention(
166
+ d_in, heads=attn_heads, dim_head=attn_dim_head) if outer_attn else nn.Identity()
167
+ self.downs.append(nn.ModuleList([
168
+ ResnetBlock(d_in, d_in, time_emb_dim=time_dim,
169
+ dropout=dropout),
170
+ ResnetBlock(d_in, d_in, time_emb_dim=time_dim,
171
+ dropout=dropout),
172
+ attn_mod,
173
+ (nn.Conv2d(d_in, d_out, 3, padding=1) if is_last else
174
+ nn.Sequential(nn.Conv2d(d_in, d_in, 4, stride=2, padding=1),
175
+ nn.Conv2d(d_in, d_out, 3, padding=1)))
176
+ ]))
177
+
178
+ mid_dim = dims[-1]
179
+ self.mid_block1 = ResnetBlock(
180
+ mid_dim, mid_dim, time_emb_dim=time_dim, dropout=dropout)
181
+ self.mid_attn = FullAttention(
182
+ mid_dim, heads=attn_heads, dim_head=attn_dim_head) # bottleneck
183
+ self.mid_block2 = ResnetBlock(
184
+ mid_dim, mid_dim, time_emb_dim=time_dim, dropout=dropout)
185
+
186
+ for i, (d_in, d_out) in enumerate(reversed(in_out)):
187
+ is_last = i == (len(in_out) - 1)
188
+ attn_mod_up = LinearAttention(
189
+ d_out, heads=attn_heads, dim_head=attn_dim_head) if outer_attn else nn.Identity()
190
+ self.ups.append(nn.ModuleList([
191
+ ResnetBlock(d_out + d_in, d_out,
192
+ time_emb_dim=time_dim, dropout=dropout),
193
+ ResnetBlock(d_out + d_in, d_out,
194
+ time_emb_dim=time_dim, dropout=dropout),
195
+ attn_mod_up,
196
+ (nn.Conv2d(d_out, d_in, 3, padding=1) if is_last else
197
+ nn.Sequential(nn.ConvTranspose2d(d_out, d_out, 4, stride=2, padding=1),
198
+ nn.Conv2d(d_out, d_in, 3, padding=1)))
199
+ ]))
200
+
201
+ self.out_dim = out_dim or channels # learned_variance=False for MNIST
202
+ self.final_res_block = ResnetBlock(
203
+ init_dim * 2, init_dim, time_emb_dim=time_dim, dropout=dropout)
204
+ self.final_conv = nn.Conv2d(init_dim, self.out_dim, 1)
205
+
206
+ @property
207
+ def downsample_factor(
208
+ self): return 2 ** (len(self.downs) - 1) # (len=3) -> 4
209
+
210
+ def forward(self, x, time, x_self_cond=None):
211
+ assert all(divisible_by(d, self.downsample_factor)
212
+ for d in x.shape[-2:])
213
+ if self.self_condition:
214
+ if x_self_cond is None:
215
+ x_self_cond = torch.zeros_like(x)
216
+ x = torch.cat([x_self_cond, x], dim=1)
217
+
218
+ x = self.init_conv(x)
219
+ r = x.clone()
220
+ t = self.time_mlp(time)
221
+
222
+ hs = []
223
+ for b1, b2, attn, down in self.downs:
224
+ x = b1(x, t)
225
+ hs.append(x)
226
+ x = b2(x, t)
227
+ x = attn(x) + x
228
+ hs.append(x)
229
+ x = down(x)
230
+
231
+ x = self.mid_block1(x, t)
232
+ x = self.mid_attn(x) + x
233
+ x = self.mid_block2(x, t)
234
+
235
+ for b1, b2, attn, up in self.ups:
236
+ x = torch.cat([x, hs.pop()], dim=1)
237
+ x = b1(x, t)
238
+ x = torch.cat([x, hs.pop()], dim=1)
239
+ x = b2(x, t)
240
+ x = attn(x) + x
241
+ x = up(x)
242
+
243
+ x = torch.cat([x, r], dim=1)
244
+ x = self.final_res_block(x, t)
245
+ return self.final_conv(x)