CocoBro commited on
Commit
47b5ec4
·
1 Parent(s): 92bc756
Files changed (1) hide show
  1. app.py +128 -87
app.py CHANGED
@@ -9,7 +9,7 @@ import time
9
  import logging
10
  from pathlib import Path
11
  from typing import Tuple, Optional, Dict, Any
12
-
13
  import gradio as gr
14
  import numpy as np
15
  import soundfile as sf
@@ -208,156 +208,197 @@ def amp_autocast(device):
208
  # ---------------------------------------------------------
209
  # 冷启动:load+cache pipeline(缓存 CPU 上的 model)
210
  # ---------------------------------------------------------
211
- def load_pipeline_cpu() -> Tuple[object, object, int]:
212
- # 延迟导入(避免启动阶段触发 CUDA 初始化)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  import torch
214
  import hydra
215
  from omegaconf import OmegaConf
216
  from safetensors.torch import load_file
217
-
218
- # 你的项目依赖也延迟导入
219
- from models.common import LoadPretrainedBase
220
- from utils.config import register_omegaconf_resolvers
221
-
222
- register_omegaconf_resolvers()
223
 
224
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
225
- if cache_key in _PIPELINE_CACHE:
226
- return _PIPELINE_CACHE[cache_key]
227
 
228
  repo_root, qwen_root = resolve_model_dirs()
229
- assert_repo_layout(repo_root)
230
-
231
- logger.info(f"repo_root = {repo_root}")
232
- logger.info(f"qwen_root = {qwen_root}")
233
-
234
- exp_cfg = OmegaConf.load(repo_root / "config.yaml")
235
- exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
236
-
237
- patch_paths_in_exp_config(exp_cfg, repo_root, qwen_root)
238
- logger.info(f"patched pretrained_ckpt = {exp_cfg['model']['autoencoder'].get('pretrained_ckpt')}")
239
- logger.info(f"patched qwen model_path = {exp_cfg['model']['content_encoder']['text_encoder'].get('model_path')}")
 
 
 
240
 
241
- model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
 
 
 
 
 
 
242
 
243
- ckpt_path = repo_root / "model.safetensors"
244
- sd = load_file(str(ckpt_path))
245
  model.load_pretrained(sd)
 
 
246
 
247
- # ZeroGPU:缓��� CPU
248
- model = model.to(torch.device("cpu")).eval()
 
 
 
 
 
 
 
 
 
 
249
 
250
- scheduler = build_scheduler(exp_cfg)
251
  target_sr = int(exp_cfg.get("sample_rate", 24000))
252
-
253
  _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
254
- logger.info("CPU pipeline loaded and cached.")
255
  return model, scheduler, target_sr
256
 
257
-
258
  # ---------------------------------------------------------
259
  # 推理:audio + caption -> edited audio
260
  # ZeroGPU:必须用 @spaces.GPU
261
  # ---------------------------------------------------------
262
  # ---------------------------------------------------------
263
  @spaces.GPU
264
- def run_edit(
265
- audio_file: str,
266
- caption: str,
267
- num_steps: int,
268
- guidance_scale: float,
269
- guidance_rescale: float,
270
- seed: int,
271
- ) -> Tuple[Optional[str], str]:
272
  import torch
273
- import gc
274
-
275
- if not audio_file: return None, "Error: Upload audio first."
276
- if not caption: return None, "Error: Input caption."
277
-
278
- # 1. 获取 CPU 模型
279
- model_cpu, scheduler, target_sr = load_pipeline_cpu()
280
-
281
- # 2. 准备设备 (强制 float16 以防 OOM 和兼容问题)
282
- device = torch.device("cuda")
283
- dtype = torch.float16
284
 
285
- logger.info(f"🚀 [GPU Start] Device: {device}, Dtype: {dtype}")
286
-
 
 
 
287
  model_on_gpu = None
288
 
289
  try:
 
 
 
 
 
 
 
 
