Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ import random
|
|
| 5 |
import os
|
| 6 |
import torch
|
| 7 |
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
| 8 |
-
|
| 9 |
|
| 10 |
# Load MusicGen model
|
| 11 |
model_name = "facebook/musicgen-small"
|
|
@@ -21,8 +21,17 @@ def generate_music(prompt):
|
|
| 21 |
return (sampling_rate, audio)
|
| 22 |
|
| 23 |
# Function to get enhanced prompt via GPT
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def refine_prompt(user_input):
|
| 28 |
completion = client.chat.completions.create(
|
|
|
|
| 5 |
import os
|
| 6 |
import torch
|
| 7 |
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
| 8 |
+
import openai
|
| 9 |
|
| 10 |
# Load MusicGen model
|
| 11 |
model_name = "facebook/musicgen-small"
|
|
|
|
| 21 |
return (sampling_rate, audio)
|
| 22 |
|
| 23 |
# Function to get enhanced prompt via GPT
|
| 24 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 25 |
+
|
| 26 |
+
def refine_prompt(user_input):
|
| 27 |
+
response = openai.ChatCompletion.create(
|
| 28 |
+
model="gpt-4",
|
| 29 |
+
messages=[
|
| 30 |
+
{"role": "system", "content": "You are a music assistant. Make the user's input more descriptive for an AI music generator."},
|
| 31 |
+
{"role": "user", "content": user_input}
|
| 32 |
+
]
|
| 33 |
+
)
|
| 34 |
+
return response.choices[0].message["content"].strip()
|
| 35 |
|
| 36 |
def refine_prompt(user_input):
|
| 37 |
completion = client.chat.completions.create(
|