CocoBro commited on
Commit
3c696cb
·
1 Parent(s): e72a2ea

fix rescale

Browse files
Files changed (1) hide show
  1. app.py +121 -65
app.py CHANGED
@@ -7,15 +7,12 @@ import spaces
7
  import os
8
  import time
9
  import logging
10
- import traceback # [新增] 用于打印报错堆栈
11
- import gc # [新增] 用于显存清理
12
  from pathlib import Path
13
  from typing import Tuple, Optional, Dict, Any
14
 
15
  import gradio as gr
16
  import numpy as np
17
  import soundfile as sf
18
- # [修改] 移除了顶部的 hydra/models 导入,防止启动时触发 CUDA
19
  from huggingface_hub import snapshot_download
20
 
21
 
@@ -27,7 +24,7 @@ logger = logging.getLogger("mmedit_space")
27
 
28
 
29
  # ---------------------------------------------------------
30
- # HF Repo IDs
31
  # ---------------------------------------------------------
32
  MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit")
33
  MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
@@ -35,14 +32,16 @@ MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
35
  QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
36
  QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
37
 
 
38
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
39
 
40
  OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
41
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
42
 
43
- # ---------------------------------------------------------
44
- # 缓存定义
45
- # ---------------------------------------------------------
 
46
  # cache: key -> (model_cpu, scheduler, target_sr)
47
  _PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
48
  # cache: key -> (repo_root, qwen_root)
@@ -50,26 +49,30 @@ _MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
50
 
51
 
52
  # ---------------------------------------------------------
53
- # 下载 Repo
54
  # ---------------------------------------------------------
55
  def resolve_model_dirs() -> Tuple[Path, Path]:
56
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
57
  if cache_key in _MODEL_DIR_CACHE:
58
  return _MODEL_DIR_CACHE[cache_key]
59
 
60
- logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID}")
61
  repo_root = snapshot_download(
62
  repo_id=MMEDIT_REPO_ID,
63
  revision=MMEDIT_REVISION,
 
 
64
  token=HF_TOKEN,
65
  )
66
  repo_root = Path(repo_root).resolve()
67
 
68
- logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID}")
69
  qwen_root = snapshot_download(
70
  repo_id=QWEN_REPO_ID,
71
  revision=QWEN_REVISION,
72
- token=HF_TOKEN,
 
 
73
  )
74
  qwen_root = Path(qwen_root).resolve()
75
 
@@ -78,10 +81,10 @@ def resolve_model_dirs() -> Tuple[Path, Path]:
78
 
79
 
80
  # ---------------------------------------------------------
81
- # 音频处理
82
  # ---------------------------------------------------------
83
  def load_and_process_audio(audio_path: str, target_sr: int):
84
- # 延迟导入,防止干扰
85
  import torch
86
  import torchaudio
87
  import librosa
@@ -90,14 +93,18 @@ def load_and_process_audio(audio_path: str, target_sr: int):
90
  if not path.exists():
91
  raise FileNotFoundError(f"Audio file not found: {audio_path}")
92
 
93
- waveform, orig_sr = torchaudio.load(str(path))
94
 
95
- if waveform.ndim > 1:
96
- waveform = waveform.mean(dim=0)
 
 
 
97
 
98
  if target_sr and int(target_sr) != int(orig_sr):
99
  waveform_np = waveform.cpu().numpy()
100
- # 稳健的重采样逻辑
 
101
  sr_mid = 16000
102
  if int(orig_sr) != sr_mid:
103
  waveform_np = librosa.resample(waveform_np, orig_sr=int(orig_sr), target_sr=sr_mid)
@@ -105,6 +112,7 @@ def load_and_process_audio(audio_path: str, target_sr: int):
105
  else:
106
  orig_sr_mid = int(orig_sr)
107
 
 
108
  if int(target_sr) != orig_sr_mid:
109
  waveform_np = librosa.resample(waveform_np, orig_sr=orig_sr_mid, target_sr=int(target_sr))
110
 
@@ -114,33 +122,55 @@ def load_and_process_audio(audio_path: str, target_sr: int):
114
 
115
 
116
  # ---------------------------------------------------------
117
- # 路径适配
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  # ---------------------------------------------------------
119
  def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None:
 
120
  vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
121
  if vae_ckpt:
122
  vae_ckpt = str(vae_ckpt).replace("\\", "/")
