CocoBro commited on
Commit
688592f
·
1 Parent(s): b75ca87
Files changed (1) hide show
  1. app.py +88 -143
app.py CHANGED
@@ -7,18 +7,18 @@ import spaces
7
  import os
8
  import time
9
  import logging
10
- import traceback
11
- import gc
12
  from pathlib import Path
13
  from typing import Tuple, Optional, Dict, Any
14
 
15
  import gradio as gr
16
  import numpy as np
17
  import soundfile as sf
18
- import torch
19
- import librosa
20
  from huggingface_hub import snapshot_download
21
 
 
22
  # -----------------------------
23
  # Logging
24
  # -----------------------------
@@ -35,7 +35,6 @@ MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
35
  QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
36
  QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
37
 
38
- # 如果 Qwen gated:Space 里把 HF_TOKEN 设为 Secret
39
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
40
 
41
  OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
@@ -45,37 +44,32 @@ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
45
  # 缓存定义
46
  # ---------------------------------------------------------
47
  # cache: key -> (model_cpu, scheduler, target_sr)
48
- # 注意:model_cpu 必须始终在 CPU 上
49
  _PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
50
  # cache: key -> (repo_root, qwen_root)
51
  _MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
52
 
53
 
54
  # ---------------------------------------------------------
55
- # 1. 下载 repo
56
  # ---------------------------------------------------------
57
  def resolve_model_dirs() -> Tuple[Path, Path]:
58
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
59
  if cache_key in _MODEL_DIR_CACHE:
60
  return _MODEL_DIR_CACHE[cache_key]
61
 
62
- logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})")
63
  repo_root = snapshot_download(
64
  repo_id=MMEDIT_REPO_ID,
65
  revision=MMEDIT_REVISION,
66
- local_dir=None,
67
- local_dir_use_symlinks=False,
68
  token=HF_TOKEN,
69
  )
70
  repo_root = Path(repo_root).resolve()
71
 
72
- logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID} (revision={QWEN_REVISION})")
73
  qwen_root = snapshot_download(
74
  repo_id=QWEN_REPO_ID,
75
  revision=QWEN_REVISION,
76
- local_dir=None,
77
- local_dir_use_symlinks=False,
78
- token=HF_TOKEN, # gated 模型必须
79
  )
80
  qwen_root = Path(qwen_root).resolve()
81
 
@@ -84,27 +78,26 @@ def resolve_model_dirs() -> Tuple[Path, Path]:
84
 
85
 
86
  # ---------------------------------------------------------
87
- # 2. 音频加载(保留你的逻辑,增强鲁棒性)
88
  # ---------------------------------------------------------
89
  def load_and_process_audio(audio_path: str, target_sr: int):
 
 
90
  import torchaudio
91
-
 
92
  path = Path(audio_path)
93
  if not path.exists():
94
  raise FileNotFoundError(f"Audio file not found: {audio_path}")
95
 
96
- waveform, orig_sr = torchaudio.load(str(path)) # (C, T)
97
 
98
- # Convert to mono
99
- if waveform.ndim == 2:
100
- waveform = waveform.mean(dim=0) # (T,)
101
- elif waveform.ndim > 2:
102
- waveform = waveform.reshape(-1)
103
 
104
  if target_sr and int(target_sr) != int(orig_sr):
105
  waveform_np = waveform.cpu().numpy()
106
-
107
- # 稳健的两步重采样逻辑
108
  sr_mid = 16000
109
  if int(orig_sr) != sr_mid:
110
  waveform_np = librosa.resample(waveform_np, orig_sr=int(orig_sr), target_sr=sr_mid)
@@ -121,52 +114,33 @@ def load_and_process_audio(audio_path: str, target_sr: int):
121
 
122
 
123
  # ---------------------------------------------------------
124
- # 3. 校验 repo 结构(保留你的逻辑)
125
- # ---------------------------------------------------------
126
- def assert_repo_layout(repo_root: Path) -> None:
127
- must = [repo_root / "config.yaml", repo_root / "model.safetensors", repo_root / "vae"]
128
- for p in must:
129
- if not p.exists():
130
- raise FileNotFoundError(f"Missing required path: {p}")
131
-
132
- vae_files = list((repo_root / "vae").glob("*.ckpt"))
133
- if len(vae_files) == 0:
134
- raise FileNotFoundError(f"No .ckpt found under: {repo_root/'vae'}")
135
-
136
-
137
- # ---------------------------------------------------------
138
- # 4. 适配 config.yaml(保留你的逻辑)
139
  # ---------------------------------------------------------
