| """ |
| DDPM - CelebA-HQ Face Generator |
| HuggingFace Spaces app β AliMusaRizvi/ddpm |
| """ |
|
|
| import math |
| import json |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| from tqdm import tqdm |
| import gradio as gr |
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file |
|
|
|
|
| |
|
|
| class Config: |
| TIMESTEPS = 1000 |
| BETA_SCHEDULE = "cosine" |
| IMAGE_SIZE = 128 |
| BASE_CHANNELS = 128 |
| CHANNEL_MULTS = (1, 2, 2, 4) |
| ATTN_RESOLUTIONS = (16,) |
| NUM_RES_BLOCKS = 2 |
| DROPOUT = 0.1 |
|
|
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
|
|
| def to_pil(tensor: torch.Tensor) -> np.ndarray: |
| img = (tensor.detach().cpu().clamp(-1, 1) + 1) / 2 * 255 |
| return img.permute(1, 2, 0).byte().numpy() |
|
|
|
|
| |
|
|
| def build_cosine_betas(timesteps: int, s: float = 0.008) -> torch.Tensor: |
| steps = timesteps + 1 |
| t = torch.linspace(0, timesteps, steps, dtype=torch.float64) |
| alphas_bar = torch.cos(((t / timesteps) + s) / (1 + s) * math.pi / 2) ** 2 |
| alphas_bar = alphas_bar / alphas_bar[0] |
| betas = 1.0 - (alphas_bar[1:] / alphas_bar[:-1]) |
| return betas.clamp(1e-5, 0.9999).float() |
|
|
|
|
| class DiffusionSchedule: |
| def __init__(self, timesteps: int = 1000, schedule: str = "cosine"): |
| self.T = timesteps |
| betas = build_cosine_betas(timesteps) |
|
|
| alphas = 1.0 - betas |
| abar = torch.cumprod(alphas, dim=0) |
| abar_prev = F.pad(abar[:-1], (1, 0), value=1.0) |
|
|
| self.betas = betas |
| self.alphas_bar = abar |
| self.alphas_bar_prev = abar_prev |
| self.sqrt_abar = abar.sqrt() |
| self.sqrt_one_minus_abar = (1 - abar).sqrt() |
|
|
| def _build_seq(self, num_steps): |
| skip = self.T // num_steps |
| return list(range(0, self.T, skip)) |
|
|
| def _to(self, device): |
| for attr in ("betas", "alphas_bar", "alphas_bar_prev", |
| "sqrt_abar", "sqrt_one_minus_abar"): |
| setattr(self, attr, getattr(self, attr).to(device)) |
| return self |
|
|
| def _ddim_step_backward(self, model, xt, t, t_prev, eta=0.0): |
| tbatch = torch.full((xt.shape[0],), t, device=DEVICE, dtype=torch.long) |
| eps = model(xt, tbatch) |
| ab_t = self.alphas_bar[t] |
| ab_prev = self.alphas_bar[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=DEVICE) |
| x0_pred = ((xt - (1 - ab_t).sqrt() * eps) / ab_t.sqrt()).clamp(-1, 1) |
| sigma = (eta |
| * ((1 - ab_prev) / (1 - ab_t)).sqrt() |
| * (1 - ab_t / ab_prev).sqrt()) |
| dir_xt = (1 - ab_prev - sigma ** 2).clamp(min=0).sqrt() * eps |
| noise = torch.randn_like(xt) if (eta > 0 and t_prev >= 0) else torch.zeros_like(xt) |
| return ab_prev.sqrt() * x0_pred + dir_xt + sigma * noise |
|
|
| @torch.no_grad() |
| def ddim_sample(self, model, shape, num_steps=200, eta=0.0, |
| return_intermediates=False): |
| self._to(DEVICE) |
| seq = self._build_seq(num_steps) |
| xt = torch.randn(shape, device=DEVICE) |
| frames = [] |
|
|
| for i in tqdm(reversed(range(len(seq))), total=len(seq), |
| desc="DDIM sampling", leave=False): |
| t = seq[i] |
| t_prev = seq[i - 1] if i > 0 else -1 |
| xt = self._ddim_step_backward(model, xt, t, t_prev, eta) |
| if return_intermediates and i % max(1, len(seq) // 5) == 0: |
| frames.append(xt.clamp(-1, 1).cpu().clone()) |
|
|
| result = xt.clamp(-1, 1).cpu() |
| return (result, frames) if return_intermediates else result |
|
|
|
|
| |
|
|
| class GroupNormFP32(nn.GroupNorm): |
| def forward(self, x): |
| return super().forward(x.float()).to(x.dtype) |
|
|
|
|
| class TimeEmbedding(nn.Module): |
| def __init__(self, time_dim): |
| super().__init__() |
| self.time_dim = time_dim |
| self.mlp = nn.Sequential( |
| nn.Linear(time_dim // 4, time_dim), |
| nn.SiLU(), |
| nn.Linear(time_dim, time_dim), |
| ) |
|
|
| def forward(self, t): |
| half = self.time_dim // 8 |
| freqs = torch.exp(-math.log(10000) * |
| torch.arange(half, device=t.device) / half) |
| args = t[:, None].float() * freqs[None] |
| emb = torch.cat([args.sin(), args.cos()], dim=-1) |
| return self.mlp(emb) |
|
|
|
|
| class ResBlock(nn.Module): |
| def __init__(self, in_ch, out_ch, time_dim, dropout=0.1): |
| super().__init__() |
| self.block1 = nn.Sequential( |
| GroupNormFP32(32, in_ch), nn.SiLU(), |
| nn.Conv2d(in_ch, out_ch, 3, padding=1), |
| ) |
| self.time_proj = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch * 2)) |
| self.block2 = nn.Sequential( |
| GroupNormFP32(32, out_ch), nn.SiLU(), |
| nn.Dropout(dropout), |
| nn.Conv2d(out_ch, out_ch, 3, padding=1), |
| ) |
| self.skip_proj = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() |
|
|
| def forward(self, x, t_emb): |
| h = self.block1(x) |
| scale, shift = self.time_proj(t_emb)[:, :, None, None].chunk(2, dim=1) |
| h = h * (1 + scale.clamp(-3, 3)) + shift.clamp(-3, 3) |
| return self.block2(h) + self.skip_proj(x) |
|
|
|
|
| class SelfAttention(nn.Module): |
| def __init__(self, channels, num_heads=4): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = channels // num_heads |
| self.norm = GroupNormFP32(32, channels) |
| self.qkv = nn.Linear(channels, channels * 3, bias=True) |
| self.out_proj = nn.Linear(channels, channels, bias=True) |
| nn.init.xavier_uniform_(self.qkv.weight, gain=0.02) |
| nn.init.zeros_(self.qkv.bias) |
| nn.init.zeros_(self.out_proj.weight) |
| nn.init.zeros_(self.out_proj.bias) |
|
|
| def forward(self, x): |
| b, c, h, w = x.shape |
| tokens = self.norm(x).to(x.dtype).view(b, c, -1).permute(0, 2, 1) |
| q, k, v = self.qkv(tokens).chunk(3, dim=-1) |
| q = q.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| k = k.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| v = v.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) |
| out = out.transpose(1, 2).reshape(b, -1, c) |
| return self.out_proj(out).permute(0, 2, 1).view(b, c, h, w) + x |
|
|
|
|
| class Downsample(nn.Module): |
| def __init__(self, ch): |
| super().__init__() |
| self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1) |
|
|
| def forward(self, x): |
| return self.conv(x) |
|
|
|
|
| class Upsample(nn.Module): |
| def __init__(self, ch): |
| super().__init__() |
| self.seq = nn.Sequential( |
| nn.Upsample(scale_factor=2, mode="nearest"), |
| nn.Conv2d(ch, ch, 3, padding=1), |
| ) |
|
|
| def forward(self, x): |
| return self.seq(x) |
|
|
|
|
| class UNet(nn.Module): |
| def __init__(self, in_ch=3, base_ch=128, ch_mults=(1,2,2,4), |
| attn_res=(16,), num_res_blocks=2, dropout=0.1, image_size=128): |
| super().__init__() |
| self.num_levels = len(ch_mults) |
| self.nrb = num_res_blocks |
| time_dim = base_ch * 4 |
|
|
| self.time_embed = TimeEmbedding(time_dim) |
| self.init_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1) |
|
|
| enc_res, enc_attn, enc_down = nn.ModuleList(), nn.ModuleList(), nn.ModuleList() |
| ch, cur_res, skips = base_ch, image_size, [base_ch] |
|
|
| for lvl, mult in enumerate(ch_mults): |
| out_ch = base_ch * mult |
| for _ in range(num_res_blocks): |
| enc_res.append(ResBlock(ch, out_ch, time_dim, dropout)) |
| enc_attn.append(SelfAttention(out_ch) if cur_res in attn_res else nn.Identity()) |
| skips.append(out_ch); ch = out_ch |
| if lvl < self.num_levels - 1: |
| enc_down.append(Downsample(ch)); skips.append(ch); cur_res //= 2 |
|
|
| self.enc_res, self.enc_attn, self.enc_down = enc_res, enc_attn, enc_down |
| self.mid_res1 = ResBlock(ch, ch, time_dim, dropout) |
| self.mid_attn = SelfAttention(ch) |
| self.mid_res2 = ResBlock(ch, ch, time_dim, dropout) |
|
|
| dec_res, dec_attn, dec_up = nn.ModuleList(), nn.ModuleList(), nn.ModuleList() |
| rev_skips, sidx = list(reversed(skips)), 0 |
|
|
| for lvl, mult in enumerate(reversed(ch_mults)): |
| out_ch = base_ch * mult |
| for _ in range(num_res_blocks + 1): |
| skip_ch = rev_skips[sidx]; sidx += 1 |
| dec_res.append(ResBlock(ch + skip_ch, out_ch, time_dim, dropout)) |
| dec_attn.append(SelfAttention(out_ch) if cur_res in attn_res else nn.Identity()) |
| ch = out_ch |
| if lvl < self.num_levels - 1: |
| dec_up.append(Upsample(ch)); cur_res *= 2 |
|
|
| self.dec_res, self.dec_attn, self.dec_up = dec_res, dec_attn, dec_up |
| self.out_norm = GroupNormFP32(32, ch) |
| self.out_act = nn.SiLU() |
| self.out_conv = nn.Conv2d(ch, in_ch, 3, padding=1) |
|
|
| def forward(self, x, t): |
| t_emb = self.time_embed(t) |
| h = self.init_conv(x) |
| stack = [h] |
|
|
| bidx = 0 |
| for lvl in range(self.num_levels): |
| for _ in range(self.nrb): |
| h = self.enc_res[bidx](h, t_emb) |
| h = self.enc_attn[bidx](h) |
| stack.append(h); bidx += 1 |
| if lvl < self.num_levels - 1: |
| h = self.enc_down[lvl](h); stack.append(h) |
|
|
| h = self.mid_res1(h, t_emb) |
| h = self.mid_attn(h) |
| h = self.mid_res2(h, t_emb) |
|
|
| bidx = 0 |
| for lvl in range(self.num_levels): |
| for _ in range(self.nrb + 1): |
| h = torch.cat([h, stack.pop()], dim=1) |
| h = self.dec_res[bidx](h, t_emb) |
| h = self.dec_attn[bidx](h); bidx += 1 |
| if lvl < self.num_levels - 1: |
| h = self.dec_up[lvl](h) |
|
|
| return self.out_conv(self.out_act(self.out_norm(h))) |
|
|
|
|
| |
|
|
| print("Downloading model weights...") |
| config_path = hf_hub_download(repo_id="AliMusaRizvi/ddpm", filename="best_model_config.json") |
| weights_path = hf_hub_download(repo_id="AliMusaRizvi/ddpm", filename="best_model.safetensors") |
|
|
| with open(config_path) as f: |
| cfg = json.load(f) |
|
|
| model = UNet( |
| in_ch = cfg["in_ch"], |
| base_ch = cfg["base_ch"], |
| ch_mults = tuple(cfg["ch_mults"]), |
| attn_res = tuple(cfg["attn_res"]), |
| num_res_blocks = cfg["num_res_blocks"], |
| dropout = cfg["dropout"], |
| image_size = cfg["image_size"], |
| ) |
| model.load_state_dict(load_file(weights_path, device="cpu"), strict=True) |
| model.to(DEVICE).eval() |
| print(f"Model ready on {DEVICE}") |
|
|
| schedule = DiffusionSchedule(Config.TIMESTEPS, Config.BETA_SCHEDULE) |
|
|
|
|
| |
|
|
| def generate_gradio(num_steps: int = 200, seed: int = 42): |
| torch.manual_seed(int(seed)) |
| shape = (1, 3, Config.IMAGE_SIZE, Config.IMAGE_SIZE) |
|
|
| final_x, frames = schedule.ddim_sample( |
| model, shape, |
| num_steps=int(num_steps), eta=0.0, |
| return_intermediates=True, |
| ) |
|
|
| |
| n_show = min(len(frames), 6) |
| fig, axes = plt.subplots(1, n_show, figsize=(18, 3.5)) |
| if n_show == 1: |
| axes = [axes] |
| for ax, frame in zip(axes, frames[:n_show]): |
| ax.imshow(to_pil(frame[0])) |
| ax.axis("off") |
| plt.suptitle("Denoising Steps", fontsize=12) |
| plt.tight_layout() |
| steps_path = "/tmp/ddpm_steps.png" |
| plt.savefig(steps_path, bbox_inches="tight", dpi=100) |
| plt.close() |
|
|
| final_path = "/tmp/ddpm_final.png" |
| Image.fromarray(to_pil(final_x[0])).save(final_path) |
|
|
| return final_path, steps_path |
|
|
|
|
| |
|
|
| with gr.Blocks(title="DDPM - CelebA-HQ Face Generator") as demo: |
| gr.Markdown( |
| "## DDPM - Unconditional Face Generation\n" |
| "Generates a face from pure Gaussian noise using the trained diffusion model." |
| ) |
| with gr.Row(): |
| steps_slider = gr.Slider(50, 400, value=200, step=50, label="DDIM Steps") |
| seed_slider = gr.Slider(0, 9999, value=42, step=1, label="Random Seed") |
| run_btn = gr.Button("Generate Image", variant="primary") |
| with gr.Row(): |
| out_final = gr.Image(label="Generated Face") |
| out_steps = gr.Image(label="Denoising Process") |
| run_btn.click( |
| fn = generate_gradio, |
| inputs = [steps_slider, seed_slider], |
| outputs = [out_final, out_steps], |
| ) |
|
|
| demo.launch() |
|
|