CocoBro commited on
Commit
0bcb372
·
1 Parent(s): 3d6e612
Files changed (1) hide show
  1. app.py +91 -35
app.py CHANGED
@@ -20,6 +20,9 @@ from safetensors.torch import load_file
20
  import diffusers.schedulers as noise_schedulers
21
  from huggingface_hub import snapshot_download
22
 
 
 
 
23
  from models.common import LoadPretrainedBase
24
  from utils.config import register_omegaconf_resolvers
25
 
@@ -45,17 +48,22 @@ MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
45
  QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
46
  QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
47
 
 
 
 
48
  OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
49
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
50
 
51
  USE_AMP = os.environ.get("USE_AMP", "0") == "1"
52
  AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") # "bf16" or "fp16"
53
 
54
- _PIPELINE_CACHE: Dict[str, Tuple[LoadPretrainedBase, object, int, torch.device]] = {}
 
 
55
 
56
 
57
  # ---------------------------------------------------------
58
- # 下载 repo
59
  # ---------------------------------------------------------
60
  def resolve_model_dirs() -> Tuple[Path, Path]:
61
  """
@@ -63,12 +71,17 @@ def resolve_model_dirs() -> Tuple[Path, Path]:
63
  repo_root: 你的 MMEdit repo 的本地目录(包含 config.yaml / model.safetensors / vae/)
64
  qwen_root: Qwen2-Audio repo 的本地目录
65
  """
 
 
 
 
66
  logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})")
67
  repo_root = snapshot_download(
68
  repo_id=MMEDIT_REPO_ID,
69
  revision=MMEDIT_REVISION,
70
  local_dir=None,
71
  local_dir_use_symlinks=False,
 
72
  )
73
  repo_root = Path(repo_root).resolve()
74
 
@@ -78,9 +91,11 @@ def resolve_model_dirs() -> Tuple[Path, Path]:
78
  revision=QWEN_REVISION,
79
  local_dir=None,
80
  local_dir_use_symlinks=False,
 
81
  )
82
  qwen_root = Path(qwen_root).resolve()
83
 
 
84
  return repo_root, qwen_root
85
 
86
 
@@ -155,21 +170,15 @@ def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_roo
155
  - pretrained_ckpt: ckpt/mmedit/vae/epoch=xx.ckpt -> repo_root/vae/epoch=xx.ckpt
156
  - model_path: ckpt/qwen2-audio-7B-instruct -> qwen_root (snapshot_download 结果)