140
  def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None:
141
- # ---- 1) VAE ckpt ----
142
  vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
143
  if vae_ckpt:
144
  vae_ckpt = str(vae_ckpt).replace("\\", "/")
145
- idx = vae_ckpt.find("vae/")
146
- if idx != -1:
147
- vae_rel = vae_ckpt[idx:] # 从 vae/ 开始截断
 
148
  else:
149
- if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt:
150
- vae_rel = f"vae/{vae_ckpt}"
151
- else:
152
- vae_rel = vae_ckpt
153
 
154
  vae_path = (repo_root / vae_rel).resolve()
155
- exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path)
156
-
157
  if not vae_path.exists():
158
- # Fallback check (鲁棒性增强)
159
- if (repo_root / Path(vae_ckpt).name).exists():
160
- exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(repo_root / Path(vae_ckpt).name)
161
- else:
162
- logger.warning(f"VAE ckpt warning: {vae_path} not found. Model loading might fail.")
163
 
164
- # ---- 2) Qwen2-Audio model_path ----
165
  exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
166
 
167
 
168
  # ---------------------------------------------------------
169
- # 5. Scheduler(保留你的逻辑)
170
  # ---------------------------------------------------------
171
  def build_scheduler(exp_cfg: Dict[str, Any]):
172
  import diffusers.schedulers as noise_schedulers
@@ -176,7 +150,7 @@ def build_scheduler(exp_cfg: Dict[str, Any]):
176
  scheduler = noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler", token=HF_TOKEN)
177
  return scheduler
178
  except Exception as e:
