RinKana commited on
Commit
468bd92
·
verified ·
1 Parent(s): 884fd0e

Upload diffusion_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffusion_utils.py +260 -0
diffusion_utils.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Tuple, Optional
4
+
5
+
6
+ class NoiseScheduler:
7
+ """Diffusion noise scheduler with cosine schedule."""
8
+
9
+ def __init__(self, num_timesteps: int = 1000, schedule_type: str = "cosine"):
10
+ self.num_timesteps = num_timesteps
11
+ self.schedule_type = schedule_type
12
+
13
+ if schedule_type == "cosine":
14
+ # Cosine schedule (more stable for small images)
15
+ s = 0.008
16
+ steps = num_timesteps + 1
17
+ x = torch.linspace(0, num_timesteps, steps)
18
+ alpha_bars = torch.cos(((x / num_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
19
+ alpha_bars = alpha_bars / alpha_bars[0] # Normalize
20
+ alphas = alpha_bars[1:] / alpha_bars[:-1]
21
+ alphas = torch.clamp(alphas, 0.0001, 0.9999)
22
+ else:
23
+ # Linear schedule (original DDPM)
24
+ betas = torch.linspace(1e-4, 0.02, num_timesteps)
25
+ alphas = 1.0 - betas
26
+
27
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
28
+
29
+ # Pre-compute values for training
30
+ self.register_buffer('alphas', alphas)
31
+ self.register_buffer('alphas_cumprod', alphas_cumprod)
32
+ self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
33
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))
34
+
35
+ def register_buffer(self, name: str, tensor: torch.Tensor):
36
+ """Register a buffer that persists with the module."""
37
+ setattr(self, name, tensor)
38
+
39
+ def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
40
+ """Sample random timesteps for training."""
41
+ return torch.randint(0, self.num_timesteps, (batch_size,), device=device, dtype=torch.long)
42
+
43
+ def add_noise(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
44
+ """
45
+ Add noise to clean images according to diffusion forward process.
46
+
47
+ Args:
48
+ x_0: Clean images [B, C, H, W]
49
+ t: Timestep indices [B]
50
+ noise: Optional pre-sampled noise
51
+
52
+ Returns:
53
+ x_t: Noisy images
54
+ noise: The noise that was added
55
+ """
56
+ if noise is None:
57
+ noise = torch.randn_like(x_0)
58
+
59
+ # Ensure buffers are on the same device as input
60
+ if self.sqrt_alphas_cumprod.device != x_0.device:
61
+ self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(x_0.device)
62
+ self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(x_0.device)
63
+
64
+ # Get sqrt(alpha_bar) and sqrt(1-alpha_bar) for each timestep
65
+ sqrt_alpha_bar = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
66
+ sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
67
+
68
+ # Forward diffusion: x_t = sqrt(alpha_bar) * x_0 + sqrt(1-alpha_bar) * epsilon
69
+ x_t = sqrt_alpha_bar * x_0 + sqrt_one_minus_alpha_bar * noise
70
+
71
+ return x_t, noise
72
+
73
+ def get_sampling_schedule(self, num_samples: int = None) -> np.ndarray:
74
+ """Get timesteps for sampling (reverse process)."""
75
+ if num_samples is None:
76
+ return np.arange(self.num_timesteps - 1, -1, -1)
77
+ else:
78
+ return np.linspace(self.num_timesteps - 1, 0, num_samples, dtype=int)
79
+
80
+
81
+ @torch.no_grad()
82
+ def sample_diffusion(
83
+ model: torch.nn.Module,
84
+ scheduler: NoiseScheduler,
85
+ shape: Tuple[int, int, int],
86
+ device: torch.device,
87
+ num_steps: Optional[int] = None,
88
+ guidance_scale: float = 1.0,
89
+ clip_x0: bool = True
90
+ ) -> torch.Tensor:
91
+ """
92
+ Generate samples using the reverse diffusion process.
93
+
94
+ Args:
95
+ model: Trained U-Net model
96
+ scheduler: Noise scheduler
97
+ shape: (C, H, W) output shape
98
+ device: Device to run on
99
+ num_steps: Number of denoising steps (None = use all)
100
+ guidance_scale: Classifier-free guidance scale (1.0 = no guidance)
101
+ clip_x0: Whether to clip predicted x_0 to [-1, 1]
102
+
103
+ Returns:
104
+ Generated images in range [-1, 1]
105
+ """
106
+ model.eval()
107
+
108
+ batch_size = shape[0] if len(shape) == 4 else 1
109
+ c, h, w = shape[-3:]
110
+
111
+ # Start from pure noise
112
+ x = torch.randn(batch_size, c, h, w, device=device)
113
+
114
+ # Get timesteps
115
+ if num_steps is None:
116
+ timesteps = scheduler.get_sampling_schedule()
117
+ else:
118
+ timesteps = scheduler.get_sampling_schedule(num_steps)
119
+
120
+ # Sampling loop
121
+ for i, t in enumerate(timesteps):
122
+ t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
123
+
124
+ # Predict noise
125
+ noise_pred = model(x, t_batch)
126
+
127
+ # Compute alpha values for this timestep
128
+ alpha_bar = scheduler.alphas_cumprod[t]
129
+ alpha = scheduler.alphas[t] if t > 0 else torch.tensor(1.0, device=device)
130
+
131
+ # Posterior variance
132
+ if t == 0:
133
+ variance = 0
134
+ else:
135
+ beta = 1 - alpha
136
+ variance = beta * (1 - alpha_bar) / (1 - alpha)
137
+
138
+ # Denoise step (simplified DDIM-style for speed)
139
+ if guidance_scale != 1.0:
140
+ # Classifier-free guidance would go here (requires conditional model)
141
+ pass
142
+
143
+ # Compute predicted x_0
144
+ pred_x0 = (x - noise_pred * torch.sqrt(1 - alpha_bar)) / torch.sqrt(alpha_bar)
145
+
146
+ if clip_x0:
147
+ pred_x0 = torch.clamp(pred_x0, -1, 1)
148
+
149
+ # Compute direction to next timestep
150
+ if t == 0:
151
+ x = pred_x0
152
+ else:
153
+ prev_alpha_bar = scheduler.alphas_cumprod[t - 1]
154
+ direction = torch.sqrt(1 - prev_alpha_bar) * noise_pred
155
+ x = torch.sqrt(prev_alpha_bar) * pred_x0 + direction
156
+
157
+ # Add variance (optional, can be deterministic)
158
+ if variance > 0:
159
+ if isinstance(variance, torch.Tensor):
160
+ var_tensor = variance.clone().detach().to(device=device, dtype=torch.float32)
161
+ else:
162
+ var_tensor = torch.tensor(variance, device=device, dtype=torch.float32)
163
+ x += torch.randn_like(x) * torch.sqrt(var_tensor)
164
+
165
+ return torch.clamp(x, -1, 1)
166
+
167
+
168
+ def interpolate_images(
169
+ model: torch.nn.Module,
170
+ scheduler: NoiseScheduler,
171
+ img1: torch.Tensor,
172
+ img2: torch.Tensor,
173
+ num_interpolations: int = 5,
174
+ device: Optional[torch.device] = None
175
+ ) -> torch.Tensor:
176
+ """
177
+ Interpolate between two latent representations and generate images.
178
+
179
+ Args:
180
+ model: Trained U-Net model
181
+ scheduler: Noise scheduler
182
+ img1: First image [1, C, H, W]
183
+ img2: Second image [1, C, H, W]
184
+ num_interpolations: Number of intermediate images
185
+ device: Device to run on
186
+
187
+ Returns:
188
+ Interpolated images [num_interpolations+2, C, H, W]
189
+ """
190
+ if device is None:
191
+ device = next(model.parameters()).device
192
+
193
+ img1 = img1.to(device)
194
+ img2 = img2.to(device)
195
+
196
+ # Add same noise to both images at high timestep
197
+ t_high = torch.tensor([scheduler.num_timesteps - 1], device=device)
198
+ noise = torch.randn_like(img1)
199
+
200
+ x1_noisy, _ = scheduler.add_noise(img1, t_high, noise)
201
+ x2_noisy, _ = scheduler.add_noise(img2, t_high, noise)
202
+
203
+ # Interpolate in noisy space
204
+ interpolated_noisy = []
205
+ for alpha in torch.linspace(0, 1, num_interpolations + 2):
206
+ interp = (1 - alpha) * x1_noisy + alpha * x2_noisy
207
+ interpolated_noisy.append(interp)
208
+
209
+ interpolated_noisy = torch.cat(interpolated_noisy, dim=0)
210
+
211
+ # Denoise all interpolated images
212
+ # Note: This is a simplified approach - proper interpolation requires more careful handling
213
+ results = []
214
+ for interp in interpolated_noisy:
215
+ x = interp.unsqueeze(0)
216
+ timesteps = scheduler.get_sampling_schedule()
217
+
218
+ for t in timesteps:
219
+ t_batch = torch.tensor([t], device=device)
220
+ noise_pred = model(x, t_batch)
221
+
222
+ alpha_bar = scheduler.alphas_cumprod[t]
223
+ alpha = scheduler.alphas[t] if t > 0 else torch.tensor(1.0, device=device)
224
+
225
+ pred_x0 = (x - noise_pred * torch.sqrt(1 - alpha_bar)) / torch.sqrt(alpha_bar)
226
+ pred_x0 = torch.clamp(pred_x0, -1, 1)
227
+
228
+ if t == 0:
229
+ x = pred_x0
230
+ else:
231
+ prev_alpha_bar = scheduler.alphas_cumprod[t - 1]
232
+ direction = torch.sqrt(1 - prev_alpha_bar) * noise_pred
233
+ x = torch.sqrt(prev_alpha_bar) * pred_x0 + direction
234
+
235
+ results.append(x)
236
+
237
+ return torch.cat(results, dim=0)
238
+
239
+
240
+ if __name__ == "__main__":
241
+ # Test the diffusion utilities
242
+ print("Testing NoiseScheduler...")
243
+ scheduler = NoiseScheduler(num_timesteps=1000)
244
+
245
+ # Test adding noise
246
+ x_clean = torch.randn(2, 3, 64, 64)
247
+ t = torch.randint(0, 1000, (2,))
248
+ x_noisy, noise = scheduler.add_noise(x_clean, t)
249
+
250
+ print(f"Clean image range: [{x_clean.min():.3f}, {x_clean.max():.3f}]")
251
+ print(f"Noisy image range: [{x_noisy.min():.3f}, {x_noisy.max():.3f}]")
252
+ print(f"Noise shape: {noise.shape}")
253
+
254
+ # Test that we can recover approximate original at t=0
255
+ t_zero = torch.zeros(2, dtype=torch.long)
256
+ x_almost_clean, _ = scheduler.add_noise(x_clean, t_zero)
257
+ mse = torch.mean((x_almost_clean - x_clean) ** 2)
258
+ print(f"MSE at t=0 (should be ~0): {mse:.6f}")
259
+
260
+ print("\nNoiseScheduler tests passed!")