caixiaoshun commited on
Commit
2b3facb
·
verified ·
1 Parent(s): fd0ad94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -19
app.py CHANGED
@@ -1,12 +1,8 @@
1
- """
2
- Gradio 应用:mini-DDPM
3
-
4
- """
5
 
6
  import math
7
- import os
8
  import random
9
- from typing import Optional
10
 
11
  import gradio as gr
12
  import numpy as np
@@ -25,6 +21,7 @@ CHECKPOINT_URL = (
25
  # 工具函数
26
  # -----------------------------
27
 
 
28
  def seed_all(seed: Optional[int] = 42):
29
  if seed is None:
30
  return
@@ -33,13 +30,12 @@ def seed_all(seed: Optional[int] = 42):
33
  torch.manual_seed(seed)
34
  if torch.cuda.is_available():
35
  torch.cuda.manual_seed_all(seed)
36
- # 为了速度;如需完全复现,可改为 False
37
  torch.backends.cudnn.benchmark = True
38
 
39
 
40
-
41
- _MODEL_CACHE: dict[str, torch.nn.Module] = {}
42
- _SCHED_CACHE: dict[str, DDPMSchedule] = {}
43
 
44
 
45
  def get_device() -> torch.device:
@@ -50,7 +46,7 @@ def load_model(device: torch.device) -> torch.nn.Module:
50
  key = str(device)
51
  if key in _MODEL_CACHE:
52
  return _MODEL_CACHE[key]
53
-
54
  state_dict = torch.hub.load_state_dict_from_url(
55
  CHECKPOINT_URL, map_location=device, file_name="checkpoints.pt", progress=True
56
  )
@@ -74,13 +70,14 @@ def load_scheduler(device: torch.device) -> DDPMSchedule:
74
  # 采样核心
75
  # -----------------------------
76
 
 
77
  def sample_ddpm(
78
  num_samples: int = 9,
79
- seed: int | None = 42,
80
- progress: gr.Progress | None = None,
81
  ) -> Image.Image:
82
-
83
- side = int(math.sqrt(num_samples))
84
  num_samples = max(1, side * side)
85
 
86
  device = get_device()
@@ -110,12 +107,14 @@ def sample_ddpm(
110
  img = Image.fromarray((grid * 255).astype(np.uint8))
111
  return img
112
 
 
113
  # -----------------------------
114
  # Gradio UI
115
  # -----------------------------
116
 
 
117
  def ui_generate(num_samples, seed):
118
- # seed == -1 表示使用随机种子
119
  real_seed = None if seed is None or int(seed) < 0 else int(seed)
120
  return sample_ddpm(
121
  num_samples=int(num_samples),
@@ -126,7 +125,6 @@ def ui_generate(num_samples, seed):
126
 
127
  def build_demo() -> gr.Blocks:
128
  with gr.Blocks(theme=gr.themes.Soft(), title="mini-DDPM") as demo:
129
-
130
  with gr.Row():
131
  with gr.Column(scale=2):
132
  out_img = gr.Image(type="pil", label="采样结果")
@@ -138,8 +136,9 @@ def build_demo() -> gr.Blocks:
138
  value=9,
139
  label="样本数量",
140
  )
141
- seed = gr.Number(value=42, precision=0, label="随机种子")
142
- btn = gr.Button("开始生成", variant="primary")
 
143
 
144
  btn.click(
145
  fn=ui_generate,
 
1
+ # Gradio 应用:mini-DDPM
 
 
 
2
 
3
  import math
 
4
  import random
5
+ from typing import Optional, Dict
6
 
7
  import gradio as gr
8
  import numpy as np
 
21
  # 工具函数
22
  # -----------------------------
23
 
24
+
25
  def seed_all(seed: Optional[int] = 42):
26
  if seed is None:
27
  return
 
30
  torch.manual_seed(seed)
31
  if torch.cuda.is_available():
32
  torch.cuda.manual_seed_all(seed)
33
+ # 为了速度;如需完全复现,可改为 False
34
  torch.backends.cudnn.benchmark = True
35
 
36
 
37
+ _MODEL_CACHE: Dict[str, torch.nn.Module] = {}
38
+ _SCHED_CACHE: Dict[str, DDPMSchedule] = {}
 
39
 
40
 
41
  def get_device() -> torch.device:
 
46
  key = str(device)
47
  if key in _MODEL_CACHE:
48
  return _MODEL_CACHE[key]
49
+
50
  state_dict = torch.hub.load_state_dict_from_url(
51
  CHECKPOINT_URL, map_location=device, file_name="checkpoints.pt", progress=True
52
  )
 
70
  # 采样核心
71
  # -----------------------------
72
 
73
+
74
  def sample_ddpm(
75
  num_samples: int = 9,
76
+ seed: Optional[int] = 42,
77
+ progress: Optional[gr.Progress] = None,
78
  ) -> Image.Image:
79
+ # 将样本数调整为最接近的平方数,便于拼图展示
80
+ side = int(math.sqrt(max(1, num_samples)))
81
  num_samples = max(1, side * side)
82
 
83
  device = get_device()
 
107
  img = Image.fromarray((grid * 255).astype(np.uint8))
108
  return img
109
 
110
+
111
  # -----------------------------
112
  # Gradio UI
113
  # -----------------------------
114
 
115
+
116
  def ui_generate(num_samples, seed):
117
+ # seed < 0 或 None 表示使用随机种子
118
  real_seed = None if seed is None or int(seed) < 0 else int(seed)
119
  return sample_ddpm(
120
  num_samples=int(num_samples),
 
125
 
126
  def build_demo() -> gr.Blocks:
127
  with gr.Blocks(theme=gr.themes.Soft(), title="mini-DDPM") as demo:
 
128
  with gr.Row():
129
  with gr.Column(scale=2):
130
  out_img = gr.Image(type="pil", label="采样结果")
 
136
  value=9,
137
  label="样本数量",
138
  )
139
+
140
+ seed = gr.Number(value=42, precision=0, label="随机种子")
141
+ btn = gr.Button("开始生成", variant="primary")
142
 
143
  btn.click(
144
  fn=ui_generate,