Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint)
c46900a verified | """ | |
| Diffusion Process Implementation (DDPM + DDIM) | |
| Changes from original: | |
| - Fixed q_posterior_mean_variance return value count (was 3, caller expected 4) | |
| - Fixed DDIM sigma/dir_xt formula inconsistency | |
| - Made GaussianDiffusion an nn.Module with registered buffers for proper device handling | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| class GaussianDiffusion(nn.Module): | |
| def __init__(self, timesteps=1500, beta_start=1e-4, beta_end=0.02, schedule_type="linear"): | |
| super().__init__() | |
| self.timesteps = timesteps | |
| if schedule_type == "linear": | |
| betas = torch.linspace(beta_start, beta_end, timesteps) | |
| elif schedule_type == "cosine": | |
| betas = self._cosine_beta_schedule(timesteps) | |
| else: | |
| raise ValueError(f"Unknown schedule: {schedule_type}") | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, dim=0) | |
| alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) | |
| # Register all schedule tensors as buffers so they move with .to(device) | |
| self.register_buffer('betas', betas) | |
| self.register_buffer('alphas', alphas) | |
| self.register_buffer('alphas_cumprod', alphas_cumprod) | |
| self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) | |
| self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) | |
| self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod)) | |
| posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) | |
| self.register_buffer('posterior_variance', posterior_variance) | |
| self.register_buffer('posterior_log_variance_clipped', torch.log(torch.clamp(posterior_variance, min=1e-20))) | |
| self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) | |
| self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)) | |
| # Precompute reciprocals used in _predict_xstart_from_noise (avoids recomputation per step) | |
| self.register_buffer('recip_sqrt_alphas_cumprod', 1.0 / torch.sqrt(alphas_cumprod)) | |
| self.register_buffer('sqrt_recip_minus_one', torch.sqrt(1.0 / alphas_cumprod - 1.0)) | |
| def _cosine_beta_schedule(self, timesteps, s=0.008): | |
| steps = timesteps + 1 | |
| x = torch.linspace(0, timesteps, steps) | |
| alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 | |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | |
| betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) | |
| return torch.clip(betas, 0.0001, 0.9999) | |
| def q_sample(self, x_start, t, noise=None): | |
| if noise is None: | |
| noise = torch.randn_like(x_start) | |
| sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) | |
| sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) | |
| return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise | |
| def p_mean_variance(self, model, x_t, t, labels, clip_denoised=True): | |
| pred_noise = model(x_t, t, labels) | |
| x_start = self._predict_xstart_from_noise(x_t, t, pred_noise) | |
| if clip_denoised: | |
| x_start = torch.clamp(x_start, -1.0, 1.0) | |
| # FIX: q_posterior_mean_variance returns 3 values, not 4 | |
| model_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(x_start, x_t, t) | |
| return model_mean, posterior_variance, posterior_log_variance, x_start | |
| def _predict_xstart_from_noise(self, x_t, t, noise): | |
| return ( | |
| self._extract(self.recip_sqrt_alphas_cumprod, t, x_t.shape) * x_t - | |
| self._extract(self.sqrt_recip_minus_one, t, x_t.shape) * noise | |
| ) | |
| def q_posterior_mean_variance(self, x_start, x_t, t): | |
| posterior_mean = ( | |
| self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + | |
| self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
| ) | |
| posterior_variance = self._extract(self.posterior_variance, t, x_t.shape) | |
| posterior_log_variance = self._extract(self.posterior_log_variance_clipped, t, x_t.shape) | |
| return posterior_mean, posterior_variance, posterior_log_variance | |
| def p_sample(self, model, x_t, t, labels): | |
| model_mean, _, model_log_variance, _ = self.p_mean_variance(model, x_t, t, labels) | |
| noise = torch.randn_like(x_t) | |
| nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))) | |
| return model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise | |
| def ddim_sample_step(self, model, x_t, t, t_next, labels, eta=0.0): | |
| pred_noise = model(x_t, t, labels) | |
| alpha_t = self._extract(self.alphas_cumprod, t, x_t.shape) | |
| alpha_t_next = self._extract(self.alphas_cumprod, t_next, x_t.shape) if t_next[0] >= 0 else torch.ones_like(alpha_t) | |
| x0_pred = (x_t - torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(alpha_t) | |
| x0_pred = torch.clamp(x0_pred, -1.0, 1.0) | |
| # FIX: Consistent DDIM sigma computation | |
| # sigma^2 = eta^2 * (1 - alpha_{t-1}) / (1 - alpha_t) * (1 - alpha_t / alpha_{t-1}) | |
| sigma_sq = eta**2 * (1 - alpha_t_next) / (1 - alpha_t) * (1 - alpha_t / alpha_t_next) if eta > 0 else 0 | |
| sigma_t = torch.sqrt(torch.clamp(sigma_sq, min=0)) if eta > 0 else 0 | |
| # dir_xt uses the same sigma^2 | |
| dir_xt = torch.sqrt(torch.clamp(1 - alpha_t_next - sigma_sq, min=0)) * pred_noise | |
| noise = torch.randn_like(x_t) if eta > 0 else 0 | |
| return torch.sqrt(alpha_t_next) * x0_pred + dir_xt + sigma_t * noise | |
| def sample(self, model, labels, channels, height, width, device, progress=False, use_ddim=True, ddim_steps=50, eta=0.0): | |
| batch_size = labels.shape[0] | |
| img = torch.randn((batch_size, channels, height, width), device=device) | |
| if use_ddim: | |
| skip = self.timesteps // ddim_steps | |
| seq = list(range(0, self.timesteps, skip)) | |
| seq_next = [-1] + seq[:-1] | |
| seq_iter = reversed(list(zip(seq, seq_next))) | |
| if progress: | |
| from tqdm import tqdm | |
| seq_iter = tqdm(seq_iter, desc=f'DDIM Sampling ({ddim_steps} steps)', total=len(seq)) | |
| for i, j in seq_iter: | |
| t = torch.full((batch_size,), i, device=device, dtype=torch.long) | |
| t_next = torch.full((batch_size,), j, device=device, dtype=torch.long) | |
| img = self.ddim_sample_step(model, img, t, t_next, labels, eta) | |
| else: | |
| if progress: | |
| from tqdm import tqdm | |
| timesteps_iter = tqdm(reversed(range(self.timesteps)), total=self.timesteps) | |
| else: | |
| timesteps_iter = reversed(range(self.timesteps)) | |
| for i in timesteps_iter: | |
| t = torch.full((batch_size,), i, device=device, dtype=torch.long) | |
| img = self.p_sample(model, img, t, labels) | |
| return img | |
| def training_losses(self, model, x_start, labels, t, noise=None): | |
| if noise is None: | |
| noise = torch.randn_like(x_start) | |
| x_t = self.q_sample(x_start, t, noise) | |
| pred_noise = model(x_t, t, labels) | |
| return F.mse_loss(pred_noise, noise, reduction='none').mean(dim=list(range(1, len(pred_noise.shape)))) | |
| def _extract(self, a, t, x_shape): | |
| batch_size = t.shape[0] | |
| out = a.gather(0, t) | |
| return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))) | |
| class ConditionalDiffusionModel(nn.Module): | |
| def __init__(self, unet, diffusion_process): | |
| super().__init__() | |
| self.unet = unet | |
| self.diffusion = diffusion_process | |
| def forward(self, x, t, labels): | |
| return self.unet(x, t, labels) | |
| def get_loss(self, x, labels, noise=None): | |
| batch_size = x.shape[0] | |
| device = x.device | |
| t = torch.randint(0, self.diffusion.timesteps, (batch_size,), device=device).long() | |
| return self.diffusion.training_losses(self, x, labels, t, noise=noise).mean() | |
| def sample(self, labels, channels, height, width, device, progress=False, use_ddim=True, ddim_steps=50, eta=0.0): | |
| self.eval() | |
| with torch.no_grad(): | |
| return self.diffusion.sample(self, labels, channels, height, width, device, progress, use_ddim, ddim_steps, eta) | |