CocoBro commited on
Commit
77f1338
·
1 Parent(s): 47b5ec4

fix load gpu

Browse files
Files changed (1) hide show
  1. app.py +95 -210
app.py CHANGED
@@ -172,24 +172,6 @@ def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_roo
172
  # Scheduler(与你 exp_cfg.model.noise_scheduler_name 对齐)
173
  # 带 fallback:避免 404
174
  # ---------------------------------------------------------
175
- def build_scheduler(exp_cfg: Dict[str, Any]):
176
- import diffusers.schedulers as noise_schedulers
177
-
178
- name = exp_cfg["model"].get("noise_scheduler_name", "stabilityai/stable-diffusion-2-1")
179
- try:
180
- scheduler = noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler", token=HF_TOKEN)
181
- return scheduler
182
- except Exception as e:
183
- logger.warning(f"DDIMScheduler.from_pretrained failed for '{name}', fallback. err={e}")
184
- return noise_schedulers.DDIMScheduler(
185
- num_train_timesteps=1000,
186
- beta_start=0.00085,
187
- beta_end=0.012,
188
- beta_schedule="scaled_linear",
189
- clip_sample=False,
190
- set_alpha_to_one=False,
191
- steps_offset=1,
192
- )
193
 
194
 
195
  def amp_autocast(device):
@@ -205,157 +187,88 @@ def amp_autocast(device):
205
  return torch.autocast("cuda", dtype=dtype, enabled=True)
206
 
207
 
208
- # ---------------------------------------------------------
209
- # 冷启动:load+cache pipeline(缓存 CPU 上的 model)
210
- # ---------------------------------------------------------
211
- # def load_pipeline_cpu() -> Tuple[object, object, int]:
212
- # # 延迟导入(避免启动阶段触发 CUDA 初始化)
213
- # import torch
214
- # import hydra
215
- # from omegaconf import OmegaConf
216
- # from safetensors.torch import load_file
217
-
218
- # # 你的项目依赖也延迟导入
219
- # from models.common import LoadPretrainedBase
220
- # from utils.config import register_omegaconf_resolvers
221
-
222
- # register_omegaconf_resolvers()
223
-
224
- # cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
225
- # if cache_key in _PIPELINE_CACHE:
226
- # return _PIPELINE_CACHE[cache_key]
227
 
228
- # repo_root, qwen_root = resolve_model_dirs()
229
- # assert_repo_layout(repo_root)
230
-
231
- # logger.info(f"repo_root = {repo_root}")
232
- # logger.info(f"qwen_root = {qwen_root}")
233
-
234
- # exp_cfg = OmegaConf.load(repo_root / "config.yaml")
235
- # exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
236
-
237
- # patch_paths_in_exp_config(exp_cfg, repo_root, qwen_root)
238
- # logger.info(f"patched pretrained_ckpt = {exp_cfg['model']['autoencoder'].get('pretrained_ckpt')}")
239
- # logger.info(f"patched qwen model_path = {exp_cfg['model']['content_encoder']['text_encoder'].get('model_path')}")
240
-
241
- # model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
242
-
243
- # ckpt_path = repo_root / "model.safetensors"
244
- # sd = load_file(str(ckpt_path))
245
- # model.load_pretrained(sd)
246
- # logger.info(f"Model loaded from safetensors: {ckpt_path}")
247
- # # ZeroGPU:缓存 CPU 版
248
- # model = model.to(torch.device("cpu")).eval()
249
-
250
- # scheduler = build_scheduler(exp_cfg)
251
- # target_sr = int(exp_cfg.get("sample_rate", 24000))
252
-
253
- # _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
254
- # logger.info("CPU pipeline loaded and cached.")
255
- # return model, scheduler, target_sr
256
- def load_pipeline_cpu():
257
- # 延迟导入
258
  import torch
259
  import hydra
260
  from omegaconf import OmegaConf
261
  from safetensors.torch import load_file
 
262
 
263
- # 尝试导入项目模块
264
  try:
265
  from utils.config import register_omegaconf_resolvers
266
  register_omegaconf_resolvers()
267
  except: pass
268
 
