waleed-12's picture
Upload 2 files
48074a3 verified
"""
app.py β€” DDPM Image Generation Demo
Deploy on Hugging Face Spaces (SDK: gradio)
Repository structure expected:
.
β”œβ”€β”€ app.py ← this file
β”œβ”€β”€ requirements.txt
└── ddpm_model.pth ← your trained weights (upload via git-lfs)
"""
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import torchvision.utils as vutils
import gradio as gr
# ──────────────────────────────────────────────────────────────
# 1. CONFIGURATION (must match your training config exactly)
# ──────────────────────────────────────────────────────────────
IMG_SIZE = 128 # change to 256 if you trained at 256
BASE_CHANNELS = 64
TIME_EMB_DIM = 256
T = 300 # total diffusion timesteps
BETA_START = 1e-4
BETA_END = 0.02
MODEL_PATH = "ddpm_model.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ──────────────────────────────────────────────────────────────
# 2. MODEL ARCHITECTURE (identical to training notebook)
# ──────────────────────────────────────────────────────────────
class SinusoidalTimeEmbedding(nn.Module):
"""
Encodes integer timestep t into a fixed-dimensional vector using
sine / cosine positional encoding, then projects it through an MLP.
"""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.SiLU(),
nn.Linear(dim * 4, dim),
)
def forward(self, t: torch.Tensor) -> torch.Tensor:
half = self.dim // 2
freq = torch.exp(
-math.log(10_000) * torch.arange(half, device=t.device) / (half - 1)
)
args = t[:, None].float() * freq[None, :]
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
return self.mlp(emb)
class ResidualBlock(nn.Module):
"""Conv residual block with time-embedding injection (scale + shift)."""
def __init__(self, in_ch: int, out_ch: int, time_emb_dim: int,
groups: int = 8, dropout: float = 0.1):
super().__init__()
self.time_proj = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_ch * 2))
self.norm1 = nn.GroupNorm(groups, in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.norm2 = nn.GroupNorm(groups, out_ch)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
h = self.conv1(F.silu(self.norm1(x)))
scale, shift = self.time_proj(t_emb).chunk(2, dim=-1)
h = h * (scale[:, :, None, None] + 1) + shift[:, :, None, None]
h = self.conv2(self.dropout(F.silu(self.norm2(h))))
return h + self.shortcut(x)
class Downsample(nn.Module):
"""Halves spatial resolution via strided convolution."""
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Upsample(nn.Module):
"""Doubles spatial resolution via nearest-neighbour interpolation + conv."""
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(F.interpolate(x, scale_factor=2, mode="nearest"))
class UNet(nn.Module):
"""
Simplified U-Net for DDPM noise prediction.
Channel progression: 64 β†’ 128 β†’ 256 (encoder), mirrored in decoder.
"""
def __init__(self, in_channels: int = 3,
base_channels: int = 64,
time_emb_dim: int = 256):
super().__init__()
ch, ch2, ch4 = base_channels, base_channels * 2, base_channels * 4
T_DIM = time_emb_dim
# Time embedding
self.time_emb = SinusoidalTimeEmbedding(T_DIM)
self.init_conv = nn.Conv2d(in_channels, ch, 3, padding=1)
# Encoder
self.enc1_res1 = ResidualBlock(ch, ch, T_DIM)
self.enc1_res2 = ResidualBlock(ch, ch, T_DIM)
self.down1 = Downsample(ch)
self.enc2_res1 = ResidualBlock(ch, ch2, T_DIM)
self.enc2_res2 = ResidualBlock(ch2, ch2, T_DIM)
self.down2 = Downsample(ch2)
self.enc3_res1 = ResidualBlock(ch2, ch4, T_DIM)
self.enc3_res2 = ResidualBlock(ch4, ch4, T_DIM)
self.down3 = Downsample(ch4)
# Bottleneck
self.mid_res1 = ResidualBlock(ch4, ch4, T_DIM)
self.mid_res2 = ResidualBlock(ch4, ch4, T_DIM)
# Decoder
self.up3 = Upsample(ch4)
self.dec3_res1 = ResidualBlock(ch4 + ch4, ch4, T_DIM)
self.dec3_res2 = ResidualBlock(ch4, ch4, T_DIM)
self.up2 = Upsample(ch4)
self.dec2_res1 = ResidualBlock(ch4 + ch2, ch2, T_DIM)
self.dec2_res2 = ResidualBlock(ch2, ch2, T_DIM)
self.up1 = Upsample(ch2)
self.dec1_res1 = ResidualBlock(ch2 + ch, ch, T_DIM)
self.dec1_res2 = ResidualBlock(ch, ch, T_DIM)
# Output
self.out_norm = nn.GroupNorm(8, ch)
self.out_conv = nn.Conv2d(ch, in_channels, 1)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
t_emb = self.time_emb(t)
x0 = self.init_conv(x)
e1 = self.enc1_res2(self.enc1_res1(x0, t_emb), t_emb)
e1d = self.down1(e1)
e2 = self.enc2_res2(self.enc2_res1(e1d, t_emb), t_emb)
e2d = self.down2(e2)
e3 = self.enc3_res2(self.enc3_res1(e2d, t_emb), t_emb)
e3d = self.down3(e3)
b = self.mid_res2(self.mid_res1(e3d, t_emb), t_emb)
d3 = self.up3(b)
d3 = self.dec3_res2(self.dec3_res1(torch.cat([d3, e3], dim=1), t_emb), t_emb)
d2 = self.up2(d3)
d2 = self.dec2_res2(self.dec2_res1(torch.cat([d2, e2], dim=1), t_emb), t_emb)
d1 = self.up1(d2)
d1 = self.dec1_res2(self.dec1_res1(torch.cat([d1, e1], dim=1), t_emb), t_emb)
return self.out_conv(F.silu(self.out_norm(d1)))
# ──────────────────────────────────────────────────────────────
# 3. NOISE SCHEDULE (pre-computed tensors on DEVICE)
# ──────────────────────────────────────────────────────────────
betas = torch.linspace(BETA_START, BETA_END, T).to(DEVICE)
alphas = 1.0 - betas
alpha_hat = torch.cumprod(alphas, dim=0)
sqrt_1m_ah = torch.sqrt(1.0 - alpha_hat)
# ──────────────────────────────────────────────────────────────
# 4. LOAD MODEL WEIGHTS
# ──────────────────────────────────────────────────────────────
model = UNet(
in_channels = 3,
base_channels = BASE_CHANNELS,
time_emb_dim = TIME_EMB_DIM,
).to(DEVICE)
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
# Strip DataParallel "module." prefix if present
if any(k.startswith("module.") for k in state_dict):
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
print(f"[INFO] Model loaded from '{MODEL_PATH}' on {DEVICE}")
# ──────────────────────────────────────────────────────────────
# 5. HELPER: tensor β†’ PIL
# ──────────────────────────────────────────────────────────────
def tensor_to_pil(t: torch.Tensor) -> Image.Image:
"""Convert a (3, H, W) tensor in [-1, 1] to a uint8 PIL image."""
arr = (
t.squeeze().cpu().clamp(-1, 1)
.add(1).div(2) # β†’ [0, 1]
.mul(255).byte()
.permute(1, 2, 0) # β†’ (H, W, 3)
.numpy()
)
return Image.fromarray(arr)
# ──────────────────────────────────────────────────────────────
# 6. GENERATION FUNCTION (called by Gradio)
# ──────────────────────────────────────────────────────────────
@torch.no_grad()
def generate_image(n_vis_steps: int = 7) -> tuple[Image.Image, Image.Image]:
"""
Run the full DDPM reverse process (T β†’ 0).
Args:
n_vis_steps : how many intermediate frames to show in the
denoising-steps grid (evenly spaced across T)
Returns:
final_pil : PIL image of the final generated output
steps_pil : PIL image showing the denoising progression grid
"""
x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=DEVICE)
# Timesteps at which we capture intermediate frames
capture_at = set(
np.linspace(T - 1, 1, int(n_vis_steps), dtype=int).tolist()
)
frames: list[torch.Tensor] = []
for t_val in reversed(range(1, T)):
t_tensor = torch.full((1,), t_val, device=DEVICE, dtype=torch.long)
# U-Net predicts the noise at this timestep
eps_pred = model(x, t_tensor)
# DDPM reverse update
coeff = betas[t_val] / sqrt_1m_ah[t_val]
mean = (1.0 / torch.sqrt(alphas[t_val])) * (x - coeff * eps_pred)
if t_val > 1:
x = mean + torch.sqrt(betas[t_val]) * torch.randn_like(x)
else:
x = mean # final step: no extra noise
if t_val in capture_at:
frames.append(x.clone().cpu())
# ── Final generated image ────────────────────────────────
final_pil = tensor_to_pil(x)
# ── Intermediate steps grid ──────────────────────────────
if frames:
grid_tensor = torch.cat(frames, dim=0) # (n, 3, H, W)
grid = vutils.make_grid(
grid_tensor.clamp(-1, 1),
nrow = len(frames),
normalize = True,
value_range = (-1, 1),
)
steps_pil = Image.fromarray(
(grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
)
else:
steps_pil = final_pil
return final_pil, steps_pil
# ──────────────────────────────────────────────────────────────
# 7. GRADIO INTERFACE
# ──────────────────────────────────────────────────────────────
with gr.Blocks(title="DDPM Image Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ–ΌοΈ DDPM Image Generator
Generates a **new image from pure Gaussian noise** using a
Denoising Diffusion Probabilistic Model trained from scratch in PyTorch.
Click **Generate** to run the full reverse diffusion process.
The right panel shows intermediate denoising steps so you can
watch the image emerge from noise.
"""
)
with gr.Row():
n_steps_slider = gr.Slider(
minimum = 4,
maximum = 12,
value = 7,
step = 1,
label = "Number of intermediate steps to visualise",
)
with gr.Row():
btn = gr.Button("✨ Generate Image", variant="primary", scale=1)
with gr.Row():
out_final = gr.Image(
label = "Final Generated Image",
type = "pil",
height = IMG_SIZE * 2,
)
out_steps = gr.Image(
label = "Intermediate Denoising Steps (noise β†’ image)",
type = "pil",
)
btn.click(
fn = generate_image,
inputs = [n_steps_slider],
outputs = [out_final, out_steps],
)
gr.Markdown(
"""
---
**Model:** Custom U-Net (64β†’128β†’256 channels) trained with MSE loss on image noise.
**Assignment:** Generative AI (AI4009) β€” Spring 2026, NUCES.
"""
)
if __name__ == "__main__":
demo.launch()