CocoBro commited on
Commit
b75ca87
·
1 Parent(s): 66b10fb
Files changed (1) hide show
  1. app.py +207 -128
app.py CHANGED
@@ -20,17 +20,14 @@ import librosa
20
  from huggingface_hub import snapshot_download
21
 
22
  # -----------------------------
23
- # Logging 配置
24
  # -----------------------------
25
- logging.basicConfig(
26
- level=logging.INFO,
27
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
28
- datefmt="%H:%M:%S"
29
- )
30
  logger = logging.getLogger("mmedit_space")
31
 
 
32
  # ---------------------------------------------------------
33
- # 配置信息
34
  # ---------------------------------------------------------
35
  MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit")
36
  MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
@@ -38,44 +35,58 @@ MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
38
  QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
39
  QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
40
 
 
41
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
42
 
43
  OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
44
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
45
 
46
  # ---------------------------------------------------------
47
- # 全局缓存
48
  # ---------------------------------------------------------
49
- # 存储 (model_cpu, scheduler, target_sr)
50
- # 警告:此缓存中的 model 必须始终保持在 "cpu" 设备上!
51
  _PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
 
52
  _MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
53
 
54
 
55
  # ---------------------------------------------------------
56
- # 辅助函数
57
  # ---------------------------------------------------------
58
  def resolve_model_dirs() -> Tuple[Path, Path]:
59
- """下载并返回模型路径"""
60
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
61
  if cache_key in _MODEL_DIR_CACHE:
62
  return _MODEL_DIR_CACHE[cache_key]
63
 
64
- logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID}...")
65
  repo_root = snapshot_download(
66
- repo_id=MMEDIT_REPO_ID, revision=MMEDIT_REVISION, token=HF_TOKEN
 
 
 
 
67
  )
68
-
69
- logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID}...")
 
70
  qwen_root = snapshot_download(
71
- repo_id=QWEN_REPO_ID, revision=QWEN_REVISION, token=HF_TOKEN
 
 
 
 
72
  )
73
-
74
- res = (Path(repo_root).resolve(), Path(qwen_root).resolve())
75
- _MODEL_DIR_CACHE[cache_key] = res
76
- return res
77
 
78
- def load_and_process_audio(audio_path: str, target_sr: int) -> torch.Tensor:
 
 
 
 
79
  import torchaudio
80
 
81
  path = Path(audio_path)
@@ -85,34 +96,87 @@ def load_and_process_audio(audio_path: str, target_sr: int) -> torch.Tensor:
85
  waveform, orig_sr = torchaudio.load(str(path)) # (C, T)
86
 
87
  # Convert to mono
88
- if waveform.ndim > 1:
89
- waveform = waveform.mean(dim=0)
90
-
91
- # Resample logic (robust method)
92
- if int(orig_sr) != int(target_sr):
93
- wav_np = waveform.cpu().numpy()
94
-
95
- # Intermediate resampling to 16k if needed (for better stability)
96
- if int(orig_sr) != 16000:
97
- wav_np = librosa.resample(wav_np, orig_sr=int(orig_sr), target_sr=16000)
98
- orig_sr_mid = 16000
 
 
99
  else:
100
  orig_sr_mid = int(orig_sr)
101
-
102
  if int(target_sr) != orig_sr_mid:
103
- wav_np = librosa.resample(wav_np, orig_sr=orig_sr_mid, target_sr=int(target_sr))
104
-
105
- waveform = torch.from_numpy(wav_np)
106
-
107
  return waveform
108
 
109
- def build_scheduler(exp_cfg):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  import diffusers.schedulers as noise_schedulers
 
111
  name = exp_cfg["model"].get("noise_scheduler_name", "stabilityai/stable-diffusion-2-1")
112
  try:
113
- return noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler", token=HF_TOKEN)
 
114
  except Exception as e:
115
- logger.warning(f"Scheduler init failed: {e}, using fallback.")
116
  return noise_schedulers.DDIMScheduler(
117
  num_train_timesteps=1000,
118
  beta_start=0.00085,
@@ -123,15 +187,20 @@ def build_scheduler(exp_cfg):
123
  steps_offset=1,
124
  )