269
- cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
270
- if cache_key in _PIPELINE_CACHE: return _PIPELINE_CACHE[cache_key]
271
-
272
- repo_root, qwen_root = resolve_model_dirs()
273
-
274
- # 加载 Config
275
- exp_cfg = OmegaConf.to_container(OmegaConf.load(repo_root / "config.yaml"), resolve=True)
276
-
277
- # 路径修复
278
- vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "")
279
- if vae_ckpt:
280
- potential_paths = [repo_root / "vae" / Path(vae_ckpt).name, repo_root / Path(vae_ckpt).name]
281
- for p in potential_paths:
282
- if p.exists():
283
- exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p)
284
- break
285
-
286
- exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
287
-
288
- logger.info("Instantiating model...")
289
- model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
290
-
291
- # 加载权重并立即释放 state_dict 内存
292
- ckpt_path = str(repo_root / "model.safetensors")
293
- logger.info(f"Loading state_dict from {ckpt_path}...")
294
- sd = load_file(ckpt_path)
295
-
296
- logger.info(f"Model loaded from safetensors: {ckpt_path}")
297
- model.load_pretrained(sd)
298
- del sd # <--- 关键:立即删除 state_dict 释放 20GB+ 内存
299
- gc.collect() # <--- 关键:强制回收
300
-
301
- # 确保在 CPU
302
- model = model.to("cpu").eval()
303
-
304
- # Scheduler
305
- import diffusers.schedulers as noise_schedulers
306
- try:
307
- scheduler = noise_schedulers.DDIMScheduler.from_pretrained(
308
- exp_cfg["model"].get("noise_scheduler_name", "stabilityai/stable-diffusion-2-1"),
309
- subfolder="scheduler", token=HF_TOKEN
310
- )
311
- except:
312
- scheduler = noise_schedulers.DDIMScheduler(num_train_timesteps=1000)
313
-
314
- target_sr = int(exp_cfg.get("sample_rate", 24000))
315
- _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
316
- return model, scheduler, target_sr
317
-
318
- # ---------------------------------------------------------
319
- # 推理:audio + caption -> edited audio
320
- # ZeroGPU:必须用 @spaces.GPU
321
- # ---------------------------------------------------------
322
- # ---------------------------------------------------------
323
- @spaces.GPU
324
- def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed):
325
- import torch
326
-
327
  if not audio_file: return None, "Please upload audio."
328
- if not caption: return None, "Please input caption."
329
 
330
- # 局部变量初始化,防 finally 报错
331
- model_cpu = None
332
- model_on_gpu = None
333
 
334
  try:
335
- # --- 1. 将加载过程放入 try 块保护 ---
336
- logger.info("Loading pipeline (CPU)...")
337
- model_cpu, scheduler, target_sr = load_pipeline_cpu()
338
-
339
- # --- 2. 准备 GPU 环境 ---
340
- device = torch.device("cuda")
341
- dtype = torch.float16
342
 
343
- if not torch.cuda.is_available():
344
- raise RuntimeError("ZeroGPU assigned but CUDA not found!")
345
-
346
- # --- 3. 搬运 (CPU -> GPU) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  gc.collect()
348
- torch.cuda.empty_cache()
349
- logger.info("Moving model to GPU...")
 
 
 
 
350
 
351
- # 原位操作,finally 必须移回
352
- model_on_gpu = model_cpu.to(device, dtype=dtype)
353
 
354
- # --- 4. 数据处理 ---
 
 
 
 
 
 
 
 
 
 
 
 
355
  torch.manual_seed(int(seed))
356
  np.random.seed(int(seed))
357
 
358
- wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=dtype)
359
 
360
  batch = {
361
  "audio_id": [Path(audio_file).stem],
@@ -368,13 +281,14 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
368
  "mask_time_aligned_content": False
369
  }
370
 
371
- # --- 5. 推理 ---
372
- logger.info("Running inference...")
373
  t0 = time.time()
374
- with torch.no_grad(), torch.autocast("cuda", dtype=dtype):
375
- out = model_on_gpu.inference(scheduler=scheduler, **batch)
376
 
377
- # --- 6. 保存 ---
 
 
378
  out_audio = out[0, 0].detach().float().cpu().numpy()
379
  out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
380
  sf.write(str(out_path), out_audio, samplerate=target_sr)
@@ -382,80 +296,51 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
382
  return str(out_path), f"Success | {time.time()-t0:.2f}s"
383
 
384
  except Exception as e:
385
- # 🔥 现在你可以看到真正的报错了!
386
- err_msg = traceback.format_exc()
387
- logger.error(f" ERROR:\n{err_msg}")
388
- return None, f"Error: {str(e)}\n(Check Logs for Traceback)"
389
 
390
  finally:
