Bangchis's picture
Upload folder using huggingface_hub
2a8536d verified
# diffusion_core.py
# -- file này chứa các công thức toán cốt lõi của DDPM/DDIM
# -- mục tiêu: tính các hệ số từ beta-schedule, và 4 hàm quan trọng:
# q_sample, predict_start_from_noise, predict_noise_from_start, q_posterior
# diffusion_core.py
# Core DDPM math: schedules and q/p transformations.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t)
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
def cosine_beta_schedule(timesteps, s=0.008):
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps, dtype=torch.float32) / timesteps
alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
class GaussianDiffusion(nn.Module):
"""
Core diffusion module that wraps a denoiser (UNet):
- Precomputes diffusion constants (betas, alphas, etc.)
- Provides training loss (forward): randomly pick t, add noise, regress target
- Provides sampling loops (DDPM or DDIM)
The denoiser must have forward(x, t, [x_self_cond]), returning a predicted target
(epsilon, x0, or v depending on `objective`).
"""
def __init__(self, model, *, image_size, timesteps=400, beta_schedule='cosine',
objective='pred_noise', sampling_steps=None, eta=0.0,
self_condition=False, auto_normalize=True, clamp_x0=True):
"""
Args:
model (nn.Module): denoiser network (e.g., UNet).
image_size (int or (h,w)): training/sampling resolution (must match UNet).
timesteps (int): T. Smaller (e.g., 400) is enough for MNIST.
beta_schedule (str): only 'cosine' implemented here for simplicity.
objective (str): 'pred_noise'|'pred_x0'|'pred_v' (training target).
sampling_steps (int or None): if set < T => DDIM sampling with S steps; else DDPM full T.
eta (float): DDIM stochasticity (0.0 => deterministic).
self_condition (bool): optional self-conditioning flag.
auto_normalize (bool): map inputs [0,1] <-> [-1,1] inside module.
clamp_x0 (bool): clamp predicted x0 to [-1,1] during sampling for stability.
"""
super().__init__()
self.model = model
param = next(model.parameters())
param_dtype = param.dtype
param_device = param.device
self.channels = model.channels
self.self_condition = self_condition
self.objective = objective
self.clamp_x0 = clamp_x0
# In-module normalization helpers (kept simple & explicit)
self.normalize = (lambda x: x * 2 -
1) if auto_normalize else (lambda x: x)
self.unnormalize = (lambda x: (x + 1) *
0.5) if auto_normalize else (lambda x: x)
# Normalize image_size to (H, W)
if isinstance(image_size, int):
image_size = (image_size, image_size)
self.image_size = image_size
# --- schedule setup ---
if beta_schedule != 'cosine':
raise NotImplementedError(
"For MNIST small, keep beta_schedule='cosine'")
betas = cosine_beta_schedule(timesteps).to(
device=param_device, dtype=param_dtype) # shape [T]
alphas = 1.0 - betas # alpha_t
alphas_cumprod = torch.cumprod(alphas, dim=0) # alpha_bar_t
alphas_cumprod_prev = F.pad(
alphas_cumprod[:-1], (1, 0), value=1.0) # alpha_bar_{t-1}
# Timesteps used in training and sampling
self.num_timesteps = int(betas.shape[0])
self.sampling_steps = int(
sampling_steps) if sampling_steps else self.num_timesteps
self.is_ddim_sampling = self.sampling_steps < self.num_timesteps
self.ddim_sampling_eta = float(eta)
# Register constants as buffers (moved with .to(device), saved in state_dict)
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
torch.sqrt(1.0 - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod',
torch.sqrt(1.0 / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
torch.sqrt(1.0 / alphas_cumprod - 1.0))
# Posterior q(x_{t-1} | x_t, x_0) parameters
posterior_variance = betas * \
(1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
self.register_buffer('posterior_variance', posterior_variance)
self.register_buffer('posterior_log_variance_clipped', torch.log(
posterior_variance.clamp(min=1e-20)))
self.register_buffer('posterior_mean_coef1', betas *
torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev)
* torch.sqrt(1.0 - betas) / (1.0 - alphas_cumprod))
# Optional loss re-weighting by SNR (kept simple here)
snr = alphas_cumprod / (1 - alphas_cumprod)
if objective == 'pred_noise':
loss_weight = snr / snr # becomes 1
elif objective == 'pred_x0':
loss_weight = snr
else: # pred_v
loss_weight = snr / (snr + 1)
self.register_buffer('loss_weight', loss_weight)
@property
def device(self):
"""Convenience: returns the device where buffers live."""
return self.betas.device
# ----------------------
# Forward diffusion (q)
# ----------------------
def q_sample(self, x0, t, noise=None):
"""
Sample x_t from q(x_t | x_0):
x_t = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * noise
"""
if noise is None:
noise = torch.randn_like(x0)
return extract(self.sqrt_alphas_cumprod, t, x0.shape) * x0 + \
extract(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
# ---------------------------------
# Converters between parameterizations
# ---------------------------------
def predict_start_from_noise(self, x_t, t, eps):
"""Given epsilon prediction, reconstruct x0."""
return extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
def predict_noise_from_start(self, x_t, t, x0):
"""Given x0 prediction, reconstruct epsilon."""
return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def predict_v(self, x0, t, eps):
"""v-parameterization = sqrt(alpha_bar)*eps - sqrt(1-alpha_bar)*x0."""
return extract(self.alphas_cumprod.sqrt(), t, x0.shape) * eps - \
extract((1.0 - self.alphas_cumprod).sqrt(), t, x0.shape) * x0
def predict_start_from_v(self, x_t, t, v):
"""Given v prediction, reconstruct x0."""
return extract(self.alphas_cumprod.sqrt(), t, x_t.shape) * x_t - \
extract((1.0 - self.alphas_cumprod).sqrt(), t, x_t.shape) * v
# ---------------------------------
# Model predictions at time t
# ---------------------------------
def model_predictions(self, x, t, x_self_cond=None, clip_x_start=False, rederive_pred_noise=False):
"""
Run the denoiser and return (pred_noise, x0):
- If objective == pred_noise: UNet predicts epsilon directly.
- If objective == pred_x0: UNet predicts x0 directly.
- If objective == pred_v: UNet predicts v; we convert to x0 & epsilon.
Args:
x (Tensor): noised image x_t.
t (LongTensor): time indices.
x_self_cond (Tensor|None): optional self-conditioning input.
clip_x_start (bool): clamp x0 to [-1,1] after prediction.
rederive_pred_noise (bool): if True, recompute epsilon from clamped x0.
Returns:
(pred_noise, x0) both shape like x.
"""
out = self.model(
x, t, x_self_cond) if x_self_cond is not None else self.model(x, t)
maybe_clip = (lambda z: z.clamp(-1, 1)
) if clip_x_start else (lambda z: z)
if self.objective == 'pred_noise':
pred_noise = out
x0 = self.predict_start_from_noise(x, t, pred_noise)
x0 = maybe_clip(x0)
if clip_x_start and rederive_pred_noise:
pred_noise = self.predict_noise_from_start(x, t, x0)
elif self.objective == 'pred_x0':
x0 = maybe_clip(out)
pred_noise = self.predict_noise_from_start(x, t, x0)
else: # 'pred_v'
v = out
x0 = self.predict_start_from_v(x, t, v)
x0 = maybe_clip(x0)
pred_noise = self.predict_noise_from_start(x, t, x0)
return pred_noise, x0
def q_posterior(self, x0, x_t, t):
"""
Compute the Gaussian q(x_{t-1} | x_t, x0) parameters:
mean = c1 * x0 + c2 * x_t
var, log_var: closed-form from betas and alpha_bars.
"""
mean = extract(self.posterior_mean_coef1, t, x_t.shape) * x0 + \
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
var = extract(self.posterior_variance, t, x_t.shape)
log_var = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return mean, var, log_var
# ----------------------
# Training loss (forward)
# ----------------------
def p_losses(self, x_start, t, noise=None):
"""
DDPM training objective:
- Sample x_t = q(x_t | x_0)
- Predict target according to objective and MSE it
- (Optional) self-conditioning can be added outside for simplicity
"""
noise = torch.randn_like(x_start) if noise is None else noise
x = self.q_sample(x_start, t, noise)
x_self_cond = None
if self.self_condition and torch.rand(1, device=self.device) < 0.5:
# simple self-conditioning: predict x0 once and feed back
with torch.no_grad():
_, x_self_cond = self.model_predictions(
x, t, None, clip_x_start=True)
model_out = self.model(
x, t, x_self_cond) if x_self_cond is not None else self.model(x, t)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
else: # pred_v
v = self.predict_v(x_start, t, noise)
target = v
# MSE over channels/spatial dims -> mean over batch
loss = F.mse_loss(model_out, target, reduction='none')
loss = loss.mean(dim=list(range(1, loss.ndim))) # average over C,H,W
# snr-based weight (here often ==1)
loss = loss * extract(self.loss_weight, t, loss.shape)
return loss.mean()
def forward(self, img):
"""
Training entry point:
- Normalize to [-1,1]
- Draw random timesteps
- Compute loss
"""
img = img.to(device=self.device, dtype=next(
self.model.parameters()).dtype)
b, c, h, w = img.shape
assert (
h, w) == self.image_size, f"image must be {self.image_size}, got {(h,w)}"
t = torch.randint(0, self.num_timesteps, (b,),
device=img.device).long()
img = self.normalize(img)
return self.p_losses(img, t)
# ----------------------
# Single DDPM step p(x_{t-1}|x_t)
# ----------------------
@torch.inference_mode()
def p_sample(self, x, t: int, x_self_cond=None):
"""
Compute one reverse step:
- predict (epsilon, x0), compute posterior q(x_{t-1}|x_t, x0)
- sample from that Gaussian (add noise except at t=0)
"""
b = x.shape[0]
tt = torch.full((b,), t, device=self.device, dtype=torch.long)
pred_noise, x0 = self.model_predictions(
x, tt, x_self_cond, clip_x_start=True)
mean, _, log_var = self.q_posterior(x0, x, tt)
noise = torch.randn_like(x) if t > 0 else 0.0
return mean + (0.5 * log_var).exp() * noise, x0
# ----------------------
# Sampling loops
# ----------------------
@torch.inference_mode()
def ddpm_sample(self, shape):
"""
DDPM sampling with T steps (slow, high quality).
"""
img = torch.randn(shape, device=self.device)
x0 = None
for t in reversed(range(self.num_timesteps)):
self_cond = x0 if self.self_condition else None
img, x0 = self.p_sample(img, t, self_cond)
return self.unnormalize(img)
@torch.inference_mode()
def ddim_sample(self, shape):
"""
DDIM sampling with S < T steps (fast, often good quality).
Deterministic when eta=0.0.
"""
T, S, eta = self.num_timesteps, self.sampling_steps, self.ddim_sampling_eta
# create a decreasing time index schedule of length S+1: [T-1, ..., 0, -1]
times = torch.linspace(-1, T - 1, steps=S + 1,
device=self.device).long().flip(0)
pairs = list(zip(times[:-1].tolist(), times[1:].tolist()))
img = torch.randn(shape, device=self.device)
x0 = None
for t, t_next in pairs:
tt = torch.full(
(shape[0],), t, device=self.device, dtype=torch.long)
pred_noise, x0 = self.model_predictions(
img, tt, None, clip_x_start=True, rederive_pred_noise=True)
if t_next < 0:
# final step: directly set to predicted x0
img = x0
continue
a_t, a_next = self.alphas_cumprod[t], self.alphas_cumprod[t_next]
sigma = eta * ((1 - a_t / a_next) *
(1 - a_next) / (1 - a_t)).sqrt()
c = (1 - a_next - sigma ** 2).sqrt()
noise = torch.randn_like(img)
# DDIM update rule
img = x0 * a_next.sqrt() + c * pred_noise + sigma * noise
return self.unnormalize(img)
@torch.inference_mode()
def sample(self, batch_size=16):
"""
Public sampling API:
- choose DDPM or DDIM depending on `sampling_steps`
- returns a batch of images in [0,1]
"""
H, W = self.image_size
fn = self.ddim_sample if self.is_ddim_sampling else self.ddpm_sample
return fn((batch_size, self.channels, H, W))
# In diffusion_core.py (add these methods inside GaussianDiffusion)
# ----------------------
# DDPM sampling with trajectory recording and foward transformations
# ----------------------
@torch.inference_mode()
def ddpm_sample_trajectory(self, shape, record_every=50, return_x0=False):
"""
DDPM sampling but also record intermediate frames.
- record_every: save a snapshot every N steps (also includes first/last).
- return_x0: if True, also store predicted x0 at the same checkpoints.
Returns:
final_img [B,C,H,W] in [0,1],
frames_xt: list of tensors in [0,1], each [B,C,H,W]
frames_x0 (or None): same length as frames_xt if return_x0=True
"""
img = torch.randn(shape, device=self.device)
frames_xt = []
frames_x0 = [] if return_x0 else None
x0 = None
T = self.num_timesteps
for t in reversed(range(T)):
# record current x_t before stepping
if t == T - 1 or t == 0 or (t % record_every) == 0:
# unnormalize for visualization (to [0,1])
frames_xt.append(self.unnormalize(img.clamp(-1, 1)))
if return_x0 and x0 is not None:
frames_x0.append(self.unnormalize(x0.clamp(-1, 1)))
self_cond = x0 if self.self_condition else None
img, x0 = self.p_sample(img, t, self_cond)
# record the final image
frames_xt.append(self.unnormalize(img.clamp(-1, 1)))
if return_x0:
frames_x0.append(self.unnormalize(x0.clamp(-1, 1)))
return self.unnormalize(img), frames_xt, frames_x0
@torch.no_grad()
def forward_noising_trajectory(self, x0, t_values):
"""
Visualize forward diffusion q(x_t | x_0) at selected t.
Args:
x0: clean images in [0,1], [B,C,H,W]
t_values: list/iterable of ints (0..T-1)
Returns:
frames_xt: list of tensors in [0,1], each [B,C,H,W]
"""
# normalize like training path
x0n = self.normalize(x0.to(self.device))
frames = []
for t in t_values:
tt = torch.full((x0n.size(0),), int(
t), device=self.device, dtype=torch.long)
xt = self.q_sample(x0n, tt) # in [-1,1] domain
# map back to [0,1] for viewing
frames.append(self.unnormalize(xt))
return frames