File size: 12,469 Bytes
c14d03d
 
 
dabc4d4
 
92bc756
c14d03d
 
 
 
 
47b5ec4
5a53edf
c14d03d
 
 
 
688592f
c14d03d
b75ca87
c14d03d
b75ca87
c14d03d
 
b75ca87
a2ca450
c14d03d
 
 
 
 
 
a2ca450
0bcb372
 
c14d03d
 
 
3c696cb
 
 
dabc4d4
b75ca87
0bcb372
c14d03d
 
 
3c696cb
c14d03d
 
0bcb372
 
 
 
3c696cb
c14d03d
b75ca87
 
3c696cb
 
b75ca87
c14d03d
b75ca87
 
3c696cb
c14d03d
b75ca87
 
3c696cb
 
 
c14d03d
b75ca87
 
 
 
c14d03d
b75ca87
 
3c696cb
b75ca87
 
3c696cb
688592f
dabc4d4
688592f
 
a2ca450
 
c14d03d
 
 
 
3c696cb
c14d03d
3c696cb
 
 
 
 
b75ca87
 
 
3c696cb
 
b75ca87
 
 
 
c14d03d
 
b75ca87
3c696cb
c14d03d
b75ca87
 
 
 
c14d03d
 
b75ca87
 
3c696cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b75ca87
 
3c696cb
b75ca87
 
 
3c696cb
 
 
b75ca87
3c696cb
 
 
 
b75ca87
 
688592f
b75ca87
3c696cb
 
 
 
 
 
 
 
 
 
b75ca87
 
 
b312d65
caea508
77f1338
43bf2c1
 
 
 
b75ca87
dabc4d4
 
 
77f1338
11cf650
9466f64
 
a2ca450
 
 
47b5ec4
 
 
 
dabc4d4
47b5ec4
 
a2ca450
77f1338
688592f
b2d65b0
77f1338
 
47b5ec4
77f1338
 
 
 
a2ca450
77f1338
 
 
 
 
 
 
 
a2ca450
77f1338
 
 
 
 
 
 
 
 
66b10fb
77f1338
 
 
 
 
47b5ec4
77f1338
43bf2c1
62c2ea1
43bf2c1
62c2ea1
 
 
43bf2c1
62c2ea1
43bf2c1
 
 
 
77f1338
 
 
 
 
 
 
 
 
 
 
 
 
688592f
 
 
62c2ea1
66b10fb
f3f0643
 
688592f
f3f0643
 
 
 
 
47b5ec4
f3f0643
47b5ec4
77f1338
f3f0643
62c2ea1
77f1338
47b5ec4
a2ca450
f3f0643
 
 
47b5ec4
 
b2d65b0
 
77f1338
 
 
47b5ec4
f3f0643
77f1338
 
 
f3f0643
66b10fb
77f1338
 
688592f
77f1338
c14d03d
77f1338
 
c14d03d
 
77f1338
 
b75ca87
a2ca450
 
 
 
 
 
11cf650
a2ca450
 
6b00c17
 
b75ca87
3c696cb
77f1338
 
 
 
 
 
c14d03d
77f1338
b75ca87
77f1338
 
c14d03d
5a53edf
 
9466f64
c14d03d
caea508
9466f64
0bcb372
caea508
 
 
 
77f1338
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# ZeroGPU 关键:必须最先导入
import spaces
import traceback
import os
import time
import logging
from pathlib import Path
from typing import Tuple, Optional, Dict, Any
import gc
import gradio as gr
import numpy as np
import soundfile as sf
from huggingface_hub import snapshot_download


# -----------------------------
# Logging
# -----------------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("mmedit_space")



MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit")
MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)

QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
QWEN_REVISION = os.environ.get("QWEN_REVISION", None)


HF_TOKEN = os.environ.get("HF_TOKEN", None)

OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

USE_AMP = os.environ.get("USE_AMP", "0") == "1"
AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16")  # "bf16" or "fp16"

_PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
# cache: key -> (repo_root, qwen_root)
_MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}


# ---------------------------------------------------------
# 下载 repo(只下载一次;huggingface_hub 自带缓存)
# ---------------------------------------------------------
def resolve_model_dirs() -> Tuple[Path, Path]:
    cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
    if cache_key in _MODEL_DIR_CACHE:
        return _MODEL_DIR_CACHE[cache_key]

    logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})")
    repo_root = snapshot_download(
        repo_id=MMEDIT_REPO_ID,
        revision=MMEDIT_REVISION,
        local_dir=None,
        local_dir_use_symlinks=False,
        token=HF_TOKEN,
    )
    repo_root = Path(repo_root).resolve()

    logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID} (revision={QWEN_REVISION})")
    qwen_root = snapshot_download(
        repo_id=QWEN_REPO_ID,
        revision=QWEN_REVISION,
        local_dir=None,
        local_dir_use_symlinks=False,
        token=HF_TOKEN,  # gated 模型必须
    )
    qwen_root = Path(qwen_root).resolve()

    _MODEL_DIR_CACHE[cache_key] = (repo_root, qwen_root)
    return repo_root, qwen_root


