ambujm22 commited on
Commit
a42f4df
·
verified ·
1 Parent(s): 54d426a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -128
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  os.environ.setdefault("GRADIO_USE_CDN", "true")
 
3
  import sys
4
- import spaces
5
  import subprocess
6
  from pathlib import Path
7
  from typing import Tuple, Optional
@@ -11,57 +11,52 @@ import numpy as np
11
  import soundfile as sf
12
  from huggingface_hub import hf_hub_download
13
 
 
 
 
 
 
14
  SPACE_ROOT = Path(__file__).parent.resolve()
15
  REPO_DIR = SPACE_ROOT / "SonicMasterRepo"
16
  WEIGHTS_REPO = "amaai-lab/SonicMaster"
17
- WEIGHTS_FILE = "model.safetensors" # from the HF model repo
18
  CACHE_DIR = SPACE_ROOT / "weights"
19
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
20
 
21
- # --- ZeroGPU compatibility (no behavior change) ---
22
- USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"
23
- if USE_ZEROGPU:
24
- import spaces
25
 
26
- @spaces.GPU(duration=60)
27
- def _zgpu_dummy():
28
- # Minimal GPU-annotated function so ZeroGPU runtime is satisfied.
29
- # We never call this; actual inference stays CPU unless you choose otherwise.
30
- return "ok"
31
  # ---------- 1) Pull weights from HF Hub ----------
32
  def get_weights_path() -> Path:
33
- weights_path = hf_hub_download(
34
- repo_id=WEIGHTS_REPO,
35
- filename=WEIGHTS_FILE,
36
- local_dir=CACHE_DIR.as_posix(),
37
- local_dir_use_symlinks=False,
38
- force_download=False,
39
- resume_download=True,
 
 
40
  )
41
- return Path(weights_path)
42
 
43
- # ---------- 2) Clone GitHub repo for code (model.py / inference_*.py ) ----------
 
44
  def ensure_repo() -> Path:
45
  if not REPO_DIR.exists():
46
  subprocess.run(
47
- ["git", "clone", "--depth", "1", "https://github.com/AMAAI-Lab/SonicMaster", REPO_DIR.as_posix()],
 
 
48
  check=True,
49
  )
50
  if REPO_DIR.as_posix() not in sys.path:
51
  sys.path.append(REPO_DIR.as_posix())
52
  return REPO_DIR
53
 
54
- # ---------- 3) Examples: use only *.wav from samples/inputs ----------
 
55
  def build_examples():
56
- """
57
- Discover up to 10 .wav files from:
58
- SonicMasterRepo/samples/inputs
59
- and pair them with prompts for gr.Examples.
60
- """
61
  repo = ensure_repo()
62
  wav_dir = repo / "samples" / "inputs"
63
  wav_paths = sorted(p for p in wav_dir.glob("*.wav") if p.is_file())
64
-
65
  prompts = [
66
  "Increase the clarity of this song by emphasizing treble frequencies.",
67
  "Make this song sound more boomy by amplifying the low end bass frequencies.",
@@ -74,54 +69,33 @@ def build_examples():
74
  "Increase the clarity of this song by emphasizing treble frequencies.",
75
  "Please, dereverb this audio.",
76
  ]
 
 
77
 
78
- examples = []
79
- for i, p in enumerate(wav_paths[:10]):
80
- prompt = prompts[i] if i < len(prompts) else prompts[-1]
81
- examples.append([p.as_posix(), prompt])
82
-
83
- # Fallback: if no wavs found, provide an empty list (Gradio handles it)
84
- return examples
85
 
86
  # ---------- 4) I/O helpers ----------
87
  def save_temp_wav(wav: np.ndarray, sr: int, path: Path):
88
- # Ensure (samples, channels) for soundfile
89
  if wav.ndim == 2 and wav.shape[0] < wav.shape[1]:
90
- # (channels, samples) -> (samples, channels)
91
- data = wav.T
92
- else:
93
- data = wav
94
- sf.write(path.as_posix(), data, sr)
95
 
96
  def read_audio(path: str) -> Tuple[np.ndarray, int]:
97
  wav, sr = sf.read(path, always_2d=False)
98
- if wav.dtype == np.float64:
99
- wav = wav.astype(np.float32)
100
- return wav, sr
101
-
102
- def run_sonicmaster_cli(
103
- input_wav_path: Path,
104
- prompt: str,
105
- out_path: Path,
106
- _logs: list, # kept for compatibility, but not shown in UI
107
- progress: Optional[gr.Progress] = None
108
- ) -> bool:
109
- """
110
- Uses the current Python interpreter and tries a few script names/flags.
111
- """
112
- import sys, shutil
113
 
 
 
 
 
 
 
