import json import logging import torch import numpy as np from typing import Dict, List, Optional, Union from audiocraft.models import MusicGen from audiocraft.data.audio import audio_write import torchaudio # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: """ AudioCraft-based MusicGen handler with native segment-based generation Supports proper continuation for long sequences with coherent transitions """ def __init__(self, path=""): """Initialize the MusicGen model using audiocraft""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Initializing AudioCraft MusicGen on device: {self.device}") # Load MusicGen model using audiocraft model_name = "facebook/musicgen-large" self.model = MusicGen.get_pretrained(model_name, device=self.device) # Model specifications self.sample_rate = self.model.sample_rate # Should be 32000 for musicgen-large self.max_segment_duration = 30.0 # Maximum duration per segment self.default_extend_stride = 18.0 # Optimal stride for continuation logger.info(f"AudioCraft MusicGen initialized successfully") logger.info(f"Sample rate: {self.sample_rate}Hz") logger.info(f"Max segment duration: {self.max_segment_duration}s") def __call__(self, data: Dict) -> Dict: """ Process the inference request with native audiocraft segment-based generation Expected input format: { "inputs": { "prompt": "description of music", "duration": 60.0 # Can be longer than 30 seconds }, "parameters": { "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0, "use_sampling": true, "extend_stride": 18.0 # Optional override } } """ try: # Extract inputs and parameters inputs = data.get("inputs", {}) parameters = data.get("parameters", {}) # Get prompt and duration prompt = inputs.get("prompt", "").strip() total_duration = float(inputs.get("duration", 10.0)) # Validate inputs if not prompt: raise ValueError("Prompt cannot be empty") # Clamp duration to reasonable range (0.5 to 300 seconds) total_duration = max(0.5, min(total_duration, 300.0)) # Format prompt for better results formatted_prompt = self._format_prompt(prompt) logger.info(f"Formatted prompt: {formatted_prompt}") # Extract generation parameters generation_params = self._extract_generation_params(parameters) extend_stride = parameters.get("extend_stride", self.default_extend_stride) logger.info(f"Generation params: {generation_params}") logger.info(f"Total duration: {total_duration}s, Extend stride: {extend_stride}s") # Generate audio using audiocraft's native continuation support if total_duration <= self.max_segment_duration: # Single segment generation logger.info(f"Single segment generation for {total_duration}s") audio_tensor = self._generate_single_segment(formatted_prompt, total_duration, generation_params) else: # Multi-segment generation with native continuation logger.info(f"Multi-segment generation for {total_duration}s") audio_tensor = self._generate_long_sequence_native( formatted_prompt, total_duration, generation_params, extend_stride ) # Convert to numpy array if audio_tensor.dim() == 3: # Remove batch dimension: [1, channels, samples] -> [channels, samples] audio_tensor = audio_tensor.squeeze(0) if audio_tensor.dim() == 2: # Take first channel if stereo: [channels, samples] -> [samples] audio_array = audio_tensor[0].cpu().float().numpy() else: # Already mono: [samples] audio_array = audio_tensor.cpu().float().numpy() logger.info(f"Generated audio: {len(audio_array)} samples at {self.sample_rate}Hz") logger.info(f"Duration: {len(audio_array) / self.sample_rate:.2f} seconds") # Return in the expected format return { "generated_audio": audio_array.tolist(), "sample_rate": self.sample_rate, "prompt": prompt, "formatted_prompt": formatted_prompt, "duration": total_duration, "parameters": generation_params, "actual_samples": len(audio_array), "expected_samples": int(total_duration * self.sample_rate), "generation_method": "audiocraft_native_continuation" if total_duration > self.max_segment_duration else "audiocraft_single_segment" } except Exception as e: logger.error(f"Error during generation: {str(e)}", exc_info=True) return { "error": str(e), "generated_audio": [], "sample_rate": self.sample_rate, "prompt": inputs.get("prompt", ""), "duration": inputs.get("duration", 10.0) } def _generate_single_segment(self, prompt: str, duration: float, generation_params: Dict) -> torch.Tensor: """Generate a single segment using audiocraft""" logger.info(f"Generating single segment: {duration}s") # Set generation parameters on the model self.model.set_generation_params( duration=duration, **generation_params ) # Generate audio with torch.no_grad(): audio_tensor = self.model.generate(descriptions=[prompt]) return audio_tensor def _generate_long_sequence_native(self, prompt: str, total_duration: float, generation_params: Dict, extend_stride: float) -> torch.Tensor: """ Generate long sequences using audiocraft's native continuation support This provides proper coherent music generation without manual stitching """ logger.info(f"Starting native long sequence generation: {total_duration}s total") segments = [] current_time = 0.0 context_audio = None overlap_duration = 10.0 # 10 seconds overlap for context while current_time < total_duration: remaining_time = total_duration - current_time segment_duration = min(self.max_segment_duration, remaining_time) logger.info(f"Generating segment at {current_time}s, duration: {segment_duration}s") # Set generation parameters for this segment self.model.set_generation_params( duration=segment_duration, extend_stride=extend_stride, **generation_params ) with torch.no_grad(): if context_audio is None: # First segment - text-only generation audio_tensor = self.model.generate(descriptions=[prompt]) else: # Subsequent segments - use continuation with previous audio # Use the last part of previous segment as context overlap_samples = int(overlap_duration * self.sample_rate) context_chunk = context_audio[:, :, -overlap_samples:] audio_tensor = self.model.generate_continuation( context_chunk, self.sample_rate, descriptions=[prompt], progress=False ) segments.append(audio_tensor) # Prepare context for next segment if current_time + segment_duration < total_duration: # Use this segment as context for the next context_audio = audio_tensor current_time += extend_stride else: # Last segment current_time = total_duration # Combine segments using audiocraft's approach if len(segments) == 1: return segments[0] else: return self._combine_segments_audiocraft_style(segments, extend_stride, overlap_duration) def _combine_segments_audiocraft_style(self, segments: List[torch.Tensor], extend_stride: float, overlap_duration: float) -> torch.Tensor: """ Combine segments using audiocraft's native approach This maintains the coherent transitions that audiocraft provides """ logger.info(f"Combining {len(segments)} segments with {extend_stride}s stride") if len(segments) == 1: return segments[0] # Start with first segment combined_audio = segments[0] stride_samples = int(extend_stride * self.sample_rate) overlap_samples = int(overlap_duration * self.sample_rate) for i, segment in enumerate(segments[1:], 1): # Calculate where to place the next segment # AudioCraft continuation already handles overlap internally, # so we just need to concatenate at the stride position # Remove overlapped portion from previous segment to avoid duplication trim_samples = combined_audio.shape[-1] - (i * stride_samples) if trim_samples > 0: combined_audio = combined_audio[:, :, :-overlap_samples] # Concatenate the new segment combined_audio = torch.cat([combined_audio, segment], dim=-1) return combined_audio def _format_prompt(self, prompt: str) -> str: """Format the prompt for optimal MusicGen results""" formatted = prompt.lower().strip() # Remove excessive punctuation formatted = formatted.replace("...", ",").replace("!!", "!").replace("??", "?") # Ensure proper ending if not formatted.endswith(('.', '!', '?', ',')): formatted += '.' return formatted def _extract_generation_params(self, parameters: Dict) -> Dict: """Extract and validate generation parameters for audiocraft""" # AudioCraft parameter mapping and defaults defaults = { "use_sampling": True, "top_k": 250, "top_p": 0.0, "temperature": 1.0, "cfg_coef": 3.0, "two_step_cfg": False, } # Map parameters from our format to audiocraft format param_mapping = { "guidance_scale": "cfg_coef", "do_sample": "use_sampling", } generation_params = defaults.copy() for key, value in parameters.items(): # Map parameter names target_key = param_mapping.get(key, key) if target_key in generation_params: # Validate parameter values if target_key == "cfg_coef": generation_params[target_key] = max(1.0, min(float(value), 10.0)) elif target_key == "temperature": generation_params[target_key] = max(0.1, min(float(value), 2.0)) elif target_key == "top_k": generation_params[target_key] = max(1, min(int(value), 1000)) elif target_key == "top_p": generation_params[target_key] = max(0.0, min(float(value), 1.0)) elif target_key == "use_sampling": generation_params[target_key] = bool(value) else: generation_params[target_key] = value return generation_params