DDIM_Image_Generation / models /diffusion.py
NoobNovel's picture
DDIM face generation — full project
0ca4c93
"""DDPM training objective + DDIM (and DDPM) sampling.
Notation follows the original DDPM paper:
- betas, alphas = 1 - betas, alpha_bar_t = prod_{s<=t} alpha_s.
- Forward (closed-form): q(x_t | x_0) = N(sqrt(abar_t) x_0, (1 - abar_t) I).
- Training: predict epsilon from x_t and t with simple MSE loss.
DDIM sampling (Song et al. 2020):
x_{t-1} = sqrt(abar_{t-1}) * x0_pred
+ sqrt(1 - abar_{t-1} - sigma_t^2) * eps_pred
+ sigma_t * z
With eta = 0 -> sigma_t = 0 -> deterministic. eta = 1 reproduces DDPM.
The class is model-agnostic: it just needs a `model(x, t) -> eps` callable.
"""
from __future__ import annotations
from typing import Callable, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Schedules
# ---------------------------------------------------------------------------
def linear_beta_schedule(timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2):
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def cosine_beta_schedule(timesteps: int, s: float = 0.008):
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
f = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
abar = f / f[0]
betas = 1 - (abar[1:] / abar[:-1])
return betas.clamp(1e-8, 0.999)
def make_betas(schedule: str, timesteps: int, beta_start: float, beta_end: float):
if schedule == "linear":
return linear_beta_schedule(timesteps, beta_start, beta_end)
if schedule == "cosine":
return cosine_beta_schedule(timesteps)
raise ValueError(f"unknown schedule {schedule}")
# ---------------------------------------------------------------------------
# Diffusion wrapper
# ---------------------------------------------------------------------------
def _gather(coef: torch.Tensor, t: torch.Tensor, target_shape) -> torch.Tensor:
"""Gather `coef` at indices `t` and reshape to broadcast against target_shape."""
out = coef.to(device=t.device).gather(0, t)
return out.reshape(t.shape[0], *([1] * (len(target_shape) - 1)))
class GaussianDiffusion(nn.Module):
def __init__(
self,
timesteps: int = 1000,
beta_start: float = 1e-4,
beta_end: float = 2e-2,
schedule: str = "linear",
):
super().__init__()
betas = make_betas(schedule, timesteps, beta_start, beta_end).float()
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
self.timesteps = timesteps
# buffers move with .to(device) and serialize
self.register_buffer("betas", betas)
self.register_buffer("alphas_cumprod", alphas_cumprod)
self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
# ------------------------------------------------------------------
# forward (training)
# ------------------------------------------------------------------
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None):
if noise is None:
noise = torch.randn_like(x0)
sa = _gather(self.sqrt_alphas_cumprod, t, x0.shape)
sma = _gather(self.sqrt_one_minus_alphas_cumprod, t, x0.shape)
return sa * x0 + sma * noise, noise
def training_loss(
self,
model: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
x0: torch.Tensor,
t: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if t is None:
t = torch.randint(0, self.timesteps, (x0.shape[0],), device=x0.device, dtype=torch.long)
x_t, noise = self.q_sample(x0, t)
eps_pred = model(x_t, t)
return F.mse_loss(eps_pred, noise)
# ------------------------------------------------------------------
# DDIM / DDPM sampling
# ------------------------------------------------------------------
@torch.no_grad()
def ddim_sample(
self,
model: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
shape,
num_steps: int = 50,
eta: float = 0.0,
x_T: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
return_trajectory: bool = False,
trajectory_stride: int = 1,
clip_x0: bool = True,
):
"""Run DDIM sampling.
eta=0 => deterministic DDIM. eta=1 => DDPM-equivalent stochastic.
Set num_steps == self.timesteps for a full DDPM-like schedule.
"""
device = device or self.betas.device
if x_T is None:
x_t = torch.randn(shape, device=device)
else:
x_t = x_T.to(device)
# uniform subsequence of timesteps, length num_steps, high -> low.
# Keep these as Python ints — indexing buffers with MPS tensors is
# buggy in some PyTorch builds (returns garbage indices).
step_list = torch.linspace(0, self.timesteps - 1, num_steps, dtype=torch.long).tolist()
step_list = list(reversed(step_list))
prev_list = step_list[1:] + [-1]
trajectory: List[torch.Tensor] = []
if return_trajectory:
trajectory.append(x_t.detach().cpu())
for i, (t_idx, t_prev) in enumerate(zip(step_list, prev_list)):
t_batch = torch.full((shape[0],), t_idx, device=device, dtype=torch.long)
eps = model(x_t, t_batch)
abar_t = self.alphas_cumprod[t_idx]
abar_prev = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=device)
x0_pred = (x_t - torch.sqrt(1.0 - abar_t) * eps) / torch.sqrt(abar_t)
if clip_x0:
x0_pred = x0_pred.clamp(-1.0, 1.0)
sigma_t = eta * torch.sqrt((1 - abar_prev) / (1 - abar_t)) * torch.sqrt(1 - abar_t / abar_prev)
dir_xt = torch.sqrt(torch.clamp(1 - abar_prev - sigma_t ** 2, min=0.0)) * eps
noise = torch.randn_like(x_t) if eta > 0 and t_prev >= 0 else torch.zeros_like(x_t)
x_t = torch.sqrt(abar_prev) * x0_pred + dir_xt + sigma_t * noise
if return_trajectory and ((i + 1) % trajectory_stride == 0 or i == len(step_list) - 1):
trajectory.append(x_t.detach().cpu())
if return_trajectory:
return x_t, trajectory
return x_t
# ---------------------------------------------------------------------------
# EMA helper (used during training to track a smoothed copy of weights)
# ---------------------------------------------------------------------------
class EMA:
"""EMA of model weights with shadow copy on CPU to save GPU memory.
For a 245M-param model the shadow takes ~1GB; keeping it off MPS frees
that memory for activations.
"""
def __init__(self, model: nn.Module, decay: float = 0.9999,
device: str = "cpu"):
self.decay = decay
self.device = torch.device(device)
self.shadow = {k: v.detach().to(self.device).clone()
for k, v in model.state_dict().items()}
@torch.no_grad()
def update(self, model: nn.Module):
for k, v in model.state_dict().items():
v_cpu = v.detach().to(self.device, non_blocking=False)
if v.dtype.is_floating_point:
self.shadow[k].mul_(self.decay).add_(v_cpu, alpha=1 - self.decay)
else:
self.shadow[k].copy_(v_cpu)
def state_dict(self):
return self.shadow
def load_state_dict(self, sd):
self.shadow = {k: v.to(self.device).clone() for k, v in sd.items()}
def copy_to(self, model: nn.Module):
# load_state_dict moves tensors to model's device automatically
target_device = next(model.parameters()).device
sd = {k: v.to(target_device) for k, v in self.shadow.items()}
model.load_state_dict(sd, strict=True)
# ---------------------------------------------------------------------------
# AdamW (manual)
# ---------------------------------------------------------------------------
class AdamW:
"""Hand-rolled AdamW with CPU-resident optimizer state.
Two reasons this is custom:
1. PyTorch 2.3.1's MPS AdamW kernel produces NaN parameters after one
step when some grads are exactly zero (which happens here because
several layers are zero-initialized). The Python impl is stable.
2. The first/second moment buffers (m, v) live on CPU, halving GPU
memory usage. We copy the grad to CPU each step, compute the AdamW
update on CPU, and copy only the resulting weight delta back to MPS.
For a 245M-param network this saves ~2GB of GPU memory.
"""
def __init__(self, params, lr: float = 2e-4, betas=(0.9, 0.999),
eps: float = 1e-8, weight_decay: float = 0.0,
state_device: str = "cpu"):
self.params = [p for p in params if p.requires_grad]
self.lr = lr
self.b1, self.b2 = betas
self.eps = eps
self.weight_decay = weight_decay
self.state_device = torch.device(state_device)
self.t = 0
self.m = [torch.zeros_like(p, device=self.state_device) for p in self.params]
self.v = [torch.zeros_like(p, device=self.state_device) for p in self.params]
def zero_grad(self, set_to_none: bool = True):
for p in self.params:
if p.grad is None:
continue
if set_to_none:
p.grad = None
else:
p.grad.zero_()
@torch.no_grad()
def step(self):
self.t += 1
bc1 = 1.0 - self.b1 ** self.t
bc2 = 1.0 - self.b2 ** self.t
for p, m, v in zip(self.params, self.m, self.v):
if p.grad is None:
continue
# bring grad to optimizer state device (typically CPU)
g = p.grad.to(self.state_device, non_blocking=False)
m.mul_(self.b1).add_(g, alpha=1 - self.b1)
v.mul_(self.b2).addcmul_(g, g, value=1 - self.b2)
m_hat = m / bc1
denom = (v / bc2).sqrt().add_(self.eps)
update = m_hat / denom # on state device
# decoupled weight decay is applied in-place on the param itself
if self.weight_decay > 0:
p.mul_(1.0 - self.lr * self.weight_decay)
p.add_(update.to(p.device, non_blocking=False), alpha=-self.lr)
def state_dict(self):
return {"t": self.t, "m": [x.clone() for x in self.m],
"v": [x.clone() for x in self.v],
"lr": self.lr, "b1": self.b1, "b2": self.b2,
"eps": self.eps, "weight_decay": self.weight_decay,
"state_device": str(self.state_device)}
def load_state_dict(self, sd):
self.t = sd["t"]
self.m = [x.to(self.state_device).clone() for x in sd["m"]]
self.v = [x.to(self.state_device).clone() for x in sd["v"]]
# ---------------------------------------------------------------------------
# Self-test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
torch.manual_seed(0)
diff = GaussianDiffusion(timesteps=100, beta_start=1e-4, beta_end=2e-2, schedule="linear")
assert diff.alphas_cumprod.shape == (100,)
assert diff.alphas_cumprod[0].item() < 1.0 and diff.alphas_cumprod[-1].item() < diff.alphas_cumprod[0].item()
# 1) q_sample at t=0 should be very close to x0 (since beta_0 ~ 0)
x0 = torch.randn(4, 3, 16, 16)
t0 = torch.zeros(4, dtype=torch.long)
xt, _ = diff.q_sample(x0, t0)
# at t=0, sqrt(1 - abar_0) ~ 1e-2, so noise contribution is small but nonzero
assert (xt - x0).abs().max().item() < 0.1, (xt - x0).abs().max().item()
# 2) at t=T-1 the marginal should be ~ pure noise (mean ~0, var ~1)
tT = torch.full((4,), diff.timesteps - 1, dtype=torch.long)
xtT, _ = diff.q_sample(x0, tT)
assert xtT.std().item() > 0.7
# 3) training loss with a dummy model that returns zeros should equal
# var of the noise (~1)
zero_model = lambda x, t: torch.zeros_like(x)
loss = diff.training_loss(zero_model, x0)
assert 0.5 < loss.item() < 1.5, loss.item()
# 4) DDIM sampling shape + determinism (eta=0)
class Identity(nn.Module):
def forward(self, x, t): return torch.zeros_like(x)
model = Identity()
out = diff.ddim_sample(model, (2, 3, 16, 16), num_steps=10, eta=0.0,
x_T=torch.randn(2, 3, 16, 16))
assert out.shape == (2, 3, 16, 16)
# determinism: same x_T -> same output
xT = torch.randn(2, 3, 16, 16)
a = diff.ddim_sample(model, (2, 3, 16, 16), num_steps=10, eta=0.0, x_T=xT.clone())
b = diff.ddim_sample(model, (2, 3, 16, 16), num_steps=10, eta=0.0, x_T=xT.clone())
assert torch.allclose(a, b, atol=1e-6)
# 5) trajectory return
out2, traj = diff.ddim_sample(model, (1, 3, 16, 16), num_steps=10, eta=0.0,
x_T=torch.randn(1, 3, 16, 16), return_trajectory=True)
# initial + 10 steps = 11 frames
assert len(traj) == 11, len(traj)
# 6) EMA
net = nn.Linear(4, 4)
ema = EMA(net, decay=0.5)
with torch.no_grad():
for p in net.parameters():
p.add_(torch.ones_like(p))
ema.update(net)
# shadow should have moved halfway toward new weights
for k, v in ema.shadow.items():
if v.dtype.is_floating_point:
assert (v - net.state_dict()[k]).abs().max() <= (net.state_dict()[k]).abs().max()
# 7) MPS round trip
if torch.backends.mps.is_available():
diff_mps = GaussianDiffusion(timesteps=50).to("mps")
x_mps = torch.randn(1, 3, 8, 8, device="mps")
out_mps = diff_mps.ddim_sample(Identity().to("mps"), (1, 3, 8, 8),
num_steps=5, eta=0.0)
assert out_mps.shape == (1, 3, 8, 8)
print("mps ok")
print("diffusion.py: all tests passed")