391
- # --- 7. 还原现场 ---
392
- logger.info("Restoring CPU state...")
393
- try:
394
- if model_cpu is not None:
395
- model_cpu.to("cpu")
396
- except Exception as e:
397
- logger.error(f"Restore failed: {e}")
398
-
399
- if model_on_gpu is not None: del model_on_gpu
400
  torch.cuda.empty_cache()
401
  gc.collect()
402
- # ---------------------------------------------------------
 
403
  # UI
404
- # ---------------------------------------------------------
405
  def build_demo():
406
- with gr.Blocks(title="MMEdit (ZeroGPU)") as demo:
407
- gr.Markdown("# MMEdit ZeroGPU(audio + caption → edited audio)")
408
-
409
  with gr.Row():
410
  with gr.Column():
411
- audio_in = gr.Audio(label="Input Audio", type="filepath")
412
- caption = gr.Textbox(label="Caption (Edit Instruction)", lines=3)
413
-
414
- # 注意:Space 不建议推大 wav;你可以换成更小的 demo wav
415
  gr.Examples(
416
- label="example inputs",
417
- examples=[
418
- ["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."],
419
- ],
420
  inputs=[audio_in, caption],
421
- cache_examples=False,
422
  )
423
-
424
- with gr.Row():
425
- num_steps = gr.Slider(1, 100, value=50, step=1, label="num_steps")
426
- guidance_scale = gr.Slider(1.0, 12.0, value=5.0, step=0.5, label="guidance_scale")
427
-
428
  with gr.Row():
429
- guidance_rescale = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="guidance_rescale")
430
- seed = gr.Number(value=42, precision=0, label="seed")
431
-
432
- run_btn = gr.Button("Run Editing", variant="primary")
433
-
 
434
  with gr.Column():
435
- audio_out = gr.Audio(label="Edited Audio", type="filepath")
436
  status = gr.Textbox(label="Status")
437
-
438
- run_btn.click(
439
- fn=run_edit,
440
- inputs=[audio_in, caption, num_steps, guidance_scale, guidance_rescale, seed],
441
- outputs=[audio_out, status],
442
- )
443
-
444
- gr.Markdown(
445
- "## 注意事项\n"
446
- "1) ZeroGPU 首次点击会分配 GPU,可能稍慢。\n"
447
- "2) 如果首次报 cuda 不可用,通常重试一次即可。\n"
448
- )
449
-
450
  return demo
451
 
452
-
453
  if __name__ == "__main__":
454
  demo = build_demo()
455
- port = int(os.environ.get("PORT", "7860"))
456
  demo.queue().launch(
457
- server_name="0.0.0.0",
458
- server_port=port,
459
- share=False,
460
- ssr_mode=False,
461
- )
 
172
  # Scheduler(与你 exp_cfg.model.noise_scheduler_name 对齐)
173
  # 带 fallback:避免 404
174
  # ---------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
 
177
  def amp_autocast(device):
 
187
  return torch.autocast("cuda", dtype=dtype, enabled=True)
188
 
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ # -----------------------------
192
+ # ZeroGPU 核心任务
193
+ # -----------------------------
194
+ # 学长说的就是这里:所有费资源的操作(加载+推理)都要放在这里面
195
+ @spaces.GPU(duration=150)
196
+ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed):
197
+ # 延迟导入,防止全局污染
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  import torch
199
  import hydra
200
  from omegaconf import OmegaConf
201
  from safetensors.torch import load_file
202
+ import diffusers.schedulers as noise_schedulers
203
 
204
+ # 尝试导入项目配置
205
  try:
206
  from utils.config import register_omegaconf_resolvers
207
  register_omegaconf_resolvers()
208
  except: pass
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  if not audio_file: return None, "Please upload audio."
 
211
 
212
+ # 局部变量,用于 finally 清理
213
+ model = None
 
214
 
215
  try:
216
+ # ==========================================
217
+ # 1. 就在这里加载模型!利用 ZeroGPU 的大内存
218
+ # ==========================================
219
+ logger.info("🚀 Starting ZeroGPU Task...")
 
 
 
220
 
