mrfakename commited on
Commit
bcb09d1
·
verified ·
1 Parent(s): 4af893c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -40
app.py CHANGED
@@ -1,13 +1,9 @@
1
- """
2
- HeartMuLa Gradio App - Music Generation with Lyrics and Tags
3
- A self-contained Gradio app for Hugging Face Spaces
4
- """
5
-
6
  import os
7
  import tempfile
8
  import torch
9
  import gradio as gr
10
  from huggingface_hub import hf_hub_download, snapshot_download
 
11
 
12
  # Download models from HuggingFace Hub on startup
13
  def download_models():
@@ -44,40 +40,29 @@ def download_models():
44
  print("All models downloaded successfully!")
45
  return model_dir
46
 
 
47
 
48
- # Global pipeline instance
49
- pipeline = None
50
-
51
-
52
- def load_pipeline():
53
- """Load the HeartMuLa pipeline."""
54
- global pipeline
55
- if pipeline is not None:
56
- return pipeline
57
 
58
- from heartlib import HeartMuLaGenPipeline
 
 
 
 
 
 
59
 
60
- model_dir = download_models()
61
-
62
- # Determine device and dtype
63
- if torch.cuda.is_available():
64
- device = torch.device("cuda")
65
- dtype = torch.bfloat16
66
- else:
67
- device = torch.device("cpu")
68
- dtype = torch.float32
69
-
70
- print(f"Loading pipeline on {device} with {dtype}...")
71
- pipeline = HeartMuLaGenPipeline.from_pretrained(
72
- model_dir,
73
- device=device,
74
- dtype=dtype,
75
- version="3B",
76
- )
77
- print("Pipeline loaded successfully!")
78
- return pipeline
79
 
80
 
 
81
  def generate_music(
82
  lyrics: str,
83
  tags: str,
@@ -286,10 +271,5 @@ with gr.Blocks(
286
  )
287
 
288
 
289
- if __name__ == "__main__":
290
- # Preload models on startup
291
- print("Initializing HeartMuLa...")
292
- load_pipeline()
293
 
294
- # Launch the app
295
- demo.launch()
 
 
 
 
 
 
1
  import os
2
  import tempfile
3
  import torch
4
  import gradio as gr
5
  from huggingface_hub import hf_hub_download, snapshot_download
6
+ import spaces
7
 
8
  # Download models from HuggingFace Hub on startup
9
  def download_models():
 
40
  print("All models downloaded successfully!")
41
  return model_dir
42
 
43
+ from heartlib import HeartMuLaGenPipeline
44
 
45
+ model_dir = download_models()
 
 
 
 
 
 
 
 
46
 
47
+ # Determine device and dtype
48
+ if torch.cuda.is_available():
49
+ device = torch.device("cuda")
50
+ dtype = torch.bfloat16
51
+ else:
52
+ device = torch.device("cpu")
53
+ dtype = torch.float32
54
 
55
+ print(f"Loading pipeline on {device} with {dtype}...")
56
+ pipeline = HeartMuLaGenPipeline.from_pretrained(
57
+ model_dir,
58
+ device=device,
59
+ dtype=dtype,
60
+ version="3B",
61
+ )
62
+ print("Pipeline loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
+ @spaces.GPU
66
  def generate_music(
67
  lyrics: str,
68
  tags: str,
 
271
  )
272
 
273
 
 
 
 
 
274
 
275
+ demo.launch()