# ---------------------------------------------------------
# 你的音频加载(按你要求:orig -> 16k -> target_sr)
# ---------------------------------------------------------
def load_and_process_audio(audio_path: str, target_sr: int):
    # 延迟导入(避免启动阶段触发 CUDA 初始化)
    import torch
    import torchaudio
    import librosa

    

    path = Path(audio_path)
    if not path.exists():
        raise FileNotFoundError(f"Audio file not found: {audio_path}")

    waveform, orig_sr = torchaudio.load(str(path))  # (C, T)

    # Convert to mono
    if waveform.ndim == 2:
        waveform = waveform.mean(dim=0)  # (T,)
    elif waveform.ndim > 2:
        waveform = waveform.reshape(-1)

    if target_sr and int(target_sr) != int(orig_sr):
        waveform_np = waveform.cpu().numpy()

        # 1) 先到 16k
        sr_mid = 16000
        if int(orig_sr) != sr_mid:
            waveform_np = librosa.resample(waveform_np, orig_sr=int(orig_sr), target_sr=sr_mid)
            orig_sr_mid = sr_mid
        else:
            orig_sr_mid = int(orig_sr)

        # 2) 再到 target_sr(如 24k)
        if int(target_sr) != orig_sr_mid:
            waveform_np = librosa.resample(waveform_np, orig_sr=orig_sr_mid, target_sr=int(target_sr))

        waveform = torch.from_numpy(waveform_np)

    return waveform


# ---------------------------------------------------------
# 校验 repo 结构
# ---------------------------------------------------------
def assert_repo_layout(repo_root: Path) -> None:
    must = [repo_root / "config.yaml", repo_root / "model.safetensors", repo_root / "vae"]
    for p in must:
        if not p.exists():
            raise FileNotFoundError(f"Missing required path: {p}")

    vae_files = list((repo_root / "vae").glob("*.ckpt"))
    if len(vae_files) == 0:
        raise FileNotFoundError(f"No .ckpt found under: {repo_root/'vae'}")


# ---------------------------------------------------------
# 适配 config.yaml 的路径写法
# ---------------------------------------------------------
def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None:
    # ---- 1) VAE ckpt ----
    vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
    if vae_ckpt:
        vae_ckpt = str(vae_ckpt).replace("\\", "/")
        idx = vae_ckpt.find("vae/")
        if idx != -1:
            vae_rel = vae_ckpt[idx:]  # 从 vae/ 开始截断
        else:
            if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt:
                vae_rel = f"vae/{vae_ckpt}"
            else:
                vae_rel = vae_ckpt

        vae_path = (repo_root / vae_rel).resolve()
        exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path)

        if not vae_path.exists():
            raise FileNotFoundError(
                f"VAE ckpt not found after patch:\n"
                f"  original: {vae_ckpt}\n"
                f"  patched : {vae_path}\n"
                f"Repo root: {repo_root}\n"
                f"Expected:  {repo_root/'vae'/'*.ckpt'}"
            )

    # ---- 2) Qwen2-Audio model_path ----
    exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)



