CocoBro commited on
Commit
66b10fb
·
1 Parent(s): d52557e
Files changed (1) hide show
  1. app.py +169 -244
app.py CHANGED
@@ -7,24 +7,30 @@ import spaces
7
  import os
8
  import time
9
  import logging
 
 
10
  from pathlib import Path
11
  from typing import Tuple, Optional, Dict, Any
12
 
13
  import gradio as gr
14
  import numpy as np
15
  import soundfile as sf
 
 
16
  from huggingface_hub import snapshot_download
17
 
18
-
19
  # -----------------------------
20
- # Logging
21
  # -----------------------------
22
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
 
 
23
  logger = logging.getLogger("mmedit_space")
24
 
25
-
26
  # ---------------------------------------------------------
27
- # HF Repo IDs(按你的默认需求)
28
  # ---------------------------------------------------------
29
  MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit")
30
  MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
@@ -32,63 +38,46 @@ MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
32
  QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
33
  QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
34
 
35
- # 如果 Qwen gated:Space 里把 HF_TOKEN 设为 Secret
36
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
37
 
38
  OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
39
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
40
 
41
- USE_AMP = os.environ.get("USE_AMP", "0") == "1"
42
- AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") # "bf16" or "fp16"
43
-
44
- # ZeroGPU:缓存 CPU pipeline(不要缓存 CUDA Tensor)
45
- # cache: key -> (model_cpu, scheduler, target_sr)
46
  _PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
47
- # cache: key -> (repo_root, qwen_root)
48
  _MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
49
 
50
 
51
  # ---------------------------------------------------------
52
- # 下载 repo(只下载一次;huggingface_hub 自带缓存)
53
  # ---------------------------------------------------------
54
  def resolve_model_dirs() -> Tuple[Path, Path]:
 
55
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
56
  if cache_key in _MODEL_DIR_CACHE:
57
  return _MODEL_DIR_CACHE[cache_key]
58
 
59
- logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})")
60
  repo_root = snapshot_download(
61
- repo_id=MMEDIT_REPO_ID,
62
- revision=MMEDIT_REVISION,
63
- local_dir=None,
64
- local_dir_use_symlinks=False,
65
- token=HF_TOKEN,
66
  )
67
- repo_root = Path(repo_root).resolve()
68
-
69
- logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID} (revision={QWEN_REVISION})")
70
  qwen_root = snapshot_download(
71
- repo_id=QWEN_REPO_ID,
72
- revision=QWEN_REVISION,
73
- local_dir=None,
74
- local_dir_use_symlinks=False,
75
- token=HF_TOKEN, # gated 模型必须
76
  )
77
- qwen_root = Path(qwen_root).resolve()
78
-
79
- _MODEL_DIR_CACHE[cache_key] = (repo_root, qwen_root)
80
- return repo_root, qwen_root
81
-
82
 
83
- # ---------------------------------------------------------
84
- # 你的音频加载(按你要求:orig -> 16k -> target_sr)
85
- # ---------------------------------------------------------
86
- def load_and_process_audio(audio_path: str, target_sr: int):
87
- # 延迟导入(避免启动阶段触发 CUDA 初始化)
88
- import torch
89
  import torchaudio
90
- import librosa
91
-
92
  path = Path(audio_path)
93
  if not path.exists():
94
  raise FileNotFoundError(f"Audio file not found: {audio_path}")
@@ -96,91 +85,34 @@ def load_and_process_audio(audio_path: str, target_sr: int):
96
  waveform, orig_sr = torchaudio.load(str(path)) # (C, T)
97
 
98
  # Convert to mono
99
- if waveform.ndim == 2:
100
- waveform = waveform.mean(dim=0) # (T,)
101
- elif waveform.ndim > 2:
102
- waveform = waveform.reshape(-1)
103
-
104
- if target_sr and int(target_sr) != int(orig_sr):
105
- waveform_np = waveform.cpu().numpy()
106
-
107
- # 1) 先到 16k
108
- sr_mid = 16000
109
- if int(orig_sr) != sr_mid:
110
- waveform_np = librosa.resample(waveform_np, orig_sr=int(orig_sr), target_sr=sr_mid)
111
- orig_sr_mid = sr_mid
112
  else:
113
  orig_sr_mid = int(orig_sr)
