Spaces:
Sleeping
Sleeping
Commit
·
3c30d79
1
Parent(s):
fcbf149
Added: Music Duration Control from the client end
Browse files
client.py
CHANGED
|
@@ -16,14 +16,16 @@ parser.add_argument(
|
|
| 16 |
parser.add_argument(
|
| 17 |
"--output_file", type=str, default="output.wav", help="Output file name"
|
| 18 |
)
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
args = parser.parse_args()
|
| 21 |
|
| 22 |
-
|
| 23 |
-
def generate_music(server_url, prompts, output_file):
|
| 24 |
url = f"{server_url}/generate_music"
|
| 25 |
headers = {"Content-Type": "application/json"}
|
| 26 |
-
data = {"prompts": prompts}
|
| 27 |
|
| 28 |
response = requests.post(url, json=data, headers=headers)
|
| 29 |
|
|
@@ -34,6 +36,5 @@ def generate_music(server_url, prompts, output_file):
|
|
| 34 |
else:
|
| 35 |
print(f"Failed to generate music: {response.status_code}, {response.text}")
|
| 36 |
|
| 37 |
-
|
| 38 |
if __name__ == "__main__":
|
| 39 |
-
generate_music(args.server_url, args.prompts, args.output_file)
|
|
|
|
| 16 |
parser.add_argument(
|
| 17 |
"--output_file", type=str, default="output.wav", help="Output file name"
|
| 18 |
)
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--duration", type=int, default=10, help="Duration of generated music in seconds"
|
| 21 |
+
)
|
| 22 |
|
| 23 |
args = parser.parse_args()
|
| 24 |
|
| 25 |
+
def generate_music(server_url, prompts, duration, output_file):
|
|
|
|
| 26 |
url = f"{server_url}/generate_music"
|
| 27 |
headers = {"Content-Type": "application/json"}
|
| 28 |
+
data = {"prompts": prompts, "duration": duration}
|
| 29 |
|
| 30 |
response = requests.post(url, json=data, headers=headers)
|
| 31 |
|
|
|
|
| 36 |
else:
|
| 37 |
print(f"Failed to generate music: {response.status_code}, {response.text}")
|
| 38 |
|
|
|
|
| 39 |
if __name__ == "__main__":
|
| 40 |
+
generate_music(args.server_url, args.prompts, args.duration, args.output_file)
|
server.py
CHANGED
|
@@ -2,7 +2,7 @@ import warnings
|
|
| 2 |
import argparse
|
| 3 |
from fastapi import FastAPI, HTTPException
|
| 4 |
from pydantic import BaseModel
|
| 5 |
-
from typing import List
|
| 6 |
import torch
|
| 7 |
from audiocraft.models import musicgen
|
| 8 |
import numpy as np
|
|
@@ -17,7 +17,6 @@ warnings.simplefilter('ignore')
|
|
| 17 |
parser = argparse.ArgumentParser(description="Music Generation Server")
|
| 18 |
parser.add_argument("--model_name", type=str, default="small", help="Pretrained model name")
|
| 19 |
parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on")
|
| 20 |
-
parser.add_argument("--duration", type=int, default=10, help="Duration of generated music in seconds")
|
| 21 |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
|
| 22 |
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
| 23 |
|
|
@@ -28,14 +27,15 @@ app = FastAPI()
|
|
| 28 |
|
| 29 |
# Load the model with the provided arguments
|
| 30 |
musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
|
| 31 |
-
musicgen_model.set_generation_params(duration=args.duration)
|
| 32 |
|
| 33 |
class MusicRequest(BaseModel):
|
| 34 |
prompts: List[str]
|
|
|
|
| 35 |
|
| 36 |
@app.post("/generate_music")
|
| 37 |
def generate_music(request: MusicRequest):
|
| 38 |
try:
|
|
|
|
| 39 |
result = musicgen_model.generate(request.prompts, progress=False)
|
| 40 |
result = result.squeeze().cpu().numpy()
|
| 41 |
|
|
|
|
| 2 |
import argparse
|
| 3 |
from fastapi import FastAPI, HTTPException
|
| 4 |
from pydantic import BaseModel
|
| 5 |
+
from typing import List, Optional
|
| 6 |
import torch
|
| 7 |
from audiocraft.models import musicgen
|
| 8 |
import numpy as np
|
|
|
|
| 17 |
parser = argparse.ArgumentParser(description="Music Generation Server")
|
| 18 |
parser.add_argument("--model_name", type=str, default="small", help="Pretrained model name")
|
| 19 |
parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on")
|
|
|
|
| 20 |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
|
| 21 |
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
| 22 |
|
|
|
|
| 27 |
|
| 28 |
# Load the model with the provided arguments
|
| 29 |
musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
|
|
|
|
| 30 |
|
| 31 |
class MusicRequest(BaseModel):
|
| 32 |
prompts: List[str]
|
| 33 |
+
duration: Optional[int] = 10 # Default duration is 10 seconds if not provided
|
| 34 |
|
| 35 |
@app.post("/generate_music")
|
| 36 |
def generate_music(request: MusicRequest):
|
| 37 |
try:
|
| 38 |
+
musicgen_model.set_generation_params(duration=request.duration)
|
| 39 |
result = musicgen_model.generate(request.prompts, progress=False)
|
| 40 |
result = result.squeeze().cpu().numpy()
|
| 41 |
|