ambujm22 commited on
Commit
e46f7d7
·
verified ·
1 Parent(s): f09078a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -24
app.py CHANGED
@@ -1,17 +1,20 @@
1
- # ---------- MUST BE FIRST: Gradio CDN + ZeroGPU probe ----------
2
- import os
3
- os.environ.setdefault("GRADIO_USE_CDN", "true")
4
-
5
- # A GPU-decorated function MUST exist at import time for ZeroGPU.
6
- # Import spaces unconditionally and register a tiny probe.
7
- import spaces
8
 
9
  @spaces.GPU(duration=10)
10
- def _gpu_probe() -> str:
11
- # Never called; only here so ZeroGPU startup check passes.
12
  return "ok"
 
 
 
 
 
13
 
14
- # ---------- Standard imports ----------
15
  import sys
16
  import subprocess
17
  from pathlib import Path
@@ -22,7 +25,6 @@ import numpy as np
22
  import soundfile as sf
23
  from huggingface_hub import hf_hub_download
24
 
25
- # Detect ZeroGPU to decide whether to CALL the GPU function.
26
  USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"
27
 
28
  SPACE_ROOT = Path(__file__).parent.resolve()
@@ -32,7 +34,6 @@ WEIGHTS_FILE = "model.safetensors"
32
  CACHE_DIR = SPACE_ROOT / "weights"
33
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
34
 
35
- # ---------- 1) Pull weights from HF Hub ----------
36
  def get_weights_path() -> Path:
37
  return Path(
38
  hf_hub_download(
@@ -45,7 +46,6 @@ def get_weights_path() -> Path:
45
  )
46
  )
47
 
48
- # ---------- 2) Clone GitHub repo ----------
49
  def ensure_repo() -> Path:
50
  if not REPO_DIR.exists():