114
-
115
- # 2) 再到 target_sr(如 24k)
116
  if int(target_sr) != orig_sr_mid:
117
- waveform_np = librosa.resample(waveform_np, orig_sr=orig_sr_mid, target_sr=int(target_sr))
118
-
119
- waveform = torch.from_numpy(waveform_np)
120
-
121
  return waveform
122
 
123
-
124
- # ---------------------------------------------------------
125
- # 校验 repo 结构
126
- # ---------------------------------------------------------
127
- def assert_repo_layout(repo_root: Path) -> None:
128
- must = [repo_root / "config.yaml", repo_root / "model.safetensors", repo_root / "vae"]
129
- for p in must:
130
- if not p.exists():
131
- raise FileNotFoundError(f"Missing required path: {p}")
132
-
133
- vae_files = list((repo_root / "vae").glob("*.ckpt"))
134
- if len(vae_files) == 0:
135
- raise FileNotFoundError(f"No .ckpt found under: {repo_root/'vae'}")
136
-
137
-
138
- # ---------------------------------------------------------
139
- # 适配 config.yaml 的路径写法
140
- # ---------------------------------------------------------
141
- def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None:
142
- # ---- 1) VAE ckpt ----
143
- vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
144
- if vae_ckpt:
145
- vae_ckpt = str(vae_ckpt).replace("\\", "/")
146
- idx = vae_ckpt.find("vae/")
147
- if idx != -1:
148
- vae_rel = vae_ckpt[idx:] # 从 vae/ 开始截断
149
- else:
150
- if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt:
151
- vae_rel = f"vae/{vae_ckpt}"
152
- else:
153
- vae_rel = vae_ckpt
154
-
155
- vae_path = (repo_root / vae_rel).resolve()
156
- exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path)
157
-
158
- if not vae_path.exists():
159
- raise FileNotFoundError(
160
- f"VAE ckpt not found after patch:\n"
161
- f" original: {vae_ckpt}\n"
162
- f" patched : {vae_path}\n"
163
- f"Repo root: {repo_root}\n"
164
- f"Expected: {repo_root/'vae'/'*.ckpt'}"
165
- )
166
-
167
- # ---- 2) Qwen2-Audio model_path ----
168
- exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
169
-
170
-
171
- # ---------------------------------------------------------
172
- # Scheduler(与你 exp_cfg.model.noise_scheduler_name 对齐)
173
- # 带 fallback:避免 404
174
- # ---------------------------------------------------------
175
- def build_scheduler(exp_cfg: Dict[str, Any]):
176
  import diffusers.schedulers as noise_schedulers
177
-
178
  name = exp_cfg["model"].get("noise_scheduler_name", "stabilityai/stable-diffusion-2-1")
179
  try:
180
- scheduler = noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler", token=HF_TOKEN)
181
- return scheduler
182
  except Exception as e:
183
- logger.warning(f"DDIMScheduler.from_pretrained failed for '{name}', fallback. err={e}")
184
  return noise_schedulers.DDIMScheduler(
185
  num_train_timesteps=1000,
186
  beta_start=0.00085,
@@ -191,73 +123,67 @@ def build_scheduler(exp_cfg: Dict[str, Any]):
191
  steps_offset=1,
192
  )
193
 
194
-
195
- def amp_autocast(device):
196
- import torch
197
-
198
- if not USE_AMP:
199
- return torch.autocast("cuda", enabled=False)
200
-
201
- if device.type != "cuda":
202
- return torch.autocast("cpu", enabled=False)
203
-
204
- dtype = torch.bfloat16 if AMP_DTYPE.lower() == "bf16" else torch.float16
205
- return torch.autocast("cuda", dtype=dtype, enabled=True)
206
-
207
-
208
- # ---------------------------------------------------------
209
- # 冷启动:load+cache pipeline(缓存 CPU 上的 model)
210
- # ---------------------------------------------------------
211
  def load_pipeline_cpu() -> Tuple[object, object, int]:
212
- # 延迟导入(避免启动阶段触发 CUDA 初始化)
213
- import torch
214
  import hydra
215
  from omegaconf import OmegaConf
216
  from safetensors.torch import load_file
