Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,11 +1,17 @@
|
|
| 1 |
-
# ---------- Gradio CDN
|
| 2 |
import os
|
| 3 |
-
import spaces
|
| 4 |
os.environ.setdefault("GRADIO_USE_CDN", "true")
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
def _gpu_probe() -> str:
|
| 7 |
-
#
|
| 8 |
return "ok"
|
|
|
|
|
|
|
| 9 |
import sys
|
| 10 |
import subprocess
|
| 11 |
from pathlib import Path
|
|
@@ -16,14 +22,14 @@ import numpy as np
|
|
| 16 |
import soundfile as sf
|
| 17 |
from huggingface_hub import hf_hub_download
|
| 18 |
|
| 19 |
-
# Detect ZeroGPU
|
| 20 |
USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"
|
| 21 |
|
| 22 |
-
SPACE_ROOT
|
| 23 |
-
REPO_DIR
|
| 24 |
WEIGHTS_REPO = "amaai-lab/SonicMaster"
|
| 25 |
WEIGHTS_FILE = "model.safetensors"
|
| 26 |
-
CACHE_DIR
|
| 27 |
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 28 |
|
| 29 |
# ---------- 1) Pull weights from HF Hub ----------
|
|
@@ -82,7 +88,7 @@ def read_audio(path: str) -> Tuple[np.ndarray, int]:
|
|
| 82 |
wav, sr = sf.read(path, always_2d=False)
|
| 83 |
return wav.astype(np.float32) if wav.dtype == np.float64 else wav, sr
|
| 84 |
|
| 85 |
-
# ---------- 5) Core inference ----------
|
| 86 |
def run_sonicmaster_cli(input_wav_path: Path,
|
| 87 |
prompt: str,
|
| 88 |
out_path: Path,
|
|
@@ -116,8 +122,8 @@ def run_sonicmaster_cli(input_wav_path: Path,
|
|
| 116 |
for idx, cmd in enumerate(CANDIDATE_CMDS, start=1):
|
| 117 |
try:
|
| 118 |
if progress: progress(0.35 + 0.05*idx, desc=f"Running inference (try {idx})")
|
| 119 |
-
|
| 120 |
-
|
| 121 |
if out_path.exists() and out_path.stat().st_size > 0:
|
| 122 |
if progress: progress(0.9, desc="Post-processing output")
|
| 123 |
return True
|
|
@@ -125,10 +131,10 @@ def run_sonicmaster_cli(input_wav_path: Path,
|
|
| 125 |
continue
|
| 126 |
return False
|
| 127 |
|
| 128 |
-
# ---------- 6)
|
| 129 |
@spaces.GPU(duration=180)
|
| 130 |
def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> bool:
|
| 131 |
-
#
|
| 132 |
try:
|
| 133 |
import torch # noqa: F401
|
| 134 |
except Exception:
|
|
@@ -150,7 +156,6 @@ def enhance_audio_ui(audio_path: str,
|
|
| 150 |
except: pass
|
| 151 |
save_temp_wav(wav, sr, tmp_in)
|
| 152 |
|
| 153 |
-
# Use GPU only when ZeroGPU is active; otherwise CPU fallback.
|
| 154 |
if progress: progress(0.3, desc="Starting inference")
|
| 155 |
if USE_ZEROGPU:
|
| 156 |
ok = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix())
|
|
@@ -179,7 +184,7 @@ with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill
|
|
| 179 |
outputs=[out_audio],
|
| 180 |
concurrency_limit=1)
|
| 181 |
|
| 182 |
-
# ---------- 9) FastAPI mount ----------
|
| 183 |
from fastapi import FastAPI, Request
|
| 184 |
from starlette.responses import PlainTextResponse
|
| 185 |
from starlette.requests import ClientDisconnect
|
|
@@ -196,4 +201,4 @@ app = gr.mount_gradio_app(app, demo.queue(max_size=16), path="/")
|
|
| 196 |
|
| 197 |
if __name__ == "__main__":
|
| 198 |
import uvicorn
|
| 199 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 1 |
+
# ---------- MUST BE FIRST: Gradio CDN + ZeroGPU probe ----------
|
| 2 |
import os
|
|
|
|
| 3 |
os.environ.setdefault("GRADIO_USE_CDN", "true")
|
| 4 |
+
|
| 5 |
+
# A GPU-decorated function MUST exist at import time for ZeroGPU.
|
| 6 |
+
# Import spaces unconditionally and register a tiny probe.
|
| 7 |
+
import spaces
|
| 8 |
+
|
| 9 |
+
@spaces.GPU(duration=10)
|
| 10 |
def _gpu_probe() -> str:
|
| 11 |
+
# Never called; only here so ZeroGPU startup check passes.
|
| 12 |
return "ok"
|
| 13 |
+
|
| 14 |
+
# ---------- Standard imports ----------
|
| 15 |
import sys
|
| 16 |
import subprocess
|
| 17 |
from pathlib import Path
|
|
|
|
| 22 |
import soundfile as sf
|
| 23 |
from huggingface_hub import hf_hub_download
|
| 24 |
|
| 25 |
+
# Detect ZeroGPU to decide whether to CALL the GPU function.
|
| 26 |
USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"
|
| 27 |
|
| 28 |
+
SPACE_ROOT = Path(__file__).parent.resolve()
|
| 29 |
+
REPO_DIR = SPACE_ROOT / "SonicMasterRepo"
|
| 30 |
WEIGHTS_REPO = "amaai-lab/SonicMaster"
|
| 31 |
WEIGHTS_FILE = "model.safetensors"
|
| 32 |
+
CACHE_DIR = SPACE_ROOT / "weights"
|
| 33 |
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 34 |
|
| 35 |
# ---------- 1) Pull weights from HF Hub ----------
|
|
|
|
| 88 |
wav, sr = sf.read(path, always_2d=False)
|
| 89 |
return wav.astype(np.float32) if wav.dtype == np.float64 else wav, sr
|
| 90 |
|
| 91 |
+
# ---------- 5) Core inference (subprocess calling your repo script) ----------
|
| 92 |
def run_sonicmaster_cli(input_wav_path: Path,
|
| 93 |
prompt: str,
|
| 94 |
out_path: Path,
|
|
|
|
| 122 |
for idx, cmd in enumerate(CANDIDATE_CMDS, start=1):
|
| 123 |
try:
|
| 124 |
if progress: progress(0.35 + 0.05*idx, desc=f"Running inference (try {idx})")
|
| 125 |
+
# inherit env so CUDA_VISIBLE_DEVICES from ZeroGPU reaches subprocess
|
| 126 |
+
subprocess.run(cmd, capture_output=True, text=True, check=True, env=os.environ.copy())
|
| 127 |
if out_path.exists() and out_path.stat().st_size > 0:
|
| 128 |
if progress: progress(0.9, desc="Post-processing output")
|
| 129 |
return True
|
|
|
|
| 131 |
continue
|
| 132 |
return False
|
| 133 |
|
| 134 |
+
# ---------- 6) REAL GPU function (always defined; only CALLED on ZeroGPU) ----------
|
| 135 |
@spaces.GPU(duration=180)
|
| 136 |
def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> bool:
|
| 137 |
+
# Import torch here so CUDA initializes inside GPU context
|
| 138 |
try:
|
| 139 |
import torch # noqa: F401
|
| 140 |
except Exception:
|
|
|
|
| 156 |
except: pass
|
| 157 |
save_temp_wav(wav, sr, tmp_in)
|
| 158 |
|
|
|
|
| 159 |
if progress: progress(0.3, desc="Starting inference")
|
| 160 |
if USE_ZEROGPU:
|
| 161 |
ok = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix())
|
|
|
|
| 184 |
outputs=[out_audio],
|
| 185 |
concurrency_limit=1)
|
| 186 |
|
| 187 |
+
# ---------- 9) FastAPI mount & disconnect handler ----------
|
| 188 |
from fastapi import FastAPI, Request
|
| 189 |
from starlette.responses import PlainTextResponse
|
| 190 |
from starlette.requests import ClientDisconnect
|
|
|
|
| 201 |
|
| 202 |
if __name__ == "__main__":
|
| 203 |
import uvicorn
|
| 204 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|