Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| import torch | |
| import numpy as np | |
| from typing import Dict, List, Optional, Union | |
| from audiocraft.models import MusicGen | |
| from audiocraft.data.audio import audio_write | |
| import torchaudio | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class EndpointHandler: | |
| """ | |
| AudioCraft-based MusicGen handler with native segment-based generation | |
| Supports proper continuation for long sequences with coherent transitions | |
| """ | |
| def __init__(self, path=""): | |
| """Initialize the MusicGen model using audiocraft""" | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Initializing AudioCraft MusicGen on device: {self.device}") | |
| # Load MusicGen model using audiocraft | |
| model_name = "facebook/musicgen-large" | |
| self.model = MusicGen.get_pretrained(model_name, device=self.device) | |
| # Model specifications | |
| self.sample_rate = self.model.sample_rate # Should be 32000 for musicgen-large | |
| self.max_segment_duration = 30.0 # Maximum duration per segment | |
| self.default_extend_stride = 18.0 # Optimal stride for continuation | |
| logger.info(f"AudioCraft MusicGen initialized successfully") | |
| logger.info(f"Sample rate: {self.sample_rate}Hz") | |
| logger.info(f"Max segment duration: {self.max_segment_duration}s") | |
| def __call__(self, data: Dict) -> Dict: | |
| """ | |
| Process the inference request with native audiocraft segment-based generation | |
| Expected input format: | |
| { | |
| "inputs": { | |
| "prompt": "description of music", | |
| "duration": 60.0 # Can be longer than 30 seconds | |
| }, | |
| "parameters": { | |
| "temperature": 1.0, | |
| "top_k": 250, | |
| "top_p": 0.0, | |
| "cfg_coef": 3.0, | |
| "use_sampling": true, | |
| "extend_stride": 18.0 # Optional override | |
| } | |
| } | |
| """ | |
| try: | |
| # Extract inputs and parameters | |
| inputs = data.get("inputs", {}) | |
| parameters = data.get("parameters", {}) | |
| # Get prompt and duration | |
| prompt = inputs.get("prompt", "").strip() | |
| total_duration = float(inputs.get("duration", 10.0)) | |
| # Validate inputs | |
| if not prompt: | |
| raise ValueError("Prompt cannot be empty") | |
| # Clamp duration to reasonable range (0.5 to 300 seconds) | |
| total_duration = max(0.5, min(total_duration, 300.0)) | |
| # Format prompt for better results | |
| formatted_prompt = self._format_prompt(prompt) | |
| logger.info(f"Formatted prompt: {formatted_prompt}") | |
| # Extract generation parameters | |
| generation_params = self._extract_generation_params(parameters) | |
| extend_stride = parameters.get("extend_stride", self.default_extend_stride) | |
| logger.info(f"Generation params: {generation_params}") | |
| logger.info(f"Total duration: {total_duration}s, Extend stride: {extend_stride}s") | |
| # Generate audio using audiocraft's native continuation support | |
| if total_duration <= self.max_segment_duration: | |
| # Single segment generation | |
| logger.info(f"Single segment generation for {total_duration}s") | |
| audio_tensor = self._generate_single_segment(formatted_prompt, total_duration, generation_params) | |
| else: | |
| # Multi-segment generation with native continuation | |
| logger.info(f"Multi-segment generation for {total_duration}s") | |
| audio_tensor = self._generate_long_sequence_native( | |
| formatted_prompt, total_duration, generation_params, extend_stride | |
| ) | |
| # Convert to numpy array | |
| if audio_tensor.dim() == 3: | |
| # Remove batch dimension: [1, channels, samples] -> [channels, samples] | |
| audio_tensor = audio_tensor.squeeze(0) | |
| if audio_tensor.dim() == 2: | |
| # Take first channel if stereo: [channels, samples] -> [samples] | |
| audio_array = audio_tensor[0].cpu().float().numpy() | |
| else: | |
| # Already mono: [samples] | |
| audio_array = audio_tensor.cpu().float().numpy() | |
| logger.info(f"Generated audio: {len(audio_array)} samples at {self.sample_rate}Hz") | |
| logger.info(f"Duration: {len(audio_array) / self.sample_rate:.2f} seconds") | |
| # Return in the expected format | |
| return { | |
| "generated_audio": audio_array.tolist(), | |
| "sample_rate": self.sample_rate, | |
| "prompt": prompt, | |
| "formatted_prompt": formatted_prompt, | |
| "duration": total_duration, | |
| "parameters": generation_params, | |
| "actual_samples": len(audio_array), | |
| "expected_samples": int(total_duration * self.sample_rate), | |
| "generation_method": "audiocraft_native_continuation" if total_duration > self.max_segment_duration else "audiocraft_single_segment" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error during generation: {str(e)}", exc_info=True) | |
| return { | |
| "error": str(e), | |
| "generated_audio": [], | |
| "sample_rate": self.sample_rate, | |
| "prompt": inputs.get("prompt", ""), | |
| "duration": inputs.get("duration", 10.0) | |
| } | |
| def _generate_single_segment(self, prompt: str, duration: float, generation_params: Dict) -> torch.Tensor: | |
| """Generate a single segment using audiocraft""" | |
| logger.info(f"Generating single segment: {duration}s") | |
| # Set generation parameters on the model | |
| self.model.set_generation_params( | |
| duration=duration, | |
| **generation_params | |
| ) | |
| # Generate audio | |
| with torch.no_grad(): | |
| audio_tensor = self.model.generate(descriptions=[prompt]) | |
| return audio_tensor | |
| def _generate_long_sequence_native(self, prompt: str, total_duration: float, | |
| generation_params: Dict, extend_stride: float) -> torch.Tensor: | |
| """ | |
| Generate long sequences using audiocraft's native continuation support | |
| This provides proper coherent music generation without manual stitching | |
| """ | |
| logger.info(f"Starting native long sequence generation: {total_duration}s total") | |
| segments = [] | |
| current_time = 0.0 | |
| context_audio = None | |
| overlap_duration = 10.0 # 10 seconds overlap for context | |
| while current_time < total_duration: | |
| remaining_time = total_duration - current_time | |
| segment_duration = min(self.max_segment_duration, remaining_time) | |
| logger.info(f"Generating segment at {current_time}s, duration: {segment_duration}s") | |
| # Set generation parameters for this segment | |
| self.model.set_generation_params( | |
| duration=segment_duration, | |
| extend_stride=extend_stride, | |
| **generation_params | |
| ) | |
| with torch.no_grad(): | |
| if context_audio is None: | |
| # First segment - text-only generation | |
| audio_tensor = self.model.generate(descriptions=[prompt]) | |
| else: | |
| # Subsequent segments - use continuation with previous audio | |
| # Use the last part of previous segment as context | |
| overlap_samples = int(overlap_duration * self.sample_rate) | |
| context_chunk = context_audio[:, :, -overlap_samples:] | |
| audio_tensor = self.model.generate_continuation( | |
| context_chunk, | |
| self.sample_rate, | |
| descriptions=[prompt], | |
| progress=False | |
| ) | |
| segments.append(audio_tensor) | |
| # Prepare context for next segment | |
| if current_time + segment_duration < total_duration: | |
| # Use this segment as context for the next | |
| context_audio = audio_tensor | |
| current_time += extend_stride | |
| else: | |
| # Last segment | |
| current_time = total_duration | |
| # Combine segments using audiocraft's approach | |
| if len(segments) == 1: | |
| return segments[0] | |
| else: | |
| return self._combine_segments_audiocraft_style(segments, extend_stride, overlap_duration) | |
| def _combine_segments_audiocraft_style(self, segments: List[torch.Tensor], | |
| extend_stride: float, overlap_duration: float) -> torch.Tensor: | |
| """ | |
| Combine segments using audiocraft's native approach | |
| This maintains the coherent transitions that audiocraft provides | |
| """ | |
| logger.info(f"Combining {len(segments)} segments with {extend_stride}s stride") | |
| if len(segments) == 1: | |
| return segments[0] | |
| # Start with first segment | |
| combined_audio = segments[0] | |
| stride_samples = int(extend_stride * self.sample_rate) | |
| overlap_samples = int(overlap_duration * self.sample_rate) | |
| for i, segment in enumerate(segments[1:], 1): | |
| # Calculate where to place the next segment | |
| # AudioCraft continuation already handles overlap internally, | |
| # so we just need to concatenate at the stride position | |
| # Remove overlapped portion from previous segment to avoid duplication | |
| trim_samples = combined_audio.shape[-1] - (i * stride_samples) | |
| if trim_samples > 0: | |
| combined_audio = combined_audio[:, :, :-overlap_samples] | |
| # Concatenate the new segment | |
| combined_audio = torch.cat([combined_audio, segment], dim=-1) | |
| return combined_audio | |
| def _format_prompt(self, prompt: str) -> str: | |
| """Format the prompt for optimal MusicGen results""" | |
| formatted = prompt.lower().strip() | |
| # Remove excessive punctuation | |
| formatted = formatted.replace("...", ",").replace("!!", "!").replace("??", "?") | |
| # Ensure proper ending | |
| if not formatted.endswith(('.', '!', '?', ',')): | |
| formatted += '.' | |
| return formatted | |
| def _extract_generation_params(self, parameters: Dict) -> Dict: | |
| """Extract and validate generation parameters for audiocraft""" | |
| # AudioCraft parameter mapping and defaults | |
| defaults = { | |
| "use_sampling": True, | |
| "top_k": 250, | |
| "top_p": 0.0, | |
| "temperature": 1.0, | |
| "cfg_coef": 3.0, | |
| "two_step_cfg": False, | |
| } | |
| # Map parameters from our format to audiocraft format | |
| param_mapping = { | |
| "guidance_scale": "cfg_coef", | |
| "do_sample": "use_sampling", | |
| } | |
| generation_params = defaults.copy() | |
| for key, value in parameters.items(): | |
| # Map parameter names | |
| target_key = param_mapping.get(key, key) | |
| if target_key in generation_params: | |
| # Validate parameter values | |
| if target_key == "cfg_coef": | |
| generation_params[target_key] = max(1.0, min(float(value), 10.0)) | |
| elif target_key == "temperature": | |
| generation_params[target_key] = max(0.1, min(float(value), 2.0)) | |
| elif target_key == "top_k": | |
| generation_params[target_key] = max(1, min(int(value), 1000)) | |
| elif target_key == "top_p": | |
| generation_params[target_key] = max(0.0, min(float(value), 1.0)) | |
| elif target_key == "use_sampling": | |
| generation_params[target_key] = bool(value) | |
| else: | |
| generation_params[target_key] = value | |
| return generation_params |