217
-
218
- # 你的项目依赖也延迟导入
219
- from models.common import LoadPretrainedBase
220
  from utils.config import register_omegaconf_resolvers
 
221
 
222
- register_omegaconf_resolvers()
 
 
 
 
223
 
224
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
225
  if cache_key in _PIPELINE_CACHE:
226
  return _PIPELINE_CACHE[cache_key]
227
 
228
  repo_root, qwen_root = resolve_model_dirs()
229
- assert_repo_layout(repo_root)
230
-
231
- logger.info(f"repo_root = {repo_root}")
232
- logger.info(f"qwen_root = {qwen_root}")
233
-
234
- exp_cfg = OmegaConf.load(repo_root / "config.yaml")
235
- exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
 
 
 
 
 
 
 
 
 
 
236
 
237
- patch_paths_in_exp_config(exp_cfg, repo_root, qwen_root)
238
- logger.info(f"patched pretrained_ckpt = {exp_cfg['model']['autoencoder'].get('pretrained_ckpt')}")
239
- logger.info(f"patched qwen model_path = {exp_cfg['model']['content_encoder']['text_encoder'].get('model_path')}")
240
 
 
241
  model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
242
-
243
  ckpt_path = repo_root / "model.safetensors"
 
244
  sd = load_file(str(ckpt_path))
245
  model.load_pretrained(sd)
246
-
247
- # ZeroGPU:缓存 CPU
248
- model = model.to(torch.device("cpu")).eval()
249
-
250
  scheduler = build_scheduler(exp_cfg)
251
  target_sr = int(exp_cfg.get("sample_rate", 24000))
252
 
253
  _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
254
- logger.info("CPU pipeline loaded and cached.")
255
  return model, scheduler, target_sr
256
 
257
 
258
  # ---------------------------------------------------------
259
- # 推理:audio + caption -> edited audio
260
- # ZeroGPU:必须用 @spaces.GPU
261
  # ---------------------------------------------------------
262
  @spaces.GPU