125
 
 
 
 
 
126
  def load_pipeline_cpu() -> Tuple[object, object, int]:
127
- """加载模型到 RAM(CPU),并建立全局缓存"""
 
128
  import hydra
129
  from omegaconf import OmegaConf
130
  from safetensors.torch import load_file
131
- from utils.config import register_omegaconf_resolvers
132
  from models.common import LoadPretrainedBase
 
133
 
134
- # 注册 omegaconf
135
  try:
136
  register_omegaconf_resolvers()
137
  except Exception:
@@ -142,48 +211,36 @@ def load_pipeline_cpu() -> Tuple[object, object, int]:
142
  return _PIPELINE_CACHE[cache_key]
143
 
144
  repo_root, qwen_root = resolve_model_dirs()
145
-
146
- cfg_path = repo_root / "config.yaml"
147
- exp_cfg = OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True)
148
-
149
- # --- Config Patching ---
150
- # Fix VAE ckpt path
151
- vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "")
152
- if vae_ckpt:
153
- # 简单暴力的路径修复:只要是 ckpt 就去 vae 目录下找
154
- fname = Path(vae_ckpt).name
155
- local_vae = repo_root / "vae" / fname
156
- if local_vae.exists():
157
- exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(local_vae)
158
- else:
159
- # 尝试直接在 repo_root 下找
160
- if (repo_root / fname).exists():
161
- exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(repo_root / fname)
162
 
163
- # Fix Qwen path
164
- exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
 
 
 
165
 
166
- logger.info("Instantiating model architecture...")
 
 
167
  model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
168
-
169
  ckpt_path = repo_root / "model.safetensors"
170
- logger.info(f"Loading weights from {ckpt_path.name}...")
171
  sd = load_file(str(ckpt_path))
172
  model.load_pretrained(sd)
173
-
174
- # 关键:确保初始状态在 CPU
175
- model = model.to("cpu").eval()
176
-
177
  scheduler = build_scheduler(exp_cfg)
178
  target_sr = int(exp_cfg.get("sample_rate", 24000))
179
 
180
  _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
181
- logger.info(" Model loaded and cached in CPU RAM.")
182
  return model, scheduler, target_sr
183
 
184
 
185
  # ---------------------------------------------------------
186
- # ZeroGPU 推理函数
187
  # ---------------------------------------------------------
188
  @spaces.GPU
