fix tf32
Browse files
app.py
CHANGED
|
@@ -23,16 +23,14 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
|
|
| 23 |
logger = logging.getLogger("mmedit_space")
|
| 24 |
|
| 25 |
|
| 26 |
-
|
| 27 |
-
# HF Repo IDs(按你的默认需求)
|
| 28 |
-
# ---------------------------------------------------------
|
| 29 |
MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit")
|
| 30 |
MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
|
| 31 |
|
| 32 |
QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
|
| 33 |
QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
|
| 34 |
|
| 35 |
-
|
| 36 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 37 |
|
| 38 |
OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
|
|
@@ -41,8 +39,6 @@ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
| 41 |
USE_AMP = os.environ.get("USE_AMP", "0") == "1"
|
| 42 |
AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") # "bf16" or "fp16"
|
| 43 |
|
| 44 |
-
# ZeroGPU:缓存 CPU pipeline(不要缓存 CUDA Tensor)
|
| 45 |
-
# cache: key -> (model_cpu, scheduler, target_sr)
|
| 46 |
_PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
|
| 47 |
# cache: key -> (repo_root, qwen_root)
|
| 48 |
_MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
|
|
@@ -89,6 +85,8 @@ def load_and_process_audio(audio_path: str, target_sr: int):
|
|
| 89 |
import torchaudio
|
| 90 |
import librosa
|
| 91 |
|
|
|
|
|
|
|
| 92 |
path = Path(audio_path)
|
| 93 |
if not path.exists():
|
| 94 |
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
|
@@ -184,7 +182,11 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
|
|
| 184 |
from safetensors.torch import load_file
|
| 185 |
import diffusers.schedulers as noise_schedulers
|
| 186 |
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
try:
|
| 189 |
from utils.config import register_omegaconf_resolvers
|
| 190 |
register_omegaconf_resolvers()
|
|
@@ -192,12 +194,10 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
|
|
| 192 |
|
| 193 |
if not audio_file: return None, "Please upload audio."
|
| 194 |
|
| 195 |
-
|
| 196 |
model = None
|
| 197 |
|
| 198 |
try:
|
| 199 |
-
# ==========================================
|
| 200 |
-
# 1. 就在这里加载模型!利用 ZeroGPU 的大内存
|
| 201 |
# ==========================================
|
| 202 |
logger.info("🚀 Starting ZeroGPU Task...")
|
| 203 |
|
|
@@ -205,7 +205,7 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
|
|
| 205 |
repo_root, qwen_root = resolve_model_dirs()
|
| 206 |
exp_cfg = OmegaConf.to_container(OmegaConf.load(repo_root / "config.yaml"), resolve=True)
|
| 207 |
|
| 208 |
-
#
|
| 209 |
vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "")
|
| 210 |
if vae_ckpt:
|
| 211 |
p1 = repo_root / "vae" / Path(vae_ckpt).name
|
|
@@ -214,7 +214,7 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
|
|
| 214 |
elif p2.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p2)
|
| 215 |
exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
|
| 216 |
|
| 217 |
-
#
|
| 218 |
logger.info("Instantiating model (Hydra)...")
|
| 219 |
model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
|
| 220 |
|
|
@@ -227,7 +227,6 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
|
|
| 227 |
gc.collect()
|
| 228 |
|
| 229 |
# ==========================================
|
| 230 |
-
# 2. 立即转到 GPU (FP16)
|
| 231 |
# ==========================================
|
| 232 |
device = torch.device("cuda")
|
| 233 |
logger.info("Moving model to CUDA (FP16)...")
|
|
@@ -279,9 +278,7 @@ def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, s
|
|
| 279 |
with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
|
| 280 |
out = model.inference(scheduler=scheduler, **batch)
|
| 281 |
|
| 282 |
-
|
| 283 |
-
# 4. 保存结果
|
| 284 |
-
# ==========================================
|
| 285 |
out_audio = out[0, 0].detach().float().cpu().numpy()
|
| 286 |
out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
|
| 287 |
sf.write(str(out_path), out_audio, samplerate=target_sr)
|
|
@@ -311,9 +308,23 @@ def build_demo():
|
|
| 311 |
audio_in = gr.Audio(label="Input", type="filepath")
|
| 312 |
caption = gr.Textbox(label="Instruction", lines=3)
|
| 313 |
gr.Examples(
|
| 314 |
-
label="Examples",
|
| 315 |
-
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
)
|
| 318 |
with gr.Row():
|
| 319 |
num_steps = gr.Slider(10, 100, 50, step=1, label="Steps")
|
|
|
|
| 23 |
logger = logging.getLogger("mmedit_space")
|
| 24 |
|
| 25 |
|
| 26 |
+
|
|
|
|
|
|
|
| 27 |
MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit")
|
| 28 |
MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
|
| 29 |
|
| 30 |
QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
|
| 31 |
QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
|
| 32 |
|
| 33 |
+
|
| 34 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 35 |
|
| 36 |
OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
|
|
|
|
| 39 |
USE_AMP = os.environ.get("USE_AMP", "0") == "1"
|
| 40 |
AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") # "bf16" or "fp16"
|
| 41 |
|
|
|
|
|
|
|
| 42 |
_PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
|
| 43 |
# cache: key -> (repo_root, qwen_root)
|
| 44 |
_MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
|
|
|
|
| 85 |
import torchaudio
|
| 86 |
import librosa
|
| 87 |
|
| 88 |
+
|
| 89 |
+
|
| 90 |
path = Path(audio_path)
|
| 91 |
if not path.exists():
|
| 92 |
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
|
|
|
| 182 |
from safetensors.torch import load_file
|
| 183 |
import diffusers.schedulers as noise_schedulers
|
| 184 |
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 188 |
+
torch.backends.cudnn.allow_tf32 = False
|
| 189 |
+
|
| 190 |
try:
|
| 191 |
from utils.config import register_omegaconf_resolvers
|
| 192 |
register_omegaconf_resolvers()
|
|
|
|
| 194 |
|
| 195 |
if not audio_file: return None, "Please upload audio."
|
| 196 |
|
| 197 |
+
|
| 198 |
model = None
|
| 199 |
|
| 200 |
try:
|
|
|
|
|
|
|
| 201 |
# ==========================================
|
| 202 |
logger.info("🚀 Starting ZeroGPU Task...")
|
| 203 |
|
|
|
|
| 205 |
repo_root, qwen_root = resolve_model_dirs()
|
| 206 |
exp_cfg = OmegaConf.to_container(OmegaConf.load(repo_root / "config.yaml"), resolve=True)
|
| 207 |
|
| 208 |
+
#
|
| 209 |
vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "")
|
| 210 |
if vae_ckpt:
|
| 211 |
p1 = repo_root / "vae" / Path(vae_ckpt).name
|
|
|
|
| 214 |
elif p2.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p2)
|
| 215 |
exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
|
| 216 |
|
| 217 |
+
#
|
| 218 |
logger.info("Instantiating model (Hydra)...")
|
| 219 |
model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
|
| 220 |
|
|
|
|
| 227 |
gc.collect()
|
| 228 |
|
| 229 |
# ==========================================
|
|
|
|
| 230 |
# ==========================================
|
| 231 |
device = torch.device("cuda")
|
| 232 |
logger.info("Moving model to CUDA (FP16)...")
|
|
|
|
| 278 |
with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
|
| 279 |
out = model.inference(scheduler=scheduler, **batch)
|
| 280 |
|
| 281 |
+
|
|
|
|
|
|
|
| 282 |
out_audio = out[0, 0].detach().float().cpu().numpy()
|
| 283 |
out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
|
| 284 |
sf.write(str(out_path), out_audio, samplerate=target_sr)
|
|
|
|
| 308 |
audio_in = gr.Audio(label="Input", type="filepath")
|
| 309 |
caption = gr.Textbox(label="Instruction", lines=3)
|
| 310 |
gr.Examples(
|
| 311 |
+
label="Examples (Click to load)",
|
| 312 |
+
# 格式:[ [音频路径1, 提示词1], [音频路径2, 提示词2], ... ]
|
| 313 |
+
examples=[
|
| 314 |
+
# 示例 1 (原本的)
|
| 315 |
+
["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."],
|
| 316 |
+
|
| 317 |
+
# 示例 2 (新加的)
|
| 318 |
+
["./YDKM2KjNkX18.wav", "Incorporate Telephone bell ringing into the background."],
|
| 319 |
+
|
| 320 |
+
# 示例 3 (新加的)
|
| 321 |
+
["./drop_audiocaps_1.wav", "Remove the sound of several beeps."],
|
| 322 |
+
|
| 323 |
+
# 示例 4 (新加的)
|
| 324 |
+
["./reorder_audiocaps_1.wav", "Switch the positions of the woman's voice and whistling."]
|
| 325 |
+
],
|
| 326 |
+
inputs=[audio_in, caption], # 对应上面列表的顺序:第一个是 Audio,第二个是 Textbox
|
| 327 |
+
cache_examples=False, # ZeroGPU 环境建议设为 False,避免启动时耗时计算
|
| 328 |
)
|
| 329 |
with gr.Row():
|
| 330 |
num_steps = gr.Slider(10, 100, 50, step=1, label="Steps")
|