263
  def run_edit(
@@ -268,148 +194,147 @@ def run_edit(
268
  guidance_rescale: float,
269
  seed: int,
270
  ) -> Tuple[Optional[str], str]:
271
- import torch
272
-
273
- # 1. 基础检查
274
- if audio_file is None or not Path(audio_file).exists():
275
- return None, "Error: please upload an audio file."
276
 
 
277
  caption = (caption or "").strip()
278
- if not caption:
279
- return None, "Error: caption is empty."
280
 
281
- # 2. 获取缓存模型
282
- # 注意:此时 model_cpu 在 CPU 上
283
  model_cpu, scheduler, target_sr = load_pipeline_cpu()
 
 
 
 
 
 
 
 
 
 
284
 
285
- # 使用 try-finally 确保无论是否出错,最后都把模型搬回 CPU
286
- # 使用 try-except 确保捕获所有推理错误,打印日志
287
  try:
288
- # --- 检查 GPU ---
289
  if not torch.cuda.is_available():
290
- return None, "Error: ZeroGPU did not allocate CUDA."
291
 
292
- device = torch.device("cuda")
293
- logger.info(f"[GPU] Assigned device: {device}")
294
-
295
- # --- 关键修改:模型上 GPU ---
296
- # model_cpu.to(device) 是原位操作!会修改全局缓存!
297
- # 所以必须在 finally 里搬回去,或者在这里使用深拷贝(深拷贝太慢,建议搬回去)
298
 
299
- model = model_cpu.to(device).eval()
300
- logger.info("Moving model to GPU for inference...")
301
- # --- 数据预处理 ---
302
- seed = int(seed)
303
- torch.manual_seed(seed)
304
- np.random.seed(seed)
305
 
306
- # 加载音频并转到 GPU
307
- wav = load_and_process_audio(audio_file, target_sr=target_sr).to(device)
 
308
 
 
 
 
 
 
 
309
  batch = {
310
  "audio_id": [Path(audio_file).stem],
311
- "content": [{"audio": wav, "caption": caption}],
312
  "task": ["audio_editing"],
313
  }
314
-
315
  kwargs = {
316
  "num_steps": int(num_steps),
317
  "guidance_scale": float(guidance_scale),
318
  "guidance_rescale": float(guidance_rescale),
319
  "use_gt_duration": False,
320
  "mask_time_aligned_content": False,
 
321
  }
322
- kwargs.update(batch)
323
 
324
- # --- 推理 ---
 
325
  t0 = time.time()
 
326
  with torch.no_grad():
327
- with amp_autocast(device):
328
- # 这里的报错现在能被捕获了
329
- out = model.inference(scheduler=scheduler, **kwargs)
 
330
  dt = time.time() - t0
 
331
 
332
- # --- 后处理 ---
 
333
  out_audio = out[0, 0].detach().float().cpu().numpy()
334
  out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
335
  sf.write(str(out_path), out_audio, samplerate=target_sr)
336
-
337
- return str(out_path), f"OK | time={dt:.2f}s | seed={seed}"
338
 
339
  except Exception as e:
340
- # 这里会打印完整的堆栈信息,让你看到真正的报错原因
341
- logger.exception("Error during inference")
342
- return None, f"Runtime Error: {str(e)}"
 
343
 
344
  finally:
345
- # --- 关键修改:清理现场 ---
346
- # 无论 try 里面是否成功,这里都会执行
347
- # 必须把模型搬回 CPU,否则全局缓存 _PIPELINE_CACHE 将指向损坏的 CUDA 地址
348
- if 'model_cpu' in locals() and model_cpu is not None:
349
- logger.info("Moving model back to CPU to preserve cache integrity...")
350
- model_cpu.to("cpu")
 
 
 
 
 
 
 
 
351
 
352
- # 强制清理显存
353
  torch.cuda.empty_cache()
 
354
 
355
 
356
  # ---------------------------------------------------------
357
- # UI
358
  # ---------------------------------------------------------
359
  def build_demo():
360
- with gr.Blocks(title="MMEdit (ZeroGPU)") as demo:
361
- gr.Markdown("# MMEdit ZeroGPU(audio + caption → edited audio)")
362
-
 
363
  with gr.Row():
364
  with gr.Column():
365
  audio_in = gr.Audio(label="Input Audio", type="filepath")
366
- caption = gr.Textbox(label="Caption (Edit Instruction)", lines=3)
367
-
368
- # 注意:Space 不建议推大 wav;你可以换成更小的 demo wav
369
- gr.Examples(
370
- label="example inputs",
371
- examples=[
372
- ["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."],
373
- ],
374
- inputs=[audio_in, caption],
375
- cache_examples=False,
376
- )
377
-
378
- with gr.Row():
379
- num_steps = gr.Slider(1, 100, value=50, step=1, label="num_steps")
380
- guidance_scale = gr.Slider(1.0, 12.0, value=5.0, step=0.5, label="guidance_scale")
381
-
382
- with gr.Row():
383
- guidance_rescale = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="guidance_rescale")
384
- seed = gr.Number(value=42, precision=0, label="seed")
385
 
386
  run_btn = gr.Button("Run Editing", variant="primary")
387
 
388
  with gr.Column():
389
- audio_out = gr.Audio(label="Edited Audio", type="filepath")
390
- status = gr.Textbox(label="Status")
391
 
392
  run_btn.click(
393
- fn=run_edit,
394
- inputs=[audio_in, caption, num_steps, guidance_scale, guidance_rescale, seed],
395
- outputs=[audio_out, status],
396
- )
397
-
398
- gr.Markdown(
399
- "## 注意事项\n"
400
- "1) ZeroGPU 首次点击会分配 GPU,可能稍慢。\n"
401
- "2) 如果首次报 cuda 不可用,通常重试一次即可。\n"
402
  )
403
-
404
  return demo
405
 
406
 
407
  if __name__ == "__main__":
408
  demo = build_demo()
409
- port = int(os.environ.get("PORT", "7860"))
 
410
  demo.queue().launch(
411
- server_name="0.0.0.0",
412
  server_port=port,
413
- share=False,
414
- ssr_mode=False,
415
- )
 
7
  import os
8
  import time
9
  import logging
10
+ import traceback
11
+ import gc
12
  from pathlib import Path
13
  from typing import Tuple, Optional, Dict, Any
14
 
15
  import gradio as gr
16
  import numpy as np
17
  import soundfile as sf
18
+ import torch
19
+ import librosa
20
  from huggingface_hub import snapshot_download
21
 
 
22
  # -----------------------------
23
+ # Logging 配置
24
  # -----------------------------
