Vansh Chugh commited on
Commit ·
49e9b94
1
Parent(s): 2fb0615
lazy loading so server startup is not delayed
Browse files- .gitignore +2 -0
- app.py +45 -30
.gitignore
CHANGED
|
@@ -2,5 +2,7 @@ __pycache__/
|
|
| 2 |
*.pyc
|
| 3 |
pretrained/
|
| 4 |
AnyAccomp-repo/
|
|
|
|
|
|
|
| 5 |
DEPLOY.md
|
| 6 |
instructions.md
|
|
|
|
| 2 |
*.pyc
|
| 3 |
pretrained/
|
| 4 |
AnyAccomp-repo/
|
| 5 |
+
.venv-test/
|
| 6 |
+
test_gradio_health.py
|
| 7 |
DEPLOY.md
|
| 8 |
instructions.md
|
app.py
CHANGED
|
@@ -3,6 +3,7 @@ sys.stdout.reconfigure(line_buffering=True)
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import os
|
|
|
|
| 6 |
# import spaces # uncomment when running on ZeroGPU
|
| 7 |
import librosa
|
| 8 |
import soundfile as sf
|
|
@@ -14,20 +15,7 @@ from anyaccomp.inference_utils import Sing2SongInferencePipeline
|
|
| 14 |
|
| 15 |
repo_id = "amphion/anyaccomp"
|
| 16 |
base_dir = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
-
|
| 18 |
-
checkpoint_marker = os.path.join(base_dir, "pretrained", "flow_matching")
|
| 19 |
-
if not os.path.exists(checkpoint_marker):
|
| 20 |
-
print(f"Downloading model files from {repo_id}...", flush=True)
|
| 21 |
-
model_dir = snapshot_download(repo_id=repo_id, local_dir=base_dir)
|
| 22 |
-
print(f"Model files downloaded to: {model_dir}", flush=True)
|
| 23 |
-
else:
|
| 24 |
-
model_dir = base_dir
|
| 25 |
-
print(f"Model files already present, skipping download.", flush=True)
|
| 26 |
-
|
| 27 |
-
CFG_PATH = os.path.join(base_dir, "config/flow_matching.json")
|
| 28 |
-
VOCODER_CFG_PATH = os.path.join(base_dir, "config/vocoder.json")
|
| 29 |
-
CHECKPOINT_PATH = os.path.join(model_dir, "pretrained/flow_matching")
|
| 30 |
-
VOCODER_CHECKPOINT_PATH = os.path.join(model_dir, "pretrained/vocoder")
|
| 31 |
|
| 32 |
INFER_DST = os.path.join(base_dir, "output_gradio")
|
| 33 |
os.makedirs(INFER_DST, exist_ok=True)
|
|
@@ -36,28 +24,55 @@ mixture_dst = os.path.join(INFER_DST, "mixture")
|
|
| 36 |
os.makedirs(acc_dst, exist_ok=True)
|
| 37 |
os.makedirs(mixture_dst, exist_ok=True)
|
| 38 |
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
inference_pipeline
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
# @spaces.GPU # uncomment when running on ZeroGPU
|
| 58 |
def sing2song_inference(vocal_filepath, n_timesteps, cfg_scale):
|
|
|
|
|
|
|
| 59 |
if inference_pipeline is None:
|
| 60 |
-
raise gr.Error("Model
|
| 61 |
|
| 62 |
if vocal_filepath is None:
|
| 63 |
raise gr.Error("Please upload a vocal audio file.")
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import os
|
| 6 |
+
import threading
|
| 7 |
# import spaces # uncomment when running on ZeroGPU
|
| 8 |
import librosa
|
| 9 |
import soundfile as sf
|
|
|
|
| 15 |
|
| 16 |
repo_id = "amphion/anyaccomp"
|
| 17 |
base_dir = os.path.dirname(os.path.abspath(__file__))
|
| 18 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
INFER_DST = os.path.join(base_dir, "output_gradio")
|
| 21 |
os.makedirs(INFER_DST, exist_ok=True)
|
|
|
|
| 24 |
os.makedirs(acc_dst, exist_ok=True)
|
| 25 |
os.makedirs(mixture_dst, exist_ok=True)
|
| 26 |
|
| 27 |
+
# Model loading state — populated by the background thread.
|
| 28 |
+
inference_pipeline = None
|
| 29 |
+
model_loading = True
|
| 30 |
+
model_error = None
|
| 31 |
|
| 32 |
+
|
| 33 |
+
def load_model():
|
| 34 |
+
global inference_pipeline, model_loading, model_error
|
| 35 |
+
try:
|
| 36 |
+
checkpoint_marker = os.path.join(base_dir, "pretrained", "flow_matching")
|
| 37 |
+
if not os.path.exists(checkpoint_marker):
|
| 38 |
+
print(f"Downloading model files from {repo_id}...", flush=True)
|
| 39 |
+
model_dir = snapshot_download(repo_id=repo_id, local_dir=base_dir)
|
| 40 |
+
print(f"Model files downloaded to: {model_dir}", flush=True)
|
| 41 |
+
else:
|
| 42 |
+
model_dir = base_dir
|
| 43 |
+
print("Model files already present, skipping download.", flush=True)
|
| 44 |
+
|
| 45 |
+
cfg_path = os.path.join(base_dir, "config/flow_matching.json")
|
| 46 |
+
vocoder_cfg_path = os.path.join(base_dir, "config/vocoder.json")
|
| 47 |
+
checkpoint_path = os.path.join(model_dir, "pretrained/flow_matching")
|
| 48 |
+
vocoder_checkpoint_path = os.path.join(model_dir, "pretrained/vocoder")
|
| 49 |
+
|
| 50 |
+
print("Initializing AnyAccomp InferencePipeline...", flush=True)
|
| 51 |
+
pipeline = Sing2SongInferencePipeline(
|
| 52 |
+
checkpoint_path, cfg_path,
|
| 53 |
+
vocoder_checkpoint_path, vocoder_cfg_path,
|
| 54 |
+
device=DEVICE,
|
| 55 |
+
)
|
| 56 |
+
pipeline.sample_rate = 24000
|
| 57 |
+
inference_pipeline = pipeline
|
| 58 |
+
print("Model loaded successfully.", flush=True)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
model_error = str(e)
|
| 61 |
+
print(f"Error loading model: {e}", flush=True)
|
| 62 |
+
finally:
|
| 63 |
+
model_loading = False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Start model loading in the background so the server can start immediately.
|
| 67 |
+
threading.Thread(target=load_model, daemon=True).start()
|
| 68 |
|
| 69 |
|
| 70 |
# @spaces.GPU # uncomment when running on ZeroGPU
|
| 71 |
def sing2song_inference(vocal_filepath, n_timesteps, cfg_scale):
|
| 72 |
+
if model_loading:
|
| 73 |
+
raise gr.Error("Model is still loading, please wait a moment and try again.")
|
| 74 |
if inference_pipeline is None:
|
| 75 |
+
raise gr.Error(f"Model failed to load: {model_error}")
|
| 76 |
|
| 77 |
if vocal_filepath is None:
|
| 78 |
raise gr.Error("Please upload a vocal audio file.")
|