157
  """
158
-
159
  # ---- 1) VAE ckpt ----
160
  vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
161
  if vae_ckpt:
162
  vae_ckpt = str(vae_ckpt).replace("\\", "/")
163
 
164
- # 你这里最稳定的做法:找到 "vae/" 子串之后的后缀
165
- # 例如:
166
- # ckpt/mmedit/vae/epoch=13-step=1000000.ckpt -> vae/epoch=13-step=1000000.ckpt
167
  idx = vae_ckpt.find("vae/")
168
  if idx != -1:
169
  vae_rel = vae_ckpt[idx:] # 从 vae/ 开始截断
170
  else:
171
- # 兜底:如果有人直接写 epoch=xx.ckpt,那就放到 repo_root/vae/
172
- # 或者写 vae/xxx.ckpt
173
  if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt:
174
  vae_rel = f"vae/{vae_ckpt}"
175
  else:
@@ -188,20 +197,35 @@ def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_roo
188
  )
189
 
190
  # ---- 2) Qwen2-Audio model_path ----
191
- # 你的 config 里写的是 ckpt/qwen2-audio-7B-instruct,但 Space 上我们直接用下载后的 qwen_root
192
  exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
193
 
194
 
195
  # ---------------------------------------------------------
196
  # Scheduler(与你 exp_cfg.model.noise_scheduler_name 对齐)
 
 
197
  # ---------------------------------------------------------
198
  def build_scheduler(exp_cfg: Dict[str, Any]):
199
  name = exp_cfg["model"].get("noise_scheduler_name", "stabilityai/stable-diffusion-2-1")
200
- scheduler = noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler")
201
- return scheduler
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
 
204
  def _amp_ctx(device: torch.device):
 
205
  if not USE_AMP:
206
  return torch.autocast("cuda", enabled=False)
207
  if device.type != "cuda":
@@ -211,9 +235,10 @@ def _amp_ctx(device: torch.device):
211
 
212
 
213
  # ---------------------------------------------------------
214
- # 冷启动:load+cache pipeline
 
215
  # ---------------------------------------------------------
216
- def load_pipeline() -> Tuple[LoadPretrainedBase, object, int, torch.device]:
217
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
218
  if cache_key in _PIPELINE_CACHE:
219
  return _PIPELINE_CACHE[cache_key]
@@ -221,10 +246,9 @@ def load_pipeline() -> Tuple[LoadPretrainedBase, object, int, torch.device]:
221
  repo_root, qwen_root = resolve_model_dirs()
222
  assert_repo_layout(repo_root)
223
 
224
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
225
  logger.info(f"repo_root = {repo_root}")
226
- logger.info(f"device = {device}")
227
  logger.info(f"qwen_root = {qwen_root}")
 
228
 
229
  exp_cfg = OmegaConf.load(repo_root / "config.yaml")
230
  exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
@@ -233,25 +257,31 @@ def load_pipeline() -> Tuple[LoadPretrainedBase, object, int, torch.device]:
233
  logger.info(f"patched pretrained_ckpt = {exp_cfg['model']['autoencoder'].get('pretrained_ckpt')}")
234
  logger.info(f"patched qwen model_path = {exp_cfg['model']['content_encoder']['text_encoder'].get('model_path')}")
235
 
 
236
  model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
237
 
 
238
  ckpt_path = repo_root / "model.safetensors"
239
  sd = load_file(str(ckpt_path))
240
  model.load_pretrained(sd)
241
 
242
- model = model.to(device).eval()
 
243
 
244
  scheduler = build_scheduler(exp_cfg)
245
  target_sr = int(exp_cfg.get("sample_rate", 24000))
246
 
247
- _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr, device)
248
- logger.info("Pipeline loaded and cached.")
249
- return model, scheduler, target_sr, device
250
 
251
 
252
  # ---------------------------------------------------------
253
  # 推理:audio + caption -> edited audio
 
 
254
  # ---------------------------------------------------------
 
255
  @torch.no_grad()
256
  def run_edit(
257
  audio_file: str,
@@ -268,12 +298,25 @@ def run_edit(
268
  if not caption:
269
  return None, "Error: caption is empty."
270
 
271
- model, scheduler, target_sr, device = load_pipeline()
 
 
 
 
 
 
 
 
 
 
 
272
 
 
273
  seed = int(seed)
274
  torch.manual_seed(seed)
275
  np.random.seed(seed)
276
 
 
277
  wav = load_and_process_audio(audio_file, target_sr=target_sr).to(device)
278
 
279
  batch = {
@@ -282,7 +325,7 @@ def run_edit(
282
  "task": ["audio_editing"],
283
  }
284
 
285
- # 和你给的 infer.config 对齐
286
  kwargs = {
287
  "num_steps": int(num_steps),
288
  "guidance_scale": float(guidance_scale),
@@ -301,6 +344,15 @@ def run_edit(
301
  out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
302
  sf.write(str(out_path), out_audio, samplerate=target_sr)
303
 
 
 
 
 
 
 
 
 
 
304
  return str(out_path), f"OK | saved={out_path.name} | time={dt:.2f}s | sr={target_sr} | seed={seed}"
305
 
306
 
@@ -308,25 +360,24 @@ def run_edit(
308
  # UI
309
  # ---------------------------------------------------------
310
  def build_demo():
311
- with gr.Blocks(title="MMEdit Space Simulator") as demo:
312
- gr.Markdown("# MMEdit Space 模拟(audio + caption → edited audio)")
313
- gr.Markdown(
314
- "点下面的示例即可自动填充音频路径与编辑指令,然后点击 Run Editing。"
315
- )
316
 
317
  with gr.Row():
318
  with gr.Column():
319
  audio_in = gr.Audio(label="Input Audio", type="filepath")
320
  caption = gr.Textbox(label="Caption (Edit Instruction)", lines=3)
321
 
322
- # 一键填充示例:点一下就把 audio_in + caption 填好
 
323
  gr.Examples(
324
  label="example inputs",
325
  examples=[
326
- ["example/Ym8O802VvJes.wav", "Mix in dog barking in the middle."],
327
  ],
328
  inputs=[audio_in, caption],
329
- cache_examples=False, # 本地/Space 都更稳,不提前缓存
330
  )
331
 
332
  with gr.Row():
@@ -351,15 +402,20 @@ def build_demo():
351
 
352
  gr.Markdown(
353
  "## 注意事项\n"
354
- "- 首次加载较慢\n"
355
- "- Space 上有一些bug,某些情况会损失原始音频\n"
 
356
  )
357
-
358
  return demo
359
 
360
 
361
-
362
  if __name__ == "__main__":
363
  demo = build_demo()
364
- port = int(os.environ.get("PORT", "7860")) # Space 默认 7860
365
- demo.launch(server_name="0.0.0.0", server_port=port, share=False)
 
 
 
 
 
 
 
20
  import diffusers.schedulers as noise_schedulers
21
  from huggingface_hub import snapshot_download
22
 
23
+ # ZeroGPU 关键:spaces
24
+ import spaces
25
+
26
  from models.common import LoadPretrainedBase
27
  from utils.config import register_omegaconf_resolvers
28
 
 
48
  QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
49
  QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
50
 
51
+ # 如果 Qwen gated:Space 里把 HF_TOKEN 设为 Secret
52
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
53
+
54
  OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
55
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
56
 
57
  USE_AMP = os.environ.get("USE_AMP", "0") == "1"
58
  AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") # "bf16" or "fp16"
59
 
60
+ # ZeroGPU:缓存 CPU pipeline(不要缓存在 CUDA)
61
+ _PIPELINE_CACHE: Dict[str, Tuple[LoadPretrainedBase, object, int]] = {}
62
+ _MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
63
 
64
 
65
  # ---------------------------------------------------------
66
+ # 下载 repo(只下载一次;huggingface_hub 自带缓存)
67
  # ---------------------------------------------------------
68
  def resolve_model_dirs() -> Tuple[Path, Path]:
69
  """
 
