Spaces:
Runtime error
Runtime error
| # Gradio 应用:mini-DDPM | |
| import math | |
| import random | |
| from typing import Optional, Dict | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torchvision | |
| from model import Unet | |
| from scheduler import DDPMSchedule | |
| CHECKPOINT_URL = ( | |
| "https://huggingface.co/caixiaoshun/mini-ddpm/resolve/main/checkpoints.pt" | |
| ) | |
| # ----------------------------- | |
| # 工具函数 | |
| # ----------------------------- | |
| def seed_all(seed: Optional[int] = 42): | |
| if seed is None: | |
| return | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| # 为了速度;如需完全复现,可改为 False | |
| torch.backends.cudnn.benchmark = True | |
| _MODEL_CACHE: Dict[str, torch.nn.Module] = {} | |
| _SCHED_CACHE: Dict[str, DDPMSchedule] = {} | |
| def get_device() -> torch.device: | |
| return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| def load_model(device: torch.device) -> torch.nn.Module: | |
| key = str(device) | |
| if key in _MODEL_CACHE: | |
| return _MODEL_CACHE[key] | |
| state_dict = torch.hub.load_state_dict_from_url( | |
| CHECKPOINT_URL, map_location=device, file_name="checkpoints.pt", progress=True | |
| ) | |
| model = Unet().to(device) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| _MODEL_CACHE[key] = model | |
| return model | |
| def load_scheduler(device: torch.device) -> DDPMSchedule: | |
| key = str(device) | |
| if key in _SCHED_CACHE: | |
| return _SCHED_CACHE[key] | |
| sched = DDPMSchedule().to(device) | |
| _SCHED_CACHE[key] = sched | |
| return sched | |
| # ----------------------------- | |
| # 采样核心 | |
| # ----------------------------- | |
| def sample_ddpm( | |
| num_samples: int = 9, | |
| seed: Optional[int] = 42, | |
| progress: Optional[gr.Progress] = None, | |
| ) -> Image.Image: | |
| # 将样本数调整为最接近的平方数,便于拼图展示 | |
| side = int(math.sqrt(max(1, num_samples))) | |
| num_samples = max(1, side * side) | |
| device = get_device() | |
| seed_all(seed) | |
| model = load_model(device) | |
| scheduler = load_scheduler(device) | |
| x_t = torch.randn(num_samples, 3, 128, 128, device=device) | |
| total = getattr(scheduler, "steps", 1000) | |
| model.eval() | |
| with torch.no_grad(): | |
| for t in range(total - 1, -1, -1): | |
| if progress is not None: | |
| done = total - t | |
| progress(done / total, desc=f"t = {t}") | |
| time = torch.full((num_samples, 1), t, device=device, dtype=torch.long) | |
| pred_noise = model(x_t, time) | |
| x_t = scheduler.step(x_t, pred_noise, time) | |
| # [-1,1] -> [0,1] | |
| x_0 = (x_t + 1) / 2 | |
| grid = torchvision.utils.make_grid(x_0, nrow=side, padding=2) | |
| grid = grid.clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy() | |
| img = Image.fromarray((grid * 255).astype(np.uint8)) | |
| return img | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| def ui_generate(num_samples, seed): | |
| # seed < 0 或 None 表示使用随机种子 | |
| real_seed = None if seed is None or int(seed) < 0 else int(seed) | |
| return sample_ddpm( | |
| num_samples=int(num_samples), | |
| seed=real_seed, | |
| progress=gr.Progress(), | |
| ) | |
| def build_demo() -> gr.Blocks: | |
| with gr.Blocks(theme=gr.themes.Soft(), title="mini-DDPM") as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| out_img = gr.Image(type="pil", label="采样结果") | |
| with gr.Column(scale=1): | |
| num_samples = gr.Slider( | |
| minimum=1, | |
| maximum=25, | |
| step=1, | |
| value=9, | |
| label="样本数量", | |
| ) | |
| seed = gr.Number(value=42, precision=0, label="随机种子") | |
| btn = gr.Button("开始生成", variant="primary") | |
| btn.click( | |
| fn=ui_generate, | |
| inputs=[num_samples, seed], | |
| outputs=[out_img], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| # 如果部署在远程,可设置 server_name="0.0.0.0" | |
| demo.queue().launch() | |