290
  if not torch.cuda.is_available():
291
  raise RuntimeError("ZeroGPU assigned but CUDA not found!")
292
 
293
- # --- 3. 搬运模型 (CPU -> GPU) ---
294
  gc.collect()
295
  torch.cuda.empty_cache()
296
-
297
  logger.info("Moving model to GPU...")
298
- # [关键] 原位操作警告:model_cpu.to() 会改变 cpu 对象
299
- # 我们必须在 finally 里搬回去!
300
  model_on_gpu = model_cpu.to(device, dtype=dtype)
301
-
302
- # --- 4. 数据准备 ---
303
  torch.manual_seed(int(seed))
304
  np.random.seed(int(seed))
305
 
306
- wav = load_and_process_audio(audio_file, target_sr=target_sr).to(device, dtype=dtype)
307
 
308
  batch = {
309
  "audio_id": [Path(audio_file).stem],
310
  "content": [{"audio": wav, "caption": caption}],
311
  "task": ["audio_editing"],
312
- }
313
-
314
- kwargs = {
315
  "num_steps": int(num_steps),
316
  "guidance_scale": float(guidance_scale),
317
  "guidance_rescale": float(guidance_rescale),
318
  "use_gt_duration": False,
319
- "mask_time_aligned_content": False,
320
- **batch
321
  }
322
-
323
  # --- 5. 推理 ---
324
- logger.info("Inference start...")
325
  t0 = time.time()
326
- with torch.no_grad():
327
- with torch.autocast("cuda", dtype=dtype):
328
- out = model_on_gpu.inference(scheduler=scheduler, **kwargs)
329
- dt = time.time() - t0
330
- logger.info(f"✅ Inference done: {dt:.2f}s")
331
-
332
- # --- 6. 结果保存 ---
333
  out_audio = out[0, 0].detach().float().cpu().numpy()
334
  out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
335
  sf.write(str(out_path), out_audio, samplerate=target_sr)
336
-
337
- return str(out_path), f"OK | {dt:.2f}s | Seed: {seed}"
338
 
339
  except Exception as e:
340
- # [关键] 打印完整堆栈,不再报 404
341
  err_msg = traceback.format_exc()
342
  logger.error(f"❌ ERROR:\n{err_msg}")
343
- return None, f"Runtime Error: {str(e)}\nCheck Logs."
344
-
345
  finally:
346
- # --- 7. [关键] 现场恢复 ---
347
- logger.info("♻️ Restoring CPU state...")
348
  try:
349
- # 必须搬回 CPU,否则缓存中的指针指向已释放的显存,下次必崩
350
- if 'model_cpu' in locals() and model_cpu is not None:
351
  model_cpu.to("cpu")
352
- logger.info("Model restored to CPU.")
353
  except Exception as e:
354
- logger.error(f"Failed to restore model: {e}")
355
 
356
- # 清理显存
357
- if 'model_on_gpu' in locals(): del model_on_gpu
358
  torch.cuda.empty_cache()
359
  gc.collect()
360
-
361
  # ---------------------------------------------------------
362
  # UI
363
  # ---------------------------------------------------------
 
9
  import logging
10
  from pathlib import Path
11
  from typing import Tuple, Optional, Dict, Any
12
+ import gc
13
  import gradio as gr
14
  import numpy as np
15
  import soundfile as sf
 
208
  # ---------------------------------------------------------
209
  # 冷启动:load+cache pipeline(缓存 CPU 上的 model)
210
  # ---------------------------------------------------------
