Spaces:
Runtime error
Runtime error
| from audiocraft.models import MusicGen | |
| import gradio as gr | |
| import os | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import tempfile # Safe temporary file handling | |
| from dotenv import load_dotenv | |
| import google.generativeai as genai | |
| # Load API key from environment | |
| load_dotenv() | |
| api_key = os.getenv("API_KEY") | |
| # Ensure the API key is set, otherwise, prevent errors | |
| if api_key: | |
| genai.configure(api_key=api_key) | |
| llm = genai.GenerativeModel("gemini-pro") | |
| else: | |
| llm = None # Avoid crashing if API key is missing | |
| # Load MusicGen Model | |
| def load_model(): | |
| model = MusicGen.get_pretrained("facebook/musicgen-small") | |
| return model | |
| model = load_model() | |
| # Function to generate music | |
| def generate_music(description, duration): | |
| try: | |
| # Improve description using Google Gemini | |
| if llm: | |
| context = f"""Enhance the following music prompt by adding relevant musical terms, structure, and flow. | |
| Ensure it's concise but descriptive: | |
| ORIGINAL PROMPT: {description} | |
| YOUR OUTPUT PROMPT:""" | |
| llm_result = llm.generate_content(context) | |
| enhanced_prompt = llm_result.text.strip() | |
| else: | |
| enhanced_prompt = description # Use original prompt if API is unavailable | |
| model.set_generation_params(use_sampling=True, top_k=250, duration=duration) | |
| output = model.generate(descriptions=[enhanced_prompt], progress=True) | |
| if not output or len(output) == 0: | |
| raise ValueError("Music generation failed. No output received.") | |
| return output[0], enhanced_prompt | |
| except Exception as e: | |
| print(f"Error generating music: {e}") | |
| return None, f"Error: {e}" | |
| # Save and return music file path | |
| def save_audio(samples): | |
| try: | |
| sample_rate = 32000 | |
| temp_dir = tempfile.gettempdir() # Use temp directory for safe file handling | |
| save_path = os.path.join(temp_dir, "generated_audio.wav") | |
| samples = samples.detach().cpu() | |
| if samples.dim() == 2: | |
| samples = samples[None, ...] | |
| torchaudio.save(save_path, samples[0], sample_rate) | |
| return save_path | |
| except Exception as e: | |
| print(f"Error saving audio: {e}") | |
| return None | |
| # Function to integrate with Gradio | |
| def generate_music_and_return(description, duration): | |
| music_tensors, enhanced_prompt = generate_music(description, duration) | |
| if music_tensors is None: | |
| return enhanced_prompt, None # Return error message instead of crashing | |
| audio_file_path = save_audio(music_tensors) | |
| return enhanced_prompt, audio_file_path | |
| # Gradio UI | |
| with gr.Blocks() as app: | |
| gr.Markdown("# 🎵 Text-to-Music Generator") | |
| gr.Markdown("Enter a music description, and our AI will generate a unique audio clip.") | |
| with gr.Row(): | |
| description_input = gr.Textbox(label="Enter music description") | |
| duration_input = gr.Slider(2, 20, value=5, step=1, label="Select duration (seconds)") | |
| generate_button = gr.Button("🎼 Generate Music") | |
| enhanced_description_output = gr.Textbox(label="Enhanced Description", interactive=False) | |
| audio_output = gr.Audio(label="Generated Audio") | |
| generate_button.click( | |
| generate_music_and_return, | |
| inputs=[description_input, duration_input], | |
| outputs=[enhanced_description_output, audio_output] | |
| ) | |
| app.launch() | |