Vansh Chugh commited on
Commit
49e9b94
·
1 Parent(s): 2fb0615

lazy loading so server startup is not delayed

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +45 -30
.gitignore CHANGED
@@ -2,5 +2,7 @@ __pycache__/
2
  *.pyc
3
  pretrained/
4
  AnyAccomp-repo/
 
 
5
  DEPLOY.md
6
  instructions.md
 
2
  *.pyc
3
  pretrained/
4
  AnyAccomp-repo/
5
+ .venv-test/
6
+ test_gradio_health.py
7
  DEPLOY.md
8
  instructions.md
app.py CHANGED
@@ -3,6 +3,7 @@ sys.stdout.reconfigure(line_buffering=True)
3
 
4
  import torch
5
  import os
 
6
  # import spaces # uncomment when running on ZeroGPU
7
  import librosa
8
  import soundfile as sf
@@ -14,20 +15,7 @@ from anyaccomp.inference_utils import Sing2SongInferencePipeline
14
 
15
  repo_id = "amphion/anyaccomp"
16
  base_dir = os.path.dirname(os.path.abspath(__file__))
17
-
18
- checkpoint_marker = os.path.join(base_dir, "pretrained", "flow_matching")
19
- if not os.path.exists(checkpoint_marker):
20
- print(f"Downloading model files from {repo_id}...", flush=True)
21
- model_dir = snapshot_download(repo_id=repo_id, local_dir=base_dir)
22
- print(f"Model files downloaded to: {model_dir}", flush=True)
23
- else:
24
- model_dir = base_dir
25
- print(f"Model files already present, skipping download.", flush=True)
26
-
27
- CFG_PATH = os.path.join(base_dir, "config/flow_matching.json")
28
- VOCODER_CFG_PATH = os.path.join(base_dir, "config/vocoder.json")
29
- CHECKPOINT_PATH = os.path.join(model_dir, "pretrained/flow_matching")
30
- VOCODER_CHECKPOINT_PATH = os.path.join(model_dir, "pretrained/vocoder")
31
 
32
  INFER_DST = os.path.join(base_dir, "output_gradio")
33
  os.makedirs(INFER_DST, exist_ok=True)
@@ -36,28 +24,55 @@ mixture_dst = os.path.join(INFER_DST, "mixture")
36
  os.makedirs(acc_dst, exist_ok=True)
37
  os.makedirs(mixture_dst, exist_ok=True)
38
 
39
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
40
 
41
- print("Initializing AnyAccomp InferencePipeline...")
42
- try:
43
- inference_pipeline = Sing2SongInferencePipeline(
44
- CHECKPOINT_PATH,
45
- CFG_PATH,
46
- VOCODER_CHECKPOINT_PATH,
47
- VOCODER_CFG_PATH,
48
- device=DEVICE,
49
- )
50
- inference_pipeline.sample_rate = 24000
51
- print("Model loaded successfully.")
52
- except Exception as e:
53
- print(f"Error loading model: {e}")
54
- inference_pipeline = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  # @spaces.GPU # uncomment when running on ZeroGPU
58
  def sing2song_inference(vocal_filepath, n_timesteps, cfg_scale):
 
 
59
  if inference_pipeline is None:
60
- raise gr.Error("Model could not be loaded. Please check paths and environment configuration.")
61
 
62
  if vocal_filepath is None:
63
  raise gr.Error("Please upload a vocal audio file.")
 
3
 
4
  import torch
5
  import os
6
+ import threading
7
  # import spaces # uncomment when running on ZeroGPU
8
  import librosa
9
  import soundfile as sf
 
15
 
16
  repo_id = "amphion/anyaccomp"
17
  base_dir = os.path.dirname(os.path.abspath(__file__))
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  INFER_DST = os.path.join(base_dir, "output_gradio")
21
  os.makedirs(INFER_DST, exist_ok=True)
 
24
  os.makedirs(acc_dst, exist_ok=True)
25
  os.makedirs(mixture_dst, exist_ok=True)
26
 
27
+ # Model loading state populated by the background thread.
28
+ inference_pipeline = None
29
+ model_loading = True
30
+ model_error = None
31
 
32
+
33
+ def load_model():
34
+ global inference_pipeline, model_loading, model_error
35
+ try:
36
+ checkpoint_marker = os.path.join(base_dir, "pretrained", "flow_matching")
37
+ if not os.path.exists(checkpoint_marker):
38
+ print(f"Downloading model files from {repo_id}...", flush=True)
39
+ model_dir = snapshot_download(repo_id=repo_id, local_dir=base_dir)
40
+ print(f"Model files downloaded to: {model_dir}", flush=True)
41
+ else:
42
+ model_dir = base_dir
43
+ print("Model files already present, skipping download.", flush=True)
44
+
45
+ cfg_path = os.path.join(base_dir, "config/flow_matching.json")
46
+ vocoder_cfg_path = os.path.join(base_dir, "config/vocoder.json")
47
+ checkpoint_path = os.path.join(model_dir, "pretrained/flow_matching")
48
+ vocoder_checkpoint_path = os.path.join(model_dir, "pretrained/vocoder")
49
+
50
+ print("Initializing AnyAccomp InferencePipeline...", flush=True)
51
+ pipeline = Sing2SongInferencePipeline(
52
+ checkpoint_path, cfg_path,
53
+ vocoder_checkpoint_path, vocoder_cfg_path,
54
+ device=DEVICE,
55
+ )
56
+ pipeline.sample_rate = 24000
57
+ inference_pipeline = pipeline
58
+ print("Model loaded successfully.", flush=True)
59
+ except Exception as e:
60
+ model_error = str(e)
61
+ print(f"Error loading model: {e}", flush=True)
62
+ finally:
63
+ model_loading = False
64
+
65
+
66
+ # Start model loading in the background so the server can start immediately.
67
+ threading.Thread(target=load_model, daemon=True).start()
68
 
69
 
70
  # @spaces.GPU # uncomment when running on ZeroGPU
71
  def sing2song_inference(vocal_filepath, n_timesteps, cfg_scale):
72
+ if model_loading:
73
+ raise gr.Error("Model is still loading, please wait a moment and try again.")
74
  if inference_pipeline is None:
75
+ raise gr.Error(f"Model failed to load: {model_error}")
76
 
77
  if vocal_filepath is None:
78
  raise gr.Error("Please upload a vocal audio file.")