123
- if "vae/" in vae_ckpt:
124
- vae_rel = vae_ckpt[vae_ckpt.find("vae/"):]
125
- elif vae_ckpt.endswith(".ckpt"):
126
- vae_rel = f"vae/{vae_ckpt}" if "/" not in vae_ckpt else vae_ckpt
127
  else:
128
- vae_rel = vae_ckpt
 
 
 
129
 
130
  vae_path = (repo_root / vae_rel).resolve()
131
- # 鲁棒性检查:如果算出来的路径不存在,尝试在根目录找文件名
132
- if not vae_path.exists():
133
- fallback = repo_root / Path(vae_ckpt).name
134
- if fallback.exists():
135
- vae_path = fallback
136
-
137
  exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path)
138
 
 
 
 
 
 
 
 
 
 
 
139
  exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
140
 
141
 
142
  # ---------------------------------------------------------
143
- # Scheduler
 
144
  # ---------------------------------------------------------
145
  def build_scheduler(exp_cfg: Dict[str, Any]):
146
  import diffusers.schedulers as noise_schedulers
@@ -150,7 +180,7 @@ def build_scheduler(exp_cfg: Dict[str, Any]):
150
  scheduler = noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler", token=HF_TOKEN)
151
  return scheduler
152
  except Exception as e:
153
- logger.warning(f"Scheduler fallback: {e}")
154
  return noise_schedulers.DDIMScheduler(
155
  num_train_timesteps=1000,
156
  beta_start=0.00085,
@@ -162,60 +192,73 @@ def build_scheduler(exp_cfg: Dict[str, Any]):
162
  )
163
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  # ---------------------------------------------------------
166
- # [核心] 冷启动:Load to CPU
167
  # ---------------------------------------------------------
168
  def load_pipeline_cpu() -> Tuple[object, object, int]:
169
- # [修改] 所有的库都在这里导入,防止全局导入触发 CUDA 初始化
170
  import torch
171
  import hydra
172
  from omegaconf import OmegaConf
173
  from safetensors.torch import load_file
174
-
175
- # 你的项目依赖
176
- try:
177
- from utils.config import register_omegaconf_resolvers
178
- from models.common import LoadPretrainedBase
179
- register_omegaconf_resolvers()
180
- except ImportError:
181
- logger.warning("Could not import project utils/models. Ensure they are in the python path.")
182
- except Exception:
183
- pass
184
 
185
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
186
  if cache_key in _PIPELINE_CACHE:
187
  return _PIPELINE_CACHE[cache_key]
188
 
189
  repo_root, qwen_root = resolve_model_dirs()
190
-
191
- logger.info(f"repo_root: {repo_root}")
192
-
 
 
193
  exp_cfg = OmegaConf.load(repo_root / "config.yaml")
194
  exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
195
 
196
  patch_paths_in_exp_config(exp_cfg, repo_root, qwen_root)
 
 
197
 
198
- logger.info("Instantiating model...")
199
- model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
200
 
201
  ckpt_path = repo_root / "model.safetensors"
202
- logger.info(f"Loading weights: {ckpt_path}")
203
  sd = load_file(str(ckpt_path))
204
  model.load_pretrained(sd)
205
 
206
- # [修改] 确保加载到 CPU
207
  model = model.to(torch.device("cpu")).eval()
208
 
209
  scheduler = build_scheduler(exp_cfg)
210
  target_sr = int(exp_cfg.get("sample_rate", 24000))
211
 
212
  _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
213
- logger.info("CPU pipeline cached.")
214
  return model, scheduler, target_sr
215
 
216
 
217
  # ---------------------------------------------------------
218
- # [核心] 推理函数 (ZeroGPU 适配版)
 
 
219
  # ---------------------------------------------------------
220
  @spaces.GPU
