LR36 commited on
Commit
7d2410b
·
verified ·
1 Parent(s): 42c40ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -24
app.py CHANGED
@@ -4,14 +4,20 @@ import os
4
  import torch
5
  import torchaudio
6
  import numpy as np
7
- import base64
8
  from dotenv import load_dotenv
9
  import google.generativeai as genai
10
 
11
  # Load API key from environment
12
  load_dotenv()
13
- genai.configure(api_key=os.getenv("API_KEY"))
14
- llm = genai.GenerativeModel("gemini-pro")
 
 
 
 
 
 
15
 
16
  # Load MusicGen Model
17
  def load_model():
@@ -22,36 +28,56 @@ model = load_model()
22
 
23
  # Function to generate music
24
  def generate_music(description, duration):
25
- context = f"""Enhance the following music prompt by adding relevant musical terms, structure, and flow.
26
- Ensure it's concise but descriptive:
27
- ORIGINAL PROMPT: {description}
28
- YOUR OUTPUT PROMPT:"""
29
-
30
- llm_result = llm.generate_content(context)
31
- enhanced_prompt = llm_result.text.strip()
 
 
 
 
32
 
33
- model.set_generation_params(use_sampling=True, top_k=250, duration=duration)
34
- output = model.generate(descriptions=[enhanced_prompt], progress=True, return_tokens=True)
35
-
36
- return output[0], enhanced_prompt
 
 
 
 
 
 
 
37
 
38
  # Save and return music file path
39
  def save_audio(samples):
40
- sample_rate = 32000
41
- save_path = "generated_audio.wav"
42
-
43
- samples = samples.detach().cpu()
44
- if samples.dim() == 2:
45
- samples = samples[None, ...]
 
 
 
 
 
46
 
47
- torchaudio.save(save_path, samples[0], sample_rate)
48
- return save_path
 
49
 
50
  # Function to integrate with Gradio
51
  def generate_music_and_return(description, duration):
52
  music_tensors, enhanced_prompt = generate_music(description, duration)
53
- audio_file_path = save_audio(music_tensors)
54
 
 
 
 
 
55
  return enhanced_prompt, audio_file_path
56
 
57
  # Gradio UI
@@ -73,4 +99,4 @@ with gr.Blocks() as app:
73
  outputs=[enhanced_description_output, audio_output]
74
  )
75
 
76
- app.launch()
 
4
  import torch
5
  import torchaudio
6
  import numpy as np
7
+ import tempfile # Safe temporary file handling
8
  from dotenv import load_dotenv
9
  import google.generativeai as genai
10
 
11
  # Load API key from environment
12
  load_dotenv()
13
+ api_key = os.getenv("API_KEY")
14
+
15
+ # Ensure the API key is set, otherwise, prevent errors
16
+ if api_key:
17
+ genai.configure(api_key=api_key)
18
+ llm = genai.GenerativeModel("gemini-pro")
19
+ else:
20
+ llm = None # Avoid crashing if API key is missing
21
 
22
  # Load MusicGen Model
23
  def load_model():
 
28
 
29
  # Function to generate music
30
  def generate_music(description, duration):
31
+ try:
32
+ # Improve description using Google Gemini
33
+ if llm:
34
+ context = f"""Enhance the following music prompt by adding relevant musical terms, structure, and flow.
35
+ Ensure it's concise but descriptive:
36
+ ORIGINAL PROMPT: {description}
37
+ YOUR OUTPUT PROMPT:"""
38
+ llm_result = llm.generate_content(context)
39
+ enhanced_prompt = llm_result.text.strip()
40
+ else:
41
+ enhanced_prompt = description # Use original prompt if API is unavailable
42
 
43
+ model.set_generation_params(use_sampling=True, top_k=250, duration=duration)
44
+ output = model.generate(descriptions=[enhanced_prompt], progress=True)
45
+
46
+ if not output or len(output) == 0:
47
+ raise ValueError("Music generation failed. No output received.")
48
+
49
+ return output[0], enhanced_prompt
50
+
51
+ except Exception as e:
52
+ print(f"Error generating music: {e}")
53
+ return None, f"Error: {e}"
54
 
55
  # Save and return music file path
56
  def save_audio(samples):
57
+ try:
58
+ sample_rate = 32000
59
+ temp_dir = tempfile.gettempdir() # Use temp directory for safe file handling
60
+ save_path = os.path.join(temp_dir, "generated_audio.wav")
61
+
62
+ samples = samples.detach().cpu()
63
+ if samples.dim() == 2:
64
+ samples = samples[None, ...]
65
+
66
+ torchaudio.save(save_path, samples[0], sample_rate)
67
+ return save_path
68
 
69
+ except Exception as e:
70
+ print(f"Error saving audio: {e}")
71
+ return None
72
 
73
  # Function to integrate with Gradio
74
  def generate_music_and_return(description, duration):
75
  music_tensors, enhanced_prompt = generate_music(description, duration)
 
76
 
77
+ if music_tensors is None:
78
+ return enhanced_prompt, None # Return error message instead of crashing
79
+
80
+ audio_file_path = save_audio(music_tensors)
81
  return enhanced_prompt, audio_file_path
82
 
83
  # Gradio UI
 
99
  outputs=[enhanced_description_output, audio_output]
100
  )
101
 
102
+ app.launch()