Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,12 +1,8 @@
|
|
| 1 |
-
|
| 2 |
-
Gradio 应用:mini-DDPM
|
| 3 |
-
|
| 4 |
-
"""
|
| 5 |
|
| 6 |
import math
|
| 7 |
-
import os
|
| 8 |
import random
|
| 9 |
-
from typing import Optional
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
import numpy as np
|
|
@@ -25,6 +21,7 @@ CHECKPOINT_URL = (
|
|
| 25 |
# 工具函数
|
| 26 |
# -----------------------------
|
| 27 |
|
|
|
|
| 28 |
def seed_all(seed: Optional[int] = 42):
|
| 29 |
if seed is None:
|
| 30 |
return
|
|
@@ -33,13 +30,12 @@ def seed_all(seed: Optional[int] = 42):
|
|
| 33 |
torch.manual_seed(seed)
|
| 34 |
if torch.cuda.is_available():
|
| 35 |
torch.cuda.manual_seed_all(seed)
|
| 36 |
-
# 为了速度;如需完全复现,可改为 False
|
| 37 |
torch.backends.cudnn.benchmark = True
|
| 38 |
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
_SCHED_CACHE: dict[str, DDPMSchedule] = {}
|
| 43 |
|
| 44 |
|
| 45 |
def get_device() -> torch.device:
|
|
@@ -50,7 +46,7 @@ def load_model(device: torch.device) -> torch.nn.Module:
|
|
| 50 |
key = str(device)
|
| 51 |
if key in _MODEL_CACHE:
|
| 52 |
return _MODEL_CACHE[key]
|
| 53 |
-
|
| 54 |
state_dict = torch.hub.load_state_dict_from_url(
|
| 55 |
CHECKPOINT_URL, map_location=device, file_name="checkpoints.pt", progress=True
|
| 56 |
)
|
|
@@ -74,13 +70,14 @@ def load_scheduler(device: torch.device) -> DDPMSchedule:
|
|
| 74 |
# 采样核心
|
| 75 |
# -----------------------------
|
| 76 |
|
|
|
|
| 77 |
def sample_ddpm(
|
| 78 |
num_samples: int = 9,
|
| 79 |
-
seed: int
|
| 80 |
-
progress: gr.Progress
|
| 81 |
) -> Image.Image:
|
| 82 |
-
|
| 83 |
-
side = int(math.sqrt(num_samples))
|
| 84 |
num_samples = max(1, side * side)
|
| 85 |
|
| 86 |
device = get_device()
|
|
@@ -110,12 +107,14 @@ def sample_ddpm(
|
|
| 110 |
img = Image.fromarray((grid * 255).astype(np.uint8))
|
| 111 |
return img
|
| 112 |
|
|
|
|
| 113 |
# -----------------------------
|
| 114 |
# Gradio UI
|
| 115 |
# -----------------------------
|
| 116 |
|
|
|
|
| 117 |
def ui_generate(num_samples, seed):
|
| 118 |
-
# seed
|
| 119 |
real_seed = None if seed is None or int(seed) < 0 else int(seed)
|
| 120 |
return sample_ddpm(
|
| 121 |
num_samples=int(num_samples),
|
|
@@ -126,7 +125,6 @@ def ui_generate(num_samples, seed):
|
|
| 126 |
|
| 127 |
def build_demo() -> gr.Blocks:
|
| 128 |
with gr.Blocks(theme=gr.themes.Soft(), title="mini-DDPM") as demo:
|
| 129 |
-
|
| 130 |
with gr.Row():
|
| 131 |
with gr.Column(scale=2):
|
| 132 |
out_img = gr.Image(type="pil", label="采样结果")
|
|
@@ -138,8 +136,9 @@ def build_demo() -> gr.Blocks:
|
|
| 138 |
value=9,
|
| 139 |
label="样本数量",
|
| 140 |
)
|
| 141 |
-
|
| 142 |
-
|
|
|
|
| 143 |
|
| 144 |
btn.click(
|
| 145 |
fn=ui_generate,
|
|
|
|
| 1 |
+
# Gradio 应用:mini-DDPM
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import math
|
|
|
|
| 4 |
import random
|
| 5 |
+
from typing import Optional, Dict
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import numpy as np
|
|
|
|
| 21 |
# 工具函数
|
| 22 |
# -----------------------------
|
| 23 |
|
| 24 |
+
|
| 25 |
def seed_all(seed: Optional[int] = 42):
|
| 26 |
if seed is None:
|
| 27 |
return
|
|
|
|
| 30 |
torch.manual_seed(seed)
|
| 31 |
if torch.cuda.is_available():
|
| 32 |
torch.cuda.manual_seed_all(seed)
|
| 33 |
+
# 为了速度;如需完全复现,可改为 False
|
| 34 |
torch.backends.cudnn.benchmark = True
|
| 35 |
|
| 36 |
|
| 37 |
+
_MODEL_CACHE: Dict[str, torch.nn.Module] = {}
|
| 38 |
+
_SCHED_CACHE: Dict[str, DDPMSchedule] = {}
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
def get_device() -> torch.device:
|
|
|
|
| 46 |
key = str(device)
|
| 47 |
if key in _MODEL_CACHE:
|
| 48 |
return _MODEL_CACHE[key]
|
| 49 |
+
|
| 50 |
state_dict = torch.hub.load_state_dict_from_url(
|
| 51 |
CHECKPOINT_URL, map_location=device, file_name="checkpoints.pt", progress=True
|
| 52 |
)
|
|
|
|
| 70 |
# 采样核心
|
| 71 |
# -----------------------------
|
| 72 |
|
| 73 |
+
|
| 74 |
def sample_ddpm(
|
| 75 |
num_samples: int = 9,
|
| 76 |
+
seed: Optional[int] = 42,
|
| 77 |
+
progress: Optional[gr.Progress] = None,
|
| 78 |
) -> Image.Image:
|
| 79 |
+
# 将样本数调整为最接近的平方数,便于拼图展示
|
| 80 |
+
side = int(math.sqrt(max(1, num_samples)))
|
| 81 |
num_samples = max(1, side * side)
|
| 82 |
|
| 83 |
device = get_device()
|
|
|
|
| 107 |
img = Image.fromarray((grid * 255).astype(np.uint8))
|
| 108 |
return img
|
| 109 |
|
| 110 |
+
|
| 111 |
# -----------------------------
|
| 112 |
# Gradio UI
|
| 113 |
# -----------------------------
|
| 114 |
|
| 115 |
+
|
| 116 |
def ui_generate(num_samples, seed):
|
| 117 |
+
# seed < 0 或 None 表示使用随机种子
|
| 118 |
real_seed = None if seed is None or int(seed) < 0 else int(seed)
|
| 119 |
return sample_ddpm(
|
| 120 |
num_samples=int(num_samples),
|
|
|
|
| 125 |
|
| 126 |
def build_demo() -> gr.Blocks:
|
| 127 |
with gr.Blocks(theme=gr.themes.Soft(), title="mini-DDPM") as demo:
|
|
|
|
| 128 |
with gr.Row():
|
| 129 |
with gr.Column(scale=2):
|
| 130 |
out_img = gr.Image(type="pil", label="采样结果")
|
|
|
|
| 136 |
value=9,
|
| 137 |
label="样本数量",
|
| 138 |
)
|
| 139 |
+
|
| 140 |
+
seed = gr.Number(value=42, precision=0, label="随机种子")
|
| 141 |
+
btn = gr.Button("开始生成", variant="primary")
|
| 142 |
|
| 143 |
btn.click(
|
| 144 |
fn=ui_generate,
|