25
+ logging.basicConfig(
26
+ level=logging.INFO,
27
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
28
+ datefmt="%H:%M:%S"
29
+ )
30
  logger = logging.getLogger("mmedit_space")
31
 
 
32
  # ---------------------------------------------------------
33
+ # 配置信息
34
  # ---------------------------------------------------------
35
  MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit")
36
  MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
 
38
  QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
39
  QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
40
 
 
41
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
42
 
43
  OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
44
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
45
 
46
+ # ---------------------------------------------------------
47
+ # 全局缓存
48
+ # ---------------------------------------------------------
49
+ # 存储 (model_cpu, scheduler, target_sr)
50
+ # 警告:此缓存中的 model 必须始终保持在 "cpu" 设备上!
51
  _PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
 
52
  _MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
53
 
54
 
55
  # ---------------------------------------------------------
56
+ # 辅助函数
57
  # ---------------------------------------------------------
58
  def resolve_model_dirs() -> Tuple[Path, Path]:
59
+ """下载并返回模型路径"""
60
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
61
  if cache_key in _MODEL_DIR_CACHE:
62
  return _MODEL_DIR_CACHE[cache_key]
63
 
64
+ logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID}...")
65
  repo_root = snapshot_download(
66
+ repo_id=MMEDIT_REPO_ID, revision=MMEDIT_REVISION, token=HF_TOKEN
 
 
 
 
67
  )
68
+
69
+ logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID}...")
 
70
  qwen_root = snapshot_download(
71
+ repo_id=QWEN_REPO_ID, revision=QWEN_REVISION, token=HF_TOKEN
 
 
 
 
72
  )
73
+
74
+ res = (Path(repo_root).resolve(), Path(qwen_root).resolve())
75
+ _MODEL_DIR_CACHE[cache_key] = res
76
+ return res
 
77
 
78
+ def load_and_process_audio(audio_path: str, target_sr: int) -> torch.Tensor:
 
 
 
 
 
79
  import torchaudio
80
+
 
81
  path = Path(audio_path)
82
  if not path.exists():
83
  raise FileNotFoundError(f"Audio file not found: {audio_path}")
 
85
  waveform, orig_sr = torchaudio.load(str(path)) # (C, T)
86
 
87
  # Convert to mono
88
+ if waveform.ndim > 1:
89
+ waveform = waveform.mean(dim=0)
90
+
91
+ # Resample logic (robust method)
92
+ if int(orig_sr) != int(target_sr):
93
+ wav_np = waveform.cpu().numpy()
94
+
95
+ # Intermediate resampling to 16k if needed (for better stability)
96
+ if int(orig_sr) != 16000:
97
+ wav_np = librosa.resample(wav_np, orig_sr=int(orig_sr), target_sr=16000)
98
+ orig_sr_mid = 16000
 
 
99
  else:
100
  orig_sr_mid = int(orig_sr)
101
+
 
102
  if int(target_sr) != orig_sr_mid:
103
+ wav_np = librosa.resample(wav_np, orig_sr=orig_sr_mid, target_sr=int(target_sr))
104
+
105
+ waveform = torch.from_numpy(wav_np)
106
+
107
  return waveform
108
 
109
+ def build_scheduler(exp_cfg):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  import diffusers.schedulers as noise_schedulers
 
111
  name = exp_cfg["model"].get("noise_scheduler_name", "stabilityai/stable-diffusion-2-1")
112
  try:
113
+ return noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler", token=HF_TOKEN)
 
114
  except Exception as e:
115
+ logger.warning(f"Scheduler init failed: {e}, using fallback.")
116
  return noise_schedulers.DDIMScheduler(
117
  num_train_timesteps=1000,
118
  beta_start=0.00085,
 
123
  steps_offset=1,
124
  )
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def load_pipeline_cpu() -> Tuple[object, object, int]:
127
+ """加载模型到 RAM(CPU),并建立全局缓存"""
 
128
  import hydra
129
  from omegaconf import OmegaConf
130
  from safetensors.torch import load_file
 
 
 
131
  from utils.config import register_omegaconf_resolvers
132
+ from models.common import LoadPretrainedBase
133
 
134
+ # 注册 omegaconf
135
+ try:
136
+ register_omegaconf_resolvers()
137
+ except Exception:
138
+ pass
139
 
