File size: 12,884 Bytes
a602628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee19acb
a602628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4709141
 
 
 
052ca84
a602628
 
4709141
 
a602628
 
052ca84
a602628
 
 
ee19acb
a602628
ee19acb
 
 
a602628
 
ee19acb
 
a602628
 
ee19acb
 
fa7f63d
ee19acb
a602628
 
ee19acb
 
a602628
ee19acb
 
a602628
 
 
 
 
 
78910e3
a602628
 
 
 
 
 
 
 
78910e3
 
 
a602628
 
 
78910e3
a602628
 
6a590ee
 
a602628
 
6a590ee
a602628
 
 
 
 
 
 
 
 
78910e3
 
a602628
 
 
 
78910e3
a602628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
052ca84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a602628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
052ca84
 
a602628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
052ca84
 
a602628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
"""
ACE-Step Engine - Wrapper for ACE-Step 1.5 official architecture
Properly integrates AceStepHandler (DiT) and LLMHandler (5Hz LM)
"""

import torch
from pathlib import Path
import logging
from typing import Optional, Dict, Any, Tuple
import os

logger = logging.getLogger(__name__)

# Import ACE-Step 1.5 official handlers
try:
    from acestep.handler import AceStepHandler
    from acestep.llm_inference import LLMHandler
    from acestep.inference import GenerationParams, GenerationConfig, generate_music
    from acestep.model_downloader import ensure_main_model, get_checkpoints_dir, check_main_model_exists
    ACE_STEP_AVAILABLE = True
except ImportError as e:
    logger.warning(f"ACE-Step 1.5 modules not available: {e}")
    ACE_STEP_AVAILABLE = False


