ddpm / app.py
SotaSF
Add files
621fad6
"""
DDPM - CelebA-HQ Face Generator
HuggingFace Spaces app β€” AliMusaRizvi/ddpm
"""
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import gradio as gr
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# ── Config ────────────────────────────────────────────────────────────────────
class Config:
TIMESTEPS = 1000
BETA_SCHEDULE = "cosine"
IMAGE_SIZE = 128
BASE_CHANNELS = 128
CHANNEL_MULTS = (1, 2, 2, 4)
ATTN_RESOLUTIONS = (16,)
NUM_RES_BLOCKS = 2
DROPOUT = 0.1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ── Helpers ───────────────────────────────────────────────────────────────────
def to_pil(tensor: torch.Tensor) -> np.ndarray:
img = (tensor.detach().cpu().clamp(-1, 1) + 1) / 2 * 255
return img.permute(1, 2, 0).byte().numpy()
# ── Noise schedule ────────────────────────────────────────────────────────────
def build_cosine_betas(timesteps: int, s: float = 0.008) -> torch.Tensor:
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_bar = torch.cos(((t / timesteps) + s) / (1 + s) * math.pi / 2) ** 2
alphas_bar = alphas_bar / alphas_bar[0]
betas = 1.0 - (alphas_bar[1:] / alphas_bar[:-1])
return betas.clamp(1e-5, 0.9999).float()
class DiffusionSchedule:
def __init__(self, timesteps: int = 1000, schedule: str = "cosine"):
self.T = timesteps
betas = build_cosine_betas(timesteps)
alphas = 1.0 - betas
abar = torch.cumprod(alphas, dim=0)
abar_prev = F.pad(abar[:-1], (1, 0), value=1.0)
self.betas = betas
self.alphas_bar = abar
self.alphas_bar_prev = abar_prev
self.sqrt_abar = abar.sqrt()
self.sqrt_one_minus_abar = (1 - abar).sqrt()
def _build_seq(self, num_steps):
skip = self.T // num_steps
return list(range(0, self.T, skip))
def _to(self, device):
for attr in ("betas", "alphas_bar", "alphas_bar_prev",
"sqrt_abar", "sqrt_one_minus_abar"):
setattr(self, attr, getattr(self, attr).to(device))
return self
def _ddim_step_backward(self, model, xt, t, t_prev, eta=0.0):
tbatch = torch.full((xt.shape[0],), t, device=DEVICE, dtype=torch.long)
eps = model(xt, tbatch)
ab_t = self.alphas_bar[t]
ab_prev = self.alphas_bar[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=DEVICE)
x0_pred = ((xt - (1 - ab_t).sqrt() * eps) / ab_t.sqrt()).clamp(-1, 1)
sigma = (eta
* ((1 - ab_prev) / (1 - ab_t)).sqrt()
* (1 - ab_t / ab_prev).sqrt())
dir_xt = (1 - ab_prev - sigma ** 2).clamp(min=0).sqrt() * eps
noise = torch.randn_like(xt) if (eta > 0 and t_prev >= 0) else torch.zeros_like(xt)
return ab_prev.sqrt() * x0_pred + dir_xt + sigma * noise
@torch.no_grad()
def ddim_sample(self, model, shape, num_steps=200, eta=0.0,
return_intermediates=False):
self._to(DEVICE)
seq = self._build_seq(num_steps)
xt = torch.randn(shape, device=DEVICE)
frames = []
for i in tqdm(reversed(range(len(seq))), total=len(seq),
desc="DDIM sampling", leave=False):
t = seq[i]
t_prev = seq[i - 1] if i > 0 else -1
xt = self._ddim_step_backward(model, xt, t, t_prev, eta)
if return_intermediates and i % max(1, len(seq) // 5) == 0:
frames.append(xt.clamp(-1, 1).cpu().clone())
result = xt.clamp(-1, 1).cpu()
return (result, frames) if return_intermediates else result
# ── Model architecture ────────────────────────────────────────────────────────
class GroupNormFP32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).to(x.dtype)
class TimeEmbedding(nn.Module):
def __init__(self, time_dim):
super().__init__()
self.time_dim = time_dim
self.mlp = nn.Sequential(
nn.Linear(time_dim // 4, time_dim),
nn.SiLU(),
nn.Linear(time_dim, time_dim),
)
def forward(self, t):
half = self.time_dim // 8
freqs = torch.exp(-math.log(10000) *
torch.arange(half, device=t.device) / half)
args = t[:, None].float() * freqs[None]
emb = torch.cat([args.sin(), args.cos()], dim=-1)
return self.mlp(emb)
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_dim, dropout=0.1):
super().__init__()
self.block1 = nn.Sequential(
GroupNormFP32(32, in_ch), nn.SiLU(),
nn.Conv2d(in_ch, out_ch, 3, padding=1),
)
self.time_proj = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch * 2))
self.block2 = nn.Sequential(
GroupNormFP32(32, out_ch), nn.SiLU(),
nn.Dropout(dropout),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
)
self.skip_proj = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x, t_emb):
h = self.block1(x)
scale, shift = self.time_proj(t_emb)[:, :, None, None].chunk(2, dim=1)
h = h * (1 + scale.clamp(-3, 3)) + shift.clamp(-3, 3)
return self.block2(h) + self.skip_proj(x)
class SelfAttention(nn.Module):
def __init__(self, channels, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.norm = GroupNormFP32(32, channels)
self.qkv = nn.Linear(channels, channels * 3, bias=True)
self.out_proj = nn.Linear(channels, channels, bias=True)
nn.init.xavier_uniform_(self.qkv.weight, gain=0.02)
nn.init.zeros_(self.qkv.bias)
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
def forward(self, x):
b, c, h, w = x.shape
tokens = self.norm(x).to(x.dtype).view(b, c, -1).permute(0, 2, 1)
q, k, v = self.qkv(tokens).chunk(3, dim=-1)
q = q.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
out = out.transpose(1, 2).reshape(b, -1, c)
return self.out_proj(out).permute(0, 2, 1).view(b, c, h, w) + x
class Downsample(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class Upsample(nn.Module):
def __init__(self, ch):
super().__init__()
self.seq = nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(ch, ch, 3, padding=1),
)
def forward(self, x):
return self.seq(x)
class UNet(nn.Module):
def __init__(self, in_ch=3, base_ch=128, ch_mults=(1,2,2,4),
attn_res=(16,), num_res_blocks=2, dropout=0.1, image_size=128):
super().__init__()
self.num_levels = len(ch_mults)
self.nrb = num_res_blocks
time_dim = base_ch * 4
self.time_embed = TimeEmbedding(time_dim)
self.init_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)
enc_res, enc_attn, enc_down = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
ch, cur_res, skips = base_ch, image_size, [base_ch]
for lvl, mult in enumerate(ch_mults):
out_ch = base_ch * mult
for _ in range(num_res_blocks):
enc_res.append(ResBlock(ch, out_ch, time_dim, dropout))
enc_attn.append(SelfAttention(out_ch) if cur_res in attn_res else nn.Identity())
skips.append(out_ch); ch = out_ch
if lvl < self.num_levels - 1:
enc_down.append(Downsample(ch)); skips.append(ch); cur_res //= 2
self.enc_res, self.enc_attn, self.enc_down = enc_res, enc_attn, enc_down
self.mid_res1 = ResBlock(ch, ch, time_dim, dropout)
self.mid_attn = SelfAttention(ch)
self.mid_res2 = ResBlock(ch, ch, time_dim, dropout)
dec_res, dec_attn, dec_up = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
rev_skips, sidx = list(reversed(skips)), 0
for lvl, mult in enumerate(reversed(ch_mults)):
out_ch = base_ch * mult
for _ in range(num_res_blocks + 1):
skip_ch = rev_skips[sidx]; sidx += 1
dec_res.append(ResBlock(ch + skip_ch, out_ch, time_dim, dropout))
dec_attn.append(SelfAttention(out_ch) if cur_res in attn_res else nn.Identity())
ch = out_ch
if lvl < self.num_levels - 1:
dec_up.append(Upsample(ch)); cur_res *= 2
self.dec_res, self.dec_attn, self.dec_up = dec_res, dec_attn, dec_up
self.out_norm = GroupNormFP32(32, ch)
self.out_act = nn.SiLU()
self.out_conv = nn.Conv2d(ch, in_ch, 3, padding=1)
def forward(self, x, t):
t_emb = self.time_embed(t)
h = self.init_conv(x)
stack = [h]
bidx = 0
for lvl in range(self.num_levels):
for _ in range(self.nrb):
h = self.enc_res[bidx](h, t_emb)
h = self.enc_attn[bidx](h)
stack.append(h); bidx += 1
if lvl < self.num_levels - 1:
h = self.enc_down[lvl](h); stack.append(h)
h = self.mid_res1(h, t_emb)
h = self.mid_attn(h)
h = self.mid_res2(h, t_emb)
bidx = 0
for lvl in range(self.num_levels):
for _ in range(self.nrb + 1):
h = torch.cat([h, stack.pop()], dim=1)
h = self.dec_res[bidx](h, t_emb)
h = self.dec_attn[bidx](h); bidx += 1
if lvl < self.num_levels - 1:
h = self.dec_up[lvl](h)
return self.out_conv(self.out_act(self.out_norm(h)))
# ── Load model (once at startup) ──────────────────────────────────────────────
print("Downloading model weights...")
config_path = hf_hub_download(repo_id="AliMusaRizvi/ddpm", filename="best_model_config.json")
weights_path = hf_hub_download(repo_id="AliMusaRizvi/ddpm", filename="best_model.safetensors")
with open(config_path) as f:
cfg = json.load(f)
model = UNet(
in_ch = cfg["in_ch"],
base_ch = cfg["base_ch"],
ch_mults = tuple(cfg["ch_mults"]),
attn_res = tuple(cfg["attn_res"]),
num_res_blocks = cfg["num_res_blocks"],
dropout = cfg["dropout"],
image_size = cfg["image_size"],
)
model.load_state_dict(load_file(weights_path, device="cpu"), strict=True)
model.to(DEVICE).eval()
print(f"Model ready on {DEVICE}")
schedule = DiffusionSchedule(Config.TIMESTEPS, Config.BETA_SCHEDULE)
# ── Gradio function ───────────────────────────────────────────────────────────
def generate_gradio(num_steps: int = 200, seed: int = 42):
torch.manual_seed(int(seed))
shape = (1, 3, Config.IMAGE_SIZE, Config.IMAGE_SIZE)
final_x, frames = schedule.ddim_sample(
model, shape,
num_steps=int(num_steps), eta=0.0,
return_intermediates=True,
)
# Denoising strip
n_show = min(len(frames), 6)
fig, axes = plt.subplots(1, n_show, figsize=(18, 3.5))
if n_show == 1:
axes = [axes]
for ax, frame in zip(axes, frames[:n_show]):
ax.imshow(to_pil(frame[0]))
ax.axis("off")
plt.suptitle("Denoising Steps", fontsize=12)
plt.tight_layout()
steps_path = "/tmp/ddpm_steps.png"
plt.savefig(steps_path, bbox_inches="tight", dpi=100)
plt.close()
final_path = "/tmp/ddpm_final.png"
Image.fromarray(to_pil(final_x[0])).save(final_path)
return final_path, steps_path
# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(title="DDPM - CelebA-HQ Face Generator") as demo:
gr.Markdown(
"## DDPM - Unconditional Face Generation\n"
"Generates a face from pure Gaussian noise using the trained diffusion model."
)
with gr.Row():
steps_slider = gr.Slider(50, 400, value=200, step=50, label="DDIM Steps")
seed_slider = gr.Slider(0, 9999, value=42, step=1, label="Random Seed")
run_btn = gr.Button("Generate Image", variant="primary")
with gr.Row():
out_final = gr.Image(label="Generated Face")
out_steps = gr.Image(label="Denoising Process")
run_btn.click(
fn = generate_gradio,
inputs = [steps_slider, seed_slider],
outputs = [out_final, out_steps],
)
demo.launch()