221
  def run_edit(
@@ -227,6 +270,7 @@ def run_edit(
227
  seed: int,
228
  ) -> Tuple[Optional[str], str]:
229
  import torch
 
230
 
231
  if not audio_file: return None, "Error: Upload audio first."
232
  if not caption: return None, "Error: Input caption."
@@ -314,43 +358,54 @@ def run_edit(
314
  torch.cuda.empty_cache()
315
  gc.collect()
316
 
317
-
318
  # ---------------------------------------------------------
319
  # UI
320
  # ---------------------------------------------------------
321
  def build_demo():
322
  with gr.Blocks(title="MMEdit (ZeroGPU)") as demo:
323
- gr.Markdown("# MMEdit ZeroGPU")
324
 
325
  with gr.Row():
326
  with gr.Column():
327
  audio_in = gr.Audio(label="Input Audio", type="filepath")
328
- caption = gr.Textbox(label="Caption", lines=3)
329
 
 
330
  gr.Examples(
331
- label="Examples",
332
- examples=[["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."]],
 
 
333
  inputs=[audio_in, caption],
 
334
  )
335
 
336
  with gr.Row():
337
- num_steps = gr.Slider(1, 100, value=50, step=1, label="Steps")
338
- guidance_scale = gr.Slider(1.0, 12.0, value=5.0, step=0.5, label="Guidance")
339
- rescale = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Rescale")
340
- seed = gr.Number(42, label="Seed")
 
 
341
 
342
- run_btn = gr.Button("Run", variant="primary")
343
 
344
  with gr.Column():
345
- audio_out = gr.Audio(label="Output", type="filepath")
346
  status = gr.Textbox(label="Status")
347
 
348
  run_btn.click(
349
  fn=run_edit,
350
- inputs=[audio_in, caption, num_steps, guidance_scale, rescale, seed],
351
  outputs=[audio_out, status],
352
  )
353
 
 
 
 
 
 
 
354
  return demo
355
 
356
 
@@ -361,4 +416,5 @@ if __name__ == "__main__":
361
  server_name="0.0.0.0",
362
  server_port=port,
363
  share=False,
364
- )
 
 
7
  import os
8
  import time
9
  import logging
 
 
10
  from pathlib import Path
11
  from typing import Tuple, Optional, Dict, Any
12
 
13
  import gradio as gr
14
  import numpy as np
15
  import soundfile as sf
 
16
  from huggingface_hub import snapshot_download
17
 
18
 
 
24
 
25
 
26
  # ---------------------------------------------------------
27
+ # HF Repo IDs(按你的默认需求)
28
  # ---------------------------------------------------------
29
  MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit")
30
  MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
 
32
  QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
33
  QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
34
 
35
+ # 如果 Qwen gated:Space 里把 HF_TOKEN 设为 Secret
36
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
37
 
38
  OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
39
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
40
 
41
+ USE_AMP = os.environ.get("USE_AMP", "0") == "1"
42
+ AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") # "bf16" or "fp16"
43
+
44
+ # ZeroGPU:缓存 CPU pipeline(不要缓存 CUDA Tensor)
45
  # cache: key -> (model_cpu, scheduler, target_sr)
46
  _PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
47
  # cache: key -> (repo_root, qwen_root)
 
49
 
50
 
51
  # ---------------------------------------------------------
52
+ # 下载 repo(只下载一次;huggingface_hub 自带缓存)
53
  # ---------------------------------------------------------
54
  def resolve_model_dirs() -> Tuple[Path, Path]:
55
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
56
  if cache_key in _MODEL_DIR_CACHE:
57
  return _MODEL_DIR_CACHE[cache_key]
58
 
59
+ logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})")
60
  repo_root = snapshot_download(
61
  repo_id=MMEDIT_REPO_ID,
62
  revision=MMEDIT_REVISION,
63
+ local_dir=None,
64
+ local_dir_use_symlinks=False,
65
  token=HF_TOKEN,
66
  )
67
  repo_root = Path(repo_root).resolve()
68
 
69
+ logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID} (revision={QWEN_REVISION})")
70
  qwen_root = snapshot_download(
71
  repo_id=QWEN_REPO_ID,
72
  revision=QWEN_REVISION,
73
+ local_dir=None,
74
+ local_dir_use_symlinks=False,
75
+ token=HF_TOKEN, # gated 模型必须
76
  )
77
  qwen_root = Path(qwen_root).resolve()
78
 
 
81
 
82
 
83
  # ---------------------------------------------------------
84
+ # 你的音频加载(按你要求:orig -> 16k -> target_sr)
85
  # ---------------------------------------------------------
86
  def load_and_process_audio(audio_path: str, target_sr: int):
87
+ # 延迟导入(避免启动阶段触发 CUDA 初始化)
88
  import torch
89
  import torchaudio
90
  import librosa
 
93
  if not path.exists():
94
  raise FileNotFoundError(f"Audio file not found: {audio_path}")
95
 
96
+ waveform, orig_sr = torchaudio.load(str(path)) # (C, T)
97
 
98
+ # Convert to mono
99
+ if waveform.ndim == 2:
100
+ waveform = waveform.mean(dim=0) # (T,)
101
+ elif waveform.ndim > 2:
102
+ waveform = waveform.reshape(-1)
103
 
