musicgen / app.py
Phoenixak99's picture
Update app.py
0908573 verified
"""
Gradio Space for MusicGen with API access
Mirrors the exact functionality of handler.py with API endpoints
"""
import gradio as gr
import json
import logging
from typing import Dict, Any, Optional
import numpy as np
# Import the existing handler
from handler import EndpointHandler
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize the handler
logger.info("Initializing MusicGen handler...")
handler = EndpointHandler(path="")
logger.info("Handler initialized successfully!")
def generate_music(
prompt: str,
duration: float = 10.0,
temperature: float = 1.0,
top_k: int = 250,
top_p: float = 0.0,
cfg_coef: float = 3.0,
use_sampling: bool = True,
extend_stride: float = 18.0
) -> tuple:
"""
Generate music using MusicGen
Returns: (audio_array, sample_rate, metadata)
"""
try:
# Prepare request in the same format as handler.py expects
request_data = {
"inputs": {
"prompt": prompt,
"duration": duration
},
"parameters": {
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"cfg_coef": cfg_coef,
"use_sampling": use_sampling,
"extend_stride": extend_stride
}
}
logger.info(f"Processing request: {prompt[:50]}... duration={duration}s")
# Call the handler (exact same as handler.py __call__)
result = handler(request_data)
# Check for errors
if "error" in result:
error_msg = f"Generation failed: {result['error']}"
logger.error(error_msg)
return None, None, error_msg
# Extract audio data
audio_list = result.get("generated_audio", [])
sample_rate = result.get("sample_rate", 32000)
if not audio_list:
return None, None, "No audio generated"
# Convert to numpy array for Gradio
audio_array = np.array(audio_list, dtype=np.float32)
# Prepare metadata
metadata = {
"prompt": result.get("prompt", prompt),
"formatted_prompt": result.get("formatted_prompt", ""),
"duration": result.get("duration", duration),
"sample_rate": sample_rate,
"actual_samples": result.get("actual_samples", 0),
"expected_samples": result.get("expected_samples", 0),
"generation_method": result.get("generation_method", ""),
"parameters": result.get("parameters", {})
}
logger.info(f"Generation successful: {len(audio_array)} samples")
# Return: audio tuple for gr.Audio, metadata string for gr.JSON
# gr.Audio expects (sample_rate, audio_array)
# gr.JSON expects the metadata dict/string
return (sample_rate, audio_array), metadata
except Exception as e:
error_msg = f"Error: {str(e)}"
logger.error(error_msg, exc_info=True)
return None, {"error": error_msg}
def api_generate(request: Dict[str, Any]) -> Dict[str, Any]:
"""
API endpoint that mirrors handler.py exactly
Use this for programmatic access
"""
try:
# Call handler directly with the request
result = handler(request)
return result
except Exception as e:
return {
"error": str(e),
"generated_audio": [],
"sample_rate": 32000
}
# Create Gradio Interface
with gr.Blocks(title="MusicGen API - AudioCraft", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎵 MusicGen Music Generation
Generate music using Meta's AudioCraft MusicGen model.
**Two ways to use this Space:**
1. **UI Below**: Interactive interface for testing
2. **API Access**: Programmatic access (see API tab)
""")
with gr.Tabs():
with gr.Tab("Generate Music"):
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Music Description",
placeholder="e.g., upbeat electronic dance music with heavy bass",
lines=3
)
duration_input = gr.Slider(
minimum=0.5,
maximum=300,
value=10.0,
step=0.5,
label="Duration (seconds)"
)
with gr.Accordion("Advanced Parameters", open=False):
temperature_input = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="Temperature"
)
cfg_coef_input = gr.Slider(
minimum=1.0,
maximum=10.0,
value=3.0,
step=0.5,
label="Guidance Scale (cfg_coef)"
)
top_k_input = gr.Slider(
minimum=1,
maximum=1000,
value=250,
step=1,
label="Top-K"
)
top_p_input = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.05,
label="Top-P"
)
use_sampling_input = gr.Checkbox(
value=True,
label="Use Sampling"
)
extend_stride_input = gr.Slider(
minimum=1.0,
maximum=30.0,
value=18.0,
step=1.0,
label="Extend Stride (for long sequences)"
)
generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg")
with gr.Column():
audio_output = gr.Audio(
label="Generated Music",
type="numpy"
)
metadata_output = gr.JSON(
label="Generation Metadata"
)
# Connect the button
generate_btn.click(
fn=generate_music,
inputs=[
prompt_input,
duration_input,
temperature_input,
top_k_input,
top_p_input,
cfg_coef_input,
use_sampling_input,
extend_stride_input
],
outputs=[audio_output, metadata_output]
)
gr.Examples(
examples=[
["upbeat electronic dance music with heavy bass", 10.0],
["calm piano melody with soft strings", 15.0],
["epic orchestral soundtrack with drums and brass", 20.0],
["ambient atmospheric soundscape with synthesizers", 30.0],
["jazz trio with piano, bass and drums", 15.0]
],
inputs=[prompt_input, duration_input],
label="Example Prompts"
)
with gr.Tab("API Documentation"):
gr.Markdown("""
## 🔌 API Access
This Space exposes a REST API that mirrors the exact functionality of `handler.py`.
### API Endpoint
```
POST https://huggingface.co/spaces/YOUR-USERNAME/YOUR-SPACE-NAME/api/predict
```
### Python Example
```python
from gradio_client import Client
# Connect to your Space
client = Client("YOUR-USERNAME/YOUR-SPACE-NAME")
# Generate music
result = client.predict(
prompt="epic orchestral music",
duration=30.0,
temperature=1.0,
top_k=250,
top_p=0.0,
cfg_coef=3.0,
use_sampling=True,
extend_stride=18.0,
api_name="/predict"
)
print(result) # Returns (audio, metadata)
```
### Direct HTTP API (matches handler.py format)
```python
import requests
API_URL = "https://huggingface.co/spaces/YOUR-USERNAME/YOUR-SPACE-NAME/api/predict"
payload = {
"data": [
"epic orchestral music", # prompt
30.0, # duration
1.0, # temperature
250, # top_k
0.0, # top_p
3.0, # cfg_coef
True, # use_sampling
18.0 # extend_stride
]
}
response = requests.post(API_URL, json=payload)
result = response.json()
# Extract audio and metadata
audio_data = result["data"][0] # (sample_rate, audio_array)
metadata = result["data"][1] # JSON metadata
```
### cURL Example
```bash
curl -X POST https://huggingface.co/spaces/YOUR-USERNAME/YOUR-SPACE-NAME/api/predict \\
-H "Content-Type: application/json" \\
-d '{
"data": [
"upbeat electronic dance music",
10.0,
1.0,
250,
0.0,
3.0,
true,
18.0
]
}'
```
### Response Format
The API returns a tuple of:
1. **Audio**: `(sample_rate, numpy_array)`
2. **Metadata**: JSON with generation details
```json
{
"prompt": "epic orchestral music",
"formatted_prompt": "epic orchestral music.",
"duration": 30.0,
"sample_rate": 32000,
"actual_samples": 960000,
"expected_samples": 960000,
"generation_method": "audiocraft_native_continuation",
"parameters": {
"use_sampling": true,
"top_k": 250,
"top_p": 0.0,
"temperature": 1.0,
"cfg_coef": 3.0,
"two_step_cfg": false
}
}
```
### Notes
- The Space needs to be running (not sleeping) to respond to API calls
- First request after sleep may take 30-60 seconds to load the model
- Maximum duration: 300 seconds (5 minutes)
- Responses include full metadata about the generation
""")
gr.Markdown("""
---
**Model:** Meta AudioCraft MusicGen Large
**GPU:** Required for reasonable inference times
**API:** Fully accessible via Gradio Client or HTTP
""")
# Launch the app
if __name__ == "__main__":
demo.queue(max_size=10)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)