Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,14 +4,20 @@ import os
|
|
| 4 |
import torch
|
| 5 |
import torchaudio
|
| 6 |
import numpy as np
|
| 7 |
-
import
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
import google.generativeai as genai
|
| 10 |
|
| 11 |
# Load API key from environment
|
| 12 |
load_dotenv()
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
# Load MusicGen Model
|
| 17 |
def load_model():
|
|
@@ -22,36 +28,56 @@ model = load_model()
|
|
| 22 |
|
| 23 |
# Function to generate music
|
| 24 |
def generate_music(description, duration):
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
# Save and return music file path
|
| 39 |
def save_audio(samples):
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
samples = samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
# Function to integrate with Gradio
|
| 51 |
def generate_music_and_return(description, duration):
|
| 52 |
music_tensors, enhanced_prompt = generate_music(description, duration)
|
| 53 |
-
audio_file_path = save_audio(music_tensors)
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
return enhanced_prompt, audio_file_path
|
| 56 |
|
| 57 |
# Gradio UI
|
|
@@ -73,4 +99,4 @@ with gr.Blocks() as app:
|
|
| 73 |
outputs=[enhanced_description_output, audio_output]
|
| 74 |
)
|
| 75 |
|
| 76 |
-
app.launch()
|
|
|
|
| 4 |
import torch
|
| 5 |
import torchaudio
|
| 6 |
import numpy as np
|
| 7 |
+
import tempfile # Safe temporary file handling
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
import google.generativeai as genai
|
| 10 |
|
| 11 |
# Load API key from environment
|
| 12 |
load_dotenv()
|
| 13 |
+
api_key = os.getenv("API_KEY")
|
| 14 |
+
|
| 15 |
+
# Ensure the API key is set, otherwise, prevent errors
|
| 16 |
+
if api_key:
|
| 17 |
+
genai.configure(api_key=api_key)
|
| 18 |
+
llm = genai.GenerativeModel("gemini-pro")
|
| 19 |
+
else:
|
| 20 |
+
llm = None # Avoid crashing if API key is missing
|
| 21 |
|
| 22 |
# Load MusicGen Model
|
| 23 |
def load_model():
|
|
|
|
| 28 |
|
| 29 |
# Function to generate music
|
| 30 |
def generate_music(description, duration):
|
| 31 |
+
try:
|
| 32 |
+
# Improve description using Google Gemini
|
| 33 |
+
if llm:
|
| 34 |
+
context = f"""Enhance the following music prompt by adding relevant musical terms, structure, and flow.
|
| 35 |
+
Ensure it's concise but descriptive:
|
| 36 |
+
ORIGINAL PROMPT: {description}
|
| 37 |
+
YOUR OUTPUT PROMPT:"""
|
| 38 |
+
llm_result = llm.generate_content(context)
|
| 39 |
+
enhanced_prompt = llm_result.text.strip()
|
| 40 |
+
else:
|
| 41 |
+
enhanced_prompt = description # Use original prompt if API is unavailable
|
| 42 |
|
| 43 |
+
model.set_generation_params(use_sampling=True, top_k=250, duration=duration)
|
| 44 |
+
output = model.generate(descriptions=[enhanced_prompt], progress=True)
|
| 45 |
+
|
| 46 |
+
if not output or len(output) == 0:
|
| 47 |
+
raise ValueError("Music generation failed. No output received.")
|
| 48 |
+
|
| 49 |
+
return output[0], enhanced_prompt
|
| 50 |
+
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f"Error generating music: {e}")
|
| 53 |
+
return None, f"Error: {e}"
|
| 54 |
|
| 55 |
# Save and return music file path
|
| 56 |
def save_audio(samples):
|
| 57 |
+
try:
|
| 58 |
+
sample_rate = 32000
|
| 59 |
+
temp_dir = tempfile.gettempdir() # Use temp directory for safe file handling
|
| 60 |
+
save_path = os.path.join(temp_dir, "generated_audio.wav")
|
| 61 |
+
|
| 62 |
+
samples = samples.detach().cpu()
|
| 63 |
+
if samples.dim() == 2:
|
| 64 |
+
samples = samples[None, ...]
|
| 65 |
+
|
| 66 |
+
torchaudio.save(save_path, samples[0], sample_rate)
|
| 67 |
+
return save_path
|
| 68 |
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Error saving audio: {e}")
|
| 71 |
+
return None
|
| 72 |
|
| 73 |
# Function to integrate with Gradio
|
| 74 |
def generate_music_and_return(description, duration):
|
| 75 |
music_tensors, enhanced_prompt = generate_music(description, duration)
|
|
|
|
| 76 |
|
| 77 |
+
if music_tensors is None:
|
| 78 |
+
return enhanced_prompt, None # Return error message instead of crashing
|
| 79 |
+
|
| 80 |
+
audio_file_path = save_audio(music_tensors)
|
| 81 |
return enhanced_prompt, audio_file_path
|
| 82 |
|
| 83 |
# Gradio UI
|
|
|
|
| 99 |
outputs=[enhanced_description_output, audio_output]
|
| 100 |
)
|
| 101 |
|
| 102 |
+
app.launch()
|