File size: 12,565 Bytes
54f456f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import json
import logging
import torch
import numpy as np
from typing import Dict, List, Optional, Union
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
import torchaudio

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler:
    """
    AudioCraft-based MusicGen handler with native segment-based generation
    Supports proper continuation for long sequences with coherent transitions
    """
    
    def __init__(self, path=""):
        """Initialize the MusicGen model using audiocraft"""
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Initializing AudioCraft MusicGen on device: {self.device}")
        
        # Load MusicGen model using audiocraft
        model_name = "facebook/musicgen-large"
        self.model = MusicGen.get_pretrained(model_name, device=self.device)
        
        # Model specifications
        self.sample_rate = self.model.sample_rate  # Should be 32000 for musicgen-large
        self.max_segment_duration = 30.0  # Maximum duration per segment
        self.default_extend_stride = 18.0  # Optimal stride for continuation
        
        logger.info(f"AudioCraft MusicGen initialized successfully")
        logger.info(f"Sample rate: {self.sample_rate}Hz")
        logger.info(f"Max segment duration: {self.max_segment_duration}s")
    
    def __call__(self, data: Dict) -> Dict:
        """
        Process the inference request with native audiocraft segment-based generation
        
        Expected input format:
        {
            "inputs": {
                "prompt": "description of music",
                "duration": 60.0  # Can be longer than 30 seconds
            },
            "parameters": {
                "temperature": 1.0,
                "top_k": 250,
                "top_p": 0.0,
                "cfg_coef": 3.0,
                "use_sampling": true,
                "extend_stride": 18.0  # Optional override
            }
        }
        """
        try:
            # Extract inputs and parameters
            inputs = data.get("inputs", {})
            parameters = data.get("parameters", {})
            
            # Get prompt and duration
            prompt = inputs.get("prompt", "").strip()
            total_duration = float(inputs.get("duration", 10.0))
            
            # Validate inputs
            if not prompt:
                raise ValueError("Prompt cannot be empty")
            
            # Clamp duration to reasonable range (0.5 to 300 seconds)
            total_duration = max(0.5, min(total_duration, 300.0))
            
            # Format prompt for better results
            formatted_prompt = self._format_prompt(prompt)
            logger.info(f"Formatted prompt: {formatted_prompt}")
            
            # Extract generation parameters
            generation_params = self._extract_generation_params(parameters)
            extend_stride = parameters.get("extend_stride", self.default_extend_stride)
            
            logger.info(f"Generation params: {generation_params}")
            logger.info(f"Total duration: {total_duration}s, Extend stride: {extend_stride}s")
            
            # Generate audio using audiocraft's native continuation support
            if total_duration <= self.max_segment_duration:
                # Single segment generation
                logger.info(f"Single segment generation for {total_duration}s")
                audio_tensor = self._generate_single_segment(formatted_prompt, total_duration, generation_params)
            else:
                # Multi-segment generation with native continuation
                logger.info(f"Multi-segment generation for {total_duration}s")
                audio_tensor = self._generate_long_sequence_native(
                    formatted_prompt, total_duration, generation_params, extend_stride
                )
            
            # Convert to numpy array
            if audio_tensor.dim() == 3:
                # Remove batch dimension: [1, channels, samples] -> [channels, samples]
                audio_tensor = audio_tensor.squeeze(0)
            
            if audio_tensor.dim() == 2:
                # Take first channel if stereo: [channels, samples] -> [samples]
                audio_array = audio_tensor[0].cpu().float().numpy()
            else:
                # Already mono: [samples]
                audio_array = audio_tensor.cpu().float().numpy()
            
            logger.info(f"Generated audio: {len(audio_array)} samples at {self.sample_rate}Hz")
            logger.info(f"Duration: {len(audio_array) / self.sample_rate:.2f} seconds")
            
            # Return in the expected format
            return {
                "generated_audio": audio_array.tolist(),
                "sample_rate": self.sample_rate,
                "prompt": prompt,
                "formatted_prompt": formatted_prompt,
                "duration": total_duration,
                "parameters": generation_params,
                "actual_samples": len(audio_array),
                "expected_samples": int(total_duration * self.sample_rate),
                "generation_method": "audiocraft_native_continuation" if total_duration > self.max_segment_duration else "audiocraft_single_segment"
            }
            
        except Exception as e:
            logger.error(f"Error during generation: {str(e)}", exc_info=True)
            return {
                "error": str(e),
                "generated_audio": [],
                "sample_rate": self.sample_rate,
                "prompt": inputs.get("prompt", ""),
                "duration": inputs.get("duration", 10.0)
            }
    
    def _generate_single_segment(self, prompt: str, duration: float, generation_params: Dict) -> torch.Tensor:
        """Generate a single segment using audiocraft"""
        logger.info(f"Generating single segment: {duration}s")
        
        # Set generation parameters on the model
        self.model.set_generation_params(
            duration=duration,
            **generation_params
        )
        
        # Generate audio
        with torch.no_grad():
            audio_tensor = self.model.generate(descriptions=[prompt])
        
        return audio_tensor
    
    def _generate_long_sequence_native(self, prompt: str, total_duration: float, 
                                     generation_params: Dict, extend_stride: float) -> torch.Tensor:
        """
        Generate long sequences using audiocraft's native continuation support
        This provides proper coherent music generation without manual stitching
        """
        logger.info(f"Starting native long sequence generation: {total_duration}s total")
        
        segments = []
        current_time = 0.0
        context_audio = None
        overlap_duration = 10.0  # 10 seconds overlap for context
        
        while current_time < total_duration:
            remaining_time = total_duration - current_time
            segment_duration = min(self.max_segment_duration, remaining_time)
            
            logger.info(f"Generating segment at {current_time}s, duration: {segment_duration}s")
            
            # Set generation parameters for this segment
            self.model.set_generation_params(
                duration=segment_duration,
                extend_stride=extend_stride,
                **generation_params
            )
            
            with torch.no_grad():
                if context_audio is None:
                    # First segment - text-only generation
                    audio_tensor = self.model.generate(descriptions=[prompt])
                else:
                    # Subsequent segments - use continuation with previous audio
                    # Use the last part of previous segment as context
                    overlap_samples = int(overlap_duration * self.sample_rate)
                    context_chunk = context_audio[:, :, -overlap_samples:]
                    
                    audio_tensor = self.model.generate_continuation(
                        context_chunk,
                        self.sample_rate,
                        descriptions=[prompt],
                        progress=False
                    )
            
            segments.append(audio_tensor)
            
            # Prepare context for next segment
            if current_time + segment_duration < total_duration:
                # Use this segment as context for the next
                context_audio = audio_tensor
                current_time += extend_stride
            else:
                # Last segment
                current_time = total_duration
        
        # Combine segments using audiocraft's approach
        if len(segments) == 1:
            return segments[0]
        else:
            return self._combine_segments_audiocraft_style(segments, extend_stride, overlap_duration)
    
    def _combine_segments_audiocraft_style(self, segments: List[torch.Tensor], 
                                         extend_stride: float, overlap_duration: float) -> torch.Tensor:
        """
        Combine segments using audiocraft's native approach
        This maintains the coherent transitions that audiocraft provides
        """
        logger.info(f"Combining {len(segments)} segments with {extend_stride}s stride")
        
        if len(segments) == 1:
            return segments[0]
        
        # Start with first segment
        combined_audio = segments[0]
        stride_samples = int(extend_stride * self.sample_rate)
        overlap_samples = int(overlap_duration * self.sample_rate)
        
        for i, segment in enumerate(segments[1:], 1):
            # Calculate where to place the next segment
            # AudioCraft continuation already handles overlap internally,
            # so we just need to concatenate at the stride position
            
            # Remove overlapped portion from previous segment to avoid duplication
            trim_samples = combined_audio.shape[-1] - (i * stride_samples)
            if trim_samples > 0:
                combined_audio = combined_audio[:, :, :-overlap_samples]
            
            # Concatenate the new segment
            combined_audio = torch.cat([combined_audio, segment], dim=-1)
        
        return combined_audio
    
    def _format_prompt(self, prompt: str) -> str:
        """Format the prompt for optimal MusicGen results"""
        formatted = prompt.lower().strip()
        
        # Remove excessive punctuation
        formatted = formatted.replace("...", ",").replace("!!", "!").replace("??", "?")
        
        # Ensure proper ending
        if not formatted.endswith(('.', '!', '?', ',')):
            formatted += '.'
        
        return formatted
    
    def _extract_generation_params(self, parameters: Dict) -> Dict:
        """Extract and validate generation parameters for audiocraft"""
        # AudioCraft parameter mapping and defaults
        defaults = {
            "use_sampling": True,
            "top_k": 250,
            "top_p": 0.0,
            "temperature": 1.0,
            "cfg_coef": 3.0,
            "two_step_cfg": False,
        }
        
        # Map parameters from our format to audiocraft format
        param_mapping = {
            "guidance_scale": "cfg_coef",
            "do_sample": "use_sampling",
        }
        
        generation_params = defaults.copy()
        
        for key, value in parameters.items():
            # Map parameter names
            target_key = param_mapping.get(key, key)
            
            if target_key in generation_params:
                # Validate parameter values
                if target_key == "cfg_coef":
                    generation_params[target_key] = max(1.0, min(float(value), 10.0))
                elif target_key == "temperature":
                    generation_params[target_key] = max(0.1, min(float(value), 2.0))
                elif target_key == "top_k":
                    generation_params[target_key] = max(1, min(int(value), 1000))
                elif target_key == "top_p":
                    generation_params[target_key] = max(0.0, min(float(value), 1.0))
                elif target_key == "use_sampling":
                    generation_params[target_key] = bool(value)
                else:
                    generation_params[target_key] = value
        
        return generation_params