swayamshetkar commited on
Commit
c357bca
·
verified ·
1 Parent(s): ca004f3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ import torch
4
+ import scipy.io.wavfile as wavfile
5
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
6
+
7
+ # Load model
8
+ processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
9
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model.to(device)
12
+
13
+ def generate_music(prompt, duration):
14
+ if not prompt.strip():
15
+ return None, "Please enter a prompt."
16
+ inputs = processor(text=[prompt], return_tensors="pt").to(device)
17
+ max_new_tokens = int(256 * (duration / 8)) # simple scaling
18
+ with torch.no_grad():
19
+ audio = model.generate(**inputs, max_new_tokens=max_new_tokens)
20
+ sr = model.config.audio_encoder.sampling_rate
21
+ audio_arr = audio[0, 0].cpu().numpy()
22
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
23
+ wavfile.write(tmp.name, rate=sr, data=audio_arr)
24
+ return tmp.name, f"Generated {duration}s of audio!"
25
+
26
+ with gr.Blocks(title="MusicGen 🎶") as demo:
27
+ gr.Markdown("# 🎵 MusicGen — Text-to-Music Generator (Small Model)")
28
+ with gr.Row():
29
+ prompt = gr.Textbox(label="Describe your music", placeholder="e.g. lo-fi hip hop with piano")
30
+ duration = gr.Slider(4, 20, value=8, step=1, label="Duration (seconds)")
31
+ btn = gr.Button("Generate 🎧")
32
+ audio_out = gr.Audio(label="Output", type="filepath")
33
+ msg = gr.Textbox(label="Status", interactive=False)
34
+ btn.click(generate_music, inputs=[prompt, duration], outputs=[audio_out, msg])
35
+
36
+ demo.launch()