| """ |
| app.py β DDPM Image Generation Demo |
| Deploy on Hugging Face Spaces (SDK: gradio) |
| |
| Repository structure expected: |
| . |
| βββ app.py β this file |
| βββ requirements.txt |
| βββ ddpm_model.pth β your trained weights (upload via git-lfs) |
| """ |
|
|
| import math |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
| import torchvision.utils as vutils |
| import gradio as gr |
|
|
| |
| |
| |
| IMG_SIZE = 128 |
| BASE_CHANNELS = 64 |
| TIME_EMB_DIM = 256 |
| T = 300 |
| BETA_START = 1e-4 |
| BETA_END = 0.02 |
| MODEL_PATH = "ddpm_model.pth" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| |
| |
| |
|
|
| class SinusoidalTimeEmbedding(nn.Module): |
| """ |
| Encodes integer timestep t into a fixed-dimensional vector using |
| sine / cosine positional encoding, then projects it through an MLP. |
| """ |
| def __init__(self, dim: int): |
| super().__init__() |
| self.dim = dim |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, dim * 4), |
| nn.SiLU(), |
| nn.Linear(dim * 4, dim), |
| ) |
|
|
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| half = self.dim // 2 |
| freq = torch.exp( |
| -math.log(10_000) * torch.arange(half, device=t.device) / (half - 1) |
| ) |
| args = t[:, None].float() * freq[None, :] |
| emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) |
| return self.mlp(emb) |
|
|
|
|
| class ResidualBlock(nn.Module): |
| """Conv residual block with time-embedding injection (scale + shift).""" |
|
|
| def __init__(self, in_ch: int, out_ch: int, time_emb_dim: int, |
| groups: int = 8, dropout: float = 0.1): |
| super().__init__() |
| self.time_proj = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_ch * 2)) |
| self.norm1 = nn.GroupNorm(groups, in_ch) |
| self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) |
| self.norm2 = nn.GroupNorm(groups, out_ch) |
| self.dropout = nn.Dropout(dropout) |
| self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
| self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() |
|
|
| def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: |
| h = self.conv1(F.silu(self.norm1(x))) |
| scale, shift = self.time_proj(t_emb).chunk(2, dim=-1) |
| h = h * (scale[:, :, None, None] + 1) + shift[:, :, None, None] |
| h = self.conv2(self.dropout(F.silu(self.norm2(h)))) |
| return h + self.shortcut(x) |
|
|
|
|
| class Downsample(nn.Module): |
| """Halves spatial resolution via strided convolution.""" |
| def __init__(self, channels: int): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.conv(x) |
|
|
|
|
| class Upsample(nn.Module): |
| """Doubles spatial resolution via nearest-neighbour interpolation + conv.""" |
| def __init__(self, channels: int): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, 3, padding=1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.conv(F.interpolate(x, scale_factor=2, mode="nearest")) |
|
|
|
|
| class UNet(nn.Module): |
| """ |
| Simplified U-Net for DDPM noise prediction. |
| Channel progression: 64 β 128 β 256 (encoder), mirrored in decoder. |
| """ |
|
|
| def __init__(self, in_channels: int = 3, |
| base_channels: int = 64, |
| time_emb_dim: int = 256): |
| super().__init__() |
| ch, ch2, ch4 = base_channels, base_channels * 2, base_channels * 4 |
| T_DIM = time_emb_dim |
|
|
| |
| self.time_emb = SinusoidalTimeEmbedding(T_DIM) |
| self.init_conv = nn.Conv2d(in_channels, ch, 3, padding=1) |
|
|
| |
| self.enc1_res1 = ResidualBlock(ch, ch, T_DIM) |
| self.enc1_res2 = ResidualBlock(ch, ch, T_DIM) |
| self.down1 = Downsample(ch) |
|
|
| self.enc2_res1 = ResidualBlock(ch, ch2, T_DIM) |
| self.enc2_res2 = ResidualBlock(ch2, ch2, T_DIM) |
| self.down2 = Downsample(ch2) |
|
|
| self.enc3_res1 = ResidualBlock(ch2, ch4, T_DIM) |
| self.enc3_res2 = ResidualBlock(ch4, ch4, T_DIM) |
| self.down3 = Downsample(ch4) |
|
|
| |
| self.mid_res1 = ResidualBlock(ch4, ch4, T_DIM) |
| self.mid_res2 = ResidualBlock(ch4, ch4, T_DIM) |
|
|
| |
| self.up3 = Upsample(ch4) |
| self.dec3_res1 = ResidualBlock(ch4 + ch4, ch4, T_DIM) |
| self.dec3_res2 = ResidualBlock(ch4, ch4, T_DIM) |
|
|
| self.up2 = Upsample(ch4) |
| self.dec2_res1 = ResidualBlock(ch4 + ch2, ch2, T_DIM) |
| self.dec2_res2 = ResidualBlock(ch2, ch2, T_DIM) |
|
|
| self.up1 = Upsample(ch2) |
| self.dec1_res1 = ResidualBlock(ch2 + ch, ch, T_DIM) |
| self.dec1_res2 = ResidualBlock(ch, ch, T_DIM) |
|
|
| |
| self.out_norm = nn.GroupNorm(8, ch) |
| self.out_conv = nn.Conv2d(ch, in_channels, 1) |
|
|
| def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| t_emb = self.time_emb(t) |
|
|
| x0 = self.init_conv(x) |
|
|
| e1 = self.enc1_res2(self.enc1_res1(x0, t_emb), t_emb) |
| e1d = self.down1(e1) |
|
|
| e2 = self.enc2_res2(self.enc2_res1(e1d, t_emb), t_emb) |
| e2d = self.down2(e2) |
|
|
| e3 = self.enc3_res2(self.enc3_res1(e2d, t_emb), t_emb) |
| e3d = self.down3(e3) |
|
|
| b = self.mid_res2(self.mid_res1(e3d, t_emb), t_emb) |
|
|
| d3 = self.up3(b) |
| d3 = self.dec3_res2(self.dec3_res1(torch.cat([d3, e3], dim=1), t_emb), t_emb) |
|
|
| d2 = self.up2(d3) |
| d2 = self.dec2_res2(self.dec2_res1(torch.cat([d2, e2], dim=1), t_emb), t_emb) |
|
|
| d1 = self.up1(d2) |
| d1 = self.dec1_res2(self.dec1_res1(torch.cat([d1, e1], dim=1), t_emb), t_emb) |
|
|
| return self.out_conv(F.silu(self.out_norm(d1))) |
|
|
|
|
| |
| |
| |
| betas = torch.linspace(BETA_START, BETA_END, T).to(DEVICE) |
| alphas = 1.0 - betas |
| alpha_hat = torch.cumprod(alphas, dim=0) |
| sqrt_1m_ah = torch.sqrt(1.0 - alpha_hat) |
|
|
|
|
| |
| |
| |
| model = UNet( |
| in_channels = 3, |
| base_channels = BASE_CHANNELS, |
| time_emb_dim = TIME_EMB_DIM, |
| ).to(DEVICE) |
|
|
| state_dict = torch.load(MODEL_PATH, map_location=DEVICE) |
|
|
| |
| if any(k.startswith("module.") for k in state_dict): |
| state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
| model.load_state_dict(state_dict) |
| model.eval() |
| print(f"[INFO] Model loaded from '{MODEL_PATH}' on {DEVICE}") |
|
|
|
|
| |
| |
| |
| def tensor_to_pil(t: torch.Tensor) -> Image.Image: |
| """Convert a (3, H, W) tensor in [-1, 1] to a uint8 PIL image.""" |
| arr = ( |
| t.squeeze().cpu().clamp(-1, 1) |
| .add(1).div(2) |
| .mul(255).byte() |
| .permute(1, 2, 0) |
| .numpy() |
| ) |
| return Image.fromarray(arr) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def generate_image(n_vis_steps: int = 7) -> tuple[Image.Image, Image.Image]: |
| """ |
| Run the full DDPM reverse process (T β 0). |
| |
| Args: |
| n_vis_steps : how many intermediate frames to show in the |
| denoising-steps grid (evenly spaced across T) |
| Returns: |
| final_pil : PIL image of the final generated output |
| steps_pil : PIL image showing the denoising progression grid |
| """ |
| x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=DEVICE) |
|
|
| |
| capture_at = set( |
| np.linspace(T - 1, 1, int(n_vis_steps), dtype=int).tolist() |
| ) |
| frames: list[torch.Tensor] = [] |
|
|
| for t_val in reversed(range(1, T)): |
| t_tensor = torch.full((1,), t_val, device=DEVICE, dtype=torch.long) |
|
|
| |
| eps_pred = model(x, t_tensor) |
|
|
| |
| coeff = betas[t_val] / sqrt_1m_ah[t_val] |
| mean = (1.0 / torch.sqrt(alphas[t_val])) * (x - coeff * eps_pred) |
|
|
| if t_val > 1: |
| x = mean + torch.sqrt(betas[t_val]) * torch.randn_like(x) |
| else: |
| x = mean |
|
|
| if t_val in capture_at: |
| frames.append(x.clone().cpu()) |
|
|
| |
| final_pil = tensor_to_pil(x) |
|
|
| |
| if frames: |
| grid_tensor = torch.cat(frames, dim=0) |
| grid = vutils.make_grid( |
| grid_tensor.clamp(-1, 1), |
| nrow = len(frames), |
| normalize = True, |
| value_range = (-1, 1), |
| ) |
| steps_pil = Image.fromarray( |
| (grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
| ) |
| else: |
| steps_pil = final_pil |
|
|
| return final_pil, steps_pil |
|
|
|
|
| |
| |
| |
| with gr.Blocks(title="DDPM Image Generator", theme=gr.themes.Soft()) as demo: |
|
|
| gr.Markdown( |
| """ |
| # πΌοΈ DDPM Image Generator |
| Generates a **new image from pure Gaussian noise** using a |
| Denoising Diffusion Probabilistic Model trained from scratch in PyTorch. |
| |
| Click **Generate** to run the full reverse diffusion process. |
| The right panel shows intermediate denoising steps so you can |
| watch the image emerge from noise. |
| """ |
| ) |
|
|
| with gr.Row(): |
| n_steps_slider = gr.Slider( |
| minimum = 4, |
| maximum = 12, |
| value = 7, |
| step = 1, |
| label = "Number of intermediate steps to visualise", |
| ) |
|
|
| with gr.Row(): |
| btn = gr.Button("β¨ Generate Image", variant="primary", scale=1) |
|
|
| with gr.Row(): |
| out_final = gr.Image( |
| label = "Final Generated Image", |
| type = "pil", |
| height = IMG_SIZE * 2, |
| ) |
| out_steps = gr.Image( |
| label = "Intermediate Denoising Steps (noise β image)", |
| type = "pil", |
| ) |
|
|
| btn.click( |
| fn = generate_image, |
| inputs = [n_steps_slider], |
| outputs = [out_final, out_steps], |
| ) |
|
|
| gr.Markdown( |
| """ |
| --- |
| **Model:** Custom U-Net (64β128β256 channels) trained with MSE loss on image noise. |
| **Assignment:** Generative AI (AI4009) β Spring 2026, NUCES. |
| """ |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|