221
+ # 路径准备
222
+ repo_root, qwen_root = resolve_model_dirs()
223
+ exp_cfg = OmegaConf.to_container(OmegaConf.load(repo_root / "config.yaml"), resolve=True)
224
+
225
+ # 路径修复逻辑
226
+ vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "")
227
+ if vae_ckpt:
228
+ p1 = repo_root / "vae" / Path(vae_ckpt).name
229
+ p2 = repo_root / Path(vae_ckpt).name
230
+ if p1.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p1)
231
+ elif p2.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p2)
232
+ exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
233
+
234
+ # 实例化模型 (此时消耗大量 CPU 内存,但 ZeroGPU 环境扛得住)
235
+ logger.info("Instantiating model (Hydra)...")
236
+ model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
237
+
238
+ # 加载权重
239
+ ckpt_path = str(repo_root / "model.safetensors")
240
+ logger.info(f"Loading weights from {ckpt_path}...")
241
+ sd = load_file(ckpt_path)
242
+ model.load_pretrained(sd)
243
+ del sd # 立即释放
244
  gc.collect()
245
+
246
+ # ==========================================
247
+ # 2. 立即转到 GPU (FP16)
248
+ # ==========================================
249
+ device = torch.device("cuda")
250
+ logger.info("Moving model to CUDA (FP16)...")
251
 
252
+ # 这一步将模型送入显卡
253
+ model = model.to(device, dtype=torch.float16).eval()
254
 
255
+ # Scheduler
256
+ try:
257
+ scheduler = noise_schedulers.DDIMScheduler.from_pretrained(
258
+ exp_cfg["model"].get("noise_scheduler_name", ""),
259
+ subfolder="scheduler", token=HF_TOKEN
260
+ )
261
+ except:
262
+ scheduler = noise_schedulers.DDIMScheduler(num_train_timesteps=1000)
263
+
264
+ # ==========================================
265
+ # 3. 开始推理
266
+ # ==========================================
267
+ target_sr = int(exp_cfg.get("sample_rate", 24000))
268
  torch.manual_seed(int(seed))
269
  np.random.seed(int(seed))
270
 
271
+ wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=torch.float16)
272
 
273
  batch = {
274
  "audio_id": [Path(audio_file).stem],
 
281
  "mask_time_aligned_content": False
282
  }
283
 
284
+ logger.info("Inference running...")
 
285
  t0 = time.time()
286
+ with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
287
+ out = model.inference(scheduler=scheduler, **batch)
288
 
289
+ # ==========================================
290
+ # 4. 保存结果
291
+ # ==========================================
292
  out_audio = out[0, 0].detach().float().cpu().numpy()
293
  out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
294
  sf.write(str(out_path), out_audio, samplerate=target_sr)
 
296
  return str(out_path), f"Success | {time.time()-t0:.2f}s"
297
 
298
  except Exception as e:
299
+ err = traceback.format_exc()
300
+ logger.error(f"❌ ERROR:\n{err}")
301
+ return None, f"Runtime Error: {e}"
 
302
 
303
  finally:
304
+ # 强制清理,防止下一次任务显存不够
305
+ logger.info("Cleaning up...")
306
+ if model is not None: del model
 
 
 
 
 
 
307
  torch.cuda.empty_cache()
308
  gc.collect()
309
+
310
+ # -----------------------------
311
  # UI
312
+ # -----------------------------
313
  def build_demo():
314
+ with gr.Blocks(title="MMEdit") as demo:
315
+ gr.Markdown("# MMEdit ZeroGPU (Direct Load)")
 
316
  with gr.Row():
317
  with gr.Column():
318
+ audio_in = gr.Audio(label="Input", type="filepath")
319
+ caption = gr.Textbox(label="Instruction", lines=3)
 
 
320
  gr.Examples(
321
+ label="Examples",
322
+ examples=[["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."]],
 
 
323
  inputs=[audio_in, caption],
 
324
  )
 
 
 
 
 
325
  with gr.Row():
326
+ num_steps = gr.Slider(10, 100, 50, step=1, label="Steps")
327
+ guidance_scale = gr.Slider(1.0, 12.0, 5.0, step=0.5, label="Guidance")
328
+ guidance_rescale = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Rescale")
329
+ seed = gr.Number(42, label="Seed")
330
+ run_btn = gr.Button("Run", variant="primary")
331
+
332
  with gr.Column():
333
+ out = gr.Audio(label="Output")
334
  status = gr.Textbox(label="Status")
335
+
336
+ run_btn.click(run_edit, [audio_in, caption, num_steps, guidance_scale, guidance_rescale, seed], [out, status])
 
 
 
 
 
 
 
 
 
 
 
337
  return demo
338
 
 
339
  if __name__ == "__main__":
340
  demo = build_demo()
341
+ # 必须 ssr_mode=False
342
  demo.queue().launch(
343
+ server_name="0.0.0.0",
344
+ server_port=int(os.environ.get("PORT", 7860)),
345
+ ssr_mode=False
346
+ )