104
  if target_sr and int(target_sr) != int(orig_sr):
105
  waveform_np = waveform.cpu().numpy()
106
+
107
+ # 1) 先到 16k
108
  sr_mid = 16000
109
  if int(orig_sr) != sr_mid:
110
  waveform_np = librosa.resample(waveform_np, orig_sr=int(orig_sr), target_sr=sr_mid)
 
112
  else:
113
  orig_sr_mid = int(orig_sr)
114
 
115
+ # 2) 再到 target_sr(如 24k)
116
  if int(target_sr) != orig_sr_mid:
117
  waveform_np = librosa.resample(waveform_np, orig_sr=orig_sr_mid, target_sr=int(target_sr))
118
 
 
122
 
123
 
124
  # ---------------------------------------------------------
125
+ # 校验 repo 结构
126
+ # ---------------------------------------------------------
127
+ def assert_repo_layout(repo_root: Path) -> None:
128
+ must = [repo_root / "config.yaml", repo_root / "model.safetensors", repo_root / "vae"]
129
+ for p in must:
130
+ if not p.exists():
131
+ raise FileNotFoundError(f"Missing required path: {p}")
132
+
133
+ vae_files = list((repo_root / "vae").glob("*.ckpt"))
134
+ if len(vae_files) == 0:
135
+ raise FileNotFoundError(f"No .ckpt found under: {repo_root/'vae'}")
136
+
137
+
138
+ # ---------------------------------------------------------
139
+ # 适配 config.yaml 的路径写法
140
  # ---------------------------------------------------------
141
  def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None:
142
+ # ---- 1) VAE ckpt ----
143
  vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
144
  if vae_ckpt:
145
  vae_ckpt = str(vae_ckpt).replace("\\", "/")
146
+ idx = vae_ckpt.find("vae/")
147
+ if idx != -1:
148
+ vae_rel = vae_ckpt[idx:] # 从 vae/ 开始截断
 
149
  else:
150
+ if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt:
151
+ vae_rel = f"vae/{vae_ckpt}"
152
+ else:
153
+ vae_rel = vae_ckpt
154
 
155
  vae_path = (repo_root / vae_rel).resolve()
 
 
 
 
 
 
156
  exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path)
157
 
158
+ if not vae_path.exists():
159
+ raise FileNotFoundError(
160
+ f"VAE ckpt not found after patch:\n"
161
+ f" original: {vae_ckpt}\n"
162
+ f" patched : {vae_path}\n"
163
+ f"Repo root: {repo_root}\n"
164
+ f"Expected: {repo_root/'vae'/'*.ckpt'}"
165
+ )
166
+
167
+ # ---- 2) Qwen2-Audio model_path ----
168
  exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
169
 
170
 
171
  # ---------------------------------------------------------
172
+ # Scheduler(与你 exp_cfg.model.noise_scheduler_name 对齐)
173
+ # 带 fallback:避免 404
174
  # ---------------------------------------------------------
175
  def build_scheduler(exp_cfg: Dict[str, Any]):
176
  import diffusers.schedulers as noise_schedulers
 
180
  scheduler = noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler", token=HF_TOKEN)
181
  return scheduler
182
  except Exception as e:
183
+ logger.warning(f"DDIMScheduler.from_pretrained failed for '{name}', fallback. err={e}")
184
  return noise_schedulers.DDIMScheduler(
185
  num_train_timesteps=1000,
186
  beta_start=0.00085,
 
192
  )
193
 
194
 
195
+ def amp_autocast(device):
196
+ import torch
197
+
198
+ if not USE_AMP:
199
+ return torch.autocast("cuda", enabled=False)
200
+
201
+ if device.type != "cuda":
202
+ return torch.autocast("cpu", enabled=False)
203
+
204
+ dtype = torch.bfloat16 if AMP_DTYPE.lower() == "bf16" else torch.float16
205
+ return torch.autocast("cuda", dtype=dtype, enabled=True)
206
+
207
+
208
  # ---------------------------------------------------------
209
+ # 冷启动:load+cache pipeline(缓存 CPU 上的 model)
210
  # ---------------------------------------------------------
211
  def load_pipeline_cpu() -> Tuple[object, object, int]:
212
+ # 延迟导入(避免启动阶段触发 CUDA 初始化)
213
  import torch
214
  import hydra
215
  from omegaconf import OmegaConf
216
  from safetensors.torch import load_file