51
  subprocess.run(
@@ -58,7 +58,6 @@ def ensure_repo() -> Path:
58
  sys.path.append(REPO_DIR.as_posix())
59
  return REPO_DIR
60
 
61
- # ---------- 3) Examples ----------
62
  def build_examples():
63
  repo = ensure_repo()
64
  wav_dir = repo / "samples" / "inputs"
@@ -78,7 +77,6 @@ def build_examples():
78
  return [[p.as_posix(), prompts[i] if i < len(prompts) else prompts[-1]]
79
  for i, p in enumerate(wav_paths[:10])]
80
 
81
- # ---------- 4) I/O helpers ----------
82
  def save_temp_wav(wav: np.ndarray, sr: int, path: Path):
83
  if wav.ndim == 2 and wav.shape[0] < wav.shape[1]:
84
  wav = wav.T
@@ -88,7 +86,6 @@ def read_audio(path: str) -> Tuple[np.ndarray, int]:
88
  wav, sr = sf.read(path, always_2d=False)
89
  return wav.astype(np.float32) if wav.dtype == np.float64 else wav, sr
90
 
91
- # ---------- 5) Core inference (subprocess calling your repo script) ----------
92
  def run_sonicmaster_cli(input_wav_path: Path,
93
  prompt: str,
94
  out_path: Path,
@@ -122,7 +119,6 @@ def run_sonicmaster_cli(input_wav_path: Path,
122
  for idx, cmd in enumerate(CANDIDATE_CMDS, start=1):
123
  try:
124
  if progress: progress(0.35 + 0.05*idx, desc=f"Running inference (try {idx})")
125
- # inherit env so CUDA_VISIBLE_DEVICES from ZeroGPU reaches subprocess
126
  subprocess.run(cmd, capture_output=True, text=True, check=True, env=os.environ.copy())
127
  if out_path.exists() and out_path.stat().st_size > 0:
128
  if progress: progress(0.9, desc="Post-processing output")
@@ -131,18 +127,15 @@ def run_sonicmaster_cli(input_wav_path: Path,
131
  continue
132
  return False
133
 
134
- # ---------- 6) REAL GPU function (always defined; only CALLED on ZeroGPU) ----------
135
  @spaces.GPU(duration=180)
136
  def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> bool:
137
- # Import torch here so CUDA initializes inside GPU context
138
  try:
139
- import torch # noqa: F401
140
  except Exception:
141
  pass
142
  from pathlib import Path as _P
143
  return run_sonicmaster_cli(_P(input_path), prompt, _P(output_path), _logs=[], progress=None)
144
 
145
- # ---------- 7) Gradio callback ----------
146
  def enhance_audio_ui(audio_path: str,
147
  prompt: str,
148
  progress=gr.Progress(track_tqdm=True)) -> Tuple[int, np.ndarray]:
@@ -166,9 +159,9 @@ def enhance_audio_ui(audio_path: str,
166
  out_wav, out_sr = read_audio(tmp_out.as_posix())
167
  return (out_sr, out_wav)
168
  else:
 
169
  return (sr, wav)
170
 
171
- # ---------- 8) Gradio UI ----------
172
  with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill_height=True) as demo:
173
  gr.Markdown("## 🎧 SonicMaster\nUpload or choose an example, write a text prompt, then click **Enhance**.")
174
  with gr.Row():
@@ -184,12 +177,13 @@ with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill
184
  outputs=[out_audio],
185
  concurrency_limit=1)
186
 
187
- # ---------- 9) FastAPI mount & disconnect handler ----------
188
  from fastapi import FastAPI, Request
189
  from starlette.responses import PlainTextResponse
190
  from starlette.requests import ClientDisconnect
191
 
192
- _ = get_weights_path(); _ = ensure_repo()
 
 
193
 
194
  app = FastAPI()
195
 
@@ -197,6 +191,7 @@ app = FastAPI()
197
  async def client_disconnect_handler(request: Request, exc: ClientDisconnect):
198
  return PlainTextResponse("Client disconnected", status_code=499)
199
 
 
200
  app = gr.mount_gradio_app(app, demo.queue(max_size=16), path="/")
201
 
202
  if __name__ == "__main__":
 
1
+ # --- ABSOLUTE TOP (see above block) ---
2
+ import importlib
3
+ try:
4
+ import spaces
5
+ except Exception as e:
6
+ raise RuntimeError("Failed to import 'spaces' package. Add `spaces` to requirements.txt.") from e
 
7
 
8
  @spaces.GPU(duration=10)
9
+ def _zerogpu_probe():
 
10
  return "ok"
11
+ # --------------------------------------
12
+
13
+ # You can set env vars after the probe—safe.
14
+ import os
15
+ os.environ.setdefault("GRADIO_USE_CDN", "true")
16
 
17
+ # Standard imports
18
  import sys
19
  import subprocess
20
  from pathlib import Path
 
25
  import soundfile as sf
26
  from huggingface_hub import hf_hub_download
27
 
 
28
  USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"
29
 
30
  SPACE_ROOT = Path(__file__).parent.resolve()
 
34
  CACHE_DIR = SPACE_ROOT / "weights"
35
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
36
 
 
37
  def get_weights_path() -> Path:
38
  return Path(
39
  hf_hub_download(
 
46
  )
47
  )
48
 
 
49
  def ensure_repo() -> Path:
50
  if not REPO_DIR.exists():
51
  subprocess.run(
 
58
  sys.path.append(REPO_DIR.as_posix())
59
  return REPO_DIR
60
 
 
61
  def build_examples():
62
  repo = ensure_repo()
63
  wav_dir = repo / "samples" / "inputs"
 
77
  return [[p.as_posix(), prompts[i] if i < len(prompts) else prompts[-1]]
78
  for i, p in enumerate(wav_paths[:10])]
79
 
 
80
  def save_temp_wav(wav: np.ndarray, sr: int, path: Path):
81
  if wav.ndim == 2 and wav.shape[0] < wav.shape[1]:
82
  wav = wav.T
 
86
  wav, sr = sf.read(path, always_2d=False)
87
  return wav.astype(np.float32) if wav.dtype == np.float64 else wav, sr
88
 
 
89
  def run_sonicmaster_cli(input_wav_path: Path,
90
  prompt: str,
91
  out_path: Path,
 
119
  for idx, cmd in enumerate(CANDIDATE_CMDS, start=1):
120
  try:
121
  if progress: progress(0.35 + 0.05*idx, desc=f"Running inference (try {idx})")
 
122
  subprocess.run(cmd, capture_output=True, text=True, check=True, env=os.environ.copy())
123
  if out_path.exists() and out_path.stat().st_size > 0:
124
  if progress: progress(0.9, desc="Post-processing output")
 
127
  continue
128
  return False
129
 
 
130
  @spaces.GPU(duration=180)
131
  def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> bool:
 
132
  try:
133
+ import torch # ensure CUDA init happens in GPU context
134
  except Exception:
135
  pass
136
  from pathlib import Path as _P
137
  return run_sonicmaster_cli(_P(input_path), prompt, _P(output_path), _logs=[], progress=None)
138
 
 
139
  def enhance_audio_ui(audio_path: str,
140
  prompt: str,
141
  progress=gr.Progress(track_tqdm=True)) -> Tuple[int, np.ndarray]:
 
159
  out_wav, out_sr = read_audio(tmp_out.as_posix())
160
  return (out_sr, out_wav)
161
  else:
162
+ # If inference fails, return original audio (your chosen fallback).
163
  return (sr, wav)
164
 
 
165
  with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill_height=True) as demo:
166
  gr.Markdown("## 🎧 SonicMaster\nUpload or choose an example, write a text prompt, then click **Enhance**.")
167
  with gr.Row():
 
177
  outputs=[out_audio],
178
  concurrency_limit=1)
179
 
 
180
  from fastapi import FastAPI, Request
181
  from starlette.responses import PlainTextResponse
182
  from starlette.requests import ClientDisconnect
183
 
184
+ # Preload light assets after probe so import won’t fail before detector.
185
+ _ = get_weights_path()
186
+ _ = ensure_repo()
187
 
188
  app = FastAPI()
189
 
 
191
  async def client_disconnect_handler(request: Request, exc: ClientDisconnect):
192
  return PlainTextResponse("Client disconnected", status_code=499)
193
 
194
+ # (Queue() at mount time is fine)
195
  app = gr.mount_gradio_app(app, demo.queue(max_size=16), path="/")
196
 
197
  if __name__ == "__main__":