swayamshetkar commited on
Commit
0aa6bf9
Β·
verified Β·
1 Parent(s): 19a49da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -43
app.py CHANGED
@@ -2,71 +2,48 @@ import gradio as gr
2
  import tempfile
3
  import torch
4
  import scipy.io.wavfile as wavfile
5
- import os
6
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
7
 
8
- # --- CPU optimization setup ---
9
- torch.set_num_threads(os.cpu_count()) # use all CPU cores
10
- torch.set_num_interop_threads(4) # reasonable inter-op threads
11
- torch.backends.quantized.engine = 'qnnpack' # enable quantized ops where possible
12
-
13
- # --- Model Loading ---
14
- print("🧠 Loading model (optimized for CPU)...")
15
  processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
 
 
 
16
 
17
- try:
18
- # βœ… Try quantized load (if bitsandbytes available)
19
- from transformers import BitsAndBytesConfig
20
- bnb_config = BitsAndBytesConfig(load_in_8bit=True)
21
- model = MusicgenForConditionalGeneration.from_pretrained(
22
- "facebook/musicgen-small",
23
- quantization_config=bnb_config,
24
- device_map="cpu"
25
- )
26
- print("βœ… Using 8-bit quantized model")
27
- except Exception:
28
- # fallback
29
- model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
30
- model = model.to("cpu").to(torch.float16)
31
- print("βš™οΈ Using standard float16 CPU model")
32
-
33
- device = "cpu"
34
- MAX_DURATION = 30 # hard cap for CPU β€” can do 30s comfortably
35
-
36
- # --- Generation Function ---
37
  def generate_music(prompt, duration):
38
  if not prompt.strip():
39
- return None, "❌ Please enter a music prompt."
40
- if duration > MAX_DURATION:
41
- return None, f"⚠️ Duration too long for CPU β€” max allowed is {MAX_DURATION} seconds."
42
-
43
- # Dynamic token scaling (smaller = faster)
44
- max_new_tokens = int(128 * (duration / 8))
45
- max_new_tokens = min(max_new_tokens, 1024) # cap for stability
46
 
 
47
  inputs = processor(text=[prompt], return_tensors="pt").to(device)
48
 
 
 
 
 
 
49
  with torch.no_grad():
50
  audio = model.generate(**inputs, max_new_tokens=max_new_tokens)
51
 
52
  sr = model.config.audio_encoder.sampling_rate
53
  audio_arr = audio[0, 0].cpu().numpy()
54
 
 
55
  tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
56
  wavfile.write(tmp.name, rate=sr, data=audio_arr)
57
 
58
- return tmp.name, f"βœ… Generated {duration}s of audio on CPU!"
59
 
60
- # --- Gradio UI ---
61
- with gr.Blocks(title="🎢 MusicGen β€” CPU Optimized") as demo:
62
- gr.Markdown("# 🎡 MusicGen β€” Text-to-Music (CPU Fast Mode)")
63
  with gr.Row():
64
  prompt = gr.Textbox(label="🎼 Describe your music", placeholder="e.g. dreamy lo-fi with soft piano")
65
- duration = gr.Slider(4, MAX_DURATION, value=10, step=1, label="Duration (seconds)")
66
- btn = gr.Button("🎧 Generate Music")
67
  audio_out = gr.Audio(label="🎢 Output", type="filepath")
68
  msg = gr.Textbox(label="Status", interactive=False)
69
  btn.click(generate_music, inputs=[prompt, duration], outputs=[audio_out, msg])
70
 
71
- if __name__ == "__main__":
72
- demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
 
2
  import tempfile
3
  import torch
4
  import scipy.io.wavfile as wavfile
 
5
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
6
 
7
+ # Load model
 
 
 
 
 
 
8
  processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
9
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model.to(device)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def generate_music(prompt, duration):
14
  if not prompt.strip():
15
+ return None, "Please enter a prompt."
16
+ if duration > 60:
17
+ return None, "❌ Duration too long β€” max allowed is 60 seconds."
 
 
 
 
18
 
19
+ # Prepare inputs
20
  inputs = processor(text=[prompt], return_tensors="pt").to(device)
21
 
22
+ # Scale tokens with duration (MusicGen β‰ˆ 256 tokens β‰ˆ 8 seconds)
23
+ max_new_tokens = int(256 * (duration / 8))
24
+ max_new_tokens = min(max_new_tokens, 2048) # Safety cap
25
+
26
+ # Generate audio
27
  with torch.no_grad():
28
  audio = model.generate(**inputs, max_new_tokens=max_new_tokens)
29
 
30
  sr = model.config.audio_encoder.sampling_rate
31
  audio_arr = audio[0, 0].cpu().numpy()
32
 
33
+ # Save temp file
34
  tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
35
  wavfile.write(tmp.name, rate=sr, data=audio_arr)
36
 
37
+ return tmp.name, f"βœ… Generated {duration}s of audio!"
38
 
39
+ with gr.Blocks(title="MusicGen 🎢") as demo:
40
+ gr.Markdown("# 🎡 MusicGen β€” Text-to-Music Generator (Extended 60 s Version)")
 
41
  with gr.Row():
42
  prompt = gr.Textbox(label="🎼 Describe your music", placeholder="e.g. dreamy lo-fi with soft piano")
43
+ duration = gr.Slider(4, 60, value=15, step=1, label="Duration (seconds)")
44
+ btn = gr.Button("Generate 🎧")
45
  audio_out = gr.Audio(label="🎢 Output", type="filepath")
46
  msg = gr.Textbox(label="Status", interactive=False)
47
  btn.click(generate_music, inputs=[prompt, duration], outputs=[audio_out, msg])
48
 
49
+ demo.launch(share=True)