140
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
141
  if cache_key in _PIPELINE_CACHE:
142
  return _PIPELINE_CACHE[cache_key]
143
 
144
  repo_root, qwen_root = resolve_model_dirs()
145
+
146
+ cfg_path = repo_root / "config.yaml"
147
+ exp_cfg = OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True)
148
+
149
+ # --- Config Patching ---
150
+ # Fix VAE ckpt path
151
+ vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "")
152
+ if vae_ckpt:
153
+ # 简单暴力的路径修复:只要是 ckpt 就去 vae 目录下找
154
+ fname = Path(vae_ckpt).name
155
+ local_vae = repo_root / "vae" / fname
156
+ if local_vae.exists():
157
+ exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(local_vae)
158
+ else:
159
+ # 尝试直接在 repo_root 下找
160
+ if (repo_root / fname).exists():
161
+ exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(repo_root / fname)
162
 
163
+ # Fix Qwen path
164
+ exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
 
165
 
166
+ logger.info("Instantiating model architecture...")
167
  model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
168
+
169
  ckpt_path = repo_root / "model.safetensors"
170
+ logger.info(f"Loading weights from {ckpt_path.name}...")
171
  sd = load_file(str(ckpt_path))
172
  model.load_pretrained(sd)
173
+
174
+ # 关键:确保初始状态在 CPU
175
+ model = model.to("cpu").eval()
176
+
177
  scheduler = build_scheduler(exp_cfg)
178
  target_sr = int(exp_cfg.get("sample_rate", 24000))
179
 
180
  _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
181
+ logger.info(" Model loaded and cached in CPU RAM.")
182
  return model, scheduler, target_sr
183
 
184
 
185
  # ---------------------------------------------------------
186
+ # ZeroGPU 推理函数
 
187
  # ---------------------------------------------------------
188
  @spaces.GPU
189
  def run_edit(
 
194
  guidance_rescale: float,
195
  seed: int,
196
  ) -> Tuple[Optional[str], str]:
 
 
 
 
 
197
 
198
+ if not audio_file: return None, "Please upload an audio file."
199
  caption = (caption or "").strip()
200
+ if not caption: return None, "Please enter an instruction caption."
 
201
 
202
+ # 1. 获取 CPU 上的模型引用
 
203
  model_cpu, scheduler, target_sr = load_pipeline_cpu()
204
+
205
+ # 2. 准备设备 - 强制使用 float16
206
+ device = torch.device("cuda")
207
+ dtype = torch.float16 # <--- 强制 FP16
208
+
209
+ logger.info(f"🚀 [GPU Task Start] Device: {device}, Precision: {dtype}")
210
+
211
+ # 用于 finally 清理的变量
212
+ model_on_gpu = None
213
+ wav_on_gpu = None
214
 
 
 
215
  try:
216
+ # --- GPU 环境检查 ---
217
  if not torch.cuda.is_available():
218
+ raise RuntimeError("ZeroGPU assigned but CUDA unavailable.")
219
 
220
+ # --- 3. 模型搬运 (CPU -> GPU) ---
221
+ # 显式清理,为大模型腾出完整空间
222
+ gc.collect()
223
+ torch.cuda.empty_cache()
 
 
224
 
225
+ logger.info("Moving model to GPU...")
 
 
 
 
 
226
 
227
+ # ⚠️ 核心逻辑:这里虽然用了 to(device),这会修改 model_cpu 的设备属性
228
+ # 所以我们在 finally 块中必须将其搬回 CPU,否则下次运行会因为设备失效而崩溃
229
+ model_on_gpu = model_cpu.to(device, dtype=dtype)
230
 
231
+ # --- 4. 数据准备 ---
232
+ torch.manual_seed(int(seed))
233
+ np.random.seed(int(seed))
234
+
235
+ wav_on_gpu = load_and_process_audio(audio_file, target_sr).to(device, dtype=dtype)
236
+
237
  batch = {
238
  "audio_id": [Path(audio_file).stem],
239
+ "content": [{"audio": wav_on_gpu, "caption": caption}],
240
  "task": ["audio_editing"],
241
  }
