swayamshetkar commited on
Commit
3a0b35b
Β·
verified Β·
1 Parent(s): bfd8166

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -13,24 +13,37 @@ model.to(device)
13
  def generate_music(prompt, duration):
14
  if not prompt.strip():
15
  return None, "Please enter a prompt."
 
 
 
 
16
  inputs = processor(text=[prompt], return_tensors="pt").to(device)
17
- max_new_tokens = int(256 * (duration / 8)) # simple scaling
 
 
 
 
 
18
  with torch.no_grad():
19
  audio = model.generate(**inputs, max_new_tokens=max_new_tokens)
 
20
  sr = model.config.audio_encoder.sampling_rate
21
  audio_arr = audio[0, 0].cpu().numpy()
 
 
22
  tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
23
  wavfile.write(tmp.name, rate=sr, data=audio_arr)
24
- return tmp.name, f"Generated {duration}s of audio!"
 
25
 
26
  with gr.Blocks(title="MusicGen 🎢") as demo:
27
- gr.Markdown("# 🎡 MusicGen β€” Text-to-Music Generator (Small Model)")
28
  with gr.Row():
29
- prompt = gr.Textbox(label="Describe your music", placeholder="e.g. lo-fi hip hop with piano")
30
- duration = gr.Slider(4, 20, value=8, step=1, label="Duration (seconds)")
31
  btn = gr.Button("Generate 🎧")
32
- audio_out = gr.Audio(label="Output", type="filepath")
33
  msg = gr.Textbox(label="Status", interactive=False)
34
  btn.click(generate_music, inputs=[prompt, duration], outputs=[audio_out, msg])
35
 
36
- demo.launch()
 
13
  def generate_music(prompt, duration):
14
  if not prompt.strip():
15
  return None, "Please enter a prompt."
16
+ if duration > 60:
17
+ return None, "❌ Duration too long β€” max allowed is 60 seconds."
18
+
19
+ # Prepare inputs
20
  inputs = processor(text=[prompt], return_tensors="pt").to(device)
21
+
22
+ # Scale tokens with duration (MusicGen β‰ˆ 256 tokens β‰ˆ 8 seconds)
23
+ max_new_tokens = int(256 * (duration / 8))
24
+ max_new_tokens = min(max_new_tokens, 2048) # Safety cap
25
+
26
+ # Generate audio
27
  with torch.no_grad():
28
  audio = model.generate(**inputs, max_new_tokens=max_new_tokens)
29
+
30
  sr = model.config.audio_encoder.sampling_rate
31
  audio_arr = audio[0, 0].cpu().numpy()
32
+
33
+ # Save temp file
34
  tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
35
  wavfile.write(tmp.name, rate=sr, data=audio_arr)
36
+
37
+ return tmp.name, f"βœ… Generated {duration}s of audio!"
38
 
39
  with gr.Blocks(title="MusicGen 🎢") as demo:
40
+ gr.Markdown("# 🎡 MusicGen β€” Text-to-Music Generator (Extended 60 s Version)")
41
  with gr.Row():
42
+ prompt = gr.Textbox(label="🎼 Describe your music", placeholder="e.g. dreamy lo-fi with soft piano")
43
+ duration = gr.Slider(4, 60, value=15, step=1, label="Duration (seconds)")
44
  btn = gr.Button("Generate 🎧")
45
+ audio_out = gr.Audio(label="🎢 Output", type="filepath")
46
  msg = gr.Textbox(label="Status", interactive=False)
47
  btn.click(generate_music, inputs=[prompt, duration], outputs=[audio_out, msg])
48
 
49
+ demo.launch(share=True)