114
  if progress: progress(0.15, desc="Loading weights & repo")
115
  ckpt = get_weights_path()
116
  repo = ensure_repo()
117
 
118
- # Use the exact Python interpreter running this process
119
- py = sys.executable or shutil.which("python3") or shutil.which("python") or "python3"
120
-
121
- # Prefer the scripts we know accept --ckpt/--input/--prompt/--output
122
- script_candidates = [
123
- repo / "infer_single.py", # if you kept your own name
124
- ]
125
 
126
  CANDIDATE_CMDS = []
127
  for script in script_candidates:
@@ -133,10 +107,6 @@ def run_sonicmaster_cli(
133
  "--prompt", prompt,
134
  "--output", out_path.as_posix(),
135
  ])
136
-
137
- # As a last resort, try alternative flag names (if someone changed the CLI)
138
- for script in script_candidates:
139
- if script.exists():
140
  CANDIDATE_CMDS.append([
141
  py, script.as_posix(),
142
  "--weights", ckpt.as_posix(),
@@ -145,110 +115,86 @@ def run_sonicmaster_cli(
145
  "--out", out_path.as_posix(),
146
  ])
147
 
148
- if not CANDIDATE_CMDS:
149
- return False
150
-
151
  for idx, cmd in enumerate(CANDIDATE_CMDS, start=1):
152
  try:
153
  if progress: progress(0.35 + 0.05*idx, desc=f"Running inference (try {idx})")
154
- res = subprocess.run(cmd, capture_output=True, text=True, check=True)
 
155
  if out_path.exists() and out_path.stat().st_size > 0:
156
  if progress: progress(0.9, desc="Post-processing output")
157
  return True
158
- except subprocess.CalledProcessError:
159
- continue
160
  except Exception:
161
  continue
162
  return False
163
 
164
 
165
- def enhance_audio_ui(
166
- audio_path: str,
167
- prompt: str,
168
- progress=gr.Progress(track_tqdm=True)
169
- ) -> Tuple[int, np.ndarray]:
170
- """
171
- Gradio callback: accepts a file path, a prompt, and returns enhanced audio.
172
- """
173
- if progress: progress(0.0, desc="Validating input")
 
 
 
 
174
  if not audio_path or not prompt:
175
  raise gr.Error("Please provide audio and a text prompt.")
176
 
177
- # Standardize input -> temp wav
178
  wav, sr = read_audio(audio_path)
179
- if progress: progress(0.15, desc="Preparing audio")
180
- tmp_in = SPACE_ROOT / "tmp_in.wav"
181
- tmp_out = SPACE_ROOT / "tmp_out.wav"
182
  if tmp_out.exists():
183
- try:
184
- tmp_out.unlink()
185
- except Exception:
186
- pass
187
-
188
  save_temp_wav(wav, sr, tmp_in)
189
 
190
- # Run model
191
- if progress: progress(0.3, desc="Starting inference")
192
- ok = run_sonicmaster_cli(tmp_in, prompt, tmp_out, _logs=[], progress=progress)
 
193
 
194
- # Return output (or echo input)
195
  if ok and tmp_out.exists() and tmp_out.stat().st_size > 0:
196
  out_wav, out_sr = read_audio(tmp_out.as_posix())
197
- if progress: progress(1.0, desc="Done")
198
  return (out_sr, out_wav)
199
  else:
200
- if progress: progress(1.0, desc="No output produced")
201
- # Return original audio if model didn't produce output
202
  return (sr, wav)
203
 
204
- # ---------- 6) Gradio UI ----------
 
205
  with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill_height=True) as demo:
206
- gr.Markdown("## 🎧 SonicMaster\nUpload or choose an example (from repo: `samples/inputs/*.wav`), write a text prompt (e.g., *reduce reverb*, *clean distortion*), then click **Enhance**.")
207
  with gr.Row():
208
- with gr.Column(scale=1):
209
- in_audio = gr.Audio(label="Input Audio (upload or use examples)", type="filepath")
210
- prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., reduce reverb and enhance clarity")
211
  run_btn = gr.Button("🚀 Enhance", variant="primary")
212
-
213
- # Use wavs from SonicMasterRepo/samples/inputs
214
- gr.Examples(
215
- examples=build_examples(),
216
- inputs=[in_audio, prompt],
217
- label="Examples (repo: samples/inputs/*.wav)"
218
- )
219
- with gr.Column(scale=1):
220
  out_audio = gr.Audio(label="Enhanced Audio (output)")
 
 
 
 
221
 
222
- # Per-event concurrency (use 1 unless you know your VRAM/CPU can handle more)
223
- run_btn.click(
224
- fn=enhance_audio_ui,
225
- inputs=[in_audio, prompt],
226
- outputs=[out_audio],
227
- concurrency_limit=1,
228
- )
229
 
