| import json |
| import os |
| import tempfile |
| import time |
| import traceback |
| from importlib.util import find_spec |
| from pathlib import Path |
| from urllib.request import urlretrieve |
|
|
| import gradio as gr |
| import numpy as np |
| import spaces |
| import soundfile as sf |
| import torch |
| from huggingface_hub import snapshot_download |
|
|
|
|
| MODEL_REPO_ID = os.environ.get("AUDIO_OMNI_REPO_ID", "HKUSTAudio/Audio-Omni") |
| MODEL_CACHE_DIR = Path(os.environ.get("AUDIO_OMNI_CACHE_DIR", "./model")) |
| QWEN_REPO_ID = os.environ.get("QWEN_OMNI_REPO_ID", "Qwen/Qwen2.5-Omni-3B") |
| QWEN_CACHE_DIR = Path(os.environ.get("QWEN_OMNI_CACHE_DIR", str(MODEL_CACHE_DIR / "qwen_omni_3b"))) |
| GPU_DURATION_SECONDS = int(os.environ.get("GPU_DURATION_SECONDS", "120")) |
| MAX_RUNTIME_SECONDS = int(os.environ.get("MAX_RUNTIME_SECONDS", "115")) |
| MAX_STEPS = int(os.environ.get("MAX_STEPS", "85")) |
| MAX_AUDIO_SECONDS = int(os.environ.get("MAX_AUDIO_SECONDS", "10")) |
| MAX_COMPLEXITY = int(os.environ.get("MAX_COMPLEXITY", "800")) |
| FIXED_SECONDS_TOTAL = int(os.environ.get("FIXED_SECONDS_TOTAL", "10")) |
| OUTPUT_SECONDS = int(os.environ.get("OUTPUT_SECONDS", "5")) |
| PREFETCH_ON_START = os.environ.get("PREFETCH_ON_START", "1").strip().lower() in {"1", "true", "yes", "on"} |
| PREFETCH_ON_BOOT = os.environ.get("PREFETCH_ON_BOOT", "1").strip().lower() in {"1", "true", "yes", "on"} |
| FORCE_ZERO_SYNC_FEATURES = os.environ.get("FORCE_ZERO_SYNC_FEATURES", "1").strip().lower() in { |
| "1", |
| "true", |
| "yes", |
| "on", |
| } |
|
|
| _MODEL = None |
| _MODEL_LOAD_ERROR = None |
| VOCAB_FALLBACK_URL = "https://raw.githubusercontent.com/ZeyueT/Audio-Omni/main/audio_omni/data/vocab.txt" |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
| def _estimate_runtime_seconds(steps: int, seconds_total: int) -> float: |
| |
| return 18.0 + 0.11 * float(steps) * float(seconds_total) |
|
|
|
|
| def _download_model_files() -> Path: |
| MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True) |
| local_dir = snapshot_download( |
| repo_id=MODEL_REPO_ID, |
| local_dir=str(MODEL_CACHE_DIR), |
| token=HF_TOKEN, |
| allow_patterns=[ |
| "Audio-Omni.json", |
| "model.ckpt", |
| "synchformer_state_dict.pth", |
| ], |
| ) |
| return Path(local_dir) |
|
|
|
|
| def _download_qwen_files() -> Path: |
| QWEN_CACHE_DIR.mkdir(parents=True, exist_ok=True) |
| local_dir = snapshot_download( |
| repo_id=QWEN_REPO_ID, |
| local_dir=str(QWEN_CACHE_DIR), |
| token=HF_TOKEN, |
| ) |
| return Path(local_dir) |
|
|
|
|
| def _prefetch_runtime_assets() -> None: |
| _download_model_files() |
| _resolve_vocab_file() |
| qwen_dir = _download_qwen_files() |
| os.environ["QWEN_OMNI_MODEL_PATH"] = str(qwen_dir) |
|
|
|
|
| def _resolve_vocab_file() -> Path: |
| candidates = [] |
|
|
| pkg_spec = find_spec("audio_omni") |
| if pkg_spec and pkg_spec.origin: |
| pkg_root = Path(pkg_spec.origin).resolve().parent |
| candidates.append(pkg_root / "data" / "vocab.txt") |
|
|
| candidates.append(Path.cwd() / "audio_omni" / "data" / "vocab.txt") |
| candidates.append(MODEL_CACHE_DIR / "vocab.txt") |
|
|
| for candidate in candidates: |
| if candidate.exists(): |
| return candidate |
|
|
| fallback_path = MODEL_CACHE_DIR / "vocab.txt" |
| fallback_path.parent.mkdir(parents=True, exist_ok=True) |
| urlretrieve(VOCAB_FALLBACK_URL, str(fallback_path)) |
| return fallback_path |
|
|
|
|
| def _patch_config_vocab(config_path: Path) -> Path: |
| vocab_path = _resolve_vocab_file() |
|
|
| with open(config_path, "r", encoding="utf-8") as fp: |
| config = json.load(fp) |
|
|
| cond_cfg = config.get("model", {}).get("conditioning", {}).get("configs", []) |
| for item in cond_cfg: |
| if item.get("id") == "speech_prompt": |
| item.setdefault("config", {}) |
| item["config"]["vocab_file"] = str(vocab_path) |
|
|
| patched_path = config_path.parent / "Audio-Omni.patched.json" |
| with open(patched_path, "w", encoding="utf-8") as fp: |
| json.dump(config, fp) |
| return patched_path |
|
|
|
|
| def _load_model(): |
| global _MODEL, _MODEL_LOAD_ERROR |
|
|
| if _MODEL is not None: |
| return _MODEL |
| if _MODEL_LOAD_ERROR is not None: |
| raise RuntimeError(_MODEL_LOAD_ERROR) |
|
|
| if not torch.cuda.is_available(): |
| _MODEL_LOAD_ERROR = "CUDA is unavailable. ZeroGPU did not provide a GPU session." |
| raise RuntimeError(_MODEL_LOAD_ERROR) |
|
|
| _prefetch_runtime_assets() |
| model_dir = MODEL_CACHE_DIR |
| config_path = model_dir / "Audio-Omni.json" |
| patched_config_path = _patch_config_vocab(config_path) |
| ckpt_path = model_dir / "model.ckpt" |
| sync_path = model_dir / "synchformer_state_dict.pth" |
| os.environ["SYNCHFORMER_CKPT"] = str(sync_path) |
| from audio_omni import AudioOmni |
|
|
| _MODEL = AudioOmni( |
| config_path=str(patched_config_path), |
| ckpt_path=str(ckpt_path), |
| device="cuda", |
| ) |
| if FORCE_ZERO_SYNC_FEATURES: |
| |
| |
| |
| def _zero_sync_features(_video_path: str, _duration: int = 10): |
| return torch.zeros(1, 240, 768, device=_MODEL.device) |
|
|
| _MODEL._extract_sync_features = _zero_sync_features |
| torch.set_grad_enabled(False) |
| return _MODEL |
|
|
|
|
| @spaces.GPU(duration=GPU_DURATION_SECONDS) |
| def warmup_model(): |
| if not PREFETCH_ON_START: |
| return "Автопрогрев отключен." |
| try: |
| started_at = time.time() |
| _load_model() |
| elapsed = time.time() - started_at |
| return f"Модель прогрета и готова к генерации. Время прогрева: {elapsed:.1f}с." |
| except Exception as exc: |
| return f"Автопрогрев не удался: {type(exc).__name__}: {exc}" |
|
|
|
|
| @spaces.GPU(duration=GPU_DURATION_SECONDS) |
| def run_v2a( |
| video_path: str, |
| prompt: str, |
| steps: int, |
| cfg_scale: float, |
| seconds_total: int, |
| seed: int, |
| ): |
| if not video_path: |
| return None, "Загрузите видео-файл." |
|
|
| try: |
| steps = int(steps) |
| requested_seconds = int(seconds_total) |
| seconds_total = FIXED_SECONDS_TOTAL |
| cfg_scale = float(cfg_scale) |
| seed = int(seed) |
|
|
| if steps > MAX_STEPS: |
| return None, f"Слишком много steps: {steps}. Максимум: {MAX_STEPS}." |
| if seconds_total > MAX_AUDIO_SECONDS: |
| return None, f"Слишком большая длительность: {seconds_total}s. Максимум: {MAX_AUDIO_SECONDS}s." |
| if steps * seconds_total > MAX_COMPLEXITY: |
| return None, ( |
| f"Слишком тяжелая комбинация (steps*seconds={steps * seconds_total}). " |
| f"Лимит: {MAX_COMPLEXITY}." |
| ) |
|
|
| estimated = _estimate_runtime_seconds(steps, seconds_total) |
| if estimated > MAX_RUNTIME_SECONDS: |
| return None, ( |
| f"Ожидаемое время ~{estimated:.0f}с, это больше лимита {MAX_RUNTIME_SECONDS}с. " |
| "Уменьшите steps или длительность." |
| ) |
|
|
| started_at = time.time() |
| model = _load_model() |
|
|
| audio_i16 = model.generate( |
| task="V2A", |
| prompt=(prompt or "").strip(), |
| video_path=video_path, |
| steps=steps, |
| cfg_scale=cfg_scale, |
| seconds_total=seconds_total, |
| seed=seed, |
| ) |
|
|
| audio_f32 = audio_i16.to(torch.float32) / 32767.0 |
| target_samples = int(model.sample_rate * OUTPUT_SECONDS) |
| if audio_f32.shape[-1] > target_samples: |
| audio_f32 = audio_f32[:, :target_samples] |
| tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
| tmp_wav.close() |
| audio_np = np.ascontiguousarray(audio_f32.transpose(0, 1).cpu().numpy()) |
| sf.write(tmp_wav.name, audio_np, model.sample_rate, subtype="PCM_16") |
| elapsed = time.time() - started_at |
| note = ( |
| f" Внутри модели: {FIXED_SECONDS_TOTAL}s (ограничение Audio-Omni), " |
| f"на выходе: {OUTPUT_SECONDS}s." |
| ) |
| if requested_seconds != OUTPUT_SECONDS: |
| note += f" Запрошено {requested_seconds}s." |
| return tmp_wav.name, f"Готово за {elapsed:.1f}с (оценка ~{estimated:.0f}с).{note}" |
| except Exception as exc: |
| debug_trace = traceback.format_exc(limit=8) |
| return None, f"Ошибка: {type(exc).__name__}: {exc}\n\n{debug_trace}" |
|
|
|
|
| with gr.Blocks(title="Audio-Omni V2A (ZeroGPU)") as demo: |
| gr.Markdown( |
| """ |
| # Audio-Omni: Video-to-Audio (V2A) on ZeroGPU |
| Загрузите видео, добавьте опциональный текстовый промпт и сгенерируйте аудио-дорожку. |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| video = gr.File( |
| label="Видео (upload)", |
| file_types=["video"], |
| type="filepath", |
| ) |
| prompt = gr.Textbox( |
| label="Промпт (опционально)", |
| placeholder="Например: realistic city ambience with distant sirens", |
| lines=2, |
| ) |
| with gr.Row(): |
| steps = gr.Slider(20, MAX_STEPS, value=min(70, MAX_STEPS), step=1, label="Diffusion steps") |
| cfg_scale = gr.Slider(1.0, 12.0, value=7.0, step=0.1, label="CFG scale") |
| with gr.Row(): |
| seconds_total = gr.Slider( |
| OUTPUT_SECONDS, |
| OUTPUT_SECONDS, |
| value=OUTPUT_SECONDS, |
| step=1, |
| label="Длительность результата (фиксировано)", |
| interactive=False, |
| ) |
| seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)") |
| run_btn = gr.Button("Generate V2A", variant="primary") |
|
|
| with gr.Column(): |
| audio_out = gr.Audio(label="Сгенерированное аудио", type="filepath") |
| status = gr.Textbox(label="Статус", interactive=False) |
|
|
| run_btn.click( |
| fn=run_v2a, |
| inputs=[video, prompt, steps, cfg_scale, seconds_total, seed], |
| outputs=[audio_out, status], |
| ) |
| demo.load(fn=warmup_model, inputs=None, outputs=[status]) |
|
|
| |
| if PREFETCH_ON_BOOT: |
| try: |
| print("[boot] Prefetching Audio-Omni assets...") |
| t0 = time.time() |
| _prefetch_runtime_assets() |
| print(f"[boot] Prefetch complete in {time.time() - t0:.1f}s.") |
| except Exception as exc: |
| |
| print(f"[boot] Prefetch failed: {type(exc).__name__}: {exc}") |
|
|
| if __name__ == "__main__": |
| demo.queue(max_size=8, default_concurrency_limit=1).launch() |
|
|