musicgen / handler.py
Phoenixak99's picture
Upload handler.py
54f456f verified
import json
import logging
import torch
import numpy as np
from typing import Dict, List, Optional, Union
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
import torchaudio
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
"""
AudioCraft-based MusicGen handler with native segment-based generation
Supports proper continuation for long sequences with coherent transitions
"""
def __init__(self, path=""):
"""Initialize the MusicGen model using audiocraft"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing AudioCraft MusicGen on device: {self.device}")
# Load MusicGen model using audiocraft
model_name = "facebook/musicgen-large"
self.model = MusicGen.get_pretrained(model_name, device=self.device)
# Model specifications
self.sample_rate = self.model.sample_rate # Should be 32000 for musicgen-large
self.max_segment_duration = 30.0 # Maximum duration per segment
self.default_extend_stride = 18.0 # Optimal stride for continuation
logger.info(f"AudioCraft MusicGen initialized successfully")
logger.info(f"Sample rate: {self.sample_rate}Hz")
logger.info(f"Max segment duration: {self.max_segment_duration}s")
def __call__(self, data: Dict) -> Dict:
"""
Process the inference request with native audiocraft segment-based generation
Expected input format:
{
"inputs": {
"prompt": "description of music",
"duration": 60.0 # Can be longer than 30 seconds
},
"parameters": {
"temperature": 1.0,
"top_k": 250,
"top_p": 0.0,
"cfg_coef": 3.0,
"use_sampling": true,
"extend_stride": 18.0 # Optional override
}
}
"""
try:
# Extract inputs and parameters
inputs = data.get("inputs", {})
parameters = data.get("parameters", {})
# Get prompt and duration
prompt = inputs.get("prompt", "").strip()
total_duration = float(inputs.get("duration", 10.0))
# Validate inputs
if not prompt:
raise ValueError("Prompt cannot be empty")
# Clamp duration to reasonable range (0.5 to 300 seconds)
total_duration = max(0.5, min(total_duration, 300.0))
# Format prompt for better results
formatted_prompt = self._format_prompt(prompt)
logger.info(f"Formatted prompt: {formatted_prompt}")
# Extract generation parameters
generation_params = self._extract_generation_params(parameters)
extend_stride = parameters.get("extend_stride", self.default_extend_stride)
logger.info(f"Generation params: {generation_params}")
logger.info(f"Total duration: {total_duration}s, Extend stride: {extend_stride}s")
# Generate audio using audiocraft's native continuation support
if total_duration <= self.max_segment_duration:
# Single segment generation
logger.info(f"Single segment generation for {total_duration}s")
audio_tensor = self._generate_single_segment(formatted_prompt, total_duration, generation_params)
else:
# Multi-segment generation with native continuation
logger.info(f"Multi-segment generation for {total_duration}s")
audio_tensor = self._generate_long_sequence_native(
formatted_prompt, total_duration, generation_params, extend_stride
)
# Convert to numpy array
if audio_tensor.dim() == 3:
# Remove batch dimension: [1, channels, samples] -> [channels, samples]
audio_tensor = audio_tensor.squeeze(0)
if audio_tensor.dim() == 2:
# Take first channel if stereo: [channels, samples] -> [samples]
audio_array = audio_tensor[0].cpu().float().numpy()
else:
# Already mono: [samples]
audio_array = audio_tensor.cpu().float().numpy()
logger.info(f"Generated audio: {len(audio_array)} samples at {self.sample_rate}Hz")
logger.info(f"Duration: {len(audio_array) / self.sample_rate:.2f} seconds")
# Return in the expected format
return {
"generated_audio": audio_array.tolist(),
"sample_rate": self.sample_rate,
"prompt": prompt,
"formatted_prompt": formatted_prompt,
"duration": total_duration,
"parameters": generation_params,
"actual_samples": len(audio_array),
"expected_samples": int(total_duration * self.sample_rate),
"generation_method": "audiocraft_native_continuation" if total_duration > self.max_segment_duration else "audiocraft_single_segment"
}
except Exception as e:
logger.error(f"Error during generation: {str(e)}", exc_info=True)
return {
"error": str(e),
"generated_audio": [],
"sample_rate": self.sample_rate,
"prompt": inputs.get("prompt", ""),
"duration": inputs.get("duration", 10.0)
}
def _generate_single_segment(self, prompt: str, duration: float, generation_params: Dict) -> torch.Tensor:
"""Generate a single segment using audiocraft"""
logger.info(f"Generating single segment: {duration}s")
# Set generation parameters on the model
self.model.set_generation_params(
duration=duration,
**generation_params
)
# Generate audio
with torch.no_grad():
audio_tensor = self.model.generate(descriptions=[prompt])
return audio_tensor
def _generate_long_sequence_native(self, prompt: str, total_duration: float,
generation_params: Dict, extend_stride: float) -> torch.Tensor:
"""
Generate long sequences using audiocraft's native continuation support
This provides proper coherent music generation without manual stitching
"""
logger.info(f"Starting native long sequence generation: {total_duration}s total")
segments = []
current_time = 0.0
context_audio = None
overlap_duration = 10.0 # 10 seconds overlap for context
while current_time < total_duration:
remaining_time = total_duration - current_time
segment_duration = min(self.max_segment_duration, remaining_time)
logger.info(f"Generating segment at {current_time}s, duration: {segment_duration}s")
# Set generation parameters for this segment
self.model.set_generation_params(
duration=segment_duration,
extend_stride=extend_stride,
**generation_params
)
with torch.no_grad():
if context_audio is None:
# First segment - text-only generation
audio_tensor = self.model.generate(descriptions=[prompt])
else:
# Subsequent segments - use continuation with previous audio
# Use the last part of previous segment as context
overlap_samples = int(overlap_duration * self.sample_rate)
context_chunk = context_audio[:, :, -overlap_samples:]
audio_tensor = self.model.generate_continuation(
context_chunk,
self.sample_rate,
descriptions=[prompt],
progress=False
)
segments.append(audio_tensor)
# Prepare context for next segment
if current_time + segment_duration < total_duration:
# Use this segment as context for the next
context_audio = audio_tensor
current_time += extend_stride
else:
# Last segment
current_time = total_duration
# Combine segments using audiocraft's approach
if len(segments) == 1:
return segments[0]
else:
return self._combine_segments_audiocraft_style(segments, extend_stride, overlap_duration)
def _combine_segments_audiocraft_style(self, segments: List[torch.Tensor],
extend_stride: float, overlap_duration: float) -> torch.Tensor:
"""
Combine segments using audiocraft's native approach
This maintains the coherent transitions that audiocraft provides
"""
logger.info(f"Combining {len(segments)} segments with {extend_stride}s stride")
if len(segments) == 1:
return segments[0]
# Start with first segment
combined_audio = segments[0]
stride_samples = int(extend_stride * self.sample_rate)
overlap_samples = int(overlap_duration * self.sample_rate)
for i, segment in enumerate(segments[1:], 1):
# Calculate where to place the next segment
# AudioCraft continuation already handles overlap internally,
# so we just need to concatenate at the stride position
# Remove overlapped portion from previous segment to avoid duplication
trim_samples = combined_audio.shape[-1] - (i * stride_samples)
if trim_samples > 0:
combined_audio = combined_audio[:, :, :-overlap_samples]
# Concatenate the new segment
combined_audio = torch.cat([combined_audio, segment], dim=-1)
return combined_audio
def _format_prompt(self, prompt: str) -> str:
"""Format the prompt for optimal MusicGen results"""
formatted = prompt.lower().strip()
# Remove excessive punctuation
formatted = formatted.replace("...", ",").replace("!!", "!").replace("??", "?")
# Ensure proper ending
if not formatted.endswith(('.', '!', '?', ',')):
formatted += '.'
return formatted
def _extract_generation_params(self, parameters: Dict) -> Dict:
"""Extract and validate generation parameters for audiocraft"""
# AudioCraft parameter mapping and defaults
defaults = {
"use_sampling": True,
"top_k": 250,
"top_p": 0.0,
"temperature": 1.0,
"cfg_coef": 3.0,
"two_step_cfg": False,
}
# Map parameters from our format to audiocraft format
param_mapping = {
"guidance_scale": "cfg_coef",
"do_sample": "use_sampling",
}
generation_params = defaults.copy()
for key, value in parameters.items():
# Map parameter names
target_key = param_mapping.get(key, key)
if target_key in generation_params:
# Validate parameter values
if target_key == "cfg_coef":
generation_params[target_key] = max(1.0, min(float(value), 10.0))
elif target_key == "temperature":
generation_params[target_key] = max(0.1, min(float(value), 2.0))
elif target_key == "top_k":
generation_params[target_key] = max(1, min(int(value), 1000))
elif target_key == "top_p":
generation_params[target_key] = max(0.0, min(float(value), 1.0))
elif target_key == "use_sampling":
generation_params[target_key] = bool(value)
else:
generation_params[target_key] = value
return generation_params