189
  def run_edit(
@@ -194,52 +251,59 @@ def run_edit(
194
  guidance_rescale: float,
195
  seed: int,
196
  ) -> Tuple[Optional[str], str]:
 
 
 
 
 
197
 
198
- if not audio_file: return None, "Please upload an audio file."
199
  caption = (caption or "").strip()
200
- if not caption: return None, "Please enter an instruction caption."
 
201
 
202
- # 1. 获取 CPU 上的模型引用
203
  model_cpu, scheduler, target_sr = load_pipeline_cpu()
204
-
205
- # 2. 准备设备 - 强制使用 float16
206
  device = torch.device("cuda")
207
- dtype = torch.float16 # <--- 强制 FP16
208
-
209
- logger.info(f"🚀 [GPU Task Start] Device: {device}, Precision: {dtype}")
210
 
211
- # 用于 finally 清理的变量
212
- model_on_gpu = None
 
 
213
  wav_on_gpu = None
214
 
215
  try:
216
- # --- GPU 环境检查 ---
217
  if not torch.cuda.is_available():
218
- raise RuntimeError("ZeroGPU assigned but CUDA unavailable.")
219
 
220
  # --- 3. 模型搬运 (CPU -> GPU) ---
221
- # 显式清理,为大模型腾出完整空间
222
  gc.collect()
223
  torch.cuda.empty_cache()
224
-
225
  logger.info("Moving model to GPU...")
226
 
227
- # ⚠️ 核心逻辑:这里虽然用了 to(device),这会修改 model_cpu 的设备属性
228
- # 所以我们在 finally 块中必须将其搬回 CPU,否则下次运行会因为设备失效而崩溃
 
229
  model_on_gpu = model_cpu.to(device, dtype=dtype)
230
 
231
- # --- 4. 数据准备 ---
232
- torch.manual_seed(int(seed))
233
- np.random.seed(int(seed))
234
-
235
- wav_on_gpu = load_and_process_audio(audio_file, target_sr).to(device, dtype=dtype)
236
 
 
 
 
237
  batch = {
238
  "audio_id": [Path(audio_file).stem],
239
  "content": [{"audio": wav_on_gpu, "caption": caption}],
240
  "task": ["audio_editing"],
241
  }
242
-
243
  kwargs = {
244
  "num_steps": int(num_steps),
245
  "guidance_scale": float(guidance_scale),
@@ -254,87 +318,102 @@ def run_edit(
254
  t0 = time.time()
255
 
256
  with torch.no_grad():
257
- # 使用 float16
258
  with torch.autocast("cuda", dtype=dtype):
259
  out = model_on_gpu.inference(scheduler=scheduler, **kwargs)
260
-
261
  dt = time.time() - t0
262
  logger.info(f"✅ Inference finished in {dt:.2f}s")
263
 
264
- # --- 6. 保存结果 ---
265
- # 立即 detach 并转回 CPU
266
  out_audio = out[0, 0].detach().float().cpu().numpy()
267
  out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
268
  sf.write(str(out_path), out_audio, samplerate=target_sr)
269
-
270
- return str(out_path), f"Success | Time: {dt:.2f}s | Seed: {seed}"
271
 
272
  except Exception as e:
273
- # 🔥 捕捉所有错误,防止 spaces 吞掉报错,打印完整堆栈
274
  err_msg = traceback.format_exc()
275
  logger.error(f"❌ CRITICAL ERROR:\n{err_msg}")
276
  return None, f"Runtime Error: {str(e)}\n(See logs for details)"
277
 
278
  finally:
279
- # --- 7. 关键:现场恢复 ---
280
- # 无论成功还是失败,必须把模型搬回 CPU,否则全局缓存 _PIPELINE_CACHE 将指向已释放的显存
281
  logger.info("♻️ Cleaning up resources...")
282
  try:
283
- # 只要 model_cpu 还在,就强制搬回 CPU
284
  if 'model_cpu' in locals() and model_cpu is not None:
285
  model_cpu.to("cpu")
286
  logger.info("Model restored to CPU.")
287
  except Exception as e:
288
  logger.error(f"Failed to restore model to CPU: {e}")
289
 
290
- # 删除局部引用
291
  if 'model_on_gpu' in locals(): del model_on_gpu
292
  if 'wav_on_gpu' in locals(): del wav_on_gpu
293
-
294
- # 强制显存清理
295
  torch.cuda.empty_cache()
296
  gc.collect()
297
 
298
 
299
  # ---------------------------------------------------------
300
- # UI 启动
301
  # ---------------------------------------------------------
302
  def build_demo():
303
- with gr.Blocks(title="MMEdit ZeroGPU") as demo:
304
- gr.Markdown("## MMEdit")
305
- gr.Markdown("ZeroGPU environment detected. Resources are allocated dynamically.")
306
-
307
  with gr.Row():
308
  with gr.Column():
309
  audio_in = gr.Audio(label="Input Audio", type="filepath")
310
- caption = gr.Textbox(label="Editing Instruction", placeholder="e.g., Add rain sound in the background")
311
-
312
- with gr.Accordion("Advanced Settings", open=False):
313
- steps = gr.Slider(10, 100, 50, step=1, label="Steps")
314
- cfg = gr.Slider(1.0, 15.0, 5.0, step=0.5, label="Guidance Scale")
315
- rescale = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Guidance Rescale")
316
- seed = gr.Number(42, label="Seed")
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  run_btn = gr.Button("Run Editing", variant="primary")
319
 
320
  with gr.Column():
321
- audio_out = gr.Audio(label="Result", type="filepath")
322
- status = gr.Textbox(label="Status Logs")
323
 
324
  run_btn.click(
325
- run_edit,
326
- inputs=[audio_in, caption, steps, cfg, rescale, seed],
327
- outputs=[audio_out, status]
 
 
 
 
 
 
328
  )
 
329
  return demo
330
 
331
 
332
  if __name__ == "__main__":
333
  demo = build_demo()
334
- # 兼容性设置:去掉 ssr_mode,让 Gradio 自动处理
335
- port = int(os.environ.get("PORT", 7860))
336
  demo.queue().launch(
337
- server_name="0.0.0.0",
338
  server_port=port,
339
- share=False
340
  )
 
20
  from huggingface_hub import snapshot_download
21
 
22
  # -----------------------------
23
+ # Logging
24
  # -----------------------------
25
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
 
 
26
  logger = logging.getLogger("mmedit_space")
27
 
28
+
29
  # ---------------------------------------------------------
30
+ # HF Repo IDs
31
  # ---------------------------------------------------------
32
  MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit")
33
  MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
 
35
  QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
36
  QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
37
 
38
+ # 如果 Qwen gated:Space 里把 HF_TOKEN 设为 Secret
39
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
40
 
41
  OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
42
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
43
 
44
  # ---------------------------------------------------------
45
+ # 缓存定义
46
  # ---------------------------------------------------------
47
+ # cache: key -> (model_cpu, scheduler, target_sr)
48
+ # 注意:model_cpu 必须始终在 CPU
49
  _PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
50
+ # cache: key -> (repo_root, qwen_root)
51
  _MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
52
 
53
 
54
  # ---------------------------------------------------------
55
+ # 1. 下载 repo
56
  # ---------------------------------------------------------
57
  def resolve_model_dirs() -> Tuple[Path, Path]:
 
58
  cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
59
  if cache_key in _MODEL_DIR_CACHE:
60
  return _MODEL_DIR_CACHE[cache_key]
61
 
62
+ logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})")
63
  repo_root = snapshot_download(
64
+ repo_id=MMEDIT_REPO_ID,
65
+ revision=MMEDIT_REVISION,
66
+ local_dir=None,
67
+ local_dir_use_symlinks=False,
68
+ token=HF_TOKEN,
69
  )
70
+ repo_root = Path(repo_root).resolve()
71
+
72
+ logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID} (revision={QWEN_REVISION})")
73
  qwen_root = snapshot_download(
74
+ repo_id=QWEN_REPO_ID,
75
+ revision=QWEN_REVISION,
76
+ local_dir=None,
77
+ local_dir_use_symlinks=False,
78
+ token=HF_TOKEN, # gated 模型必须
79
  )
