TexttoMusic / app.py
LR36's picture
Update app.py
7d2410b verified
from audiocraft.models import MusicGen
import gradio as gr
import os
import torch
import torchaudio
import numpy as np
import tempfile # Safe temporary file handling
from dotenv import load_dotenv
import google.generativeai as genai
# Load API key from environment
load_dotenv()
api_key = os.getenv("API_KEY")
# Ensure the API key is set, otherwise, prevent errors
if api_key:
genai.configure(api_key=api_key)
llm = genai.GenerativeModel("gemini-pro")
else:
llm = None # Avoid crashing if API key is missing
# Load MusicGen Model
def load_model():
model = MusicGen.get_pretrained("facebook/musicgen-small")
return model
model = load_model()
# Function to generate music
def generate_music(description, duration):
try:
# Improve description using Google Gemini
if llm:
context = f"""Enhance the following music prompt by adding relevant musical terms, structure, and flow.
Ensure it's concise but descriptive:
ORIGINAL PROMPT: {description}
YOUR OUTPUT PROMPT:"""
llm_result = llm.generate_content(context)
enhanced_prompt = llm_result.text.strip()
else:
enhanced_prompt = description # Use original prompt if API is unavailable
model.set_generation_params(use_sampling=True, top_k=250, duration=duration)
output = model.generate(descriptions=[enhanced_prompt], progress=True)
if not output or len(output) == 0:
raise ValueError("Music generation failed. No output received.")
return output[0], enhanced_prompt
except Exception as e:
print(f"Error generating music: {e}")
return None, f"Error: {e}"
# Save and return music file path
def save_audio(samples):
try:
sample_rate = 32000
temp_dir = tempfile.gettempdir() # Use temp directory for safe file handling
save_path = os.path.join(temp_dir, "generated_audio.wav")
samples = samples.detach().cpu()
if samples.dim() == 2:
samples = samples[None, ...]
torchaudio.save(save_path, samples[0], sample_rate)
return save_path
except Exception as e:
print(f"Error saving audio: {e}")
return None
# Function to integrate with Gradio
def generate_music_and_return(description, duration):
music_tensors, enhanced_prompt = generate_music(description, duration)
if music_tensors is None:
return enhanced_prompt, None # Return error message instead of crashing
audio_file_path = save_audio(music_tensors)
return enhanced_prompt, audio_file_path
# Gradio UI
with gr.Blocks() as app:
gr.Markdown("# 🎵 Text-to-Music Generator")
gr.Markdown("Enter a music description, and our AI will generate a unique audio clip.")
with gr.Row():
description_input = gr.Textbox(label="Enter music description")
duration_input = gr.Slider(2, 20, value=5, step=1, label="Select duration (seconds)")
generate_button = gr.Button("🎼 Generate Music")
enhanced_description_output = gr.Textbox(label="Enhanced Description", interactive=False)
audio_output = gr.Audio(label="Generated Audio")
generate_button.click(
generate_music_and_return,
inputs=[description_input, duration_input],
outputs=[enhanced_description_output, audio_output]
)
app.launch()