Andrew commited on
Commit
9e2d0e8
·
1 Parent(s): a5cfbf7
acestep/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """ACE-Step package."""
acestep/audio_utils.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio saving and transcoding utility module
3
+
4
+ Independent audio file operations outside of handler, supporting:
5
+ - Save audio tensor/numpy to files (default FLAC format, fast)
6
+ - Format conversion (FLAC/WAV/MP3)
7
+ - Batch processing
8
+ """
9
+
10
+ import os
11
+ import hashlib
12
+ import json
13
+ from pathlib import Path
14
+ from typing import Union, Optional, List, Tuple
15
+ import torch
16
+ import numpy as np
17
+ import torchaudio
18
+ from loguru import logger
19
+
20
+
21
+ class AudioSaver:
22
+ """Audio saving and transcoding utility class"""
23
+
24
+ def __init__(self, default_format: str = "flac"):
25
+ """
26
+ Initialize audio saver
27
+
28
+ Args:
29
+ default_format: Default save format ('flac', 'wav', 'mp3')
30
+ """
31
+ self.default_format = default_format.lower()
32
+ if self.default_format not in ["flac", "wav", "mp3"]:
33
+ logger.warning(f"Unsupported format {default_format}, using 'flac'")
34
+ self.default_format = "flac"
35
+
36
+ def save_audio(
37
+ self,
38
+ audio_data: Union[torch.Tensor, np.ndarray],
39
+ output_path: Union[str, Path],
40
+ sample_rate: int = 48000,
41
+ format: Optional[str] = None,
42
+ channels_first: bool = True,
43
+ ) -> str:
44
+ """
45
+ Save audio data to file
46
+
47
+ Args:
48
+ audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
49
+ output_path: Output file path (extension can be omitted)
50
+ sample_rate: Sample rate
51
+ format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
52
+ channels_first: If True, tensor format is [channels, samples], else [samples, channels]
53
+
54
+ Returns:
55
+ Actual saved file path
56
+ """
57
+ format = (format or self.default_format).lower()
58
+ if format not in ["flac", "wav", "mp3"]:
59
+ logger.warning(f"Unsupported format {format}, using {self.default_format}")
60
+ format = self.default_format
61
+
62
+ # Ensure output path has correct extension
63
+ output_path = Path(output_path)
64
+ if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
65
+ output_path = output_path.with_suffix(f'.{format}')
66
+
67
+ # Convert to torch tensor
68
+ if isinstance(audio_data, np.ndarray):
69
+ if channels_first:
70
+ # numpy [samples, channels] -> tensor [channels, samples]
71
+ audio_tensor = torch.from_numpy(audio_data.T).float()
72
+ else:
73
+ # numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
74
+ audio_tensor = torch.from_numpy(audio_data).float()
75
+ if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
76
+ audio_tensor = audio_tensor.T
77
+ else:
78
+ # torch tensor
79
+ audio_tensor = audio_data.cpu().float()
80
+ if not channels_first and audio_tensor.dim() == 2:
81
+ # [samples, channels] -> [channels, samples]
82
+ if audio_tensor.shape[0] > audio_tensor.shape[1]:
83
+ audio_tensor = audio_tensor.T
84
+
85
+ # Ensure memory is contiguous
86
+ audio_tensor = audio_tensor.contiguous()
87
+
88
+ # Select backend and save
89
+ try:
90
+ if format == "mp3":
91
+ # MP3 uses ffmpeg backend
92
+ torchaudio.save(
93
+ str(output_path),
94
+ audio_tensor,
95
+ sample_rate,
96
+ channels_first=True,
97
+ backend='ffmpeg',
98
+ )
99
+ elif format in ["flac", "wav"]:
100
+ # FLAC and WAV use soundfile backend (fastest)
101
+ torchaudio.save(
102
+ str(output_path),
103
+ audio_tensor,
104
+ sample_rate,
105
+ channels_first=True,
106
+ backend='soundfile',
107
+ )
108
+ else:
109
+ # Other formats use default backend
110
+ torchaudio.save(
111
+ str(output_path),
112
+ audio_tensor,
113
+ sample_rate,
114
+ channels_first=True,
115
+ )
116
+
117
+ logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
118
+ return str(output_path)
119
+
120
+ except Exception as e:
121
+ try:
122
+ import soundfile as sf
123
+ audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
124
+ sf.write(str(output_path), audio_np, sample_rate, format=format.upper())
125
+ logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
126
+ return str(output_path)
127
+ except Exception as e:
128
+ logger.error(f"[AudioSaver] Failed to save audio: {e}")
129
+ raise
130
+
131
+ def convert_audio(
132
+ self,
133
+ input_path: Union[str, Path],
134
+ output_path: Union[str, Path],
135
+ output_format: str,
136
+ remove_input: bool = False,
137
+ ) -> str:
138
+ """
139
+ Convert audio format
140
+
141
+ Args:
142
+ input_path: Input audio file path
143
+ output_path: Output audio file path
144
+ output_format: Target format ('flac', 'wav', 'mp3')
145
+ remove_input: Whether to delete input file
146
+
147
+ Returns:
148
+ Output file path
149
+ """
150
+ input_path = Path(input_path)
151
+ output_path = Path(output_path)
152
+
153
+ if not input_path.exists():
154
+ raise FileNotFoundError(f"Input file not found: {input_path}")
155
+
156
+ # Load audio
157
+ audio_tensor, sample_rate = torchaudio.load(str(input_path))
158
+
159
+ # Save as new format
160
+ output_path = self.save_audio(
161
+ audio_tensor,
162
+ output_path,
163
+ sample_rate=sample_rate,
164
+ format=output_format,
165
+ channels_first=True
166
+ )
167
+
168
+ # Delete input file if needed
169
+ if remove_input:
170
+ input_path.unlink()
171
+ logger.debug(f"[AudioSaver] Removed input file: {input_path}")
172
+
173
+ return output_path
174
+
175
+ def save_batch(
176
+ self,
177
+ audio_batch: Union[List[torch.Tensor], torch.Tensor],
178
+ output_dir: Union[str, Path],
179
+ file_prefix: str = "audio",
180
+ sample_rate: int = 48000,
181
+ format: Optional[str] = None,
182
+ channels_first: bool = True,
183
+ ) -> List[str]:
184
+ """
185
+ Save audio batch
186
+
187
+ Args:
188
+ audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
189
+ output_dir: Output directory
190
+ file_prefix: File prefix
191
+ sample_rate: Sample rate
192
+ format: Audio format
193
+ channels_first: Tensor format flag
194
+
195
+ Returns:
196
+ List of saved file paths
197
+ """
198
+ output_dir = Path(output_dir)
199
+ output_dir.mkdir(parents=True, exist_ok=True)
200
+
201
+ # Process batch
202
+ if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
203
+ # [batch, channels, samples]
204
+ audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
205
+ elif isinstance(audio_batch, list):
206
+ audio_list = audio_batch
207
+ else:
208
+ audio_list = [audio_batch]
209
+
210
+ saved_paths = []
211
+ for i, audio in enumerate(audio_list):
212
+ output_path = output_dir / f"{file_prefix}_{i:04d}"
213
+ saved_path = self.save_audio(
214
+ audio,
215
+ output_path,
216
+ sample_rate=sample_rate,
217
+ format=format,
218
+ channels_first=channels_first
219
+ )
220
+ saved_paths.append(saved_path)
221
+
222
+ return saved_paths
223
+
224
+
225
+ def get_audio_file_hash(audio_file) -> str:
226
+ """
227
+ Get hash identifier for an audio file.
228
+
229
+ Args:
230
+ audio_file: Path to audio file (str) or file-like object
231
+
232
+ Returns:
233
+ Hash string or empty string
234
+ """
235
+ if audio_file is None:
236
+ return ""
237
+
238
+ try:
239
+ if isinstance(audio_file, str):
240
+ if os.path.exists(audio_file):
241
+ with open(audio_file, 'rb') as f:
242
+ return hashlib.md5(f.read()).hexdigest()
243
+ return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
244
+ elif hasattr(audio_file, 'name'):
245
+ return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
246
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
247
+ except Exception:
248
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
249
+
250
+
251
+ def generate_uuid_from_params(params_dict) -> str:
252
+ """
253
+ Generate deterministic UUID from generation parameters.
254
+ Same parameters will always generate the same UUID.
255
+
256
+ Args:
257
+ params_dict: Dictionary of parameters
258
+
259
+ Returns:
260
+ UUID string
261
+ """
262
+
263
+ params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
264
+ hash_obj = hashlib.sha256(params_json.encode('utf-8'))
265
+ hash_hex = hash_obj.hexdigest()
266
+ uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
267
+ return uuid_str
268
+
269
+
270
+ def generate_uuid_from_audio_data(
271
+ audio_data: Union[torch.Tensor, np.ndarray],
272
+ seed: Optional[int] = None
273
+ ) -> str:
274
+ """
275
+ Generate UUID from audio data (for caching/deduplication)
276
+
277
+ Args:
278
+ audio_data: Audio data
279
+ seed: Optional seed value
280
+
281
+ Returns:
282
+ UUID string
283
+ """
284
+ if isinstance(audio_data, torch.Tensor):
285
+ # Convert to numpy and calculate hash
286
+ audio_np = audio_data.cpu().numpy()
287
+ else:
288
+ audio_np = audio_data
289
+
290
+ # Calculate data hash
291
+ data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
292
+
293
+ if seed is not None:
294
+ combined = f"{data_hash}_{seed}"
295
+ return hashlib.md5(combined.encode()).hexdigest()
296
+
297
+ return data_hash
298
+
299
+
300
+ # Global default instance
301
+ _default_saver = AudioSaver(default_format="flac")
302
+
303
+ SILENT_RMS_THRESHOLD = 1e-5
304
+ SILENT_PEAK_THRESHOLD = 1e-5
305
+
306
+
307
+ def is_audio_silent(
308
+ audio_data: Union[torch.Tensor, np.ndarray],
309
+ rms_threshold: float = SILENT_RMS_THRESHOLD,
310
+ peak_threshold: float = SILENT_PEAK_THRESHOLD,
311
+ channels_first: bool = True,
312
+ ) -> Tuple[bool, float, float]:
313
+ """
314
+ Check if audio is silent or near-silent (e.g. zeroed conditioning output).
315
+ Returns (is_silent, rms, peak) where rms/peak are computed over the full signal.
316
+ """
317
+ if audio_data is None:
318
+ return True, 0.0, 0.0
319
+ if isinstance(audio_data, np.ndarray):
320
+ x = np.asarray(audio_data, dtype=np.float64).ravel()
321
+ else:
322
+ x = audio_data.cpu().float().numpy().ravel()
323
+ if x.size == 0:
324
+ return True, 0.0, 0.0
325
+ rms = float(np.sqrt(np.mean(x * x)))
326
+ peak = float(np.max(np.abs(x)))
327
+ is_silent = rms <= rms_threshold and peak <= peak_threshold
328
+ return is_silent, rms, peak
329
+
330
+
331
+ def save_audio(
332
+ audio_data: Union[torch.Tensor, np.ndarray],
333
+ output_path: Union[str, Path],
334
+ sample_rate: int = 48000,
335
+ format: Optional[str] = None,
336
+ channels_first: bool = True,
337
+ ) -> str:
338
+ """
339
+ Convenience function: save audio (using default configuration)
340
+
341
+ Args:
342
+ audio_data: Audio data
343
+ output_path: Output path
344
+ sample_rate: Sample rate
345
+ format: Format (default flac)
346
+ channels_first: Tensor format flag
347
+
348
+ Returns:
349
+ Saved file path
350
+ """
351
+ return _default_saver.save_audio(
352
+ audio_data, output_path, sample_rate, format, channels_first
353
+ )
354
+
acestep/constants.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants for ACE-Step
3
+ Centralized constants used across the codebase
4
+ """
5
+
6
+ # ==============================================================================
7
+ # Language Constants
8
+ # ==============================================================================
9
+
10
+ # Supported languages for vocal generation and language detection
11
+ # Covers major world languages with good TTS support in the underlying model
12
+ # 'unknown' is used when language cannot be determined automatically
13
+ VALID_LANGUAGES = [
14
+ 'ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en',
15
+ 'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id',
16
+ 'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no',
17
+ 'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw',
18
+ 'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh',
19
+ 'unknown'
20
+ ]
21
+
22
+
23
+ # ==============================================================================
24
+ # Keyscale Constants
25
+ # ==============================================================================
26
+
27
+ # Musical note names using standard Western notation
28
+ KEYSCALE_NOTES = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
29
+
30
+ # Supported accidentals: natural, ASCII sharp/flat, Unicode sharp/flat
31
+ KEYSCALE_ACCIDENTALS = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
32
+
33
+ # Major and minor scale modes
34
+ KEYSCALE_MODES = ['major', 'minor']
35
+
36
+ # Generate all valid keyscales: 7 notes × 5 accidentals × 2 modes = 70 combinations
37
+ # Examples: "C major", "F# minor", "B♭ major"
38
+ VALID_KEYSCALES = set()
39
+ for note in KEYSCALE_NOTES:
40
+ for acc in KEYSCALE_ACCIDENTALS:
41
+ for mode in KEYSCALE_MODES:
42
+ VALID_KEYSCALES.add(f"{note}{acc} {mode}")
43
+
44
+
45
+ # ==============================================================================
46
+ # Metadata Range Constants
47
+ # ==============================================================================
48
+
49
+ # BPM (Beats Per Minute) range - covers most musical styles
50
+ # 30 BPM: Very slow ballads, ambient music
51
+ # 300 BPM: Fast electronic dance music, extreme metal
52
+ BPM_MIN = 30
53
+ BPM_MAX = 300
54
+
55
+ # Duration range (in seconds) - balances quality vs. computational cost
56
+ # 10s: Short loops, musical excerpts
57
+ # 600s: Full songs, extended compositions (10 minutes)
58
+ DURATION_MIN = 10
59
+ DURATION_MAX = 600
60
+
61
+ # Valid time signatures - common musical meter patterns
62
+ # 2: 2/4 time (marches, polka)
63
+ # 3: 3/4 time (waltzes, ballads)
64
+ # 4: 4/4 time (most pop, rock, hip-hop)
65
+ # 6: 6/8 time (compound time, folk dances)
66
+ VALID_TIME_SIGNATURES = [2, 3, 4, 6]
67
+
68
+
69
+ # ==============================================================================
70
+ # Task Type Constants
71
+ # ==============================================================================
72
+
73
+ # All supported generation tasks across different model variants
74
+ TASK_TYPES = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
75
+
76
+ # Task types available for turbo models (optimized subset for speed)
77
+ # - text2music: Generate from text descriptions
78
+ # - repaint: Selective audio editing/regeneration
79
+ # - cover: Style transfer using reference audio
80
+ TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
81
+
82
+ # Task types available for base models (full feature set)
83
+ # Additional tasks requiring more computational resources:
84
+ # - extract: Separate individual tracks/stems from audio
85
+ # - lego: Multi-track generation (add layers)
86
+ # - complete: Automatic completion of partial audio
87
+ TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
88
+
89
+
90
+ # ==============================================================================
91
+ # Instruction Constants
92
+ # ==============================================================================
93
+
94
+ # Default instructions
95
+ DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
96
+ DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
97
+ DEFAULT_LM_UNDERSTAND_INSTRUCTION = "Understand the given musical conditions and describe the audio semantics accordingly:"
98
+ DEFAULT_LM_INSPIRED_INSTRUCTION = "Expand the user's input into a more detailed and specific musical description:"
99
+ DEFAULT_LM_REWRITE_INSTRUCTION = "Format the user's input into a more detailed and specific musical description:"
100
+
101
+ # Instruction templates for each task type
102
+ # Note: Some instructions use placeholders like {TRACK_NAME} or {TRACK_CLASSES}
103
+ # These should be formatted using .format() or f-strings when used
104
+ TASK_INSTRUCTIONS = {
105
+ "text2music": "Fill the audio semantic mask based on the given conditions:",
106
+ "repaint": "Repaint the mask area based on the given conditions:",
107
+ "cover": "Generate audio semantic tokens based on the given conditions:",
108
+ "extract": "Extract the {TRACK_NAME} track from the audio:",
109
+ "extract_default": "Extract the track from the audio:",
110
+ "lego": "Generate the {TRACK_NAME} track based on the audio context:",
111
+ "lego_default": "Generate the track based on the audio context:",
112
+ "complete": "Complete the input track with {TRACK_CLASSES}:",
113
+ "complete_default": "Complete the input track:",
114
+ }
115
+
116
+
117
+ # ==============================================================================
118
+ # Track/Instrument Constants
119
+ # ==============================================================================
120
+
121
+ # Supported instrumental track types for multi-track generation and extraction
122
+ # Organized by instrument families for logical grouping:
123
+ # - Wind instruments: woodwinds, brass
124
+ # - Electronic: fx (effects), synth (synthesizer)
125
+ # - String instruments: strings, guitar, bass
126
+ # - Rhythm section: percussion, drums, keyboard
127
+ # - Vocals: backing_vocals, vocals (lead vocals)
128
+ TRACK_NAMES = [
129
+ "woodwinds", "brass", "fx", "synth", "strings", "percussion",
130
+ "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
131
+ ]
132
+
133
+ # Template for SFT (Supervised Fine-Tuning) model prompts
134
+ # Used to format inputs for the language model with instruction, caption, and metadata
135
+ SFT_GEN_PROMPT = """# Instruction
136
+ {}
137
+
138
+ # Caption
139
+ {}
140
+
141
+ # Metas
142
+ {}<|endoftext|>
143
+ """
144
+
145
+
146
+ # ==============================================================================
147
+ # GPU Memory Configuration Constants
148
+ # ==============================================================================
149
+
150
+ # GPU tier thresholds (in GB)
151
+ GPU_TIER_THRESHOLDS = {
152
+ "tier1": 4, # <= 4GB
153
+ "tier2": 6, # 4-6GB
154
+ "tier3": 8, # 6-8GB
155
+ "tier4": 12, # 8-12GB
156
+ "tier5": 16, # 12-16GB
157
+ "tier6": 24, # 16-24GB
158
+ # "unlimited" for >= 24GB
159
+ }
160
+
161
+ # LM model memory requirements (in GB)
162
+ LM_MODEL_MEMORY_GB = {
163
+ "0.6B": 3.0,
164
+ "1.7B": 8.0,
165
+ "4B": 12.0,
166
+ }
167
+
168
+ # LM model names mapping
169
+ LM_MODEL_NAMES = {
170
+ "0.6B": "acestep-5Hz-lm-0.6B",
171
+ "1.7B": "acestep-5Hz-lm-1.7B",
172
+ "4B": "acestep-5Hz-lm-4B",
173
+ }
174
+
175
+
176
+ # ==============================================================================
177
+ # Debug Constants
178
+ # ==============================================================================
179
+
180
+ # Tensor debug mode (values: "OFF" | "ON" | "VERBOSE")
181
+ TENSOR_DEBUG_MODE = "OFF"
182
+
183
+ # Placeholder debug switches for other main functionality (default "OFF")
184
+ # Update names/usage as features adopt them.
185
+ DEBUG_API_SERVER = "OFF"
186
+ DEBUG_INFERENCE = "OFF"
187
+ DEBUG_TRAINING = "OFF"
188
+ DEBUG_DATASET = "OFF"
189
+ DEBUG_AUDIO = "OFF"
190
+ DEBUG_LLM = "OFF"
191
+ DEBUG_UI = "OFF"
192
+ DEBUG_MODEL_LOADING = "OFF"
193
+ DEBUG_GPU = "OFF"
acestep/constrained_logits_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
acestep/dit_alignment_score.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DiT Alignment Score Module
3
+
4
+ This module provides lyrics-to-audio alignment using cross-attention matrices
5
+ from DiT model for generating LRC timestamps.
6
+
7
+ Refactored from lyrics_alignment_infos.py for integration with ACE-Step.
8
+ """
9
+ import numba
10
+ import torch
11
+ import numpy as np
12
+ import torch.nn.functional as F
13
+ from dataclasses import dataclass, asdict
14
+ from typing import List, Dict, Any, Optional, Tuple, Union
15
+
16
+
17
+ # ================= Data Classes =================
18
+ @dataclass
19
+ class TokenTimestamp:
20
+ """Stores per-token timing information."""
21
+ token_id: int
22
+ text: str
23
+ start: float
24
+ end: float
25
+ probability: float
26
+
27
+
28
+ @dataclass
29
+ class SentenceTimestamp:
30
+ """Stores per-sentence timing information with token list."""
31
+ text: str
32
+ start: float
33
+ end: float
34
+ tokens: List[TokenTimestamp]
35
+ confidence: float
36
+
37
+
38
+ # ================= DTW Algorithm (Numba Optimized) =================
39
+ @numba.jit(nopython=True)
40
+ def dtw_cpu(x: np.ndarray):
41
+ """
42
+ Dynamic Time Warping algorithm optimized with Numba.
43
+
44
+ Args:
45
+ x: Cost matrix of shape [N, M]
46
+
47
+ Returns:
48
+ Tuple of (text_indices, time_indices) arrays
49
+ """
50
+ N, M = x.shape
51
+ # Use float32 for memory efficiency
52
+ cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
53
+ trace = -np.ones((N + 1, M + 1), dtype=np.float32)
54
+ cost[0, 0] = 0
55
+
56
+ for j in range(1, M + 1):
57
+ for i in range(1, N + 1):
58
+ c0 = cost[i - 1, j - 1]
59
+ c1 = cost[i - 1, j]
60
+ c2 = cost[i, j - 1]
61
+
62
+ if c0 < c1 and c0 < c2:
63
+ c, t = c0, 0
64
+ elif c1 < c0 and c1 < c2:
65
+ c, t = c1, 1
66
+ else:
67
+ c, t = c2, 2
68
+
69
+ cost[i, j] = x[i - 1, j - 1] + c
70
+ trace[i, j] = t
71
+
72
+ return _backtrace(trace, N, M)
73
+
74
+
75
+ @numba.jit(nopython=True)
76
+ def _backtrace(trace: np.ndarray, N: int, M: int):
77
+ """
78
+ Optimized backtrace function for DTW.
79
+
80
+ Args:
81
+ trace: Trace matrix of shape (N+1, M+1)
82
+ N, M: Original matrix dimensions
83
+
84
+ Returns:
85
+ Path array of shape (2, path_len) - first row is text indices, second is time indices
86
+ """
87
+ # Boundary handling
88
+ trace[0, :] = 2
89
+ trace[:, 0] = 1
90
+
91
+ # Pre-allocate array, max path length is N+M
92
+ max_path_len = N + M
93
+ path = np.zeros((2, max_path_len), dtype=np.int32)
94
+
95
+ i, j = N, M
96
+ path_idx = max_path_len - 1
97
+
98
+ while i > 0 or j > 0:
99
+ path[0, path_idx] = i - 1 # text index
100
+ path[1, path_idx] = j - 1 # time index
101
+ path_idx -= 1
102
+
103
+ t = trace[i, j]
104
+ if t == 0:
105
+ i -= 1
106
+ j -= 1
107
+ elif t == 1:
108
+ i -= 1
109
+ elif t == 2:
110
+ j -= 1
111
+ else:
112
+ break
113
+
114
+ actual_len = max_path_len - path_idx - 1
115
+ return path[:, path_idx + 1:max_path_len]
116
+
117
+
118
+ # ================= Utility Functions =================
119
+ def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
120
+ """
121
+ Apply median filter to tensor.
122
+
123
+ Args:
124
+ x: Input tensor
125
+ filter_width: Width of median filter
126
+
127
+ Returns:
128
+ Filtered tensor
129
+ """
130
+ pad_width = filter_width // 2
131
+ if x.shape[-1] <= pad_width:
132
+ return x
133
+ if x.ndim == 2:
134
+ x = x[None, :]
135
+ x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
136
+ result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
137
+ if result.ndim > 2:
138
+ result = result.squeeze(0)
139
+ return result
140
+
141
+
142
+ # ================= Main Aligner Class =================
143
+ class MusicStampsAligner:
144
+ """
145
+ Aligner class for generating lyrics timestamps from cross-attention matrices.
146
+
147
+ Uses bidirectional consensus denoising and DTW for alignment.
148
+ """
149
+
150
+ def __init__(self, tokenizer):
151
+ """
152
+ Initialize the aligner.
153
+
154
+ Args:
155
+ tokenizer: Text tokenizer for decoding tokens
156
+ """
157
+ self.tokenizer = tokenizer
158
+
159
+ def _apply_bidirectional_consensus(
160
+ self,
161
+ weights_stack: torch.Tensor,
162
+ violence_level: float,
163
+ medfilt_width: int
164
+ ) -> tuple:
165
+ """
166
+ Core denoising logic using bidirectional consensus.
167
+
168
+ Args:
169
+ weights_stack: Attention weights [Heads, Tokens, Frames]
170
+ violence_level: Denoising strength coefficient
171
+ medfilt_width: Median filter width
172
+
173
+ Returns:
174
+ Tuple of (calc_matrix, energy_matrix) as numpy arrays
175
+ """
176
+ # A. Bidirectional Consensus
177
+ row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame
178
+ col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token
179
+ processed = row_prob * col_prob
180
+
181
+ # 1. Row suppression (kill horizontal crossing lines)
182
+ row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True)
183
+ processed = processed - (violence_level * row_medians)
184
+ processed = torch.relu(processed)
185
+
186
+ # 2. Column suppression (kill vertical crossing lines)
187
+ col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True)
188
+ processed = processed - (violence_level * col_medians)
189
+ processed = torch.relu(processed)
190
+
191
+ # C. Power sharpening
192
+ processed = processed ** 2
193
+
194
+ # Energy matrix for confidence
195
+ energy_matrix = processed.mean(dim=0).cpu().numpy()
196
+
197
+ # D. Z-Score normalization
198
+ std, mean = torch.std_mean(processed, unbiased=False)
199
+ weights_processed = (processed - mean) / (std + 1e-9)
200
+
201
+ # E. Median filtering
202
+ weights_processed = median_filter(weights_processed, filter_width=medfilt_width)
203
+ calc_matrix = weights_processed.mean(dim=0).numpy()
204
+
205
+ return calc_matrix, energy_matrix
206
+
207
+ def _preprocess_attention(
208
+ self,
209
+ attention_matrix: torch.Tensor,
210
+ custom_config: Dict[int, List[int]],
211
+ violence_level: float,
212
+ medfilt_width: int = 7
213
+ ) -> tuple:
214
+ """
215
+ Preprocess attention matrix for alignment.
216
+
217
+ Args:
218
+ attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames]
219
+ custom_config: Dict mapping layer indices to head indices
220
+ violence_level: Denoising strength
221
+ medfilt_width: Median filter width
222
+
223
+ Returns:
224
+ Tuple of (calc_matrix, energy_matrix, visual_matrix)
225
+ """
226
+ if not isinstance(attention_matrix, torch.Tensor):
227
+ weights = torch.tensor(attention_matrix)
228
+ else:
229
+ weights = attention_matrix.clone()
230
+
231
+ weights = weights.cpu().float()
232
+
233
+ selected_tensors = []
234
+ for layer_idx, head_indices in custom_config.items():
235
+ for head_idx in head_indices:
236
+ if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
237
+ head_matrix = weights[layer_idx, head_idx]
238
+ selected_tensors.append(head_matrix)
239
+
240
+ if not selected_tensors:
241
+ return None, None, None
242
+
243
+ # Stack selected heads: [Heads, Tokens, Frames]
244
+ weights_stack = torch.stack(selected_tensors, dim=0)
245
+ visual_matrix = weights_stack.mean(dim=0).numpy()
246
+
247
+ calc_matrix, energy_matrix = self._apply_bidirectional_consensus(
248
+ weights_stack, violence_level, medfilt_width
249
+ )
250
+
251
+ return calc_matrix, energy_matrix, visual_matrix
252
+
253
+ def stamps_align_info(
254
+ self,
255
+ attention_matrix: torch.Tensor,
256
+ lyrics_tokens: List[int],
257
+ total_duration_seconds: float,
258
+ custom_config: Dict[int, List[int]],
259
+ return_matrices: bool = False,
260
+ violence_level: float = 2.0,
261
+ medfilt_width: int = 1
262
+ ) -> Dict[str, Any]:
263
+ """
264
+ Get alignment information from attention matrix.
265
+
266
+ Args:
267
+ attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames]
268
+ lyrics_tokens: List of lyrics token IDs
269
+ total_duration_seconds: Total audio duration in seconds
270
+ custom_config: Dict mapping layer indices to head indices
271
+ return_matrices: Whether to return intermediate matrices
272
+ violence_level: Denoising strength
273
+ medfilt_width: Median filter width
274
+
275
+ Returns:
276
+ Dict containing calc_matrix, lyrics_tokens, total_duration_seconds,
277
+ and optionally energy_matrix and vis_matrix
278
+ """
279
+ calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention(
280
+ attention_matrix, custom_config, violence_level, medfilt_width
281
+ )
282
+
283
+ if calc_matrix is None:
284
+ return {
285
+ "calc_matrix": None,
286
+ "lyrics_tokens": lyrics_tokens,
287
+ "total_duration_seconds": total_duration_seconds,
288
+ "error": "No valid attention heads found"
289
+ }
290
+
291
+ return_dict = {
292
+ "calc_matrix": calc_matrix,
293
+ "lyrics_tokens": lyrics_tokens,
294
+ "total_duration_seconds": total_duration_seconds
295
+ }
296
+
297
+ if return_matrices:
298
+ return_dict['energy_matrix'] = energy_matrix
299
+ return_dict['vis_matrix'] = visual_matrix
300
+
301
+ return return_dict
302
+
303
+ def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]:
304
+ """
305
+ Decode tokens incrementally to properly handle multi-byte UTF-8 characters.
306
+
307
+ For Chinese and other multi-byte characters, the tokenizer may split them
308
+ into multiple byte-level tokens. Decoding each token individually produces
309
+ invalid UTF-8 sequences (showing as �). This method uses byte-level comparison
310
+ to correctly track which characters each token contributes.
311
+
312
+ Args:
313
+ token_ids: List of token IDs
314
+
315
+ Returns:
316
+ List of decoded text for each token position
317
+ """
318
+ decoded_tokens = []
319
+ prev_bytes = b""
320
+
321
+ for i in range(len(token_ids)):
322
+ # Decode tokens from start to current position
323
+ current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False)
324
+ current_bytes = current_text.encode('utf-8', errors='surrogatepass')
325
+
326
+ # The contribution of current token is the new bytes added
327
+ if len(current_bytes) >= len(prev_bytes):
328
+ new_bytes = current_bytes[len(prev_bytes):]
329
+ # Try to decode the new bytes; if incomplete, use empty string
330
+ try:
331
+ token_text = new_bytes.decode('utf-8')
332
+ except UnicodeDecodeError:
333
+ # Incomplete UTF-8 sequence, this token doesn't complete a character
334
+ token_text = ""
335
+ else:
336
+ # Edge case: current decode is shorter (shouldn't happen normally)
337
+ token_text = ""
338
+
339
+ decoded_tokens.append(token_text)
340
+ prev_bytes = current_bytes
341
+
342
+ return decoded_tokens
343
+
344
+ def token_timestamps(
345
+ self,
346
+ calc_matrix: np.ndarray,
347
+ lyrics_tokens: List[int],
348
+ total_duration_seconds: float
349
+ ) -> List[TokenTimestamp]:
350
+ """
351
+ Generate per-token timestamps using DTW.
352
+
353
+ Args:
354
+ calc_matrix: Processed attention matrix [Tokens, Frames]
355
+ lyrics_tokens: List of token IDs
356
+ total_duration_seconds: Total audio duration
357
+
358
+ Returns:
359
+ List of TokenTimestamp objects
360
+ """
361
+ n_frames = calc_matrix.shape[-1]
362
+ text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64))
363
+
364
+ seconds_per_frame = total_duration_seconds / n_frames
365
+ alignment_results = []
366
+
367
+ # Use incremental decoding to properly handle multi-byte UTF-8 characters
368
+ decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens)
369
+
370
+ for i in range(len(lyrics_tokens)):
371
+ mask = (text_indices == i)
372
+
373
+ if not np.any(mask):
374
+ start = alignment_results[-1].end if alignment_results else 0.0
375
+ end = start
376
+ token_conf = 0.0
377
+ else:
378
+ times = time_indices[mask] * seconds_per_frame
379
+ start = times[0]
380
+ end = times[-1]
381
+ token_conf = 0.0
382
+
383
+ if end < start:
384
+ end = start
385
+
386
+ alignment_results.append(TokenTimestamp(
387
+ token_id=lyrics_tokens[i],
388
+ text=decoded_tokens[i],
389
+ start=float(start),
390
+ end=float(end),
391
+ probability=token_conf
392
+ ))
393
+
394
+ return alignment_results
395
+
396
+ def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str:
397
+ """
398
+ Decode a sentence by decoding all token IDs together.
399
+ This avoids UTF-8 encoding issues from joining individual token texts.
400
+
401
+ Args:
402
+ tokens: List of TokenTimestamp objects
403
+
404
+ Returns:
405
+ Properly decoded sentence text
406
+ """
407
+ token_ids = [t.token_id for t in tokens]
408
+ return self.tokenizer.decode(token_ids, skip_special_tokens=False)
409
+
410
+ def sentence_timestamps(
411
+ self,
412
+ token_alignment: List[TokenTimestamp]
413
+ ) -> List[SentenceTimestamp]:
414
+ """
415
+ Group token timestamps into sentence timestamps.
416
+
417
+ Args:
418
+ token_alignment: List of TokenTimestamp objects
419
+
420
+ Returns:
421
+ List of SentenceTimestamp objects
422
+ """
423
+ results = []
424
+ current_tokens = []
425
+
426
+ for token in token_alignment:
427
+ current_tokens.append(token)
428
+
429
+ if '\n' in token.text:
430
+ # Decode all token IDs together to avoid UTF-8 issues
431
+ full_text = self._decode_sentence_from_tokens(current_tokens)
432
+
433
+ if full_text.strip():
434
+ valid_scores = [t.probability for t in current_tokens if t.probability > 0]
435
+ sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
436
+
437
+ results.append(SentenceTimestamp(
438
+ text=full_text.strip(),
439
+ start=round(current_tokens[0].start, 3),
440
+ end=round(current_tokens[-1].end, 3),
441
+ tokens=list(current_tokens),
442
+ confidence=sent_conf
443
+ ))
444
+
445
+ current_tokens = []
446
+
447
+ # Handle last sentence
448
+ if current_tokens:
449
+ # Decode all token IDs together to avoid UTF-8 issues
450
+ full_text = self._decode_sentence_from_tokens(current_tokens)
451
+ if full_text.strip():
452
+ valid_scores = [t.probability for t in current_tokens if t.probability > 0]
453
+ sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
454
+
455
+ results.append(SentenceTimestamp(
456
+ text=full_text.strip(),
457
+ start=round(current_tokens[0].start, 3),
458
+ end=round(current_tokens[-1].end, 3),
459
+ tokens=list(current_tokens),
460
+ confidence=sent_conf
461
+ ))
462
+
463
+ # Normalize confidence scores
464
+ if results:
465
+ all_scores = [s.confidence for s in results]
466
+ min_score = min(all_scores)
467
+ max_score = max(all_scores)
468
+ score_range = max_score - min_score
469
+
470
+ if score_range > 1e-9:
471
+ for s in results:
472
+ normalized_score = (s.confidence - min_score) / score_range
473
+ s.confidence = round(normalized_score, 2)
474
+ else:
475
+ for s in results:
476
+ s.confidence = round(s.confidence, 2)
477
+
478
+ return results
479
+
480
+ def format_lrc(
481
+ self,
482
+ sentence_timestamps: List[SentenceTimestamp],
483
+ include_end_time: bool = False
484
+ ) -> str:
485
+ """
486
+ Format sentence timestamps as LRC lyrics format.
487
+
488
+ Args:
489
+ sentence_timestamps: List of SentenceTimestamp objects
490
+ include_end_time: Whether to include end time (enhanced LRC format)
491
+
492
+ Returns:
493
+ LRC formatted string
494
+ """
495
+ lines = []
496
+
497
+ for sentence in sentence_timestamps:
498
+ # Convert seconds to mm:ss.xx format
499
+ start_minutes = int(sentence.start // 60)
500
+ start_seconds = sentence.start % 60
501
+
502
+ if include_end_time:
503
+ end_minutes = int(sentence.end // 60)
504
+ end_seconds = sentence.end % 60
505
+ timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]"
506
+ else:
507
+ timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]"
508
+
509
+ # Clean the text (remove structural tags like [verse], [chorus])
510
+ text = sentence.text
511
+
512
+ lines.append(f"{timestamp}{text}")
513
+
514
+ return "\n".join(lines)
515
+
516
+ def get_timestamps_and_lrc(
517
+ self,
518
+ calc_matrix: np.ndarray,
519
+ lyrics_tokens: List[int],
520
+ total_duration_seconds: float
521
+ ) -> Dict[str, Any]:
522
+ """
523
+ Convenience method to get both timestamps and LRC in one call.
524
+
525
+ Args:
526
+ calc_matrix: Processed attention matrix
527
+ lyrics_tokens: List of token IDs
528
+ total_duration_seconds: Total audio duration
529
+
530
+ Returns:
531
+ Dict containing token_timestamps, sentence_timestamps, and lrc_text
532
+ """
533
+ token_stamps = self.token_timestamps(
534
+ calc_matrix=calc_matrix,
535
+ lyrics_tokens=lyrics_tokens,
536
+ total_duration_seconds=total_duration_seconds
537
+ )
538
+
539
+ sentence_stamps = self.sentence_timestamps(token_stamps)
540
+ lrc_text = self.format_lrc(sentence_stamps)
541
+
542
+ return {
543
+ "token_timestamps": token_stamps,
544
+ "sentence_timestamps": sentence_stamps,
545
+ "lrc_text": lrc_text
546
+ }
547
+
548
+
549
+ class MusicLyricScorer:
550
+ """
551
+ Scorer class for evaluating lyrics-to-audio alignment quality.
552
+
553
+ Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence)
554
+ using tensor operations for potential differentiability or GPU acceleration.
555
+ """
556
+
557
+ def __init__(self, tokenizer: Any):
558
+ """
559
+ Initialize the aligner.
560
+
561
+ Args:
562
+ tokenizer: Tokenizer instance (must implement .decode()).
563
+ """
564
+ self.tokenizer = tokenizer
565
+
566
+ def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray:
567
+ """
568
+ Generate a mask distinguishing lyrics (1) from structural tags (0).
569
+ Uses self.tokenizer to decode tokens.
570
+
571
+ Args:
572
+ token_ids: List of token IDs.
573
+
574
+ Returns:
575
+ Numpy array of shape [len(token_ids)] with 1 or 0.
576
+ """
577
+ decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids]
578
+ mask = np.ones(len(token_ids), dtype=np.int32)
579
+ in_bracket = False
580
+
581
+ for i, token_str in enumerate(decoded_tokens):
582
+ if '[' in token_str:
583
+ in_bracket = True
584
+ if in_bracket:
585
+ mask[i] = 0
586
+ if ']' in token_str:
587
+ in_bracket = False
588
+ mask[i] = 0
589
+ return mask
590
+
591
+ def _preprocess_attention(
592
+ self,
593
+ attention_matrix: Union[torch.Tensor, np.ndarray],
594
+ custom_config: Dict[int, List[int]],
595
+ medfilt_width: int = 1
596
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]:
597
+ """
598
+ Extracts and normalizes the attention matrix.
599
+
600
+ Logic V4: Uses Min-Max normalization to highlight energy differences.
601
+
602
+ Args:
603
+ attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames].
604
+ custom_config: Config mapping layers to heads.
605
+ medfilt_width: Width for median filtering.
606
+
607
+ Returns:
608
+ Tuple of (calc_matrix, energy_matrix, avg_weights_tensor).
609
+ """
610
+ # 1. Prepare Tensor
611
+ if not isinstance(attention_matrix, torch.Tensor):
612
+ weights = torch.tensor(attention_matrix)
613
+ else:
614
+ weights = attention_matrix.clone()
615
+ weights = weights.cpu().float()
616
+
617
+ # 2. Select Heads based on config
618
+ selected_tensors = []
619
+ for layer_idx, head_indices in custom_config.items():
620
+ for head_idx in head_indices:
621
+ if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
622
+ selected_tensors.append(weights[layer_idx, head_idx])
623
+
624
+ if not selected_tensors:
625
+ return None, None, None
626
+
627
+ weights_stack = torch.stack(selected_tensors, dim=0)
628
+
629
+ # 3. Average Heads
630
+ avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames]
631
+
632
+ # 4. Preprocessing Logic
633
+ # Min-Max normalization preserving energy distribution
634
+ # Median filter is applied to the energy matrix
635
+ energy_tensor = median_filter(avg_weights, filter_width=medfilt_width)
636
+ energy_matrix = energy_tensor.numpy()
637
+
638
+ e_min, e_max = energy_matrix.min(), energy_matrix.max()
639
+
640
+ if e_max - e_min > 1e-9:
641
+ energy_matrix = (energy_matrix - e_min) / (e_max - e_min)
642
+ else:
643
+ energy_matrix = np.zeros_like(energy_matrix)
644
+
645
+ # Contrast enhancement for DTW pathfinding
646
+ # calc_matrix is used for pathfinding, energy_matrix for scoring
647
+ calc_matrix = energy_matrix ** 2
648
+
649
+ return calc_matrix, energy_matrix, avg_weights
650
+
651
+ def _compute_alignment_metrics(
652
+ self,
653
+ energy_matrix: torch.Tensor,
654
+ path_coords: torch.Tensor,
655
+ type_mask: torch.Tensor,
656
+ time_weight: float = 0.01,
657
+ overlap_frames: float = 9.0,
658
+ instrumental_weight: float = 1.0
659
+ ) -> Tuple[float, float, float]:
660
+ """
661
+ Core metric calculation logic using high-precision Tensor operations.
662
+
663
+ Args:
664
+ energy_matrix: Normalized energy [Rows, Cols].
665
+ path_coords: DTW path coordinates [Steps, 2].
666
+ type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags).
667
+ time_weight: Minimum energy threshold for monotonicity.
668
+ overlap_frames: Allowed overlap for monotonicity check.
669
+ instrumental_weight: Weight for non-lyric tokens in confidence calc.
670
+
671
+ Returns:
672
+ Tuple of (coverage, monotonicity, confidence).
673
+ """
674
+ # Ensure high precision for internal calculation
675
+ energy_matrix = energy_matrix.to(dtype=torch.float64)
676
+ path_coords = path_coords.long()
677
+ type_mask = type_mask.long()
678
+
679
+ device = energy_matrix.device
680
+ rows, cols = energy_matrix.shape
681
+
682
+ is_lyrics_row = (type_mask == 1)
683
+
684
+ # ================= A. Coverage Score =================
685
+ # Ratio of lyric lines that have significant energy peak
686
+ row_max_energies = energy_matrix.max(dim=1).values
687
+ total_sung_rows = is_lyrics_row.sum().double()
688
+
689
+ coverage_threshold = 0.1
690
+ valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold)
691
+ valid_sung_rows = valid_sung_mask.sum().double()
692
+
693
+ if total_sung_rows > 0:
694
+ coverage_score = valid_sung_rows / total_sung_rows
695
+ else:
696
+ coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64)
697
+
698
+ # ================= B. Monotonicity Score =================
699
+ # Check if the "center of mass" of lyric lines moves forward in time
700
+ col_indices = torch.arange(cols, device=device, dtype=torch.float64)
701
+
702
+ # Zero out low energy noise
703
+ weights = torch.where(
704
+ energy_matrix > time_weight,
705
+ energy_matrix,
706
+ torch.zeros_like(energy_matrix)
707
+ )
708
+
709
+ sum_w = weights.sum(dim=1)
710
+ sum_t = (weights * col_indices).sum(dim=1)
711
+
712
+ # Calculate centroids
713
+ centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64)
714
+ valid_w_mask = sum_w > 1e-9
715
+ centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask]
716
+
717
+ # Extract sequence of valid lyrics centroids
718
+ valid_sequence_mask = is_lyrics_row & (centroids >= 0)
719
+ sung_centroids = centroids[valid_sequence_mask]
720
+
721
+ cnt = sung_centroids.shape[0]
722
+ if cnt > 1:
723
+ curr_c = sung_centroids[:-1]
724
+ next_c = sung_centroids[1:]
725
+
726
+ # Check non-decreasing order with overlap tolerance
727
+ non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum()
728
+ pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64)
729
+ monotonicity_score = non_decreasing / pairs
730
+ else:
731
+ monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64)
732
+
733
+ # ================= C. Path Confidence =================
734
+ # Average energy along the optimal path
735
+ if path_coords.shape[0] > 0:
736
+ p_rows = path_coords[:, 0]
737
+ p_cols = path_coords[:, 1]
738
+
739
+ path_energies = energy_matrix[p_rows, p_cols]
740
+ step_weights = torch.ones_like(path_energies)
741
+
742
+ # Lower weight for instrumental/tag steps
743
+ is_inst_step = (type_mask[p_rows] == 0)
744
+ step_weights[is_inst_step] = instrumental_weight
745
+
746
+ total_energy = (path_energies * step_weights).sum()
747
+ total_steps = step_weights.sum()
748
+
749
+ if total_steps > 0:
750
+ path_confidence = total_energy / total_steps
751
+ else:
752
+ path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
753
+ else:
754
+ path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
755
+
756
+ return coverage_score.item(), monotonicity_score.item(), path_confidence.item()
757
+
758
+ def lyrics_alignment_info(
759
+ self,
760
+ attention_matrix: Union[torch.Tensor, np.ndarray],
761
+ token_ids: List[int],
762
+ custom_config: Dict[int, List[int]],
763
+ return_matrices: bool = False,
764
+ medfilt_width: int = 1
765
+ ) -> Dict[str, Any]:
766
+ """
767
+ Generates alignment path and processed matrices.
768
+
769
+ Args:
770
+ attention_matrix: Input attention tensor.
771
+ token_ids: Corresponding token IDs.
772
+ custom_config: Layer/Head configuration.
773
+ return_matrices: If True, returns matrices in the output.
774
+ medfilt_width: Median filter width.
775
+
776
+ Returns:
777
+ Dict or AlignmentInfo object containing path and masks.
778
+ """
779
+ calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention(
780
+ attention_matrix, custom_config, medfilt_width
781
+ )
782
+
783
+ if calc_matrix is None:
784
+ return {
785
+ "calc_matrix": None,
786
+ "error": "No valid attention heads found"
787
+ }
788
+
789
+ # 1. Generate Semantic Mask (1=Lyrics, 0=Tags)
790
+ # Uses self.tokenizer internally
791
+ type_mask = self._generate_token_type_mask(token_ids)
792
+
793
+ # Safety check for shape mismatch
794
+ if len(type_mask) != energy_matrix.shape[0]:
795
+ # Fallback to all lyrics if shapes don't align
796
+ type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32)
797
+
798
+ # 2. DTW Pathfinding
799
+ # Using negative calc_matrix because DTW minimizes cost
800
+ text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32))
801
+ path_coords = np.stack([text_indices, time_indices], axis=1)
802
+
803
+ return_dict = {
804
+ "path_coords": path_coords,
805
+ "type_mask": type_mask,
806
+ "energy_matrix": energy_matrix
807
+ }
808
+ if return_matrices:
809
+ return_dict['calc_matrix'] = calc_matrix
810
+ return_dict['vis_matrix'] = vis_matrix
811
+
812
+ return return_dict
813
+
814
+ def calculate_score(
815
+ self,
816
+ energy_matrix: Union[torch.Tensor, np.ndarray],
817
+ type_mask: Union[torch.Tensor, np.ndarray],
818
+ path_coords: Union[torch.Tensor, np.ndarray],
819
+ time_weight: float = 0.01,
820
+ overlap_frames: float = 9.0,
821
+ instrumental_weight: float = 1.0
822
+ ) -> Dict[str, Any]:
823
+ """
824
+ Calculates the final alignment score based on pre-computed components.
825
+
826
+ Args:
827
+ energy_matrix: Processed energy matrix.
828
+ type_mask: Token type mask.
829
+ path_coords: DTW path coordinates.
830
+ time_weight: Minimum energy threshold for monotonicity.
831
+ overlap_frames: Allowed backward movement frames.
832
+ instrumental_weight: Weight for non-lyric path steps.
833
+
834
+ Returns:
835
+ AlignmentScore object containing individual metrics and final score.
836
+ """
837
+ # Ensure Inputs are Tensors on the correct device
838
+ if not isinstance(energy_matrix, torch.Tensor):
839
+ # Use available accelerator device; fallback to CPU if none
840
+ if torch.cuda.is_available():
841
+ _score_device = "cuda"
842
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
843
+ _score_device = "mps"
844
+ else:
845
+ _score_device = "cpu"
846
+ energy_matrix = torch.tensor(energy_matrix, device=_score_device, dtype=torch.float32)
847
+
848
+ device = energy_matrix.device
849
+
850
+ if not isinstance(type_mask, torch.Tensor):
851
+ type_mask = torch.tensor(type_mask, device=device, dtype=torch.long)
852
+ else:
853
+ type_mask = type_mask.to(device=device, dtype=torch.long)
854
+
855
+ if not isinstance(path_coords, torch.Tensor):
856
+ path_coords = torch.tensor(path_coords, device=device, dtype=torch.long)
857
+ else:
858
+ path_coords = path_coords.to(device=device, dtype=torch.long)
859
+
860
+ # Compute Metrics
861
+ coverage, monotonicity, confidence = self._compute_alignment_metrics(
862
+ energy_matrix=energy_matrix,
863
+ path_coords=path_coords,
864
+ type_mask=type_mask,
865
+ time_weight=time_weight,
866
+ overlap_frames=overlap_frames,
867
+ instrumental_weight=instrumental_weight
868
+ )
869
+
870
+ # Final Score Calculation
871
+ # (Cov^2 * Mono^2 * Conf)
872
+ final_score = (coverage ** 2) * (monotonicity ** 2) * confidence
873
+ final_score = float(np.clip(final_score, 0.0, 1.0))
874
+
875
+ return {
876
+ "lyrics_score": round(final_score, 4)
877
+ }
acestep/genres_vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
acestep/gpu_config.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPU Configuration Module
3
+ Centralized GPU memory detection and adaptive configuration management
4
+
5
+ Debug Mode:
6
+ Set environment variable MAX_CUDA_VRAM to simulate different GPU memory sizes.
7
+ Example: MAX_CUDA_VRAM=8 python acestep # Simulates 8GB GPU
8
+
9
+ For MPS testing, use MAX_MPS_VRAM to simulate MPS memory.
10
+ Example: MAX_MPS_VRAM=16 python acestep # Simulates 16GB MPS
11
+
12
+ This is useful for testing GPU tier configurations on high-end hardware.
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ from dataclasses import dataclass
18
+ from typing import Optional, List, Dict, Tuple
19
+ from loguru import logger
20
+
21
+
22
+ # Environment variable for debugging/testing different GPU memory configurations
23
+ DEBUG_MAX_CUDA_VRAM_ENV = "MAX_CUDA_VRAM"
24
+ DEBUG_MAX_MPS_VRAM_ENV = "MAX_MPS_VRAM"
25
+
26
+ # Tolerance for 16GB detection: reported VRAM like 15.5GB is effectively 16GB hardware
27
+ # Real-world 16GB GPUs often report 15.7-15.9GB due to system/driver reservations
28
+ VRAM_16GB_TOLERANCE_GB = 0.5
29
+ VRAM_16GB_MIN_GB = 16.0 - VRAM_16GB_TOLERANCE_GB # treat as 16GB class if >= this
30
+
31
+ # PyTorch installation URLs for diagnostics
32
+ PYTORCH_CUDA_INSTALL_URL = "https://download.pytorch.org/whl/cu121"
33
+ PYTORCH_ROCM_INSTALL_URL = "https://download.pytorch.org/whl/rocm6.0"
34
+
35
+
36
+ @dataclass
37
+ class GPUConfig:
38
+ """GPU configuration based on available memory"""
39
+ tier: str # "tier1", "tier2", etc. or "unlimited"
40
+ gpu_memory_gb: float
41
+
42
+ # Duration limits (in seconds)
43
+ max_duration_with_lm: int # When LM is initialized
44
+ max_duration_without_lm: int # When LM is not initialized
45
+
46
+ # Batch size limits
47
+ max_batch_size_with_lm: int
48
+ max_batch_size_without_lm: int
49
+
50
+ # LM configuration
51
+ init_lm_default: bool # Whether to initialize LM by default
52
+ available_lm_models: List[str] # Available LM models for this tier
53
+
54
+ # LM memory allocation (GB) for each model size
55
+ lm_memory_gb: Dict[str, float] # e.g., {"0.6B": 3, "1.7B": 8, "4B": 12}
56
+
57
+
58
+ # GPU tier configurations
59
+ GPU_TIER_CONFIGS = {
60
+ "tier1": { # <= 4GB
61
+ "max_duration_with_lm": 180, # 3 minutes
62
+ "max_duration_without_lm": 180, # 3 minutes
63
+ "max_batch_size_with_lm": 1,
64
+ "max_batch_size_without_lm": 1,
65
+ "init_lm_default": False,
66
+ "available_lm_models": [],
67
+ "lm_memory_gb": {},
68
+ },
69
+ "tier2": { # 4-6GB
70
+ "max_duration_with_lm": 360, # 6 minutes
71
+ "max_duration_without_lm": 360, # 6 minutes
72
+ "max_batch_size_with_lm": 1,
73
+ "max_batch_size_without_lm": 1,
74
+ "init_lm_default": False,
75
+ "available_lm_models": [],
76
+ "lm_memory_gb": {},
77
+ },
78
+ "tier3": { # 6-8GB
79
+ "max_duration_with_lm": 240, # 4 minutes with LM
80
+ "max_duration_without_lm": 360, # 6 minutes without LM
81
+ "max_batch_size_with_lm": 1,
82
+ "max_batch_size_without_lm": 2,
83
+ "init_lm_default": False, # Don't init by default due to limited memory
84
+ "available_lm_models": ["acestep-5Hz-lm-0.6B"],
85
+ "lm_memory_gb": {"0.6B": 3},
86
+ },
87
+ "tier4": { # 8-12GB
88
+ "max_duration_with_lm": 240, # 4 minutes with LM
89
+ "max_duration_without_lm": 360, # 6 minutes without LM
90
+ "max_batch_size_with_lm": 2,
91
+ "max_batch_size_without_lm": 4,
92
+ "init_lm_default": False, # Don't init by default
93
+ "available_lm_models": ["acestep-5Hz-lm-0.6B"],
94
+ "lm_memory_gb": {"0.6B": 3},
95
+ },
96
+ "tier5": { # 12-16GB
97
+ "max_duration_with_lm": 240, # 4 minutes with LM
98
+ "max_duration_without_lm": 360, # 6 minutes without LM
99
+ "max_batch_size_with_lm": 2,
100
+ "max_batch_size_without_lm": 4,
101
+ "init_lm_default": True,
102
+ "available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B"],
103
+ "lm_memory_gb": {"0.6B": 3, "1.7B": 8},
104
+ },
105
+ "tier6": { # 16-24GB
106
+ "max_duration_with_lm": 480, # 8 minutes
107
+ "max_duration_without_lm": 480, # 8 minutes
108
+ "max_batch_size_with_lm": 4,
109
+ "max_batch_size_without_lm": 8,
110
+ "init_lm_default": True,
111
+ "available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
112
+ "lm_memory_gb": {"0.6B": 3, "1.7B": 8, "4B": 12},
113
+ },
114
+ "unlimited": { # >= 24GB
115
+ "max_duration_with_lm": 600, # 10 minutes (max supported)
116
+ "max_duration_without_lm": 600, # 10 minutes
117
+ "max_batch_size_with_lm": 8,
118
+ "max_batch_size_without_lm": 8,
119
+ "init_lm_default": True,
120
+ "available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
121
+ "lm_memory_gb": {"0.6B": 3, "1.7B": 8, "4B": 12},
122
+ },
123
+ }
124
+
125
+
126
+ def get_gpu_memory_gb() -> float:
127
+ """
128
+ Get GPU memory in GB. Returns 0 if no GPU is available.
129
+
130
+ Debug Mode:
131
+ Set environment variable MAX_CUDA_VRAM to override the detected GPU memory.
132
+ Example: MAX_CUDA_VRAM=8 python acestep # Simulates 8GB GPU
133
+
134
+ For MPS testing, set MAX_MPS_VRAM to override MPS memory detection.
135
+ Example: MAX_MPS_VRAM=16 python acestep # Simulates 16GB MPS
136
+
137
+ This allows testing different GPU tier configurations on high-end hardware.
138
+ """
139
+ # Check for debug override first
140
+ debug_vram = os.environ.get(DEBUG_MAX_CUDA_VRAM_ENV)
141
+ if debug_vram is not None:
142
+ try:
143
+ simulated_gb = float(debug_vram)
144
+ logger.warning(f"⚠️ DEBUG MODE: Simulating GPU memory as {simulated_gb:.1f}GB (set via {DEBUG_MAX_CUDA_VRAM_ENV} environment variable)")
145
+ return simulated_gb
146
+ except ValueError:
147
+ logger.warning(f"Invalid {DEBUG_MAX_CUDA_VRAM_ENV} value: {debug_vram}, ignoring")
148
+ debug_mps_vram = os.environ.get(DEBUG_MAX_MPS_VRAM_ENV)
149
+ if debug_mps_vram is not None:
150
+ try:
151
+ simulated_gb = float(debug_mps_vram)
152
+ logger.warning(f"⚠️ DEBUG MODE: Simulating MPS memory as {simulated_gb:.1f}GB (set via {DEBUG_MAX_MPS_VRAM_ENV} environment variable)")
153
+ return simulated_gb
154
+ except ValueError:
155
+ logger.warning(f"Invalid {DEBUG_MAX_MPS_VRAM_ENV} value: {debug_mps_vram}, ignoring")
156
+
157
+ try:
158
+ import torch
159
+ if torch.cuda.is_available():
160
+ # Get total memory of the first GPU in GB
161
+ total_memory = torch.cuda.get_device_properties(0).total_memory
162
+ memory_gb = total_memory / (1024**3) # Convert bytes to GB
163
+ device_name = torch.cuda.get_device_name(0)
164
+ is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
165
+ if is_rocm:
166
+ logger.info(f"ROCm GPU detected: {device_name} ({memory_gb:.1f} GB, HIP {torch.version.hip})")
167
+ else:
168
+ logger.info(f"CUDA GPU detected: {device_name} ({memory_gb:.1f} GB)")
169
+ return memory_gb
170
+ elif hasattr(torch, 'xpu') and torch.xpu.is_available():
171
+ # Get total memory of the first XPU in GB
172
+ total_memory = torch.xpu.get_device_properties(0).total_memory
173
+ memory_gb = total_memory / (1024**3) # Convert bytes to GB
174
+ return memory_gb
175
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
176
+ mps_module = getattr(torch, "mps", None)
177
+ try:
178
+ if mps_module is not None and hasattr(mps_module, "recommended_max_memory"):
179
+ total_memory = mps_module.recommended_max_memory()
180
+ memory_gb = total_memory / (1024**3) # Convert bytes to GB
181
+ return memory_gb
182
+ if mps_module is not None and hasattr(mps_module, "get_device_properties"):
183
+ props = mps_module.get_device_properties(0)
184
+ total_memory = getattr(props, "total_memory", None)
185
+ if total_memory:
186
+ memory_gb = total_memory / (1024**3)
187
+ return memory_gb
188
+ except Exception as e:
189
+ logger.warning(f"Failed to detect MPS memory: {e}")
190
+
191
+ # Fallback: estimate from system unified memory (Apple Silicon shares CPU/GPU RAM)
192
+ try:
193
+ import subprocess
194
+ result = subprocess.run(
195
+ ["sysctl", "-n", "hw.memsize"],
196
+ capture_output=True, text=True, timeout=5
197
+ )
198
+ total_system_bytes = int(result.stdout.strip())
199
+ # MPS can use up to ~75% of unified memory for GPU workloads
200
+ memory_gb = (total_system_bytes / (1024**3)) * 0.75
201
+ return memory_gb
202
+ except Exception:
203
+ logger.warning(f"MPS available but total memory not exposed. Set {DEBUG_MAX_MPS_VRAM_ENV} to enable tiering.")
204
+ # Conservative fallback for M1/M2
205
+ return 8.0
206
+ else:
207
+ # No GPU detected - provide diagnostic information
208
+ _log_gpu_diagnostic_info(torch)
209
+ return 0
210
+ except Exception as e:
211
+ logger.warning(f"Failed to detect GPU memory: {e}")
212
+ return 0
213
+
214
+
215
+ def _log_gpu_diagnostic_info(torch_module):
216
+ """
217
+ Log diagnostic information when GPU is not detected to help users troubleshoot.
218
+
219
+ Args:
220
+ torch_module: The torch module to inspect for build information
221
+ """
222
+ logger.warning("=" * 80)
223
+ logger.warning("⚠️ GPU NOT DETECTED - DIAGNOSTIC INFORMATION")
224
+ logger.warning("=" * 80)
225
+
226
+ # Check PyTorch build type
227
+ is_rocm_build = hasattr(torch_module.version, 'hip') and torch_module.version.hip is not None
228
+ is_cuda_build = hasattr(torch_module.version, 'cuda') and torch_module.version.cuda is not None
229
+
230
+ if is_rocm_build:
231
+ logger.warning("✓ PyTorch ROCm build detected")
232
+ logger.warning(f" HIP version: {torch_module.version.hip}")
233
+ logger.warning("")
234
+ logger.warning("❌ torch.cuda.is_available() returned False")
235
+ logger.warning("")
236
+ logger.warning("Common causes for AMD/ROCm GPUs:")
237
+ logger.warning(" 1. ROCm drivers not installed or not properly configured")
238
+ logger.warning(" 2. GPU not supported by installed ROCm version")
239
+ logger.warning(" 3. Missing or incorrect HSA_OVERRIDE_GFX_VERSION environment variable")
240
+ logger.warning(" 4. ROCm runtime libraries not in system path")
241
+ logger.warning("")
242
+
243
+ # Check for common environment variables
244
+ hsa_override = os.environ.get('HSA_OVERRIDE_GFX_VERSION')
245
+ if hsa_override:
246
+ logger.warning(f" HSA_OVERRIDE_GFX_VERSION is set to: {hsa_override}")
247
+ else:
248
+ logger.warning(" ⚠️ HSA_OVERRIDE_GFX_VERSION is not set")
249
+ logger.warning(" For RDNA3 GPUs (RX 7000 series, RX 9000 series):")
250
+ logger.warning(" - RX 7900 XT/XTX, RX 9070 XT: set HSA_OVERRIDE_GFX_VERSION=11.0.0")
251
+ logger.warning(" - RX 7800 XT, RX 7700 XT: set HSA_OVERRIDE_GFX_VERSION=11.0.1")
252
+ logger.warning(" - RX 7600: set HSA_OVERRIDE_GFX_VERSION=11.0.2")
253
+
254
+ logger.warning("")
255
+ logger.warning("Troubleshooting steps:")
256
+ logger.warning(" 1. Verify ROCm installation:")
257
+ logger.warning(" rocm-smi # Should list your GPU")
258
+ logger.warning(" 2. Check PyTorch ROCm build:")
259
+ logger.warning(" python -c \"import torch; print(f'ROCm: {torch.version.hip}')\"")
260
+ logger.warning(" 3. Set HSA_OVERRIDE_GFX_VERSION for your GPU (see above)")
261
+ logger.warning(" 4. On Windows: Use start_gradio_ui_rocm.bat which sets required env vars")
262
+ logger.warning(" 5. See docs/en/ACE-Step1.5-Rocm-Manual-Linux.md for Linux setup")
263
+ logger.warning(" 6. See requirements-rocm.txt for Windows ROCm setup instructions")
264
+
265
+ elif is_cuda_build:
266
+ logger.warning("✓ PyTorch CUDA build detected")
267
+ logger.warning(f" CUDA version: {torch_module.version.cuda}")
268
+ logger.warning("")
269
+ logger.warning("❌ torch.cuda.is_available() returned False")
270
+ logger.warning("")
271
+ logger.warning("Common causes for NVIDIA GPUs:")
272
+ logger.warning(" 1. NVIDIA drivers not installed")
273
+ logger.warning(" 2. CUDA runtime not installed or version mismatch")
274
+ logger.warning(" 3. GPU not supported by installed CUDA version")
275
+ logger.warning("")
276
+ logger.warning("Troubleshooting steps:")
277
+ logger.warning(" 1. Verify NVIDIA driver installation:")
278
+ logger.warning(" nvidia-smi # Should list your GPU")
279
+ logger.warning(" 2. Check CUDA version compatibility")
280
+ logger.warning(" 3. Reinstall PyTorch with CUDA support:")
281
+ logger.warning(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}")
282
+
283
+ else:
284
+ logger.warning("⚠️ PyTorch build type: CPU-only")
285
+ logger.warning("")
286
+ logger.warning("You have installed a CPU-only version of PyTorch!")
287
+ logger.warning("")
288
+ logger.warning("For NVIDIA GPUs:")
289
+ logger.warning(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}")
290
+ logger.warning("")
291
+ logger.warning("For AMD GPUs with ROCm:")
292
+ logger.warning(" Windows: See requirements-rocm.txt for detailed instructions")
293
+ logger.warning(f" Linux: pip install torch --index-url {PYTORCH_ROCM_INSTALL_URL}")
294
+ logger.warning("")
295
+ logger.warning("For more information, see README.md section 'AMD / ROCm GPUs'")
296
+
297
+ logger.warning("=" * 80)
298
+
299
+
300
+ def get_gpu_tier(gpu_memory_gb: float) -> str:
301
+ """
302
+ Determine GPU tier based on available memory.
303
+
304
+ Args:
305
+ gpu_memory_gb: GPU memory in GB
306
+
307
+ Returns:
308
+ Tier string: "tier1", "tier2", "tier3", "tier4", "tier5", "tier6", or "unlimited"
309
+ """
310
+ if gpu_memory_gb <= 0:
311
+ # CPU mode - use tier1 limits
312
+ return "tier1"
313
+ elif gpu_memory_gb <= 4:
314
+ return "tier1"
315
+ elif gpu_memory_gb <= 6:
316
+ return "tier2"
317
+ elif gpu_memory_gb <= 8:
318
+ return "tier3"
319
+ elif gpu_memory_gb <= 12:
320
+ return "tier4"
321
+ elif gpu_memory_gb < VRAM_16GB_MIN_GB:
322
+ return "tier5"
323
+ elif gpu_memory_gb <= 24:
324
+ if gpu_memory_gb < 16.0:
325
+ logger.info(f"Detected {gpu_memory_gb:.2f}GB VRAM — treating as 16GB class GPU")
326
+ return "tier6"
327
+ else:
328
+ return "unlimited"
329
+
330
+
331
+ def get_gpu_config(gpu_memory_gb: Optional[float] = None) -> GPUConfig:
332
+ """
333
+ Get GPU configuration based on detected or provided GPU memory.
334
+
335
+ Args:
336
+ gpu_memory_gb: GPU memory in GB. If None, will be auto-detected.
337
+
338
+ Returns:
339
+ GPUConfig object with all configuration parameters
340
+ """
341
+ if gpu_memory_gb is None:
342
+ gpu_memory_gb = get_gpu_memory_gb()
343
+
344
+ tier = get_gpu_tier(gpu_memory_gb)
345
+ config = GPU_TIER_CONFIGS[tier]
346
+
347
+ return GPUConfig(
348
+ tier=tier,
349
+ gpu_memory_gb=gpu_memory_gb,
350
+ max_duration_with_lm=config["max_duration_with_lm"],
351
+ max_duration_without_lm=config["max_duration_without_lm"],
352
+ max_batch_size_with_lm=config["max_batch_size_with_lm"],
353
+ max_batch_size_without_lm=config["max_batch_size_without_lm"],
354
+ init_lm_default=config["init_lm_default"],
355
+ available_lm_models=config["available_lm_models"],
356
+ lm_memory_gb=config["lm_memory_gb"],
357
+ )
358
+
359
+
360
+ def get_lm_model_size(model_path: str) -> str:
361
+ """
362
+ Extract LM model size from model path.
363
+
364
+ Args:
365
+ model_path: Model path string (e.g., "acestep-5Hz-lm-0.6B")
366
+
367
+ Returns:
368
+ Model size string: "0.6B", "1.7B", or "4B"
369
+ """
370
+ if "0.6B" in model_path:
371
+ return "0.6B"
372
+ elif "1.7B" in model_path:
373
+ return "1.7B"
374
+ elif "4B" in model_path:
375
+ return "4B"
376
+ else:
377
+ # Default to smallest model assumption
378
+ return "0.6B"
379
+
380
+
381
+ def get_lm_gpu_memory_ratio(model_path: str, total_gpu_memory_gb: float) -> Tuple[float, float]:
382
+ """
383
+ Calculate GPU memory utilization ratio for LM model.
384
+
385
+ Args:
386
+ model_path: LM model path (e.g., "acestep-5Hz-lm-0.6B")
387
+ total_gpu_memory_gb: Total GPU memory in GB
388
+
389
+ Returns:
390
+ Tuple of (gpu_memory_utilization_ratio, target_memory_gb)
391
+ """
392
+ model_size = get_lm_model_size(model_path)
393
+
394
+ # Target memory allocation for each model size
395
+ target_memory = {
396
+ "0.6B": 3.0,
397
+ "1.7B": 8.0,
398
+ "4B": 12.0,
399
+ }
400
+
401
+ target_gb = target_memory.get(model_size, 3.0)
402
+
403
+ # For large GPUs (>=24GB), don't restrict memory too much
404
+ if total_gpu_memory_gb >= 24:
405
+ # Use a reasonable ratio that allows the model to run efficiently
406
+ ratio = min(0.9, max(0.2, target_gb / total_gpu_memory_gb))
407
+ else:
408
+ # For smaller GPUs, strictly limit memory usage
409
+ ratio = min(0.9, max(0.1, target_gb / total_gpu_memory_gb))
410
+
411
+ return ratio, target_gb
412
+
413
+
414
+ def check_duration_limit(
415
+ duration: float,
416
+ gpu_config: GPUConfig,
417
+ lm_initialized: bool
418
+ ) -> Tuple[bool, str]:
419
+ """
420
+ Check if requested duration is within limits for current GPU configuration.
421
+
422
+ Args:
423
+ duration: Requested duration in seconds
424
+ gpu_config: Current GPU configuration
425
+ lm_initialized: Whether LM is initialized
426
+
427
+ Returns:
428
+ Tuple of (is_valid, warning_message)
429
+ """
430
+ max_duration = gpu_config.max_duration_with_lm if lm_initialized else gpu_config.max_duration_without_lm
431
+
432
+ if duration > max_duration:
433
+ warning_msg = (
434
+ f"⚠️ Requested duration ({duration:.0f}s) exceeds the limit for your GPU "
435
+ f"({gpu_config.gpu_memory_gb:.1f}GB). Maximum allowed: {max_duration}s "
436
+ f"({'with' if lm_initialized else 'without'} LM). "
437
+ f"Duration will be clamped to {max_duration}s."
438
+ )
439
+ return False, warning_msg
440
+
441
+ return True, ""
442
+
443
+
444
+ def check_batch_size_limit(
445
+ batch_size: int,
446
+ gpu_config: GPUConfig,
447
+ lm_initialized: bool
448
+ ) -> Tuple[bool, str]:
449
+ """
450
+ Check if requested batch size is within limits for current GPU configuration.
451
+
452
+ Args:
453
+ batch_size: Requested batch size
454
+ gpu_config: Current GPU configuration
455
+ lm_initialized: Whether LM is initialized
456
+
457
+ Returns:
458
+ Tuple of (is_valid, warning_message)
459
+ """
460
+ max_batch_size = gpu_config.max_batch_size_with_lm if lm_initialized else gpu_config.max_batch_size_without_lm
461
+
462
+ if batch_size > max_batch_size:
463
+ warning_msg = (
464
+ f"⚠️ Requested batch size ({batch_size}) exceeds the limit for your GPU "
465
+ f"({gpu_config.gpu_memory_gb:.1f}GB). Maximum allowed: {max_batch_size} "
466
+ f"({'with' if lm_initialized else 'without'} LM). "
467
+ f"Batch size will be clamped to {max_batch_size}."
468
+ )
469
+ return False, warning_msg
470
+
471
+ return True, ""
472
+
473
+
474
+ def is_lm_model_supported(model_path: str, gpu_config: GPUConfig) -> Tuple[bool, str]:
475
+ """
476
+ Check if the specified LM model is supported for current GPU configuration.
477
+
478
+ Args:
479
+ model_path: LM model path
480
+ gpu_config: Current GPU configuration
481
+
482
+ Returns:
483
+ Tuple of (is_supported, warning_message)
484
+ """
485
+ if not gpu_config.available_lm_models:
486
+ return False, (
487
+ f"⚠️ Your GPU ({gpu_config.gpu_memory_gb:.1f}GB) does not have enough memory "
488
+ f"to run any LM model. Please disable LM initialization."
489
+ )
490
+
491
+ model_size = get_lm_model_size(model_path)
492
+
493
+ # Check if model size is in available models
494
+ for available_model in gpu_config.available_lm_models:
495
+ if model_size in available_model:
496
+ return True, ""
497
+
498
+ return False, (
499
+ f"⚠️ LM model {model_path} ({model_size}) is not supported for your GPU "
500
+ f"({gpu_config.gpu_memory_gb:.1f}GB). Available models: {', '.join(gpu_config.available_lm_models)}"
501
+ )
502
+
503
+
504
+ def get_recommended_lm_model(gpu_config: GPUConfig) -> Optional[str]:
505
+ """
506
+ Get recommended LM model for current GPU configuration.
507
+
508
+ Args:
509
+ gpu_config: Current GPU configuration
510
+
511
+ Returns:
512
+ Recommended LM model path, or None if LM is not supported
513
+ """
514
+ if not gpu_config.available_lm_models:
515
+ return None
516
+
517
+ # Return the largest available model (last in the list)
518
+ return gpu_config.available_lm_models[-1]
519
+
520
+
521
+ def print_gpu_config_info(gpu_config: GPUConfig):
522
+ """Print GPU configuration information for debugging."""
523
+ logger.info(f"GPU Configuration:")
524
+ logger.info(f" - GPU Memory: {gpu_config.gpu_memory_gb:.1f} GB")
525
+ logger.info(f" - Tier: {gpu_config.tier}")
526
+ logger.info(f" - Max Duration (with LM): {gpu_config.max_duration_with_lm}s ({gpu_config.max_duration_with_lm // 60} min)")
527
+ logger.info(f" - Max Duration (without LM): {gpu_config.max_duration_without_lm}s ({gpu_config.max_duration_without_lm // 60} min)")
528
+ logger.info(f" - Max Batch Size (with LM): {gpu_config.max_batch_size_with_lm}")
529
+ logger.info(f" - Max Batch Size (without LM): {gpu_config.max_batch_size_without_lm}")
530
+ logger.info(f" - Init LM by Default: {gpu_config.init_lm_default}")
531
+ logger.info(f" - Available LM Models: {gpu_config.available_lm_models or 'None'}")
532
+
533
+
534
+ # Global GPU config instance (initialized lazily)
535
+ _global_gpu_config: Optional[GPUConfig] = None
536
+
537
+
538
+ def get_global_gpu_config() -> GPUConfig:
539
+ """Get the global GPU configuration, initializing if necessary."""
540
+ global _global_gpu_config
541
+ if _global_gpu_config is None:
542
+ _global_gpu_config = get_gpu_config()
543
+ return _global_gpu_config
544
+
545
+
546
+ def set_global_gpu_config(config: GPUConfig):
547
+ """Set the global GPU configuration."""
548
+ global _global_gpu_config
549
+ _global_gpu_config = config
acestep/handler.py ADDED
The diff for this file is too large to render. See raw diff
 
acestep/inference.py ADDED
@@ -0,0 +1,1310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step Inference API Module
3
+
4
+ This module provides a standardized inference interface for music generation,
5
+ designed for third-party integration. It offers both a simplified API and
6
+ backward-compatible Gradio UI support.
7
+ """
8
+
9
+ import math
10
+ import os
11
+ import tempfile
12
+ import shutil
13
+ import subprocess
14
+ import sys
15
+ from typing import Optional, Union, List, Dict, Any, Tuple
16
+ from dataclasses import dataclass, field, asdict
17
+ from loguru import logger
18
+
19
+ from acestep.audio_utils import AudioSaver, generate_uuid_from_params, is_audio_silent
20
+ from acestep.constants import TASK_INSTRUCTIONS
21
+ from acestep.gpu_config import get_gpu_config
22
+
23
+ # HuggingFace Space environment detection
24
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
25
+
26
+ def _get_spaces_gpu_decorator(duration=180):
27
+ """
28
+ Get the @spaces.GPU decorator if running in HuggingFace Space environment.
29
+ Returns identity decorator if not in Space environment.
30
+ """
31
+ if IS_HUGGINGFACE_SPACE:
32
+ try:
33
+ import spaces
34
+ return spaces.GPU(duration=duration)
35
+ except ImportError:
36
+ logger.warning("spaces package not found, GPU decorator disabled")
37
+ return lambda func: func
38
+ return lambda func: func
39
+
40
+
41
+ @dataclass
42
+ class GenerationParams:
43
+ """Configuration for music generation parameters.
44
+
45
+ Attributes:
46
+ # Text Inputs
47
+ caption: A short text prompt describing the desired music (main prompt). < 512 characters
48
+ lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
49
+ instrumental: If True, generate instrumental music regardless of lyrics.
50
+
51
+ # Music Metadata
52
+ bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
53
+ keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
54
+ timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
55
+ vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
56
+ duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
57
+
58
+ # Generation Parameters
59
+ inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
60
+ guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
61
+ seed: Integer seed for reproducibility. -1 means use random seed each time.
62
+
63
+ # Advanced DiT Parameters
64
+ use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
65
+ cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
66
+ cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
67
+ shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps.
68
+
69
+ # Task-Specific Parameters
70
+ task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
71
+ reference_audio: Path to a reference audio file for style transfer or cover tasks.
72
+ src_audio: Path to a source audio file for audio-to-audio tasks.
73
+ audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
74
+ repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
75
+ repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
76
+ audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
77
+ instruction: Optional task instruction prompt. If empty, auto-generated by system.
78
+
79
+ # 5Hz Language Model Parameters for CoT reasoning
80
+ thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
81
+ lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
82
+ lm_cfg_scale: Classifier-free guidance scale for the LLM.
83
+ lm_top_k: LLM top-k sampling (0 = disabled).
84
+ lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
85
+ lm_negative_prompt: Negative prompt to use for LLM (for control).
86
+ use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
87
+ use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
88
+ use_cot_language: Whether to let LLM detect vocal language via CoT.
89
+ """
90
+ # Required Inputs
91
+ task_type: str = "text2music"
92
+ instruction: str = "Fill the audio semantic mask based on the given conditions:"
93
+
94
+ # Audio Uploads
95
+ reference_audio: Optional[str] = None
96
+ src_audio: Optional[str] = None
97
+
98
+ # LM Codes Hints
99
+ audio_codes: str = ""
100
+
101
+ # Text Inputs
102
+ caption: str = ""
103
+ lyrics: str = ""
104
+ instrumental: bool = False
105
+
106
+ # Metadata
107
+ vocal_language: str = "unknown"
108
+ bpm: Optional[int] = None
109
+ keyscale: str = ""
110
+ timesignature: str = ""
111
+ duration: float = -1.0
112
+
113
+ # Advanced Settings
114
+ inference_steps: int = 8
115
+ seed: int = -1
116
+ guidance_scale: float = 7.0
117
+ use_adg: bool = False
118
+ cfg_interval_start: float = 0.0
119
+ cfg_interval_end: float = 1.0
120
+ shift: float = 1.0
121
+ infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
122
+ # Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
123
+ # If provided, overrides inference_steps and shift
124
+ timesteps: Optional[List[float]] = None
125
+
126
+ repainting_start: float = 0.0
127
+ repainting_end: float = -1
128
+ audio_cover_strength: float = 1.0
129
+
130
+ # 5Hz Language Model Parameters
131
+ thinking: bool = True
132
+ lm_temperature: float = 0.85
133
+ lm_cfg_scale: float = 2.0
134
+ lm_top_k: int = 0
135
+ lm_top_p: float = 0.9
136
+ lm_negative_prompt: str = "NO USER INPUT"
137
+ use_cot_metas: bool = True
138
+ use_cot_caption: bool = True
139
+ use_cot_lyrics: bool = False # TODO: not used yet
140
+ use_cot_language: bool = True
141
+ use_constrained_decoding: bool = True
142
+
143
+ cot_bpm: Optional[int] = None
144
+ cot_keyscale: str = ""
145
+ cot_timesignature: str = ""
146
+ cot_duration: Optional[float] = None
147
+ cot_vocal_language: str = "unknown"
148
+ cot_caption: str = ""
149
+ cot_lyrics: str = ""
150
+
151
+ def to_dict(self) -> Dict[str, Any]:
152
+ """Convert config to dictionary for JSON serialization."""
153
+ return asdict(self)
154
+
155
+
156
+ @dataclass
157
+ class GenerationConfig:
158
+ """Configuration for music generation.
159
+
160
+ Attributes:
161
+ batch_size: Number of audio samples to generate
162
+ allow_lm_batch: Whether to allow batch processing in LM
163
+ use_random_seed: Whether to use random seed
164
+ seeds: Seed(s) for batch generation. Can be:
165
+ - None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
166
+ - List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
167
+ - int: Single seed value (will be converted to list and padded)
168
+ lm_batch_chunk_size: Batch chunk size for LM processing
169
+ constrained_decoding_debug: Whether to enable constrained decoding debug
170
+ audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
171
+ """
172
+ batch_size: int = 2
173
+ allow_lm_batch: bool = False
174
+ use_random_seed: bool = True
175
+ seeds: Optional[List[int]] = None
176
+ lm_batch_chunk_size: int = 8
177
+ constrained_decoding_debug: bool = False
178
+ audio_format: str = "flac" # Default to FLAC for fast saving
179
+
180
+ def to_dict(self) -> Dict[str, Any]:
181
+ """Convert config to dictionary for JSON serialization."""
182
+ return asdict(self)
183
+
184
+
185
+ @dataclass
186
+ class GenerationResult:
187
+ """Result of music generation.
188
+
189
+ Attributes:
190
+ # Audio Outputs
191
+ audios: List of audio dictionaries with paths, keys, params
192
+ status_message: Status message from generation
193
+ extra_outputs: Extra outputs from generation
194
+ success: Whether generation completed successfully
195
+ error: Error message if generation failed
196
+ """
197
+
198
+ # Audio Outputs
199
+ audios: List[Dict[str, Any]] = field(default_factory=list)
200
+ # Generation Information
201
+ status_message: str = ""
202
+ extra_outputs: Dict[str, Any] = field(default_factory=dict)
203
+ # Success Status
204
+ success: bool = True
205
+ error: Optional[str] = None
206
+
207
+ def to_dict(self) -> Dict[str, Any]:
208
+ """Convert result to dictionary for JSON serialization."""
209
+ return asdict(self)
210
+
211
+
212
+ @dataclass
213
+ class UnderstandResult:
214
+ """Result of music understanding from audio codes.
215
+
216
+ Attributes:
217
+ # Metadata Fields
218
+ caption: Generated caption describing the music
219
+ lyrics: Generated or extracted lyrics
220
+ bpm: Beats per minute (None if not detected)
221
+ duration: Duration in seconds (None if not detected)
222
+ keyscale: Musical key (e.g., "C Major")
223
+ language: Vocal language code (e.g., "en", "zh")
224
+ timesignature: Time signature (e.g., "4/4")
225
+
226
+ # Status
227
+ status_message: Status message from understanding
228
+ success: Whether understanding completed successfully
229
+ error: Error message if understanding failed
230
+ """
231
+ # Metadata Fields
232
+ caption: str = ""
233
+ lyrics: str = ""
234
+ bpm: Optional[int] = None
235
+ duration: Optional[float] = None
236
+ keyscale: str = ""
237
+ language: str = ""
238
+ timesignature: str = ""
239
+
240
+ # Status
241
+ status_message: str = ""
242
+ success: bool = True
243
+ error: Optional[str] = None
244
+
245
+ def to_dict(self) -> Dict[str, Any]:
246
+ """Convert result to dictionary for JSON serialization."""
247
+ return asdict(self)
248
+
249
+
250
+ def _update_metadata_from_lm(
251
+ metadata: Dict[str, Any],
252
+ bpm: Optional[int],
253
+ key_scale: str,
254
+ time_signature: str,
255
+ audio_duration: Optional[float],
256
+ vocal_language: str,
257
+ caption: str,
258
+ lyrics: str,
259
+ ) -> Tuple[Optional[int], str, str, Optional[float], str, str, str]:
260
+ """Update metadata fields from LM output if not provided by user."""
261
+
262
+ if bpm is None and metadata.get('bpm'):
263
+ bpm_value = metadata.get('bpm')
264
+ if bpm_value not in ["N/A", ""]:
265
+ try:
266
+ bpm = int(bpm_value)
267
+ except (ValueError, TypeError):
268
+ pass
269
+
270
+ if not key_scale and metadata.get('keyscale'):
271
+ key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
272
+ if key_scale_value != "N/A":
273
+ key_scale = key_scale_value
274
+
275
+ if not time_signature and metadata.get('timesignature'):
276
+ time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
277
+ if time_signature_value != "N/A":
278
+ time_signature = time_signature_value
279
+
280
+ if audio_duration is None or audio_duration <= 0:
281
+ audio_duration_value = metadata.get('duration', -1)
282
+ if audio_duration_value not in ["N/A", ""]:
283
+ try:
284
+ audio_duration = float(audio_duration_value)
285
+ except (ValueError, TypeError):
286
+ pass
287
+
288
+ if not vocal_language and metadata.get('vocal_language'):
289
+ vocal_language = metadata.get('vocal_language')
290
+ if not caption and metadata.get('caption'):
291
+ caption = metadata.get('caption')
292
+ if not lyrics and metadata.get('lyrics'):
293
+ lyrics = metadata.get('lyrics')
294
+ return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
295
+
296
+
297
+ @_get_spaces_gpu_decorator(duration=180)
298
+ def generate_music(
299
+ dit_handler,
300
+ llm_handler,
301
+ params: GenerationParams,
302
+ config: GenerationConfig,
303
+ save_dir: Optional[str] = None,
304
+ progress=None,
305
+ ) -> GenerationResult:
306
+ """Generate music using ACE-Step model with optional LM reasoning.
307
+
308
+ Args:
309
+ dit_handler: Initialized DiT model handler (AceStepHandler instance)
310
+ llm_handler: Initialized LLM handler (LLMHandler instance)
311
+ params: Generation parameters (GenerationParams instance)
312
+ config: Generation configuration (GenerationConfig instance)
313
+
314
+ Returns:
315
+ GenerationResult with generated audio files and metadata
316
+ """
317
+ try:
318
+ # Phase 1: LM-based metadata and code generation (if enabled)
319
+ audio_code_string_to_use = params.audio_codes
320
+ lm_generated_metadata = None
321
+ lm_generated_audio_codes_list = []
322
+ lm_total_time_costs = {
323
+ "phase1_time": 0.0,
324
+ "phase2_time": 0.0,
325
+ "total_time": 0.0,
326
+ }
327
+
328
+ # Extract mutable copies of metadata (will be updated by LM if needed)
329
+ bpm = params.bpm
330
+ key_scale = params.keyscale
331
+ time_signature = params.timesignature
332
+ audio_duration = params.duration
333
+ dit_input_caption = params.caption
334
+ dit_input_vocal_language = params.vocal_language
335
+ dit_input_lyrics = params.lyrics
336
+ # Determine if we need to generate audio codes
337
+ # If user has provided audio_codes, we don't need to generate them
338
+ # Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
339
+ user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
340
+
341
+ # Safety: cover task without any source audio or codes produces silence.
342
+ if params.task_type == "cover":
343
+ no_src_audio = not (params.reference_audio or params.src_audio)
344
+ if no_src_audio and not user_provided_audio_codes:
345
+ logger.warning("Cover task requested without source audio or audio codes. Falling back to text2music.")
346
+ params.task_type = "text2music"
347
+ if params.instruction == TASK_INSTRUCTIONS.get("cover"):
348
+ params.instruction = TASK_INSTRUCTIONS.get("text2music", params.instruction)
349
+
350
+ # Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
351
+ # For now, we use "llm_dit" if batch mode or if user hasn't provided codes
352
+ # Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
353
+ # Note: This logic can be refined based on specific requirements
354
+ need_audio_codes = not user_provided_audio_codes
355
+
356
+ # Determine if we should use chunk-based LM generation (always use chunks for consistency)
357
+ # Determine actual batch size for chunk processing
358
+ actual_batch_size = config.batch_size if config.batch_size is not None else 1
359
+
360
+ # Prepare seeds for batch generation
361
+ # Use config.seed if provided, otherwise fallback to params.seed
362
+ # Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
363
+ seed_for_generation = ""
364
+ # Original code (commented out because it crashes on int seeds):
365
+ # if config.seeds is not None and len(config.seeds) > 0:
366
+ # if isinstance(config.seeds, list):
367
+ # # Convert List[int] to comma-separated string
368
+ # seed_for_generation = ",".join(str(s) for s in config.seeds)
369
+
370
+ if config.seeds is not None:
371
+ if isinstance(config.seeds, list) and len(config.seeds) > 0:
372
+ # Convert List[int] to comma-separated string
373
+ seed_for_generation = ",".join(str(s) for s in config.seeds)
374
+ elif isinstance(config.seeds, int):
375
+ # Fix: Explicitly handle single integer seeds by converting to string.
376
+ # Previously, this would crash because 'len()' was called on an int.
377
+ seed_for_generation = str(config.seeds)
378
+
379
+ # Use dit_handler.prepare_seeds to handle seed list generation and padding
380
+ # This will handle all the logic: padding with random seeds if needed, etc.
381
+ actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
382
+
383
+ # LM-based Chain-of-Thought reasoning
384
+ # Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
385
+ # and don't need LM to generate audio codes
386
+ skip_lm_tasks = {"cover", "repaint"}
387
+
388
+ # Determine if we should use LLM
389
+ # LLM is needed for:
390
+ # 1. thinking=True: generate audio codes via LM
391
+ # 2. use_cot_caption=True: enhance/generate caption via CoT
392
+ # 3. use_cot_language=True: detect vocal language via CoT
393
+ # 4. use_cot_metas=True: fill missing metadata via CoT
394
+ need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas
395
+ use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
396
+ lm_status = []
397
+
398
+ if params.task_type in skip_lm_tasks:
399
+ logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
400
+
401
+ logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, "
402
+ f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, "
403
+ f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, "
404
+ f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}")
405
+
406
+ def _infer_audio_duration_seconds(audio_path: str) -> Optional[float]:
407
+ """Best-effort duration inference for common audio formats."""
408
+ if not audio_path:
409
+ return None
410
+ # Try torchaudio (supports more formats when ffmpeg backend is available)
411
+ try:
412
+ import torchaudio
413
+ info = torchaudio.info(audio_path)
414
+ if info and info.num_frames and info.sample_rate:
415
+ return float(info.num_frames) / float(info.sample_rate)
416
+ except Exception:
417
+ pass
418
+ # Try soundfile (fast for wav/flac)
419
+ try:
420
+ import soundfile as sf
421
+ info = sf.info(audio_path)
422
+ if info and info.frames and info.samplerate:
423
+ return float(info.frames) / float(info.samplerate)
424
+ except Exception:
425
+ pass
426
+ # macOS fallback: use afinfo for m4a/aac
427
+ if sys.platform == "darwin" and shutil.which("afinfo"):
428
+ try:
429
+ result = subprocess.run(
430
+ ["afinfo", audio_path],
431
+ check=False,
432
+ capture_output=True,
433
+ text=True,
434
+ )
435
+ if result.stdout:
436
+ for line in result.stdout.splitlines():
437
+ if "duration:" in line:
438
+ # Example: "duration: 183.165s"
439
+ parts = line.strip().split()
440
+ for p in parts:
441
+ if p.endswith("s"):
442
+ try:
443
+ return float(p.rstrip("s"))
444
+ except ValueError:
445
+ continue
446
+ except Exception:
447
+ pass
448
+ return None
449
+
450
+ # Clamp duration and batch size to GPU limits (applies to non-Gradio callers too)
451
+ try:
452
+ # If duration not provided, try to infer from source audio to enable safe clamping.
453
+ if (audio_duration is None or float(audio_duration) <= 0) and (params.src_audio or params.reference_audio):
454
+ audio_path = params.src_audio or params.reference_audio
455
+ try:
456
+ inferred = _infer_audio_duration_seconds(audio_path)
457
+ if inferred and inferred > 0:
458
+ audio_duration = inferred
459
+ params.duration = inferred
460
+ logger.info(f"[generate_music] Inferred duration from audio file: {inferred:.2f}s")
461
+ except Exception as e:
462
+ logger.warning(f"[generate_music] Failed to infer duration from audio file: {e}")
463
+
464
+ gpu_config = get_gpu_config()
465
+ max_duration = gpu_config.max_duration_with_lm if use_lm else gpu_config.max_duration_without_lm
466
+ if audio_duration is not None and float(audio_duration) > 0 and float(audio_duration) > max_duration:
467
+ logger.warning(f"[generate_music] Duration {audio_duration}s exceeds GPU limit {max_duration}s. Clamping.")
468
+ audio_duration = float(max_duration)
469
+ params.duration = float(max_duration)
470
+
471
+ max_batch = gpu_config.max_batch_size_with_lm if use_lm else gpu_config.max_batch_size_without_lm
472
+ if config.batch_size is not None and config.batch_size > max_batch:
473
+ logger.warning(f"[generate_music] Batch size {config.batch_size} exceeds GPU limit {max_batch}. Clamping.")
474
+ config.batch_size = max_batch
475
+
476
+ # Extra safety for MPS: large durations can OOM with batch > 1
477
+ if (
478
+ hasattr(dit_handler, "device")
479
+ and dit_handler.device == "mps"
480
+ and audio_duration is not None
481
+ and float(audio_duration) > 180
482
+ and config.batch_size is not None
483
+ and config.batch_size > 1
484
+ ):
485
+ logger.warning("[generate_music] MPS with long duration detected; reducing batch size to 1 to avoid OOM.")
486
+ config.batch_size = 1
487
+ except Exception as e:
488
+ logger.warning(f"[generate_music] Failed to clamp duration/batch to GPU limits: {e}")
489
+
490
+ if use_lm:
491
+ # Convert sampling parameters - handle None values safely
492
+ top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
493
+ top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
494
+
495
+ # Build user_metadata from user-provided values
496
+ user_metadata = {}
497
+ if bpm is not None:
498
+ try:
499
+ bpm_value = float(bpm)
500
+ if bpm_value > 0:
501
+ user_metadata['bpm'] = int(bpm_value)
502
+ except (ValueError, TypeError):
503
+ pass
504
+
505
+ if key_scale and key_scale.strip():
506
+ key_scale_clean = key_scale.strip()
507
+ if key_scale_clean.lower() not in ["n/a", ""]:
508
+ user_metadata['keyscale'] = key_scale_clean
509
+
510
+ if time_signature and time_signature.strip():
511
+ time_sig_clean = time_signature.strip()
512
+ if time_sig_clean.lower() not in ["n/a", ""]:
513
+ user_metadata['timesignature'] = time_sig_clean
514
+
515
+ if audio_duration is not None:
516
+ try:
517
+ duration_value = float(audio_duration)
518
+ if duration_value > 0:
519
+ user_metadata['duration'] = int(duration_value)
520
+ except (ValueError, TypeError):
521
+ pass
522
+
523
+ user_metadata_to_pass = user_metadata if user_metadata else None
524
+
525
+ # Determine infer_type based on whether we need audio codes
526
+ # - "llm_dit": generates both metas and audio codes (two-phase internally)
527
+ # - "dit": generates only metas (single phase)
528
+ infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit"
529
+
530
+ # Use chunk size from config, or default to batch_size if not set
531
+ max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
532
+ num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
533
+
534
+ all_metadata_list = []
535
+ all_audio_codes_list = []
536
+
537
+ for chunk_idx in range(num_chunks):
538
+ chunk_start = chunk_idx * max_inference_batch_size
539
+ chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
540
+ chunk_size = chunk_end - chunk_start
541
+ chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
542
+
543
+ logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
544
+ f"(size: {chunk_size}, seeds: {chunk_seeds})")
545
+
546
+ # Use the determined infer_type
547
+ # - "llm_dit" will internally run two phases (metas + codes)
548
+ # - "dit" will only run phase 1 (metas only)
549
+ result = llm_handler.generate_with_stop_condition(
550
+ caption=params.caption or "",
551
+ lyrics=params.lyrics or "",
552
+ infer_type=infer_type,
553
+ temperature=params.lm_temperature,
554
+ cfg_scale=params.lm_cfg_scale,
555
+ negative_prompt=params.lm_negative_prompt,
556
+ top_k=top_k_value,
557
+ top_p=top_p_value,
558
+ target_duration=audio_duration, # Pass duration to limit audio codes generation
559
+ user_metadata=user_metadata_to_pass,
560
+ use_cot_caption=params.use_cot_caption,
561
+ use_cot_language=params.use_cot_language,
562
+ use_cot_metas=params.use_cot_metas,
563
+ use_constrained_decoding=params.use_constrained_decoding,
564
+ constrained_decoding_debug=config.constrained_decoding_debug,
565
+ batch_size=chunk_size,
566
+ seeds=chunk_seeds,
567
+ progress=progress,
568
+ )
569
+
570
+ # Check if LM generation failed
571
+ if not result.get("success", False):
572
+ error_msg = result.get("error", "Unknown LM error")
573
+ lm_status.append(f"❌ LM Error: {error_msg}")
574
+ # Return early with error
575
+ return GenerationResult(
576
+ audios=[],
577
+ status_message=f"❌ LM generation failed: {error_msg}",
578
+ extra_outputs={},
579
+ success=False,
580
+ error=error_msg,
581
+ )
582
+
583
+ # Extract metadata and audio_codes from result dict
584
+ if chunk_size > 1:
585
+ metadata_list = result.get("metadata", [])
586
+ audio_codes_list = result.get("audio_codes", [])
587
+ all_metadata_list.extend(metadata_list)
588
+ all_audio_codes_list.extend(audio_codes_list)
589
+ else:
590
+ metadata = result.get("metadata", {})
591
+ audio_codes = result.get("audio_codes", "")
592
+ all_metadata_list.append(metadata)
593
+ all_audio_codes_list.append(audio_codes)
594
+
595
+ # Collect time costs from LM extra_outputs
596
+ lm_extra = result.get("extra_outputs", {})
597
+ lm_chunk_time_costs = lm_extra.get("time_costs", {})
598
+ if lm_chunk_time_costs:
599
+ # Accumulate time costs from all chunks
600
+ for key in ["phase1_time", "phase2_time", "total_time"]:
601
+ if key in lm_chunk_time_costs:
602
+ lm_total_time_costs[key] += lm_chunk_time_costs[key]
603
+
604
+ time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
605
+ lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
606
+
607
+ lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
608
+ lm_generated_audio_codes_list = all_audio_codes_list
609
+
610
+ # Set audio_code_string_to_use based on infer_type
611
+ if infer_type == "llm_dit":
612
+ # If batch mode, use list; otherwise use single string
613
+ if actual_batch_size > 1:
614
+ audio_code_string_to_use = all_audio_codes_list
615
+ else:
616
+ audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
617
+ else:
618
+ # For "dit" mode, keep user-provided codes or empty
619
+ audio_code_string_to_use = params.audio_codes
620
+
621
+ # Update metadata from LM if not provided by user
622
+ if lm_generated_metadata:
623
+ bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
624
+ metadata=lm_generated_metadata,
625
+ bpm=bpm,
626
+ key_scale=key_scale,
627
+ time_signature=time_signature,
628
+ audio_duration=audio_duration,
629
+ vocal_language=dit_input_vocal_language,
630
+ caption=dit_input_caption,
631
+ lyrics=dit_input_lyrics)
632
+ if not params.bpm:
633
+ params.cot_bpm = bpm
634
+ if not params.keyscale:
635
+ params.cot_keyscale = key_scale
636
+ if not params.timesignature:
637
+ params.cot_timesignature = time_signature
638
+ if not params.duration:
639
+ params.cot_duration = audio_duration
640
+ if not params.vocal_language:
641
+ params.cot_vocal_language = vocal_language
642
+ if not params.caption:
643
+ params.cot_caption = caption
644
+ if not params.lyrics:
645
+ params.cot_lyrics = lyrics
646
+ dit_input_lyrics = lyrics
647
+
648
+ # set cot caption and language if needed
649
+ if params.use_cot_caption:
650
+ dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
651
+ if params.use_cot_language:
652
+ dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
653
+
654
+ # Phase 2: DiT music generation
655
+ # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
656
+ result = dit_handler.generate_music(
657
+ captions=dit_input_caption,
658
+ lyrics=dit_input_lyrics,
659
+ bpm=bpm,
660
+ key_scale=key_scale,
661
+ time_signature=time_signature,
662
+ vocal_language=dit_input_vocal_language,
663
+ inference_steps=params.inference_steps,
664
+ guidance_scale=params.guidance_scale,
665
+ use_random_seed=config.use_random_seed,
666
+ seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
667
+ reference_audio=params.reference_audio,
668
+ audio_duration=audio_duration,
669
+ batch_size=config.batch_size if config.batch_size is not None else 1,
670
+ src_audio=params.src_audio,
671
+ audio_code_string=audio_code_string_to_use,
672
+ repainting_start=params.repainting_start,
673
+ repainting_end=params.repainting_end,
674
+ instruction=params.instruction,
675
+ audio_cover_strength=params.audio_cover_strength,
676
+ task_type=params.task_type,
677
+ use_adg=params.use_adg,
678
+ cfg_interval_start=params.cfg_interval_start,
679
+ cfg_interval_end=params.cfg_interval_end,
680
+ shift=params.shift,
681
+ infer_method=params.infer_method,
682
+ timesteps=params.timesteps,
683
+ progress=progress,
684
+ )
685
+
686
+ # Check if generation failed
687
+ if not result.get("success", False):
688
+ return GenerationResult(
689
+ audios=[],
690
+ status_message=result.get("status_message", ""),
691
+ extra_outputs={},
692
+ success=False,
693
+ error=result.get("error"),
694
+ )
695
+
696
+ # Extract results from dit_handler.generate_music dict
697
+ dit_audios = result.get("audios", [])
698
+ status_message = result.get("status_message", "")
699
+ dit_extra_outputs = result.get("extra_outputs", {})
700
+
701
+ # Use the seed list already prepared above (from config.seed or params.seed fallback)
702
+ # actual_seed_list was computed earlier using dit_handler.prepare_seeds
703
+ seed_list = actual_seed_list
704
+
705
+ # Get base params dictionary
706
+ base_params_dict = params.to_dict()
707
+
708
+ # Save audio files using AudioSaver (format from config)
709
+ audio_format = config.audio_format if config.audio_format else "flac"
710
+ audio_saver = AudioSaver(default_format=audio_format)
711
+
712
+ # Use handler's temp_dir for saving files
713
+ if save_dir is not None:
714
+ os.makedirs(save_dir, exist_ok=True)
715
+
716
+ # Build audios list for GenerationResult with params and save files
717
+ # Audio saving and UUID generation handled here, outside of handler
718
+ audios = []
719
+ silent_warnings = []
720
+ for idx, dit_audio in enumerate(dit_audios):
721
+ # Create a copy of params dict for this audio
722
+ audio_params = base_params_dict.copy()
723
+
724
+ # Update audio-specific values
725
+ audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
726
+
727
+ # Add audio codes if batch mode
728
+ if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
729
+ audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
730
+
731
+ # Get audio tensor and metadata
732
+ audio_tensor = dit_audio.get("tensor")
733
+ sample_rate = dit_audio.get("sample_rate", 48000)
734
+
735
+ # Generate UUID for this audio (moved from handler)
736
+ batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
737
+ audio_code_str = lm_generated_audio_codes_list[idx] if (
738
+ lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
739
+ if isinstance(audio_code_str, list):
740
+ audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
741
+
742
+ audio_key = generate_uuid_from_params(audio_params)
743
+
744
+ silent_check = False
745
+ if audio_tensor is not None:
746
+ silent_check, rms_val, peak_val = is_audio_silent(audio_tensor, channels_first=True)
747
+ if silent_check:
748
+ logger.warning(
749
+ f"[generate_music] Silent output detected (idx={idx}, RMS={rms_val:.2e}, peak={peak_val:.2e}). "
750
+ "Likely cause: LLM backend returned empty conditioning, or incompatible torch/triton/flash-attn. "
751
+ "Suggest running with --backend pt."
752
+ )
753
+ silent_warnings.append(
754
+ f"Output {idx + 1}: silent or near-silent (RMS≈{rms_val:.2e}). "
755
+ "Likely causes: LLM backend failure, incompatible torch/triton/flash-attn, or CPU/fallback path. "
756
+ "Try running with --backend pt."
757
+ )
758
+
759
+ audio_path = None
760
+ if audio_tensor is not None and save_dir is not None and not silent_check:
761
+ try:
762
+ audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
763
+ audio_path = audio_saver.save_audio(audio_tensor,
764
+ audio_file,
765
+ sample_rate=sample_rate,
766
+ format=audio_format,
767
+ channels_first=True)
768
+ except Exception as e:
769
+ logger.error(f"[generate_music] Failed to save audio file: {e}")
770
+ audio_path = ""
771
+
772
+ audio_dict = {
773
+ "path": audio_path or "",
774
+ "tensor": audio_tensor,
775
+ "key": audio_key,
776
+ "sample_rate": sample_rate,
777
+ "params": audio_params,
778
+ "silent": silent_check,
779
+ }
780
+
781
+ audios.append(audio_dict)
782
+
783
+ # Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
784
+ extra_outputs = dit_extra_outputs.copy()
785
+ extra_outputs["lm_metadata"] = lm_generated_metadata
786
+
787
+ # Merge time_costs from both LM and DiT into a unified dictionary
788
+ unified_time_costs = {}
789
+
790
+ # Add LM time costs (if LM was used)
791
+ if use_lm and lm_total_time_costs:
792
+ for key, value in lm_total_time_costs.items():
793
+ unified_time_costs[f"lm_{key}"] = value
794
+
795
+ # Add DiT time costs (if available)
796
+ dit_time_costs = dit_extra_outputs.get("time_costs", {})
797
+ if dit_time_costs:
798
+ for key, value in dit_time_costs.items():
799
+ unified_time_costs[f"dit_{key}"] = value
800
+
801
+ # Calculate total pipeline time
802
+ if unified_time_costs:
803
+ lm_total = unified_time_costs.get("lm_total_time", 0.0)
804
+ dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
805
+ unified_time_costs["pipeline_total_time"] = lm_total + dit_total
806
+
807
+ # Update extra_outputs with unified time_costs
808
+ extra_outputs["time_costs"] = unified_time_costs
809
+
810
+ if lm_status:
811
+ status_message = "\n".join(lm_status) + "\n" + status_message
812
+ else:
813
+ status_message = status_message
814
+ if silent_warnings:
815
+ status_message = "⚠️ Silent output detected:\n" + "\n".join(silent_warnings) + "\n\nSuggested fix: try running with --backend pt\n\n" + (status_message or "")
816
+ # Create and return GenerationResult
817
+ return GenerationResult(
818
+ audios=audios,
819
+ status_message=status_message,
820
+ extra_outputs=extra_outputs,
821
+ success=True,
822
+ error=None,
823
+ )
824
+
825
+ except Exception as e:
826
+ logger.exception("Music generation failed")
827
+ return GenerationResult(
828
+ audios=[],
829
+ status_message=f"Error: {str(e)}",
830
+ extra_outputs={},
831
+ success=False,
832
+ error=str(e),
833
+ )
834
+
835
+
836
+ def understand_music(
837
+ llm_handler,
838
+ audio_codes: str,
839
+ temperature: float = 0.85,
840
+ top_k: Optional[int] = None,
841
+ top_p: Optional[float] = None,
842
+ repetition_penalty: float = 1.0,
843
+ use_constrained_decoding: bool = True,
844
+ constrained_decoding_debug: bool = False,
845
+ ) -> UnderstandResult:
846
+ """Understand music from audio codes using the 5Hz Language Model.
847
+
848
+ This function analyzes audio semantic codes and generates metadata about the music,
849
+ including caption, lyrics, BPM, duration, key scale, language, and time signature.
850
+
851
+ If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
852
+ instead of analyzing existing codes.
853
+
854
+ Note: cfg_scale and negative_prompt are not supported in understand mode.
855
+
856
+ Args:
857
+ llm_handler: Initialized LLM handler (LLMHandler instance)
858
+ audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
859
+ Use empty string or "NO USER INPUT" to generate a sample example.
860
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
861
+ top_k: Top-K sampling (None or 0 = disabled)
862
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
863
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
864
+ use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
865
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
866
+
867
+ Returns:
868
+ UnderstandResult with parsed metadata fields and status
869
+
870
+ Example:
871
+ >>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
872
+ >>> if result.success:
873
+ ... print(f"Caption: {result.caption}")
874
+ ... print(f"BPM: {result.bpm}")
875
+ ... print(f"Lyrics: {result.lyrics}")
876
+ """
877
+ # Check if LLM is initialized
878
+ if not llm_handler.llm_initialized:
879
+ return UnderstandResult(
880
+ status_message="5Hz LM not initialized. Please initialize it first.",
881
+ success=False,
882
+ error="LLM not initialized",
883
+ )
884
+
885
+ # If codes are empty, use "NO USER INPUT" to generate a sample example
886
+ if not audio_codes or not audio_codes.strip():
887
+ audio_codes = "NO USER INPUT"
888
+
889
+ try:
890
+ # Call LLM understanding
891
+ metadata, status = llm_handler.understand_audio_from_codes(
892
+ audio_codes=audio_codes,
893
+ temperature=temperature,
894
+ top_k=top_k,
895
+ top_p=top_p,
896
+ repetition_penalty=repetition_penalty,
897
+ use_constrained_decoding=use_constrained_decoding,
898
+ constrained_decoding_debug=constrained_decoding_debug,
899
+ )
900
+
901
+ # Check if LLM returned empty metadata (error case)
902
+ if not metadata:
903
+ return UnderstandResult(
904
+ status_message=status or "Failed to understand audio codes",
905
+ success=False,
906
+ error=status or "Empty metadata returned",
907
+ )
908
+
909
+ # Extract and convert fields
910
+ caption = metadata.get('caption', '')
911
+ lyrics = metadata.get('lyrics', '')
912
+ keyscale = metadata.get('keyscale', '')
913
+ language = metadata.get('language', metadata.get('vocal_language', ''))
914
+ timesignature = metadata.get('timesignature', '')
915
+
916
+ # Convert BPM to int
917
+ bpm = None
918
+ bpm_value = metadata.get('bpm')
919
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
920
+ try:
921
+ bpm = int(bpm_value)
922
+ except (ValueError, TypeError):
923
+ pass
924
+
925
+ # Convert duration to float
926
+ duration = None
927
+ duration_value = metadata.get('duration')
928
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
929
+ try:
930
+ duration = float(duration_value)
931
+ except (ValueError, TypeError):
932
+ pass
933
+
934
+ # Clean up N/A values
935
+ if keyscale == 'N/A':
936
+ keyscale = ''
937
+ if language == 'N/A':
938
+ language = ''
939
+ if timesignature == 'N/A':
940
+ timesignature = ''
941
+
942
+ return UnderstandResult(
943
+ caption=caption,
944
+ lyrics=lyrics,
945
+ bpm=bpm,
946
+ duration=duration,
947
+ keyscale=keyscale,
948
+ language=language,
949
+ timesignature=timesignature,
950
+ status_message=status,
951
+ success=True,
952
+ error=None,
953
+ )
954
+
955
+ except Exception as e:
956
+ logger.exception("Music understanding failed")
957
+ return UnderstandResult(
958
+ status_message=f"Error: {str(e)}",
959
+ success=False,
960
+ error=str(e),
961
+ )
962
+
963
+
964
+ @dataclass
965
+ class CreateSampleResult:
966
+ """Result of creating a music sample from a natural language query.
967
+
968
+ This is used by the "Simple Mode" / "Inspiration Mode" feature where users
969
+ provide a natural language description and the LLM generates a complete
970
+ sample with caption, lyrics, and metadata.
971
+
972
+ Attributes:
973
+ # Metadata Fields
974
+ caption: Generated detailed music description/caption
975
+ lyrics: Generated lyrics (or "[Instrumental]" for instrumental music)
976
+ bpm: Beats per minute (None if not generated)
977
+ duration: Duration in seconds (None if not generated)
978
+ keyscale: Musical key (e.g., "C Major")
979
+ language: Vocal language code (e.g., "en", "zh")
980
+ timesignature: Time signature (e.g., "4")
981
+ instrumental: Whether this is an instrumental piece
982
+
983
+ # Status
984
+ status_message: Status message from sample creation
985
+ success: Whether sample creation completed successfully
986
+ error: Error message if sample creation failed
987
+ """
988
+ # Metadata Fields
989
+ caption: str = ""
990
+ lyrics: str = ""
991
+ bpm: Optional[int] = None
992
+ duration: Optional[float] = None
993
+ keyscale: str = ""
994
+ language: str = ""
995
+ timesignature: str = ""
996
+ instrumental: bool = False
997
+
998
+ # Status
999
+ status_message: str = ""
1000
+ success: bool = True
1001
+ error: Optional[str] = None
1002
+
1003
+ def to_dict(self) -> Dict[str, Any]:
1004
+ """Convert result to dictionary for JSON serialization."""
1005
+ return asdict(self)
1006
+
1007
+
1008
+ def create_sample(
1009
+ llm_handler,
1010
+ query: str,
1011
+ instrumental: bool = False,
1012
+ vocal_language: Optional[str] = None,
1013
+ temperature: float = 0.85,
1014
+ top_k: Optional[int] = None,
1015
+ top_p: Optional[float] = None,
1016
+ repetition_penalty: float = 1.0,
1017
+ use_constrained_decoding: bool = True,
1018
+ constrained_decoding_debug: bool = False,
1019
+ ) -> CreateSampleResult:
1020
+ """Create a music sample from a natural language query using the 5Hz Language Model.
1021
+
1022
+ This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural
1023
+ language description of music and generates a complete sample including:
1024
+ - Detailed caption/description
1025
+ - Lyrics (unless instrumental)
1026
+ - Metadata (BPM, duration, key, language, time signature)
1027
+
1028
+ Note: cfg_scale and negative_prompt are not supported in create_sample mode.
1029
+
1030
+ Args:
1031
+ llm_handler: Initialized LLM handler (LLMHandler instance)
1032
+ query: User's natural language music description (e.g., "a soft Bengali love song")
1033
+ instrumental: Whether to generate instrumental music (no vocals)
1034
+ vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh").
1035
+ If provided, the model will be constrained to generate lyrics in this language.
1036
+ If None or "unknown", no language constraint is applied.
1037
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
1038
+ top_k: Top-K sampling (None or 0 = disabled)
1039
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
1040
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
1041
+ use_constrained_decoding: Whether to use FSM-based constrained decoding
1042
+ constrained_decoding_debug: Whether to enable debug logging
1043
+
1044
+ Returns:
1045
+ CreateSampleResult with generated sample fields and status
1046
+
1047
+ Example:
1048
+ >>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn")
1049
+ >>> if result.success:
1050
+ ... print(f"Caption: {result.caption}")
1051
+ ... print(f"Lyrics: {result.lyrics}")
1052
+ ... print(f"BPM: {result.bpm}")
1053
+ """
1054
+ # Check if LLM is initialized
1055
+ if not llm_handler.llm_initialized:
1056
+ return CreateSampleResult(
1057
+ status_message="5Hz LM not initialized. Please initialize it first.",
1058
+ success=False,
1059
+ error="LLM not initialized",
1060
+ )
1061
+
1062
+ try:
1063
+ # Call LLM to create sample
1064
+ metadata, status = llm_handler.create_sample_from_query(
1065
+ query=query,
1066
+ instrumental=instrumental,
1067
+ vocal_language=vocal_language,
1068
+ temperature=temperature,
1069
+ top_k=top_k,
1070
+ top_p=top_p,
1071
+ repetition_penalty=repetition_penalty,
1072
+ use_constrained_decoding=use_constrained_decoding,
1073
+ constrained_decoding_debug=constrained_decoding_debug,
1074
+ )
1075
+
1076
+ # Check if LLM returned empty metadata (error case)
1077
+ if not metadata:
1078
+ return CreateSampleResult(
1079
+ status_message=status or "Failed to create sample",
1080
+ success=False,
1081
+ error=status or "Empty metadata returned",
1082
+ )
1083
+
1084
+ # Extract and convert fields
1085
+ caption = metadata.get('caption', '')
1086
+ lyrics = metadata.get('lyrics', '')
1087
+ keyscale = metadata.get('keyscale', '')
1088
+ language = metadata.get('language', metadata.get('vocal_language', ''))
1089
+ timesignature = metadata.get('timesignature', '')
1090
+ is_instrumental = metadata.get('instrumental', instrumental)
1091
+
1092
+ # Convert BPM to int
1093
+ bpm = None
1094
+ bpm_value = metadata.get('bpm')
1095
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
1096
+ try:
1097
+ bpm = int(bpm_value)
1098
+ except (ValueError, TypeError):
1099
+ pass
1100
+
1101
+ # Convert duration to float
1102
+ duration = None
1103
+ duration_value = metadata.get('duration')
1104
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
1105
+ try:
1106
+ duration = float(duration_value)
1107
+ except (ValueError, TypeError):
1108
+ pass
1109
+
1110
+ # Clean up N/A values
1111
+ if keyscale == 'N/A':
1112
+ keyscale = ''
1113
+ if language == 'N/A':
1114
+ language = ''
1115
+ if timesignature == 'N/A':
1116
+ timesignature = ''
1117
+
1118
+ return CreateSampleResult(
1119
+ caption=caption,
1120
+ lyrics=lyrics,
1121
+ bpm=bpm,
1122
+ duration=duration,
1123
+ keyscale=keyscale,
1124
+ language=language,
1125
+ timesignature=timesignature,
1126
+ instrumental=is_instrumental,
1127
+ status_message=status,
1128
+ success=True,
1129
+ error=None,
1130
+ )
1131
+
1132
+ except Exception as e:
1133
+ logger.exception("Sample creation failed")
1134
+ return CreateSampleResult(
1135
+ status_message=f"Error: {str(e)}",
1136
+ success=False,
1137
+ error=str(e),
1138
+ )
1139
+
1140
+
1141
+ @dataclass
1142
+ class FormatSampleResult:
1143
+ """Result of formatting user-provided caption and lyrics.
1144
+
1145
+ This is used by the "Format" feature where users provide caption and lyrics,
1146
+ and the LLM formats them into structured music metadata and an enhanced description.
1147
+
1148
+ Attributes:
1149
+ # Metadata Fields
1150
+ caption: Enhanced/formatted music description/caption
1151
+ lyrics: Formatted lyrics (may be same as input or reformatted)
1152
+ bpm: Beats per minute (None if not detected)
1153
+ duration: Duration in seconds (None if not detected)
1154
+ keyscale: Musical key (e.g., "C Major")
1155
+ language: Vocal language code (e.g., "en", "zh")
1156
+ timesignature: Time signature (e.g., "4")
1157
+
1158
+ # Status
1159
+ status_message: Status message from formatting
1160
+ success: Whether formatting completed successfully
1161
+ error: Error message if formatting failed
1162
+ """
1163
+ # Metadata Fields
1164
+ caption: str = ""
1165
+ lyrics: str = ""
1166
+ bpm: Optional[int] = None
1167
+ duration: Optional[float] = None
1168
+ keyscale: str = ""
1169
+ language: str = ""
1170
+ timesignature: str = ""
1171
+
1172
+ # Status
1173
+ status_message: str = ""
1174
+ success: bool = True
1175
+ error: Optional[str] = None
1176
+
1177
+ def to_dict(self) -> Dict[str, Any]:
1178
+ """Convert result to dictionary for JSON serialization."""
1179
+ return asdict(self)
1180
+
1181
+
1182
+ def format_sample(
1183
+ llm_handler,
1184
+ caption: str,
1185
+ lyrics: str,
1186
+ user_metadata: Optional[Dict[str, Any]] = None,
1187
+ temperature: float = 0.85,
1188
+ top_k: Optional[int] = None,
1189
+ top_p: Optional[float] = None,
1190
+ repetition_penalty: float = 1.0,
1191
+ use_constrained_decoding: bool = True,
1192
+ constrained_decoding_debug: bool = False,
1193
+ ) -> FormatSampleResult:
1194
+ """Format user-provided caption and lyrics using the 5Hz Language Model.
1195
+
1196
+ This function takes user input (caption and lyrics) and generates structured
1197
+ music metadata including an enhanced caption, BPM, duration, key, language,
1198
+ and time signature.
1199
+
1200
+ If user_metadata is provided, those values will be used to constrain the
1201
+ decoding, ensuring the output matches user-specified values.
1202
+
1203
+ Note: cfg_scale and negative_prompt are not supported in format mode.
1204
+
1205
+ Args:
1206
+ llm_handler: Initialized LLM handler (LLMHandler instance)
1207
+ caption: User's caption/description (e.g., "Latin pop, reggaeton")
1208
+ lyrics: User's lyrics with structure tags
1209
+ user_metadata: Optional dict with user-provided metadata to constrain decoding.
1210
+ Supported keys: bpm, duration, keyscale, timesignature, language
1211
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
1212
+ top_k: Top-K sampling (None or 0 = disabled)
1213
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
1214
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
1215
+ use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
1216
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
1217
+
1218
+ Returns:
1219
+ FormatSampleResult with formatted metadata fields and status
1220
+
1221
+ Example:
1222
+ >>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...")
1223
+ >>> if result.success:
1224
+ ... print(f"Caption: {result.caption}")
1225
+ ... print(f"BPM: {result.bpm}")
1226
+ ... print(f"Lyrics: {result.lyrics}")
1227
+ """
1228
+ # Check if LLM is initialized
1229
+ if not llm_handler.llm_initialized:
1230
+ return FormatSampleResult(
1231
+ status_message="5Hz LM not initialized. Please initialize it first.",
1232
+ success=False,
1233
+ error="LLM not initialized",
1234
+ )
1235
+
1236
+ try:
1237
+ # Call LLM formatting
1238
+ metadata, status = llm_handler.format_sample_from_input(
1239
+ caption=caption,
1240
+ lyrics=lyrics,
1241
+ user_metadata=user_metadata,
1242
+ temperature=temperature,
1243
+ top_k=top_k,
1244
+ top_p=top_p,
1245
+ repetition_penalty=repetition_penalty,
1246
+ use_constrained_decoding=use_constrained_decoding,
1247
+ constrained_decoding_debug=constrained_decoding_debug,
1248
+ )
1249
+
1250
+ # Check if LLM returned empty metadata (error case)
1251
+ if not metadata:
1252
+ return FormatSampleResult(
1253
+ status_message=status or "Failed to format input",
1254
+ success=False,
1255
+ error=status or "Empty metadata returned",
1256
+ )
1257
+
1258
+ # Extract and convert fields
1259
+ result_caption = metadata.get('caption', '')
1260
+ result_lyrics = metadata.get('lyrics', lyrics) # Fall back to input lyrics
1261
+ keyscale = metadata.get('keyscale', '')
1262
+ language = metadata.get('language', metadata.get('vocal_language', ''))
1263
+ timesignature = metadata.get('timesignature', '')
1264
+
1265
+ # Convert BPM to int
1266
+ bpm = None
1267
+ bpm_value = metadata.get('bpm')
1268
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
1269
+ try:
1270
+ bpm = int(bpm_value)
1271
+ except (ValueError, TypeError):
1272
+ pass
1273
+
1274
+ # Convert duration to float
1275
+ duration = None
1276
+ duration_value = metadata.get('duration')
1277
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
1278
+ try:
1279
+ duration = float(duration_value)
1280
+ except (ValueError, TypeError):
1281
+ pass
1282
+
1283
+ # Clean up N/A values
1284
+ if keyscale == 'N/A':
1285
+ keyscale = ''
1286
+ if language == 'N/A':
1287
+ language = ''
1288
+ if timesignature == 'N/A':
1289
+ timesignature = ''
1290
+
1291
+ return FormatSampleResult(
1292
+ caption=result_caption,
1293
+ lyrics=result_lyrics,
1294
+ bpm=bpm,
1295
+ duration=duration,
1296
+ keyscale=keyscale,
1297
+ language=language,
1298
+ timesignature=timesignature,
1299
+ status_message=status,
1300
+ success=True,
1301
+ error=None,
1302
+ )
1303
+
1304
+ except Exception as e:
1305
+ logger.exception("Format sample failed")
1306
+ return FormatSampleResult(
1307
+ status_message=f"Error: {str(e)}",
1308
+ success=False,
1309
+ error=str(e),
1310
+ )
acestep/llm_inference.py ADDED
The diff for this file is too large to render. See raw diff
 
acestep/model_downloader.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step Model Downloader
3
+
4
+ This module provides functionality to download models from HuggingFace Hub or ModelScope.
5
+ It supports automatic downloading when models are not found locally,
6
+ with intelligent fallback between download sources.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import argparse
12
+ from typing import Optional, List, Dict, Tuple
13
+ from pathlib import Path
14
+
15
+ from loguru import logger
16
+
17
+
18
+ # =============================================================================
19
+ # Network Detection & Smart Download
20
+ # =============================================================================
21
+
22
+ def _can_access_google(timeout: float = 3.0) -> bool:
23
+ """
24
+ Check if Google is accessible (to determine HuggingFace vs ModelScope).
25
+
26
+ Args:
27
+ timeout: Connection timeout in seconds
28
+
29
+ Returns:
30
+ True if Google is accessible, False otherwise
31
+ """
32
+ import socket
33
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
34
+ try:
35
+ sock.settimeout(timeout)
36
+ sock.connect(("www.google.com", 443))
37
+ return True
38
+ except (socket.timeout, socket.error, OSError):
39
+ return False
40
+ finally:
41
+ sock.close()
42
+
43
+
44
+ def _download_from_huggingface_internal(
45
+ repo_id: str,
46
+ local_dir: Path,
47
+ token: Optional[str] = None,
48
+ ) -> None:
49
+ """
50
+ Internal function to download from HuggingFace Hub.
51
+
52
+ Args:
53
+ repo_id: HuggingFace repository ID (e.g., "ACE-Step/Ace-Step1.5")
54
+ local_dir: Local directory to save the model
55
+ token: HuggingFace token for private repos (optional)
56
+
57
+ Raises:
58
+ Exception: If download fails
59
+ """
60
+ from huggingface_hub import snapshot_download
61
+
62
+ logger.info(f"[Model Download] Downloading from HuggingFace: {repo_id} -> {local_dir}")
63
+
64
+ snapshot_download(
65
+ repo_id=repo_id,
66
+ local_dir=str(local_dir),
67
+ local_dir_use_symlinks=False,
68
+ token=token,
69
+ )
70
+
71
+
72
+ def _download_from_modelscope_internal(
73
+ repo_id: str,
74
+ local_dir: Path,
75
+ ) -> None:
76
+ """
77
+ Internal function to download from ModelScope.
78
+
79
+ Args:
80
+ repo_id: ModelScope repository ID (e.g., "ACE-Step/Ace-Step1.5")
81
+ local_dir: Local directory to save the model
82
+
83
+ Raises:
84
+ Exception: If download fails
85
+ """
86
+ from modelscope import snapshot_download
87
+
88
+ logger.info(f"[Model Download] Downloading from ModelScope: {repo_id} -> {local_dir}")
89
+
90
+ snapshot_download(
91
+ model_id=repo_id,
92
+ local_dir=str(local_dir),
93
+ )
94
+
95
+
96
+ def _smart_download(
97
+ repo_id: str,
98
+ local_dir: Path,
99
+ token: Optional[str] = None,
100
+ prefer_source: Optional[str] = None,
101
+ ) -> Tuple[bool, str]:
102
+ """
103
+ Smart download with automatic fallback between HuggingFace and ModelScope.
104
+
105
+ Automatically detects network environment and chooses the best download source.
106
+ If the primary source fails, automatically falls back to the alternative.
107
+
108
+ Args:
109
+ repo_id: Repository ID (same format for both HF and ModelScope)
110
+ local_dir: Local directory to save the model
111
+ token: HuggingFace token for private repos (optional)
112
+ prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
113
+
114
+ Returns:
115
+ Tuple of (success, message)
116
+ """
117
+ # Ensure directory exists
118
+ local_dir.mkdir(parents=True, exist_ok=True)
119
+
120
+ # Determine primary source
121
+ if prefer_source == "huggingface":
122
+ use_huggingface_first = True
123
+ logger.info("[Model Download] User preference: HuggingFace Hub")
124
+ elif prefer_source == "modelscope":
125
+ use_huggingface_first = False
126
+ logger.info("[Model Download] User preference: ModelScope")
127
+ else:
128
+ # Auto-detect network environment
129
+ can_access_google = _can_access_google()
130
+ use_huggingface_first = can_access_google
131
+ logger.info(f"[Model Download] Auto-detected: {'HuggingFace Hub' if can_access_google else 'ModelScope'}")
132
+
133
+ if use_huggingface_first:
134
+ logger.info("[Model Download] Using HuggingFace Hub...")
135
+ try:
136
+ _download_from_huggingface_internal(repo_id, local_dir, token)
137
+ return True, f"Successfully downloaded from HuggingFace: {repo_id}"
138
+ except Exception as e:
139
+ logger.warning(f"[Model Download] HuggingFace download failed: {e}")
140
+ logger.info("[Model Download] Falling back to ModelScope...")
141
+ try:
142
+ _download_from_modelscope_internal(repo_id, local_dir)
143
+ return True, f"Successfully downloaded from ModelScope: {repo_id}"
144
+ except Exception as e2:
145
+ error_msg = f"Both HuggingFace and ModelScope downloads failed. HF: {e}, MS: {e2}"
146
+ logger.error(error_msg)
147
+ return False, error_msg
148
+ else:
149
+ logger.info("[Model Download] Using ModelScope...")
150
+ try:
151
+ _download_from_modelscope_internal(repo_id, local_dir)
152
+ return True, f"Successfully downloaded from ModelScope: {repo_id}"
153
+ except Exception as e:
154
+ logger.warning(f"[Model Download] ModelScope download failed: {e}")
155
+ logger.info("[Model Download] Falling back to HuggingFace Hub...")
156
+ try:
157
+ _download_from_huggingface_internal(repo_id, local_dir, token)
158
+ return True, f"Successfully downloaded from HuggingFace: {repo_id}"
159
+ except Exception as e2:
160
+ error_msg = f"Both ModelScope and HuggingFace downloads failed. MS: {e}, HF: {e2}"
161
+ logger.error(error_msg)
162
+ return False, error_msg
163
+
164
+
165
+ # =============================================================================
166
+ # Model Registry
167
+ # =============================================================================
168
+ # Main model contains core components (vae, text_encoder, default DiT)
169
+ MAIN_MODEL_REPO = "ACE-Step/Ace-Step1.5"
170
+
171
+ # Sub-models that can be downloaded separately into the checkpoints directory
172
+ SUBMODEL_REGISTRY: Dict[str, str] = {
173
+ # LM models
174
+ "acestep-5Hz-lm-0.6B": "ACE-Step/acestep-5Hz-lm-0.6B",
175
+ "acestep-5Hz-lm-4B": "ACE-Step/acestep-5Hz-lm-4B",
176
+ # DiT models
177
+ "acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3",
178
+ "acestep-v15-sft": "ACE-Step/acestep-v15-sft",
179
+ "acestep-v15-base": "ACE-Step/acestep-v15-base",
180
+ "acestep-v15-turbo-shift1": "ACE-Step/acestep-v15-turbo-shift1",
181
+ "acestep-v15-turbo-continuous": "ACE-Step/acestep-v15-turbo-continuous",
182
+ }
183
+
184
+ # Components that come from the main model repo (ACE-Step/Ace-Step1.5)
185
+ MAIN_MODEL_COMPONENTS = [
186
+ "acestep-v15-turbo", # Default DiT model
187
+ "vae", # VAE for audio encoding/decoding
188
+ "Qwen3-Embedding-0.6B", # Text encoder
189
+ "acestep-5Hz-lm-1.7B", # Default LM model (1.7B)
190
+ ]
191
+
192
+ # Default LM model (included in main model)
193
+ DEFAULT_LM_MODEL = "acestep-5Hz-lm-1.7B"
194
+
195
+
196
+ def get_project_root() -> Path:
197
+ """Get the project root directory."""
198
+ current_file = Path(__file__).resolve()
199
+ return current_file.parent.parent
200
+
201
+
202
+ def get_checkpoints_dir(custom_dir: Optional[str] = None) -> Path:
203
+ """Get the checkpoints directory path."""
204
+ if custom_dir:
205
+ return Path(custom_dir)
206
+ return get_project_root() / "checkpoints"
207
+
208
+
209
+ def check_main_model_exists(checkpoints_dir: Optional[Path] = None) -> bool:
210
+ """
211
+ Check if the main model components exist in the checkpoints directory.
212
+
213
+ Returns:
214
+ True if all main model components exist, False otherwise.
215
+ """
216
+ if checkpoints_dir is None:
217
+ checkpoints_dir = get_checkpoints_dir()
218
+
219
+ for component in MAIN_MODEL_COMPONENTS:
220
+ component_path = checkpoints_dir / component
221
+ if not component_path.exists():
222
+ return False
223
+ return True
224
+
225
+
226
+ def check_model_exists(model_name: str, checkpoints_dir: Optional[Path] = None) -> bool:
227
+ """
228
+ Check if a specific model exists in the checkpoints directory.
229
+
230
+ Args:
231
+ model_name: Name of the model to check
232
+ checkpoints_dir: Custom checkpoints directory (optional)
233
+
234
+ Returns:
235
+ True if the model exists, False otherwise.
236
+ """
237
+ if checkpoints_dir is None:
238
+ checkpoints_dir = get_checkpoints_dir()
239
+
240
+ model_path = checkpoints_dir / model_name
241
+ return model_path.exists()
242
+
243
+
244
+ def list_available_models() -> Dict[str, str]:
245
+ """
246
+ List all available models for download.
247
+
248
+ Returns:
249
+ Dictionary mapping local names to HuggingFace repo IDs.
250
+ """
251
+ models = {
252
+ "main": MAIN_MODEL_REPO,
253
+ **SUBMODEL_REGISTRY
254
+ }
255
+ return models
256
+
257
+
258
+ def download_main_model(
259
+ checkpoints_dir: Optional[Path] = None,
260
+ force: bool = False,
261
+ token: Optional[str] = None,
262
+ prefer_source: Optional[str] = None,
263
+ ) -> Tuple[bool, str]:
264
+ """
265
+ Download the main ACE-Step model from HuggingFace or ModelScope.
266
+
267
+ The main model includes:
268
+ - acestep-v15-turbo (default DiT model)
269
+ - vae (audio encoder/decoder)
270
+ - Qwen3-Embedding-0.6B (text encoder)
271
+ - acestep-5Hz-lm-1.7B (default LM model)
272
+
273
+ Args:
274
+ checkpoints_dir: Custom checkpoints directory (optional)
275
+ force: Force re-download even if model exists
276
+ token: HuggingFace token for private repos (optional)
277
+ prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
278
+
279
+ Returns:
280
+ Tuple of (success, message)
281
+ """
282
+ if checkpoints_dir is None:
283
+ checkpoints_dir = get_checkpoints_dir()
284
+
285
+ # Ensure checkpoints directory exists
286
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
287
+
288
+ if not force and check_main_model_exists(checkpoints_dir):
289
+ return True, f"Main model already exists at {checkpoints_dir}"
290
+
291
+ print(f"Downloading main model from {MAIN_MODEL_REPO}...")
292
+ print(f"Destination: {checkpoints_dir}")
293
+ print("This may take a while depending on your internet connection...")
294
+
295
+ # Use smart download with automatic fallback
296
+ return _smart_download(MAIN_MODEL_REPO, checkpoints_dir, token, prefer_source)
297
+
298
+
299
+ def download_submodel(
300
+ model_name: str,
301
+ checkpoints_dir: Optional[Path] = None,
302
+ force: bool = False,
303
+ token: Optional[str] = None,
304
+ prefer_source: Optional[str] = None,
305
+ ) -> Tuple[bool, str]:
306
+ """
307
+ Download a specific sub-model from HuggingFace or ModelScope.
308
+
309
+ Args:
310
+ model_name: Name of the model to download (must be in SUBMODEL_REGISTRY)
311
+ checkpoints_dir: Custom checkpoints directory (optional)
312
+ force: Force re-download even if model exists
313
+ token: HuggingFace token for private repos (optional)
314
+ prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
315
+
316
+ Returns:
317
+ Tuple of (success, message)
318
+ """
319
+ if model_name not in SUBMODEL_REGISTRY:
320
+ available = ", ".join(SUBMODEL_REGISTRY.keys())
321
+ return False, f"Unknown model '{model_name}'. Available models: {available}"
322
+
323
+ if checkpoints_dir is None:
324
+ checkpoints_dir = get_checkpoints_dir()
325
+
326
+ # Ensure checkpoints directory exists
327
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
328
+
329
+ model_path = checkpoints_dir / model_name
330
+
331
+ if not force and model_path.exists():
332
+ return True, f"Model '{model_name}' already exists at {model_path}"
333
+
334
+ repo_id = SUBMODEL_REGISTRY[model_name]
335
+
336
+ print(f"Downloading {model_name} from {repo_id}...")
337
+ print(f"Destination: {model_path}")
338
+
339
+ # Use smart download with automatic fallback
340
+ return _smart_download(repo_id, model_path, token, prefer_source)
341
+
342
+
343
+ def download_all_models(
344
+ checkpoints_dir: Optional[Path] = None,
345
+ force: bool = False,
346
+ token: Optional[str] = None,
347
+ ) -> Tuple[bool, List[str]]:
348
+ """
349
+ Download all available models.
350
+
351
+ Args:
352
+ checkpoints_dir: Custom checkpoints directory (optional)
353
+ force: Force re-download even if models exist
354
+ token: HuggingFace token for private repos (optional)
355
+
356
+ Returns:
357
+ Tuple of (all_success, list of messages)
358
+ """
359
+ if checkpoints_dir is None:
360
+ checkpoints_dir = get_checkpoints_dir()
361
+
362
+ messages = []
363
+ all_success = True
364
+
365
+ # Download main model first
366
+ success, msg = download_main_model(checkpoints_dir, force, token)
367
+ messages.append(msg)
368
+ if not success:
369
+ all_success = False
370
+
371
+ # Download all sub-models
372
+ for model_name in SUBMODEL_REGISTRY:
373
+ success, msg = download_submodel(model_name, checkpoints_dir, force, token)
374
+ messages.append(msg)
375
+ if not success:
376
+ all_success = False
377
+
378
+ return all_success, messages
379
+
380
+
381
+ def ensure_main_model(
382
+ checkpoints_dir: Optional[Path] = None,
383
+ token: Optional[str] = None,
384
+ prefer_source: Optional[str] = None,
385
+ ) -> Tuple[bool, str]:
386
+ """
387
+ Ensure the main model is available, downloading if necessary.
388
+
389
+ This function is designed to be called during initialization.
390
+ It will only download if the model doesn't exist.
391
+
392
+ Args:
393
+ checkpoints_dir: Custom checkpoints directory (optional)
394
+ token: HuggingFace token for private repos (optional)
395
+ prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
396
+
397
+ Returns:
398
+ Tuple of (success, message)
399
+ """
400
+ if checkpoints_dir is None:
401
+ checkpoints_dir = get_checkpoints_dir()
402
+
403
+ if check_main_model_exists(checkpoints_dir):
404
+ return True, "Main model is available"
405
+
406
+ print("\n" + "=" * 60)
407
+ print("Main model not found. Starting automatic download...")
408
+ print("=" * 60 + "\n")
409
+
410
+ return download_main_model(checkpoints_dir, token=token, prefer_source=prefer_source)
411
+
412
+
413
+ def ensure_lm_model(
414
+ model_name: Optional[str] = None,
415
+ checkpoints_dir: Optional[Path] = None,
416
+ token: Optional[str] = None,
417
+ prefer_source: Optional[str] = None,
418
+ ) -> Tuple[bool, str]:
419
+ """
420
+ Ensure an LM model is available, downloading if necessary.
421
+
422
+ Args:
423
+ model_name: Name of the LM model (defaults to DEFAULT_LM_MODEL)
424
+ checkpoints_dir: Custom checkpoints directory (optional)
425
+ token: HuggingFace token for private repos (optional)
426
+ prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
427
+
428
+ Returns:
429
+ Tuple of (success, message)
430
+ """
431
+ if model_name is None:
432
+ model_name = DEFAULT_LM_MODEL
433
+
434
+ if checkpoints_dir is None:
435
+ checkpoints_dir = get_checkpoints_dir()
436
+
437
+ if check_model_exists(model_name, checkpoints_dir):
438
+ return True, f"LM model '{model_name}' is available"
439
+
440
+ # Check if this is a known LM model
441
+ if model_name not in SUBMODEL_REGISTRY:
442
+ # Check if it might be a variant name
443
+ for known_model in SUBMODEL_REGISTRY:
444
+ if "lm" in known_model.lower() and model_name.lower() in known_model.lower():
445
+ model_name = known_model
446
+ break
447
+ else:
448
+ return False, f"Unknown LM model: {model_name}"
449
+
450
+ print("\n" + "=" * 60)
451
+ print(f"LM model '{model_name}' not found. Starting automatic download...")
452
+ print("=" * 60 + "\n")
453
+
454
+ return download_submodel(model_name, checkpoints_dir, token=token, prefer_source=prefer_source)
455
+
456
+
457
+ def ensure_dit_model(
458
+ model_name: str,
459
+ checkpoints_dir: Optional[Path] = None,
460
+ token: Optional[str] = None,
461
+ prefer_source: Optional[str] = None,
462
+ ) -> Tuple[bool, str]:
463
+ """
464
+ Ensure a DiT model is available, downloading if necessary.
465
+
466
+ Args:
467
+ model_name: Name of the DiT model
468
+ checkpoints_dir: Custom checkpoints directory (optional)
469
+ token: HuggingFace token for private repos (optional)
470
+ prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
471
+
472
+ Returns:
473
+ Tuple of (success, message)
474
+ """
475
+ if checkpoints_dir is None:
476
+ checkpoints_dir = get_checkpoints_dir()
477
+
478
+ if check_model_exists(model_name, checkpoints_dir):
479
+ return True, f"DiT model '{model_name}' is available"
480
+
481
+ # Check if this is the default turbo model (part of main)
482
+ if model_name == "acestep-v15-turbo":
483
+ return ensure_main_model(checkpoints_dir, token, prefer_source)
484
+
485
+ # Check if it's a known sub-model
486
+ if model_name in SUBMODEL_REGISTRY:
487
+ print("\n" + "=" * 60)
488
+ print(f"DiT model '{model_name}' not found. Starting automatic download...")
489
+ print("=" * 60 + "\n")
490
+ return download_submodel(model_name, checkpoints_dir, token=token, prefer_source=prefer_source)
491
+
492
+ return False, f"Unknown DiT model: {model_name}"
493
+
494
+
495
+ def print_model_list():
496
+ """Print formatted list of available models."""
497
+ print("\nAvailable Models for Download:")
498
+ print("=" * 60)
499
+ print("\nSupported Sources: HuggingFace Hub <-> ModelScope (auto-fallback)")
500
+
501
+ print("\n[Main Model]")
502
+ print(f" main -> {MAIN_MODEL_REPO}")
503
+ print(" Contains: vae, Qwen3-Embedding-0.6B, acestep-v15-turbo, acestep-5Hz-lm-1.7B")
504
+
505
+ print("\n[Optional LM Models]")
506
+ for name, repo in SUBMODEL_REGISTRY.items():
507
+ if "lm" in name.lower():
508
+ print(f" {name} -> {repo}")
509
+
510
+ print("\n[Optional DiT Models]")
511
+ for name, repo in SUBMODEL_REGISTRY.items():
512
+ if "lm" not in name.lower():
513
+ print(f" {name} -> {repo}")
514
+
515
+ print("\n" + "=" * 60)
516
+
517
+
518
+ def main():
519
+ """CLI entry point for model downloading."""
520
+ parser = argparse.ArgumentParser(
521
+ description="Download ACE-Step models with automatic fallback (HuggingFace <-> ModelScope)",
522
+ formatter_class=argparse.RawDescriptionHelpFormatter,
523
+ epilog="""
524
+ Examples:
525
+ acestep-download # Download main model (includes LM 1.7B)
526
+ acestep-download --all # Download all available models
527
+ acestep-download --model acestep-v15-sft # Download a specific model
528
+ acestep-download --list # List all available models
529
+
530
+ Network Detection:
531
+ Automatically detects network environment and chooses the best download source:
532
+ - Google accessible -> HuggingFace (fallback to ModelScope)
533
+ - Google blocked -> ModelScope (fallback to HuggingFace)
534
+
535
+ Alternative using huggingface-cli:
536
+ huggingface-cli download ACE-Step/Ace-Step1.5 --local-dir ./checkpoints
537
+ huggingface-cli download ACE-Step/acestep-5Hz-lm-0.6B --local-dir ./checkpoints/acestep-5Hz-lm-0.6B
538
+ """
539
+ )
540
+
541
+ parser.add_argument(
542
+ "--model", "-m",
543
+ type=str,
544
+ help="Specific model to download (use --list to see available models)"
545
+ )
546
+ parser.add_argument(
547
+ "--all", "-a",
548
+ action="store_true",
549
+ help="Download all available models"
550
+ )
551
+ parser.add_argument(
552
+ "--list", "-l",
553
+ action="store_true",
554
+ help="List all available models"
555
+ )
556
+ parser.add_argument(
557
+ "--dir", "-d",
558
+ type=str,
559
+ default=None,
560
+ help="Custom checkpoints directory (default: ./checkpoints)"
561
+ )
562
+ parser.add_argument(
563
+ "--force", "-f",
564
+ action="store_true",
565
+ help="Force re-download even if model exists"
566
+ )
567
+ parser.add_argument(
568
+ "--token", "-t",
569
+ type=str,
570
+ default=None,
571
+ help="HuggingFace token for private repos"
572
+ )
573
+ parser.add_argument(
574
+ "--skip-main",
575
+ action="store_true",
576
+ help="Skip downloading the main model (only download specified sub-model)"
577
+ )
578
+
579
+ args = parser.parse_args()
580
+
581
+ # Handle --list
582
+ if args.list:
583
+ print_model_list()
584
+ return 0
585
+
586
+ # Get checkpoints directory
587
+ checkpoints_dir = get_checkpoints_dir(args.dir) if args.dir else get_checkpoints_dir()
588
+ print(f"Checkpoints directory: {checkpoints_dir}")
589
+
590
+ # Handle --all
591
+ if args.all:
592
+ success, messages = download_all_models(checkpoints_dir, args.force, args.token)
593
+ for msg in messages:
594
+ print(msg)
595
+ return 0 if success else 1
596
+
597
+ # Handle --model
598
+ if args.model:
599
+ if args.model == "main":
600
+ success, msg = download_main_model(checkpoints_dir, args.force, args.token)
601
+ elif args.model in SUBMODEL_REGISTRY:
602
+ # Download main model first if needed (unless --skip-main)
603
+ if not args.skip_main and not check_main_model_exists(checkpoints_dir):
604
+ print("Main model not found. Downloading main model first...")
605
+ main_success, main_msg = download_main_model(checkpoints_dir, args.force, args.token)
606
+ print(main_msg)
607
+ if not main_success:
608
+ return 1
609
+
610
+ success, msg = download_submodel(args.model, checkpoints_dir, args.force, args.token)
611
+ else:
612
+ print(f"Unknown model: {args.model}")
613
+ print("Use --list to see available models")
614
+ return 1
615
+
616
+ print(msg)
617
+ return 0 if success else 1
618
+
619
+ # Default: download main model (includes default LM 1.7B)
620
+ print("Downloading main model (includes vae, text encoder, DiT, and LM 1.7B)...")
621
+
622
+ # Download main model
623
+ success, msg = download_main_model(checkpoints_dir, args.force, args.token)
624
+ print(msg)
625
+
626
+ if success:
627
+ print("\nDownload complete!")
628
+ print(f"Models are available at: {checkpoints_dir}")
629
+
630
+ return 0 if success else 1
631
+
632
+
633
+ if __name__ == "__main__":
634
+ sys.exit(main())
handler.py CHANGED
@@ -1,15 +1,13 @@
1
  # handler.py
2
  import base64
3
- import inspect
4
  import io
5
  import os
6
  import traceback
7
- from typing import Any, Dict, Tuple
8
 
9
  import numpy as np
10
  import soundfile as sf
11
 
12
- # Optional torch import for dtype/device handling
13
  try:
14
  import torch
15
  except Exception:
@@ -20,7 +18,7 @@ class EndpointHandler:
20
  """
21
  Hugging Face Inference Endpoints custom handler for ACE-Step 1.5.
22
 
23
- Request body shape:
24
  {
25
  "inputs": {
26
  "prompt": "upbeat pop rap, emotional guitar",
@@ -29,130 +27,144 @@ class EndpointHandler:
29
  "sample_rate": 44100,
30
  "seed": 42,
31
  "guidance_scale": 7.0,
32
- "steps": 50,
33
  "use_lm": true,
34
  "simple_prompt": false,
35
- "model_repo": "ACE-Step/Ace-Step1.5"
 
36
  }
37
  }
38
 
39
- Also supported for simple mode:
40
  {
41
  "inputs": "upbeat pop rap with emotional guitar"
42
  }
43
 
44
- Response:
45
- {
46
- "audio_base64_wav": "...",
47
- "sample_rate": 44100,
48
- "duration_sec": 12,
49
- "used_fallback": false,
50
- "model_loaded": true,
51
- "model_error": null,
52
- "meta": {...}
53
- }
54
  """
55
 
56
  def __init__(self, path: str = ""):
57
  self.path = path
58
- self.model = None
59
- self.model_error = None
60
  self.model_repo = os.getenv("ACE_MODEL_REPO", "ACE-Step/Ace-Step1.5")
 
 
 
 
 
61
  self.default_sr = int(os.getenv("DEFAULT_SAMPLE_RATE", "44100"))
 
 
 
62
 
63
- # Runtime knobs
64
  self.device = "cuda" if (torch is not None and torch.cuda.is_available()) else "cpu"
65
  self.dtype = "float16" if self.device == "cuda" else "float32"
66
 
67
- # Try to initialize ACE-Step pipeline from repo code paths.
68
- # Repo mentions Python API and module path `acestep.acestep_v15_pipeline`.
69
- self._init_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  # --------------------------
72
- # Initialization helpers
73
  # --------------------------
74
  def _init_model(self) -> None:
75
  err_msgs = []
76
 
77
- # Strategy A: class/factory in acestep.acestep_v15_pipeline
78
  try:
79
- from acestep import acestep_v15_pipeline as m # type: ignore
80
-
81
- # Try common factory patterns
82
- if hasattr(m, "from_pretrained"):
83
- self.model = m.from_pretrained(self.model_repo) # type: ignore
84
- elif hasattr(m, "AceStepV15Pipeline"):
85
- cls = getattr(m, "AceStepV15Pipeline")
86
- if hasattr(cls, "from_pretrained"):
87
- self.model = cls.from_pretrained(self.model_repo)
88
- else:
89
- self.model = cls(model_path=self.model_repo)
90
- elif hasattr(m, "Pipeline"):
91
- cls = getattr(m, "Pipeline")
92
- if hasattr(cls, "from_pretrained"):
93
- self.model = cls.from_pretrained(self.model_repo)
94
- else:
95
- self.model = cls(self.model_repo)
96
- else:
97
- raise RuntimeError("No known pipeline class/factory found in acestep_v15_pipeline")
98
 
99
- # Move device if supported
100
- if self.model is not None and hasattr(self.model, "to"):
101
- try:
102
- self.model.to(self.device)
103
- except Exception:
104
- pass
105
 
106
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  except Exception as e:
108
- err_msgs.append(f"Strategy A failed: {type(e).__name__}: {e}")
109
 
110
- # Strategy B: import root `acestep` and find a likely pipeline symbol
111
  try:
112
- import acestep # type: ignore
113
-
114
- candidates = [
115
- "AceStepV15Pipeline",
116
- "AceStepPipeline",
117
- "Pipeline",
118
- "create_pipeline",
119
- "build_pipeline",
120
- "load_pipeline",
121
- ]
122
-
123
- obj = None
124
- for name in candidates:
125
- if hasattr(acestep, name):
126
- obj = getattr(acestep, name)
127
- break
128
-
129
- if obj is None:
130
- raise RuntimeError("No known pipeline symbol found in `acestep` package")
131
-
132
- if callable(obj):
133
- # class or factory
134
- if hasattr(obj, "from_pretrained"):
135
- self.model = obj.from_pretrained(self.model_repo)
136
- else:
137
- # try keyword variants
138
- try:
139
- self.model = obj(model_path=self.model_repo)
140
- except TypeError:
141
- self.model = obj(self.model_repo)
142
- else:
143
- self.model = obj
144
-
145
- if self.model is not None and hasattr(self.model, "to"):
146
- try:
147
- self.model.to(self.device)
148
- except Exception:
149
- pass
150
- return
151
  except Exception as e:
152
- err_msgs.append(f"Strategy B failed: {type(e).__name__}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- self.model = None
155
- self.model_error = " | ".join(err_msgs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # --------------------------
158
  # Audio helpers
@@ -166,7 +178,6 @@ class EndpointHandler:
166
  else:
167
  arr = np.asarray(audio)
168
 
169
- # Convert common tensor shape [channels, samples] to [samples, channels].
170
  if arr.ndim == 2 and arr.shape[0] in (1, 2) and arr.shape[1] > arr.shape[0]:
171
  arr = arr.T
172
 
@@ -188,6 +199,9 @@ class EndpointHandler:
188
  y = (0.07 * np.sin(2 * np.pi * 440 * t) + 0.01 * rng.standard_normal(len(t))).astype(np.float32)
189
  return np.clip(y, -1.0, 1.0)
190
 
 
 
 
191
  @staticmethod
192
  def _to_bool(value: Any, default: bool = False) -> bool:
193
  if value is None:
@@ -242,39 +256,18 @@ class EndpointHandler:
242
  if not lyrics and (instrumental or simple_prompt):
243
  lyrics = "[Instrumental]"
244
 
245
- duration_sec = self._to_int(raw_inputs.get("duration_sec", raw_inputs.get("duration", 10)), 10)
246
- duration_sec = max(1, min(duration_sec, 600))
247
 
248
  sample_rate = self._to_int(raw_inputs.get("sample_rate", self.default_sr), self.default_sr)
249
  sample_rate = max(8000, min(sample_rate, 48000))
250
 
251
  seed = self._to_int(raw_inputs.get("seed", 42), 42)
252
  guidance_scale = self._to_float(raw_inputs.get("guidance_scale", 7.0), 7.0)
253
- steps = self._to_int(raw_inputs.get("steps", raw_inputs.get("inference_steps", 50)), 50)
254
- steps = max(1, min(steps, 500))
255
  use_lm = self._to_bool(raw_inputs.get("use_lm", raw_inputs.get("thinking", True)), True)
256
- task_type = self._pick_text(raw_inputs, "task_type") or "text2music"
257
-
258
- model_repo = raw_inputs.get("model_repo")
259
-
260
- model_kwargs = {
261
- "task_type": task_type,
262
- "prompt": prompt,
263
- "caption": prompt,
264
- "query": prompt,
265
- "lyrics": lyrics,
266
- "duration_sec": duration_sec,
267
- "duration": duration_sec,
268
- "sample_rate": sample_rate,
269
- "seed": seed,
270
- "guidance_scale": guidance_scale,
271
- "steps": steps,
272
- "inference_steps": steps,
273
- "num_inference_steps": steps,
274
- "use_lm": use_lm,
275
- "thinking": use_lm,
276
- "instrumental": instrumental,
277
- }
278
 
279
  return {
280
  "prompt": prompt,
@@ -287,165 +280,144 @@ class EndpointHandler:
287
  "use_lm": use_lm,
288
  "instrumental": instrumental,
289
  "simple_prompt": simple_prompt,
290
- "model_repo": model_repo,
291
- "model_kwargs": model_kwargs,
292
  }
293
 
294
- @staticmethod
295
- def _invoke_with_supported_kwargs(fn: Any, kwargs: Dict[str, Any]) -> Any:
296
- try:
297
- sig = inspect.signature(fn)
298
- has_var_kw = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
299
- if has_var_kw:
300
- return fn(**kwargs)
301
- accepted = {
302
- name
303
- for name, p in sig.parameters.items()
304
- if p.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
305
- }
306
- filtered = {k: v for k, v in kwargs.items() if k in accepted}
307
- return fn(**filtered)
308
- except Exception:
309
- # Fallback for C-extension callables or dynamic signatures.
310
- return fn(**kwargs)
311
-
312
- def _normalize_model_output(self, out: Any, default_sr: int) -> Tuple[np.ndarray, int]:
313
- if out is None:
314
- raise RuntimeError("Model returned None")
315
-
316
- if hasattr(out, "success") and not getattr(out, "success"):
317
- err = getattr(out, "error", "unknown model error")
318
- raise RuntimeError(str(err))
319
-
320
- if hasattr(out, "audios"):
321
- audios = getattr(out, "audios") or []
322
- if not audios:
323
- raise RuntimeError("Model result has no audios")
324
- first = audios[0]
325
- if isinstance(first, dict):
326
- audio = first.get("tensor", first.get("audio", first.get("waveform", first.get("wav"))))
327
- sr = first.get("sample_rate", default_sr)
328
- else:
329
- audio = getattr(first, "tensor", getattr(first, "audio", None))
330
- sr = getattr(first, "sample_rate", default_sr)
331
- if audio is None:
332
- raise RuntimeError("Model result audio entry is missing tensor/audio")
333
- return self._as_float32(audio), int(sr)
334
-
335
- if isinstance(out, tuple) and len(out) >= 1:
336
- audio = out[0]
337
- sr = int(out[1]) if len(out) > 1 and out[1] is not None else default_sr
338
- return self._as_float32(audio), sr
339
-
340
- if isinstance(out, dict):
341
- if "audios" in out:
342
- audios = out.get("audios") or []
343
- if not audios:
344
- raise RuntimeError("Model output `audios` is empty")
345
- first = audios[0]
346
- if not isinstance(first, dict):
347
- raise RuntimeError("Model output `audios[0]` must be a dict")
348
- audio = first.get("tensor", first.get("audio", first.get("waveform", first.get("wav"))))
349
- sr = first.get("sample_rate", default_sr)
350
- if audio is None:
351
- raise RuntimeError("Model output `audios[0]` missing tensor/audio")
352
- return self._as_float32(audio), int(sr)
353
-
354
- audio = out.get("audio", out.get("waveform", out.get("wav", out.get("tensor"))))
355
- sr = out.get("sample_rate", out.get("sr", default_sr))
356
- if audio is None:
357
- raise RuntimeError("Model dict output missing audio/waveform field")
358
- return self._as_float32(audio), int(sr)
359
-
360
- for name in ("audio", "waveform", "wav", "tensor"):
361
- if hasattr(out, name):
362
- audio = getattr(out, name)
363
- if audio is not None:
364
- sr = getattr(out, "sample_rate", getattr(out, "sr", default_sr))
365
- return self._as_float32(audio), int(sr)
366
-
367
- return self._as_float32(out), default_sr
368
-
369
  # --------------------------
370
- # Inference
371
  # --------------------------
372
- def _call_model(
373
- self,
374
- model_kwargs: Dict[str, Any],
375
- sample_rate: int,
376
- ) -> Tuple[np.ndarray, int]:
377
- """
378
- Tries multiple invocation styles to tolerate minor ACE-Step API differences.
379
- Returns (audio_np, sample_rate).
380
- """
381
- if self.model is None:
382
- raise RuntimeError("Model is not loaded")
383
-
384
- # Common callable entrypoints
385
- methods = [
386
- "__call__",
387
- "generate",
388
- "infer",
389
- "inference",
390
- "text_to_music",
391
- "run",
392
- ]
393
-
394
- last_err = None
395
- for m in methods:
396
- try:
397
- fn = self.model if m == "__call__" else getattr(self.model, m, None)
398
- if fn is None:
399
- continue
400
-
401
- # Try full kwargs
402
- try:
403
- out = self._invoke_with_supported_kwargs(fn, model_kwargs)
404
- except TypeError:
405
- # Narrow payload if signature is strict
406
- skinny = {
407
- "prompt": model_kwargs.get("prompt"),
408
- "caption": model_kwargs.get("caption"),
409
- "lyrics": model_kwargs.get("lyrics"),
410
- "duration": model_kwargs.get("duration"),
411
- "seed": model_kwargs.get("seed"),
412
- }
413
- skinny = {k: v for k, v in skinny.items() if v is not None and (k != "prompt" or str(v).strip())}
414
- out = self._invoke_with_supported_kwargs(fn, skinny)
415
-
416
- return self._normalize_model_output(out, sample_rate)
417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  except Exception as e:
419
- last_err = e
420
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
- raise RuntimeError(f"No compatible inference method worked: {last_err}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
425
  try:
426
  req = self._normalize_request(data)
427
 
428
- # Optional override
429
- model_repo = req.get("model_repo")
430
- if model_repo and model_repo != self.model_repo:
431
- # hot-switch model only if user asks
432
- self.model_repo = str(model_repo)
433
- self._init_model()
434
-
435
  used_fallback = False
 
 
 
 
 
 
 
 
436
 
437
- if self.model is not None:
438
- try:
439
- audio, out_sr = self._call_model(
440
- model_kwargs=req["model_kwargs"],
441
- sample_rate=req["sample_rate"],
442
- )
443
- except Exception as e:
444
- used_fallback = True
445
- self.model_error = f"Inference failed: {type(e).__name__}: {e}"
446
- audio = self._fallback_sine(req["duration_sec"], req["sample_rate"], req["seed"])
447
- out_sr = req["sample_rate"]
448
- else:
449
  used_fallback = True
450
  audio = self._fallback_sine(req["duration_sec"], req["sample_rate"], req["seed"])
451
  out_sr = req["sample_rate"]
@@ -455,7 +427,7 @@ class EndpointHandler:
455
  "sample_rate": int(out_sr),
456
  "duration_sec": int(req["duration_sec"]),
457
  "used_fallback": used_fallback,
458
- "model_loaded": self.model is not None,
459
  "model_repo": self.model_repo,
460
  "model_error": self.model_error,
461
  "meta": {
@@ -469,16 +441,34 @@ class EndpointHandler:
469
  "use_lm": req["use_lm"],
470
  "simple_prompt": req["simple_prompt"],
471
  "instrumental": req["instrumental"],
472
- "resolved_prompt": req["prompt"],
473
- "resolved_lyrics": req["lyrics"],
 
 
 
 
 
 
 
 
474
  },
475
  }
476
 
477
  except Exception as e:
478
  return {
479
  "error": f"{type(e).__name__}: {e}",
480
- "traceback": traceback.format_exc(limit=3),
481
  "audio_base64_wav": None,
482
  "sample_rate": None,
483
  "duration_sec": None,
484
- }
 
 
 
 
 
 
 
 
 
 
 
1
  # handler.py
2
  import base64
 
3
  import io
4
  import os
5
  import traceback
6
+ from typing import Any, Dict, Optional, Tuple
7
 
8
  import numpy as np
9
  import soundfile as sf
10
 
 
11
  try:
12
  import torch
13
  except Exception:
 
18
  """
19
  Hugging Face Inference Endpoints custom handler for ACE-Step 1.5.
20
 
21
+ Supported request shapes:
22
  {
23
  "inputs": {
24
  "prompt": "upbeat pop rap, emotional guitar",
 
27
  "sample_rate": 44100,
28
  "seed": 42,
29
  "guidance_scale": 7.0,
30
+ "steps": 8,
31
  "use_lm": true,
32
  "simple_prompt": false,
33
+ "instrumental": false,
34
+ "allow_fallback": false
35
  }
36
  }
37
 
38
+ Or simple mode:
39
  {
40
  "inputs": "upbeat pop rap with emotional guitar"
41
  }
42
 
43
+ Notes:
44
+ - This handler uses ACE-Step's official Python API internally.
45
+ - Fallback sine generation is disabled by default so model failures are explicit.
 
 
 
 
 
 
 
46
  """
47
 
48
  def __init__(self, path: str = ""):
49
  self.path = path
50
+ self.project_root = os.path.dirname(os.path.abspath(__file__))
51
+
52
  self.model_repo = os.getenv("ACE_MODEL_REPO", "ACE-Step/Ace-Step1.5")
53
+ self.config_path = os.getenv("ACE_CONFIG_PATH", "acestep-v15-turbo")
54
+ self.lm_model_path = os.getenv("ACE_LM_MODEL_PATH", "acestep-5Hz-lm-1.7B")
55
+ self.lm_backend = os.getenv("ACE_LM_BACKEND", "pt")
56
+ self.download_source = os.getenv("ACE_DOWNLOAD_SOURCE", "huggingface")
57
+
58
  self.default_sr = int(os.getenv("DEFAULT_SAMPLE_RATE", "44100"))
59
+ self.enable_fallback = self._to_bool(os.getenv("ACE_ENABLE_FALLBACK"), False)
60
+ self.init_lm_on_start = self._to_bool(os.getenv("ACE_INIT_LLM"), False)
61
+ self.skip_init = self._to_bool(os.getenv("ACE_SKIP_INIT"), False)
62
 
 
63
  self.device = "cuda" if (torch is not None and torch.cuda.is_available()) else "cpu"
64
  self.dtype = "float16" if self.device == "cuda" else "float32"
65
 
66
+ self.model_loaded = False
67
+ self.model_error: Optional[str] = None
68
+ self.init_details: Dict[str, Any] = {}
69
+
70
+ self.dit_handler = None
71
+ self.llm_handler = None
72
+ self.llm_initialized = False
73
+ self.llm_error: Optional[str] = None
74
+
75
+ self._GenerationParams = None
76
+ self._GenerationConfig = None
77
+ self._generate_music = None
78
+ self._create_sample = None
79
+
80
+ if self.skip_init:
81
+ self.model_error = "Initialization skipped because ACE_SKIP_INIT=true"
82
+ else:
83
+ self._init_model()
84
 
85
  # --------------------------
86
+ # Initialization
87
  # --------------------------
88
  def _init_model(self) -> None:
89
  err_msgs = []
90
 
 
91
  try:
92
+ from acestep.handler import AceStepHandler
93
+ from acestep.inference import GenerationConfig, GenerationParams, create_sample, generate_music
94
+ from acestep.llm_inference import LLMHandler
95
+ except Exception as e:
96
+ self.model_error = f"ACE-Step import failed: {type(e).__name__}: {e}"
97
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ self._GenerationParams = GenerationParams
100
+ self._GenerationConfig = GenerationConfig
101
+ self._generate_music = generate_music
102
+ self._create_sample = create_sample
 
 
103
 
104
+ try:
105
+ self.dit_handler = AceStepHandler()
106
+ prefer_source = self.download_source if self.download_source in {"huggingface", "modelscope"} else None
107
+ init_status, ok = self.dit_handler.initialize_service(
108
+ project_root=self.project_root,
109
+ config_path=self.config_path,
110
+ device=self.device,
111
+ use_flash_attention=False,
112
+ compile_model=False,
113
+ offload_to_cpu=False,
114
+ offload_dit_to_cpu=False,
115
+ prefer_source=prefer_source,
116
+ )
117
+ self.init_details["dit_status"] = init_status
118
+ if not ok:
119
+ raise RuntimeError(init_status)
120
  except Exception as e:
121
+ err_msgs.append(f"DiT init failed: {type(e).__name__}: {e}")
122
 
 
123
  try:
124
+ self.llm_handler = LLMHandler()
125
+ if self.init_lm_on_start:
126
+ self._ensure_llm_initialized()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  except Exception as e:
128
+ err_msgs.append(f"LLM bootstrap failed: {type(e).__name__}: {e}")
129
+
130
+ if err_msgs:
131
+ self.model_loaded = False
132
+ self.model_error = " | ".join(err_msgs)
133
+ return
134
+
135
+ self.model_loaded = True
136
+ self.model_error = None
137
+
138
+ def _ensure_llm_initialized(self) -> bool:
139
+ if self.llm_handler is None:
140
+ self.llm_error = "LLM handler is not available"
141
+ return False
142
+
143
+ if self.llm_initialized:
144
+ return True
145
 
146
+ try:
147
+ checkpoint_dir = os.path.join(self.project_root, "checkpoints")
148
+ status, ok = self.llm_handler.initialize(
149
+ checkpoint_dir=checkpoint_dir,
150
+ lm_model_path=self.lm_model_path,
151
+ backend=self.lm_backend,
152
+ device=self.device,
153
+ offload_to_cpu=False,
154
+ )
155
+ self.init_details["llm_status"] = status
156
+ if not ok:
157
+ self.llm_error = status
158
+ self.llm_initialized = False
159
+ return False
160
+
161
+ self.llm_error = None
162
+ self.llm_initialized = True
163
+ return True
164
+ except Exception as e:
165
+ self.llm_error = f"LLM init exception: {type(e).__name__}: {e}"
166
+ self.llm_initialized = False
167
+ return False
168
 
169
  # --------------------------
170
  # Audio helpers
 
178
  else:
179
  arr = np.asarray(audio)
180
 
 
181
  if arr.ndim == 2 and arr.shape[0] in (1, 2) and arr.shape[1] > arr.shape[0]:
182
  arr = arr.T
183
 
 
199
  y = (0.07 * np.sin(2 * np.pi * 440 * t) + 0.01 * rng.standard_normal(len(t))).astype(np.float32)
200
  return np.clip(y, -1.0, 1.0)
201
 
202
+ # --------------------------
203
+ # Request normalization
204
+ # --------------------------
205
  @staticmethod
206
  def _to_bool(value: Any, default: bool = False) -> bool:
207
  if value is None:
 
256
  if not lyrics and (instrumental or simple_prompt):
257
  lyrics = "[Instrumental]"
258
 
259
+ duration_sec = self._to_int(raw_inputs.get("duration_sec", raw_inputs.get("duration", 12)), 12)
260
+ duration_sec = max(10, min(duration_sec, 600))
261
 
262
  sample_rate = self._to_int(raw_inputs.get("sample_rate", self.default_sr), self.default_sr)
263
  sample_rate = max(8000, min(sample_rate, 48000))
264
 
265
  seed = self._to_int(raw_inputs.get("seed", 42), 42)
266
  guidance_scale = self._to_float(raw_inputs.get("guidance_scale", 7.0), 7.0)
267
+ steps = self._to_int(raw_inputs.get("steps", raw_inputs.get("inference_steps", 8)), 8)
268
+ steps = max(1, min(steps, 200))
269
  use_lm = self._to_bool(raw_inputs.get("use_lm", raw_inputs.get("thinking", True)), True)
270
+ allow_fallback = self._to_bool(raw_inputs.get("allow_fallback"), self.enable_fallback)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  return {
273
  "prompt": prompt,
 
280
  "use_lm": use_lm,
281
  "instrumental": instrumental,
282
  "simple_prompt": simple_prompt,
283
+ "allow_fallback": allow_fallback,
 
284
  }
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  # --------------------------
287
+ # ACE-Step invocation
288
  # --------------------------
289
+ def _build_generation_inputs(self, req: Dict[str, Any], llm_ready: bool) -> Tuple[Dict[str, Any], Dict[str, Any]]:
290
+ caption = req["prompt"]
291
+ lyrics = req["lyrics"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ extras: Dict[str, Any] = {
294
+ "simple_expansion_used": False,
295
+ "simple_expansion_error": None,
296
+ }
297
+
298
+ bpm = None
299
+ keyscale = ""
300
+ timesignature = ""
301
+ vocal_language = "unknown"
302
+ duration = float(req["duration_sec"])
303
+
304
+ if req["simple_prompt"] and req["use_lm"] and llm_ready and caption:
305
+ try:
306
+ sample = self._create_sample(
307
+ llm_handler=self.llm_handler,
308
+ query=caption,
309
+ instrumental=req["instrumental"],
310
+ )
311
+ if getattr(sample, "success", False):
312
+ caption = getattr(sample, "caption", "") or caption
313
+ lyrics = getattr(sample, "lyrics", "") or lyrics
314
+ bpm = getattr(sample, "bpm", None)
315
+ keyscale = getattr(sample, "keyscale", "") or ""
316
+ timesignature = getattr(sample, "timesignature", "") or ""
317
+ vocal_language = getattr(sample, "language", "") or "unknown"
318
+ sample_duration = getattr(sample, "duration", None)
319
+ if sample_duration:
320
+ duration = float(sample_duration)
321
+ extras["simple_expansion_used"] = True
322
+ else:
323
+ extras["simple_expansion_error"] = getattr(sample, "error", "create_sample failed")
324
  except Exception as e:
325
+ extras["simple_expansion_error"] = f"{type(e).__name__}: {e}"
326
+
327
+ params = self._GenerationParams(
328
+ task_type="text2music",
329
+ caption=caption,
330
+ lyrics=lyrics,
331
+ instrumental=req["instrumental"],
332
+ duration=duration,
333
+ inference_steps=req["steps"],
334
+ guidance_scale=req["guidance_scale"],
335
+ seed=req["seed"],
336
+ bpm=bpm,
337
+ keyscale=keyscale,
338
+ timesignature=timesignature,
339
+ vocal_language=vocal_language,
340
+ thinking=bool(req["use_lm"] and llm_ready),
341
+ use_cot_metas=bool(req["use_lm"] and llm_ready),
342
+ use_cot_caption=bool(req["use_lm"] and llm_ready and not req["simple_prompt"]),
343
+ use_cot_language=bool(req["use_lm"] and llm_ready),
344
+ )
345
+
346
+ config = self._GenerationConfig(
347
+ batch_size=1,
348
+ allow_lm_batch=False,
349
+ use_random_seed=False,
350
+ seeds=[req["seed"]],
351
+ audio_format="wav",
352
+ )
353
+
354
+ extras["resolved_prompt"] = caption
355
+ extras["resolved_lyrics"] = lyrics
356
+ extras["resolved_duration"] = duration
357
 
358
+ return {"params": params, "config": config}, extras
359
+
360
+ def _call_model(self, req: Dict[str, Any]) -> Tuple[np.ndarray, int, Dict[str, Any]]:
361
+ if not self.model_loaded or self.dit_handler is None:
362
+ raise RuntimeError(self.model_error or "Model is not loaded")
363
+
364
+ llm_ready = False
365
+ if req["use_lm"]:
366
+ llm_ready = self._ensure_llm_initialized()
367
+
368
+ generation_inputs, extras = self._build_generation_inputs(req, llm_ready)
369
+
370
+ result = self._generate_music(
371
+ self.dit_handler,
372
+ self.llm_handler if llm_ready else None,
373
+ generation_inputs["params"],
374
+ generation_inputs["config"],
375
+ save_dir=None,
376
+ progress=None,
377
+ )
378
 
379
+ if not getattr(result, "success", False):
380
+ raise RuntimeError(getattr(result, "error", "generation failed"))
381
+
382
+ audios = getattr(result, "audios", None) or []
383
+ if not audios:
384
+ raise RuntimeError("generation succeeded but no audio was returned")
385
+
386
+ first = audios[0]
387
+ audio_tensor = first.get("tensor") if isinstance(first, dict) else None
388
+ if audio_tensor is None:
389
+ raise RuntimeError("generated audio tensor is missing")
390
+
391
+ sample_rate = int(first.get("sample_rate", req["sample_rate"]))
392
+ status_message = getattr(result, "status_message", "")
393
+
394
+ meta = {
395
+ "llm_requested": req["use_lm"],
396
+ "llm_initialized": llm_ready,
397
+ "llm_error": self.llm_error,
398
+ "status_message": status_message,
399
+ }
400
+ meta.update(extras)
401
+
402
+ return self._as_float32(audio_tensor), sample_rate, meta
403
+
404
+ # --------------------------
405
+ # Endpoint entry
406
+ # --------------------------
407
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
408
  try:
409
  req = self._normalize_request(data)
410
 
 
 
 
 
 
 
 
411
  used_fallback = False
412
+ runtime_meta: Dict[str, Any] = {}
413
+
414
+ try:
415
+ audio, out_sr, runtime_meta = self._call_model(req)
416
+ except Exception as model_exc:
417
+ self.model_error = f"Inference failed: {type(model_exc).__name__}: {model_exc}"
418
+ if not req["allow_fallback"]:
419
+ raise RuntimeError(self.model_error)
420
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  used_fallback = True
422
  audio = self._fallback_sine(req["duration_sec"], req["sample_rate"], req["seed"])
423
  out_sr = req["sample_rate"]
 
427
  "sample_rate": int(out_sr),
428
  "duration_sec": int(req["duration_sec"]),
429
  "used_fallback": used_fallback,
430
+ "model_loaded": self.model_loaded,
431
  "model_repo": self.model_repo,
432
  "model_error": self.model_error,
433
  "meta": {
 
441
  "use_lm": req["use_lm"],
442
  "simple_prompt": req["simple_prompt"],
443
  "instrumental": req["instrumental"],
444
+ "allow_fallback": req["allow_fallback"],
445
+ "resolved_prompt": runtime_meta.get("resolved_prompt", req["prompt"]),
446
+ "resolved_lyrics": runtime_meta.get("resolved_lyrics", req["lyrics"]),
447
+ "simple_expansion_used": runtime_meta.get("simple_expansion_used", False),
448
+ "simple_expansion_error": runtime_meta.get("simple_expansion_error"),
449
+ "llm_requested": runtime_meta.get("llm_requested", False),
450
+ "llm_initialized": runtime_meta.get("llm_initialized", False),
451
+ "llm_error": runtime_meta.get("llm_error"),
452
+ "status_message": runtime_meta.get("status_message", ""),
453
+ "init_details": self.init_details,
454
  },
455
  }
456
 
457
  except Exception as e:
458
  return {
459
  "error": f"{type(e).__name__}: {e}",
460
+ "traceback": traceback.format_exc(limit=4),
461
  "audio_base64_wav": None,
462
  "sample_rate": None,
463
  "duration_sec": None,
464
+ "used_fallback": False,
465
+ "model_loaded": self.model_loaded,
466
+ "model_repo": self.model_repo,
467
+ "model_error": self.model_error,
468
+ "meta": {
469
+ "device": self.device,
470
+ "dtype": self.dtype,
471
+ "init_details": self.init_details,
472
+ "llm_error": self.llm_error,
473
+ },
474
+ }
requirements.txt CHANGED
@@ -2,7 +2,13 @@ numpy
2
  soundfile
3
  torch
4
  torchaudio
5
- transformers
6
  accelerate
7
  huggingface_hub
8
- git+https://github.com/ace-step/ACE-Step-1.5.git@f30aee4c186c33b7b8a6ea59a4b7fc36f795b49f
 
 
 
 
 
 
 
2
  soundfile
3
  torch
4
  torchaudio
5
+ transformers>=4.51.0,<4.58.0
6
  accelerate
7
  huggingface_hub
8
+ diffusers
9
+ loguru
10
+ tqdm
11
+ numba>=0.63.1
12
+ PyYAML
13
+ modelscope
14
+ filelock>=3.13.0