Spaces:
Sleeping
Sleeping
File size: 11,064 Bytes
0ca4c93 80ef8f3 0ca4c93 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | """Gradio demo β DDIM Face Generation.
Single-page layout:
- Top: title + generate controls + output
- Middle: trajectory GIF + interpolation (collapsible)
- Bottom: how it works / architecture description
"""
from __future__ import annotations
import argparse
import os
import tempfile
from typing import Optional
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
import numpy as np
import torch
from PIL import Image
from sample import load_run
from utils.visualize import interpolate_latents, trajectory_to_gif, make_grid
# ---------------------------------------------------------------------------
# Global state β loaded once at startup
# ---------------------------------------------------------------------------
class State:
def __init__(self, ckpt_path: str, prefer_ema: bool = True):
if torch.backends.mps.is_available():
self.device = torch.device("mps")
elif torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.cfg, self.model, self.diffusion = load_run(ckpt_path, self.device, prefer_ema)
self.image_size = self.cfg.image_size
self.in_channels = self.cfg.in_channels
STATE: Optional[State] = None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _seeded(seed: Optional[int]) -> torch.Generator:
g = torch.Generator(device="cpu")
if seed is not None and seed >= 0:
g.manual_seed(int(seed))
return g
def _grid_pil(samples: torch.Tensor, nrow: int) -> Image.Image:
return Image.fromarray(make_grid(samples.cpu(), nrow=nrow))
# ---------------------------------------------------------------------------
# Callbacks
# ---------------------------------------------------------------------------
def cb_generate(num: int, steps: int, seed: float) -> Image.Image:
s = STATE
g = _seeded(int(seed))
shape = (int(num), s.in_channels, s.image_size, s.image_size)
x_T = torch.randn(*shape, generator=g).to(s.device)
with torch.no_grad():
out = s.diffusion.ddim_sample(s.model, shape, num_steps=int(steps),
eta=0.0, x_T=x_T, device=s.device)
nrow = int(np.ceil(np.sqrt(num)))
return _grid_pil(out, nrow)
def cb_trajectory(steps: int, seed: float) -> str:
s = STATE
g = _seeded(int(seed))
shape = (1, s.in_channels, s.image_size, s.image_size)
x_T = torch.randn(*shape, generator=g).to(s.device)
with torch.no_grad():
_, traj = s.diffusion.ddim_sample(
s.model, shape, num_steps=int(steps), eta=0.0,
x_T=x_T, device=s.device,
return_trajectory=True, trajectory_stride=1,
)
tmp = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
tmp.close()
trajectory_to_gif(traj, tmp.name, fps=12)
return tmp.name
def cb_interpolate(frames: int, steps: int, seed_a: float, seed_b: float) -> Image.Image:
s = STATE
shape_one = (1, s.in_channels, s.image_size, s.image_size)
z1 = torch.randn(*shape_one, generator=_seeded(int(seed_a)))
z2 = torch.randn(*shape_one, generator=_seeded(int(seed_b)))
latents = interpolate_latents(z1, z2, num_steps=int(frames)).squeeze(1).to(s.device)
with torch.no_grad():
out = s.diffusion.ddim_sample(
s.model, (int(frames), s.in_channels, s.image_size, s.image_size),
num_steps=int(steps), eta=0.0, x_T=latents, device=s.device,
)
return _grid_pil(out, int(frames))
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
TECH_MD = """
## How it works
This demo runs a **DDIM (Denoising Diffusion Implicit Model)** trained from scratch β no pretrained weights, no diffusers library.
### The core idea
A diffusion model learns to reverse a noise process. During training, we take a real face and progressively corrupt it with Gaussian noise over T=1000 steps until it's pure noise. The model (a U-Net) learns to predict the noise added at each step. At inference, we start from pure random noise and run the reverse process β but with DDIM we can skip most steps, getting a good result in just 20β50 steps instead of 1000.
### Architecture
```
Input (noise + timestep t)
β
ββββββΌβββββ
β U-Net β Channels: [64, 128, 256, 256]
β β Self-attention at 8Γ8 and 16Γ16 resolution
β Time β Sinusoidal time embedding β MLP β injected at every ResBlock
β Embed β GroupNorm + SiLU activations throughout
ββββββ¬βββββ
β
predicted Ξ΅ (noise)
```
The U-Net has:
- **4 resolution levels** with strided conv downsampling / nearest-neighbour upsampling
- **Residual blocks** with time-step conditioning (FiLM-style additive injection)
- **Multi-head self-attention** at the two lowest resolutions (8Γ8, 16Γ16)
- **EMA weights** used for inference β a running exponential average of training weights that produces cleaner samples
### Training
- **Dataset:** CelebA-HQ β 30,000 aligned face photographs at 256Γ256, resized to 64Γ64
- **Hardware:** Apple Mac Mini M-series (MPS backend), no cloud GPU
- **Duration:** ~100 epochs, ~14 hours total
- **Optimizer:** AdamW (CPU-resident state to avoid MPS memory pressure)
- **Loss:** simple MSE between predicted and actual noise β `L = ||Ξ΅ - Ξ΅_ΞΈ(x_t, t)||Β²`
- **Noise schedule:** linear Ξ² from 1Γ10β»β΄ β 0.02 over T=1000 steps
### Sampling modes
| Mode | What it shows |
|------|--------------|
| **Generate** | New faces sampled from pure Gaussian noise via DDIM |
| **Trajectory** | The full denoising path animated as a GIF β from noise to face |
| **Interpolate** | Spherical linear interpolation (slerp) between two noise vectors, showing a smooth transition between two generated faces |
### DDIM speedup
Standard DDPM requires T=1000 sequential network passes. DDIM uses a non-Markovian sampler that achieves comparable quality in 20β50 steps β a **20β50Γ speedup** with no retraining.
### Built entirely from scratch
Every component is hand-written in PyTorch:
`attention.py` Β· `unet.py` Β· `diffusion.py` Β· `dataset.py` Β· `train.py`
No Hugging Face Diffusers, no guided-diffusion, no pre-trained encoders.
"""
def build_ui():
import gradio as gr
s = STATE
max_steps = min(s.cfg.timesteps, 100) # cap at 100 for CPU
with gr.Blocks(title="DDIM Face Generation") as demo:
gr.Markdown("""
# π§ DDIM Face Generation
**Denoising Diffusion Implicit Model trained from scratch on CelebA-HQ.**
Generates novel human faces by reversing a learned noise process β no pretrained weights used.
> β±οΈ Running on CPU β generation takes ~30β60 seconds. Use **seed β₯ 0** to reproduce results.
""")
# ββ Generate ββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### βοΈ Controls")
num = gr.Slider(1, 9, value=4, step=1, label="Number of faces")
steps = gr.Slider(10, max_steps, value=20, step=5,
label="DDIM steps (more = sharper, slower)")
seed = gr.Number(value=-1, label="Seed (-1 = random each time)")
gen_btn = gr.Button("β¨ Generate Faces", variant="primary", size="lg")
with gr.Column(scale=2):
gr.Markdown("### πΌοΈ Output")
gen_out = gr.Image(label="Generated faces", type="pil",
show_label=False, height=400)
gen_btn.click(cb_generate, [num, steps, seed], gen_out)
gr.Markdown("---")
# ββ Trajectory & Interpolation (accordion) ββββββββββββββββββββ
with gr.Accordion("ποΈ Denoising Trajectory (noise β face GIF)", open=False):
gr.Markdown("Watch a single face emerge from pure Gaussian noise step by step.")
with gr.Row():
t_steps = gr.Slider(10, max_steps, value=20, step=5, label="Steps")
t_seed = gr.Number(value=42, label="Seed")
t_btn = gr.Button("Animate", variant="secondary")
t_out = gr.Image(label="Denoising trajectory", type="filepath")
t_btn.click(cb_trajectory, [t_steps, t_seed], t_out)
with gr.Accordion("π Latent Interpolation (face A β face B)", open=False):
gr.Markdown(
"Spherical linear interpolation (slerp) between two noise vectors β "
"each column is a smooth blend between two independently sampled faces."
)
with gr.Row():
i_frames = gr.Slider(4, 10, value=6, step=1, label="Frames")
i_steps = gr.Slider(10, max_steps, value=20, step=5, label="DDIM steps")
i_seed_a = gr.Number(value=0, label="Seed A")
i_seed_b = gr.Number(value=7, label="Seed B")
i_btn = gr.Button("Interpolate", variant="secondary")
i_out = gr.Image(label="A βΆ B interpolation", type="pil")
i_btn.click(cb_interpolate, [i_frames, i_steps, i_seed_a, i_seed_b], i_out)
gr.Markdown("---")
# ββ Tech description ββββββββββββββββββββββββββββββββββββββββββ
with gr.Accordion("π How it works β architecture, training & theory", open=False):
gr.Markdown(TECH_MD)
gr.Markdown(
"<div style='text-align:center;color:#888;font-size:0.85em'>"
"Built from scratch Β· PyTorch Β· CelebA-HQ Β· Apple Silicon Β· "
"<a href='https://github.com/Gh-Novel/DDIM_Image_Generation' target='_blank'>GitHub</a>"
"</div>"
)
return demo
# ---------------------------------------------------------------------------
DEFAULT_CKPT = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"checkpoints", "stage-64_best.pt")
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", default=DEFAULT_CKPT)
p.add_argument("--no-ema", action="store_true")
p.add_argument("--share", action="store_true")
p.add_argument("--port", type=int, default=7860)
return p.parse_args()
def main():
global STATE
args = parse_args()
STATE = State(args.ckpt, prefer_ema=not args.no_ema)
demo = build_ui()
demo.queue()
demo.launch(
server_name="0.0.0.0", # required for HF Spaces Docker
server_port=args.port,
share=args.share,
)
if __name__ == "__main__":
main()
|