80
+ qwen_root = Path(qwen_root).resolve()
81
+
82
+ _MODEL_DIR_CACHE[cache_key] = (repo_root, qwen_root)
83
+ return repo_root, qwen_root
84
 
85
+
86
+ # ---------------------------------------------------------
87
+ # 2. 音频加载(保留你的逻辑,增强鲁棒性)
88
+ # ---------------------------------------------------------
89
+ def load_and_process_audio(audio_path: str, target_sr: int):
90
  import torchaudio
91
 
92
  path = Path(audio_path)
 
96
  waveform, orig_sr = torchaudio.load(str(path)) # (C, T)
97
 
98
  # Convert to mono
99
+ if waveform.ndim == 2:
100
+ waveform = waveform.mean(dim=0) # (T,)
101
+ elif waveform.ndim > 2:
102
+ waveform = waveform.reshape(-1)
103
+
104
+ if target_sr and int(target_sr) != int(orig_sr):
105
+ waveform_np = waveform.cpu().numpy()
106
+
107
+ # 稳健的两步重采样逻辑
108
+ sr_mid = 16000
109
+ if int(orig_sr) != sr_mid:
110
+ waveform_np = librosa.resample(waveform_np, orig_sr=int(orig_sr), target_sr=sr_mid)
111
+ orig_sr_mid = sr_mid
112
  else:
113
  orig_sr_mid = int(orig_sr)
114
+
115
  if int(target_sr) != orig_sr_mid:
116
+ waveform_np = librosa.resample(waveform_np, orig_sr=orig_sr_mid, target_sr=int(target_sr))
117
+
118
+ waveform = torch.from_numpy(waveform_np)
119
+
120
  return waveform
121
 
122
+
123
+ # ---------------------------------------------------------
124
+ # 3. 校验 repo 结构(保留你的逻辑)
125
+ # ---------------------------------------------------------
126
+ def assert_repo_layout(repo_root: Path) -> None:
127
+ must = [repo_root / "config.yaml", repo_root / "model.safetensors", repo_root / "vae"]
128
+ for p in must:
129
+ if not p.exists():
130
+ raise FileNotFoundError(f"Missing required path: {p}")
131
+
132
+ vae_files = list((repo_root / "vae").glob("*.ckpt"))
133
+ if len(vae_files) == 0:
134
+ raise FileNotFoundError(f"No .ckpt found under: {repo_root/'vae'}")
135
+
136
+
137
+ # ---------------------------------------------------------
138
+ # 4. 适配 config.yaml(保留你的逻辑)
139
+ # ---------------------------------------------------------
140
+ def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None:
141
+ # ---- 1) VAE ckpt ----
142
+ vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
143
+ if vae_ckpt:
144
+ vae_ckpt = str(vae_ckpt).replace("\\", "/")
145
+ idx = vae_ckpt.find("vae/")
146
+ if idx != -1:
147
+ vae_rel = vae_ckpt[idx:] # 从 vae/ 开始截断
148
+ else:
149
+ if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt:
150
+ vae_rel = f"vae/{vae_ckpt}"
151
+ else:
152
+ vae_rel = vae_ckpt
153
+
154
+ vae_path = (repo_root / vae_rel).resolve()
155
+ exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path)
156
+
157
+ if not vae_path.exists():
158
+ # Fallback check (鲁棒性增强)
159
+ if (repo_root / Path(vae_ckpt).name).exists():
160
+ exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(repo_root / Path(vae_ckpt).name)
161
+ else:
162
+ logger.warning(f"VAE ckpt warning: {vae_path} not found. Model loading might fail.")
163
+
164
+ # ---- 2) Qwen2-Audio model_path ----
165
+ exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
166
+
167
+
168
+ # ---------------------------------------------------------
169
+ # 5. Scheduler(保留你的逻辑)
170
+ # ---------------------------------------------------------
171
+ def build_scheduler(exp_cfg: Dict[str, Any]):
172
  import diffusers.schedulers as noise_schedulers
173
+
174
  name = exp_cfg["model"].get("noise_scheduler_name", "stabilityai/stable-diffusion-2-1")
175
  try:
176
+ scheduler = noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler", token=HF_TOKEN)
177
+ return scheduler
178
  except Exception as e:
179
+ logger.warning(f"DDIMScheduler.from_pretrained failed for '{name}', fallback. err={e}")
180
  return noise_schedulers.DDIMScheduler(
181
  num_train_timesteps=1000,
182
  beta_start=0.00085,
 
187
  steps_offset=1,
188
  )
189
 
190
+
191
+ # ---------------------------------------------------------
192
+ # 6. 冷启动:Load Pipeline to CPU
193
+ # ---------------------------------------------------------
194
  def load_pipeline_cpu() -> Tuple[object, object, int]:
195
+ # 延迟导入
196
+ import torch
197
  import hydra
198
  from omegaconf import OmegaConf
199
  from safetensors.torch import load_file
200
+
201
  from models.common import LoadPretrainedBase
202
+ from utils.config import register_omegaconf_resolvers
203
 
 
204
  try:
205
  register_omegaconf_resolvers()
206
  except Exception:
 
211
  return _PIPELINE_CACHE[cache_key]