class ACEStepEngine:
    """Wrapper engine for ACE-Step 1.5 with custom interface."""

    def __init__(self, config: Dict[str, Any]):
        """
        Initialize ACE-Step engine.

        Args:
            config: Configuration dictionary
        """
        self.config = config
        self._initialized = False
        self.dit_handler = None
        self.llm_handler = None
        
        logger.info(f"ACE-Step Engine created (GPU will be detected on first use)")

        if not ACE_STEP_AVAILABLE:
            logger.error("ACE-Step 1.5 modules not available")
            logger.error("Please ensure acestep package is installed in your environment")
            return

        logger.info("✓ ACE-Step Engine created (models will load on first use)")

    def _download_checkpoints(self):
        """Download model checkpoints from HuggingFace if not present."""
        checkpoints_dir = get_checkpoints_dir(self.config.get("checkpoint_dir"))
        
        # Check if main model already exists
        if check_main_model_exists(checkpoints_dir):
            logger.info(f"✓ ACE-Step 1.5 models already exist at {checkpoints_dir}")
            return
            
        logger.info("Downloading ACE-Step 1.5 models from HuggingFace...")
        logger.info("This may take several minutes (models are ~7GB total)...")
        
        try:
            # Use the built-in model downloader
            success, message = ensure_main_model(
                checkpoints_dir=checkpoints_dir,
                prefer_source="huggingface"  # Use HuggingFace for Spaces
            )
            
            if not success:
                raise RuntimeError(f"Failed to download models: {message}")
            
            logger.info(f"✓ {message}")
            logger.info("✓ All ACE-Step 1.5 models downloaded successfully")
            
        except Exception as e:
            logger.error(f"Failed to download checkpoints: {e}")
            raise

    def _load_models(self):
        """Initialize and load ACE-Step models."""
        try:
            if not ACE_STEP_AVAILABLE:
                raise RuntimeError("ACE-Step 1.5 not available")

            checkpoint_dir = self.config.get("checkpoint_dir", "./checkpoints")
            dit_model_path = self.config.get("dit_model_path", "acestep-v15-turbo")
            lm_model_path = self.config.get("lm_model_path", "acestep-5Hz-lm-1.7B")
            
            # Get checkpoints directory using helper function
            checkpoints_dir = get_checkpoints_dir(checkpoint_dir)
            
            # Get project root
            project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
            
            logger.info(f"Initializing DiT handler with model: {dit_model_path}")
            
            # Initialize DiT handler (handles main diffusion model, VAE, text encoder)
            # Note: handler auto-detects checkpoints dir as project_root/checkpoints
            # config_path should be just the model name, not full path
            status_dit, success_dit = self.dit_handler.initialize_service(
                project_root=project_root,
                config_path=dit_model_path,  # Just model name, handler adds checkpoints/
                device="auto",
                use_flash_attention=False,
                compile_model=False,
                offload_to_cpu=False,
            )
            
            if not success_dit:
                raise RuntimeError(f"Failed to initialize DiT: {status_dit}")
            
            logger.info(f"✓ DiT initialized: {status_dit}")
            
            # Initialize LLM handler (handles 5Hz Language Model)
            logger.info(f"Initializing LLM handler with model: {lm_model_path}")
            
            status_llm, success_llm = self.llm_handler.initialize(
                checkpoint_dir=str(checkpoints_dir),
                lm_model_path=lm_model_path,
                backend="pt",  # Use PyTorch backend for compatibility
                device="auto",
                offload_to_cpu=False,
            )
            
            if not success_llm:
                logger.warning(f"LLM initialization failed: {status_llm}")
                logger.warning("Continuing without LLM (DiT-only mode)")
            else:
                logger.info(f" LLM initialized: {status_llm}")
            
            self._initialized = True
            logger.info(" ACE-Step engine fully initialized")

        except Exception as e:
            logger.error(f"Failed to initialize models: {e}")
            raise

    def _ensure_models_loaded(self):
        """Ensure models are loaded (lazy loading for ZeroGPU compatibility)."""
        if not self._initialized:
            logger.info("Lazy loading models on first use...")
            
            # Detect device now (within GPU context on ZeroGPU)
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            logger.info(f"Using device: {self.device}")
            
            # Create handlers if not already created
            if self.dit_handler is None:
                self.dit_handler = AceStepHandler()
            if self.llm_handler is None:
                self.llm_handler = LLMHandler()
            
            try:
                # Download and load models
                self._download_checkpoints()
                self._load_models()
                logger.info("✓ Models loaded successfully")
            except Exception as e:
                logger.error(f"Failed to load models: {e}")
                raise

    def generate(
        self,
        prompt: str,
        lyrics: Optional[str] = None,
        duration: int = 30,
        temperature: float = 0.7,
        top_p: float = 0.9,
        seed: int = -1,
        style: str = "auto",
        lora_path: Optional[str] = None
    ) -> str:
        """
        Generate music using ACE-Step.

        Args:
            prompt: Text description of desired music
            lyrics: Optional lyrics
            duration: Duration in seconds
            temperature: Sampling temperature (for LLM)
            top_p: Nucleus sampling parameter (for LLM)
            seed: Random seed (-1 for random)
            style: Music style
            lora_path: Path to LoRA model if using

        Returns:
            Path to generated audio file
        """
        # Ensure models are loaded (lazy loading for ZeroGPU)
        self._ensure_models_loaded()

        try:
            # Prepare generation parameters
            params = GenerationParams(
                task_type="text2music",
                caption=prompt,
                lyrics=lyrics or "",
                duration=duration,
                inference_steps=8,  # Turbo model default
                seed=seed if seed >= 0 else -1,
                thinking=True,  # Use LLM planning
                lm_temperature=temperature,
                lm_top_p=top_p,
            )
            
            # Prepare generation config
            config = GenerationConfig(
                batch_size=1,
                use_random_seed=(seed < 0),
                audio_format="wav",
            )
            
            # Generate using official inference
            output_dir = self.config.get("output_dir", "outputs")
            os.makedirs(output_dir, exist_ok=True)
            
            logger.info(f"Generating {duration}s audio: {prompt[:50]}...")
            
            result = generate_music(
                dit_handler=self.dit_handler,
                llm_handler=self.llm_handler,
                params=params,
                config=config,
                save_dir=output_dir,
            )
            
            if result.audio_paths:
                output_path = result.audio_paths[0]
                logger.info(f" Generated: {output_path}")
                return output_path
            else:
                raise RuntimeError("No audio generated")

        except Exception as e:
            logger.error(f"Generation failed: {e}")
            raise

    def generate_clip(
        self,
        prompt: str,
        lyrics: str,
        duration: int,
        context_audio: Optional[str] = None,
        style: str = "auto",
        temperature: float = 0.7,
        seed: int = -1
    ) -> str:
        """
        Generate audio clip for timeline (with context conditioning).
        
        Args:
            prompt: Text prompt
            lyrics: Lyrics for this clip
            duration: Duration in seconds (typically 32)
            context_audio: Path to previous audio for style conditioning
            style: Music style
            temperature: Sampling temperature
            seed: Random seed

        Returns:
            Path to generated clip
        """
        # For timeline clips, use regular generation with extended context
        # Context conditioning would require custom implementation
        return self.generate(
            prompt=prompt,
            lyrics=lyrics,
            duration=duration,
            temperature=temperature,
            seed=seed,
            style=style
        )

    def generate_variation(self, audio_path: str, strength: float = 0.5) -> str:
        """Generate variation of existing audio."""
        # Ensure models are loaded (lazy loading for ZeroGPU)
        self._ensure_models_loaded()
            
        try:
            params = GenerationParams(
                task_type="audio_variation",
                audio_path=audio_path,
                audio_cover_strength=strength,
                inference_steps=8,
            )
            
            config = GenerationConfig(
                batch_size=1,
                audio_format="wav",
            )
            
            output_dir = self.config.get("output_dir", "outputs")
            
            result = generate_music(
                self.dit_handler,
                self.llm_handler,
                params,
                config,
                save_dir=output_dir,
            )
            
            return result.audio_paths[0] if result.audio_paths else audio_path

        except Exception as e:
            logger.error(f"Variation generation failed: {e}")
            raise

    def repaint(
        self,
        audio_path: str,
        start_time: float,
        end_time: float,
        new_prompt: str
    ) -> str:
        """Repaint specific section of audio."""
        if not self._initialized:
            raise RuntimeError("Engine not initialized")
            
        try:
            params = GenerationParams(
                task_type="repainting",
                audio_path=audio_path,
                caption=new_prompt,
                repainting_start=start_time,
                repainting_end=end_time,
                inference_steps=8,
            )
            
            config = GenerationConfig(
                batch_size=1,
                audio_format="wav",
            )
            
            output_dir = self.config.get("output_dir", "outputs")
            
            result = generate_music(
                self.dit_handler,
                self.llm_handler,
                params,
                config,
                save_dir=output_dir,
            )
            
            return result.audio_paths[0] if result.audio_paths else audio_path

        except Exception as e:
            logger.error(f"Repainting failed: {e}")
            raise

    def edit_lyrics(self, audio_path: str, new_lyrics: str) -> str:
        """Edit lyrics while maintaining music."""
        # This would require custom implementation
        # For now, regenerate with same style
        logger.warning("Lyric editing not fully implemented - regenerating with new lyrics")
        
        return self.generate(
            prompt="Match the style of the reference",
            lyrics=new_lyrics,
            duration=30,
        )

    def is_initialized(self) -> bool:
        """Check if engine is initialized."""
        return self._initialized