zerogpu
Browse files
app.py
CHANGED
|
@@ -20,6 +20,9 @@ from safetensors.torch import load_file
|
|
| 20 |
import diffusers.schedulers as noise_schedulers
|
| 21 |
from huggingface_hub import snapshot_download
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
from models.common import LoadPretrainedBase
|
| 24 |
from utils.config import register_omegaconf_resolvers
|
| 25 |
|
|
@@ -45,17 +48,22 @@ MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
|
|
| 45 |
QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
|
| 46 |
QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
|
| 47 |
|
|
|
|
|
|
|
|
|
|
| 48 |
OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
|
| 49 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 50 |
|
| 51 |
USE_AMP = os.environ.get("USE_AMP", "0") == "1"
|
| 52 |
AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") # "bf16" or "fp16"
|
| 53 |
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
# ---------------------------------------------------------
|
| 58 |
-
# 下载 repo
|
| 59 |
# ---------------------------------------------------------
|
| 60 |
def resolve_model_dirs() -> Tuple[Path, Path]:
|
| 61 |
"""
|
|
@@ -63,12 +71,17 @@ def resolve_model_dirs() -> Tuple[Path, Path]:
|
|
| 63 |
repo_root: 你的 MMEdit repo 的本地目录(包含 config.yaml / model.safetensors / vae/)
|
| 64 |
qwen_root: Qwen2-Audio repo 的本地目录
|
| 65 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})")
|
| 67 |
repo_root = snapshot_download(
|
| 68 |
repo_id=MMEDIT_REPO_ID,
|
| 69 |
revision=MMEDIT_REVISION,
|
| 70 |
local_dir=None,
|
| 71 |
local_dir_use_symlinks=False,
|
|
|
|
| 72 |
)
|
| 73 |
repo_root = Path(repo_root).resolve()
|
| 74 |
|
|
@@ -78,9 +91,11 @@ def resolve_model_dirs() -> Tuple[Path, Path]:
|
|
| 78 |
revision=QWEN_REVISION,
|
| 79 |
local_dir=None,
|
| 80 |
local_dir_use_symlinks=False,
|
|
|
|
| 81 |
)
|
| 82 |
qwen_root = Path(qwen_root).resolve()
|
| 83 |
|
|
|
|
| 84 |
return repo_root, qwen_root
|
| 85 |
|
| 86 |
|
|
@@ -155,21 +170,15 @@ def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_roo
|
|
| 155 |
- pretrained_ckpt: ckpt/mmedit/vae/epoch=xx.ckpt -> repo_root/vae/epoch=xx.ckpt
|
| 156 |
- model_path: ckpt/qwen2-audio-7B-instruct -> qwen_root (snapshot_download 结果)
|
| 157 |
"""
|
| 158 |
-
|
| 159 |
# ---- 1) VAE ckpt ----
|
| 160 |
vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
|
| 161 |
if vae_ckpt:
|
| 162 |
vae_ckpt = str(vae_ckpt).replace("\\", "/")
|
| 163 |
|
| 164 |
-
# 你这里最稳定的做法:找到 "vae/" 子串之后的后缀
|
| 165 |
-
# 例如:
|
| 166 |
-
# ckpt/mmedit/vae/epoch=13-step=1000000.ckpt -> vae/epoch=13-step=1000000.ckpt
|
| 167 |
idx = vae_ckpt.find("vae/")
|
| 168 |
if idx != -1:
|
| 169 |
vae_rel = vae_ckpt[idx:] # 从 vae/ 开始截断
|
| 170 |
else:
|
| 171 |
-
# 兜底:如果有人直接写 epoch=xx.ckpt,那就放到 repo_root/vae/
|
| 172 |
-
# 或者写 vae/xxx.ckpt
|
| 173 |
if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt:
|
| 174 |
vae_rel = f"vae/{vae_ckpt}"
|
| 175 |
else:
|
|
@@ -188,20 +197,35 @@ def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_roo
|
|
| 188 |
)
|
| 189 |
|
| 190 |
# ---- 2) Qwen2-Audio model_path ----
|
| 191 |
-
# 你的 config 里写的是 ckpt/qwen2-audio-7B-instruct,但 Space 上我们直接用下载后的 qwen_root
|
| 192 |
exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
|
| 193 |
|
| 194 |
|
| 195 |
# ---------------------------------------------------------
|
| 196 |
# Scheduler(与你 exp_cfg.model.noise_scheduler_name 对齐)
|
|
|
|
|
|
|
| 197 |
# ---------------------------------------------------------
|
| 198 |
def build_scheduler(exp_cfg: Dict[str, Any]):
|
| 199 |
name = exp_cfg["model"].get("noise_scheduler_name", "stabilityai/stable-diffusion-2-1")
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
|
| 204 |
def _amp_ctx(device: torch.device):
|
|
|
|
| 205 |
if not USE_AMP:
|
| 206 |
return torch.autocast("cuda", enabled=False)
|
| 207 |
if device.type != "cuda":
|
|
@@ -211,9 +235,10 @@ def _amp_ctx(device: torch.device):
|
|
| 211 |
|
| 212 |
|
| 213 |
# ---------------------------------------------------------
|
| 214 |
-
# 冷启动:load+cache pipeline
|
|
|
|
| 215 |
# ---------------------------------------------------------
|
| 216 |
-
def
|
| 217 |
cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
|
| 218 |
if cache_key in _PIPELINE_CACHE:
|
| 219 |
return _PIPELINE_CACHE[cache_key]
|
|
@@ -221,10 +246,9 @@ def load_pipeline() -> Tuple[LoadPretrainedBase, object, int, torch.device]:
|
|
| 221 |
repo_root, qwen_root = resolve_model_dirs()
|
| 222 |
assert_repo_layout(repo_root)
|
| 223 |
|
| 224 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 225 |
logger.info(f"repo_root = {repo_root}")
|
| 226 |
-
logger.info(f"device = {device}")
|
| 227 |
logger.info(f"qwen_root = {qwen_root}")
|
|
|
|
| 228 |
|
| 229 |
exp_cfg = OmegaConf.load(repo_root / "config.yaml")
|
| 230 |
exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
|
|
@@ -233,25 +257,31 @@ def load_pipeline() -> Tuple[LoadPretrainedBase, object, int, torch.device]:
|
|
| 233 |
logger.info(f"patched pretrained_ckpt = {exp_cfg['model']['autoencoder'].get('pretrained_ckpt')}")
|
| 234 |
logger.info(f"patched qwen model_path = {exp_cfg['model']['content_encoder']['text_encoder'].get('model_path')}")
|
| 235 |
|
|
|
|
| 236 |
model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
|
| 237 |
|
|
|
|
| 238 |
ckpt_path = repo_root / "model.safetensors"
|
| 239 |
sd = load_file(str(ckpt_path))
|
| 240 |
model.load_pretrained(sd)
|
| 241 |
|
| 242 |
-
|
|
|
|
| 243 |
|
| 244 |
scheduler = build_scheduler(exp_cfg)
|
| 245 |
target_sr = int(exp_cfg.get("sample_rate", 24000))
|
| 246 |
|
| 247 |
-
_PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr
|
| 248 |
-
logger.info("
|
| 249 |
-
return model, scheduler, target_sr
|
| 250 |
|
| 251 |
|
| 252 |
# ---------------------------------------------------------
|
| 253 |
# 推理:audio + caption -> edited audio
|
|
|
|
|
|
|
| 254 |
# ---------------------------------------------------------
|
|
|
|
| 255 |
@torch.no_grad()
|
| 256 |
def run_edit(
|
| 257 |
audio_file: str,
|
|
@@ -268,12 +298,25 @@ def run_edit(
|
|
| 268 |
if not caption:
|
| 269 |
return None, "Error: caption is empty."
|
| 270 |
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
|
|
|
| 273 |
seed = int(seed)
|
| 274 |
torch.manual_seed(seed)
|
| 275 |
np.random.seed(seed)
|
| 276 |
|
|
|
|
| 277 |
wav = load_and_process_audio(audio_file, target_sr=target_sr).to(device)
|
| 278 |
|
| 279 |
batch = {
|
|
@@ -282,7 +325,7 @@ def run_edit(
|
|
| 282 |
"task": ["audio_editing"],
|
| 283 |
}
|
| 284 |
|
| 285 |
-
#
|
| 286 |
kwargs = {
|
| 287 |
"num_steps": int(num_steps),
|
| 288 |
"guidance_scale": float(guidance_scale),
|
|
@@ -301,6 +344,15 @@ def run_edit(
|
|
| 301 |
out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
|
| 302 |
sf.write(str(out_path), out_audio, samplerate=target_sr)
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
return str(out_path), f"OK | saved={out_path.name} | time={dt:.2f}s | sr={target_sr} | seed={seed}"
|
| 305 |
|
| 306 |
|
|
@@ -308,25 +360,24 @@ def run_edit(
|
|
| 308 |
# UI
|
| 309 |
# ---------------------------------------------------------
|
| 310 |
def build_demo():
|
| 311 |
-
with gr.Blocks(title="MMEdit
|
| 312 |
-
gr.Markdown("# MMEdit
|
| 313 |
-
|
| 314 |
-
"点下面的示例即可自动填充音频路径与编辑指令,然后点击 Run Editing。"
|
| 315 |
-
)
|
| 316 |
|
| 317 |
with gr.Row():
|
| 318 |
with gr.Column():
|
| 319 |
audio_in = gr.Audio(label="Input Audio", type="filepath")
|
| 320 |
caption = gr.Textbox(label="Caption (Edit Instruction)", lines=3)
|
| 321 |
|
| 322 |
-
#
|
|
|
|
| 323 |
gr.Examples(
|
| 324 |
label="example inputs",
|
| 325 |
examples=[
|
| 326 |
-
["
|
| 327 |
],
|
| 328 |
inputs=[audio_in, caption],
|
| 329 |
-
cache_examples=False,
|
| 330 |
)
|
| 331 |
|
| 332 |
with gr.Row():
|
|
@@ -351,15 +402,20 @@ def build_demo():
|
|
| 351 |
|
| 352 |
gr.Markdown(
|
| 353 |
"## 注意事项\n"
|
| 354 |
-
"
|
| 355 |
-
"
|
|
|
|
| 356 |
)
|
| 357 |
-
|
| 358 |
return demo
|
| 359 |
|
| 360 |
|
| 361 |
-
|
| 362 |
if __name__ == "__main__":
|
| 363 |
demo = build_demo()
|
| 364 |
-
port = int(os.environ.get("PORT", "7860"))
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
import diffusers.schedulers as noise_schedulers
|
| 21 |
from huggingface_hub import snapshot_download
|
| 22 |
|
| 23 |
+
# ZeroGPU 关键:spaces
|
| 24 |
+
import spaces
|
| 25 |
+
|
| 26 |
from models.common import LoadPretrainedBase
|
| 27 |
from utils.config import register_omegaconf_resolvers
|
| 28 |
|
|
|
|
| 48 |
QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
|
| 49 |
QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
|
| 50 |
|
| 51 |
+
# 如果 Qwen gated:Space 里把 HF_TOKEN 设为 Secret
|
| 52 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 53 |
+
|
| 54 |
OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
|
| 55 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 56 |
|
| 57 |
USE_AMP = os.environ.get("USE_AMP", "0") == "1"
|
| 58 |
AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") # "bf16" or "fp16"
|
| 59 |
|
| 60 |
+
# ZeroGPU:缓存 CPU pipeline(不要缓存在 CUDA)
|
| 61 |
+
_PIPELINE_CACHE: Dict[str, Tuple[LoadPretrainedBase, object, int]] = {}
|
| 62 |
+
_MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
|
| 63 |
|
| 64 |
|
| 65 |
# ---------------------------------------------------------
|
| 66 |
+
# 下载 repo(只下载一次;huggingface_hub 自带缓存)
|
| 67 |
# ---------------------------------------------------------
|
| 68 |
def resolve_model_dirs() -> Tuple[Path, Path]:
|
| 69 |
"""
|
|
|
|
| 71 |
repo_root: 你的 MMEdit repo 的本地目录(包含 config.yaml / model.safetensors / vae/)
|
| 72 |
qwen_root: Qwen2-Audio repo 的本地目录
|
| 73 |
"""
|
| 74 |
+
cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
|
| 75 |
+
if cache_key in _MODEL_DIR_CACHE:
|
| 76 |
+
return _MODEL_DIR_CACHE[cache_key]
|
| 77 |
+
|
| 78 |
logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})")
|
| 79 |
repo_root = snapshot_download(
|
| 80 |
repo_id=MMEDIT_REPO_ID,
|
| 81 |
revision=MMEDIT_REVISION,
|
| 82 |
local_dir=None,
|
| 83 |
local_dir_use_symlinks=False,
|
| 84 |
+
token=HF_TOKEN, # 私有 repo 时也可用
|
| 85 |
)
|
| 86 |
repo_root = Path(repo_root).resolve()
|
| 87 |
|
|
|
|
| 91 |
revision=QWEN_REVISION,
|
| 92 |
local_dir=None,
|
| 93 |
local_dir_use_symlinks=False,
|
| 94 |
+
token=HF_TOKEN, # gated 模型必须
|
| 95 |
)
|
| 96 |
qwen_root = Path(qwen_root).resolve()
|
| 97 |
|
| 98 |
+
_MODEL_DIR_CACHE[cache_key] = (repo_root, qwen_root)
|
| 99 |
return repo_root, qwen_root
|
| 100 |
|
| 101 |
|
|
|
|
| 170 |
- pretrained_ckpt: ckpt/mmedit/vae/epoch=xx.ckpt -> repo_root/vae/epoch=xx.ckpt
|
| 171 |
- model_path: ckpt/qwen2-audio-7B-instruct -> qwen_root (snapshot_download 结果)
|
| 172 |
"""
|
|
|
|
| 173 |
# ---- 1) VAE ckpt ----
|
| 174 |
vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
|
| 175 |
if vae_ckpt:
|
| 176 |
vae_ckpt = str(vae_ckpt).replace("\\", "/")
|
| 177 |
|
|
|
|
|
|
|
|
|
|
| 178 |
idx = vae_ckpt.find("vae/")
|
| 179 |
if idx != -1:
|
| 180 |
vae_rel = vae_ckpt[idx:] # 从 vae/ 开始截断
|
| 181 |
else:
|
|
|
|
|
|
|
| 182 |
if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt:
|
| 183 |
vae_rel = f"vae/{vae_ckpt}"
|
| 184 |
else:
|
|
|
|
| 197 |
)
|
| 198 |
|
| 199 |
# ---- 2) Qwen2-Audio model_path ----
|
|
|
|
| 200 |
exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
|
| 201 |
|
| 202 |
|
| 203 |
# ---------------------------------------------------------
|
| 204 |
# Scheduler(与你 exp_cfg.model.noise_scheduler_name 对齐)
|
| 205 |
+
# 注意:有些 repo_id 不存在 scheduler 子目录会 404。
|
| 206 |
+
# 这里给一个 fallback,避免直接炸。
|
| 207 |
# ---------------------------------------------------------
|
| 208 |
def build_scheduler(exp_cfg: Dict[str, Any]):
|
| 209 |
name = exp_cfg["model"].get("noise_scheduler_name", "stabilityai/stable-diffusion-2-1")
|
| 210 |
+
try:
|
| 211 |
+
scheduler = noise_schedulers.DDIMScheduler.from_pretrained(name, subfolder="scheduler", token=HF_TOKEN)
|
| 212 |
+
return scheduler
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.warning(f"DDIMScheduler.from_pretrained failed for '{name}', fallback to default DDIM config. err={e}")
|
| 215 |
+
# fallback:不依赖远端 repo
|
| 216 |
+
return noise_schedulers.DDIMScheduler(
|
| 217 |
+
num_train_timesteps=1000,
|
| 218 |
+
beta_start=0.00085,
|
| 219 |
+
beta_end=0.012,
|
| 220 |
+
beta_schedule="scaled_linear",
|
| 221 |
+
clip_sample=False,
|
| 222 |
+
set_alpha_to_one=False,
|
| 223 |
+
steps_offset=1,
|
| 224 |
+
)
|
| 225 |
|
| 226 |
|
| 227 |
def _amp_ctx(device: torch.device):
|
| 228 |
+
# ZeroGPU:只有在 device=cuda 且你明确开启 USE_AMP 才 autocast
|
| 229 |
if not USE_AMP:
|
| 230 |
return torch.autocast("cuda", enabled=False)
|
| 231 |
if device.type != "cuda":
|
|
|
|
| 235 |
|
| 236 |
|
| 237 |
# ---------------------------------------------------------
|
| 238 |
+
# 冷启动:load+cache pipeline(缓存 CPU 上的 model)
|
| 239 |
+
# ZeroGPU 启动阶段一般没有 CUDA,所以这里不要 model.to("cuda")
|
| 240 |
# ---------------------------------------------------------
|
| 241 |
+
def load_pipeline_cpu() -> Tuple[LoadPretrainedBase, object, int]:
|
| 242 |
cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
|
| 243 |
if cache_key in _PIPELINE_CACHE:
|
| 244 |
return _PIPELINE_CACHE[cache_key]
|
|
|
|
| 246 |
repo_root, qwen_root = resolve_model_dirs()
|
| 247 |
assert_repo_layout(repo_root)
|
| 248 |
|
|
|
|
| 249 |
logger.info(f"repo_root = {repo_root}")
|
|
|
|
| 250 |
logger.info(f"qwen_root = {qwen_root}")
|
| 251 |
+
logger.info(f"torch.cuda.is_available (startup) = {torch.cuda.is_available()}")
|
| 252 |
|
| 253 |
exp_cfg = OmegaConf.load(repo_root / "config.yaml")
|
| 254 |
exp_cfg = OmegaConf.to_container(exp_cfg, resolve=True)
|
|
|
|
| 257 |
logger.info(f"patched pretrained_ckpt = {exp_cfg['model']['autoencoder'].get('pretrained_ckpt')}")
|
| 258 |
logger.info(f"patched qwen model_path = {exp_cfg['model']['content_encoder']['text_encoder'].get('model_path')}")
|
| 259 |
|
| 260 |
+
# instantiate model(在 CPU 上构建)
|
| 261 |
model: LoadPretrainedBase = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
|
| 262 |
|
| 263 |
+
# load weights(你的 mmedit 权重)
|
| 264 |
ckpt_path = repo_root / "model.safetensors"
|
| 265 |
sd = load_file(str(ckpt_path))
|
| 266 |
model.load_pretrained(sd)
|
| 267 |
|
| 268 |
+
# 强制留在 CPU(ZeroGPU 关键)
|
| 269 |
+
model = model.to(torch.device("cpu")).eval()
|
| 270 |
|
| 271 |
scheduler = build_scheduler(exp_cfg)
|
| 272 |
target_sr = int(exp_cfg.get("sample_rate", 24000))
|
| 273 |
|
| 274 |
+
_PIPELINE_CACHE[cache_key] = (model, scheduler, target_sr)
|
| 275 |
+
logger.info("CPU pipeline loaded and cached.")
|
| 276 |
+
return model, scheduler, target_sr
|
| 277 |
|
| 278 |
|
| 279 |
# ---------------------------------------------------------
|
| 280 |
# 推理:audio + caption -> edited audio
|
| 281 |
+
# ZeroGPU:必须用 @spaces.GPU
|
| 282 |
+
# 并且:函数内再把模型搬到 cuda,推完搬回 cpu
|
| 283 |
# ---------------------------------------------------------
|
| 284 |
+
@spaces.GPU
|
| 285 |
@torch.no_grad()
|
| 286 |
def run_edit(
|
| 287 |
audio_file: str,
|
|
|
|
| 298 |
if not caption:
|
| 299 |
return None, "Error: caption is empty."
|
| 300 |
|
| 301 |
+
# 1) 取 CPU 缓存
|
| 302 |
+
model_cpu, scheduler, target_sr = load_pipeline_cpu()
|
| 303 |
+
|
| 304 |
+
# 2) ZeroGPU 进入 GPU 区域后,cuda 才会 available
|
| 305 |
+
if not torch.cuda.is_available():
|
| 306 |
+
return None, "Error: ZeroGPU did not allocate CUDA. Please retry (queue) or check Space hardware."
|
| 307 |
+
|
| 308 |
+
device = torch.device("cuda")
|
| 309 |
+
logger.info(f"[GPU] torch.cuda.is_available={torch.cuda.is_available()}, device={device}")
|
| 310 |
+
|
| 311 |
+
# 3) 把模型搬到 GPU(临时)
|
| 312 |
+
model = model_cpu.to(device).eval()
|
| 313 |
|
| 314 |
+
# seed
|
| 315 |
seed = int(seed)
|
| 316 |
torch.manual_seed(seed)
|
| 317 |
np.random.seed(seed)
|
| 318 |
|
| 319 |
+
# audio preprocess
|
| 320 |
wav = load_and_process_audio(audio_file, target_sr=target_sr).to(device)
|
| 321 |
|
| 322 |
batch = {
|
|
|
|
| 325 |
"task": ["audio_editing"],
|
| 326 |
}
|
| 327 |
|
| 328 |
+
# 与 infer.config 对齐
|
| 329 |
kwargs = {
|
| 330 |
"num_steps": int(num_steps),
|
| 331 |
"guidance_scale": float(guidance_scale),
|
|
|
|
| 344 |
out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
|
| 345 |
sf.write(str(out_path), out_audio, samplerate=target_sr)
|
| 346 |
|
| 347 |
+
# 4) 推完立刻把模型搬回 CPU(ZeroGPU 关键:避免缓存里残留 cuda tensor)
|
| 348 |
+
model_cpu = model.to("cpu")
|
| 349 |
+
del model
|
| 350 |
+
torch.cuda.empty_cache()
|
| 351 |
+
|
| 352 |
+
# 5) 更新缓存(仍然缓存 CPU 版本)
|
| 353 |
+
cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
|
| 354 |
+
_PIPELINE_CACHE[cache_key] = (model_cpu, scheduler, target_sr)
|
| 355 |
+
|
| 356 |
return str(out_path), f"OK | saved={out_path.name} | time={dt:.2f}s | sr={target_sr} | seed={seed}"
|
| 357 |
|
| 358 |
|
|
|
|
| 360 |
# UI
|
| 361 |
# ---------------------------------------------------------
|
| 362 |
def build_demo():
|
| 363 |
+
with gr.Blocks(title="MMEdit (ZeroGPU)") as demo:
|
| 364 |
+
gr.Markdown("# MMEdit ZeroGPU(audio + caption → edited audio)")
|
| 365 |
+
|
|
|
|
|
|
|
| 366 |
|
| 367 |
with gr.Row():
|
| 368 |
with gr.Column():
|
| 369 |
audio_in = gr.Audio(label="Input Audio", type="filepath")
|
| 370 |
caption = gr.Textbox(label="Caption (Edit Instruction)", lines=3)
|
| 371 |
|
| 372 |
+
# 注意:Spaces 不允许你 push 大的 wav 示例。
|
| 373 |
+
# 最稳的方式:你自己在 Space repo 放一个很小的 demo wav(几百 KB)。
|
| 374 |
gr.Examples(
|
| 375 |
label="example inputs",
|
| 376 |
examples=[
|
| 377 |
+
["./Ym8O802VvJes.wav", "Mix in dog barking in the middle."],
|
| 378 |
],
|
| 379 |
inputs=[audio_in, caption],
|
| 380 |
+
cache_examples=False,
|
| 381 |
)
|
| 382 |
|
| 383 |
with gr.Row():
|
|
|
|
| 402 |
|
| 403 |
gr.Markdown(
|
| 404 |
"## 注意事项\n"
|
| 405 |
+
"1) ZeroGPU 首次点击会分配 GPU,可能稍慢。\n"
|
| 406 |
+
"2) 如果遇到错误,请重试(尤其是首次启动时)。\n"
|
| 407 |
+
"3) 原始音频保留可能有bug\n"
|
| 408 |
)
|
|
|
|
| 409 |
return demo
|
| 410 |
|
| 411 |
|
|
|
|
| 412 |
if __name__ == "__main__":
|
| 413 |
demo = build_demo()
|
| 414 |
+
port = int(os.environ.get("PORT", "7860"))
|
| 415 |
+
# ZeroGPU:强烈建议 queue;并禁用 SSR 更稳
|
| 416 |
+
demo.queue().launch(
|
| 417 |
+
server_name="0.0.0.0",
|
| 418 |
+
server_port=port,
|
| 419 |
+
share=False,
|
| 420 |
+
ssr_mode=False,
|
| 421 |
+
)
|