242
+
243
  kwargs = {
244
  "num_steps": int(num_steps),
245
  "guidance_scale": float(guidance_scale),
246
  "guidance_rescale": float(guidance_rescale),
247
  "use_gt_duration": False,
248
  "mask_time_aligned_content": False,
249
+ **batch
250
  }
 
251
 
252
+ # --- 5. 推理 ---
253
+ logger.info("Starting inference...")
254
  t0 = time.time()
255
+
256
  with torch.no_grad():
257
+ # 使用 float16
258
+ with torch.autocast("cuda", dtype=dtype):
259
+ out = model_on_gpu.inference(scheduler=scheduler, **kwargs)
260
+
261
  dt = time.time() - t0
262
+ logger.info(f"✅ Inference finished in {dt:.2f}s")
263
 
264
+ # --- 6. 保存结果 ---
265
+ # 立即 detach 并转回 CPU
266
  out_audio = out[0, 0].detach().float().cpu().numpy()
267
  out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
268
  sf.write(str(out_path), out_audio, samplerate=target_sr)
269
+
270
+ return str(out_path), f"Success | Time: {dt:.2f}s | Seed: {seed}"
271
 
272
  except Exception as e:
273
+ # 🔥 捕捉所有错误,防止 spaces 吞掉报错,打印完整堆栈
274
+ err_msg = traceback.format_exc()
275
+ logger.error(f" CRITICAL ERROR:\n{err_msg}")
276
+ return None, f"Runtime Error: {str(e)}\n(See logs for details)"
277
 
278
  finally:
279
+ # --- 7. 关键:现场恢复 ---
280
+ # 无论成功还是失败,必须把模型搬回 CPU,否则全局缓存 _PIPELINE_CACHE 将指向已释放的显存
281
+ logger.info("♻️ Cleaning up resources...")
282
+ try:
283
+ # 只要 model_cpu 还在,就强制搬回 CPU
284
+ if 'model_cpu' in locals() and model_cpu is not None:
285
+ model_cpu.to("cpu")
286
+ logger.info("Model restored to CPU.")
287
+ except Exception as e:
288
+ logger.error(f"Failed to restore model to CPU: {e}")
289
+
290
+ # 删除局部引用
291
+ if 'model_on_gpu' in locals(): del model_on_gpu
292
+ if 'wav_on_gpu' in locals(): del wav_on_gpu
293
 
294
+ # 强制显存清理
295
  torch.cuda.empty_cache()
296
+ gc.collect()
297
 
298
 
299
  # ---------------------------------------------------------
300
+ # UI 启动
301
  # ---------------------------------------------------------
302
  def build_demo():
303
+ with gr.Blocks(title="MMEdit ZeroGPU") as demo:
304
+ gr.Markdown("## MMEdit")
305
+ gr.Markdown("ZeroGPU environment detected. Resources are allocated dynamically.")
306
+
307
  with gr.Row():
308
  with gr.Column():
309
  audio_in = gr.Audio(label="Input Audio", type="filepath")
310
+ caption = gr.Textbox(label="Editing Instruction", placeholder="e.g., Add rain sound in the background")
311
+
312
+ with gr.Accordion("Advanced Settings", open=False):
313
+ steps = gr.Slider(10, 100, 50, step=1, label="Steps")
314
+ cfg = gr.Slider(1.0, 15.0, 5.0, step=0.5, label="Guidance Scale")
315
+ rescale = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Guidance Rescale")
316
+ seed = gr.Number(42, label="Seed")
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  run_btn = gr.Button("Run Editing", variant="primary")
319
 
320
  with gr.Column():
321
+ audio_out = gr.Audio(label="Result", type="filepath")
322
+ status = gr.Textbox(label="Status Logs")
323
 
324
  run_btn.click(
325
+ run_edit,
326
+ inputs=[audio_in, caption, steps, cfg, rescale, seed],
327
+ outputs=[audio_out, status]
 
 
 
 
 
 
328
  )
 
329
  return demo
330
 
331
 
332
  if __name__ == "__main__":
333
  demo = build_demo()
334
+ # 兼容性设置:去掉 ssr_mode,让 Gradio 自动处理
335
+ port = int(os.environ.get("PORT", 7860))
336
  demo.queue().launch(
337
+ server_name="0.0.0.0",
338
  server_port=port,
339
+ share=False
340
+ )