211
+ # def load_pipeline_cpu() -> Tuple[object, object, int]:
212
+ # # 延迟导入(避免启动阶段触发 CUDA 初始化)
213
+ # import torch
214
+ # import hydra
215
+ # from omegaconf import OmegaConf
216
+ # from safetensors.torch import load_file
217
+
218
+ # # 你的项目依赖也延迟导入
219
+ # from models.common import LoadPretrainedBase
220
+ # from utils.config import register_omegaconf_resolvers
221
+
222
+ # register_omegaconf_resolvers()
223
+
224
+ # cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
225
+ # if cache_key in _PIPELINE_CACHE:
226
+ # return _PIPELINE_CACHE[cache_key]
227
+
228
+ # repo_root, qwen_root = resolve_model_dirs()
229
+ # assert_repo_layout(repo_root)
230
+
231
+ # logger.info(f"repo_root = {repo_root}")
232
+ # logger.info(f"qwen_root = {qwen_root}")
233
+
234
+ # exp_cfg = OmegaConf.load(repo_root / "config.yaml")
235
+ # exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
236
+
237
+ # patch_paths_in_exp_config(exp_cfg, repo_root, qwen_root)
238
+ # logger.info(f"patched pretrained_ckpt = {exp_cfg['model']['autoencoder'].get('pretrained_ckpt')}")
239
+ # logger.info(f"patched qwen model_path = {exp_cfg['model']['content_encoder']['text_encoder'].get('model_path')}")
240
+
241
+ # model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
242
+
243
+ # ckpt_path = repo_root / "model.safetensors"
244
+ # sd = load_file(str(ckpt_path))
245
+ # model.load_pretrained(sd)
246
+ # logger.info(f"Model loaded from safetensors: {ckpt_path}")
247
+ # # ZeroGPU:缓存 CPU 版
248
+ # model = model.to(torch.device("cpu")).eval()
249
+
250
+ # scheduler = build_scheduler(exp_cfg)
251
+ # target_sr = int(exp_cfg.get("sample_rate", 24000))
252
+
253
+ # _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
254
+ # logger.info("CPU pipeline loaded and cached.")
255
+ # return model, scheduler, target_sr
256
+ def load_pipeline_cpu():
257
+ # 延迟导入
258
  import torch
259
  import hydra
260
  from omegaconf import OmegaConf
261
  from safetensors.torch import load_file
262
+
263
+ # 尝试导入项目模块
264
+ try:
265
+ from utils.config import register_omegaconf_resolvers
266
+ register_omegaconf_resolvers()
267
+ except: pass
268
 
269
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
270
+ if cache_key in _PIPELINE_CACHE: return _PIPELINE_CACHE[cache_key]
 
271
 
272
  repo_root, qwen_root = resolve_model_dirs()
273
+
274
+ # 加载 Config
275
+ exp_cfg = OmegaConf.to_container(OmegaConf.load(repo_root / "config.yaml"), resolve=True)
276
+
277
+ # 路径修复
278
+ vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "")
279
+ if vae_ckpt:
280
+ potential_paths = [repo_root / "vae" / Path(vae_ckpt).name, repo_root / Path(vae_ckpt).name]
281
+ for p in potential_paths:
282
+ if p.exists():
283
+ exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p)
284
+ break
285
+
286
+ exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
287
 
288
+ logger.info("Instantiating model...")
289
+ model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
290
+
291
+ # 加载权重并立即释放 state_dict 内存
292
+ ckpt_path = str(repo_root / "model.safetensors")
293
+ logger.info(f"Loading state_dict from {ckpt_path}...")
294
+ sd = load_file(ckpt_path)
295
 
296
+ logger.info(f"Model loaded from safetensors: {ckpt_path}")
 
297
  model.load_pretrained(sd)
298
+ del sd # <--- 关键:立即删除 state_dict 释放 20GB+ 内存
299
+ gc.collect() # <--- 关键:强制回收
300
 
301
+ # 确保在 CPU
302
+ model = model.to("cpu").eval()
303
+
304
+ # Scheduler
305
+ import diffusers.schedulers as noise_schedulers
306
+ try:
307
+ scheduler = noise_schedulers.DDIMScheduler.from_pretrained(
308
+ exp_cfg["model"].get("noise_scheduler_name", "stabilityai/stable-diffusion-2-1"),
309
+ subfolder="scheduler", token=HF_TOKEN
310
+ )
311
+ except:
312
+ scheduler = noise_schedulers.DDIMScheduler(num_train_timesteps=1000)
313
 
 
314
  target_sr = int(exp_cfg.get("sample_rate", 24000))
 