179
- logger.warning(f"DDIMScheduler.from_pretrained failed for '{name}', fallback. err={e}")
180
  return noise_schedulers.DDIMScheduler(
181
  num_train_timesteps=1000,
182
  beta_start=0.00085,
@@ -189,20 +163,22 @@ def build_scheduler(exp_cfg: Dict[str, Any]):
189
 
190
 
191
  # ---------------------------------------------------------
192
- # 6. 冷启动:Load Pipeline to CPU
193
  # ---------------------------------------------------------
194
  def load_pipeline_cpu() -> Tuple[object, object, int]:
195
- # 延迟导入
196
  import torch
197
  import hydra
198
  from omegaconf import OmegaConf
199
  from safetensors.torch import load_file
200
-
201
- from models.common import LoadPretrainedBase
202
- from utils.config import register_omegaconf_resolvers
203
-
204
  try:
 
 
205
  register_omegaconf_resolvers()
 
 
206
  except Exception:
207
  pass
208
 
@@ -211,36 +187,35 @@ def load_pipeline_cpu() -> Tuple[object, object, int]:
211
  return _PIPELINE_CACHE[cache_key]
212
 
213
  repo_root, qwen_root = resolve_model_dirs()
214
- assert_repo_layout(repo_root)
215
-
216
- logger.info(f"repo_root = {repo_root}")
217
- logger.info(f"qwen_root = {qwen_root}")
218
-
219
  exp_cfg = OmegaConf.load(repo_root / "config.yaml")
220
  exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
221
 
222
  patch_paths_in_exp_config(exp_cfg, repo_root, qwen_root)
223
 
224
  logger.info("Instantiating model...")
225
- model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
226
 
227
  ckpt_path = repo_root / "model.safetensors"
 
228
  sd = load_file(str(ckpt_path))
229
  model.load_pretrained(sd)
230
 
231
- # 关键:确保模型在 CPU 上,并且是 eval 模式
232
  model = model.to(torch.device("cpu")).eval()
233
 
234
  scheduler = build_scheduler(exp_cfg)
235
  target_sr = int(exp_cfg.get("sample_rate", 24000))
236
 
237
  _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
238
- logger.info("CPU pipeline loaded and cached.")
239
  return model, scheduler, target_sr
240
 
241
 
242
  # ---------------------------------------------------------
243
- # 7. ZeroGPU 推理核心(修复
244
  # ---------------------------------------------------------
245
  @spaces.GPU
246
  def run_edit(
@@ -253,57 +228,45 @@ def run_edit(
253
  ) -> Tuple[Optional[str], str]:
254
  import torch
255
 
256
- # 1. 基础检查
257
- if audio_file is None or not Path(audio_file).exists():
258
- return None, "Error: please upload an audio file."
259
-
260
- caption = (caption or "").strip()
261
- if not caption:
262
- return None, "Error: caption is empty."
263
 
264
- # 2. 获取缓存模型 (CPU)
265
  model_cpu, scheduler, target_sr = load_pipeline_cpu()
266
 
267
- # 强制使用 float16兼容性最好
268
  device = torch.device("cuda")
269
  dtype = torch.float16
 
 
270
 
271
- logger.info(f"🚀 [GPU Task Start] Device: {device}, Dtype: {dtype}")
272
-
273
- # 用于 finally 清理
274
  model_on_gpu = None
275
- wav_on_gpu = None
276
-
277
  try:
278
- # --- 检查环境 ---
279
  if not torch.cuda.is_available():
280
  raise RuntimeError("ZeroGPU assigned but CUDA not found!")
281
 
282
- # --- 3. 模型搬运 (CPU -> GPU) ---
283
  gc.collect()
284
  torch.cuda.empty_cache()
285
-
286
- logger.info("Moving model to GPU...")
287
 
288
- # ⚠️ 关键点:这里 model_cpu.to(device) 是原位操作,
289
- # 我们必须在 finally 里搬回去,才能保证全局缓存不坏。
290
- # 同时做 dtype 转换以节省显存。
291
  model_on_gpu = model_cpu.to(device, dtype=dtype)
292
 
293
- # --- 4. 数据预处理 ---
294
- seed = int(seed)
295
- torch.manual_seed(seed)
296
- np.random.seed(seed)
 
297
 
298
- # 加载音频并转到 GPU
299
- wav_on_gpu = load_and_process_audio(audio_file, target_sr=target_sr).to(device, dtype=dtype)
300
-
301
  batch = {
302
  "audio_id": [Path(audio_file).stem],
303
- "content": [{"audio": wav_on_gpu, "caption": caption}],
304
  "task": ["audio_editing"],
305
  }
306
-
307
  kwargs = {
308
  "num_steps": int(num_steps),
309
  "guidance_scale": float(guidance_scale),
@@ -314,84 +277,72 @@ def run_edit(
314
  }
315
 
316
  # --- 5. 推理 ---
317
- logger.info("Starting inference...")
318
  t0 = time.time()
319
-
320
  with torch.no_grad():
321
- # 使用 float16 autocast
322
  with torch.autocast("cuda", dtype=dtype):
323
  out = model_on_gpu.inference(scheduler=scheduler, **kwargs)
324
-
325
  dt = time.time() - t0
326
- logger.info(f"✅ Inference finished in {dt:.2f}s")
327
 
328
- # --- 6. 后处理 ---
329
  out_audio = out[0, 0].detach().float().cpu().numpy()
330
  out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
331
  sf.write(str(out_path), out_audio, samplerate=target_sr)
332
 
333
- return str(out_path), f"OK | time={dt:.2f}s | seed={seed}"
334
 
335
  except Exception as e:
336
- # 🔥 打印完整堆栈,防止 404 掩盖真实错误
337
  err_msg = traceback.format_exc()
338
- logger.error(f"❌ CRITICAL ERROR:\n{err_msg}")
339
- return None, f"Runtime Error: {str(e)}\n(See logs for details)"
340
 
341
  finally:
342
- # --- 7. 关键现场恢复(必须执行)---
343
- logger.info("♻️ Cleaning up resources...")
344
  try:
345
- # 必须把模型搬回 CPU,否则全局缓存 _PIPELINE_CACHE 指向已释放的显存
346
  if 'model_cpu' in locals() and model_cpu is not None:
347
  model_cpu.to("cpu")
348
  logger.info("Model restored to CPU.")
349
  except Exception as e:
350
- logger.error(f"Failed to restore model to CPU: {e}")
351
 
352
- # 删除引用
353
  if 'model_on_gpu' in locals(): del model_on_gpu
354
- if 'wav_on_gpu' in locals(): del wav_on_gpu
355
-
356
- # 强制清理显存
357
  torch.cuda.empty_cache()
358
  gc.collect()
359
 
360
 
361
  # ---------------------------------------------------------
362
- # UI (完全保留你的 Examples)
363
  # ---------------------------------------------------------
364
  def build_demo():
365
  with gr.Blocks(title="MMEdit (ZeroGPU)") as demo:
366
- gr.Markdown("# MMEdit ZeroGPU(audio + caption → edited audio)")
367
 
368
  with gr.Row():
369
  with gr.Column():
370
  audio_in = gr.Audio(label="Input Audio", type="filepath")
371
- caption = gr.Textbox(label="Caption (Edit Instruction)", lines=3)
372
 
373
- # 恢复了你的 Examples
374
  gr.Examples(
375
- label="example inputs",
376
- examples=[
377
- ["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."],
378
- ],
379
  inputs=[audio_in, caption],
380
- cache_examples=False,
381
  )
382
 
383
  with gr.Row():
384
- num_steps = gr.Slider(1, 100, value=50, step=1, label="num_steps")
385
- guidance_scale = gr.Slider(1.0, 12.0, value=5.0, step=0.5, label="guidance_scale")
386
-
387
- with gr.Row():
388
- guidance_rescale = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="guidance_rescale")
389
- seed = gr.Number(value=42, precision=0, label="seed")
390
 
391
- run_btn = gr.Button("Run Editing", variant="primary")
392
 
393
  with gr.Column():
394
- audio_out = gr.Audio(label="Edited Audio", type="filepath")
395
  status = gr.Textbox(label="Status")
396
 
397
  run_btn.click(
@@ -400,12 +351,6 @@ def build_demo():
400
  outputs=[audio_out, status],
401
  )
402
 
403
- gr.Markdown(
404
- "## 注意事项\n"
405
- "1) ZeroGPU 首次点击会分配 GPU,可能稍慢。\n"
406
- "2) 如果首次报 cuda 不可用,通常重试一次即可。\n"
407
- )
408
-
409
  return demo
410
 
411
 
 
7
  import os
8
  import time
9
  import logging
10
+ import traceback # [新增] 用于打印报错堆栈
11
+ import gc # [新增] 用于显存清理
12
  from pathlib import Path
13
  from typing import Tuple, Optional, Dict, Any
14
 
15
  import gradio as gr
16
  import numpy as np
17
  import soundfile as sf
18
+ # [修改] 移除了顶部的 hydra/models 导入,防止启动时触发 CUDA
 
19
  from huggingface_hub import snapshot_download
20
 
21
+
22
  # -----------------------------
23
  # Logging
24
  # -----------------------------
 
35
  QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
36
  QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
37
 
 
38
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
39
 
40
  OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
 
44
  # 缓存定义
45
  # ---------------------------------------------------------
46
  # cache: key -> (model_cpu, scheduler, target_sr)
 
47
  _PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
48
  # cache: key -> (repo_root, qwen_root)
49
  _MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
50
 
51
 
52
  # ---------------------------------------------------------
53
+ # 下载 Repo
54
  # ---------------------------------------------------------
55
  def resolve_model_dirs() -> Tuple[Path, Path]:
56
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
57
  if cache_key in _MODEL_DIR_CACHE:
58
  return _MODEL_DIR_CACHE[cache_key]
59
 
60
+ logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID}")
61
  repo_root = snapshot_download(
62
  repo_id=MMEDIT_REPO_ID,
63
  revision=MMEDIT_REVISION,
 
 
64
  token=HF_TOKEN,
65
  )
66
  repo_root = Path(repo_root).resolve()
67
 
68
+ logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID}")
69
  qwen_root = snapshot_download(
70
  repo_id=QWEN_REPO_ID,
71
  revision=QWEN_REVISION,
72
+ token=HF_TOKEN,
 
 
73
  )
