mini-ddpm / app.py
caixiaoshun's picture
Update app.py
2b3facb verified
# Gradio 应用:mini-DDPM
import math
import random
from typing import Optional, Dict
import gradio as gr
import numpy as np
from PIL import Image
import torch
import torchvision
from model import Unet
from scheduler import DDPMSchedule
CHECKPOINT_URL = (
"https://huggingface.co/caixiaoshun/mini-ddpm/resolve/main/checkpoints.pt"
)
# -----------------------------
# 工具函数
# -----------------------------
def seed_all(seed: Optional[int] = 42):
if seed is None:
return
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# 为了速度;如需完全复现,可改为 False
torch.backends.cudnn.benchmark = True
_MODEL_CACHE: Dict[str, torch.nn.Module] = {}
_SCHED_CACHE: Dict[str, DDPMSchedule] = {}
def get_device() -> torch.device:
return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def load_model(device: torch.device) -> torch.nn.Module:
key = str(device)
if key in _MODEL_CACHE:
return _MODEL_CACHE[key]
state_dict = torch.hub.load_state_dict_from_url(
CHECKPOINT_URL, map_location=device, file_name="checkpoints.pt", progress=True
)
model = Unet().to(device)
model.load_state_dict(state_dict)
model.eval()
_MODEL_CACHE[key] = model
return model
def load_scheduler(device: torch.device) -> DDPMSchedule:
key = str(device)
if key in _SCHED_CACHE:
return _SCHED_CACHE[key]
sched = DDPMSchedule().to(device)
_SCHED_CACHE[key] = sched
return sched
# -----------------------------
# 采样核心
# -----------------------------
def sample_ddpm(
num_samples: int = 9,
seed: Optional[int] = 42,
progress: Optional[gr.Progress] = None,
) -> Image.Image:
# 将样本数调整为最接近的平方数,便于拼图展示
side = int(math.sqrt(max(1, num_samples)))
num_samples = max(1, side * side)
device = get_device()
seed_all(seed)
model = load_model(device)
scheduler = load_scheduler(device)
x_t = torch.randn(num_samples, 3, 128, 128, device=device)
total = getattr(scheduler, "steps", 1000)
model.eval()
with torch.no_grad():
for t in range(total - 1, -1, -1):
if progress is not None:
done = total - t
progress(done / total, desc=f"t = {t}")
time = torch.full((num_samples, 1), t, device=device, dtype=torch.long)
pred_noise = model(x_t, time)
x_t = scheduler.step(x_t, pred_noise, time)
# [-1,1] -> [0,1]
x_0 = (x_t + 1) / 2
grid = torchvision.utils.make_grid(x_0, nrow=side, padding=2)
grid = grid.clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy()
img = Image.fromarray((grid * 255).astype(np.uint8))
return img
# -----------------------------
# Gradio UI
# -----------------------------
def ui_generate(num_samples, seed):
# seed < 0 或 None 表示使用随机种子
real_seed = None if seed is None or int(seed) < 0 else int(seed)
return sample_ddpm(
num_samples=int(num_samples),
seed=real_seed,
progress=gr.Progress(),
)
def build_demo() -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft(), title="mini-DDPM") as demo:
with gr.Row():
with gr.Column(scale=2):
out_img = gr.Image(type="pil", label="采样结果")
with gr.Column(scale=1):
num_samples = gr.Slider(
minimum=1,
maximum=25,
step=1,
value=9,
label="样本数量",
)
seed = gr.Number(value=42, precision=0, label="随机种子")
btn = gr.Button("开始生成", variant="primary")
btn.click(
fn=ui_generate,
inputs=[num_samples, seed],
outputs=[out_img],
)
return demo
if __name__ == "__main__":
demo = build_demo()
# 如果部署在远程,可设置 server_name="0.0.0.0"
demo.queue().launch()