71
  repo_root: 你的 MMEdit repo 的本地目录(包含 config.yaml / model.safetensors / vae/)
72
  qwen_root: Qwen2-Audio repo 的本地目录
73
  """
74
+ cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
75
+ if cache_key in _MODEL_DIR_CACHE:
76
+ return _MODEL_DIR_CACHE[cache_key]
77
+
78
  logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})")
79
  repo_root = snapshot_download(
80
  repo_id=MMEDIT_REPO_ID,
81
  revision=MMEDIT_REVISION,
82
  local_dir=None,
83
  local_dir_use_symlinks=False,
84
+ token=HF_TOKEN, # 私有 repo 时也可用
85
  )
86
  repo_root = Path(repo_root).resolve()
87
 
 
91
  revision=QWEN_REVISION,
92
  local_dir=None,
93
  local_dir_use_symlinks=False,
94
+ token=HF_TOKEN, # gated 模型必须
95
  )
96
  qwen_root = Path(qwen_root).resolve()
97
 
98
+ _MODEL_DIR_CACHE[cache_key] = (repo_root, qwen_root)
99
  return repo_root, qwen_root
100
 
101
 
 
170
  - pretrained_ckpt: ckpt/mmedit/vae/epoch=xx.ckpt -> repo_root/vae/epoch=xx.ckpt
171
  - model_path: ckpt/qwen2-audio-7B-instruct -> qwen_root (snapshot_download 结果)
172
  """
 
173
  # ---- 1) VAE ckpt ----
174
  vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
175
  if vae_ckpt:
176
  vae_ckpt = str(vae_ckpt).replace("\\", "/")
