ambujm22 commited on
Commit
95c1898
·
verified ·
1 Parent(s): ef16744

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -201
app.py CHANGED
@@ -1,204 +1,11 @@
1
- # ---------- MUST BE FIRST: Gradio CDN + ZeroGPU probe ----------
2
- import os
3
- os.environ.setdefault("GRADIO_USE_CDN", "true")
4
-
5
- # A GPU-decorated function MUST exist at import time for ZeroGPU.
6
- # Import spaces unconditionally and register a tiny probe.
7
  import spaces
8
-
9
- @spaces.GPU(duration=10)
10
- def _gpu_probe() -> str:
11
- # Never called; only here so ZeroGPU startup check passes.
12
- return "ok"
13
-
14
- # ---------- Standard imports ----------
15
- import sys
16
- import subprocess
17
- from pathlib import Path
18
- from typing import Tuple, Optional
19
-
20
  import gradio as gr
21
- import numpy as np
22
- import soundfile as sf
23
- from huggingface_hub import hf_hub_download
24
-
25
- # Detect ZeroGPU to decide whether to CALL the GPU function.
26
- USE_ZEROGPU = os.getenv("SPACE_RUNTIME", "").lower() == "zerogpu"
27
-
28
- SPACE_ROOT = Path(__file__).parent.resolve()
29
- REPO_DIR = SPACE_ROOT / "SonicMasterRepo"
30
- WEIGHTS_REPO = "amaai-lab/SonicMaster"
31
- WEIGHTS_FILE = "model.safetensors"
32
- CACHE_DIR = SPACE_ROOT / "weights"
33
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
34
-
35
- # ---------- 1) Pull weights from HF Hub ----------
36
- def get_weights_path() -> Path:
37
- return Path(
38
- hf_hub_download(
39
- repo_id=WEIGHTS_REPO,
40
- filename=WEIGHTS_FILE,
41
- local_dir=CACHE_DIR.as_posix(),
42
- local_dir_use_symlinks=False,
43
- force_download=False,
44
- resume_download=True,
45
- )
46
- )
47
-
48
- # ---------- 2) Clone GitHub repo ----------
49
- def ensure_repo() -> Path:
50
- if not REPO_DIR.exists():
51
- subprocess.run(
52
- ["git", "clone", "--depth", "1",
53
- "https://github.com/AMAAI-Lab/SonicMaster",
54
- REPO_DIR.as_posix()],
55
- check=True,
56
- )
57
- if REPO_DIR.as_posix() not in sys.path:
58
- sys.path.append(REPO_DIR.as_posix())
59
- return REPO_DIR
60
-
61
- # ---------- 3) Examples ----------
62
- def build_examples():
63
- repo = ensure_repo()
64
- wav_dir = repo / "samples" / "inputs"
65
- wav_paths = sorted(p for p in wav_dir.glob("*.wav") if p.is_file())
66
- prompts = [
67
- "Increase the clarity of this song by emphasizing treble frequencies.",
68
- "Make this song sound more boomy by amplifying the low end bass frequencies.",
69
- "Can you make this sound louder, please?",
70
- "Make the audio smoother and less distorted.",
71
- "Improve the balance in this song.",
72
- "Disentangle the left and right channels to give this song a stereo feeling.",
73
- "Correct the unnatural frequency emphasis. Reduce the roominess or echo.",
74
- "Raise the level of the vocals, please.",
75
- "Increase the clarity of this song by emphasizing treble frequencies.",
76
- "Please, dereverb this audio.",
77
- ]
78
- return [[p.as_posix(), prompts[i] if i < len(prompts) else prompts[-1]]
79
- for i, p in enumerate(wav_paths[:10])]
80
-
81
- # ---------- 4) I/O helpers ----------
82
- def save_temp_wav(wav: np.ndarray, sr: int, path: Path):
83
- if wav.ndim == 2 and wav.shape[0] < wav.shape[1]:
84
- wav = wav.T
85
- sf.write(path.as_posix(), wav, sr)
86
-
87
- def read_audio(path: str) -> Tuple[np.ndarray, int]:
88
- wav, sr = sf.read(path, always_2d=False)
89
- return wav.astype(np.float32) if wav.dtype == np.float64 else wav, sr
90
-
91
- # ---------- 5) Core inference (subprocess calling your repo script) ----------
92
- def run_sonicmaster_cli(input_wav_path: Path,
93
- prompt: str,
94
- out_path: Path,
95
- _logs: list,
96
- progress: Optional[gr.Progress] = None) -> bool:
97
- if progress: progress(0.15, desc="Loading weights & repo")
98
- ckpt = get_weights_path()
99
- repo = ensure_repo()
100
-
101
- py = sys.executable or "python3"
102
- script_candidates = [repo / "infer_single.py"]
103
-
104
- CANDIDATE_CMDS = []
105
- for script in script_candidates:
106
- if script.exists():
107
- CANDIDATE_CMDS.append([
108
- py, script.as_posix(),
109
- "--ckpt", ckpt.as_posix(),
110
- "--input", input_wav_path.as_posix(),
111
- "--prompt", prompt,
112
- "--output", out_path.as_posix(),
113
- ])
114
- CANDIDATE_CMDS.append([
115
- py, script.as_posix(),
116
- "--weights", ckpt.as_posix(),
117
- "--input", input_wav_path.as_posix(),
118
- "--text", prompt,
119
- "--out", out_path.as_posix(),
120
- ])
121
-
122
- for idx, cmd in enumerate(CANDIDATE_CMDS, start=1):
123
- try:
124
- if progress: progress(0.35 + 0.05*idx, desc=f"Running inference (try {idx})")
125
- # inherit env so CUDA_VISIBLE_DEVICES from ZeroGPU reaches subprocess
126
- subprocess.run(cmd, capture_output=True, text=True, check=True, env=os.environ.copy())
127
- if out_path.exists() and out_path.stat().st_size > 0:
128
- if progress: progress(0.9, desc="Post-processing output")
129
- return True
130
- except Exception:
131
- continue
132
- return False
133
-
134
- # ---------- 6) REAL GPU function (always defined; only CALLED on ZeroGPU) ----------
135
- @spaces.GPU(duration=180)
136
- def enhance_on_gpu(input_path: str, prompt: str, output_path: str) -> bool:
137
- # Import torch here so CUDA initializes inside GPU context
138
- try:
139
- import torch # noqa: F401
140
- except Exception:
141
- pass
142
- from pathlib import Path as _P
143
- return run_sonicmaster_cli(_P(input_path), prompt, _P(output_path), _logs=[], progress=None)
144
-
145
- # ---------- 7) Gradio callback ----------
146
- def enhance_audio_ui(audio_path: str,
147
- prompt: str,
148
- progress=gr.Progress(track_tqdm=True)) -> Tuple[int, np.ndarray]:
149
- if not audio_path or not prompt:
150
- raise gr.Error("Please provide audio and a text prompt.")
151
-
152
- wav, sr = read_audio(audio_path)
153
- tmp_in, tmp_out = SPACE_ROOT / "tmp_in.wav", SPACE_ROOT / "tmp_out.wav"
154
- if tmp_out.exists():
155
- try: tmp_out.unlink()
156
- except: pass
157
- save_temp_wav(wav, sr, tmp_in)
158
-
159
- if progress: progress(0.3, desc="Starting inference")
160
- if USE_ZEROGPU:
161
- ok = enhance_on_gpu(tmp_in.as_posix(), prompt, tmp_out.as_posix())
162
- else:
163
- ok = run_sonicmaster_cli(tmp_in, prompt, tmp_out, _logs=[], progress=progress)
164
-
165
- if ok and tmp_out.exists() and tmp_out.stat().st_size > 0:
166
- out_wav, out_sr = read_audio(tmp_out.as_posix())
167
- return (out_sr, out_wav)
168
- else:
169
- return (sr, wav)
170
-
171
- # ---------- 8) Gradio UI ----------
172
- with gr.Blocks(title="SonicMaster – Text-Guided Restoration & Mastering", fill_height=True) as demo:
173
- gr.Markdown("## 🎧 SonicMaster\nUpload or choose an example, write a text prompt, then click **Enhance**.")
174
- with gr.Row():
175
- with gr.Column():
176
- in_audio = gr.Audio(label="Input Audio", type="filepath")
177
- prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., reduce reverb")
178
- run_btn = gr.Button("🚀 Enhance", variant="primary")
179
- gr.Examples(examples=build_examples(), inputs=[in_audio, prompt])
180
- with gr.Column():
181
- out_audio = gr.Audio(label="Enhanced Audio (output)")
182
- run_btn.click(fn=enhance_audio_ui,
183
- inputs=[in_audio, prompt],
184
- outputs=[out_audio],
185
- concurrency_limit=1)
186
-
187
- # ---------- 9) FastAPI mount & disconnect handler ----------
188
- from fastapi import FastAPI, Request
189
- from starlette.responses import PlainTextResponse
190
- from starlette.requests import ClientDisconnect
191
-
192
- _ = get_weights_path(); _ = ensure_repo()
193
-
194
  app = FastAPI()
195
-
196
- @app.exception_handler(ClientDisconnect)
197
- async def client_disconnect_handler(request: Request, exc: ClientDisconnect):
198
- return PlainTextResponse("Client disconnected", status_code=499)
199
-
200
- app = gr.mount_gradio_app(app, demo.queue(max_size=16), path="/")
201
-
202
- if __name__ == "__main__":
203
- import uvicorn
204
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import os; os.environ.setdefault("GRADIO_USE_CDN","true")
 
 
 
 
 
2
  import spaces
3
+ @spaces.GPU(duration=10)
4
+ def ping(): return "pong"
 
 
 
 
 
 
 
 
 
 
5
  import gradio as gr
6
+ def fn(x): return f"ok:{x}"
7
+ with gr.Blocks() as demo:
8
+ t=gr.Textbox(); o=gr.Textbox(); t.submit(fn, t, o)
9
+ from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  app = FastAPI()
11
+ app = gr.mount_gradio_app(app, demo, path="/")