315
  _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
 
316
  return model, scheduler, target_sr
317
 
 
318
  # ---------------------------------------------------------
319
  # 推理:audio + caption -> edited audio
320
  # ZeroGPU:必须用 @spaces.GPU
321
  # ---------------------------------------------------------
322
  # ---------------------------------------------------------
323
  @spaces.GPU
324
+ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed):
 
 
 
 
 
 
 
325
  import torch
 
 
 
 
 
 
 
 
 
 
 
326
 
327
+ if not audio_file: return None, "Please upload audio."
328
+ if not caption: return None, "Please input caption."
329
+
330
+ # 局部变量初始化,防 finally 报错
331
+ model_cpu = None
332
  model_on_gpu = None
333
 
334
  try:
335
+ # --- 1. 将加载过程放入 try 块保护 ---
336
+ logger.info("Loading pipeline (CPU)...")
337
+ model_cpu, scheduler, target_sr = load_pipeline_cpu()
338
+
339
+ # --- 2. 准备 GPU 环境 ---
340
+ device = torch.device("cuda")
341
+ dtype = torch.float16
342
+
343
  if not torch.cuda.is_available():
344
  raise RuntimeError("ZeroGPU assigned but CUDA not found!")
345
 
346
+ # --- 3. 搬运 (CPU -> GPU) ---
347
  gc.collect()
348
  torch.cuda.empty_cache()
 
349
  logger.info("Moving model to GPU...")
350
+
351
+ # 原位操作,finally 必须移回
352
  model_on_gpu = model_cpu.to(device, dtype=dtype)
353
+
354
+ # --- 4. 数据处理 ---
355
  torch.manual_seed(int(seed))
356
  np.random.seed(int(seed))
357
 
358
+ wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=dtype)
359
 
360
  batch = {
361
  "audio_id": [Path(audio_file).stem],
362
  "content": [{"audio": wav, "caption": caption}],
363
  "task": ["audio_editing"],
 
 
 
364
  "num_steps": int(num_steps),
365
  "guidance_scale": float(guidance_scale),
366
  "guidance_rescale": float(guidance_rescale),
367
  "use_gt_duration": False,
368
+ "mask_time_aligned_content": False
 
369
  }
370
+
371
  # --- 5. 推理 ---
372
+ logger.info("Running inference...")
373
  t0 = time.time()
374
+ with torch.no_grad(), torch.autocast("cuda", dtype=dtype):
375
+ out = model_on_gpu.inference(scheduler=scheduler, **batch)
376
+
377
+ # --- 6. 保存 ---
 
 
 
378
  out_audio = out[0, 0].detach().float().cpu().numpy()
379
  out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
380
  sf.write(str(out_path), out_audio, samplerate=target_sr)
381
+
382
+ return str(out_path), f"Success | {time.time()-t0:.2f}s"
383
 
384
  except Exception as e:
385
+ # 🔥 现在你可以看到真正的报错了!
386
  err_msg = traceback.format_exc()
387
  logger.error(f"❌ ERROR:\n{err_msg}")
388
+ return None, f"Error: {str(e)}\n(Check Logs for Traceback)"
389
+
390
  finally:
391
+ # --- 7. 还原现场 ---
392
+ logger.info("Restoring CPU state...")
393
  try:
394
+ if model_cpu is not None:
 
395
  model_cpu.to("cpu")
 
396
  except Exception as e:
397
+ logger.error(f"Restore failed: {e}")
398
 
399
+ if model_on_gpu is not None: del model_on_gpu
 
400
  torch.cuda.empty_cache()
401
  gc.collect()
 
402
  # ---------------------------------------------------------
403
  # UI
404
  # ---------------------------------------------------------