230
- # ---------- FastAPI mount with ClientDisconnect handler (Spaces-friendly) ----------
231
  from fastapi import FastAPI, Request
232
  from starlette.responses import PlainTextResponse
233
  from starlette.requests import ClientDisconnect
234
- import gradio as gr
235
 
236
- # Warm up cache & repo (optional but nice)
237
- _ = get_weights_path()
238
- _ = ensure_repo()
239
 
240
  app = FastAPI()
241
 
242
  @app.exception_handler(ClientDisconnect)
243
  async def client_disconnect_handler(request: Request, exc: ClientDisconnect):
244
- # Treat as benign client cancel; avoids noisy tracebacks on Spaces
245
  return PlainTextResponse("Client disconnected", status_code=499)
246
 
247
- # Mount Gradio at root. Use queue if you were using it before.
248
  app = gr.mount_gradio_app(app, demo.queue(max_size=16), path="/")
249
 
250
- # Optional: allow local dev with `python app.py`
251
  if __name__ == "__main__":
252
  import uvicorn
253
- uvicorn.run(app, host="0.0.0.0", port=7860)
254
-
 
1
  import os
2
  os.environ.setdefault("GRADIO_USE_CDN", "true")
3
+
4
  import sys
 
5
  import subprocess
6
  from pathlib import Path
7
  from typing import Tuple, Optional
 
11
  import soundfile as sf
12
  from huggingface_hub import hf_hub_download
13
 
14
+ # Detect ZeroGPU
15
+ USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"
16
+ if USE_ZEROGPU:
17
+ import spaces
18
+
19
  SPACE_ROOT = Path(__file__).parent.resolve()
20
  REPO_DIR = SPACE_ROOT / "SonicMasterRepo"
21
  WEIGHTS_REPO = "amaai-lab/SonicMaster"
22
+ WEIGHTS_FILE = "model.safetensors"
23
  CACHE_DIR = SPACE_ROOT / "weights"
24
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
25
 
 
 
 
 
26
 
 
 
 
 
 
27
  # ---------- 1) Pull weights from HF Hub ----------
28
  def get_weights_path() -> Path:
29
+ return Path(
30
+ hf_hub_download(
31
+ repo_id=WEIGHTS_REPO,
32
+ filename=WEIGHTS_FILE,
33
+ local_dir=CACHE_DIR.as_posix(),
34
+ local_dir_use_symlinks=False,
35
+ force_download=False,
36
+ resume_download=True,
37
+ )
38
  )
 
39
 
40
+
41
+ # ---------- 2) Clone GitHub repo ----------
42
  def ensure_repo() -> Path:
43
  if not REPO_DIR.exists():
44
  subprocess.run(
45
+ ["git", "clone", "--depth", "1",
46
+ "https://github.com/AMAAI-Lab/SonicMaster",
47
+ REPO_DIR.as_posix()],
48
  check=True,
49
  )
50
  if REPO_DIR.as_posix() not in sys.path:
51
  sys.path.append(REPO_DIR.as_posix())
52
  return REPO_DIR
53
 
54
+
55
+ # ---------- 3) Examples ----------
56
  def build_examples():
 
 
 
 
 
57
  repo = ensure_repo()
58
  wav_dir = repo / "samples" / "inputs"
59
  wav_paths = sorted(p for p in wav_dir.glob("*.wav") if p.is_file())
 
60
  prompts = [
61
  "Increase the clarity of this song by emphasizing treble frequencies.",
62
  "Make this song sound more boomy by amplifying the low end bass frequencies.",
 
69
  "Increase the clarity of this song by emphasizing treble frequencies.",
70
  "Please, dereverb this audio.",
71
  ]
72
+ return [[p.as_posix(), prompts[i] if i < len(prompts) else prompts[-1]]
73
+ for i, p in enumerate(wav_paths[:10])]
74
 
 
 
 
 
 
 
 
75
 
76
  # ---------- 4) I/O helpers ----------
77
  def save_temp_wav(wav: np.ndarray, sr: int, path: Path):
 
78
  if wav.ndim == 2 and wav.shape[0] < wav.shape[1]:
79
+ wav = wav.T
80
+ sf.write(path.as_posix(), wav, sr)
 
 
 
81
 
82
  def read_audio(path: str) -> Tuple[np.ndarray, int]:
83
  wav, sr = sf.read(path, always_2d=False)
84
+ return wav.astype(np.float32) if wav.dtype == np.float64 else wav, sr
85
+
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ # ---------- 5) Core inference ----------
88
+ def run_sonicmaster_cli(input_wav_path: Path,
89
+ prompt: str,
90
+ out_path: Path,
91
+ _logs: list,
92
+ progress: Optional[gr.Progress] = None) -> bool:
93
  if progress: progress(0.15, desc="Loading weights & repo")
