AlexWortega commited on
Commit
70ced22
·
verified ·
1 Parent(s): 54c8086

Upload diffusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffusion.py +150 -0
diffusion.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gaussian Diffusion (DDPM) framework for PDE next-frame prediction.
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+
10
+ class GaussianDiffusion(nn.Module):
11
+ """DDPM with linear beta schedule.
12
+
13
+ Training: given (condition, target), add noise to target, predict noise.
14
+ Sampling: iteratively denoise starting from Gaussian noise.
15
+
16
+ Args:
17
+ model: U-Net (or any eps-predicting network).
18
+ timesteps: number of diffusion steps.
19
+ beta_start: starting noise level.
20
+ beta_end: ending noise level.
21
+ """
22
+
23
+ def __init__(self, model, timesteps=1000, beta_start=1e-4, beta_end=0.02):
24
+ super().__init__()
25
+ self.model = model
26
+ self.T = timesteps
27
+
28
+ # --- precompute schedule ---
29
+ betas = torch.linspace(beta_start, beta_end, timesteps)
30
+ alphas = 1.0 - betas
31
+ alpha_bar = torch.cumprod(alphas, dim=0)
32
+
33
+ self.register_buffer("betas", betas)
34
+ self.register_buffer("alphas", alphas)
35
+ self.register_buffer("alpha_bar", alpha_bar)
36
+ self.register_buffer("sqrt_alpha_bar", torch.sqrt(alpha_bar))
37
+ self.register_buffer("sqrt_one_minus_alpha_bar", torch.sqrt(1 - alpha_bar))
38
+ self.register_buffer("sqrt_recip_alpha", torch.sqrt(1.0 / alphas))
39
+ self.register_buffer(
40
+ "posterior_variance",
41
+ betas * (1.0 - F.pad(alpha_bar[:-1], (1, 0), value=1.0)) / (1.0 - alpha_bar),
42
+ )
43
+
44
+ def q_sample(self, x0, t, noise=None):
45
+ """Forward process: add noise to x0 at timestep t."""
46
+ if noise is None:
47
+ noise = torch.randn_like(x0)
48
+ a = self.sqrt_alpha_bar[t][:, None, None, None]
49
+ b = self.sqrt_one_minus_alpha_bar[t][:, None, None, None]
50
+ return a * x0 + b * noise, noise
51
+
52
+ def training_loss(self, x_target, x_cond):
53
+ """Compute training loss (predict noise).
54
+
55
+ Args:
56
+ x_target: clean target frames [B, C, H, W].
57
+ x_cond: condition frames [B, C, H, W].
58
+
59
+ Returns:
60
+ scalar MSE loss.
61
+ """
62
+ B = x_target.shape[0]
63
+ t = torch.randint(0, self.T, (B,), device=x_target.device)
64
+ noise = torch.randn_like(x_target)
65
+ x_noisy, _ = self.q_sample(x_target, t, noise)
66
+
67
+ eps_pred = self.model(x_noisy, t, cond=x_cond)
68
+ return F.mse_loss(eps_pred, noise)
69
+
70
+ @torch.no_grad()
71
+ def sample(self, x_cond, shape=None):
72
+ """Generate target frames by iterative denoising (DDPM).
73
+
74
+ Args:
75
+ x_cond: condition frames [B, C_cond, H, W].
76
+ shape: (B, C_out, H, W) of the target. Inferred if None.
77
+
78
+ Returns:
79
+ denoised sample [B, C_out, H, W].
80
+ """
81
+ device = x_cond.device
82
+ if shape is None:
83
+ shape = x_cond.shape # assume same channels
84
+
85
+ x = torch.randn(shape, device=device)
86
+
87
+ for i in reversed(range(self.T)):
88
+ t = torch.full((shape[0],), i, device=device, dtype=torch.long)
89
+ eps = self.model(x, t, cond=x_cond)
90
+
91
+ alpha = self.alphas[i]
92
+ alpha_bar = self.alpha_bar[i]
93
+ beta = self.betas[i]
94
+
95
+ mean = (1.0 / alpha.sqrt()) * (x - beta / (1 - alpha_bar).sqrt() * eps)
96
+
97
+ if i > 0:
98
+ sigma = self.posterior_variance[i].sqrt()
99
+ x = mean + sigma * torch.randn_like(x)
100
+ else:
101
+ x = mean
102
+
103
+ return x
104
+
105
+ @torch.no_grad()
106
+ def sample_ddim(self, x_cond, shape=None, steps=50, eta=0.0):
107
+ """DDIM accelerated sampling.
108
+
109
+ Args:
110
+ x_cond: condition [B, C_cond, H, W].
111
+ shape: target shape.
112
+ steps: number of DDIM steps (<<T for speed).
113
+ eta: stochasticity (0=deterministic DDIM, 1=DDPM).
114
+
115
+ Returns:
116
+ denoised sample [B, C_out, H, W].
117
+ """
118
+ device = x_cond.device
119
+ if shape is None:
120
+ shape = x_cond.shape
121
+
122
+ # Sub-sample timesteps uniformly
123
+ step_indices = torch.linspace(0, self.T - 1, steps + 1, dtype=torch.long, device=device)
124
+ step_indices = step_indices.flip(0) # reverse: T-1 ... 0
125
+
126
+ x = torch.randn(shape, device=device)
127
+
128
+ for idx in range(len(step_indices) - 1):
129
+ t_cur = step_indices[idx]
130
+ t_next = step_indices[idx + 1]
131
+
132
+ t_batch = t_cur.expand(shape[0])
133
+ eps = self.model(x, t_batch, cond=x_cond)
134
+
135
+ ab_cur = self.alpha_bar[t_cur]
136
+ ab_next = self.alpha_bar[t_next]
137
+
138
+ # Predict x0
139
+ x0_pred = (x - (1 - ab_cur).sqrt() * eps) / ab_cur.sqrt()
140
+ x0_pred = x0_pred.clamp(-5, 5) # stability clamp
141
+
142
+ # Direction
143
+ sigma = eta * ((1 - ab_next) / (1 - ab_cur) * (1 - ab_cur / ab_next)).sqrt()
144
+ dir_xt = (1 - ab_next - sigma**2).sqrt() * eps
145
+
146
+ x = ab_next.sqrt() * x0_pred + dir_xt
147
+ if sigma > 0:
148
+ x = x + sigma * torch.randn_like(x)
149
+
150
+ return x