217
+
218
+ # 你的项目依赖也延迟导入
219
+ from models.common import LoadPretrainedBase
220
+ from utils.config import register_omegaconf_resolvers
221
+
222
+ register_omegaconf_resolvers()
 
 
 
 
223
 
224
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
225
  if cache_key in _PIPELINE_CACHE:
226
  return _PIPELINE_CACHE[cache_key]
227
 
228
  repo_root, qwen_root = resolve_model_dirs()
229
+ assert_repo_layout(repo_root)
230
+
231
+ logger.info(f"repo_root = {repo_root}")
232
+ logger.info(f"qwen_root = {qwen_root}")
233
+
234
  exp_cfg = OmegaConf.load(repo_root / "config.yaml")
235
  exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
236
 
237
  patch_paths_in_exp_config(exp_cfg, repo_root, qwen_root)
238
+ logger.info(f"patched pretrained_ckpt = {exp_cfg['model']['autoencoder'].get('pretrained_ckpt')}")
239
+ logger.info(f"patched qwen model_path = {exp_cfg['model']['content_encoder']['text_encoder'].get('model_path')}")
240
 
241
+ model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
 
242
 
243
  ckpt_path = repo_root / "model.safetensors"
 
244
  sd = load_file(str(ckpt_path))
245
  model.load_pretrained(sd)
246
 
247
+ # ZeroGPU:缓存 CPU
248
  model = model.to(torch.device("cpu")).eval()
249
 
250
  scheduler = build_scheduler(exp_cfg)
251
  target_sr = int(exp_cfg.get("sample_rate", 24000))
252
 
253
  _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
254
+ logger.info("CPU pipeline loaded and cached.")
255
  return model, scheduler, target_sr
256
 
257
 
258
  # ---------------------------------------------------------
259
+ # 推理:audio + caption -> edited audio
260
+ # ZeroGPU:必须用 @spaces.GPU
261
+ # ---------------------------------------------------------
262
  # ---------------------------------------------------------
263
  @spaces.GPU
264
  def run_edit(
 
270
  seed: int,
271
  ) -> Tuple[Optional[str], str]:
272
  import torch
273
+ import gc
274
 
275
  if not audio_file: return None, "Error: Upload audio first."
276
  if not caption: return None, "Error: Input caption."
 
358
  torch.cuda.empty_cache()
359
  gc.collect()
360
 
 
361
  # ---------------------------------------------------------
362
  # UI
363
  # ---------------------------------------------------------
364
  def build_demo():
365
  with gr.Blocks(title="MMEdit (ZeroGPU)") as demo:
366
+ gr.Markdown("# MMEdit ZeroGPU(audio + caption → edited audio)")
367
 
368
  with gr.Row():
369
  with gr.Column():
370
  audio_in = gr.Audio(label="Input Audio", type="filepath")
371
+ caption = gr.Textbox(label="Caption (Edit Instruction)", lines=3)
372
 
373
+ # 注意:Space 不建议推大 wav;你可以换成更小的 demo wav
374
  gr.Examples(
375
+ label="example inputs",
376
+ examples=[
377
+ ["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."],
378
+ ],
379
  inputs=[audio_in, caption],
380
+ cache_examples=False,
381
  )
382
 
383
  with gr.Row():
384
+ num_steps = gr.Slider(1, 100, value=50, step=1, label="num_steps")
385
+ guidance_scale = gr.Slider(1.0, 12.0, value=5.0, step=0.5, label="guidance_scale")
386
+
387
+ with gr.Row():
388
+ guidance_rescale = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="guidance_rescale")
389
+ seed = gr.Number(value=42, precision=0, label="seed")
390
 
391
+ run_btn = gr.Button("Run Editing", variant="primary")
392
 
393
  with gr.Column():
394
+ audio_out = gr.Audio(label="Edited Audio", type="filepath")
395
  status = gr.Textbox(label="Status")
396
 
397
  run_btn.click(
398
  fn=run_edit,
399
+ inputs=[audio_in, caption, num_steps, guidance_scale, guidance_rescale, seed],
400
  outputs=[audio_out, status],
401
  )
402
 
403
+ gr.Markdown(
404
+ "## 注意事项\n"
405
+ "1) ZeroGPU 首次点击会分配 GPU,可能稍慢。\n"
406
+ "2) 如果首次报 cuda 不可用,通常重试一次即可。\n"
407
+ )
408
+
409
  return demo
410
 
411
 
 
416
  server_name="0.0.0.0",
417
  server_port=port,
418
  share=False,
419
+ ssr_mode=False,
420
+ )