# Gradio 应用:mini-DDPM import math import random from typing import Optional, Dict import gradio as gr import numpy as np from PIL import Image import torch import torchvision from model import Unet from scheduler import DDPMSchedule CHECKPOINT_URL = ( "https://huggingface.co/caixiaoshun/mini-ddpm/resolve/main/checkpoints.pt" ) # ----------------------------- # 工具函数 # ----------------------------- def seed_all(seed: Optional[int] = 42): if seed is None: return random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # 为了速度;如需完全复现,可改为 False torch.backends.cudnn.benchmark = True _MODEL_CACHE: Dict[str, torch.nn.Module] = {} _SCHED_CACHE: Dict[str, DDPMSchedule] = {} def get_device() -> torch.device: return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") def load_model(device: torch.device) -> torch.nn.Module: key = str(device) if key in _MODEL_CACHE: return _MODEL_CACHE[key] state_dict = torch.hub.load_state_dict_from_url( CHECKPOINT_URL, map_location=device, file_name="checkpoints.pt", progress=True ) model = Unet().to(device) model.load_state_dict(state_dict) model.eval() _MODEL_CACHE[key] = model return model def load_scheduler(device: torch.device) -> DDPMSchedule: key = str(device) if key in _SCHED_CACHE: return _SCHED_CACHE[key] sched = DDPMSchedule().to(device) _SCHED_CACHE[key] = sched return sched # ----------------------------- # 采样核心 # ----------------------------- def sample_ddpm( num_samples: int = 9, seed: Optional[int] = 42, progress: Optional[gr.Progress] = None, ) -> Image.Image: # 将样本数调整为最接近的平方数,便于拼图展示 side = int(math.sqrt(max(1, num_samples))) num_samples = max(1, side * side) device = get_device() seed_all(seed) model = load_model(device) scheduler = load_scheduler(device) x_t = torch.randn(num_samples, 3, 128, 128, device=device) total = getattr(scheduler, "steps", 1000) model.eval() with torch.no_grad(): for t in range(total - 1, -1, -1): if progress is not None: done = total - t progress(done / total, desc=f"t = {t}") time = torch.full((num_samples, 1), t, device=device, dtype=torch.long) pred_noise = model(x_t, time) x_t = scheduler.step(x_t, pred_noise, time) # [-1,1] -> [0,1] x_0 = (x_t + 1) / 2 grid = torchvision.utils.make_grid(x_0, nrow=side, padding=2) grid = grid.clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy() img = Image.fromarray((grid * 255).astype(np.uint8)) return img # ----------------------------- # Gradio UI # ----------------------------- def ui_generate(num_samples, seed): # seed < 0 或 None 表示使用随机种子 real_seed = None if seed is None or int(seed) < 0 else int(seed) return sample_ddpm( num_samples=int(num_samples), seed=real_seed, progress=gr.Progress(), ) def build_demo() -> gr.Blocks: with gr.Blocks(theme=gr.themes.Soft(), title="mini-DDPM") as demo: with gr.Row(): with gr.Column(scale=2): out_img = gr.Image(type="pil", label="采样结果") with gr.Column(scale=1): num_samples = gr.Slider( minimum=1, maximum=25, step=1, value=9, label="样本数量", ) seed = gr.Number(value=42, precision=0, label="随机种子") btn = gr.Button("开始生成", variant="primary") btn.click( fn=ui_generate, inputs=[num_samples, seed], outputs=[out_img], ) return demo if __name__ == "__main__": demo = build_demo() # 如果部署在远程,可设置 server_name="0.0.0.0" demo.queue().launch()