beatloo / app.py
swayamshetkar's picture
Update app.py
a057abd verified
import gradio as gr
import tempfile
import torch
import scipy.io.wavfile as wavfile
from transformers import AutoProcessor, MusicgenForConditionalGeneration
# Load model
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def generate_music(prompt, duration):
if not prompt.strip():
return None, "Please enter a prompt."
if duration > 40:
return None, "❌ Duration too long β€” max allowed is 40 seconds."
# Prepare inputs
inputs = processor(text=[prompt], return_tensors="pt").to(device)
# Approximate duration scaling
approx_tokens = int(256 * (duration / 8))
approx_tokens = min(approx_tokens, 2048)
with torch.no_grad():
audio = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=approx_tokens)
sr = model.config.audio_encoder.sampling_rate
audio_arr = audio[0, 0].cpu().numpy()
# Save temp file
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
wavfile.write(tmp.name, rate=sr, data=audio_arr)
return tmp.name, f"βœ… Generated {duration}s (approx) of audio!"
with gr.Blocks(title="MusicGen 🎢") as demo:
gr.Markdown("# 🎡 MusicGen β€” Text-to-Music Generator (Stable 40s Version)")
with gr.Row():
prompt = gr.Textbox(label="🎼 Describe your music", placeholder="e.g. dreamy lo-fi with soft piano")
duration = gr.Slider(4, 40, value=15, step=1, label="Duration (seconds)")
btn = gr.Button("Generate 🎧")
audio_out = gr.Audio(label="🎢 Output", type="filepath")
msg = gr.Textbox(label="Status", interactive=False)
btn.click(generate_music, inputs=[prompt, duration], outputs=[audio_out, msg])
demo.launch(share=True)