""" DDPM - CelebA-HQ Face Generator HuggingFace Spaces app — AliMusaRizvi/ddpm """ import math import json import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image import matplotlib.pyplot as plt from tqdm import tqdm import gradio as gr from huggingface_hub import hf_hub_download from safetensors.torch import load_file # ── Config ──────────────────────────────────────────────────────────────────── class Config: TIMESTEPS = 1000 BETA_SCHEDULE = "cosine" IMAGE_SIZE = 128 BASE_CHANNELS = 128 CHANNEL_MULTS = (1, 2, 2, 4) ATTN_RESOLUTIONS = (16,) NUM_RES_BLOCKS = 2 DROPOUT = 0.1 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ── Helpers ─────────────────────────────────────────────────────────────────── def to_pil(tensor: torch.Tensor) -> np.ndarray: img = (tensor.detach().cpu().clamp(-1, 1) + 1) / 2 * 255 return img.permute(1, 2, 0).byte().numpy() # ── Noise schedule ──────────────────────────────────────────────────────────── def build_cosine_betas(timesteps: int, s: float = 0.008) -> torch.Tensor: steps = timesteps + 1 t = torch.linspace(0, timesteps, steps, dtype=torch.float64) alphas_bar = torch.cos(((t / timesteps) + s) / (1 + s) * math.pi / 2) ** 2 alphas_bar = alphas_bar / alphas_bar[0] betas = 1.0 - (alphas_bar[1:] / alphas_bar[:-1]) return betas.clamp(1e-5, 0.9999).float() class DiffusionSchedule: def __init__(self, timesteps: int = 1000, schedule: str = "cosine"): self.T = timesteps betas = build_cosine_betas(timesteps) alphas = 1.0 - betas abar = torch.cumprod(alphas, dim=0) abar_prev = F.pad(abar[:-1], (1, 0), value=1.0) self.betas = betas self.alphas_bar = abar self.alphas_bar_prev = abar_prev self.sqrt_abar = abar.sqrt() self.sqrt_one_minus_abar = (1 - abar).sqrt() def _build_seq(self, num_steps): skip = self.T // num_steps return list(range(0, self.T, skip)) def _to(self, device): for attr in ("betas", "alphas_bar", "alphas_bar_prev", "sqrt_abar", "sqrt_one_minus_abar"): setattr(self, attr, getattr(self, attr).to(device)) return self def _ddim_step_backward(self, model, xt, t, t_prev, eta=0.0): tbatch = torch.full((xt.shape[0],), t, device=DEVICE, dtype=torch.long) eps = model(xt, tbatch) ab_t = self.alphas_bar[t] ab_prev = self.alphas_bar[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=DEVICE) x0_pred = ((xt - (1 - ab_t).sqrt() * eps) / ab_t.sqrt()).clamp(-1, 1) sigma = (eta * ((1 - ab_prev) / (1 - ab_t)).sqrt() * (1 - ab_t / ab_prev).sqrt()) dir_xt = (1 - ab_prev - sigma ** 2).clamp(min=0).sqrt() * eps noise = torch.randn_like(xt) if (eta > 0 and t_prev >= 0) else torch.zeros_like(xt) return ab_prev.sqrt() * x0_pred + dir_xt + sigma * noise @torch.no_grad() def ddim_sample(self, model, shape, num_steps=200, eta=0.0, return_intermediates=False): self._to(DEVICE) seq = self._build_seq(num_steps) xt = torch.randn(shape, device=DEVICE) frames = [] for i in tqdm(reversed(range(len(seq))), total=len(seq), desc="DDIM sampling", leave=False): t = seq[i] t_prev = seq[i - 1] if i > 0 else -1 xt = self._ddim_step_backward(model, xt, t, t_prev, eta) if return_intermediates and i % max(1, len(seq) // 5) == 0: frames.append(xt.clamp(-1, 1).cpu().clone()) result = xt.clamp(-1, 1).cpu() return (result, frames) if return_intermediates else result # ── Model architecture ──────────────────────────────────────────────────────── class GroupNormFP32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).to(x.dtype) class TimeEmbedding(nn.Module): def __init__(self, time_dim): super().__init__() self.time_dim = time_dim self.mlp = nn.Sequential( nn.Linear(time_dim // 4, time_dim), nn.SiLU(), nn.Linear(time_dim, time_dim), ) def forward(self, t): half = self.time_dim // 8 freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half) args = t[:, None].float() * freqs[None] emb = torch.cat([args.sin(), args.cos()], dim=-1) return self.mlp(emb) class ResBlock(nn.Module): def __init__(self, in_ch, out_ch, time_dim, dropout=0.1): super().__init__() self.block1 = nn.Sequential( GroupNormFP32(32, in_ch), nn.SiLU(), nn.Conv2d(in_ch, out_ch, 3, padding=1), ) self.time_proj = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch * 2)) self.block2 = nn.Sequential( GroupNormFP32(32, out_ch), nn.SiLU(), nn.Dropout(dropout), nn.Conv2d(out_ch, out_ch, 3, padding=1), ) self.skip_proj = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() def forward(self, x, t_emb): h = self.block1(x) scale, shift = self.time_proj(t_emb)[:, :, None, None].chunk(2, dim=1) h = h * (1 + scale.clamp(-3, 3)) + shift.clamp(-3, 3) return self.block2(h) + self.skip_proj(x) class SelfAttention(nn.Module): def __init__(self, channels, num_heads=4): super().__init__() self.num_heads = num_heads self.head_dim = channels // num_heads self.norm = GroupNormFP32(32, channels) self.qkv = nn.Linear(channels, channels * 3, bias=True) self.out_proj = nn.Linear(channels, channels, bias=True) nn.init.xavier_uniform_(self.qkv.weight, gain=0.02) nn.init.zeros_(self.qkv.bias) nn.init.zeros_(self.out_proj.weight) nn.init.zeros_(self.out_proj.bias) def forward(self, x): b, c, h, w = x.shape tokens = self.norm(x).to(x.dtype).view(b, c, -1).permute(0, 2, 1) q, k, v = self.qkv(tokens).chunk(3, dim=-1) q = q.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) out = out.transpose(1, 2).reshape(b, -1, c) return self.out_proj(out).permute(0, 2, 1).view(b, c, h, w) + x class Downsample(nn.Module): def __init__(self, ch): super().__init__() self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1) def forward(self, x): return self.conv(x) class Upsample(nn.Module): def __init__(self, ch): super().__init__() self.seq = nn.Sequential( nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(ch, ch, 3, padding=1), ) def forward(self, x): return self.seq(x) class UNet(nn.Module): def __init__(self, in_ch=3, base_ch=128, ch_mults=(1,2,2,4), attn_res=(16,), num_res_blocks=2, dropout=0.1, image_size=128): super().__init__() self.num_levels = len(ch_mults) self.nrb = num_res_blocks time_dim = base_ch * 4 self.time_embed = TimeEmbedding(time_dim) self.init_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1) enc_res, enc_attn, enc_down = nn.ModuleList(), nn.ModuleList(), nn.ModuleList() ch, cur_res, skips = base_ch, image_size, [base_ch] for lvl, mult in enumerate(ch_mults): out_ch = base_ch * mult for _ in range(num_res_blocks): enc_res.append(ResBlock(ch, out_ch, time_dim, dropout)) enc_attn.append(SelfAttention(out_ch) if cur_res in attn_res else nn.Identity()) skips.append(out_ch); ch = out_ch if lvl < self.num_levels - 1: enc_down.append(Downsample(ch)); skips.append(ch); cur_res //= 2 self.enc_res, self.enc_attn, self.enc_down = enc_res, enc_attn, enc_down self.mid_res1 = ResBlock(ch, ch, time_dim, dropout) self.mid_attn = SelfAttention(ch) self.mid_res2 = ResBlock(ch, ch, time_dim, dropout) dec_res, dec_attn, dec_up = nn.ModuleList(), nn.ModuleList(), nn.ModuleList() rev_skips, sidx = list(reversed(skips)), 0 for lvl, mult in enumerate(reversed(ch_mults)): out_ch = base_ch * mult for _ in range(num_res_blocks + 1): skip_ch = rev_skips[sidx]; sidx += 1 dec_res.append(ResBlock(ch + skip_ch, out_ch, time_dim, dropout)) dec_attn.append(SelfAttention(out_ch) if cur_res in attn_res else nn.Identity()) ch = out_ch if lvl < self.num_levels - 1: dec_up.append(Upsample(ch)); cur_res *= 2 self.dec_res, self.dec_attn, self.dec_up = dec_res, dec_attn, dec_up self.out_norm = GroupNormFP32(32, ch) self.out_act = nn.SiLU() self.out_conv = nn.Conv2d(ch, in_ch, 3, padding=1) def forward(self, x, t): t_emb = self.time_embed(t) h = self.init_conv(x) stack = [h] bidx = 0 for lvl in range(self.num_levels): for _ in range(self.nrb): h = self.enc_res[bidx](h, t_emb) h = self.enc_attn[bidx](h) stack.append(h); bidx += 1 if lvl < self.num_levels - 1: h = self.enc_down[lvl](h); stack.append(h) h = self.mid_res1(h, t_emb) h = self.mid_attn(h) h = self.mid_res2(h, t_emb) bidx = 0 for lvl in range(self.num_levels): for _ in range(self.nrb + 1): h = torch.cat([h, stack.pop()], dim=1) h = self.dec_res[bidx](h, t_emb) h = self.dec_attn[bidx](h); bidx += 1 if lvl < self.num_levels - 1: h = self.dec_up[lvl](h) return self.out_conv(self.out_act(self.out_norm(h))) # ── Load model (once at startup) ────────────────────────────────────────────── print("Downloading model weights...") config_path = hf_hub_download(repo_id="AliMusaRizvi/ddpm", filename="best_model_config.json") weights_path = hf_hub_download(repo_id="AliMusaRizvi/ddpm", filename="best_model.safetensors") with open(config_path) as f: cfg = json.load(f) model = UNet( in_ch = cfg["in_ch"], base_ch = cfg["base_ch"], ch_mults = tuple(cfg["ch_mults"]), attn_res = tuple(cfg["attn_res"]), num_res_blocks = cfg["num_res_blocks"], dropout = cfg["dropout"], image_size = cfg["image_size"], ) model.load_state_dict(load_file(weights_path, device="cpu"), strict=True) model.to(DEVICE).eval() print(f"Model ready on {DEVICE}") schedule = DiffusionSchedule(Config.TIMESTEPS, Config.BETA_SCHEDULE) # ── Gradio function ─────────────────────────────────────────────────────────── def generate_gradio(num_steps: int = 200, seed: int = 42): torch.manual_seed(int(seed)) shape = (1, 3, Config.IMAGE_SIZE, Config.IMAGE_SIZE) final_x, frames = schedule.ddim_sample( model, shape, num_steps=int(num_steps), eta=0.0, return_intermediates=True, ) # Denoising strip n_show = min(len(frames), 6) fig, axes = plt.subplots(1, n_show, figsize=(18, 3.5)) if n_show == 1: axes = [axes] for ax, frame in zip(axes, frames[:n_show]): ax.imshow(to_pil(frame[0])) ax.axis("off") plt.suptitle("Denoising Steps", fontsize=12) plt.tight_layout() steps_path = "/tmp/ddpm_steps.png" plt.savefig(steps_path, bbox_inches="tight", dpi=100) plt.close() final_path = "/tmp/ddpm_final.png" Image.fromarray(to_pil(final_x[0])).save(final_path) return final_path, steps_path # ── Gradio UI ───────────────────────────────────────────────────────────────── with gr.Blocks(title="DDPM - CelebA-HQ Face Generator") as demo: gr.Markdown( "## DDPM - Unconditional Face Generation\n" "Generates a face from pure Gaussian noise using the trained diffusion model." ) with gr.Row(): steps_slider = gr.Slider(50, 400, value=200, step=50, label="DDIM Steps") seed_slider = gr.Slider(0, 9999, value=42, step=1, label="Random Seed") run_btn = gr.Button("Generate Image", variant="primary") with gr.Row(): out_final = gr.Image(label="Generated Face") out_steps = gr.Image(label="Denoising Process") run_btn.click( fn = generate_gradio, inputs = [steps_slider, seed_slider], outputs = [out_final, out_steps], ) demo.launch()