74
  qwen_root = Path(qwen_root).resolve()
75
 
 
78
 
79
 
80
  # ---------------------------------------------------------
81
+ # 音频处理
82
  # ---------------------------------------------------------
83
  def load_and_process_audio(audio_path: str, target_sr: int):
84
+ # 延迟导入,防止干扰
85
+ import torch
86
  import torchaudio
87
+ import librosa
88
+
89
  path = Path(audio_path)
90
  if not path.exists():
91
  raise FileNotFoundError(f"Audio file not found: {audio_path}")
92
 
93
+ waveform, orig_sr = torchaudio.load(str(path))
94
 
95
+ if waveform.ndim > 1:
96
+ waveform = waveform.mean(dim=0)
 
 
 
97
 
98
  if target_sr and int(target_sr) != int(orig_sr):
99
  waveform_np = waveform.cpu().numpy()
100
+ # 稳健的重采样逻辑
 
101
  sr_mid = 16000
102
  if int(orig_sr) != sr_mid:
103
  waveform_np = librosa.resample(waveform_np, orig_sr=int(orig_sr), target_sr=sr_mid)
 
114
 
115
 
116
  # ---------------------------------------------------------
117
+ # 路径适配
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  # ---------------------------------------------------------
119
  def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None:
 
120
  vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