94
  ckpt = get_weights_path()
95
  repo = ensure_repo()
96
 
97
+ py = sys.executable or "python3"
98
+ script_candidates = [repo / "infer_single.py"]
 
 
 
 
 
99
 
100
  CANDIDATE_CMDS = []
101
  for script in script_candidates:
 
107
  "--prompt", prompt,
108
  "--output", out_path.as_posix(),
109
  ])
 
 
 
 
110
  CANDIDATE_CMDS.append([
111
  py, script.as_posix(),
112
  "--weights", ckpt.as_posix(),
 
115
  "--out", out_path.as_posix(),
116
  ])
117
 
 
 
 
118
  for idx, cmd in enumerate(CANDIDATE_CMDS, start=1):
119
  try:
120
  if progress: progress(0.35 + 0.05*idx, desc=f"Running inference (try {idx})")
121
+ res = subprocess.run(cmd, capture_output=True,
122
+ text=True, check=True, env=os.environ.copy())
123
  if out_path.exists() and out_path.stat().st_size > 0:
124
  if progress: progress(0.9, desc="Post-processing output")
125
  return True
 
 
126
  except Exception:
127
  continue
128
  return False
129
 
130
 
131
+ # ---------- 6) ZeroGPU wrapper ----------
132
+ if USE_ZEROGPU:
133
+ @spaces.GPU(duration=180)
134
+ def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> bool:
135
+ import torch # ensures CUDA initializes inside GPU context
136
+ return run_sonicmaster_cli(Path(input_path), prompt, Path(output_path),
137
+ _logs=[], progress=None)
138
+
139
+
140
+ # ---------- 7) Gradio callback ----------
141
+ def enhance_audio_ui(audio_path: str,
142
+ prompt: str,
143
+ progress=gr.Progress(track_tqdm=True)) -> Tuple[int, np.ndarray]:
144
  if not audio_path or not prompt:
145
  raise gr.Error("Please provide audio and a text prompt.")
146
 
 
147
  wav, sr = read_audio(audio_path)
148
+ tmp_in, tmp_out = SPACE_ROOT / "tmp_in.wav", SPACE_ROOT / "tmp_out.wav"
 
 
149
  if tmp_out.exists():
150
+ try: tmp_out.unlink()
151
+ except: pass
 
 
 
152
  save_temp_wav(wav, sr, tmp_in)
153
 
154
+ if USE_ZEROGPU:
155
+ ok = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix())
156
+ else:
157
+ ok = run_sonicmaster_cli(tmp_in, prompt, tmp_out, _logs=[], progress=progress)
158
 
 
159
  if ok and tmp_out.exists() and tmp_out.stat().st_size > 0:
160
  out_wav, out_sr = read_audio(tmp_out.as_posix())
 
161
  return (out_sr, out_wav)
162
  else:
 
 
163
  return (sr, wav)
164
 
165
+
166
+ # ---------- 8) Gradio UI ----------
167
  with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill_height=True) as demo:
168
+ gr.Markdown("## 🎧 SonicMaster\nUpload or choose an example, write a text prompt, then click **Enhance**.")
169
  with gr.Row():
170
+ with gr.Column():
171
+ in_audio = gr.Audio(label="Input Audio", type="filepath")
172
+ prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., reduce reverb")
173
  run_btn = gr.Button("🚀 Enhance", variant="primary")
174
+ gr.Examples(examples=build_examples(), inputs=[in_audio, prompt])
175
+ with gr.Column():
 
 
 
 
 
 
176
  out_audio = gr.Audio(label="Enhanced Audio (output)")
177
+ run_btn.click(fn=enhance_audio_ui,
178
+ inputs=[in_audio, prompt],
179
+ outputs=[out_audio],
180
+ concurrency_limit=1)
181
 
 
 
 
 
 
 
 
182
 
183
+ # ---------- 9) FastAPI mount ----------
184
  from fastapi import FastAPI, Request
185
  from starlette.responses import PlainTextResponse
186
  from starlette.requests import ClientDisconnect
 
187
 
188
+ _ = get_weights_path(); _ = ensure_repo()
 
 
189
 
190
  app = FastAPI()
191
 
192
  @app.exception_handler(ClientDisconnect)
193
  async def client_disconnect_handler(request: Request, exc: ClientDisconnect):
 
194
  return PlainTextResponse("Client disconnected", status_code=499)
195
 
 
196
  app = gr.mount_gradio_app(app, demo.queue(max_size=16), path="/")
197
 
 
198
  if __name__ == "__main__":
199
  import uvicorn
200
+ uvicorn.run(app, host="0.0.0.0", port=7860)