testovich / app.py
GiorgioV's picture
Update app.py
3892913 verified
import json
import os
import tempfile
import time
import traceback
from importlib.util import find_spec
from pathlib import Path
from urllib.request import urlretrieve
import gradio as gr
import numpy as np
import spaces
import soundfile as sf
import torch
from huggingface_hub import snapshot_download
MODEL_REPO_ID = os.environ.get("AUDIO_OMNI_REPO_ID", "HKUSTAudio/Audio-Omni")
MODEL_CACHE_DIR = Path(os.environ.get("AUDIO_OMNI_CACHE_DIR", "./model"))
QWEN_REPO_ID = os.environ.get("QWEN_OMNI_REPO_ID", "Qwen/Qwen2.5-Omni-3B")
QWEN_CACHE_DIR = Path(os.environ.get("QWEN_OMNI_CACHE_DIR", str(MODEL_CACHE_DIR / "qwen_omni_3b")))
GPU_DURATION_SECONDS = int(os.environ.get("GPU_DURATION_SECONDS", "120"))
MAX_RUNTIME_SECONDS = int(os.environ.get("MAX_RUNTIME_SECONDS", "115"))
MAX_STEPS = int(os.environ.get("MAX_STEPS", "85"))
MAX_AUDIO_SECONDS = int(os.environ.get("MAX_AUDIO_SECONDS", "10"))
MAX_COMPLEXITY = int(os.environ.get("MAX_COMPLEXITY", "800"))
FIXED_SECONDS_TOTAL = int(os.environ.get("FIXED_SECONDS_TOTAL", "10"))
OUTPUT_SECONDS = int(os.environ.get("OUTPUT_SECONDS", "5"))
PREFETCH_ON_START = os.environ.get("PREFETCH_ON_START", "1").strip().lower() in {"1", "true", "yes", "on"}
PREFETCH_ON_BOOT = os.environ.get("PREFETCH_ON_BOOT", "1").strip().lower() in {"1", "true", "yes", "on"}
FORCE_ZERO_SYNC_FEATURES = os.environ.get("FORCE_ZERO_SYNC_FEATURES", "1").strip().lower() in {
"1",
"true",
"yes",
"on",
}
_MODEL = None
_MODEL_LOAD_ERROR = None
VOCAB_FALLBACK_URL = "https://raw.githubusercontent.com/ZeyueT/Audio-Omni/main/audio_omni/data/vocab.txt"
HF_TOKEN = os.environ.get("HF_TOKEN")
def _estimate_runtime_seconds(steps: int, seconds_total: int) -> float:
# Practical heuristic for ZeroGPU sessions (warm model).
return 18.0 + 0.11 * float(steps) * float(seconds_total)
def _download_model_files() -> Path:
MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_dir = snapshot_download(
repo_id=MODEL_REPO_ID,
local_dir=str(MODEL_CACHE_DIR),
token=HF_TOKEN,
allow_patterns=[
"Audio-Omni.json",
"model.ckpt",
"synchformer_state_dict.pth",
],
)
return Path(local_dir)
def _download_qwen_files() -> Path:
QWEN_CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_dir = snapshot_download(
repo_id=QWEN_REPO_ID,
local_dir=str(QWEN_CACHE_DIR),
token=HF_TOKEN,
)
return Path(local_dir)
def _prefetch_runtime_assets() -> None:
_download_model_files()
_resolve_vocab_file()
qwen_dir = _download_qwen_files()
os.environ["QWEN_OMNI_MODEL_PATH"] = str(qwen_dir)
def _resolve_vocab_file() -> Path:
candidates = []
pkg_spec = find_spec("audio_omni")
if pkg_spec and pkg_spec.origin:
pkg_root = Path(pkg_spec.origin).resolve().parent
candidates.append(pkg_root / "data" / "vocab.txt")
candidates.append(Path.cwd() / "audio_omni" / "data" / "vocab.txt")
candidates.append(MODEL_CACHE_DIR / "vocab.txt")
for candidate in candidates:
if candidate.exists():
return candidate
fallback_path = MODEL_CACHE_DIR / "vocab.txt"
fallback_path.parent.mkdir(parents=True, exist_ok=True)
urlretrieve(VOCAB_FALLBACK_URL, str(fallback_path))
return fallback_path
def _patch_config_vocab(config_path: Path) -> Path:
vocab_path = _resolve_vocab_file()
with open(config_path, "r", encoding="utf-8") as fp:
config = json.load(fp)
cond_cfg = config.get("model", {}).get("conditioning", {}).get("configs", [])
for item in cond_cfg:
if item.get("id") == "speech_prompt":
item.setdefault("config", {})
item["config"]["vocab_file"] = str(vocab_path)
patched_path = config_path.parent / "Audio-Omni.patched.json"
with open(patched_path, "w", encoding="utf-8") as fp:
json.dump(config, fp)
return patched_path
def _load_model():
global _MODEL, _MODEL_LOAD_ERROR
if _MODEL is not None:
return _MODEL
if _MODEL_LOAD_ERROR is not None:
raise RuntimeError(_MODEL_LOAD_ERROR)
if not torch.cuda.is_available():
_MODEL_LOAD_ERROR = "CUDA is unavailable. ZeroGPU did not provide a GPU session."
raise RuntimeError(_MODEL_LOAD_ERROR)
_prefetch_runtime_assets()
model_dir = MODEL_CACHE_DIR
config_path = model_dir / "Audio-Omni.json"
patched_config_path = _patch_config_vocab(config_path)
ckpt_path = model_dir / "model.ckpt"
sync_path = model_dir / "synchformer_state_dict.pth"
os.environ["SYNCHFORMER_CKPT"] = str(sync_path)
from audio_omni import AudioOmni
_MODEL = AudioOmni(
config_path=str(patched_config_path),
ckpt_path=str(ckpt_path),
device="cuda",
)
if FORCE_ZERO_SYNC_FEATURES:
# Workaround for an upstream SynchformerConditioner bug in Audio-Omni:
# non-zero sync features can trigger a shape-mismatch assignment path.
# For V2A in Space, a zero sync tensor keeps inference stable.
def _zero_sync_features(_video_path: str, _duration: int = 10):
return torch.zeros(1, 240, 768, device=_MODEL.device)
_MODEL._extract_sync_features = _zero_sync_features
torch.set_grad_enabled(False)
return _MODEL
@spaces.GPU(duration=GPU_DURATION_SECONDS)
def warmup_model():
if not PREFETCH_ON_START:
return "Автопрогрев отключен."
try:
started_at = time.time()
_load_model()
elapsed = time.time() - started_at
return f"Модель прогрета и готова к генерации. Время прогрева: {elapsed:.1f}с."
except Exception as exc:
return f"Автопрогрев не удался: {type(exc).__name__}: {exc}"
@spaces.GPU(duration=GPU_DURATION_SECONDS)
def run_v2a(
video_path: str,
prompt: str,
steps: int,
cfg_scale: float,
seconds_total: int,
seed: int,
):
if not video_path:
return None, "Загрузите видео-файл."
try:
steps = int(steps)
requested_seconds = int(seconds_total)
seconds_total = FIXED_SECONDS_TOTAL
cfg_scale = float(cfg_scale)
seed = int(seed)
if steps > MAX_STEPS:
return None, f"Слишком много steps: {steps}. Максимум: {MAX_STEPS}."
if seconds_total > MAX_AUDIO_SECONDS:
return None, f"Слишком большая длительность: {seconds_total}s. Максимум: {MAX_AUDIO_SECONDS}s."
if steps * seconds_total > MAX_COMPLEXITY:
return None, (
f"Слишком тяжелая комбинация (steps*seconds={steps * seconds_total}). "
f"Лимит: {MAX_COMPLEXITY}."
)
estimated = _estimate_runtime_seconds(steps, seconds_total)
if estimated > MAX_RUNTIME_SECONDS:
return None, (
f"Ожидаемое время ~{estimated:.0f}с, это больше лимита {MAX_RUNTIME_SECONDS}с. "
"Уменьшите steps или длительность."
)
started_at = time.time()
model = _load_model()
audio_i16 = model.generate(
task="V2A",
prompt=(prompt or "").strip(),
video_path=video_path,
steps=steps,
cfg_scale=cfg_scale,
seconds_total=seconds_total,
seed=seed,
)
audio_f32 = audio_i16.to(torch.float32) / 32767.0
target_samples = int(model.sample_rate * OUTPUT_SECONDS)
if audio_f32.shape[-1] > target_samples:
audio_f32 = audio_f32[:, :target_samples]
tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp_wav.close()
audio_np = np.ascontiguousarray(audio_f32.transpose(0, 1).cpu().numpy())
sf.write(tmp_wav.name, audio_np, model.sample_rate, subtype="PCM_16")
elapsed = time.time() - started_at
note = (
f" Внутри модели: {FIXED_SECONDS_TOTAL}s (ограничение Audio-Omni), "
f"на выходе: {OUTPUT_SECONDS}s."
)
if requested_seconds != OUTPUT_SECONDS:
note += f" Запрошено {requested_seconds}s."
return tmp_wav.name, f"Готово за {elapsed:.1f}с (оценка ~{estimated:.0f}с).{note}"
except Exception as exc:
debug_trace = traceback.format_exc(limit=8)
return None, f"Ошибка: {type(exc).__name__}: {exc}\n\n{debug_trace}"
with gr.Blocks(title="Audio-Omni V2A (ZeroGPU)") as demo:
gr.Markdown(
"""
# Audio-Omni: Video-to-Audio (V2A) on ZeroGPU
Загрузите видео, добавьте опциональный текстовый промпт и сгенерируйте аудио-дорожку.
"""
)
with gr.Row():
with gr.Column():
video = gr.File(
label="Видео (upload)",
file_types=["video"],
type="filepath",
)
prompt = gr.Textbox(
label="Промпт (опционально)",
placeholder="Например: realistic city ambience with distant sirens",
lines=2,
)
with gr.Row():
steps = gr.Slider(20, MAX_STEPS, value=min(70, MAX_STEPS), step=1, label="Diffusion steps")
cfg_scale = gr.Slider(1.0, 12.0, value=7.0, step=0.1, label="CFG scale")
with gr.Row():
seconds_total = gr.Slider(
OUTPUT_SECONDS,
OUTPUT_SECONDS,
value=OUTPUT_SECONDS,
step=1,
label="Длительность результата (фиксировано)",
interactive=False,
)
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
run_btn = gr.Button("Generate V2A", variant="primary")
with gr.Column():
audio_out = gr.Audio(label="Сгенерированное аудио", type="filepath")
status = gr.Textbox(label="Статус", interactive=False)
run_btn.click(
fn=run_v2a,
inputs=[video, prompt, steps, cfg_scale, seconds_total, seed],
outputs=[audio_out, status],
)
demo.load(fn=warmup_model, inputs=None, outputs=[status])
# Run download prefetch during container startup (before first request).
if PREFETCH_ON_BOOT:
try:
print("[boot] Prefetching Audio-Omni assets...")
t0 = time.time()
_prefetch_runtime_assets()
print(f"[boot] Prefetch complete in {time.time() - t0:.1f}s.")
except Exception as exc:
# Keep app alive even if prefetch fails; generation path can retry.
print(f"[boot] Prefetch failed: {type(exc).__name__}: {exc}")
if __name__ == "__main__":
demo.queue(max_size=8, default_concurrency_limit=1).launch()