121
  if vae_ckpt:
122
  vae_ckpt = str(vae_ckpt).replace("\\", "/")
123
+ if "vae/" in vae_ckpt:
124
+ vae_rel = vae_ckpt[vae_ckpt.find("vae/"):]
125
+ elif vae_ckpt.endswith(".ckpt"):
126
+ vae_rel = f"vae/{vae_ckpt}" if "/" not in vae_ckpt else vae_ckpt
127
  else:
128
+ vae_rel = vae_ckpt
 
 
 
129
 
130
  vae_path = (repo_root / vae_rel).resolve()
131
+ # 鲁棒性检查:如果算出来的路径不存在,尝试在根目录找文件名
 
132
  if not vae_path.exists():
133
+ fallback = repo_root / Path(vae_ckpt).name
134
+ if fallback.exists():
135
+ vae_path = fallback
136
+
137
+ exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path)
138
 
 
139
  exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
140
 
141
 
142
  # ---------------------------------------------------------
143
+ # Scheduler
144
  # ---------------------------------------------------------
145
  def build_scheduler(exp_cfg: Dict[str, Any]):
146
  import diffusers.schedulers as noise_schedulers
 
150
  scheduler = noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler", token=HF_TOKEN)
151
  return scheduler
152
  except Exception as e:
153
+ logger.warning(f"Scheduler fallback: {e}")
154
  return noise_schedulers.DDIMScheduler(
155
  num_train_timesteps=1000,
156
  beta_start=0.00085,
 
163
 
164
 
165
  # ---------------------------------------------------------
166
+ # [核心] 冷启动:Load to CPU
167
  # ---------------------------------------------------------
168
  def load_pipeline_cpu() -> Tuple[object, object, int]:
169
+ # [修改] 所有的库都在这里导入,防止全局导入触发 CUDA 初始化
170
  import torch
171
  import hydra
172
  from omegaconf import OmegaConf
173
  from safetensors.torch import load_file
174
+
175
+ # 你的项目依赖
 
 
176
  try:
177
+ from utils.config import register_omegaconf_resolvers
178
+ from models.common import LoadPretrainedBase
179
  register_omegaconf_resolvers()
180
+ except ImportError:
181
+ logger.warning("Could not import project utils/models. Ensure they are in the python path.")
182
  except Exception:
183
  pass
184
 
 
187
  return _PIPELINE_CACHE[cache_key]
188
 
189
  repo_root, qwen_root = resolve_model_dirs()
190
+
191
+ logger.info(f"repo_root: {repo_root}")
192
+
 
 
193
  exp_cfg = OmegaConf.load(repo_root / "config.yaml")
194
  exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
195
 
196
  patch_paths_in_exp_config(exp_cfg, repo_root, qwen_root)
197
 
198
  logger.info("Instantiating model...")
199
+ model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
200
 
201
  ckpt_path = repo_root / "model.safetensors"
202
+ logger.info(f"Loading weights: {ckpt_path}")
203
  sd = load_file(str(ckpt_path))
204
  model.load_pretrained(sd)
205
 
206
+ # [修改] 确保加载到 CPU
207
  model = model.to(torch.device("cpu")).eval()
208
 
209
  scheduler = build_scheduler(exp_cfg)
210
  target_sr = int(exp_cfg.get("sample_rate", 24000))
211
 
212
  _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
213
+ logger.info("CPU pipeline cached.")
214
  return model, scheduler, target_sr
215
 
216
 
217
  # ---------------------------------------------------------
218
+ # [核心] 推理函数 (ZeroGPU 适配)
219
  # ---------------------------------------------------------
220
  @spaces.GPU
221
  def run_edit(
 
228
  ) -> Tuple[Optional[str], str]:
229
  import torch
230
 
231
+ if not audio_file: return None, "Error: Upload audio first."
232
+ if not caption: return None, "Error: Input caption."
 
 
 
 
 
233
 
234
+ # 1. 获取 CPU 模型
235
  model_cpu, scheduler, target_sr = load_pipeline_cpu()
236
 
237
+ # 2. 准备设备 (强制 float16 以防 OOM 和兼容问题)
238
  device = torch.device("cuda")
239
  dtype = torch.float16
240
+
241
+ logger.info(f"🚀 [GPU Start] Device: {device}, Dtype: {dtype}")
242
 
 
 
 
243
  model_on_gpu = None
244
+
 
245
  try:
 
246
  if not torch.cuda.is_available():
247
  raise RuntimeError("ZeroGPU assigned but CUDA not found!")
248
 
249
+ # --- 3. 搬运模型 (CPU -> GPU) ---
250
  gc.collect()
251
  torch.cuda.empty_cache()
 
 
252
 
253
+ logger.info("Moving model to GPU...")
254
+ # [关键] 原位操作警告:model_cpu.to() 会改变 cpu 对象
255
+ # 我们必须在 finally 里搬回去!
256
  model_on_gpu = model_cpu.to(device, dtype=dtype)
257
 
258
+ # --- 4. 数据准备 ---
259
+ torch.manual_seed(int(seed))
260
+ np.random.seed(int(seed))
261
+
262
+ wav = load_and_process_audio(audio_file, target_sr=target_sr).to(device, dtype=dtype)
263
 
 
 
 
264
  batch = {
265
  "audio_id": [Path(audio_file).stem],
266
+ "content": [{"audio": wav, "caption": caption}],
267
  "task": ["audio_editing"],
268
  }
269
+
270
  kwargs = {
271
  "num_steps": int(num_steps),
272
  "guidance_scale": float(guidance_scale),
 
277
  }
278
 
279
  # --- 5. 推理 ---
280
+ logger.info("Inference start...")
281
  t0 = time.time()
 
282
  with torch.no_grad():
 
283
  with torch.autocast("cuda", dtype=dtype):
284
  out = model_on_gpu.inference(scheduler=scheduler, **kwargs)
 
285
  dt = time.time() - t0
286
+ logger.info(f"✅ Inference done: {dt:.2f}s")
287
 
288
+ # --- 6. 结果保存 ---
289
  out_audio = out[0, 0].detach().float().cpu().numpy()
290
  out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
291
  sf.write(str(out_path), out_audio, samplerate=target_sr)
292
 
293
+ return str(out_path), f"OK | {dt:.2f}s | Seed: {seed}"
294
 
295
  except Exception as e:
296
+ # [关键] 打印完整堆栈,不再报 404
297
  err_msg = traceback.format_exc()
298
+ logger.error(f"❌ ERROR:\n{err_msg}")
299
+ return None, f"Runtime Error: {str(e)}\nCheck Logs."
300
 
301
  finally:
302
+ # --- 7. [关键] 现场恢复 ---
303
+ logger.info("♻️ Restoring CPU state...")
304
  try:
305
+ # 必须搬回 CPU,否则缓存中的针指向已释放的显存,下次必崩
306
  if 'model_cpu' in locals() and model_cpu is not None:
307
  model_cpu.to("cpu")
308
  logger.info("Model restored to CPU.")
309
  except Exception as e:
310
+ logger.error(f"Failed to restore model: {e}")
311
 
312
+ # 清理显存
313
  if 'model_on_gpu' in locals(): del model_on_gpu
 
 
 
314
  torch.cuda.empty_cache()
315
  gc.collect()
316
 
317
 
318
  # ---------------------------------------------------------
319
+ # UI
320
  # ---------------------------------------------------------
321
  def build_demo():
322
  with gr.Blocks(title="MMEdit (ZeroGPU)") as demo:
323
+ gr.Markdown("# MMEdit ZeroGPU")
324
 
325
  with gr.Row():
326
  with gr.Column():
327
  audio_in = gr.Audio(label="Input Audio", type="filepath")
328
+ caption = gr.Textbox(label="Caption", lines=3)
329
 
 
330
  gr.Examples(
331
+ label="Examples",
332
+ examples=[["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."]],
 
 
333
  inputs=[audio_in, caption],
 
334
  )
335
 
336
  with gr.Row():
337
+ num_steps = gr.Slider(1, 100, value=50, step=1, label="Steps")
338
+ guidance_scale = gr.Slider(1.0, 12.0, value=5.0, step=0.5, label="Guidance")
339
+ rescale = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Rescale")
340
+ seed = gr.Number(42, label="Seed")
 
 
341
 
342
+ run_btn = gr.Button("Run", variant="primary")
343
 
344
  with gr.Column():
345
+ audio_out = gr.Audio(label="Output", type="filepath")
346
  status = gr.Textbox(label="Status")
347
 
348
  run_btn.click(
 
351
  outputs=[audio_out, status],
352
  )
353
 
 
 
 
 
 
 
354
  return demo
355
 
356