177
 
 
 
 
178
  idx = vae_ckpt.find("vae/")
179
  if idx != -1:
180
  vae_rel = vae_ckpt[idx:] # 从 vae/ 开始截断
181
  else:
 
 
182
  if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt:
183
  vae_rel = f"vae/{vae_ckpt}"
184
  else:
 
197
  )
198
 
199
  # ---- 2) Qwen2-Audio model_path ----
 
200
  exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
201
 
202
 
203
  # ---------------------------------------------------------
204
  # Scheduler(与你 exp_cfg.model.noise_scheduler_name 对齐)
205
+ # 注意:有些 repo_id 不存在 scheduler 子目录会 404。
206
+ # 这里给一个 fallback,避免直接炸。
207
  # ---------------------------------------------------------
208
  def build_scheduler(exp_cfg: Dict[str, Any]):
209
  name = exp_cfg["model"].get("noise_scheduler_name", "stabilityai/stable-diffusion-2-1")
210
+ try:
211
+ scheduler = noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler", token=HF_TOKEN)
212
+ return scheduler
213
+ except Exception as e:
214
+ logger.warning(f"DDIMScheduler.from_pretrained failed for '{name}', fallback to default DDIM config. err={e}")
215
+ # fallback:不依赖远端 repo
216
+ return noise_schedulers.DDIMScheduler(
217
+ num_train_timesteps=1000,
218
+ beta_start=0.00085,
219
+ beta_end=0.012,
220
+ beta_schedule="scaled_linear",
221
+ clip_sample=False,
222
+ set_alpha_to_one=False,
223
+ steps_offset=1,
224
+ )
225
 
226
 
227
  def _amp_ctx(device: torch.device):
228
+ # ZeroGPU:只有在 device=cuda 且你明确开启 USE_AMP 才 autocast
229
  if not USE_AMP:
230
  return torch.autocast("cuda", enabled=False)
231
  if device.type != "cuda":
 
235
 
236
 
237
  # ---------------------------------------------------------
238
+ # 冷启动:load+cache pipeline(缓存 CPU 上的 model)
239
+ # ZeroGPU 启动阶段一般没有 CUDA,所以这里不要 model.to("cuda")
240
  # ---------------------------------------------------------
241
+ def load_pipeline_cpu() -> Tuple[LoadPretrainedBase, object, int]:
242
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
243
  if cache_key in _PIPELINE_CACHE:
244
  return _PIPELINE_CACHE[cache_key]
 
246
  repo_root, qwen_root = resolve_model_dirs()
247
  assert_repo_layout(repo_root)
248
 
 
249
  logger.info(f"repo_root = {repo_root}")
 
250
  logger.info(f"qwen_root = {qwen_root}")
251
+ logger.info(f"torch.cuda.is_available (startup) = {torch.cuda.is_available()}")
252
 
253
  exp_cfg = OmegaConf.load(repo_root / "config.yaml")
254
  exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
 
257
  logger.info(f"patched pretrained_ckpt = {exp_cfg['model']['autoencoder'].get('pretrained_ckpt')}")
258
  logger.info(f"patched qwen model_path = {exp_cfg['model']['content_encoder']['text_encoder'].get('model_path')}")
259
 
260
+ # instantiate model(在 CPU 上构建)
261
  model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
262
 
263
+ # load weights(你的 mmedit 权重)
264
  ckpt_path = repo_root / "model.safetensors"
265
  sd = load_file(str(ckpt_path))
266
  model.load_pretrained(sd)
267
 
268
+ # 强制留在 CPU(ZeroGPU 关键)
269
+ model = model.to(torch.device("cpu")).eval()
270
 
271
  scheduler = build_scheduler(exp_cfg)
272
  target_sr = int(exp_cfg.get("sample_rate", 24000))
273
 
274
+ _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
275
+ logger.info("CPU pipeline loaded and cached.")
276
+ return model, scheduler, target_sr
277
 
278
 
279
  # ---------------------------------------------------------
280
  # 推理:audio + caption -> edited audio
