File size: 12,469 Bytes
c14d03d dabc4d4 92bc756 c14d03d 47b5ec4 5a53edf c14d03d 688592f c14d03d b75ca87 c14d03d b75ca87 c14d03d b75ca87 a2ca450 c14d03d a2ca450 0bcb372 c14d03d 3c696cb dabc4d4 b75ca87 0bcb372 c14d03d 3c696cb c14d03d 0bcb372 3c696cb c14d03d b75ca87 3c696cb b75ca87 c14d03d b75ca87 3c696cb c14d03d b75ca87 3c696cb c14d03d b75ca87 c14d03d b75ca87 3c696cb b75ca87 3c696cb 688592f dabc4d4 688592f a2ca450 c14d03d 3c696cb c14d03d 3c696cb b75ca87 3c696cb b75ca87 c14d03d b75ca87 3c696cb c14d03d b75ca87 c14d03d b75ca87 3c696cb b75ca87 3c696cb b75ca87 3c696cb b75ca87 3c696cb b75ca87 688592f b75ca87 3c696cb b75ca87 b312d65 caea508 77f1338 43bf2c1 b75ca87 dabc4d4 77f1338 11cf650 9466f64 a2ca450 47b5ec4 dabc4d4 47b5ec4 a2ca450 77f1338 688592f b2d65b0 77f1338 47b5ec4 77f1338 a2ca450 77f1338 a2ca450 77f1338 66b10fb 77f1338 47b5ec4 77f1338 43bf2c1 62c2ea1 43bf2c1 62c2ea1 43bf2c1 62c2ea1 43bf2c1 77f1338 688592f 62c2ea1 66b10fb f3f0643 688592f f3f0643 47b5ec4 f3f0643 47b5ec4 77f1338 f3f0643 62c2ea1 77f1338 47b5ec4 a2ca450 f3f0643 47b5ec4 b2d65b0 77f1338 47b5ec4 f3f0643 77f1338 f3f0643 66b10fb 77f1338 688592f 77f1338 c14d03d 77f1338 c14d03d 77f1338 b75ca87 a2ca450 11cf650 a2ca450 6b00c17 b75ca87 3c696cb 77f1338 c14d03d 77f1338 b75ca87 77f1338 c14d03d 5a53edf 9466f64 c14d03d caea508 9466f64 0bcb372 caea508 77f1338 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ZeroGPU 关键:必须最先导入
import spaces
import traceback
import os
import time
import logging
from pathlib import Path
from typing import Tuple, Optional, Dict, Any
import gc
import gradio as gr
import numpy as np
import soundfile as sf
from huggingface_hub import snapshot_download
# -----------------------------
# Logging
# -----------------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("mmedit_space")
MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit")
MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
USE_AMP = os.environ.get("USE_AMP", "0") == "1"
AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") # "bf16" or "fp16"
_PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
# cache: key -> (repo_root, qwen_root)
_MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
# ---------------------------------------------------------
# 下载 repo(只下载一次;huggingface_hub 自带缓存)
# ---------------------------------------------------------
def resolve_model_dirs() -> Tuple[Path, Path]:
cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
if cache_key in _MODEL_DIR_CACHE:
return _MODEL_DIR_CACHE[cache_key]
logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})")
repo_root = snapshot_download(
repo_id=MMEDIT_REPO_ID,
revision=MMEDIT_REVISION,
local_dir=None,
local_dir_use_symlinks=False,
token=HF_TOKEN,
)
repo_root = Path(repo_root).resolve()
logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID} (revision={QWEN_REVISION})")
qwen_root = snapshot_download(
repo_id=QWEN_REPO_ID,
revision=QWEN_REVISION,
local_dir=None,
local_dir_use_symlinks=False,
token=HF_TOKEN, # gated 模型必须
)
qwen_root = Path(qwen_root).resolve()
_MODEL_DIR_CACHE[cache_key] = (repo_root, qwen_root)
return repo_root, qwen_root
# ---------------------------------------------------------
# 你的音频加载(按你要求:orig -> 16k -> target_sr)
# ---------------------------------------------------------
def load_and_process_audio(audio_path: str, target_sr: int):
# 延迟导入(避免启动阶段触发 CUDA 初始化)
import torch
import torchaudio
import librosa
path = Path(audio_path)
if not path.exists():
raise FileNotFoundError(f"Audio file not found: {audio_path}")
waveform, orig_sr = torchaudio.load(str(path)) # (C, T)
# Convert to mono
if waveform.ndim == 2:
waveform = waveform.mean(dim=0) # (T,)
elif waveform.ndim > 2:
waveform = waveform.reshape(-1)
if target_sr and int(target_sr) != int(orig_sr):
waveform_np = waveform.cpu().numpy()
# 1) 先到 16k
sr_mid = 16000
if int(orig_sr) != sr_mid:
waveform_np = librosa.resample(waveform_np, orig_sr=int(orig_sr), target_sr=sr_mid)
orig_sr_mid = sr_mid
else:
orig_sr_mid = int(orig_sr)
# 2) 再到 target_sr(如 24k)
if int(target_sr) != orig_sr_mid:
waveform_np = librosa.resample(waveform_np, orig_sr=orig_sr_mid, target_sr=int(target_sr))
waveform = torch.from_numpy(waveform_np)
return waveform
# ---------------------------------------------------------
# 校验 repo 结构
# ---------------------------------------------------------
def assert_repo_layout(repo_root: Path) -> None:
must = [repo_root / "config.yaml", repo_root / "model.safetensors", repo_root / "vae"]
for p in must:
if not p.exists():
raise FileNotFoundError(f"Missing required path: {p}")
vae_files = list((repo_root / "vae").glob("*.ckpt"))
if len(vae_files) == 0:
raise FileNotFoundError(f"No .ckpt found under: {repo_root/'vae'}")
# ---------------------------------------------------------
# 适配 config.yaml 的路径写法
# ---------------------------------------------------------
def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None:
# ---- 1) VAE ckpt ----
vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
if vae_ckpt:
vae_ckpt = str(vae_ckpt).replace("\\", "/")
idx = vae_ckpt.find("vae/")
if idx != -1:
vae_rel = vae_ckpt[idx:] # 从 vae/ 开始截断
else:
if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt:
vae_rel = f"vae/{vae_ckpt}"
else:
vae_rel = vae_ckpt
vae_path = (repo_root / vae_rel).resolve()
exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path)
if not vae_path.exists():
raise FileNotFoundError(
f"VAE ckpt not found after patch:\n"
f" original: {vae_ckpt}\n"
f" patched : {vae_path}\n"
f"Repo root: {repo_root}\n"
f"Expected: {repo_root/'vae'/'*.ckpt'}"
)
# ---- 2) Qwen2-Audio model_path ----
exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
@spaces.GPU
def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed):
import torch
import hydra
from omegaconf import OmegaConf
from safetensors.torch import load_file
import diffusers.schedulers as noise_schedulers
logger.info("🚀 Starting ..")
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
try:
from utils.config import register_omegaconf_resolvers
register_omegaconf_resolvers()
except: pass
if not audio_file: return None, "Please upload audio."
model = None
try:
# ==========================================
logger.info("🚀 Starting ZeroGPU Task...")
# 路径准备
repo_root, qwen_root = resolve_model_dirs()
exp_cfg = OmegaConf.to_container(OmegaConf.load(repo_root / "config.yaml"), resolve=True)
#
vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "")
if vae_ckpt:
p1 = repo_root / "vae" / Path(vae_ckpt).name
p2 = repo_root / Path(vae_ckpt).name
if p1.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p1)
elif p2.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p2)
exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
#
logger.info("Instantiating model (Hydra)...")
model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
# 加载权重
ckpt_path = str(repo_root / "model.safetensors")
logger.info(f"Loading weights from {ckpt_path}...")
sd = load_file(ckpt_path)
model.load_pretrained(sd)
del sd # 立即释放
gc.collect()
# ==========================================
# ==========================================
device = torch.device("cuda")
logger.info("Moving model to CUDA (FP16)...")
# 这一步将模型送入显卡
def safe_move_model(m, dev):
logger.info("🛡️ Moving model to GPU in FP32...")
for name, child in m.named_children():
child.to(dev, dtype=torch.float32)
logger.info(f"Moving {name} to GPU (fp32)...")
m.to(dev, dtype=torch.float32)
return m
model = safe_move_model(model, device)
model.eval()
logger.info("Model is moved to CUDA.")
# Scheduler
try:
scheduler = noise_schedulers.DDIMScheduler.from_pretrained(
exp_cfg["model"].get("noise_scheduler_name", ""),
subfolder="scheduler", token=HF_TOKEN
)
except:
scheduler = noise_schedulers.DDIMScheduler(num_train_timesteps=1000)
# ==========================================
# 3. 开始推理
# ==========================================
target_sr = int(exp_cfg.get("sample_rate", 24000))
torch.manual_seed(int(seed))
np.random.seed(int(seed))
wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=torch.float32)
batch = {
"audio_id": [Path(audio_file).stem],
"content": [{"audio": wav, "caption": caption}],
"task": ["audio_editing"],
"num_steps": int(num_steps),
"guidance_scale": float(guidance_scale),
"guidance_rescale": float(guidance_rescale),
"use_gt_duration": False,
"mask_time_aligned_content": False
}
logger.info("Inference running...")
t0 = time.time()
with torch.no_grad():
out = model.inference(scheduler=scheduler, **batch)
out_audio = out[0, 0].detach().float().cpu().numpy()
out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
sf.write(str(out_path), out_audio, samplerate=target_sr)
return str(out_path), f"Success | {time.time()-t0:.2f}s"
except Exception as e:
err = traceback.format_exc()
logger.error(f"❌ ERROR:\n{err}")
return None, f"Runtime Error: {e}"
finally:
# 强制清理,防止下一次任务显存不够
logger.info("Cleaning up...")
if model is not None: del model
torch.cuda.empty_cache()
gc.collect()
# -----------------------------
# UI
# -----------------------------
def build_demo():
with gr.Blocks(title="MMEdit") as demo:
gr.Markdown("# MMEdit ZeroGPU (Direct Load)")
with gr.Row():
with gr.Column():
audio_in = gr.Audio(label="Input", type="filepath")
caption = gr.Textbox(label="Instruction", lines=3)
gr.Examples(
label="Examples (Click to load)",
# 格式:[ [音频路径1, 提示词1], [音频路径2, 提示词2], ... ]
examples=[
# 示例 1 (原本的)
["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."],
["./YDKM2KjNkX18.wav", "Incorporate Telephone bell ringing into the background."],
["./drop_audiocaps_1.wav", "Erase the rain falling sound from the background."],
["./reorder_audiocaps_1.wav", "Switch the positions of the woman's voice and whistling."]
],
inputs=[audio_in, caption], # 对应上面列表的顺序:第一个是 Audio,第二个是 Textbox
cache_examples=False, # ZeroGPU 环境建议设为 False,避免启动时耗时计算
)
with gr.Row():
num_steps = gr.Slider(10, 100, 50, step=1, label="Steps")
guidance_scale = gr.Slider(1.0, 12.0, 5.0, step=0.5, label="Guidance")
guidance_rescale = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Rescale")
seed = gr.Number(42, label="Seed")
run_btn = gr.Button("Run", variant="primary")
with gr.Column():
out = gr.Audio(label="Output")
status = gr.Textbox(label="Status")
run_btn.click(run_edit, [audio_in, caption, num_steps, guidance_scale, guidance_rescale, seed], [out, status])
return demo
if __name__ == "__main__":
print("[BOOT] entering main()", flush=True)
demo = build_demo()
port = int(os.environ.get("PORT", "7860"))
print(f"[BOOT] launching gradio on 0.0.0.0:{port}", flush=True)
demo.queue().launch(
server_name="0.0.0.0",
server_port=port,
share=False,
ssr_mode=False,
) |