@spaces.GPU
def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed):
    



    import torch
    import hydra
    from omegaconf import OmegaConf
    from safetensors.torch import load_file
    import diffusers.schedulers as noise_schedulers
    logger.info("🚀 Starting ..")
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False



    try:
        from utils.config import register_omegaconf_resolvers
        register_omegaconf_resolvers()
    except: pass

    if not audio_file: return None, "Please upload audio."
    

    model = None
    
    try:
        # ==========================================
        logger.info("🚀 Starting ZeroGPU Task...")
        
        # 路径准备
        repo_root, qwen_root = resolve_model_dirs()
        exp_cfg = OmegaConf.to_container(OmegaConf.load(repo_root / "config.yaml"), resolve=True)
        
        # 
        vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "")
        if vae_ckpt:
            p1 = repo_root / "vae" / Path(vae_ckpt).name
            p2 = repo_root / Path(vae_ckpt).name
            if p1.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p1)
            elif p2.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p2)
        exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)

        # 
        logger.info("Instantiating model (Hydra)...")
        model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
        
        # 加载权重
        ckpt_path = str(repo_root / "model.safetensors")
        logger.info(f"Loading weights from {ckpt_path}...")
        sd = load_file(ckpt_path)
        model.load_pretrained(sd)
        del sd # 立即释放
        gc.collect()

        # ==========================================
        # ==========================================
        device = torch.device("cuda")
        logger.info("Moving model to CUDA (FP16)...")
        
        # 这一步将模型送入显卡
        def safe_move_model(m, dev):
            logger.info("🛡️ Moving model to GPU in FP32...")
            for name, child in m.named_children():
                child.to(dev, dtype=torch.float32)
                logger.info(f"Moving {name} to GPU (fp32)...")
            m.to(dev, dtype=torch.float32)
            return m

            
        model = safe_move_model(model, device)
        model.eval()
        logger.info("Model is moved to CUDA.")
        # Scheduler
        try:
            scheduler = noise_schedulers.DDIMScheduler.from_pretrained(
                exp_cfg["model"].get("noise_scheduler_name", ""), 
                subfolder="scheduler", token=HF_TOKEN
            )
        except:
            scheduler = noise_schedulers.DDIMScheduler(num_train_timesteps=1000)

        # ==========================================
        # 3. 开始推理
        # ==========================================
        target_sr = int(exp_cfg.get("sample_rate", 24000))
        torch.manual_seed(int(seed))
        np.random.seed(int(seed))
        
        wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=torch.float32)
        
        batch = {
            "audio_id": [Path(audio_file).stem],
            "content": [{"audio": wav, "caption": caption}],
            "task": ["audio_editing"],
            "num_steps": int(num_steps),
            "guidance_scale": float(guidance_scale),
            "guidance_rescale": float(guidance_rescale),
            "use_gt_duration": False,
            "mask_time_aligned_content": False
        }
        
        logger.info("Inference running...")
        t0 = time.time()
        with torch.no_grad():
            out = model.inference(scheduler=scheduler, **batch)
            

        out_audio = out[0, 0].detach().float().cpu().numpy()
        out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
        sf.write(str(out_path), out_audio, samplerate=target_sr)
        
        return str(out_path), f"Success | {time.time()-t0:.2f}s"

    except Exception as e:
        err = traceback.format_exc()
        logger.error(f"❌ ERROR:\n{err}")
        return None, f"Runtime Error: {e}"
        
    finally:
        # 强制清理,防止下一次任务显存不够
        logger.info("Cleaning up...")
        if model is not None: del model
        torch.cuda.empty_cache()
        gc.collect()

# -----------------------------
# UI
# -----------------------------
def build_demo():
    with gr.Blocks(title="MMEdit") as demo:
        gr.Markdown("# MMEdit ZeroGPU (Direct Load)")
        with gr.Row():
            with gr.Column():
                audio_in = gr.Audio(label="Input", type="filepath")
                caption = gr.Textbox(label="Instruction", lines=3)
                gr.Examples(
                    label="Examples (Click to load)",
                    # 格式:[ [音频路径1, 提示词1], [音频路径2, 提示词2], ... ]
                    examples=[
                        # 示例 1 (原本的)
                        ["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."],
                        ["./YDKM2KjNkX18.wav", "Incorporate Telephone bell ringing into the background."],
                        ["./drop_audiocaps_1.wav", "Erase the rain falling sound  from the background."],
                        ["./reorder_audiocaps_1.wav", "Switch the positions of the woman's voice and whistling."]
                    ],
                    inputs=[audio_in, caption],  # 对应上面列表的顺序:第一个是 Audio,第二个是 Textbox
                    cache_examples=False,        # ZeroGPU 环境建议设为 False,避免启动时耗时计算
                )
                with gr.Row():
                    num_steps = gr.Slider(10, 100, 50, step=1, label="Steps")
                    guidance_scale = gr.Slider(1.0, 12.0, 5.0, step=0.5, label="Guidance")
                    guidance_rescale = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Rescale")
                    seed = gr.Number(42, label="Seed")
                run_btn = gr.Button("Run", variant="primary")
            
            with gr.Column():
                out = gr.Audio(label="Output")
                status = gr.Textbox(label="Status")
        
        run_btn.click(run_edit, [audio_in, caption, num_steps, guidance_scale, guidance_rescale, seed], [out, status])
    return demo

if __name__ == "__main__":
    print("[BOOT] entering main()", flush=True)
    demo = build_demo()
    port = int(os.environ.get("PORT", "7860"))
    print(f"[BOOT] launching gradio on 0.0.0.0:{port}", flush=True)
    demo.queue().launch(
        server_name="0.0.0.0",
        server_port=port,
        share=False,
        ssr_mode=False,
    )