212
 
213
  repo_root, qwen_root = resolve_model_dirs()
214
+ assert_repo_layout(repo_root)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
+ logger.info(f"repo_root = {repo_root}")
217
+ logger.info(f"qwen_root = {qwen_root}")
218
+
219
+ exp_cfg = OmegaConf.load(repo_root / "config.yaml")
220
+ exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
221
 
222
+ patch_paths_in_exp_config(exp_cfg, repo_root, qwen_root)
223
+
224
+ logger.info("Instantiating model...")
225
  model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
226
+
227
  ckpt_path = repo_root / "model.safetensors"
 
228
  sd = load_file(str(ckpt_path))
229
  model.load_pretrained(sd)
230
+
231
+ # 关键:确保模型在 CPU 上,并且是 eval 模式
232
+ model = model.to(torch.device("cpu")).eval()
233
+
234
  scheduler = build_scheduler(exp_cfg)
235
  target_sr = int(exp_cfg.get("sample_rate", 24000))
236
 
237
  _PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
238
+ logger.info("CPU pipeline loaded and cached.")
239
  return model, scheduler, target_sr
240
 
241
 
242
  # ---------------------------------------------------------
243
+ # 7. ZeroGPU 推理核心(修复版)
244
  # ---------------------------------------------------------
245
  @spaces.GPU
246
  def run_edit(
 
251
  guidance_rescale: float,
252
  seed: int,
253
  ) -> Tuple[Optional[str], str]:
254
+ import torch
255
+
256
+ # 1. 基础检查
257
+ if audio_file is None or not Path(audio_file).exists():
258
+ return None, "Error: please upload an audio file."
259
 
 
260
  caption = (caption or "").strip()
261
+ if not caption:
262
+ return None, "Error: caption is empty."
263
 
264
+ # 2. 获取缓存模型 (CPU)
265
  model_cpu, scheduler, target_sr = load_pipeline_cpu()
266
+
267
+ # 强制使用 float16,兼容性最好
268
  device = torch.device("cuda")
269
+ dtype = torch.float16
 
 
270
 
271
+ logger.info(f"🚀 [GPU Task Start] Device: {device}, Dtype: {dtype}")
272
+
273
+ # 用于 finally 清理
274
+ model_on_gpu = None
275
  wav_on_gpu = None
276
 
277
  try:
278
+ # --- 检查环境 ---
279
  if not torch.cuda.is_available():
280
+ raise RuntimeError("ZeroGPU assigned but CUDA not found!")
281
 
282
  # --- 3. 模型搬运 (CPU -> GPU) ---
 
283
  gc.collect()
284
  torch.cuda.empty_cache()
285
+
286
  logger.info("Moving model to GPU...")
287
 
288
+ # ⚠️ 关键点:这里 model_cpu.to(device) 是原位操作,
289
+ # 我们必须在 finally 里搬回去,才能保证全局缓存不坏。
290
+ # 同时做 dtype 转换以节省显存。
291
  model_on_gpu = model_cpu.to(device, dtype=dtype)
292
 
293
+ # --- 4. 数据预处理 ---
294
+ seed = int(seed)
295
+ torch.manual_seed(seed)
296
+ np.random.seed(seed)
 
297
 
298
+ # 加载音频并转到 GPU
299
+ wav_on_gpu = load_and_process_audio(audio_file, target_sr=target_sr).to(device, dtype=dtype)
300
+
301
  batch = {
302
  "audio_id": [Path(audio_file).stem],
303
  "content": [{"audio": wav_on_gpu, "caption": caption}],
304
  "task": ["audio_editing"],
305
  }
