File size: 3,395 Bytes
318ee25
77f602d
318ee25
 
 
 
7d2410b
318ee25
 
 
77f602d
 
7d2410b
 
 
 
 
 
 
 
318ee25
77f602d
318ee25
 
 
 
77f602d
318ee25
77f602d
 
7d2410b
 
 
 
 
 
 
 
 
 
 
318ee25
7d2410b
 
 
 
 
 
 
 
 
 
 
318ee25
77f602d
 
7d2410b
 
 
 
 
 
 
 
 
 
 
318ee25
7d2410b
 
 
77f602d
 
 
 
 
7d2410b
 
 
 
77f602d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318ee25
7d2410b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from audiocraft.models import MusicGen
import gradio as gr
import os
import torch
import torchaudio
import numpy as np
import tempfile  # Safe temporary file handling
from dotenv import load_dotenv
import google.generativeai as genai

# Load API key from environment
load_dotenv()
api_key = os.getenv("API_KEY")

# Ensure the API key is set, otherwise, prevent errors
if api_key:
    genai.configure(api_key=api_key)
    llm = genai.GenerativeModel("gemini-pro")
else:
    llm = None  # Avoid crashing if API key is missing

# Load MusicGen Model
def load_model():
    model = MusicGen.get_pretrained("facebook/musicgen-small")
    return model

model = load_model()

# Function to generate music
def generate_music(description, duration):
    try:
        # Improve description using Google Gemini
        if llm:
            context = f"""Enhance the following music prompt by adding relevant musical terms, structure, and flow. 
            Ensure it's concise but descriptive:
            ORIGINAL PROMPT: {description}
            YOUR OUTPUT PROMPT:"""
            llm_result = llm.generate_content(context)
            enhanced_prompt = llm_result.text.strip()
        else:
            enhanced_prompt = description  # Use original prompt if API is unavailable

        model.set_generation_params(use_sampling=True, top_k=250, duration=duration)
        output = model.generate(descriptions=[enhanced_prompt], progress=True)

        if not output or len(output) == 0:
            raise ValueError("Music generation failed. No output received.")

        return output[0], enhanced_prompt

    except Exception as e:
        print(f"Error generating music: {e}")
        return None, f"Error: {e}"

# Save and return music file path
def save_audio(samples):
    try:
        sample_rate = 32000
        temp_dir = tempfile.gettempdir()  # Use temp directory for safe file handling
        save_path = os.path.join(temp_dir, "generated_audio.wav")
        
        samples = samples.detach().cpu()
        if samples.dim() == 2:
            samples = samples[None, ...]

        torchaudio.save(save_path, samples[0], sample_rate)
        return save_path

    except Exception as e:
        print(f"Error saving audio: {e}")
        return None

# Function to integrate with Gradio
def generate_music_and_return(description, duration):
    music_tensors, enhanced_prompt = generate_music(description, duration)
    
    if music_tensors is None:
        return enhanced_prompt, None  # Return error message instead of crashing

    audio_file_path = save_audio(music_tensors)
    return enhanced_prompt, audio_file_path

# Gradio UI
with gr.Blocks() as app:
    gr.Markdown("# 🎵 Text-to-Music Generator")
    gr.Markdown("Enter a music description, and our AI will generate a unique audio clip.")

    with gr.Row():
        description_input = gr.Textbox(label="Enter music description")
        duration_input = gr.Slider(2, 20, value=5, step=1, label="Select duration (seconds)")

    generate_button = gr.Button("🎼 Generate Music")
    enhanced_description_output = gr.Textbox(label="Enhanced Description", interactive=False)
    audio_output = gr.Audio(label="Generated Audio")

    generate_button.click(
        generate_music_and_return, 
        inputs=[description_input, duration_input], 
        outputs=[enhanced_description_output, audio_output]
    )

app.launch()