281
+ # ZeroGPU:必须用 @spaces.GPU
282
+ # 并且:函数内再把模型搬到 cuda,推完搬回 cpu
283
  # ---------------------------------------------------------
284
+ @spaces.GPU
285
  @torch.no_grad()
286
  def run_edit(
287
  audio_file: str,
 
298
  if not caption:
299
  return None, "Error: caption is empty."
300
 
301
+ # 1) CPU 缓存
302
+ model_cpu, scheduler, target_sr = load_pipeline_cpu()
303
+
304
+ # 2) ZeroGPU 进入 GPU 区域后,cuda 才会 available
305
+ if not torch.cuda.is_available():
306
+ return None, "Error: ZeroGPU did not allocate CUDA. Please retry (queue) or check Space hardware."
307
+
308
+ device = torch.device("cuda")
309
+ logger.info(f"[GPU] torch.cuda.is_available={torch.cuda.is_available()}, device={device}")
310
+
311
+ # 3) 把模型搬到 GPU(临时)
312
+ model = model_cpu.to(device).eval()
313
 
314
+ # seed
315
  seed = int(seed)
316
  torch.manual_seed(seed)
317
  np.random.seed(seed)
318
 
319
+ # audio preprocess
320
  wav = load_and_process_audio(audio_file, target_sr=target_sr).to(device)
321
 
322
  batch = {
 
325
  "task": ["audio_editing"],
326
  }
327
 
328
+ # infer.config 对齐
329
  kwargs = {
330
  "num_steps": int(num_steps),
331
  "guidance_scale": float(guidance_scale),
 
344
  out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
345
  sf.write(str(out_path), out_audio, samplerate=target_sr)
346
 
347
+ # 4) 推完立刻把模型搬回 CPU(ZeroGPU 关键:避免缓存里残留 cuda tensor)
348
+ model_cpu = model.to("cpu")
349
+ del model
350
+ torch.cuda.empty_cache()
351
+
352
+ # 5) 更新缓存(仍然缓存 CPU 版本)
353
+ cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
354
+ _PIPELINE_CACHE[cache_key] = (model_cpu, scheduler, target_sr)
355
+
356
  return str(out_path), f"OK | saved={out_path.name} | time={dt:.2f}s | sr={target_sr} | seed={seed}"
357
 
358
 
 
360
  # UI
361
  # ---------------------------------------------------------
362
  def build_demo():
363
+ with gr.Blocks(title="MMEdit (ZeroGPU)") as demo:
364
+ gr.Markdown("# MMEdit ZeroGPU(audio + caption → edited audio)")
365
+
 
 
366
 
367
  with gr.Row():
368
  with gr.Column():
369
  audio_in = gr.Audio(label="Input Audio", type="filepath")
370
  caption = gr.Textbox(label="Caption (Edit Instruction)", lines=3)
371
 
372
+ # 注意:Spaces 不允许你 push 大的 wav 示例。
373
+ # 最稳的方式:你自己在 Space repo 放一个很小的 demo wav(几百 KB)。
374
  gr.Examples(
375
  label="example inputs",
376
  examples=[
377
+ ["./Ym8O802VvJes.wav", "Mix in dog barking in the middle."],
378
  ],
379
  inputs=[audio_in, caption],
380
+ cache_examples=False,
381
  )
382
 
383
  with gr.Row():
 
402
 
403
  gr.Markdown(
404
  "## 注意事项\n"
405
+ "1) ZeroGPU 首次点击会分配 GPU,可能稍慢。\n"
406
+ "2) 如果遇到错误,请重试(尤其是首次启动时)。\n"
407
+ "3) 原始音频保留可能有bug\n"
408
  )
 
409
  return demo
410
 
411
 
 
412
  if __name__ == "__main__":
413
  demo = build_demo()
414
+ port = int(os.environ.get("PORT", "7860"))
415
+ # ZeroGPU:强烈建议 queue;并禁用 SSR 更稳
416
+ demo.queue().launch(
417
+ server_name="0.0.0.0",
418
+ server_port=port,
419
+ share=False,
420
+ ssr_mode=False,
421
+ )