306
+
307
  kwargs = {
308
  "num_steps": int(num_steps),
309
  "guidance_scale": float(guidance_scale),
 
318
  t0 = time.time()
319
 
320
  with torch.no_grad():
321
+ # 使用 float16 autocast
322
  with torch.autocast("cuda", dtype=dtype):
323
  out = model_on_gpu.inference(scheduler=scheduler, **kwargs)
324
+
325
  dt = time.time() - t0
326
  logger.info(f"✅ Inference finished in {dt:.2f}s")
327
 
328
+ # --- 6. 后处理 ---
 
329
  out_audio = out[0, 0].detach().float().cpu().numpy()
330
  out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
331
  sf.write(str(out_path), out_audio, samplerate=target_sr)
332
+
333
+ return str(out_path), f"OK | time={dt:.2f}s | seed={seed}"
334
 
335
  except Exception as e:
336
+ # 🔥 打印完整堆栈,防止 404 掩盖真实错误
337
  err_msg = traceback.format_exc()
338
  logger.error(f"❌ CRITICAL ERROR:\n{err_msg}")
339
  return None, f"Runtime Error: {str(e)}\n(See logs for details)"
340
 
341
  finally:
342
+ # --- 7. 关键:现场恢复(必须执行)---
 
343
  logger.info("♻️ Cleaning up resources...")
344
  try:
345
+ # 必须把模型搬回 CPU,否则全局缓存 _PIPELINE_CACHE 指向已释放的显存
346
  if 'model_cpu' in locals() and model_cpu is not None:
347
  model_cpu.to("cpu")
348
  logger.info("Model restored to CPU.")
349
  except Exception as e:
350
  logger.error(f"Failed to restore model to CPU: {e}")
351
 
352
+ # 删除引用
353
  if 'model_on_gpu' in locals(): del model_on_gpu
354
  if 'wav_on_gpu' in locals(): del wav_on_gpu
355
+
356
+ # 强制清理显存
357
  torch.cuda.empty_cache()
358
  gc.collect()
359
 
360
 
361
  # ---------------------------------------------------------
362
+ # UI (完全保留你的 Examples)
363
  # ---------------------------------------------------------
364
  def build_demo():
365
+ with gr.Blocks(title="MMEdit (ZeroGPU)") as demo:
366
+ gr.Markdown("# MMEdit ZeroGPU(audio + caption → edited audio)")
367
+
 
368
  with gr.Row():
369
  with gr.Column():
370
  audio_in = gr.Audio(label="Input Audio", type="filepath")
371
+ caption = gr.Textbox(label="Caption (Edit Instruction)", lines=3)
372
+
373
+ # 恢复了你的 Examples
374
+ gr.Examples(
375
+ label="example inputs",
376
+ examples=[
377
+ ["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."],
378
+ ],
379
+ inputs=[audio_in, caption],
380
+ cache_examples=False,
381
+ )
382
+
383
+ with gr.Row():
384
+ num_steps = gr.Slider(1, 100, value=50, step=1, label="num_steps")
385
+ guidance_scale = gr.Slider(1.0, 12.0, value=5.0, step=0.5, label="guidance_scale")
386
+
387
+ with gr.Row():
388
+ guidance_rescale = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="guidance_rescale")
389
+ seed = gr.Number(value=42, precision=0, label="seed")
390
 
391
  run_btn = gr.Button("Run Editing", variant="primary")
392
 
393
  with gr.Column():
394
+ audio_out = gr.Audio(label="Edited Audio", type="filepath")
395
+ status = gr.Textbox(label="Status")
396
 
397
  run_btn.click(
398
+ fn=run_edit,
399
+ inputs=[audio_in, caption, num_steps, guidance_scale, guidance_rescale, seed],
400
+ outputs=[audio_out, status],
401
+ )
402
+
403
+ gr.Markdown(
404
+ "## 注意事项\n"
405
+ "1) ZeroGPU 首次点击会分配 GPU,可能稍慢。\n"
406
+ "2) 如果首次报 cuda 不可用,通常重试一次即可。\n"
407
  )
408
+
409
  return demo
410
 
411
 
412
  if __name__ == "__main__":
413
  demo = build_demo()
414
+ port = int(os.environ.get("PORT", "7860"))
 
415
  demo.queue().launch(
416
+ server_name="0.0.0.0",
417
  server_port=port,
418
+ share=False,
419
  )