""" app.py — DDPM Image Generation Demo Deploy on Hugging Face Spaces (SDK: gradio) Repository structure expected: . ├── app.py ← this file ├── requirements.txt └── ddpm_model.pth ← your trained weights (upload via git-lfs) """ import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import torchvision.utils as vutils import gradio as gr # ────────────────────────────────────────────────────────────── # 1. CONFIGURATION (must match your training config exactly) # ────────────────────────────────────────────────────────────── IMG_SIZE = 128 # change to 256 if you trained at 256 BASE_CHANNELS = 64 TIME_EMB_DIM = 256 T = 300 # total diffusion timesteps BETA_START = 1e-4 BETA_END = 0.02 MODEL_PATH = "ddpm_model.pth" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ────────────────────────────────────────────────────────────── # 2. MODEL ARCHITECTURE (identical to training notebook) # ────────────────────────────────────────────────────────────── class SinusoidalTimeEmbedding(nn.Module): """ Encodes integer timestep t into a fixed-dimensional vector using sine / cosine positional encoding, then projects it through an MLP. """ def __init__(self, dim: int): super().__init__() self.dim = dim self.mlp = nn.Sequential( nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim), ) def forward(self, t: torch.Tensor) -> torch.Tensor: half = self.dim // 2 freq = torch.exp( -math.log(10_000) * torch.arange(half, device=t.device) / (half - 1) ) args = t[:, None].float() * freq[None, :] emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) return self.mlp(emb) class ResidualBlock(nn.Module): """Conv residual block with time-embedding injection (scale + shift).""" def __init__(self, in_ch: int, out_ch: int, time_emb_dim: int, groups: int = 8, dropout: float = 0.1): super().__init__() self.time_proj = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_ch * 2)) self.norm1 = nn.GroupNorm(groups, in_ch) self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) self.norm2 = nn.GroupNorm(groups, out_ch) self.dropout = nn.Dropout(dropout) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: h = self.conv1(F.silu(self.norm1(x))) scale, shift = self.time_proj(t_emb).chunk(2, dim=-1) h = h * (scale[:, :, None, None] + 1) + shift[:, :, None, None] h = self.conv2(self.dropout(F.silu(self.norm2(h)))) return h + self.shortcut(x) class Downsample(nn.Module): """Halves spatial resolution via strided convolution.""" def __init__(self, channels: int): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) class Upsample(nn.Module): """Doubles spatial resolution via nearest-neighbour interpolation + conv.""" def __init__(self, channels: int): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(F.interpolate(x, scale_factor=2, mode="nearest")) class UNet(nn.Module): """ Simplified U-Net for DDPM noise prediction. Channel progression: 64 → 128 → 256 (encoder), mirrored in decoder. """ def __init__(self, in_channels: int = 3, base_channels: int = 64, time_emb_dim: int = 256): super().__init__() ch, ch2, ch4 = base_channels, base_channels * 2, base_channels * 4 T_DIM = time_emb_dim # Time embedding self.time_emb = SinusoidalTimeEmbedding(T_DIM) self.init_conv = nn.Conv2d(in_channels, ch, 3, padding=1) # Encoder self.enc1_res1 = ResidualBlock(ch, ch, T_DIM) self.enc1_res2 = ResidualBlock(ch, ch, T_DIM) self.down1 = Downsample(ch) self.enc2_res1 = ResidualBlock(ch, ch2, T_DIM) self.enc2_res2 = ResidualBlock(ch2, ch2, T_DIM) self.down2 = Downsample(ch2) self.enc3_res1 = ResidualBlock(ch2, ch4, T_DIM) self.enc3_res2 = ResidualBlock(ch4, ch4, T_DIM) self.down3 = Downsample(ch4) # Bottleneck self.mid_res1 = ResidualBlock(ch4, ch4, T_DIM) self.mid_res2 = ResidualBlock(ch4, ch4, T_DIM) # Decoder self.up3 = Upsample(ch4) self.dec3_res1 = ResidualBlock(ch4 + ch4, ch4, T_DIM) self.dec3_res2 = ResidualBlock(ch4, ch4, T_DIM) self.up2 = Upsample(ch4) self.dec2_res1 = ResidualBlock(ch4 + ch2, ch2, T_DIM) self.dec2_res2 = ResidualBlock(ch2, ch2, T_DIM) self.up1 = Upsample(ch2) self.dec1_res1 = ResidualBlock(ch2 + ch, ch, T_DIM) self.dec1_res2 = ResidualBlock(ch, ch, T_DIM) # Output self.out_norm = nn.GroupNorm(8, ch) self.out_conv = nn.Conv2d(ch, in_channels, 1) def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: t_emb = self.time_emb(t) x0 = self.init_conv(x) e1 = self.enc1_res2(self.enc1_res1(x0, t_emb), t_emb) e1d = self.down1(e1) e2 = self.enc2_res2(self.enc2_res1(e1d, t_emb), t_emb) e2d = self.down2(e2) e3 = self.enc3_res2(self.enc3_res1(e2d, t_emb), t_emb) e3d = self.down3(e3) b = self.mid_res2(self.mid_res1(e3d, t_emb), t_emb) d3 = self.up3(b) d3 = self.dec3_res2(self.dec3_res1(torch.cat([d3, e3], dim=1), t_emb), t_emb) d2 = self.up2(d3) d2 = self.dec2_res2(self.dec2_res1(torch.cat([d2, e2], dim=1), t_emb), t_emb) d1 = self.up1(d2) d1 = self.dec1_res2(self.dec1_res1(torch.cat([d1, e1], dim=1), t_emb), t_emb) return self.out_conv(F.silu(self.out_norm(d1))) # ────────────────────────────────────────────────────────────── # 3. NOISE SCHEDULE (pre-computed tensors on DEVICE) # ────────────────────────────────────────────────────────────── betas = torch.linspace(BETA_START, BETA_END, T).to(DEVICE) alphas = 1.0 - betas alpha_hat = torch.cumprod(alphas, dim=0) sqrt_1m_ah = torch.sqrt(1.0 - alpha_hat) # ────────────────────────────────────────────────────────────── # 4. LOAD MODEL WEIGHTS # ────────────────────────────────────────────────────────────── model = UNet( in_channels = 3, base_channels = BASE_CHANNELS, time_emb_dim = TIME_EMB_DIM, ).to(DEVICE) state_dict = torch.load(MODEL_PATH, map_location=DEVICE) # Strip DataParallel "module." prefix if present if any(k.startswith("module.") for k in state_dict): state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.eval() print(f"[INFO] Model loaded from '{MODEL_PATH}' on {DEVICE}") # ────────────────────────────────────────────────────────────── # 5. HELPER: tensor → PIL # ────────────────────────────────────────────────────────────── def tensor_to_pil(t: torch.Tensor) -> Image.Image: """Convert a (3, H, W) tensor in [-1, 1] to a uint8 PIL image.""" arr = ( t.squeeze().cpu().clamp(-1, 1) .add(1).div(2) # → [0, 1] .mul(255).byte() .permute(1, 2, 0) # → (H, W, 3) .numpy() ) return Image.fromarray(arr) # ────────────────────────────────────────────────────────────── # 6. GENERATION FUNCTION (called by Gradio) # ────────────────────────────────────────────────────────────── @torch.no_grad() def generate_image(n_vis_steps: int = 7) -> tuple[Image.Image, Image.Image]: """ Run the full DDPM reverse process (T → 0). Args: n_vis_steps : how many intermediate frames to show in the denoising-steps grid (evenly spaced across T) Returns: final_pil : PIL image of the final generated output steps_pil : PIL image showing the denoising progression grid """ x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=DEVICE) # Timesteps at which we capture intermediate frames capture_at = set( np.linspace(T - 1, 1, int(n_vis_steps), dtype=int).tolist() ) frames: list[torch.Tensor] = [] for t_val in reversed(range(1, T)): t_tensor = torch.full((1,), t_val, device=DEVICE, dtype=torch.long) # U-Net predicts the noise at this timestep eps_pred = model(x, t_tensor) # DDPM reverse update coeff = betas[t_val] / sqrt_1m_ah[t_val] mean = (1.0 / torch.sqrt(alphas[t_val])) * (x - coeff * eps_pred) if t_val > 1: x = mean + torch.sqrt(betas[t_val]) * torch.randn_like(x) else: x = mean # final step: no extra noise if t_val in capture_at: frames.append(x.clone().cpu()) # ── Final generated image ──────────────────────────────── final_pil = tensor_to_pil(x) # ── Intermediate steps grid ────────────────────────────── if frames: grid_tensor = torch.cat(frames, dim=0) # (n, 3, H, W) grid = vutils.make_grid( grid_tensor.clamp(-1, 1), nrow = len(frames), normalize = True, value_range = (-1, 1), ) steps_pil = Image.fromarray( (grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8) ) else: steps_pil = final_pil return final_pil, steps_pil # ────────────────────────────────────────────────────────────── # 7. GRADIO INTERFACE # ────────────────────────────────────────────────────────────── with gr.Blocks(title="DDPM Image Generator", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🖼️ DDPM Image Generator Generates a **new image from pure Gaussian noise** using a Denoising Diffusion Probabilistic Model trained from scratch in PyTorch. Click **Generate** to run the full reverse diffusion process. The right panel shows intermediate denoising steps so you can watch the image emerge from noise. """ ) with gr.Row(): n_steps_slider = gr.Slider( minimum = 4, maximum = 12, value = 7, step = 1, label = "Number of intermediate steps to visualise", ) with gr.Row(): btn = gr.Button("✨ Generate Image", variant="primary", scale=1) with gr.Row(): out_final = gr.Image( label = "Final Generated Image", type = "pil", height = IMG_SIZE * 2, ) out_steps = gr.Image( label = "Intermediate Denoising Steps (noise → image)", type = "pil", ) btn.click( fn = generate_image, inputs = [n_steps_slider], outputs = [out_final, out_steps], ) gr.Markdown( """ --- **Model:** Custom U-Net (64→128→256 channels) trained with MSE loss on image noise. **Assignment:** Generative AI (AI4009) — Spring 2026, NUCES. """ ) if __name__ == "__main__": demo.launch()